gpt4 book ai didi

python - tensorflow.get_collection()中的collection是否清空?

转载 作者:太空宇宙 更新时间:2023-11-04 05:23:08 25 4
gpt4 key购买 nike

我正在通过斯坦福类(class)学习使用 Tensorflow 的神经网络。我在实现 RNN 时发现了这一点,但不太明白为什么会累积损失:

# This adds a loss operation to the Graph for batch training
def add_loss_op(self, output):
all_ones = [tf.ones([self.config.batch_size * self.config.num_steps])]
cross_entropy = sequence_loss(
[output], [tf.reshape(self.labels_placeholder, [-1])], all_ones, len(self.vocab))
tf.add_to_collection('total_loss', cross_entropy)
# Doesn't this increase in size every batch of training?
loss = tf.add_n(tf.get_collection('total_loss'))
return loss

get_collection() 的文档 here没有提及有关清除变量的任何内容。由于这是针对每个训练步骤运行的,因此损失是否会在每个时期/小批量训练中增加并结转?

我对神经网络还是个新手,所以请纠正我对此的任何误解!

最佳答案

我相信这里的 add_n 实际上只是为了确保将“total_loss”集合中任何预先存在的损失添加到最终结果中。它不会更改任何变量,只是将其输入相加并返回总数。

关于python - tensorflow.get_collection()中的collection是否清空?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39657063/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com