gpt4 book ai didi

python - 理解 torch.nn.Flatten

转载 作者:行者123 更新时间:2023-12-05 03:42:00 31 4
gpt4 key购买 nike

我知道 Flatten 会删除除一个维度之外的所有维度。比如我理解flatten() :

> t = torch.ones(4, 3)
> t
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])

> flatten(t)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

但是,我没有得到 Flatten,尤其是我没有从 the doc 中得到这段代码的含义。 :

>>> input = torch.randn(32, 1, 5, 5)
>>> m = nn.Sequential(
>>> nn.Conv2d(1, 32, 5, 1, 1),
>>> nn.Flatten()
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([32, 288])

我觉得输出的大小应该是[160],因为32*5=160

Q1. 那么为什么它输出大小 [32,288]

Q2. 我也不明白文档中给出的 shape 信息的含义:

enter image description here

Q3.还有参数的含义:

enter image description here

最佳答案

这是默认行为的差异。 torch.flatten 默认情况下展平所有尺寸,而 torch.nn.Flatten 默认情况下,从第二个维度(索引 1)开始展平所有维度。

您可以在 start_dim 的默认值中看到此行为和 end_dim争论。 start_dim参数表示要展平的第一个维度(零索引),end_dim参数表示要展平的最后一个维度。所以,当start_dim=1 , 这是 torch.nn.Flatten 的默认值,第一个维度(索引 0)未展平,但在 start_dim=0 时包含在内, 这是 torch.flatten 的默认值.

造成这种差异的原因可能是因为 torch.nn.Flatten旨在与 torch.nn.Sequential 一起使用,其中通常对一批输入执行一系列操作,其中每个输入都独立于其他输入进行处理。例如,如果您有一批图像并调用 torch.nn.Flatten ,典型的用例是分别展平每张图像,而不是展平整个批处理。

如果您确实想使用 torch.nn.Flatten 展平所有尺寸, 您可以简单地将对象创建为 torch.nn.Flatten(start_dim=0) .

最后,文档中的形状信息仅涵盖张量的形状将如何受到影响,说明第一个(索引 0)维度保持原样。所以,如果你有一个形状为 (N, *dims) 的输入张量, 其中*dims是任意维度序列,输出张量的形状为 (N, product of *dims) ,因为除了批量维度之外的所有维度都被展平了。例如,形状为 (3,10,10) 的输入输出形状为 (3, 10 x 10) = (3, 100) .

关于python - 理解 torch.nn.Flatten,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67460123/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com