gpt4 book ai didi

python - 如何在 Tensorflow SavedModel 中列出所有使用过的操作?

转载 作者:行者123 更新时间:2023-12-03 15:48:10 25 4
gpt4 key购买 nike

如果我使用 tensorflow.saved_model.save 保存我的模型SavedModel 格式的函数,之后如何检索此模型中使用了哪些 Tensorflow Ops。由于模型可以恢复,所以这些操作都存储在图中,我的猜测是在saved_model.pb中文件。如果我加载这个 protobuf(所以不是整个模型),protobuf 的库部分会列出这些,但目前没有记录并标记为实验功能。在 Tensorflow 1.x 中创建的模型将没有这部分。

那么从 SavedModel 格式的模型中检索已使用操作列表(如 MatchingFilesWriteFile )的快速可靠方法是什么?

现在我可以卡住整个东西,比如 tensorflowjs-converter 做。因为他们还检查支持的操作。当 LSTM 在模型中时,这当前不起作用,请参阅 here .有没有更好的方法来做到这一点,因为 Ops 肯定在那里?

示例模型:

class FileReader(tf.Module):

@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
input_scalar = tf.reshape(file_name, [])
output = tf.io.read_file(input_scalar)
return tf.stack([output], name='content')

file_reader = FileReader()

tf.saved_model.save(file_reader, 'file_reader')

预期输出所有操作,在这种情况下至少包含:
  • ReadFile如所述 here
  • ...
  • 最佳答案

    saved_model.pbSavedModel protobuf 消息,然后您可以直接从那里获得操作。假设我们创建一个模型如下:

    import tensorflow as tf

    class FileReader(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
    def read_disk(self, file_name):
    input_scalar = tf.reshape(file_name, [])
    output = tf.io.read_file(input_scalar)
    return tf.stack([output], name='content')

    file_reader = FileReader()
    tf.saved_model.save(file_reader, 'tmp')

    我们现在可以找到该模型使用的操作,如下所示:

    from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

    saved_model = SavedModel()
    with open('tmp/saved_model.pb', 'rb') as f:
    saved_model.ParseFromString(f.read())
    model_op_names = set()
    # Iterate over every metagraph in case there is more than one
    for meta_graph in saved_model.meta_graphs:
    # Add operations in the graph definition
    model_op_names.update(node.op for node in meta_graph.graph_def.node)
    # Go through the functions in the graph definition
    for func in meta_graph.graph_def.library.function:
    # Add operations in each function
    model_op_names.update(node.op for node in func.node_def)
    # Convert to list, sorted if you want
    model_op_names = sorted(model_op_names)
    print(*model_op_names, sep='\n')
    # Const
    # Identity
    # MergeV2Checkpoints
    # NoOp
    # Pack
    # PartitionedCall
    # Placeholder
    # ReadFile
    # Reshape
    # RestoreV2
    # SaveV2
    # ShardedFilename
    # StatefulPartitionedCall
    # StringJoin

    关于python - 如何在 Tensorflow SavedModel 中列出所有使用过的操作?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60154650/

    25 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com