gpt4 book ai didi

tensorflow - 数据集 API、迭代器和 tf.contrib.data.rejection_resample

转载 作者:行者123 更新时间:2023-12-04 22:18:36 27 4
gpt4 key购买 nike

[在@mrry 评论后编辑#1]
我正在使用(伟大而惊人的)数据集 API 以及 tf.contrib.data.rejection_resample
为输入训练管道设置特定的分布函数。

在将 tf.contrib.data.rejection_resample 添加到 input_fn 之前,我使用了一次性迭代器。唉,当开始使用后者时,我尝试使用 dataset.make_initializable_iterator() - 这是因为我们正在引入管道状态变量,并且需要在输入管道中的所有变量都初始化之后初始化迭代器。
正如@mrry 所写 here.

我将 input_fn 传递给估计器并由实验包装。

问题是 - 在哪里 Hook 迭代器的 init?
如果我尝试:

dataset = dataset.batch(batch_size)
if self.balance:
dataset = tf.contrib.data.rejection_resample(dataset, self.class_mapping_function, self.dist_target)
iterator = dataset.make_initializable_iterator()
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
else:
iterator = dataset.make_one_shot_iterator()

image_batch, label_batch = iterator.get_next()
print (image_batch)

和映射函数:
def class_mapping_function(self, feature, label):
"""
returns a a function to be used with dataset.map() to return class numeric ID
The function is mapping a nested structure of tensors (having shapes and types defined by dataset.output_shapes
and dataset.output_types) to a scalar tf.int32 tensor. Values should be in [0, num_classes).
"""
# For simplicity, trying to return the label itself as I assume its numeric...

return tf.cast(label, tf.int32) # <-- I guess this is the bug

迭代器不像单次迭代器那样接收张量形状。

例如。
使用 One Shot 迭代器运行,迭代器得到正确的形状:
Tensor("train_input_fn/IteratorGetNext:0", shape=(?, 100, 100, 3), dtype=float32, device=/device:CPU:0)

但是当使用可初始化迭代器时,它缺少张量形状信息:
Tensor("train_input_fn/IteratorGetNext:0", shape=(?,), dtype=int32, device=/device:CPU:0)

任何帮助将不胜感激!

[ 编辑#2 ]- 在@mrry 评论之后,它似乎是另一个数据集]
也许这里真正的问题不是迭代器的初始化序列,而是 tf.contrib.data.rejection_resample 使用的映射函数,它返回 tf.int32。但是后来我想知道应该如何定义映射函数?例如,要将数据集形状保持为 (?,100,100,3)...

[ 编辑#3 ]:来自rejection_resample的实现
class_values_ds = dataset.map(class_func)

因此,class_func 将获取一个数据集并返回一个 tf.int32 的数据集是有意义的。

最佳答案

在@mrry 响应之后,我可以想出一个关于如何将数据集 API 与 tf.contrib.data.rejection_resample 一起使用的解决方案(使用 TF1.3)。

目标

给定具有某种分布的特征/标签数据集,让输入管道将分布 reshape 为特定的目标分布。

数值示例

假设我们正在构建一个网络来将某些特征分类为 10 个类别之一。
并假设我们只有 100 个带有随机标签分布的特征。
30 个特征标记为 1 类,5 个特征标记为 2 类
等等。
在训练期间,我们不希望类 1 优于类 2,因此我们希望每个 mini-batch 为所有类保持均匀分布。

解决方案

使用 tf.contrib.data.rejection_resample 将允许为我们的输入管道设置特定的分布。

在文档中它说 tf.contrib.data.rejection_resample 将采取

(1) 数据集 - 这是您要平衡的数据集

(2) class_func - 这是一个仅从原始数据集生成新数字标签数据集的函数

(3) target_dist - 一个向量中的类数的大小,以具体化所需的新分布。

(4) 更多可选值 - 暂时跳过

正如文档所说,它返回一个`Dataset.

事实证明,输入数据集的形状与输出数据集的形状不同。因此,返回的数据集(在 TF1.3 中实现)应由用户过滤,如下所示:

    balanced_dataset = tf.contrib.data.rejection_resample(input_dataset,
self.class_mapping_function,
self.target_distribution)

# Return to the same Dataset shape as was the original input
balanced_dataset = balanced_dataset.map(lambda _, data: (data))

关于迭代器类型的一个注释。
正如@mrry 解释的 here ,当在管道中使用有状态对象时,应该使用可初始化的迭代器而不是 one-hot。请注意,当使用可初始化迭代器时,您应该将 init_op 添加到 TABLE_INITIALIZERS 中,否则您将收到此错误:“GetNext() 失败,因为迭代器尚未初始化。”

代码示例:
# Creating the iterator, that allows to access elements from the dataset
if self.use_balancing:
# For balancing function, we use stateful variables in the sense that they hold current dataset distribution
# and calculate next distribution according to incoming examples.
# For dataset pipeline that have state, one_shot iterator will not work, and we are forced to use
# initializable iterator
# This should be relaxed in the future.
# https://stackoverflow.com/questions/44374083/tensorflow-cannot-capture-a-stateful-node-by-value-in-tf-contrib-data-api
iterator = dataset.make_initializable_iterator()
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)

else:
iterator = dataset.make_one_shot_iterator()

image_batch, label_batch = iterator.get_next()

行得通吗?

是的。
这是在输入管道标签上收集直方图后来自 Tensorboard 的 2 张图像。
原始输入标签是均匀分布的。
场景 A:试图实现以下 10 类分布:
[0.1, 0.4 ,0.05,0.05,0.05,0.05,0.05,0.05,0.1,0.1]

结果:

enter image description here

场景 B:试图实现以下 10 类分布:
[0.1,0.1,0.05,0.05,0.05,0.05,0.05,0.05, 0.4 ,0.1]

结果:

enter image description here

关于tensorflow - 数据集 API、迭代器和 tf.contrib.data.rejection_resample,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47039760/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com