gpt4 book ai didi

python - 如何在 TensorFlow 中循环遍历张量的元素?

转载 作者:太空宇宙 更新时间:2023-11-04 00:34:27 26 4
gpt4 key购买 nike

我想使用 TensorFlow 的 tf.image.rot90() 创建一个函数 batch_rot90(batch_of_images),后者一次只获取一张图像,前者应该一次拍摄一批 n 张图像 (shape = [n,x,y,f])。

很自然地,我们应该遍历批处理中的所有图像并一张一张地旋转它们。在 numpy 中,这看起来像:

def batch_rot90(batch):
for i in range(batch.shape[0]):
batch_of_images[i] = rot90(batch[i,:,:,:])
return batch

这在 TensorFlow 中是如何完成的?使用 tf.while_loop 我得到了他的远方:

batch = tf.placeholder(tf.float32, shape=[2, 256, 256, 4])    
def batch_rot90(batch, k, name=''):
i = tf.constant(0)
def cond(batch, i):
return tf.less(i, tf.shape(batch)[0])
def body(im, i):
batch[i] = tf.image.rot90(batch[i], k)
i = tf.add(i, 1)
return batch, i
r = tf.while_loop(cond, body, [batch, i])
return r

但是不允许对 im[i] 赋值,我对 r 返回的内容感到困惑。

我意识到使用 tf.batch_to_space() 可能有针对这种特殊情况的解决方法,但我相信它也应该可以通过某种循环实现。

最佳答案

更新的答案:

x = tf.placeholder(tf.float32, shape=[2, 3])

def cond(batch, output, i):
return tf.less(i, tf.shape(batch)[0])

def body(batch, output, i):
output = output.write(i, tf.add(batch[i], 10))
return batch, output, i + 1

# TensorArray is a data structure that support dynamic writing
output_ta = tf.TensorArray(dtype=tf.float32,
size=0,
dynamic_size=True,
element_shape=(x.get_shape()[1],))
_, output_op, _ = tf.while_loop(cond, body, [x, output_ta, 0])
output_op = output_op.stack()

with tf.Session() as sess:
print(sess.run(output_op, feed_dict={x: [[1, 2, 3], [0, 0, 0]]}))

我认为您应该考虑使用 tf.scatter_update 来批量更新一个图像,而不是使用 batch[i] = ...。引用this link详细信息。对于您的情况,我建议将 body 的第一行更改为:

tf.scatter_update(batch, i, tf.image.rot90(batch[i], k))

关于python - 如何在 TensorFlow 中循环遍历张量的元素?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44690865/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com