gpt4 book ai didi

tensorflow - 如何停止 tensorflow 中张量某些条目的梯度

转载 作者:行者123 更新时间:2023-12-03 02:12:17 26 4
gpt4 key购买 nike

我正在尝试实现嵌入层。将使用预先训练的手套嵌入来初始化嵌入。对于可以在手套中找到的单词,它将被修复。对于那些没有出现在手套中的单词,它会被随机初始化,并且是可训练的。我如何在 tensorflow 中做到这一点?我知道整个张量有一个 tf.stop_gradient ,有没有适合这种场景的 stop_gradient api?或者,有什么解决方法吗?任何建议表示赞赏

最佳答案

所以想法是使用masktf.stop_gradient来解决这个问题:

res_matrix = tf.stop_gradient(mask_h*E) + mask*E,

在矩阵mask中,1表示我想要应用渐变的条目,0表示我不想应用渐变的条目(将渐变设置为0), mask_hmask 的倒数(1 翻转为 0,0 翻转为 1)。然后我们可以从 res_matrix 中获取。这是测试代码:

import tensorflow as tf
import numpy as np

def entry_stop_gradients(target, mask):
mask_h = tf.abs(mask-1)
return tf.stop_gradient(mask_h * target) + mask * target

mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)

emb = tf.constant(np.ones([10, 5]))

matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))

parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)

loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
print sess.run(loss)
print sess.run([grad1, grad2])

关于tensorflow - 如何停止 tensorflow 中张量某些条目的梯度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43364985/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com