gpt4 book ai didi

python-3.x - 使用 PyTorch 实现一个简单的 ResNet block

转载 作者:行者123 更新时间:2023-12-04 00:58:01 31 4
gpt4 key购买 nike

我正在尝试实现以下 ResNet block ,该 ResNet 由具有两个卷积层和一个跳过连接的 block 组成。由于某种原因,它不会将跳过连接的输出(如果应用)或输入添加到卷积层的输出。

ResNet block 具有:

  • 两个卷积层:
  • 3x3 内核
  • 无偏见
  • 两侧填充一个像素
  • 每个卷积层后的 2d 批量归一化
  • 跳过连接:
  • 如果分辨率和 channel 数没有改变,只需复制输入。
  • 如果分辨率或 channel 数发生变化,则跳过连接应该有一个卷积层:
  • 1x1 无偏差卷积
  • 跨步更改分辨率(可选)
  • 不同数量的输入 channel 和输出 channel (可选)
  • 1x1 卷积层之后是 2d 批量归一化。
  • ReLU 非线性应用在第一个卷积层之后和 block 的末尾。

  • 我的代码:
    class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
    """
    Args:
    in_channels (int): Number of input channels.
    out_channels (int): Number of output channels.
    stride (int): Controls the stride.
    """
    super(Block, self).__init__()

    self.skip = nn.Sequential()

    if stride != 1 or in_channels != out_channels:
    self.skip = nn.Sequential(
    nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, bias=False),
    nn.BatchNorm2d(out_channels))
    else:
    self.skip = None

    self.block = nn.Sequential(
    nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1, bias=False),
    nn.BatchNorm2d(out_channels),
    nn.ReLU(),
    nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1, bias=False),
    nn.BatchNorm2d(out_channels))

    def forward(self, x):
    out = self.block(x)

    if self.skip is not None:
    out = self.skip(x)
    else:
    out = x

    out += x

    out = F.relu(out)
    return out

    最佳答案

    问题在于 out 的重用多变的。通常,你会像 this 这样实现:

    def forward(self, x):
    identity = x
    out = self.block(x)

    if self.skip is not None:
    identity = self.skip(x)

    out += identity
    out = F.relu(out)

    return out

    如果你喜欢“单线”:

    def forward(self, x):
    out = self.block(x)
    out += (x if self.skip is None else self.skip(x))
    out = F.relu(out)
    return out

    如果你真的喜欢单线(拜托,那太多了,不要选择这个选项:))

    def forward(self, x):
    return F.relu(self.block(x) + (x if self.skip is None else self.skip(x)))

    关于python-3.x - 使用 PyTorch 实现一个简单的 ResNet block ,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60817390/

    31 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com