gpt4 book ai didi

python - 在每个类的数据集中拆分 tensorflow 数据集

转载 作者:太空宇宙 更新时间:2023-11-04 02:17:32 26 4
gpt4 key购买 nike

我有一个从一个 tfrecord 文件创建的数据集。该数据集包含 5 个不同的类。

现在我想从每个批处理中创建具有固定数量元素(例如 8 个)的批处理。因此它应该创建包含每个类 8 个元素的 40 个元素的批处理。

这可以用 tf.data 实现吗?

最佳答案

最简单的事情是(也许不是很方便):

a) 准备 5 个不同的 TFRecords,每个只包含一个特定类的元素。

b) 创建 5 个不同的 tf.data.TFRecordDataset 实例,因此创建 5 个不同的迭代器。

c) 然后在主代码中:

iterators =  [....] # Store your iterators in a list
data = list(map(lambda x : x.get_next(), iterators))
data_to_use = tf.concat(....) # Concat your data in one single batch of `40` elements.

另一种方法(不创建单独的数据集)

a) 仅使用一个 TFRecord。但是创建 5 它的不同实例

b) 在每个实例中,使用tf.data API 的tf.data.filter(predicate) 方法来过滤属于一个特定类的记录。为此,您必须编写一个函数,它可以检查每条记录的类。

c) 然后按照前面的解决方案中的步骤 c) 进行操作。

关于python - 在每个类的数据集中拆分 tensorflow 数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52299505/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com