gpt4 book ai didi

python - 如何在 Keras 中使用 LSTM 层恢复已保存的模型

转载 作者:太空宇宙 更新时间:2023-11-04 04:35:35 25 4
gpt4 key购买 nike

我正在关注 tutorial使用 LSTM 生成英文文本并使用莎士比亚的作品作为训练文件。这是我正在引用的模型-

model = Sequential()
model.add(LSTM(HIDDEN_DIM, input_shape=(None, VOCAB_SIZE), return_sequences=True))
model.add(Dropout(0.2))
for i in range(LAYER_NUM - 1):
model.add(LSTM(HIDDEN_DIM, return_sequences=True))
model.add(TimeDistributed(Dense(VOCAB_SIZE)))
model.add(Activation('softmax'))
model.compile(loss="categorical_crossentropy", optimizer="rmsprop")

经过 30 个时期的训练后,我使用 model.save('model.h5') 保存模型。至此,模型已经学会了基本格式,也学会了一些单词。但是,当我尝试使用 load_model('model.h5') 在新程序中加载模型并尝试生成一些文本时,它最终会预测完全随机的字母和符号。这让我认为模型权重没有正确恢复,因为我在仅存储模型权重时遇到了同样的问题。那么有没有其他方法可以存储和恢复具有 LSTM 层的训练模型?

作为引用,为了生成文本,该函数随机生成一个字符并将其输入模型以预测下一个字符。这是函数-

def generate_text(model, length):
ix = [np.random.randint(VOCAB_SIZE)]
y_char = [ix_to_char[ix[-1]]]
X = np.zeros((1, length, VOCAB_SIZE))
for i in range(length):
X[0, i, :][ix[-1]] = 1
print(ix_to_char[ix[-1]], end="")
ix = np.argmax(model.predict(X[:, :i+1, :])[0], 1)
y_char.append(ix_to_char[ix[-1]])
return ('').join(y_char)

编辑

训练代码片段-

for nbepoch in range(1, 11):
print('Epoch ', nbepoch)
model.fit(X, y, batch_size=64, verbose=1, epochs=1)
if nbepoch % 10 == 0:
model.model.save('checkpoint_{}_epoch_{}.h5'.format(512, nbepoch))
generate_text(model, 50)
print('\n\n\n')

其中 generate_text() 只是一个预测新字符的函数,从随机生成的字符开始。每训练 10 个 epoch 后,整个模型将保存为 .h5 文件。

加载模型的代码-

print('Loading Model')

model = load_model('checkpoint_512_epoch_10.h5')

print('Model loaded')

generate_text(model, 400)

就预测而言,文本生成通常是在训练时结构化的,模型会学习一些单词。然而,当加载保存的模型时,文本生成是完全随机的,就好像权重是随机重新初始化的。

最佳答案

经过一些挖掘,我终于发现我在字符和单热向量之间创建字典映射的方式是问题所在。我使用 char = list(set(data)) 函数获取文件中所有字符的列表,然后将字符的索引指定为该字符的“代码编号”。然而,显然 list(set(data)) 函数并不总是输出相同的列表,相反,对于 python 的每个“ session ”,顺序是随机的。所以我的字典映射过去常常在保存和加载模型之间发生变化,因为这发生在不同的脚本中。使用 char = sorted(list(set(data))) 可以消除这个问题。

关于python - 如何在 Keras 中使用 LSTM 层恢复已保存的模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51809132/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com