gpt4 book ai didi

tensorflow - 如何使用 TensorFlow 打乱整个数据集?

转载 作者:行者123 更新时间:2023-12-03 04:32:29 25 4
gpt4 key购买 nike

现在我使用以下函数进行洗牌

from tensorflow.contrib import data
def input_pipeline(filenames, batch_size):
# Define a `tf.contrib.data.Dataset` for iterating over one epoch of the data.
dataset = data.TextLineDataset(filenames)
dataset = dataset.map(decode_func)
dataset = dataset.shuffle(buffer_size=10000) # Equivalent to min_after_dequeue=10000.
dataset = dataset.batch(batch_size)

# Return an *initializable* iterator over the dataset, which will allow us to
# re-initialize it at the beginning of each epoch.
return dataset.make_initializable_iterator()

但它只会以 buffer_size 的量对数据进行混洗,并且会按顺序填充 buffer

我的数据量很大,我不能将 buffer_size 设置得太大。还有其他解决方案可以打乱整个数据集吗?

最佳答案

目前,数据集 API 不支持对整个数据集(超过 10k 个示例)进行混洗。根据this线程,常见的做法是:

  1. Randomly shuffle the entire data once using a MapReduce/Spark/Beam/etc. job to create a set of roughly equal-sized files ("shards").
  2. In each epoch:

    a. Randomly shuffle the list of shard filenames, using Dataset.list_files(...).shuffle(num_shards).

    b. Use dataset.interleave(lambda filename: tf.data.TextLineDataset(filename), cycle_length=N) to mix together records from N different shards.

    c. Use dataset.shuffle(B) to shuffle the resulting dataset. Setting B might require some experimentation, but you will probably want to set it to some value larger than the number of records in a single shard.

关于tensorflow - 如何使用 TensorFlow 打乱整个数据集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44792761/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com