gpt4 book ai didi

python - 在 tensorflow-r1.2 中正确使用 `tf.scatter_nd`

转载 作者:太空宇宙 更新时间:2023-11-03 14:51:48 28 4
gpt4 key购买 nike

给定形状为[batch_size, sequence_len]indices,形状为[batch_size, sequence_len, sampled_size]updates >,to_shape 形状为 [batch_size, sequence_len, vocab_size],其中 vocab_size >> sampled_size,我会喜欢使用 tf.scatterupdates 映射到带有 to_shape 的巨大张量,这样 to_shape[bs, indices[bs , sz]] = 更新[bs, sz]。也就是说,我想逐行将 updates 映射到 to_shape。请注意,sequence_lensampled_size 是标量张量,而其他的是固定的。我尝试执行以下操作:

new_tensor = tf.scatter_nd(tf.expand_dims(indices, axis=2), updates, to_shape)

但是我得到一个错误:

ValueError: The inner 2 dimension of output.shape=[?,?,?] must match the inner 1 dimension of updates.shape=[80,50,?]: Shapes must be equal rank, but are 2 and 1 for .... with input shapes: [80, 50, 1], [80, 50,?], [3]

你能告诉我如何正确使用 scatter_nd 吗?提前致谢!

最佳答案

假设你有:

  • 张量更新,形状为[batch_size, sequence_len, sampled_size]
  • 张量指数,形状为[batch_size, sequence_len, sampled_size]

然后你做:

import tensorflow as tf

# Create updates and indices...

# Create additional indices
i1, i2 = tf.meshgrid(tf.range(batch_size),
tf.range(sequence_len), indexing="ij")
i1 = tf.tile(i1[:, :, tf.newaxis], [1, 1, sampled_size])
i2 = tf.tile(i2[:, :, tf.newaxis], [1, 1, sampled_size])
# Create final indices
idx = tf.stack([i1, i2, indices], axis=-1)
# Output shape
to_shape = [batch_size, sequence_len, vocab_size]
# Get scattered tensor
output = tf.scatter_nd(idx, updates, to_shape)

tf.scatter_nd采用 indices 张量、updates 张量和某种形状。 updates是原始张量,shape就是想要的输出shape,所以[batch_size, sequence_len, vocab_size]。现在,indices 更复杂了。由于您的输出有 3 个维度(等级 3),对于 updates 中的每个元素,您需要 3 个索引来确定每个元素在输出中的放置位置。因此 indices 参数的形状应该与 updates 相同,但增加了一个大小为 3 的维度。在这种情况下,我们希望第一个维度相同,但我们仍然必须指定 3 个索引。所以我们使用 tf.meshgrid生成我们需要的索引,并沿着第三个维度平铺它们(updates 最后一个维度中每个元素向量的第一个和第二个索引是相同的)。最后,我们将这些索引与之前创建的映射索引堆叠起来,我们就有了完整的 3 维索引。

关于python - 在 tensorflow-r1.2 中正确使用 `tf.scatter_nd`,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45162998/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com