gpt4 book ai didi

python - 无法恢复类 TextVectorization 的图层 - 文本分类

转载 作者:行者123 更新时间:2023-12-04 08:30:21 26 4
gpt4 key购买 nike

系统信息谷歌实验室

当我运行官方 tensorflow 基本文本分类提供的示例时,在模型保存之前一切正常,但是当我加载模型时却出现此错误。

RuntimeError: Unable to restore a layer of class TextVectorization. Layers of class TextVectorization require that the class be provided to the model loading code, either by registering the class using @keras.utils.register_keras_serializable on the class def and including that file in your program, or by passing the class in a keras.utils.CustomObjectScope that wraps this load call.

预期行为:模型应成功加载并处理原始输入

https://colab.research.google.com/gist/amahendrakar/8b65a688dc87ce9ca07ffb0ce50b84c7/44199.ipynb#scrollTo=fEjmSrKIqiiM

示例链接:https://tensorflow.google.cn/tutorials/keras/text_classification

最佳答案

当我实现(和自定义)“基本文本分类”中的代码时,我也遇到了这个错误消息(RuntimeError:无法恢复类 TextVectorization 的层。[...]) "教程。

我没有在笔记本中运行代码,而是使用了两个脚本,一个用于构建、训练和保存模型,另一个用于加载模型并进行预测。 (因此,错误似乎并不局限于 Google Colab)。

这是我必须做的(参见 https://github.com/tensorflow/tensorflow/issues/45231):

首先,我在函数定义之前的第一个脚本中添加了这一行,然后再次构建、训练和保存模型:

@tf.keras.utils.register_keras_serializable()
def custom_standardization(input_data):
[...]

# Save model as SavedModel
export_model.save(model_path, save_format='tf')

其次,我还必须在第二个脚本中添加相同的行和整个函数定义,以确保在我重新启动 (!) ipython(我当前运行脚本的位置)并且只运行第二个脚本时它能正常工作:

@tf.keras.utils.register_keras_serializable()
def custom_standardization(input_data):
lowercase = tf.strings.lower(input_data)
stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')
return tf.strings.regex_replace(stripped_html,
'[%s]' % re.escape(string.punctuation),
'')
[...]
# Load model
reloaded_model = tf.keras.models.load_model(model_path)
# Make predictions
predictions = reloaded_model.predict(examples)

注意:如果我在运行第一个脚本后没有重启 ipython 就运行第二个脚本,我会得到这个错误:

ValueError: Custom>custom_standardization has already been registered [...]

或者,您可以在构建模型时只使用矢量层中的默认标准化方法:

vectorize_layer = TextVectorization(
standardize="lower_and_strip_punctuation",
max_tokens=max_features,
output_mode='int',
output_sequence_length=sequence_length)

关于python - 无法恢复类 TextVectorization 的图层 - 文本分类,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65050132/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com