gpt4 book ai didi

python - 类型错误 : unsupported callable using Dataset with estimator input_fn

转载 作者:太空狗 更新时间:2023-10-30 02:25:54 35 4
gpt4 key购买 nike

我正在尝试将 Iris 教程 ( https://www.tensorflow.org/get_started/estimator ) 转换为从 .png 文件而不是 .csv 文件中读取训练数据。它使用 numpy_input_fn 工作,但当我从 Dataset 制作它时不起作用。我认为 input_fn() 返回了错误的类型,但并不真正理解它应该是什么以及如何做到这一点。错误是:

  File "iris_minimal.py", line 27, in <module>
model_fn().train(input_fn(), steps=1)
...
raise TypeError('unsupported callable') from ex
TypeError: unsupported callable

TensorFlow 版本为 1.3。完整代码:

import tensorflow as tf
from tensorflow.contrib.data import Dataset, Iterator

NUM_CLASSES = 3

def model_fn():
feature_columns = [tf.feature_column.numeric_column("x", shape=[4])]
return tf.estimator.DNNClassifier([10, 20, 10], feature_columns, "tmp/iris_model", NUM_CLASSES)

def input_parser(img_path, label):
one_hot = tf.one_hot(label, NUM_CLASSES)
file_contents = tf.read_file(img_path)
image_decoded = tf.image.decode_png(file_contents, channels=1)
image_decoded = tf.image.resize_images(image_decoded, [2, 2])
image_decoded = tf.reshape(image_decoded, [4])
return image_decoded, one_hot

def input_fn():
filenames = tf.constant(['images/image_1.png', 'images/image_2.png'])
labels = tf.constant([0,1])
data = Dataset.from_tensor_slices((filenames, labels))
data = data.map(input_parser)
iterator = data.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels

model_fn().train(input_fn(), steps=1)

最佳答案

我注意到您的代码片段中有几个错误:

  • train方法接受输入函数,所以它应该是input_fn,而不是input_fn()
  • 这些特征应该是一本字典,例如{'x': 特征}
  • DNNClassifier使用 SparseSoftmaxCrossEntropyWithLogits 损失函数。 稀疏 意味着它采用序数类表示,而不是单热表示,因此您的转换是不必要的(this question 解释了 tf 中交叉熵损失之间的区别)。

试试下面的代码:

import tensorflow as tf
from tensorflow.contrib.data import Dataset

NUM_CLASSES = 3

def model_fn():
feature_columns = [tf.feature_column.numeric_column("x", shape=[4], dtype=tf.float32)]
return tf.estimator.DNNClassifier([10, 20, 10], feature_columns, "tmp/iris_model", NUM_CLASSES)

def input_parser(img_path, label):
file_contents = tf.read_file(img_path)
image_decoded = tf.image.decode_png(file_contents, channels=1)
image_decoded = tf.image.resize_images(image_decoded, [2, 2])
image_decoded = tf.reshape(image_decoded, [4])
label = tf.reshape(label, [1])
return image_decoded, label

def input_fn():
filenames = tf.constant(['input1.jpg', 'input2.jpg'])
labels = tf.constant([0,1], dtype=tf.int32)
data = Dataset.from_tensor_slices((filenames, labels))
data = data.map(input_parser)
data = data.batch(1)
iterator = data.make_one_shot_iterator()
features, labels = iterator.get_next()
return {'x': features}, labels

model_fn().train(input_fn, steps=1)

关于python - 类型错误 : unsupported callable using Dataset with estimator input_fn,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47120637/

35 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com