gpt4 book ai didi

python - 在复杂模型上使用 Pytorch 进行修剪

转载 作者:行者123 更新时间:2023-12-05 07:04:11 24 4
gpt4 key购买 nike

所以我正在尝试使用 torch.nn.utils.prune.global_unstructured .

我在一个简单的模型上完成了它并且成功了。 model.cov2 或其他层都可以。我正在尝试在(嵌套的)模型上执行此操作?我得到的错误是:

AttributeError: 'CNN' object has no attribute 'conv1'

和其他错误。我尝试了所有方法来访问这个深度 cov1,但我做不到。

您可以在下面找到模型代码:

class CNN(nn.Module):
def __init__(self):
"""CNN Builder."""
super(CNN, self).__init__()

self.conv_layer = nn.Sequential(

# Conv Layer block 1
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),

# Conv Layer block 2
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Dropout2d(p=0.05),

# Conv Layer block 3
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)


self.fc_layer = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(4096, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Dropout(p=0.1),
nn.Linear(512, 100)
)


def forward(self, x):
"""Perform forward."""
# conv layers
x = self.conv_layer(x)
# flatten
x = x.view(x.size(0), -1)
# fc layer
x = self.fc_layer(x)
return x

如何在此模型上应用剪枝?

最佳答案

您的模块不是名称“conv1”或“conv2”,您可以使用 named_modules 生成器查看名称。从上面,你有一个'conv_stem',它可以被索引为 model.conv_stem[0] 来访问。您可以遍历模块来创建一个像这样的字典:

parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'), )

并将其传入。查看更多信息:https://colab.research.google.com/github/pytorch/tutorials/blob/gh-pages/_downloads/f40ae04715cdb214ecba048c12f8dddf/pruning_tutorial.ipynb#scrollTo=UVFjM079F0Oi

关于python - 在复杂模型上使用 Pytorch 进行修剪,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63001581/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com