gpt4 book ai didi

python - 如何替换已保存图形的输入,例如数据集迭代器的占位符?

转载 作者:太空狗 更新时间:2023-10-30 00:53:08 25 4
gpt4 key购买 nike

我有一个已保存的 Tensorflow 图,它​​通过带有 feed_dict 参数的 placeholder 消耗输入。

sess.run(my_tensor, feed_dict={input_image: image})

因为用 Dataset Iterator 提供数据是 more efficient ,我想加载保存的图形,将 input_image placeholder 替换为 Iterator 并运行。我怎样才能做到这一点?有更好的方法吗?非常感谢带有代码示例的答案。

最佳答案

您可以通过序列化您的图形并使用 tf.import_graph_def 重新导入它来实现这一点,它有一个 input_map 参数用于在所需位置插入输入。

要做到这一点,您至少需要知道您替换的输入的名称和您希望执行的输出的名称(在我的示例中分别为 xy ).

import tensorflow as tf

# restore graph (built from scratch here for the example)
x = tf.placeholder(tf.int64, shape=(), name='x')
y = tf.square(x, name='y')

# just for display -- you don't need to create a Session for serialization
with tf.Session() as sess:
print("with placeholder:")
for i in range(10):
print(sess.run(y, {x: i}))

# serialize the graph
graph_def = tf.get_default_graph().as_graph_def()

tf.reset_default_graph()

# build new pipeline
batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# plug in new pipeline
[y] = tf.import_graph_def(graph_def, input_map={'x:0': batch}, return_elements=['y:0'])

# enjoy Dataset inputs!
with tf.Session() as sess:
print('with Dataset:')
try:
while True:
print(sess.run(y))
except tf.errors.OutOfRangeError:
pass

请注意,占位符节点仍然存在,因为我没有费心在这里解析 graph_def 以将其删除——您可以将其删除作为改进,尽管我认为保留它也可以在这里。

根据您恢复图形的方式,输入替换可能已经内置在加载程序中,这使事情变得更简单(无需返回到 GraphDef)。例如,如果您从 .meta 文件加载图形,则可以使用接受相同 input_map 参数的 tf.train.import_meta_graph

import tensorflow as tf

# build new pipeline
batch = tf.data.Dataset.range(10).make_one_shot_iterator().get_next()
# load your net and plug in new pipeline
# you need to know the name of the tensor where to plug-in your input
restorer = tf.train.import_meta_graph(graph_filepath, input_map={'x:0': batch})
y = tf.get_default_graph().get_tensor_by_name('y:0')

# enjoy Dataset inputs!
with tf.Session() as sess:
# not needed here, but in practice you would also need to restore weights
# restorer.restore(sess, weights_filepath)
print('with Dataset:')
try:
while True:
print(sess.run(y))
except tf.errors.OutOfRangeError:
pass

关于python - 如何替换已保存图形的输入,例如数据集迭代器的占位符?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50364377/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com