gpt4 book ai didi

numpy - 计算 Hessians w.r.t 高阶变量既不能通过 tf.hessians() 也不能通过 tf.gradients() 工作

转载 作者:行者123 更新时间:2023-11-30 08:46:16 27 4
gpt4 key购买 nike

当我们需要计算双梯度或Hessian时,在tensorflow中,我们可以使用tf.hessians(F(x),x) ,或使用tf.gradient(tf.gradients(F(x),x)[0], x)[0] 。然而,当 x不是排名第一,我在使用 tf.hessians() 时被告知以下错误.

ValueError: Cannot compute Hessian because element 0 of xs does not have rank one.. Tensor model_inputs/action:0 must have rank 1. Received rank 2, shape (?, 1)

在下面的代码中:

with tf.name_scope("1st scope"):
self.states = tf.placeholder(tf.float32, (None, self.state_dim), name="states")
self.action = tf.placeholder(tf.float32, (None, self.action_dim), name="action")

with tf.name_scope("2nd scope"):
with tf.variable_scope("3rd scope"):
self.policy_outputs = self.policy_network(self.states)
# use tf.gradients twice
self.actor_action_gradients = tf.gradients(self.policy_outputs, self.action)[0]
self.actor_action_hessian = tf.gradients(self.actor_action_gradients, self.action)[0]
# or use tf.hessians
self.actor_action_hessian = tf.hessian(self.policy_outputs, self.action)

当使用tf.gradients()时,也会导致错误:

in create_variables self.actor_action_hessian = tf.gradients(self.actor_action_gradients, self.action)[0]

AttributeError: 'NoneType' object has no attribute 'dtype'

我该如何解决这个问题,tf.gradients()也不tf.hessians()在这种情况下可以使用吗?

最佳答案

第二种方法很好,错误在其他地方,即你的图没有连接。

self.actor_action_gradients = tf.gradients(self.policy_outputs, self.action)[0]
self.actor_action_hessian = tf.gradients(self.actor_action_gradients, self.action)[0]

第二行抛出错误,因为 self.actor_action_gradients 为 None,因此您无法计算其梯度。您的代码中没有任何内容表明 self.policy_outputs 依赖于 self.action (而且它不应该依赖于 self.action,因为它的操作依赖于策略,而不是基于操作的策略)。

一旦你解决了这个问题,你就会注意到,“hessian”并不是真正的 hessian,而是一个向量,以形成 f wrt 的正确 hessian。 x 您必须迭代 tf.gradients 返回的所有值,并独立计算每个值的 tf.gradients。这是 TF 中的一个已知限制,目前没有更简单的方法可用。

关于numpy - 计算 Hessians w.r.t 高阶变量既不能通过 tf.hessians() 也不能通过 tf.gradients() 工作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45865829/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com