gpt4 book ai didi

python - 如何从检查点使用 tf.estimator.Estimator 进行预测?

转载 作者:太空狗 更新时间:2023-10-29 20:57:43 24 4
gpt4 key购买 nike

我刚刚用 tensorflow 训练了一个 CNN 来识别太阳黑子。我的模型与 this 几乎相同.问题是我无法在任何地方找到关于如何使用训练阶段生成的检查点进行预测的明确解释。

尝试使用标准恢复方法:

saver = tf.train.import_meta_graph('./model/model.ckpt.meta')
saver.restore(sess,'./model/model.ckpt')

但后来我不知道如何运行它。
尝试使用 tf.estimator.Estimator.predict()像这样:

# Create the Estimator (should reload the last checkpoint but it doesn't)
sunspot_classifier = tf.estimator.Estimator(
model_fn=cnn_model_fn, model_dir="./model")

# Set up logging for predictions
# Log the values in the "Softmax" tensor with label "probabilities"
tensors_to_log = {"probabilities": "softmax_tensor"}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=50)

# predict with the model and print results
pred_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": pred_data},
shuffle=False)
pred_results = sunspot_classifier.predict(input_fn=pred_input_fn)
print(pred_results)

但它所做的是吐出 <generator object Estimator.predict at 0x10dda6bf8> .如果我使用相同的代码但使用 tf.estimator.Estimator.evaluate()它就像一个魅力(重新加载模型,执行评估并将其发送到 TensorBoard)。

我知道有很多类似的问题,但我找不到真正适合我的方法。

最佳答案

sunspot_classifier.predict(input_fn=pred_input_fn) 返回生成器。所以 pred_results 是生成器对象。要从中获取值(value),您需要通过 next(pred_results)

对其进行迭代

解决方法是打印(下一个(pred_results))

关于python - 如何从检查点使用 tf.estimator.Estimator 进行预测?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47911104/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com