gpt4 book ai didi

Tensorflow:如何在Estimator中使用生成器中的数据集

转载 作者:行者123 更新时间:2023-12-04 10:34:21 26 4
gpt4 key购买 nike

试图建立简单的模型只是为了弄清楚如何处理tf.data.Dataset.from_generator。我不明白如何设置output_shapes参数。我尝试了几种组合,包括未指定它,但由于张量的形状不匹配,仍然会收到一些错误。这个想法只是用SIZE = 10产生两个numpy数组,并对它们进行线性回归。这是代码:

SIZE = 10


def _generator():
feats = np.random.normal(0, 1, SIZE)
labels = np.random.normal(0, 1, SIZE)
yield feats, labels


def input_func_gen():
shapes = (SIZE, SIZE)
dataset = tf.data.Dataset.from_generator(generator=_generator,
output_types=(tf.float32, tf.float32),
output_shapes=shapes)
dataset = dataset.batch(10)
dataset = dataset.repeat(20)
iterator = dataset.make_one_shot_iterator()
features_tensors, labels = iterator.get_next()
features = {'x': features_tensors}
return features, labels


def train():
x_col = tf.feature_column.numeric_column(key='x', )
es = tf.estimator.LinearRegressor(feature_columns=[x_col])
es = es.train(input_fn=input_func_gen)

另一个问题是是否可以使用此功能为 tf.feature_column.crossed_column的要素列提供数据?总体目标是在批处理培训中使用 Dataset.from_generator功能,在这种情况下,如果数据不适合内存,则将数据从数据库加载到大块中。所有意见和示例都受到高度赞赏。

谢谢!

最佳答案

output_shapes 的可选tf.data.Dataset.from_generator()参数允许您指定生成器产生的值的形状。在其类型上有两个约束定义了应如何指定它:

  • output_shapes参数是一个“嵌套结构”(例如,元组,元组的元组,元组的字典等),必须与生成器产生的值的结构相匹配。

    在您的程序中,_generator()包含语句yield feats, labels。因此,“嵌套结构”是两个元素的元组(每个数组一个)。
  • output_shapes结构的每个组件都应与相应张量的形状匹配。数组的形状始终是尺寸的元组。 (tf.Tensor的形状更为笼统:请参见this Stack Overflow question进行讨论。)让我们看一下feats的实际形状:
    >>> SIZE = 10
    >>> feats = np.random.normal(0, 1, SIZE)
    >>> print feats.shape
    (10,)

  • 因此, output_shapes参数应为2元素元组,其中每个元素均为 (SIZE,):
    shapes = ((SIZE,), (SIZE,))
    dataset = tf.data.Dataset.from_generator(generator=_generator,
    output_types=(tf.float32, tf.float32),
    output_shapes=shapes)

    最后,您将需要为 tf.feature_column.numeric_column() tf.estimator.LinearRegressor() API提供更多有关形状的信息:
    x_col = tf.feature_column.numeric_column(key='x', shape=(SIZE,))
    es = tf.estimator.LinearRegressor(feature_columns=[x_col],
    label_dimension=10)

    关于Tensorflow:如何在Estimator中使用生成器中的数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48769142/

    26 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com