gpt4 book ai didi

python - 将元数据添加到 tensorflow 卡住图 pb

转载 作者:行者123 更新时间:2023-12-03 16:51:20 25 4
gpt4 key购买 nike

为了分享我们训练好的 tensorflow 网络,我们将图卡住为 .pb文件。我们还创建了一个 xml 文件,其中包含一些元数据,例如输入张量和输出张量、要应用的预处理类型、训练数据信息等。然后通过加载图形和评估张量等使用 Java 或 C# 提供模型。

为了使共享更容易,我想将此 xml 数据包含在 .pb 中的某处。文件。有没有办法做到这一点?一个想法是将它作为 tf.Constant,但我不知道如何将它连接到普通图。

请注意,这是使用 freeze_graph.py .新的 SavedModel 格式是否更合适?

最佳答案

首先,是的,您应该使用新的 SavedModel 格式,因为它会得到 TF 团队的支持,并且也适用于 Keras。您可以向模型添加一个额外的端点,它返回一个带有 XML 数据字符串的常量张量(如您所述)。
这很好,因为它是密封的——底层的 savemodel 格式并不重要,因为您的元数据保存在计算图本身中。
看这个问题的答案:Saving a TF2 keras model with custom signature defs .对于 Keras,该答案并不能 100% 为您提供帮助,因为它无法与 tf.keras.models.load 函数很好地互操作,因为它们将其包装在 tf.Module 中.幸运的是,使用 tf.keras.Model如果添加 tf.function 装饰器,则在 TF2 中也能正常工作:

class MyModel(tf.keras.Model):

def __init__(self, metadata, **kwargs):
super(MyModel, self).__init__(**kwargs)
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.metadata = tf.constant(metadata)

def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)

@tf.function(input_signature=[])
def get_metadata(self):
return self.metadata

model = MyModel('metadata_test')
input_arr = tf.random.uniform((5, 5, 1)) # This call is needed so Keras knows its input shape. You could define manually too
outputs = model(input_arr)
然后您可以按如下方式保存和加载您的模型:
tf.keras.models.save_model(model, 'test_model_keras')
model_loaded = tf.keras.models.load_model('test_model_keras')
最后使用 model_loaded.get_metadata()检索您的常量元数据张量。

关于python - 将元数据添加到 tensorflow 卡住图 pb,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54642590/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com