gpt4 book ai didi

python - tensorflow 数据集 tf.estimator.inputs.numpy_input_fn

转载 作者:太空宇宙 更新时间:2023-11-04 07:30:24 25 4
gpt4 key购买 nike

我正在编写代码,用于在 tensorflow 中从光盘读取图像和标签,然后尝试调用 tf.estimator.inputs.numpy_input_fn。我怎样才能传递整个数据集而不是单个图像。我的代码如下所示:

filenames = tf.constant(filenames)
labels = tf.constant(labels)

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
dataset_batched = dataset.batch(10)
iterator = dataset_batched.make_one_shot_iterator()
features, labels = iterator.get_next()

with tf.Session() as sess:

print(dataset_batched)
print(np.shape(sess.run(features)))
print(np.shape(sess.run(labels)))

mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_mk, model_dir=dir)
train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": np.array(sess.run(features))},
y=np.array(sess.run(labels)),
batch_size=1,
num_epochs=None,
shuffle=False)
mnist_classifier.train(input_fn=train_input_fn, steps=1)

我的问题是如何在此处传递数据集 x={"x": np.array(sess.run(features))}

最佳答案

此处不需要/使用numpy_input_fn。您应该将顶部的代码包装到一个返回 iterator.get_next() 的函数(例如,my_input_fn)中,然后传递 input_fn=my_input_fn 进入 train 调用。这会将完整的数据集以 10 个为一组传递给训练代码。

numpy_input_fn 适用于当您已经在数组中拥有可用的完整数据集并且想要一种快速的方法来执行批处理/改组/重复等操作时。

关于python - tensorflow 数据集 tf.estimator.inputs.numpy_input_fn,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48956404/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com