gpt4 book ai didi

python - tf.tensor_scatter_nd_add 函数确实有效

转载 作者:太空宇宙 更新时间:2023-11-04 01:45:06 26 4
gpt4 key购买 nike

我想在具有两个新值矩阵的张量的第一个维度中插入两个切片,我正在使用方法 tensor_scatter_add 但它给了我一个错误

indices = tf.constant([[0], [2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]]])
tensor = tf.ones([4, 5, 4])
updated = tf.tensor_scatter_add(tensor, indices, updates)
with tf.Session() as se:
print(ses.run(scatter))

最佳答案

tensor 的内部 2 个维度必须匹配 updates 的内部 2 个维度。 两个形状中的维度 0 必须相等,但 5 和 4

tensor 必须与 updates 具有相同的 dtype 但在您的代码中不同。

存在以下错误:

with tf.Session() as se:
print(ses.run(scatter))

您将 tf.Session() 别名为 se 但调用 ses 而不是 se 和您传递的散点图到 ses.run() 但它没有在任何地方定义; se.run(updated) 应该是正确的函数调用。

带有代码修复的片段:
这应该适合您。

indices = tf.constant([[0], [2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]]])
tensor = tf.ones([4, 4, 4], dtype=tf.int32)
updated = tf.tensor_scatter_nd_add(tensor, indices, updates)
with tf.Session() as se:
print(se.run(updated))

关于python - tf.tensor_scatter_nd_add 函数确实有效,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59246382/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com