gpt4 book ai didi

neural-network - TensorFlow 训练

转载 作者:行者123 更新时间:2023-12-03 10:14:40 26 4
gpt4 key购买 nike

假设我有一个非常简单的神经网络,比如多层感知器。对于每一层,激活函数都是 sigmoid 并且网络是全连接的。

在 TensorFlow 中,这可能是这样定义的:

    sess = tf.InteractiveSession()

# Training Tensor
x = tf.placeholder(tf.float32, shape = [None, n_fft])
# Label Tensor
y_ = tf.placeholder(tf.float32, shape = [None, n_fft])

# Declaring variable buffer for weights W and bias b
# Layer structure [n_fft, n_fft, n_fft, n_fft]
# Input -> Layer 1
struct_w = [n_fft, n_fft]
struct_b = [n_fft]
W1 = weight_variable(struct_w, 'W1')
b1 = bias_variable(struct_b, 'b1')
h1 = tf.nn.sigmoid(tf.matmul(x, W1) + b1)

# Layer1 -> Layer 2
W2 = weight_variable(struct_w, 'W2')
b2 = bias_variable(struct_b, 'b2')
h2 = tf.nn.sigmoid(tf.matmul(h1, W2) + b2)

# Layer2 -> output
W3 = weight_variable(struct_w, 'W3')
b3 = bias_variable(struct_b, 'b3')
y = tf.nn.sigmoid(tf.matmul(h2, W3) + b3)

# Calculating difference between label and output using mean square error
mse = tf.reduce_mean(tf.square(y - y_))

# Train the Model
# Gradient Descent
train_step = tf.train.GradientDescentOptimizer(0.3).minimize(mse)

此模型的设计目标是映射一个 n_fft将 fft 频谱图指向另一个 n_fft目标频谱图。让我们假设训练数据和目标数据的大小都是 [3000, n_fft] .它们存储在变量 spec_train 中和 spec_target .

现在问题来了。对于 TensorFlow,这两种训练有什么区别吗?

培训一:

for i in xrange(200):
train_step.run(feed_dict = {x: spec_train, y_: spec_target})

培训2:

for i in xrange(200):
for j in xrange(3000):
train = spec_train[j, :].reshape(1, n_fft)
label = spec_target[j, :].reshape(1, n_fft)
train_step.run(feed_dict = {x: train, y_: label})

非常感谢!

最佳答案

在第一个训练版本中,您一次训练整批训练数据,这意味着 spec_train 的第一个和第 3000 个元素将在一个步骤中使用相同的模型参数进行处理。这被称为 (批量)梯度下降 .

在第二个训练版本中,您一次从训练数据中训练一个示例,这意味着 spec_train 的第 3000 个元素将使用自最近处理第一个元素以来已更新 2999 次的模型参数进行处理。这被称为 随机梯度下降 (或者如果元素是随机选择的)。

通常,TensorFlow 用于过大而无法一次性处理的数据集,因此小批量 SGD(在一个步骤中处理示例的子集)受到青睐。一次处理单个元素在理论上是可取的,但本质上是顺序的,并且具有很高的固定成本,因为矩阵乘法和其他操作在计算上并不那么密集。因此,一次处理小批量(例如 32 或 128 个)示例是常用的方法,在不同批次上并行训练多个副本。

看到这个 Stats StackExchange question有关何时应该使用一种方法与另一种方法的更多理论讨论。

关于neural-network - TensorFlow 训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34097457/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com