gpt4 book ai didi

python - keras访问预训练模型的层参数以卡住

转载 作者:行者123 更新时间:2023-11-30 09:06:06 25 4
gpt4 key购买 nike

我保存了一个具有多层的 LSTM。现在,我想加载它并微调最后一个 LSTM 层。如何定位该层并更改其参数?

训练和保存的简单模型示例:

model = Sequential()
# first layer #neurons
model.add(LSTM(100, return_sequences=True, input_shape=(X.shape[1],
X.shape[2])))
model.add(LSTM(50, return_sequences=True))
model.add(LSTM(25))
model.add(Dense(1))
model.compile(loss='mae', optimizer='adam')

我可以加载并重新训练它,但我找不到一种方法来定位特定层并卡住所有其他层。

最佳答案

一个简单的解决方案是命名每个层,即

model.add(LSTM(50, return_sequences=True, name='2nd_lstm'))

然后,加载模型后,您可以迭代各层并卡住与名称条件匹配的层:

for layer in model.layers:
if layer.name == '2nd_lstm':
layer.trainable = False

然后您需要重新编译您的模型以使更改生效,然后您可以像往常一样恢复训练。

关于python - keras访问预训练模型的层参数以卡住,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51428696/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com