gpt4 book ai didi

python - 了解 tf.scatter_nd_update : How to update column values?

转载 作者:太空宇宙 更新时间:2023-11-04 04:18:49 25 4
gpt4 key购买 nike

我正在尝试将切片更新的 NumPy 操作转换为 TensorFlow。我想重现以下最小示例:

input = np.arange(3 * 5).reshape((3, 5))

array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])

input[:, [0, 2]] = -1

array([[-1, 1, -1, 3, 4],
[-1, 6, -1, 8, 9],
[-1, 11, -1, 13, 14]])

因此,我想为数组中某些列的所有元素设置一个常量值。

现在,我有张量而不是 NumPy 数组,列索引也被动态计算并存储在张量中。我找到了如何使用 tf.scatter_nd_update 更新给定 中的所有值:

input = tf.Variable(tf.reshape(tf.range(3 * 5, dtype=tf.int32), [3, 5]))                                                                                                                                                                                          
indices = tf.constant([[0], [2]])
updates = tf.constant([[-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1]])

scatter = tf.scatter_nd_update(input, indices, updates)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(scatter))

输出:

[[-1 -1 -1 -1 -1]
[ 5 6 7 8 9]
[-1 -1 -1 -1 -1]]

但是我怎样才能对某些列执行此操作?

最佳答案

你可以这样做:

import tensorflow as tf

def update_columns(variable, columns, value):
columns = tf.convert_to_tensor(columns)
rows = tf.range(tf.shape(variable)[0], dtype=columns.dtype)
ii, jj = tf.meshgrid(rows, columns, indexing='ij')
value = tf.broadcast_to(value, tf.shape(ii))
return tf.scatter_nd_update(variable, tf.stack([ii, jj], axis=-1), value)

inp = tf.Variable(tf.reshape(tf.range(3 * 5, dtype=tf.int32), [3, 5]))
updated = update_columns(inp, [0, 2], -1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(updated))

输出:

[[-1  1 -1  3  4]
[-1 6 -1 8 9]
[-1 11 -1 13 14]]

但是请注意,您应该只使用 tf.scatter_nd_update如果您真的想使用一个变量(并为其分配一个新值)。如果你想得到一个等于另一个张量但更新了一些值的张量,你应该使用常规张量操作而不是将其转换为变量。例如,对于这种情况,您可以这样做:

import tensorflow as tf

def update_columns_tensor(tensor, columns, value):
columns = tf.convert_to_tensor(columns)
shape = tf.shape(tensor)
num_rows, num_columns = shape[0], shape[1]
mask = tf.equal(tf.range(num_columns, dtype=columns.dtype), tf.expand_dims(columns, 1))
mask = tf.tile(tf.expand_dims(tf.reduce_any(mask, axis=0), 0), (num_rows, 1))
value = tf.broadcast_to(value, shape)
return tf.where(mask, value, tensor)

inp = tf.reshape(tf.range(3 * 5, dtype=tf.int32), [3, 5])
updated = update_columns_tensor(inp, [0, 2], -1)
with tf.Session() as sess:
print(sess.run(updated))
# Same output

关于python - 了解 tf.scatter_nd_update : How to update column values?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54890741/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com