gpt4 book ai didi

python - 在 tensorflow 中使用两种不同的模型

转载 作者:太空宇宙 更新时间:2023-11-04 08:39:47 24 4
gpt4 key购买 nike

我正在尝试使用两种不同的移动网络模型。以下是我如何初始化模型的代码。

def initialSetup():
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
start_time = timeit.default_timer()

# This takes 2-5 seconds to run
# Unpersists graph from file
with tf.gfile.FastGFile('age/output_graph.pb', 'rb') as f:
age_graph_def = tf.GraphDef()
age_graph_def.ParseFromString(f.read())
tf.import_graph_def(age_graph_def, name='')

with tf.gfile.FastGFile('output_graph.pb', 'rb') as f:
gender_graph_def = tf.GraphDef()
gender_graph_def.ParseFromString(f.read())
tf.import_graph_def(gender_graph_def, name='')

print ('Took {} seconds to unpersist the graph'.format(timeit.default_timer() - start_time))

由于两者是两种不同的模型,我该如何使用它进行预测?

更新

initialSetup()

age_session = tf.Session(graph=age_graph_def)
gender_session = tf.Session(graph=gender_graph_def)

with tf.Session() as sess:
start_time = timeit.default_timer()

# Feed the image_data as input to the graph and get first prediction
softmax_tensor = age_session.graph.get_tensor_by_name('final_result:0')

print ('Took {} seconds to feed data to graph'.format(timeit.default_timer() - start_time))

while True:
# Capture frame-by-frame
ret, frame = video_capture.read()

错误

Traceback (most recent call last): File "C:/Users/Desktop/untitled/testimg/testimg/combo.py", line 48, in age_session = tf.Session(graph=age_graph_def) File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1292, in init super(Session, self).init(target, graph, config=config) File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 529, in init raise TypeError('graph must be a tf.Graph, but got %s' % type(graph)) TypeError: graph must be a tf.Graph, but got Exception ignored in: > Traceback (most recent call last): File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 587, in del if self._session is not None: AttributeError: 'Session' object has no attribute '_session'

最佳答案

当您在同一个图中使用多个模型时,使用名称范围来为各个张量提供可预测的名称。例如,您可以重写 initial_setup() 如下:

def initialSetup():
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
start_time = timeit.default_timer()

# This takes 2-5 seconds to run
# Unpersists graph from file
with tf.gfile.FastGFile('age/output_graph.pb', 'rb') as f:
age_graph_def = tf.GraphDef()
age_graph_def.ParseFromString(f.read())
tf.import_graph_def(age_graph_def, name='age_model')

with tf.gfile.FastGFile('output_graph.pb', 'rb') as f:
gender_graph_def = tf.GraphDef()
gender_graph_def.ParseFromString(f.read())
tf.import_graph_def(gender_graph_def, name='gender_model')

print ('Took {} seconds to unpersist the graph'.format(timeit.default_timer() - start_time))

现在 age_graph_def 中所有节点的名称都将以 "age_model/" 为前缀,而 gender_graph_def 中所有节点的名称将以前缀为前缀code> 将以 "gender_model/" 为前缀。它们都是同一默认图的一部分,因此您可以使用不带 graph 参数的单个 tf.Session 来访问任一模型。

initialSetup()

with tf.Session() as sess:
start_time = timeit.default_timer()

# Feed the image_data as input to the graph and get first prediction
softmax_tensor = sess.graph.get_tensor_by_name('age_model/final_result:0')

# Alternatively, to get a tensor from the gender model:
# tensor = sess.graph.get_tensor_by_name('gender_model/...')

print ('Took {} seconds to feed data to graph'.format(timeit.default_timer() - start_time))

while True:
# Capture frame-by-frame
ret, frame = video_capture.read()

关于python - 在 tensorflow 中使用两种不同的模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45747769/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com