gpt4 book ai didi

python - 使用估算器训练 Tensorflow 模型 (from_generator)

转载 作者:太空狗 更新时间:2023-10-30 02:15:33 28 4
gpt4 key购买 nike

我正在尝试使用生成器训练估算器,但我想为该估算器提供每次迭代的样本包。我显示代码:

def _generator():
for i in range(100):
feats = np.random.rand(4,2)
labels = np.random.rand(4,1)

yield feats, labels


def input_func_gen():
shapes = ((4,2),(4,1))
dataset = tf.data.Dataset.from_generator(generator=_generator,
output_types=(tf.float32, tf.float32),
output_shapes=shapes)
dataset = dataset.batch(4)
# dataset = dataset.repeat(20)
iterator = dataset.make_one_shot_iterator()
features_tensors, labels = iterator.get_next()
features = {'x': features_tensors}
return features, labels


x_col = tf.feature_column.numeric_column(key='x', shape=(4,2))
es = tf.estimator.LinearRegressor(feature_columns=[x_col],model_dir=tf_data)
es = es.train(input_fn=input_func_gen,steps = None)

当我运行这段代码时,它引发了这个错误:

    raise ValueError(err.message)
ValueError: Dimensions must be equal, but are 2 and 3 for 'linear/head/labels/assert_equal/Equal' (op: 'Equal') with input shapes: [2], [3].

我必须如何调用此结构?

谢谢!!!

最佳答案

批量大小由 Tensorflow 自动计算并添加到张量形状中,因此无需手动完成。您的生成器还应定义为输出单个样本。

假设形状的位置 0 中的 4 是批量大小,那么:

import tensorflow as tf
import numpy

def _generator():
for i in range(100):
feats = numpy.random.rand(2)
labels = numpy.random.rand(1)

yield feats, labels


def input_func_gen():
shapes = ((2),(1))
dataset = tf.data.Dataset.from_generator(generator=_generator,
output_types=(tf.float32, tf.float32),
output_shapes=shapes)
dataset = dataset.batch(4)
# dataset = dataset.repeat(20)
iterator = dataset.make_one_shot_iterator()
features_tensors, labels = iterator.get_next()
features = {'x': features_tensors}
return features, labels


x_col = tf.feature_column.numeric_column(key='x', shape=(2))
es = tf.estimator.LinearRegressor(feature_columns=[x_col])
es = es.train(input_fn=input_func_gen,steps = None)

关于python - 使用估算器训练 Tensorflow 模型 (from_generator),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49673602/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com