gpt4 book ai didi

python - 如何正确使用 tf.scatter_update 进行 N 维更新?

转载 作者:行者123 更新时间:2023-12-01 08:28:03 26 4
gpt4 key购买 nike

我一直在尝试使用 tf.scatter_update 进行 N 维更新(在 tf.scatter_nd 由于形状不匹配而失败之后)。一般来说,这些将用于创建用于过滤传入张量切片的掩码。

假设输入张量 A 的形状为 (batch, i, j, k(深度))。我只对修改所有 k所有 bi,j 值感兴趣。

MWE:

import tensorflow as tf

b, i, j, k = 64, 128, 128, 256
A = tf.random_uniform(shape=(64, 128, 128, 256), dtype='int32', seed=1234) # Batch, i, j, k

mask = tf.ones(shape=(b,i,j,k), dtype='int32')

# Placeholder for more complicated index Tensor. GPU Ignores OOB indices.
indices = tf.random_uniform(shape=(b, 25, k, 2), dtype='int32', seed=4321) # Index number, k, i-j coord.

updates = tf.random_uniform(shape=(i, j, k), dtype='int32', seed=1111)
scatter = tf.scatter_update(mask, indices, updates)

with tf.Session() as sess:
sess.run(scatter)

结果:

AttributeError: 'Tensor' object has no attribute '_lazy_read'

我已经通过 Python 脚本、Python Notebook 以及使用/不使用 Eager Execution 进行了尝试。运气不好。

输入绝对必须是张量,因为其想法是通过一系列操作中途稀疏地更新该张量。

关于tf.scatter_update我是否缺少一些基本的东西? tf.scatter_nd 会更合适吗?如果是这样,有什么区别,特别是更新索引。

当引用tf.scatter_update时文档,示例是基本的并且使用常量;我很难将其应用到更现实的情况和问题中。

最佳答案

Tensorflow 的文档通过将 ref 参数输入为 tf.Variable 来使用所有scatter 操作(例如 scatter_nd_add 等) :

ref: A mutable Tensor. Must be one of the following types: blablabla. A mutable Tensor. Should be from a Variable node.

我遇到了同样的问题,并且在 ref 的 tf 变量上使用时效果很好。所有其他论据我想都可以保留,但我没有彻底调查。

关于python - 如何正确使用 tf.scatter_update 进行 N 维更新?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54096360/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com