gpt4 book ai didi

python - 保存使用 keras 训练的 TF 模型,然后在 Go 中进行评估

转载 作者:IT王子 更新时间:2023-10-29 00:43:29 28 4
gpt4 key购买 nike

我正在尝试使用 keras 设置一个经典的 MNIST 挑战模型,然后保存 tensorflow 图并随后将其加载到 Go 中,然后用一些输入进行评估。我一直在关注this articlegithub 上提供完整代码. Nils 仅使用 tensorflow 来设置 comp.graph,但我想使用 keras。我设法像他一样保存模型

型号:

   model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
activation='relu',
input_shape=(28,28,1), name="inputNode"))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax', name="inferNode"))

运行正常,训练和评估然后保存如上:

builder = tf.saved_model.builder.SavedModelBuilder("mnistmodel_my")
# GOLANG note that we must tag our model so that we can retrieve it at inference-time
builder.add_meta_graph_and_variables(sess, ["serve"])
builder.save()

然后我尝试将其评估为:

result, runErr := model.Session.Run(
map[tf.Output]*tf.Tensor{
model.Graph.Operation("inputNode").Output(0): tensor,
},
[]tf.Output{
model.Graph.Operation("inferNode").Output(0),
},
nil,
)

在 Go 中,我遵循示例,但在评估时,我得到:

    panic: nil-Operation. If the Output was created with a Scope object, see Scope.Err() for details.

goroutine 1 [running]:
github.com/tensorflow/tensorflow/tensorflow/go.Output.c(0x0, 0x0, 0x0, 0x0)
/Users/air/go/src/github.com/tensorflow/tensorflow/tensorflow/go/operation.go:119 +0xbb
github.com/tensorflow/tensorflow/tensorflow/go.newCRunArgs(0xc42006e210, 0xc420047ef0, 0x1, 0x1, 0x0, 0x0, 0x0, 0xc4200723c8)
/Users/air/go/src/github.com/tensorflow/tensorflow/tensorflow/go/session.go:307 +0x22d
github.com/tensorflow/tensorflow/tensorflow/go.(*Session).Run(0xc420078060, 0xc42006e210, 0xc420047ef0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, ...)
/Users/air/go/src/github.com/tensorflow/tensorflow/tensorflow/go/session.go:85 +0x153
main.main()
/Users/air/PycharmProjects/GoTensor/custom.go:36 +0x341
exit status 2

因为它说 nil-Operation 我想我可能错误地标记了节点。但是我不知道我应该标记哪些其他节点?

非常感谢!!!

最佳答案

您的代码应该可以正常工作。关于零操作的原因,您是对的。

您只需找到“inputNode”的完整节点名称即可。

在 python 中,在模型定义之后,您可以遍历图形节点并查找完整名称,方法如下:

for n in sess.graph.as_graph_def().node:
if "inputNode" in n.name:
print(n.name)

一旦你得到了完整的名字,你就可以在你的 Go 程序中使用它了。

此外,我建议您使用更完整且易于使用的 tensorflow API 包装器:tfgo

关于python - 保存使用 keras 训练的 TF 模型,然后在 Go 中进行评估,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46366954/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com