gpt4 book ai didi

python - 在 TensorFlow 中导入 GraphDef 后设置 py_func op

转载 作者:行者123 更新时间:2023-11-28 21:42:23 24 4
gpt4 key购买 nike

我有一个保存的图形定义,它是用 tf.train.import_meta_graph 导入的。该图包含不可序列化的 py_func 操作。我可以在不从头开始构建图形的情况下定义 python 函数并将其分配给此操作吗?

最佳答案

这是可能的,但可能有点脆弱。特别是,需要按照它们在原始图中定义的相同顺序重新定义 pyfunc(以便它们在 FuncRegistry 中具有相同的标识符)。

一个例子。我们可以定义一个包含 py_func 的图:

import tensorflow as tf

def my_py_func(x):
return 13. * x + 2.

def train_model():
with tf.Graph().as_default():
some_input = tf.constant([[1., 2., 3., 4.],
[5., 6., 7., 8.]])
after_py_func = tf.py_func(my_py_func, [some_input], Tout=tf.float32,
name="my_py_func")
coefficient = tf.get_variable(
"coefficient",
shape=[])
bias = tf.get_variable(
"bias",
shape=[])
loss = tf.reduce_sum((coefficient * some_input + bias - after_py_func) ** 2)
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = tf.group(tf.train.AdamOptimizer(0.1).minimize(loss),
tf.assign_add(global_step, 1))
# Make it easy to retreive things we care about when the metagraph is reloaded.
tf.add_to_collection('useful_ops', bias)
tf.add_to_collection('useful_ops', coefficient)
tf.add_to_collection('useful_ops', loss)
tf.add_to_collection('useful_ops', train_op)
tf.add_to_collection('useful_ops', global_step)
tf.add_to_collection('useful_ops', some_input)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as session:
session.run(init_op)
for i in range(5000):
(_, evaled_loss, evaled_coefficient, evaled_bias,
evaled_global_step) = session.run(
[train_op, loss, coefficient, bias, global_step])
if i % 1000 == 0:
print(evaled_global_step, evaled_loss, evaled_coefficient,
evaled_bias)
saver.save(session, "./trained_pyfunc_model", global_step=global_step)

这会进行一些基本训练(匹配在 py_func 中找到的线性函数):

1 37350.4 -0.0934748 0.193026
1001 19.2717 12.3749 5.40368
2001 0.108373 12.9532 2.2548
3001 8.28227e-06 12.9996 2.00222
4001 3.77258e-09 13.0 2.00004

然后,如果我们在新的 Python session 中尝试加载元图而不重新定义 pyfunc,则会出现错误:

def load_model():
with tf.Graph().as_default():
saver = tf.train.import_meta_graph("./trained_pyfunc_model-5000.meta")
bias, coefficient, loss, train_op, global_step, some_input = tf.get_collection('useful_ops')
#after_py_func = tf.py_func(my_py_func, [some_input], Tout=tf.float32,
# name="my_py_func")
with tf.Session() as session:
saver.restore(session, "./trained_pyfunc_model-5000")
(_, evaled_loss, evaled_coefficient, evaled_bias,
evaled_global_step) = session.run(
[train_op, loss, coefficient, bias, global_step])
print("Restored: ", evaled_global_step, evaled_loss, evaled_coefficient, evaled_bias)

UnknownError (see above for traceback): KeyError: 'pyfunc_0'

但是,只要 py_func 以相同的顺序定义并具有相同的实现,我们应该没问题:

def load_model():
with tf.Graph().as_default():
saver = tf.train.import_meta_graph("./trained_pyfunc_model-5000.meta")
bias, coefficient, loss, train_op, global_step, some_input = tf.get_collection('useful_ops')
after_py_func = tf.py_func(my_py_func, [some_input], Tout=tf.float32,
name="my_py_func")
with tf.Session() as session:
saver.restore(session, "./trained_pyfunc_model-5000")
(_, evaled_loss, evaled_coefficient, evaled_bias,
evaled_global_step) = session.run(
[train_op, loss, coefficient, bias, global_step])
print("Restored: ", evaled_global_step, evaled_loss, evaled_coefficient, evaled_bias)

这让我们可以继续训练,或者我们想对恢复的模型做任何其他事情:

Restored:  5001 1.77897e-09 13.0 2.00003

请注意,有状态的 py_funcs 将更难处理:TensorFlow 不会保存任何可能与其关联的 Python 变量!

关于python - 在 TensorFlow 中导入 GraphDef 后设置 py_func op,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43644506/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com