gpt4 book ai didi

pytorch - 并行模拟 torch.nn.Sequential 容器

转载 作者:行者123 更新时间:2023-12-03 08:48:00 27 4
gpt4 key购买 nike

只是想知道,为什么我在 torch.nn 中找不到主题? nn.Sequential 非常方便,它允许在一个地方定义网络,清晰可见,但仅限于非常简单的网络!通过并行模拟(并且对剩余连接的“身份”节点几乎没有帮助),它形成了构建任何前馈网络组合方式的完整方法。我错过了什么吗?

最佳答案

好吧,也许它不应该出现在标准模块集合中,只是因为它的定义非常简单:

class ParallelModule(nn.Sequential):
def __init__(self, *args):
super(ParallelModule, self).__init__( *args )

def forward(self, input):
output = []
for module in self:
output.append( module(input) )
return torch.cat( output, dim=1 )

从“顺序”继承“并行”在意识形态上是不好的,但效果很好。现在可以使用以下代码定义如图所示的网络: Network image by torchviz:

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = nn.Sequential(
nn.Conv2d( 1, 32, 3, padding=1 ), nn.ReLU(),
nn.Conv2d( 32, 64, 3, padding=1 ), nn.ReLU(),
nn.MaxPool2d( 3, stride=2 ), nn.Dropout2d( 0.25 ),

ParallelModule(
nn.Conv2d( 64, 64, 1 ),
nn.Sequential(
nn.Conv2d( 64, 64, 1 ), nn.ReLU(),
ParallelModule(
nn.Conv2d( 64, 32, (3,1), padding=(1,0) ),
nn.Conv2d( 64, 32, (1,3), padding=(0,1) ),
),
),
nn.Sequential(
nn.Conv2d( 64, 64, 1 ), nn.ReLU(),
nn.Conv2d( 64, 64, 3, padding=1 ), nn.ReLU(),
ParallelModule(
nn.Conv2d( 64, 32, (3,1), padding=(1,0) ),
nn.Conv2d( 64, 32, (1,3), padding=(0,1) ),
),
),
nn.Sequential(
#PrinterModule(),
nn.AvgPool2d( 3, stride=1, padding=1 ),
nn.Conv2d( 64, 64, 1 ),
),
),
nn.ReLU(),
nn.Conv2d( 256, 64, 1 ), nn.ReLU(),

nn.Conv2d( 64, 128, 3, padding=1 ), nn.ReLU(),
nn.MaxPool2d( 3, stride=2 ), nn.Dropout2d( 0.5 ),
nn.Flatten(),
nn.Linear( 4608, 128 ), nn.ReLU(),
nn.Linear( 128, 10 ), nn.LogSoftmax( dim=1 ),
)

def forward(self, x):
return self.net.forward( x )

关于pytorch - 并行模拟 torch.nn.Sequential 容器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60586559/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com