gpt4 book ai didi

python - 如何修改 tf.nn.embedding_lookup() 的返回张量?

转载 作者:行者123 更新时间:2023-11-30 09:06:43 25 4
gpt4 key购买 nike

我想使用scatter_nd_update来更改从tf.nn.embedding_lookup()返回的张量的内容。但是,返回的张量不可变,并且 scatter_nd_update() 需要可变张量作为输入。我花了很多时间试图找到解决方案,包括使用 gen_state_ops._temporary_variable 和使用 tf.sparse_to_dense,不幸的是都失败了。

我想知道是否有一个完美的解决方案?

with tf.device('/cpu:0'), tf.name_scope("embedding"):
self.W = tf.Variable(
tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
name="W")
self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
updates = tf.constant(0,shape=[embedding_size])
for i in range(1,sequence_length - 2):
indices = [None,i]
tf.scatter_nd_update(self.embedded_chars,indices,updates)
self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)

最佳答案

tf.nn.embedding_lookup只是返回较大矩阵的切片,因此最简单的解决方案是更新矩阵本身的值,在您的情况下它是self.W:

self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)

由于它是一个变量,因此它符合 tf.scatter_nd_update 。请注意,您不能只更新任何张量,只能更新变量

另一个选项是仅为所选切片创建一个新变量,为其分配 self.embedded_chars 并随后执行更新。

<小时/>

警告:在这两种情况下,您都会阻止梯度来训练嵌入矩阵,因此请仔细检查覆盖学习值是否确实是您想要的。

关于python - 如何修改 tf.nn.embedding_lookup() 的返回张量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50259009/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com