gpt4 book ai didi

tensorflow - 使用数据集API生成平衡的迷你批处理

转载 作者:行者123 更新时间:2023-12-04 04:15:14 25 4
gpt4 key购买 nike

我对新的数据集API(tensorflow 1.4rc1)有疑问。
我有一个不平衡的数据集wrt来标记01。我的目标是在预处理过程中创建平衡的微型批次。

假设我有两个过滤后的数据集:

ds_pos = dataset.filter(lambda l, x, y, z: tf.reshape(tf.equal(l, 1), []))
ds_neg = dataset.filter(lambda l, x, y, z: tf.reshape(tf.equal(l, 0), [])).repeat()

有没有一种方法可以合并这两个数据集,使结果数据集看起来像 ds = [0, 1, 0, 1, 0, 1]:

像这样的东西:
dataset = tf.data.Dataset.zip((ds_pos, ds_neg))
dataset = dataset.apply(...)
# dataset looks like [0, 1, 0, 1, 0, 1, ...]
dataset = dataset.batch(20)

我当前的方法是:
def _concat(x, y):
return tf.cond(tf.random_uniform(()) > 0.5, lambda: x, lambda: y)
dataset = tf.data.Dataset.zip((ds_pos, ds_neg))
dataset = dataset.map(_concat)

但是我感觉有一种更优雅的方式。

提前致谢!

最佳答案

您走在正确的轨道上。以下示例使用Dataset.flat_map()将每对正例和负例转换为结果中的两个连续例:

dataset = tf.data.Dataset.zip((ds_pos, ds_neg))

# Each input element will be converted into a two-element `Dataset` using
# `Dataset.from_tensors()` and `Dataset.concatenate()`, then `Dataset.flat_map()`
# will flatten the resulting `Dataset`s into a single `Dataset`.
dataset = dataset.flat_map(
lambda ex_pos, ex_neg: tf.data.Dataset.from_tensors(ex_pos).concatenate(
tf.data.Dataset.from_tensors(ex_neg)))

dataset = dataset.batch(20)

关于tensorflow - 使用数据集API生成平衡的迷你批处理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46938530/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com