gpt4 book ai didi

python - 通过顺序容器构建的 PyTorch 扁平化层

转载 作者:太空狗 更新时间:2023-10-29 20:39:19 25 4
gpt4 key购买 nike

我正在尝试通过 PyTorch 的顺序容器构建一个 cnn,我的问题是我不知道如何展平该层。

main = nn.Sequential()
self._conv_block(main, 'conv_0', 3, 6, 5)
main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
self._conv_block(main, 'conv_1', 6, 16, 3)
main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2))
main.add_module('flatten', make_it_flatten)

我应该在“make_it_flatten”中放什么?我试图展平主要但它不起作用,主要不存在调用 View

main = main.view(-1, 16*3*3)

最佳答案

这可能不是您要找的东西,但您可以简单地创建自己的 nn.Module 来展平任何输入,然后您可以将其添加到 nn.Sequential( ) 对象:

class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size()[0], -1)

x.size()[0] 将选择 batch dim,-1 将计算所有剩余的 dims 以适应元素的数量,从而展平任何张量/变量。

并在 nn.Sequential 中使用它:

main = nn.Sequential()
self._conv_block(main, 'conv_0', 3, 6, 5)
main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
self._conv_block(main, 'conv_1', 6, 16, 3)
main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2))
main.add_module('flatten', Flatten())

关于python - 通过顺序容器构建的 PyTorch 扁平化层,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45584907/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com