gpt4 book ai didi

python - tensorflow map_fn TensorArray 具有不一致的形状

转载 作者:太空宇宙 更新时间:2023-11-03 13:11:09 24 4
gpt4 key购买 nike

我正在研究 map_fn 函数,注意到它输出一个 TensorArray,这应该意味着它能够输出“锯齿状”张量(其中内部的张量具有不同的第一维)。

我试着用这段代码看看这个:

import tensorflow as tf
import numpy as np

NUM_ARRAYS = 1000
MAX_LENGTH = 1000

lengths = tf.placeholder(tf.int32)
tArray = tf.map_fn(lambda x: tf.random_normal((x,), 0, 1),
lengths,
dtype=tf.float32) # Should return a TensorArray.

# startTensor = tf.random_normal((tf.reduce_sum(lengths),), 0, 1)
# tArray = tf.TensorArray(tf.float32, NUM_ARRAYS)
# tArray = tArray.split(startTensor, lengths)
# outArray = tArray.concat()


with tf.Session() as sess:
outputArray, l = sess.run(
[tArray, lengths],
feed_dict={lengths: np.random.randint(MAX_LENGTH, size=NUM_ARRAYS)})
print outputArray.shape, l

但是出现错误:

“TensorArray 的形状不一致。索引 0 的形状:[259],但索引 1 的形状:[773]”

这当然让我感到惊讶,因为我的印象是 TensorArrays 应该能够处理它。我错了吗?

最佳答案

虽然 tf.map_fn()确实使用 tf.TensorArray 内部对象,并且 tf.TensorArray 可以容纳不同大小的对象,这个程序不会按原样运行,因为 tf.map_fn() 通过将元素堆叠在一起,将其 tf.TensorArray 结果转换回 tf.Tensor,而这个操作失败了。

但是,您可以使用较低级别的 tf.while_loop() 实现基于 tf.TensorArray 的op 而不是:

lengths = tf.placeholder(tf.int32)
num_elems = tf.shape(lengths)[0]
init_array = tf.TensorArray(tf.float32, size=num_elems)

def loop_body(i, ta):
return i + 1, ta.write(i, tf.random_normal((lengths[i],), 0, 1))

_, result_array = tf.while_loop(
lambda i, ta: i < num_elems, loop_body, [0, init_array])

关于python - tensorflow map_fn TensorArray 具有不一致的形状,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43270849/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com