gpt4 book ai didi

python - 在 Tensorflow 2.0 中迭代无限重复的 tf.data 数据集的正确方法是什么

转载 作者:行者123 更新时间:2023-12-04 13:57:01 25 4
gpt4 key购买 nike

TF2.0 文档建议使用 python for 循环迭代数据集:

for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
# do training

问题是,如果数据集无限重复(据我所知,出于性能原因,这是有道理的)这个循环将永远不会结束。

我目前正在做的是设置一些我想要迭代的时期和训练步骤:
train_iter = iter(train_dataset)
for i in range(num_epochs):
# do some setup
for step in range(num_batches):
(x_batch, y_batch) = next(train_iter)
# do training
# log metrics

我不确定的是这是否会对我的训练过程的表现产生负面影响。这是否会使我的训练运行速度变慢,或者我是否会通过像这样运行我的训练来阻止 Tensorflow 优化我的代码?
最重要的是,设置一个时期内要处理的批次数量可能有点烦人,因为我想在我的数据管道中进行随机扩充。因此,我的数据集中唯一样本的数量在不同的训练类(class)之间可能会有所不同。不过这也不是什么大问题。

我试图通过谷歌找到答案,但不幸的是没有运气。

最佳答案

代码的问题,

train_iter = iter(train_dataset)
for i in range(num_epochs):
# do some setup
for step in range(num_batches):
(x_batch, y_batch) = next(train_iter)

是每个 epoch , model看到 batches以相同的顺序,这是效率不高的。

此类代码的示例输出如下所示:
tf.Tensor(4, shape=(), dtype=int64) tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(8, shape=(), dtype=int64) tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(9, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64) tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(8, shape=(), dtype=int64) tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(9, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64) tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(8, shape=(), dtype=int64) tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(9, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)

如上所示,每个 Epoch 对应的值是相同的,或者换句话说, Batches为每个 epoch 重复( 4, 0, 8, 6, 73,1,2,9,5 重复三次)。

优化高效的方式通过 batches不同的顺序是使用参数, reshuffle_each_iteration=True .示例代码如下所示:
import tensorflow as tf

dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(buffer_size=5, reshuffle_each_iteration=True)
iter(dataset)

buffer_size = 10
batch_size = 2

for epoch in range(num_epochs):
dataset_epoch = dataset.batch(batch_size)
for x, y in dataset_epoch:
print(x,y)

上面代码的输出如下所示,可以观察到与任何批次对应的值都没有重复:
tf.Tensor(2, shape=(), dtype=int64) tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(3, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64) tf.Tensor(6, shape=(), dtype=int64)
tf.Tensor(9, shape=(), dtype=int64) tf.Tensor(4, shape=(), dtype=int64)
tf.Tensor(5, shape=(), dtype=int64) tf.Tensor(8, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(7, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64) tf.Tensor(9, shape=(), dtype=int64)
tf.Tensor(3, shape=(), dtype=int64) tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(8, shape=(), dtype=int64) tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(2, shape=(), dtype=int64) tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(1, shape=(), dtype=int64) tf.Tensor(7, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64) tf.Tensor(8, shape=(), dtype=int64)
tf.Tensor(9, shape=(), dtype=int64) tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(0, shape=(), dtype=int64) tf.Tensor(4, shape=(), dtype=int64)

希望这可以帮助。快乐学习!

关于python - 在 Tensorflow 2.0 中迭代无限重复的 tf.data 数据集的正确方法是什么,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60266064/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com