gpt4 book ai didi

python - Tensorflow:如何确保每批中的所有样本都有不同的标签?

转载 作者:行者123 更新时间:2023-12-01 02:05:10 27 4
gpt4 key购买 nike

我想知道是否有一种方法可以对要在 Tensorflow 中生成的批处理应用约束。特别是,我想生产包含不同标签的批处理。

假设我有五个可能的标签,{A, B, C, D, E} ,我想要像 (A, C, E, D, B) 这样的批处理或(B,E,D,C,A) 。基本上,我想避免具有相同标签的批处理,例如 (A, A, D, E, C)(A, B, B, B, E) .

最佳答案

实现您的要求

批处理只是从输入的任何内容中提取 BATCH_SIZE 样本并将它们打包在一起,所以从技术上讲,是的,这是可能的。不过,这取决于您,确保 batch() 的输入按照您想要的方式排序。

最有效的方法可能是拥有 5 个 tf.data.Dataset,每个都有一个特定的标签,将它们压缩在一起以获得一个“批处理”数据集,其标签始终按相同顺序,然后在其上使用 .map tf.random_shuffle获得批处理的随机排列并将其馈送到您的网络。

我还会在随机排列后加入 .shuffle,只是确保网络不会始终以相同的顺序看到相同的批处理。

代码如下:

data = [ tf.constant([chr(ord('A')+i), chr(ord('a')+i) ]) for i in range(5) ]

per_label_datasets = [tf.data.Dataset.from_tensor_slices(d) for d in data]
dataset = tf.data.Dataset.zip(tuple(per_label_datasets)) # now an item has shape len(per_label_datasets) and one item from each
dataset = dataset.map(lambda *args : tf.random_shuffle(args)) # lambda needed because random_shuffle takes only one argument
dataset = dataset.shuffle(10) # optional

it = dataset.make_one_shot_iterator()
batch = it.get_next()

sess = tf.Session()
print(sess.run(batch))
print(sess.run(batch))

示例输出:

[b'a' b'c' b'd' b'e' b'b']
[b'C' b'B' b'A' b'D' b'E']

个人备注

我不知道您使用的是什么模型,并且我假设有一些模型对此有意义,但是在大多数模型中,批处理中的样本顺序毫无意义,因为结果是平均的计算损失时在一批中一起计算损失。因此,如果您确实需要它,可以通过多种方法来实现,但在开始对管道进行编码之前请确保您确实需要它。

关于python - Tensorflow:如何确保每批中的所有样本都有不同的标签?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49147811/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com