gpt4 book ai didi

pytorch - KeyError : 'unexpected key "module.编码器.embedding.weight“在state_dict”中

转载 作者:行者123 更新时间:2023-12-01 22:26:20 24 4
gpt4 key购买 nike

我在尝试加载已保存的模型时收到以下错误。

KeyError:“state_dict 中出现意外的键“module.encoder.embedding.weight””

这是我用来加载已保存模型的函数。

def load_model_states(model, tag):
"""Load a previously saved model states."""
filename = os.path.join(args.save_path, tag)
with open(filename, 'rb') as f:
model.load_state_dict(torch.load(f))

该模型是一个序列到序列网络,其初始化函数(构造函数)如下所示。

def __init__(self, dictionary, embedding_index, max_sent_length, args):
""""Constructor of the class."""
super(Sequence2Sequence, self).__init__()
self.dictionary = dictionary
self.embedding_index = embedding_index
self.config = args
self.encoder = Encoder(len(self.dictionary), self.config)
self.decoder = AttentionDecoder(len(self.dictionary), max_sent_length, self.config)
self.criterion = nn.NLLLoss() # Negative log-likelihood loss

# Initializing the weight parameters for the embedding layer in the encoder.
self.encoder.init_embedding_weights(self.dictionary, self.embedding_index, self.config.emsize)

当我打印模型(序列到序列网络)时,我得到以下结果。

Sequence2Sequence (
(encoder): Encoder (
(drop): Dropout (p = 0.25)
(embedding): Embedding(43723, 300)
(rnn): LSTM(300, 300, batch_first=True, dropout=0.25)
)
(decoder): AttentionDecoder (
(embedding): Embedding(43723, 300)
(attn): Linear (600 -> 12)
(attn_combine): Linear (600 -> 300)
(drop): Dropout (p = 0.25)
(out): Linear (300 -> 43723)
(rnn): LSTM(300, 300, batch_first=True, dropout=0.25)
)
(criterion): NLLLoss (
)
)

因此,module.encoder.embedding是一个嵌入层,module.encoder.embedding.weight表示相关的权重矩阵。那么,为什么它说 - state_dict 中出现意外的键“module.encoder.embedding.weight”

最佳答案

我解决了这个问题。实际上,我使用 nn.DataParallel 保存模型,它将模型存储在模块中,然后我尝试在不使用 DataParallel 的情况下加载它。因此,要么我需要在网络中临时添加一个 nn.DataParallel 以便加载,要么我可以加载权重文件,创建一个不带模块前缀的新有序字典,然后将其加载回来。

第二个解决方法如下所示。

# original saved file with DataParallel
state_dict = torch.load('myfile.pth.tar')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)

引用:https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686

关于pytorch - KeyError : 'unexpected key "module.编码器.embedding.weight“在state_dict”中,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44230907/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com