作者热门文章
- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
我想知道是否有一种方法可以对要在 Tensorflow 中生成的批处理应用约束。特别是,我想生产包含不同标签的批处理。
假设我有五个可能的标签,{A, B, C, D, E}
,我想要像 (A, C, E, D, B)
这样的批处理或(B,E,D,C,A)
。基本上,我想避免具有相同标签的批处理,例如 (A, A, D, E, C)
或(A, B, B, B, E)
.
最佳答案
批处理只是从输入的任何内容中提取 BATCH_SIZE
样本并将它们打包在一起,所以从技术上讲,是的,这是可能的。不过,这取决于您,确保 batch()
的输入按照您想要的方式排序。
最有效的方法可能是拥有 5 个 tf.data.Dataset,每个都有一个特定的标签,将它们压缩在一起以获得一个“批处理”数据集,其标签始终按相同顺序,然后在其上使用 .map
tf.random_shuffle
获得批处理的随机排列并将其馈送到您的网络。
我还会在随机排列后加入 .shuffle
,只是确保网络不会始终以相同的顺序看到相同的批处理。
代码如下:
data = [ tf.constant([chr(ord('A')+i), chr(ord('a')+i) ]) for i in range(5) ]
per_label_datasets = [tf.data.Dataset.from_tensor_slices(d) for d in data]
dataset = tf.data.Dataset.zip(tuple(per_label_datasets)) # now an item has shape len(per_label_datasets) and one item from each
dataset = dataset.map(lambda *args : tf.random_shuffle(args)) # lambda needed because random_shuffle takes only one argument
dataset = dataset.shuffle(10) # optional
it = dataset.make_one_shot_iterator()
batch = it.get_next()
sess = tf.Session()
print(sess.run(batch))
print(sess.run(batch))
示例输出:
[b'a' b'c' b'd' b'e' b'b']
[b'C' b'B' b'A' b'D' b'E']
我不知道您使用的是什么模型,并且我假设有一些模型对此有意义,但是在大多数模型中,批处理中的样本顺序毫无意义,因为结果是平均的计算损失时在一批中一起计算损失。因此,如果您确实需要它,可以通过多种方法来实现,但在开始对管道进行编码之前请确保您确实需要它。
关于python - Tensorflow:如何确保每批中的所有样本都有不同的标签?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49147811/
我是一名优秀的程序员,十分优秀!