gpt4 book ai didi

tensorflow - 如何在tensorflow中将输入更改为OP

转载 作者:行者123 更新时间:2023-12-04 01:38:16 24 4
gpt4 key购买 nike

我需要在构造后将输入更改为 tesor。以下面的简化示例:x (constant=42.0)s (x^2)x_new (constant=4.0)

我想将 s 的输入从 x 更改为 x_new。执行此操作后,我期望 s.eval() == 16.0

x = tf.constant(42.0, name='x')
s = tf.square(x, name='s')
x_new = tf.constant(4.0, name='x_new')

tf.get_default_graph().as_graph_def()

Out[6]:
node {
name: "x"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 42.0
}
}
}
}
node {
name: "s"
op: "Square"
input: "x"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "x_new"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 4.0
}
}
}
}
versions {
producer: 24
}

我尝试过使用 tf.contrib.graph_editor.reroute_inputs,但我终究无法弄清楚如何处理它返回的子图。

我也试过使用 tf.import_graph_def ,正如这个 git issue ( https://github.com/tensorflow/tensorflow/issues/1758 ) 中模糊描述的那样,但无论我尝试了多少种方法,我都没有得到s 将输入从 x 更改为 x_new

有人知道如何使用这两种方法中的任何一种来完成这个简单的示例吗?

最佳答案

所以这是我在尝试使用预训练网络时经常遇到的问题。正如您在问题中提到的,一种方法是“import_graph_def”,如 Github issue 中所述。 . Google 一直将我指向此处,但该问题缺少一个干净的示例,因此我将在此处发布一个最小的解决方案。

import tensorflow as tf

with tf.compat.v1.Session() as sess:
x = tf.constant(42.0, name="x")
s = tf.square(x, name="s")
print(sess.run(s))
scope = "test"
x_new = tf.constant(4.0, name="{}/x".format(scope))
tf.import_graph_def(tf.get_default_graph().as_graph_def(), name=scope, input_map={'x': x_new})
print(sess.run("{}/s:0".format(scope)))

请注意,如果您不提供范围,则根据 docs 默认为“导入” .

相反,如果您需要保留一个图表然后对其进行编辑(或从其他人那里加载一个保留的图表),您可以保存该图表并重新加载它(基于此 answer )

import tensorflow as tf

scope = "test"
graph_filename = "test.pb"

with tf.compat.v1.Session() as sess:
x = tf.constant(42.0, name="x")
s = tf.square(x, name="s")
print(sess.run(s))
with tf.gfile.GFile(graph_filename, 'wb') as outfile:
outfile.write(tf.get_default_graph().as_graph_def().SerializeToString())

with tf.compat.v1.Session() as sess:
x_new = tf.constant(4.0, name="{}/x".format(scope))
with tf.gfile.GFile(graph_filename, 'rb') as infile:
graph_def = tf.GraphDef()
bytes_read = graph_def.ParseFromString(infile.read())
tf.import_graph_def(graph_def, name=scope, input_map={'x': x_new})
print(sess.run("{}/s:0".format(scope)))

@Moshe 在评论中询问了有关使用 import_meta_graph 的问题,因此我将添加一个简单示例。如果您想对 import_meta_graph 进行替换(在我看来,这是最好的,因为您跳过所有导入/导出)然后使用方法 1。否则使用方法 2(导入/导出技巧)以编程方式添加输入。

导出示例(请注意,对于这样一个简单的示例,我们至少需要创建一个变量以避免 Saver 抛出异常):

    import tensorflow as tf
scope = "test"
model_name = 'my-model'
with tf.compat.v1.Session() as sess:
#Create a toy network
x = tf.constant(42.0, name="x")
s = tf.square(x, name="s")
print(sess.run(s))
#Save the network
w = tf.Variable(2.0, name="w") #Need at least one variable to make saver happy
sess.run(tf.global_variables_initializer()) #Need to initialize variables before saving
saver = tf.compat.v1.train.Saver()
tf.compat.v1.add_to_collection('s', s) #Select the network pieces to save
saver.save(sess, model_name)

导入示例(取消注释标有方法 2 的行以创建动态输入):

    import tensorflow as tf
scope = "test"
model_name = 'my-model'
with tf.compat.v1.Session() as sess:
x_new = tf.placeholder(tf.float32, name="{}/x".format(scope))
saver = tf.compat.v1.train.import_meta_graph('{}.meta'.format(model_name), import_scope=scope, input_map={'x': x_new}) #Method 1
# saver = tf.compat.v1.train.import_meta_graph('{}.meta'.format(model_name)) #Method 2
saver.restore(sess, model_name)
# tf.import_graph_def(tf.get_default_graph().as_graph_def(), name=scope, input_map={'x': x_new}) #Method 2
s = tf.compat.v1.get_collection('s')[0]
for i in range(4):
print(sess.run("{}/s:0".format(scope), {"{}/x:0".format(scope): i}))

关于tensorflow - 如何在tensorflow中将输入更改为OP,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47256160/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com