gpt4 book ai didi

tensorflow - 交错 tf.data.Datasets

转载 作者:行者123 更新时间:2023-12-03 01:41:24 27 4
gpt4 key购买 nike

我正在尝试使用 tf.data.Dataset 来交错两个数据集,但这样做时遇到问题。给出这个简单的例子:

ds0 = tf.data.Dataset()
ds0 = ds0.range(0, 10, 2)
ds1 = tf.data.Dataset()
ds1 = ds1.range(1, 10, 2)
dataset = ...
iter = dataset.make_one_shot_iterator()
val = iter.get_next()

什么是 ... 来产生像 0, 1, 2, 3...9 这样的输出?

看起来 dataset.interleave() 是相关的,但我无法以不产生错误的方式制定该语句。

最佳答案

MattScarpino 在 his comment 中走在正确的轨道上。您可以使用Dataset.zip()以及Dataset.flat_map()展平多元素数据集:

ds0 = tf.data.Dataset.range(0, 10, 2)
ds1 = tf.data.Dataset.range(1, 10, 2)

# Zip combines an element from each input into a single element, and flat_map
# enables you to map the combined element into two elements, then flattens the
# result.
dataset = tf.data.Dataset.zip((ds0, ds1)).flat_map(
lambda x0, x1: tf.data.Dataset.from_tensors(x0).concatenate(
tf.data.Dataset.from_tensors(x1)))

iter = dataset.make_one_shot_iterator()
val = iter.get_next()

话虽如此,您对使用 Dataset.interleave() 的直觉是相当明智的。我们正在研究可以让您更轻松地完成此操作的方法。

<小时/>PS。作为替代方案,如果您更改 ds0ds1 的方式,您可以使用 Dataset.interleave() 来解决问题> 定义:

dataset = tf.data.Dataset.range(2).interleave(
lambda x: tf.data.Dataset.range(x, 10, 2), cycle_length=2, block_length=1)

关于tensorflow - 交错 tf.data.Datasets,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47343228/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com