gpt4 book ai didi

python - tf.GradientTape() 的位置对模型训练时间的影响

转载 作者:太空宇宙 更新时间:2023-11-04 04:04:15 25 4
gpt4 key购买 nike

我试图在每个时期更新权重,但我正在分批处理数据。问题是,为了规范化损失,我需要在训练循环之外记录 TensorFlow 变量(以进行跟踪和规范化)。但是当我这样做时,训练时间很长。

我认为,它将所有批处理的变量累积到图中并在最后计算梯度。

我已经开始跟踪 for 循环外和 for 循环内的变量,后者比第一个更快。我对为什么会发生这种情况感到困惑,因为无论我做什么,我的模型的可训练变量和损失都保持不变。

# Very Slow

loss_value = 0
batches = 0

with tf.GradientTape() as tape:
for inputs, min_seq in zip(dataset, minutes_sequence):
temp_loss_value = my_loss_function(inputs, min_seq)
batches +=1
loss_value = loss_value + temp_loss_value

# The following line takes huge time.
grads = tape.gradient(loss_value, model.trainable_variables)

# Very Fast

loss_value = 0
batches = 0

for inputs, min_seq in zip(dataset, minutes_sequence):
with tf.GradientTape() as tape:
temp_loss_value = my_loss_function(inputs, min_seq)
batches +=1
loss_value = loss_value + temp_loss_value

# If I do the following line, the graph will break because this are out of tape's scope.
loss_value = loss_value / batches

# the following line takes huge time
grads = tape.gradient(loss_value, model.trainable_variables)

当我在 for 循环内声明 tf.GradientTape() 时,它非常快,但在外面它很慢。

附言- 这是针对自定义损失的,架构仅包含一个大小为 10 的隐藏层。

我想知道 tf.GradientTape() 的位置所造成的差异以及它应该如何用于批处理数据集中每个时期的权重更新。

最佳答案

磁带变量主要用于观察可训练的张量变量(记录变量的先前值和变化值),以便我们可以根据损失函数计算训练一个epoch的梯度。它是此处用于记录变量状态的 python 上下文管理器构造的实现。关于 python 上下文管理器的优秀资源是 here .因此,如果在循环内部,它将记录前向传递的变量(权重),以便我们可以一次性计算所有这些变量的梯度(而不是像没有像 tensorflow 这样的库的天真的实现中基于堆栈的梯度传递) .如果它在循环之外,它将记录所有时期的状态,并且根据 Tensorflow 源代码,如果使用 TF2.0,它也会刷新,这与模型开发人员必须负责刷新的 TF1.x 不同。在您的示例中,您没有设置任何作者,但如果设置了任何作者,它也会这样做。因此,对于上面的代码,它将继续记录(内部使用 Graph.add_to_collection 方法)所有权重,并且随着时代的增加,您应该会看到减速。减速率将与网络规模(可训练变量)和当前纪元数成正比。

所以放在循环里面是正确的。此外,梯度应应用于 for 循环内部而不是外部(与 with 处于相同的缩进级别),否则您仅在训练循环结束时(最后一个纪元之后)应用梯度。我发现您的训练对于梯度检索的当前位置可能不是那么好(尽管您在代码片段中省略了它,但之后它被应用到您的代码中)。

再多一个好resource在我刚找到的渐变带上。

关于python - tf.GradientTape() 的位置对模型训练时间的影响,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57658294/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com