gpt4 book ai didi

tensorflow2.0 - TensorFlow 2.0 : Save and load a model that contains a LSTM layer,,而负载通用失败并出现 ValueError

转载 作者:行者123 更新时间:2023-12-04 14:40:39 25 4
gpt4 key购买 nike

当我尝试保存和加载包含 LSTM 层的模型时,加载通用失败 ValueError: 找不到匹配的函数来调用从 SavedModel 加载 .

class RegNet(Model):
def __init__(self,
intermediate_dim=50,
state_dim=9,
name='RegNet',
**kwargs):
super(RegNet, self).__init__()
self.d1 = Dense(intermediate_dim, activation='relu')
self.d2 = Dense(state_dim, activation='relu')
self.h = LSTM(state_dim, activation='sigmoid', return_sequences=True)
self.o = Dense(state_dim, activation='softmax')

def call(self, x):
x = self.d1(x)
x = self.d2(x)
x = self.h(x)
y = self.o(x)
return y

regNet = RegNet()
...
# Export the model to a SavedModel
regNet.save(regNet_ckpt_dir, save_format='tf')
# Recreate the exact same model
tf.keras.models.load_model(regNet_ckpt_dir)

错误报告:
> ValueError: Could not find matching function to call loaded from the SavedModel. Got:
Positional arguments (2 total):
* Tensor("x:0", shape=(None, 1, 20), dtype=float32)
* Tensor("training:0", shape=(), dtype=bool)
Keyword arguments: {}

Expected these arguments to match one of the following 4 option(s):

Option 1:
Positional arguments (2 total):
* TensorSpec(shape=(None, 1, 20), dtype=tf.float32, name='input_1')
* False
Keyword arguments: {}

Option 2:
Positional arguments (2 total):
* TensorSpec(shape=(None, 1, 20), dtype=tf.float32, name='x')
* False
Keyword arguments: {}

Option 3:
Positional arguments (2 total):
* TensorSpec(shape=(None, 1, 20), dtype=tf.float32, name='x')
* True
Keyword arguments: {}

Option 4:
Positional arguments (2 total):
* TensorSpec(shape=(None, 1, 20), dtype=tf.float32, name='input_1')
* True
Keyword arguments: {}

当我评论 LSTM 层时,加载命令将成功。问题出在哪儿?我们无法在 TensorFlow 2.0 中保存和加载带有 LSTM 层的模型?

最佳答案

万一其他人偶然发现这个,这个解决方案对我有用:

# Save model
tf.keras.models.save_model(model, "saved_model.hp5", save_format="h5")

# Load model
loaded_model = tf.keras.models.load_model("saved_model.hp5")

不确定为什么“model.save(filename)”语法不适用于 LSTM,但我遇到了同样的问题。

关于tensorflow2.0 - TensorFlow 2.0 : Save and load a model that contains a LSTM layer,,而负载通用失败并出现 ValueError,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58339137/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com