gpt4 book ai didi

python - 在 tensorflow 中使用 scatter_update 和馈送的数据

转载 作者:行者123 更新时间:2023-12-01 08:27:49 24 4
gpt4 key购买 nike

我正在尝试使用 scatter_update 来更新张量的切片。我的第一个熟悉该函数的代码片段效果非常好。

import tensorflow as tf
import numpy as np

with tf.Session() as sess:
init_val = tf.Variable(tf.zeros((3, 2)))
indices = tf.constant([0, 1])
update = tf.scatter_update(init_val, indices, tf.ones((2, 2)))

init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(update))

但是当我尝试将初始值输入图表时

with tf.Session() as sess:
x = tf.placeholder(tf.float32, shape=(3, 2))
init_val = x
indices = tf.constant([0, 1])
update = tf.scatter_update(init_val, indices, tf.ones((2, 2)))

init = tf.global_variables_initializer()
sess.run(init)
print(sess.run(update, feed_dict={x: np.zeros((3, 2))}))

我收到奇怪的错误

InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder_1' with dtype float and shape [3,2]
[[{{node Placeholder_1}} = Placeholder[dtype=DT_FLOAT, shape=[3,2], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

tf.Variable 分配给 init_val 时将其放在 x 周围也没有帮助,因为我收到错误

AttributeError: 'Tensor' object has no attribute '_lazy_read'

(请参阅 Github 上的 this entry)。有人有想法吗?提前致谢!

我在 CPU 上使用 Tensorflow 1.12。

最佳答案

您可以通过构建和更新张量以及掩模张量来通过散射替换张量:

import tensorflow as tf
import numpy as np

with tf.Session() as sess:
x = tf.placeholder(tf.float32, shape=(3, 2))
init_val = x
indices = tf.constant([0, 1])
x_shape = tf.shape(x)
indices = tf.expand_dims(indices, 1)
replacement = tf.ones((2, 2))
update = tf.scatter_nd(indices, replacement, x_shape)
mask = tf.scatter_nd(indices, tf.ones_like(replacement, dtype=tf.bool), x_shape)
result = tf.where(mask, update, x)
print(sess.run(result, feed_dict={x: np.arange(6).reshape((3, 2))}))

输出:

[[1. 1.]
[1. 1.]
[4. 5.]]

关于python - 在 tensorflow 中使用 scatter_update 和馈送的数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54110085/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com