gpt4 book ai didi

python - Tensorflow - 停止恢复网络参数

转载 作者:行者123 更新时间:2023-11-30 09:19:23 26 4
gpt4 key购买 nike

我正在尝试从 tensorflow 网络进行多个顺序预测,但即使对于 CPU,性能似乎也很差(2 层 8x8 卷积网络的每次预测约为 500 毫秒)。我怀疑问题的一部分在于它似乎每次都重新加载网络参数。下面代码中对 classifier.predict 的每次调用都会产生以下输出行 - 因此我会看到数百次。

信息:tensorflow:从/tmp/model_data/model.ckpt-102001 恢复参数

如何重用已加载的检查点?

(我无法在这里进行批量预测,因为网络的输出是游戏中的一步棋,然后需要将其应用于当前状态,然后再输入新的游戏状态。)

这是进行预测的循环。

def rollout(classifier, state):
while not state.terminated:
predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": state.as_nn_input()}, shuffle=False)
prediction = next(classifier.predict(input_fn=predict_input_fn))
index = np.random.choice(NUM_ACTIONS, p=prediction["probabilities"]) # Select a move according to the network's output probabilities
state.apply_move(index)

classifier 是一个 tf.estimator.Estimator 创建的...

classifier = tf.estimator.Estimator(
model_fn=cnn_model_fn, model_dir=os.path.join(tempfile.gettempdir(), 'model_data'))

最佳答案

Estimator API是一个高级 API。

The tf.estimator framework makes it easy to construct and train machine learning models via its high-level Estimator API. Estimator offers classes you can instantiate to quickly configure common model types such as regressors and classifiers.

Estimator API 抽象了 TensorFlow 的大量复杂性,但在此过程中失去了一些通用性。阅读代码后,很明显,如果不每次都重新加载模型,就无法运行多个顺序预测。低级 TensorFlow API 允许这种行为。但是...

Keras是支持此用例的高级框架。简单define the model然后调用predict反复。

def rollout(model, state):
while not state.terminated:
predictions = model.predict(state.as_nn_input())
for _, prediction in enumerate(predictions):
index = np.random.choice(bt.ACTIONS, p=prediction)
state.apply_mode(index)

不科学的基准测试表明,这大约快了 100 倍。

关于python - Tensorflow - 停止恢复网络参数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45804879/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com