gpt4 book ai didi

python - 根据新数据点更新预先训练的深度学习模型

转载 作者:行者123 更新时间:2023-12-01 08:38:05 25 4
gpt4 key购买 nike

考虑 ImageNet 上的图像分类示例,如何使用新数据点更新预训练模型。我已经加载了预训练的模型。我有一个新数据点,与之前训练模型的原始数据的分布完全不同。因此,我想借助新数据点来更新/微调模型。该如何去做呢?有人可以帮我做吗?我使用pytorch 0.4.0进行实现,在GPU Tesla K40C上运行。

最佳答案

如果您不想更改分类器的输出(即类的数量),那么您可以简单地使用新的示例图像继续训练模型,假设它们被 reshape 为与预训练模型相同的形状接受。

另一方面,如果您想更改预训练模型中的类数,则可以用新层替换最后一个全连接层,并在新样本上仅训练该特定层。以下是此案例的示例代码,来自 PyTorch's autograd mechanics notes :

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
# Replace the last fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)

# Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

关于python - 根据新数据点更新预先训练的深度学习模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53624766/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com