gpt4 book ai didi

tensorflow - 使用 tf.import_graph_def 附加新输入管道时如何避免图形重复?

转载 作者:行者123 更新时间:2023-12-04 15:49:39 26 4
gpt4 key购买 nike

我正在尝试为我在 tensorflow 中做的模型设置两个不同的管道。为此,我从 here 中获取了答案。和 here ,但每次我运行它并保存图形以在 tensorboard 中显示它,或打印图形中所有可用的节点时,它都显示原始模型已被复制,而不是将新输入附加到相应的节点。

这是一个最小的例子:

import tensorflow as tf

# Creates toy dataset with tf.data API
dataset = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset = dataset.batch(32)

# Input placeholder
x = tf.placeholder(tf.float32,shape=[None,10],name='x')

# Main model
with tf.variable_scope('model'):
y = tf.add(tf.constant(2.),x,name='y')
z = tf.add(tf.constant(2.),y,name='z')

# Session
sess = tf.Session()

# Iterator that will be the new input pipeline for training
iterator = dataset.make_initializable_iterator()
next_elem = iterator.get_next()

graph_def = tf.get_default_graph().as_graph_def()

# If uncommented, it creates an error
#tf.reset_default_graph()

# Create the input to the node y
x_ds = tf.import_graph_def(graph_def=graph_def,
input_map={'x:0':next_elem})

# Write to disk the graph
tf.summary.FileWriter('./',sess.graph)

# Print all the nodes names
for node in sess.graph_def.node:
print(node.name)

我希望只有一个 y 和 z 节点。然而,当显示图形的所有名称或使用 tensorboard 检查它时,有两个结构,原始结构和其他结构在“导入”命名空间中,数据集输入到 y。知道如何解决这个问题吗?或者这是预期的行为?

最佳答案

在阅读了其他一些问题后,我找到了问题的答案。 Here是关于如何连接来自不同图形的节点的精彩解释。

这里的关键是手动定义将创建每个操作的图形。下面以代码为例。

import numpy as np
import tensorflow as tf

### Main model with a placeholder as input

# Create a graph
g_1 = tf.Graph()

# Define everything inside it
with g_1.as_default():
# Input placeholder
x = tf.placeholder(tf.float64,shape=[None,2],name='x')
with tf.variable_scope('model'):
y = tf.add(tf.constant(2.,dtype=tf.float64),x,name='y')
z = tf.add(tf.constant(2.,dtype=tf.float64),y,name='z')

gdef_1 = g_1.as_graph_def()


### Change the input pipeline

# Create another graph
g_2 = tf.Graph()

# Define everything inside it
with g_2.as_default():
# Create a toy tf.dataset
dataset = tf.data.Dataset.from_tensor_slices(np.array([[1.,2],[3,4],[5,6]]))
dataset = dataset.batch(1)

# Iterator that will be the new input pipeline for training
iterator = dataset.make_initializable_iterator()
next_elem = iterator.get_next()
# Create an identical operation as next_elemebt with name so it can be
# manipulated later
next_elem = tf.identity(next_elem,name='next_elem')

# Create the new pipeline. Use next_elem as input instead of x
z, = tf.import_graph_def(gdef_1,
input_map={'x:0':next_elem},
return_elements=['model/z:0'],
name='') # Set name to '' so it conserves the same scope as the original

# Create session linked to g_1
sess_1 = tf.Session(graph=g_1)

# Create session linked to g_2
sess_2 = tf.Session(graph=g_2)

# Initialize the iterator
sess_2.run(iterator.initializer)

# Write the graph to disk
tf.summary.FileWriter('./',sess_2.graph)

# Testing placeholders
out = sess_1.run([y],feed_dict={x:np.array([[1.,2.]],dtype=np.float64)})
print(out)

# Testing tf.data
out = sess_2.run([z])
print(out)

现在,一切都应该在不同的图表中。

关于tensorflow - 使用 tf.import_graph_def 附加新输入管道时如何避免图形重复?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54374588/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com