gpt4 book ai didi

python - Tensorflow:针对每个样本计算 Hessian

转载 作者:太空宇宙 更新时间:2023-11-04 09:44:34 26 4
gpt4 key购买 nike

我有一个张量 X,大小为 M x D。我们可以将 X 的每一行解释为训练样本,将每一列解释为一个特征。

X 用于计算大小为 M x 1 的张量 u(换句话说,u 依赖于计算图中的 X)。我们可以将其解释为预测向量;每个 sample 一个。特别是,u 的第 m 行仅使用 X 的第 m 行计算。

现在,如果我运行 tensor.gradients(u, X)[0],我将获得一个 M x D 张量对应于u 相对于 X 的“每个样本”梯度。

我怎样才能类似地计算“每个样本”的 Hessian 张量? (即,M x D x D 数量)


附录:彼得在下面的回答是正确的。我还发现了一种使用堆叠和拆堆叠的不同方法(使用 Peter 的表示法):

hess2 = tf.stack([
tf.gradients( tmp, a )[ 0 ]
for tmp in tf.unstack( grad, num=5, axis=1 )
], axis = 2)

在 Peter 的示例中,D=5 是特征数。 我怀疑(但我没有检查过) 上面的方法对于 M 比较大,因为它跳过了 Peter 的回答中提到的零条目.

最佳答案

tf.hessians()正在为提供的 ysxs 计算 Hessian,而不考虑维度。因为你有维度 M x D 和维度 M x 的 xs 结果D 因此结果的维度将是 M x D x M x D。但是由于每个样本的输出彼此独立,因此大部分 Hessian 矩阵将为零,即三维中只有一个切片具有任何值。因此,为了获得您想要的结果,您应该取两个 M 维度的对角线,或者更简单,您应该简单地求和并消除第三个维度,如下所示:

hess2 = tf.reduce_sum( hess, axis = 2 )

示例代码(已测试):

import tensorflow as tf

a = tf.constant( [ [ 1.0, 1, 1, 1, 1 ], [ 2, 2, 2, 2, 2 ], [ 3, 3, 3, 3, 3 ] ] )
b = tf.constant( [ [ 1.0 ], [ 2 ], [ 3 ], [ 4 ], [ 5 ] ] )
c = tf.matmul( a, b )
c_sq = tf.square( c )

grad = tf.gradients( c_sq, a )[ 0 ]

hess = tf.hessians( c_sq, a )[ 0 ]
hess2 = tf.reduce_sum( hess, axis = 2 )


with tf.Session() as sess:
res = sess.run( [ c_sq, grad, hess2 ] )

for v in res:
print( v.shape )
print( v )
print( "=======================")

将输出:

(3, 1)
[[ 225.]
[ 900.]
[2025.]]
=======================
(3, 5)
[[ 30. 60. 90. 120. 150.]
[ 60. 120. 180. 240. 300.]
[ 90. 180. 270. 360. 450.]]
=======================
(3, 5, 5)
[[[ 2. 4. 6. 8. 10.]
[ 4. 8. 12. 16. 20.]
[ 6. 12. 18. 24. 30.]
[ 8. 16. 24. 32. 40.]
[10. 20. 30. 40. 50.]]

[[ 2. 4. 6. 8. 10.]
[ 4. 8. 12. 16. 20.]
[ 6. 12. 18. 24. 30.]
[ 8. 16. 24. 32. 40.]
[10. 20. 30. 40. 50.]]

[[ 2. 4. 6. 8. 10.]
[ 4. 8. 12. 16. 20.]
[ 6. 12. 18. 24. 30.]
[ 8. 16. 24. 32. 40.]
[10. 20. 30. 40. 50.]]]
=======================

关于python - Tensorflow:针对每个样本计算 Hessian,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50310532/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com