gpt4 book ai didi

machine-learning - 调整 CNN 输入大小的正确方法(例如 VGG)

转载 作者:行者123 更新时间:2023-12-04 14:25:15 25 4
gpt4 key购买 nike

我想在 128x128 大小的图像上训练 VGG。我不想将它们重新调整为 224x224 以节省 GPU 内存和训练时间。这样做的正确方法是什么?

最佳答案

最好的办法是保留卷积部分,替换掉全连接层。这样甚至可以为网络的卷积部分采用预训练权重。全连接层必须随机初始化。通过这种方式,可以微调具有较小输入大小的网络。

这里是一些pytorch代码

import torch
from torch.autograd import Variable
import torchvision
import torch.nn as nn

from torchvision.models.vgg import model_urls

VGG_TYPES = {'vgg11' : torchvision.models.vgg11,
'vgg11_bn' : torchvision.models.vgg11_bn,
'vgg13' : torchvision.models.vgg13,
'vgg13_bn' : torchvision.models.vgg13_bn,
'vgg16' : torchvision.models.vgg16,
'vgg16_bn' : torchvision.models.vgg16_bn,
'vgg19_bn' : torchvision.models.vgg19_bn,
'vgg19' : torchvision.models.vgg19}


class Custom_VGG(nn.Module):

def __init__(self,
ipt_size=(128, 128),
pretrained=True,
vgg_type='vgg19_bn',
num_classes=1000):
super(Custom_VGG, self).__init__()

# load convolutional part of vgg
assert vgg_type in VGG_TYPES, "Unknown vgg_type '{}'".format(vgg_type)
vgg_loader = VGG_TYPES[vgg_type]
vgg = vgg_loader(pretrained=pretrained)
self.features = vgg.features

# init fully connected part of vgg
test_ipt = Variable(torch.zeros(1,3,ipt_size[0],ipt_size[1]))
test_out = vgg.features(test_ipt)
self.n_features = test_out.size(1) * test_out.size(2) * test_out.size(3)
self.classifier = nn.Sequential(nn.Linear(self.n_features, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes)
)
self._init_classifier_weights()

def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x

def _init_classifier_weights(self):
for m in self.classifier:
if isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()

要创建一个 vgg 只需调用这个:

vgg = Custom_VGG(ipt_size=(128, 128), pretrained=True)

关于machine-learning - 调整 CNN 输入大小的正确方法(例如 VGG),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46963372/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com