gpt4 book ai didi

python - 如何加速 Keras model.predict?

转载 作者:行者123 更新时间:2023-12-01 23:17:18 24 4
gpt4 key购买 nike

我训练了一个 LSTM 模型并尝试对所有测试观察结果进行预测。但是,keras model.predict 需要永远计算所有预测。有没有办法加快这个过程?假设每个预测有两个特征 (x1 & x2)。每个特征的长度(x1 & x2)都是33。如[32,1,17,......,0]。我需要做出 1M 的预测。我的代码是

predictions = np.argmax(make.predict([x1, x2]), axis = -1)  

有什么办法可以加快速度吗?非常感谢

最佳答案

实际上,Keras 模型是执行、训练、再训练、微调和总结以及模型明智更改的主要架构,在进行预测和部署时,我们需要使用keras 模型的卡住推理图

我们还应该使用 TensorRT 进行卡住图优化 , OpenVINO以及许多其他模型优化技术。

我在这里添加了代码片段以从 Keras 模型转换图形

卡住图的链接

Convert Keras Model to TensorFlow frozen graph

Save, Load and Inference From TensorFlow 2.x Frozen Graph

看似冗长,实则方法

#卡住图表

# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()


# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir="./frozen_models",
name="frozen_graph.pb",
as_text=False)

#加载卡住图

   # Load frozen graph using TensorFlow 1.x functions
with tf.io.gfile.GFile("./frozen_models/frozen_graph.pb", "rb") as f:
graph_def = tf.compat.v1.GraphDef()
loaded = graph_def.ParseFromString(f.read())

# Wrap frozen graph to ConcreteFunctions
frozen_func = wrap_frozen_graph(graph_def=graph_def,
inputs=["x:0"],
outputs=["Identity:0"],
print_graph=True)

关于python - 如何加速 Keras model.predict?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68656751/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com