gpt4 book ai didi

python - TF2.0 Data API 从每个类标签中获取 n_i 个样本

转载 作者:太空宇宙 更新时间:2023-11-04 04:05:00 25 4
gpt4 key购买 nike

我必须使用 TF2 Keras 模型将形状为 32x32 的输入分为 3 类。我的训练集有 7000 个例子

>>> X_train.shape # (7000, 32, 32)
>>> Y_train.shape # (7000, 3)

每个类别的示例数量各不相同(例如,class_0 有大约 2500 个示例,而 class_1 有大约 800 个,等等)

我想使用 tf.data API 创建一个数据集对象,该对象返回不带编号的训练数据批处理。来自 [n_0, n_1, n_2] 指定的每个类的示例。

我想从每个类中随机抽取这些 n_i 个样本,并替换 X_train, Y_train

例如,如果我调用 get_batch([100, 150, 125]),它应该从 class_0 的 X_batch 返回 100 个随机样本,从 class_1 返回 150 个,以及 125来自 class_2。

我如何使用 TF2.0 数据 API 实现这一点,以便我可以使用它来训练 Keras 模型?

最佳答案

一种可能的方法是按如下方式进行:

  1. 将来自 X_trainY_train 的数据加载到单个 tf.data 数据集中,以便我们确保保留每个 X 与正确的 Y
  2. 匹配
  3. .shuffle() 然后使用 filter()
  4. 将数据集拆分为每个 n_i
  5. 编写我们的 get_batch 函数以从每个数据集中返回正确数量的样本,shuffle() 样本然后将其拆分回 X & Y

像这样:

# 1: Load the data into a Dataset
raw_data = tf.data.Dataset.zip(
(
tf.data.Dataset.from_tensor_slices(X_train),
tf.data.Dataset.from_tensor_slices(Y_train)
)
).shuffle(7000)


# 2: Split for each category
def get_filter_fn(n):
def filter_fn(x, y):
return tf.equal(1.0, y[n])
return filter_fn

n_0s = raw_data.filter(get_filter_fn(0))
n_1s = raw_data.filter(get_filter_fn(1))
n_2s = raw_data.filter(get_filter_fn(2))

# 3:
def get_batch(n_0,n_1,n_2):
sample = n_0s.take(n_0).concatenate(n_1s.take(n_1)).concatenate(n_2s.take(n_2))
shuffled = sample.shuffle(n_0 + n_1 + n_2)
return shuffled.map(lambda x,y: x),shuffled.map(lambda x,y: y)

现在我们可以做:

x_batch, y_batch = get_batch(100, 150, 125)

请注意,我在这里使用了一些潜在的浪费操作来追求一种我认为直观和直接的方法(特别是读取 raw_data 数据集 3 次以进行过滤操作)所以我没有声称这是完成您需要的最有效方法,但对于像您描述的那样适合内存的数据集,我相信这种低效率可以忽略不计

关于python - TF2.0 Data API 从每个类标签中获取 n_i 个样本,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57530480/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com