gpt4 book ai didi

python - 如何循环打印模型的所有 tf​​.Tensors?

转载 作者:太空宇宙 更新时间:2023-11-03 14:46:56 24 4
gpt4 key购买 nike

我的模型中有 float32 类型的所有张量的列表:

import os                                                                                                     
import os.path
import tensorflow as tf
from tensorflow.python.platform import gfile
import numpy
numpy.set_printoptions(threshold=numpy.nan)


with tf.Session() as sess:
model_filename = 'MY_pb_file.pb'
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_= tf.import_graph_def(graph_def,name='')
from pprint import pprint
pprint([out for op in tf.get_default_graph().get_operations() if op.type != 'Placeholder' for out in op.values() if out.dtype == tf.float32])

这给了我所有的列表:

<tf.Tensor 'MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm/batchnorm/add:0' shape=(16,) dtype=float32>, <tf.Tensor 'MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm/batchnorm/Rsqrt:0' shape=(16,) dtype=float32>,
<tf.Tensor 'MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm/batchnorm/mul:0' shape=(16,) dtype=float32>,
<tf.Tensor 'MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm/batchnorm/mul_1:0' shape=(?, 64, 64, 16) dtype=float32>,
<tf.Tensor 'MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm/batchnorm/mul_2:0' shape=(16,) dtype=float32>,
<tf.Tensor 'MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm/batchnorm/sub:0' shape=(16,) dtype=float32>, <tf.Tensor 'MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm/batchnorm/add_1:0' shape=(?, 64, 64, 16) dtype=float32>,
...

此时我可以使用sess.run('NAME')来查看它的值:

>>> sess.run('MobilenetV1/MobilenetV1/Conv2d_1_pointwise/BatchNorm/batchnorm/mul:0')
array([ 0.51656026, 29.6620369 , 0.48722425, 7.73186255,
-9.51173401, 0.60846734, 0.21111809, 0.23865609,
23.85105324, 1.04134226, 28.59620476, 35.79195023,
0.34110394, 0.5557093 , 10.39805031, 10.99952412], dtype=float32)

但是,我想在循环中打印所有 tf​​.Tensor 值。我怎样才能做到这一点?

显然,有些需要定义字典:

sess.run('MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6:0')

例如:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'input' with dtype float and shape [?,128,128,3]
[[Node: input = Placeholder[dtype=DT_FLOAT, shape=[?,128,128,3], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

最佳答案

当您请求张量的值时,Tensorflow 会计算图中的该值,因为张量值通常不会在对 sess.run 的不同调用中保留(这就是变量的用途)。计算您请求的张量值所需的操作似乎需要来自输入占位符之一(在错误语句中名为 input)的输入,因此您必须通过sess.run 中的 feed 字典。

根据您的评论,请考虑以下示例:

import tensorflow as tf

a = tf.constant(4)
b = tf.constant(3)
c = tf.placeholder(tf.int32, [], 'c')

d = a + b
e = a + c

请求张量d工作正常:

with tf.Session() as sess:
print(sess.run(d)) # prints 7

但是,请求 e 会引发与您报告的相同错误:

with tf.Session() as sess:
print(sess.run(e))

打印内容

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'c' with dtype int32
[[Node: c = Placeholder[dtype=DT_INT32, shape=[], _device="/job:localhost/replica:0/task:0/gpu:0"]()]]

发生这种情况是因为为了计算e,我们必须计算a + c,并且如果我们不向c提供值,这不可能。因此,例如,这有效:

with tf.Session() as sess:
print(sess.run(e, feed_dict={c: 1})) # prints 5

评估d工作得很好,因为评估d所需的计算路径不涉及占位符。因此,要解决您的问题,您应该在调用 sess.run('MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6:0')< 时为名为 'input' 的占位符提供一个值.

关于python - 如何循环打印模型的所有 tf​​.Tensors?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46165553/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com