gpt4 book ai didi

python - 估计器预测无限循环

转载 作者:太空狗 更新时间:2023-10-30 00:18:41 26 4
gpt4 key购买 nike

我不明白如何使用 TensorFlow Estimator API 进行单一预测 - 我的代码导致无限循环,不断预测相同的输入。

根据documentation ,预测应该在 input_fn 引发 StopIteration 异常时停止:

input_fn: Input function returning features which is a dictionary of string feature name to Tensor or SparseTensor. If it returns a tuple, first item is extracted as features. Prediction continues until input_fn raises an end-of-input exception (OutOfRangeError or StopIteration).

这是我代码中的相关部分:

classifier = tf.estimator.Estimator(model_fn=image_classifier, model_dir=output_dir,
config=training_config, params=hparams)

def make_predict_input_fn(filename):
queue = [ filename ]
def _input_fn():
if len(queue) == 0:
raise StopIteration
image = model.read_and_preprocess(queue.pop())
return {'image': image}
return _input_fn

predictions = classifier.predict(make_predict_input_fn('garden-rose-red-pink-56866.jpeg'))
for i, p in enumerate(predictions):
print("Prediction %s: %s" % (i + 1, p["class"]))

我错过了什么?

最佳答案

那是因为 input_fn() 需要是一个生成器。将您的函数更改为(yield 而不是 return):

def make_predict_input_fn(filename):
queue = [ filename ]
def _input_fn():
if len(queue) == 0:
raise StopIteration
image = model.read_and_preprocess(queue.pop())
yield {'image': image}
return _input_fn

关于python - 估计器预测无限循环,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47856852/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com