gpt4 book ai didi

python - 是否可以用现有图形中的常量替换占位符?

转载 作者:太空狗 更新时间:2023-10-29 21:25:34 25 4
gpt4 key购买 nike

我有一张经过训练的模型的卡住图,它有一个 tf.placeholder,我总是向其提供相同的值。

我想知道是否可以用 tf.constant 代替它。如果它是某种方式 - 任何例子将不胜感激!

编辑:这是代码的样子,以帮助可视化问题

我正在使用(由其他人)预训练的模型来运行推理。该模型在本地存储为扩展名为 .pb 的卡住图形文件。

代码如下所示:

# load graph
graph = load_graph('frozen.pb')
session = tf.Session(graph=graph)

# Get input and output tensors
images_placeholder = graph.get_tensor_by_name("input:0")
output = graph.get_tensor_by_name("output:0")
phase_train_placeholder = graph.get_tensor_by_name("phase_train:0")

feed_dict = {images_placeholder: images, phase_train_placeholder: False}

result = session.run(output, feed_dict=feed_dict)

问题是出于我的目的,我总是输入 phase_train_placeholder: False,所以我想知道是否可以消除该占位符并将其替换为 tf.constant(False, dtype=bool, shape=[])

最佳答案

所以我没有设法找到任何合适的方法,而是设法以一种 hacky 的方式做到了,方法是重建图形定义并替换我需要替换的节点。灵感来自 this代码。

这是代码( super hacky,使用风险自负):

INPUT_GRAPH_DEF_FILE = 'path/to/file'
OUTPUT_GRAPH_DEF_FILE = 'another/one'

# Get NodeDef of a constant tensor we want to put in place of
# the placeholder.
# (There is probably a better way to do this)
example_graph = tf.Graph()
with tf.Session(graph=example_graph):
c = tf.constant(False, dtype=bool, shape=[], name='phase_train')
for node in example_graph.as_graph_def().node:
if node.name == 'phase_train':
c_def = node

# load our graph
graph = load_graph(INPUT_GRAPH_DEF_FILE)
graph_def = graph.as_graph_def()

# Create new graph, and rebuild it from original one
# replacing phase train node def with constant
new_graph_def = graph_pb2.GraphDef()
for node in graph_def.node:
if node.name == 'phase_train':
new_graph_def.node.extend([c_def])
else:
new_graph_def.node.extend([copy.deepcopy(node)])

# save new graph
with tf.gfile.GFile(OUTPUT_GRAPH_DEF_FILE, "wb") as f:
f.write(new_graph_def.SerializeToString())

关于python - 是否可以用现有图形中的常量替换占位符?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43332342/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com