gpt4 book ai didi

python - 保存为 tensorflow 图的 Keras 顺序模型缺少火车操作?

转载 作者:行者123 更新时间:2023-12-04 15:37:33 24 4
gpt4 key购买 nike

我尝试在 keras 和 tensorflow 中制作简单的模型,然后将它们保存到 pb 文件中。运行以下命令时,我注意到 tensorflow 示例有一个训练操作,但 keras 示例没有。 问题:有没有一种方法可以在从 keras 模型创建的 tensorflow 图中找到训练操作,或者确保添加了一个?

tf.get_default_graph().get_operations()

tensorflow 示例

import tensorflow as tf

x = tf.placeholder(tf.float32, shape=[None, 1, 1], name='input')
y = tf.placeholder(tf.float32, shape=[None, 1, 1], name='target')

y_ = tf.identity(tf.layers.dense(x, 1), name='output')

loss = tf.reduce_mean(tf.square(y_ - y), name='loss')
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss, name='train')

init = tf.global_variables_initializer()
saver_def = tf.train.Saver().as_saver_def()
tf.get_default_graph().get_operations()

Keras 示例

import tensorflow as tf
from tensorflow.keras import layers

model = tf.keras.Sequential()
model.add(layers.GRU(20, input_shape=(10, 1), return_sequences=True, name='input'))
model.add(layers.Dense(1, activation='linear'))
model.compile(loss="mse", optimizer="adam")
model.summary()

init = tf.global_variables_initializer()
saver_def = tf.train.Saver().as_saver_def()
tf.get_default_graph().get_operations()

编辑

感谢 Daniel Möller,在运行健身训练后添加到图表中。但是,它的命名不如似乎总是使用“train”的 tensorflow 模型好听。我发现我的 keras 模型的名称“training/group_deps”已保存到 tensorflow 图中。

如果训练名称和目标名称可以像输入和输出一样容易找到,那就太好了,可以通过以下方式找到:

model.input.name
model.output.name

但是我的问题好像解决了,但是每次都需要去挖掘图表文件。因此,如果有人知道更简单的方法,我们将不胜感激。目标是使用 tensorflows C API 运行网络。

编辑2

我在 tesorflow 中找到了 summarize_graph 工具。但是最初构建它的尝试因 Windows 上的边框而失败。目前其他事情是当务之急,所以我没有做进一步的研究。

https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms#inspecting-graphs

最佳答案

为了社区的利益提及以下问题的解决方案(目前已实现)。

Keras Sequential model saved as tensorflow graph is missing train operation?

运行命令,model.fit 在使用命令保存模型之前,

saver_def = tf.train.Saver().as_saver_def()

Keras 序列模型 的图形中包含训练操作。

关于python - 保存为 tensorflow 图的 Keras 顺序模型缺少火车操作?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59259645/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com