gpt4 book ai didi

tensorflow - 何时在 tensorflow 中使用 model.predict(x) 与 model(x)

转载 作者:行者123 更新时间:2023-12-04 14:16:02 25 4
gpt4 key购买 nike

我有一个使用 tf.keras.models.load_model 加载的 keras.models.Model

现在有两个选项可以使用这个模型。我可以调用 model.predict(x) 或调用 model(x).numpy()。这两个选项都给我相同的结果,但 model.predict(x) 的运行时间要长 10 倍以上。

source code中的评论状态:

Computation is done in batches. This method is designed for performance in large scale inputs. For small amount of inputs that fit in one batch, directly using __call__ is recommended for faster execution, e.g., model(x), or model(x, training=False)

我已经用包含 1 的 x 进行了测试; 1,000,000;和 10,000,000 行,model(x) 仍然表现更好。

输入需要多大才能被归类为大规模输入,并且 model.predict(x) 才能表现得更好?

最佳答案

现有的堆栈溢出答案可能对您有用:https://stackoverflow.com/a/58385156/5666087 .我在 tensorflow/tensorflow#33340 上找到了它.该答案建议将 experimental_run_tf_function=False 传递到 model.compile 调用中以恢复到模型执行的 TF 1.x 版本。您也可以完全省略 model.compile 调用(对于预测而言它不是必需的)。

How large does the input need to be to be classified as a large scale input, and for the model.predict(x) to perform better?

这是您可以测试的内容。正如文档所述,如果您的数据适合一个批处理,model(x) 可能会比 model.predict(x) 更快。 model.predict(x) 优于 model(x) 的一件事是能够预测多个批处理。如果您想使用 model(x) 对多个批处理进行预测,则必须自己编写循环。 model.predict还提供其他功能,例如回调。

仅供引用,源代码中的文档已添加到提交 42f469be0f3e8c36624f0b01c571e7ed15f75faf 中, 由于 tensorflow/tensorflow#33340 .

model.predict(x) 的主要行为已实现 here .它不仅仅包含模型的前向传递。这可能是一些速度差异的原因。

I've tested with x containing 1; 1,000,000; and 10,000,000 rows and model(x) still performs better.

这 10,000,000 行是否适合一批......?

关于tensorflow - 何时在 tensorflow 中使用 model.predict(x) 与 model(x),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60159714/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com