gpt4 book ai didi

python - 在 Tensorflow 2.0 中卡住和导出 TensorFlow 模型

转载 作者:行者123 更新时间:2023-12-05 07:20:38 25 4
gpt4 key购买 nike

我正在尝试将用 Tensorflow 1.13 编写的现有代码(使用 Estimators)迁移到 Tensorflow 2.0,但我在尝试找到等效的 API 来卡住和输出图形并输出 .pb 文件时遇到问题。

在 tensorflow 1.13 中,估计器类有一个函数 export_savedmodel,它接受一个模型路径和一个 serving_input_receiver_fn。我在设置 serving_input_receiver_fn 时遇到问题,因为它似乎需要占位符。然而,当迁移到 Tensorflow 2.0 时,尽管存在相同的 API,但由于 eager execution 模型设置为默认值,占位符不适用于 eager execution 模式。

   def export(self):
self.configure()
a_shape = (None, None, None, self.IMG_CHANNELS)
b_shape = tf.TensorShape((None, None, self.IMU_DATA_DIM))
a = tf.compat.v1.placeholder(tf.float32, a_shape, name="a")
b = tf.compat.v1.placeholder(tf.float32, b_shape, name='b')
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'a': a,
'b':b
})
return self.modelPath, input_fn

RuntimeError:tf.placeholder() 与急切执行不兼容。

因此,我想问一下,从现有的检查点文件中卡住和导出模型以输出 .pb 文件的正确方法是什么?

最佳答案

这是 tf.estimator.export.build_raw_serving_input_receiver_fn() 的示例。它可以直接粘贴到带有 TF2.x 的笔记本中。希望对您有所帮助。

import tensorflow as tf

checkpoint_dir = "/some/location/to/store/the_model"

input_column = tf.feature_column.numeric_column("x")
# Use a LinearClassifier but this would also work with a custom Estimator
estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

# Create a fake dataset with only one feature 'x' and an associated label
def input_fn():
return tf.data.Dataset.from_tensor_slices(
({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)

# The thing is that we must not call raw_input_fn: would result in the error
# "tf.placeholder() is not compatible with eager execution."
# Instead pass raw_input_fn directly to estimator.export_saved_model()

feature_to_tensor = {
# pass some dummy tensor: this is just to get the shapes for the placeholder
# that will be created by build_raw_serving_input_receiver_fn().
# Adjust with the shape of 'x'.
#
'x': tf.constant(0.),
}
raw_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_to_tensor, default_batch_size=None)
export_dir = estimator.export_saved_model(checkpoint_dir, raw_input_fn).decode()

然后您可以检查导出的模型:

!saved_model_cli show --all --dir $export_dir

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['predict']:
The given SavedModel SignatureDef contains the following input(s):
inputs['x'] tensor_info:
dtype: DT_FLOAT
shape: (-1)
name: Placeholder:0
The given SavedModel SignatureDef contains the following output(s):
outputs['all_class_ids'] tensor_info:
dtype: DT_INT32
shape: (-1, 2)
name: head/predictions/Tile:0
outputs['all_classes'] tensor_info:
dtype: DT_STRING
shape: (-1, 2)
name: head/predictions/Tile_1:0
outputs['class_ids'] tensor_info:
dtype: DT_INT64
shape: (-1, 1)
name: head/predictions/ExpandDims:0
outputs['classes'] tensor_info:
dtype: DT_STRING
shape: (-1, 1)
name: head/predictions/str_classes:0
outputs['logistic'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: head/predictions/logistic:0
outputs['logits'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: linear/linear_model/linear/linear_model/linear/linear_model/weighted_sum:0
outputs['probabilities'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 2)
name: head/predictions/probabilities:0
Method name is: tensorflow/serving/predict

导出的模型现在可以被另一个进程加载和用于推理:

import tensorflow as tf
imported = tf.saved_model.load(export_dir)
f = imported.signatures["predict"]
f(x=tf.constant([-2., 5., -3.]))

{'class_ids': <tf.Tensor: shape=(3, 1), dtype=int64, numpy=
array([[1],
[0],
[1]], dtype=int64)>,
'classes': <tf.Tensor: shape=(3, 1), dtype=string, numpy=
array([[b'1'],
[b'0'],
[b'1']], dtype=object)>,
'all_class_ids': <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[0, 1],
[0, 1],
[0, 1]])>,
...etc...

关于python - 在 Tensorflow 2.0 中卡住和导出 TensorFlow 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57437270/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com