gpt4 book ai didi

python - 具有部分洗牌功能的 Tensorflow 数据集

转载 作者:行者123 更新时间:2023-11-30 22:15:34 25 4
gpt4 key购买 nike

我正在使用 TensorFlow 的数据集 API,根据文档,我对 shuffle() 方法感到困惑:

The Dataset.shuffle() transformation randomly shuffles the input dataset using a similar algorithm to tf.RandomShuffleQueue: it maintains a fixed-size buffer and chooses the next element uniformly at random from that buffer.

如果我仅“部分”打乱数据集(例如 buffer_size <= 元素数量),我预计只会打乱前 buffer_size 元素,但事实并非如此,参见示例:

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8])
.shuffle(buffer_size=4, seed=42)
.batch(2)
iter = dataset.make_initializable_iterator() # create the iterator
el = iter.get_next()
with tf.Session() as sess:
sess.run(iter.initializer)
print('batch:', sess.run(el))

输出:

batch: [2 5]

为什么这里是5?因为缓冲区大小只有 4?前 2 个元素应该在 1~4 之内,对吗?我在这里缺少什么?

谢谢

最佳答案

简短的回答是,可以随时补充洗牌缓冲区,包括在创建批处理的过程中。

您的观察可能是这样发生的:

  • 数据集读取数据中的前 4 个元素。随机播放缓冲区现在包含 [1, 2, 3, 4]
  • 您请求两个元素(通过数据集上的 get_next() 来创建 2 个批处理)
  • 随机数据集选择 2 并将下一个元素读取到随机缓冲区中,该缓冲区现在包含 [1, 3, 4, 5]。
  • 随机数据集从缓冲区中选取 5 个。
  • 您的批处理 [2, 5] 已退回。

关于python - 具有部分洗牌功能的 Tensorflow 数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50255035/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com