gpt4 book ai didi

python - tf.data.Dataset : how to get the dataset size (number of elements in a epoch)?

转载 作者:行者123 更新时间:2023-12-04 00:49:23 24 4
gpt4 key购买 nike

假设我以这种方式定义了一个数据集:

filename_dataset = tf.data.Dataset.list_files("{}/*.png".format(dataset))

我怎样才能获得数据集中的元素数量(因此,组成一个时代的单个元素的数量)?

我知道 tf.data.Dataset已经知道数据集的维度,因为 repeat()方法允许在指定数量的时期内重复输入管道。所以它一定是一种获取这些信息的方法。

最佳答案

tf.data.Dataset.list_files创建一个名为 MatchingFiles:0 的张量(如果适用,带有适当的前缀)。

你可以评估

tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]

获取文件数。

当然,这仅适用于简单的情况,特别是如果每​​个图像只有一个样本(或已知数量的样本)。

在更复杂的情况下,例如当您不知道每个文件中的样本数时,您只能在一个 epoch 结束时观察样本数。

为此,您可以查看由您的 Dataset 计算的 epoch 数。 . repeat()创建一个名为 _count 的成员,这计算时代的数量。通过在迭代期间观察它,您可以发现它何时发生变化并从那里计算您的数据集大小。

这个计数器可能被埋在 Dataset的层次结构中。 s是在连续调用成员函数时创建的,所以我们要这样挖出来。
d = my_dataset
# RepeatDataset seems not to be exposed -- this is a possible workaround
RepeatDataset = type(tf.data.Dataset().repeat())
try:
while not isinstance(d, RepeatDataset):
d = d._input_dataset
except AttributeError:
warnings.warn('no epoch counter found')
epoch_counter = None
else:
epoch_counter = d._count

请注意,使用此技术,数据集大小的计算并不准确,因为在此期间 epoch_counter递增通常混合来自两个连续时期的样本。所以这个计算精确到你的批次长度。

关于python - tf.data.Dataset : how to get the dataset size (number of elements in a epoch)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50737192/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com