gpt4 book ai didi

tensorflow - tf.Print() 导致梯度错误

转载 作者:行者123 更新时间:2023-12-04 02:00:16 27 4
gpt4 key购买 nike

我试图使用 tf.Print调试语句以更好地理解从 compute_gradients() 报告的梯度和变量的格式,但遇到了意外问题。训练例程和调试例程(gvdebug)如下:

def gvdebug(g, v):
#g = tf.Print(g,[g],'G: ')
#v = tf.Print(v,[v],'V: ')
g2 = tf.zeros_like(g, dtype=tf.float32)
v2 = tf.zeros_like(v, dtype=tf.float32)
g2 = g
v2 = v
return g2,v2

# Define training operation
def training(loss, global_step, learning_rate=0.1):
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
grads_and_vars = optimizer.compute_gradients(loss)
gv2 = [gvdebug(gv[0], gv[1]) for gv in grads_and_vars]
train_op = optimizer.apply_gradients(gv2, global_step=global_step)
return train_op

此代码工作正常(但不打印),但如果我取消注释 gvdebug() 中的两个 tf.Print 行,我会从 apply_gradients 收到一条错误消息:'TypeError: Variable must be a tf.Variable'。我以为 tf.Print 刚刚通过了张量——我做错了什么?

最佳答案

TL;博士

不要试图tf.Print gv[1]因为它是 tf.Variable .它就像一个指向创建 gradient 的变量的指针。在 gv[0] .

更多信息

当您运行时 compute_gradients它返回一个列表 gradients 以及它们对应的 tf.Variable s .
grads_and_vars的每个元素是 Tensor和一个 tf.Variable .需要注意的是不是 变量的值。

删除 v = tf.Print(v,[v],'V: ') 后,您的代码对我有用

关于tensorflow - tf.Print() 导致梯度错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34756903/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com