gpt4 book ai didi

Tensorflow cond 不会停止错误分支上的梯度

转载 作者:行者123 更新时间:2023-12-02 03:01:52 25 4
gpt4 key购买 nike

我正在构建一个 RNN 模型,其中 init_state 可能来自两种情况之一。 1) 通过 feed_dict 从之前的时间步长输出状态输入的静态 init_state。 2) 变量的一些函数,我称之为分数。

init_state = cell.zero_state(batch,tf.float32)
with tf.name_scope('hidden1'):
weights_h1 = tf.Variable(
tf.truncated_normal([T, cells_dim],
stddev=1.0 / np.sqrt(T)),
name='weights')
biases_h1 = tf.Variable(tf.zeros([cells_dim]),
name='biases')
hidden1 = tf.nn.relu(tf.matmul(score, weights_h1) + biases_h1)

init_state2 = tf.cond(is_start, lambda: hidden1, lambda: init_state)

init_state2 然后用作 static_rnn 的输入,最终用于计算损失和 train_op。当 is_start 为 False 时,我希望 train_op 对 weights_h1 没有影响。但是,权重在每次更新后都会发生变化。非常感谢任何帮助。

最佳答案

这应该有效:

def return_init_state():
init_state = cell.zero_state(batch,tf.float32)
return init_state

def return_hidden_1():
with tf.name_scope('hidden1'):
weights_h1 = tf.Variable(
tf.truncated_normal([T, cells_dim],
stddev=1.0 / np.sqrt(T)),
name='weights')
biases_h1 = tf.Variable(tf.zeros([cells_dim]),
name='biases')
hidden1 = tf.nn.relu(tf.matmul(score, weights_h1) + biases_h1)

return hidden1

init_state2 = tf.cond(is_start, lambda: return_hidden_1, lambda: return_init_state)

注意如何在 tf.cond 的上下文中调用这些方法。因此,无论创建什么操作,都将在 tf.cond 的上下文中。否则,在您的情况下,操作将以任何一种方式运行。

关于Tensorflow cond 不会停止错误分支上的梯度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45391008/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com