gpt4 book ai didi

python - 相对于重复函数的梯度

转载 作者:行者123 更新时间:2023-12-01 02:10:21 24 4
gpt4 key购买 nike

我在计算重复调用的转换函数的梯度时遇到问题。

相对于 Action 计算的梯度是“无”,即使损失取决于由重复转换调用生成的最大值之和所选择的所选 Action 。如果我们将损失函数的值更改为 v 而不是 a 的总和,那么我们会收到过渡的梯度。

为什么当我们的损失是根据 a 的总和来计算时,没有计算过渡的梯度?

下面是一段示例代码,您可以在其中复制该问题。

import tensorflow as tf
import numpy as np

ACTION_DIM = 1

# random input
x = tf.Variable(np.random.rand(1, 5)) # [b branches, state_dim]

depth = 3
b = 4
v_list, a_list = [], [] # value and action store
# make value estimates 3 steps into the future by predicting intermediate states
for i in range(depth):
reuse = True if i > 0 else False
x = tf.tile(x, [b, 1]) # copy the state to be used for b different actions
mu = tf.layers.dense(x, ACTION_DIM, name='mu', reuse=reuse)
action_distribution = tf.distributions.Normal(loc=mu, scale=tf.ones_like(mu))
a = tf.reshape(action_distribution.sample(1), [-1, ACTION_DIM])
x_a = tf.concat([x, a], axis=1) # concatenate action and state
x = tf.layers.dense(x_a, x.shape[-1], name='transition', reuse=reuse) # next state s'
v = tf.layers.dense(x, 1, name='value', reuse=reuse) # value of s'
v_list.append(tf.reshape(v, [-1, b ** i]))
a_list.append(tf.reshape(a, [-1, b ** i]))

# backup our sum of max values along trajectory
sum_v = [None]*depth
sum_v[-1] = v_list[-1]
for i in reversed(range(depth)):
max_v_i = tf.reduce_max(v_list[i], axis=1)
if i > 0:
sum_v[i-1] = tf.reduce_max(v_list[i-1], axis=1) + max_v_i

max_idx = tf.reshape(tf.argmax(sum_v[0]), [-1, 1])
v = tf.gather_nd(v_list[0], max_idx)
a = tf.gather_nd(a_list[0], max_idx)
loss = -tf.reduce_sum(a)
opt = tf.train.AdamOptimizer()
grads = opt.compute_gradients(loss)

最佳答案

我认为问题源于您在定义 col_idx 时的 arg_max 调用。 Arg_max 是一个位置参数,因此没有渐变。这是有道理的,因为列表中最大值的位置不会随着最大值的变化而变化。

我也不相信对 tf.contrib.distributions.Normal 的调用将具有相对于其输入变量的导数,但这只是因为它位于 contrib 中。如果修复arg_max后问题仍然存在,也许您可​​以尝试使用默认的tensorflow。

关于python - 相对于重复函数的梯度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48740685/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com