gpt4 book ai didi

tensorflow - TF数据API : How to produce tensorflow input to object set recognition

转载 作者:行者123 更新时间:2023-12-02 20:23:51 26 4
gpt4 key购买 nike

考虑这个问题:从图像数据集(如 ImageNet)中的随机主题中选择随机数量的样本作为 Tensorflow 图的输入元素,该图充当对象集识别器。对于每个批处理,每个类都有相同数量的样本以方便计算。但是不同的批处理对于一个类会有不同数量的图像,即batch_0:num_imgs_per_cls=2; batch_1000:num_imgs_per_cls=3。

如果 Tensorflow 中有现有功能,我们将非常感谢从头开始解释整个过程(例如从图像目录)。

最佳答案

@mrry here 有一个非常相似的答案.

平衡批处理抽样

在人脸识别中,我们经常使用三元组损失(或类似的损失)来训练模型。对三元组进行采样以计算损失的常用方法是创建一批平衡的图像,其中我们有 10 个不同的类别(即 10 个不同的人),每个类别有 5 张图像。在此示例中,总批量大小为 50。

更普遍的问题是对num_classes_per_batch(示例中为10个)类进行采样,然后对每个类的num_images_per_class(示例中为5个)图像进行采样。总批量大小为:

batch_size = num_classes_per_batch * num_images_per_class

每个类别都有一个数据集

处理大量不同类别(MS-Celeb 中为 100,000 个)的最简单方法是为每个类别创建一个数据集。
例如,您可以为每个类拥有一个 tfrecord 并创建如下数据集:

# Build one dataset per class.
filenames = ["class_0.tfrecords", "class_1.tfrecords"...]
per_class_datasets = [tf.data.TFRecordDataset(f).repeat(None) for f in filenames]

数据集中的示例

现在我们希望能够从这些数据集中进行采样。例如,我们希望批处理中包含以下标签:

1 1 1 3 3 3 9 9 9 4 4 4

这对应于 num_classes_per_batch=4num_images_per_class=3

为此,我们需要使用将在 r1.9 中发布的功能。该函数应名为 tf.contrib.data.choose_from_datasets(有关此问题的讨论,请参阅 here)。
它应该看起来像:

def choose_from_datasets(datasets, selector):
"""Chooses elements with indices from selector among the datasets in `datasets`."""

因此,我们创建这个选择器,它将输出1 1 1 3 3 3 9 9 9 4 4 4并将其与数据集结合起来获取将输出平衡批处理的最终数据集:

def generator(_):
# Sample `num_classes_per_batch` classes for the batch
sampled = tf.random_shuffle(tf.range(num_classes))[:num_classes_per_batch]
# Repeat each element `num_images_per_class` times
batch_labels = tf.tile(tf.expand_dims(sampled, -1), [1, num_images_per_class])
return tf.to_int64(tf.reshape(batch_labels, [-1]))

selector = tf.contrib.data.Counter().map(generator)
selector = selector.apply(tf.contrib.data.unbatch())

dataset = tf.contrib.data.choose_from_datasets(datasets, selector)

# Batch
batch_size = num_classes_per_batch * num_images_per_class
dataset = dataset.batch(batch_size)

您可以使用夜间 TensorFlow 构建并使用 DirectedInterleaveDataset 作为解决方法来测试这一点:

# The working option right now is 
from tensorflow.contrib.data.python.ops.interleave_ops import DirectedInterleaveDataset
dataset = DirectedInterleaveDataset(selector, datasets)

我还写过有关此解决方法的文章 here .

关于tensorflow - TF数据API : How to produce tensorflow input to object set recognition,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50550126/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com