gpt4 book ai didi

tensorflow - tf.gradients 会通过 tf.cond 吗?

转载 作者:行者123 更新时间:2023-12-04 13:42:14 24 4
gpt4 key购买 nike

我想创建一对循环神经网络,比如 NN1 和 NN2,其中 NN2 从前一个时间步再现其输出,并且只要 NN1 输出与前一个时间步不同的值,就不会在当前时间步更新其权重。

为此,我计划使用 tf.cond()连同tf.stop_gradients() .但是,在我运行的所有玩具示例中,我无法获得 tf.gradients()穿越 tf.cond() :tf.gradients()只需返回 [None] .

这是一个简单的玩具示例:

import tensorflow as tf

x = tf.constant(5)
y = tf.constant(3)

mult = tf.multiply(x, y)
cond = tf.cond(pred = tf.constant(True),
true_fn = lambda: mult,
false_fn = lambda: mult)

grad = tf.gradients(cond, x) # Returns [None]

这是我定义的另一个简单的玩具示例 true_fnfalse_fntf.cond() (仍然没有骰子):

import tensorflow as tf

x = tf.constant(5)
y = tf.constant(3)
z = tf.constant(8)

cond = tf.cond(pred = x < y,
true_fn = lambda: tf.add(x, z),
false_fn = lambda: tf.square(y))

tf.gradients(cond, z) # Returns [None]

我原本认为梯度应该流经 true_fnfalse_fn ,但显然根本没有梯度流动。这是通过 tf.cond() 计算的梯度的预期行为吗? ?可能有办法解决这个问题吗?

最佳答案

是的,梯度将通过 tf.cond() .您只需要使用浮点数而不是整数,并且(最好)使用变量而不是常量:


import tensorflow as tf

x = tf.Variable(5.0, dtype=tf.float32)
y = tf.Variable(6.0, dtype=tf.float32)
z = tf.Variable(8.0, dtype=tf.float32)

cond = tf.cond(pred = x < y,
true_fn = lambda: tf.add(x, z),
false_fn = lambda: tf.square(y))

op = tf.gradients(cond, z)
# Returns [<tf.Tensor 'gradients_1/cond_1/Add/Switch_1_grad/cond_grad:0' shape=() dtype=float32>]

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(op)) # [1.0]

关于tensorflow - tf.gradients 会通过 tf.cond 吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55381342/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com