gpt4 book ai didi

python - Pytorch 通过知识转移保存和加载 VGG16

转载 作者:行者123 更新时间:2023-12-01 09:03:30 25 4
gpt4 key购买 nike

我使用以下语句保存带有知识转移的 VGG16:

torch.save(model.state_dict(), 'checkpoint.pth')

并使用以下语句重新加载:

state_dict = torch.load('checkpoint.pth')model.load_state_dict(state_dict)

只要我重新加载 VGG16 模型并使用以下代码为其提供与之前相同的设置,就可以了:

model = models.vgg16(pretrained=True)
model.cuda()
for param in model.parameters(): param.requires_grad = False

class Network(nn.Module):
def __init__(self, input_size, output_size, hidden_layers, drop_p=0.5):

# input_size: integer, size of the input
# output_size: integer, size of the output layer
# hidden_layers: list of integers, the sizes of the hidden layers
# drop_p: float between 0 and 1, dropout probability

super().__init__()
# Add the first layer, input to a hidden layer
self.hidden_layers = nn.ModuleList([nn.Linear(input_size, hidden_layers[0])])

# Add a variable number of more hidden layers
layer_sizes = zip(hidden_layers[:-1], hidden_layers[1:])
self.hidden_layers.extend([nn.Linear(h1, h2) for h1, h2 in layer_sizes])
self.output = nn.Linear(hidden_layers[-1], output_size)
self.dropout = nn.Dropout(p=drop_p)

def forward(self, x):
''' Forward pass through the network, returns the output logits '''

# Forward through each layer in `hidden_layers`, with ReLU activation and dropout
for linear in self.hidden_layers:
x = F.relu(linear(x))
x = self.dropout(x)

x = self.output(x)
return F.log_softmax(x, dim=1)

classifier = Network(25088, 102, [4096], drop_p=0.5)
model.classifier = classifier

如何避免这种情况?如何重新加载模型而无需重新加载 VGG16 并重新定义分类器?

最佳答案

为什么不直接重新定义类似VGG16的模型呢?查看vgg.py详情请参阅

class VGG_New(nn.Module):
def __init__(self, features, num_classes=1000, init_weights=True):
super(VGG, self).__init__()
self.features = features
# change here with you code
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
if init_weights:
self._initialize_weights()

def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)

然后仅加载特征的权重

pretrained_dict=torch.load(vgg_weight)
model_dict=vgg_new.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# or filter with key value
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k.find('classifier')==-1}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
vgg_new.load_state_dict(model_dict)

关于python - Pytorch 通过知识转移保存和加载 VGG16,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52268048/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com