gpt4 book ai didi

tensorflow - 如何打乱连接的 Tensorflow 数据集

转载 作者:行者123 更新时间:2023-12-04 06:27:54 24 4
gpt4 key购买 nike

我有多个具有相同结构的 tensorflow 数据集。
我想将它们组合成一个单一的数据集。使用
tf.dataset.concatenate

但我发现在对这个组合数据集进行混洗时,数据集不会在整个数据集的规模上混洗。但是在每个分离的数据集中进行了洗牌。

有没有办法解决这个问题?

最佳答案

当您连接两个 Dataset s,你得到第一个元素,然后是第二个元素。如果对结果进行混洗,如果混洗缓冲区小于 Dataset 的大小,您将无法获得良好的混合效果。 .

相反,您需要的是交错数据集中的样本。如果您使用 TF >= 1.9,最好的方法是使用专用的 tf.contrib.data.choose_from_datasets 功能。直接来自文档的示例:

datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
tf.data.Dataset.from_tensors("bar").repeat(),
tf.data.Dataset.from_tensors("baz").repeat()]

# Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
choice_dataset = tf.data.Dataset.range(3).repeat(3)

result = tf.contrib.data.choose_from_datasets(datasets, choice_dataset)

如果在批处理中保留样本顺序和/或它们的比率很重要,则对输入数据集进行混洗可能会更好。

如果您使用的是较早版本的 TF,则可以依赖 zip 的组合。 , flat_mapconcatenate像这样:
a = tf.data.Dataset.range(3).repeat()
b = tf.data.Dataset.range(100, 105).repeat()

value = (tf.data.Dataset
.zip((a, b))
.flat_map(lambda x, y: tf.data.Dataset.concatenate(
tf.data.Dataset.from_tensors([x]),
tf.data.Dataset.from_tensors([y])))
.make_one_shot_iterator()
.get_next())

sess = tf.InteractiveSession()

for _ in range(10):
print(value.eval())

关于tensorflow - 如何打乱连接的 Tensorflow 数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51764893/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com