gpt4 book ai didi

python - Tensorflow:tf.data.Dataset,无法在组件0中批量处理具有不同形状的张量

转载 作者:行者123 更新时间:2023-12-04 23:12:14 27 4
gpt4 key购买 nike

输入管道中出现以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot batch tensors with different shapes in component 0. First element had shape [2,48,48,3] and element 1 had shape [27,48,48,3].



用这个代码
dataset = tf.data.Dataset.from_generator(generator,
(tf.float32, tf.int64, tf.int64, tf.float32, tf.int64, tf.float32))

dataset = dataset.batch(max_buffer_size)

这是完全合乎逻辑的,因为批处理方法尝试创建(batch_size,?,48,48,3)张量。但是我希望它为这种情况创建一个[29,48,48,3]张量。因此串联而不是堆栈。 tf.data有可能吗?

我可以在Python的generator函数中进行串联,但是我想知道tf.data管道是否也可以实现

最佳答案

第一种情况:我们希望输出具有固定的批量大小

在这种情况下,生成器生成[None, 48, 48, 3]形状的值,其中第一维可以是任何值。我们要对此进行批处理,以便输出为[batch_size, 48, 48, 3]。如果直接使用tf.data.Dataset.batch,则会出现错误,因此我们需要先取消对的批处理。

为此,我们可以在批处理之前像这样使用 tf.contrib.data.unbatch :

dataset = dataset.apply(tf.contrib.data.unbatch())
dataset = dataset.batch(batch_size)

这是一个完整的示例,其中生成器生成 [1][2, 2][3, 3, 3][4, 4, 4, 4]

我们无法直接批处理这些输出值,因此我们先取消批处理然后再批处理它们:

def gen():
for i in range(1, 5):
yield [i] * i

# Create dataset from generator
# The output shape is variable: (None,)
dataset = tf.data.Dataset.from_generator(gen, tf.int64, tf.TensorShape([None]))

# The issue here is that we want to batch the data
dataset = dataset.apply(tf.contrib.data.unbatch())
dataset = dataset.batch(2)

# Create iterator from dataset
iterator = dataset.make_one_shot_iterator()
x = iterator.get_next() # shape (None,)

sess = tf.Session()
for i in range(5):
print(sess.run(x))

这将打印以下输出:
[1 2]
[2 3]
[3 3]
[4 4]
[4 4]

第二种情况:我们要串联可变大小的批次

更新(03/30/2018):我删除了以前使用分片的答案,这大大降低了性能(请参阅评论)。

在这种情况下,我们要串联一定数量的批次。问题是这些批次的大小可变。例如,数据集产生 [1][2, 2],我们希望获得 [1, 2, 2]作为输出。

解决此问题的一种快速方法是创建一个环绕原始发电机的新发电机。新的生成器将产生批处理数据。 (感谢 Guillaume的想法)

这是一个完整的示例,其中生成器生成 [1][2, 2][3, 3, 3][4, 4, 4, 4]

def gen():
for i in range(1, 5):
yield [i] * i

def get_batch_gen(gen, batch_size=2):
def batch_gen():
buff = []
for i, x in enumerate(gen()):
if i % batch_size == 0 and buff:
yield np.concatenate(buff, axis=0)
buff = []
buff += [x]

if buff:
yield np.concatenate(buff, axis=0)

return batch_gen

# Create dataset from generator
batch_size = 2
dataset = tf.data.Dataset.from_generator(get_batch_gen(gen, batch_size),
tf.int64, tf.TensorShape([None]))

# Create iterator from dataset
iterator = dataset.make_one_shot_iterator()
x = iterator.get_next() # shape (None,)


with tf.Session() as sess:
for i in range(2):
print(sess.run(x))

这将打印以下输出:
[1 2 2]
[3 3 3 4 4 4 4]

关于python - Tensorflow:tf.data.Dataset,无法在组件0中批量处理具有不同形状的张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49531286/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com