gpt4 book ai didi

python - PyTorch 中 nn.Linear 的类定义是什么?

转载 作者:太空宇宙 更新时间:2023-11-04 07:14:02 24 4
gpt4 key购买 nike

下面代码中的self.hidden是什么?

import torch.nn as nn
import torch.nn.functional as F

class Network(nn.Module):
def __init__(self):
super().__init__()
self.hidden = nn.Linear(784, 256)
self.output = nn.Linear(256, 10)

def forward(self, x):
x = F.sigmoid(self.hidden(x))
x = F.softmax(self.output(x), dim=1)
return x

self.hiddennn.Linear 并且它可以将张量 x 作为参数。

最佳答案

What is the class definition of nn.Linear in pytorch?

来自 documentation :


类 torch.nn.Linear(in_features, out_features, bias=True)

对传入数据应用线性变换:y = x*W^T + b

参数:

  • in_features – 每个输入样本的大小(即 x 的大小)
  • out_features – 每个输出样本的大小(即 y 的大小)
  • bias – 如果设置为 False,该层将不会学习附加偏差。默认值:真

请注意,权重 W 的形状为 (out_features, in_features),偏差 b 的形状为 (out_features)。它们是随机初始化的,以后可以更改(例如,在神经网络的训练过程中,它们会通过某种优化算法进行更新)。

在您的神经网络中,self.hidden = nn.Linear(784, 256) 定义了一个隐藏(意味着它位于输入和输出之间层),完全连接的线性层,它接受形状为(batch_size,784)的输入x,其中批量大小是输入的数量(每个大小为 784) 立即传递给网络(作为单个张量),并通过线性方程 y = x*W^T + b 将其转换为张量 y 形状 (batch_size, 256)。它由 sigmoid 函数进一步转换,x = F.sigmoid(self.hidden(x))(它不是 nn.Linear 的一部分,而是一个额外的步骤)。

让我们看一个具体的例子:

import torch
import torch.nn as nn

x = torch.tensor([[1.0, -1.0],
[0.0, 1.0],
[0.0, 0.0]])

in_features = x.shape[1] # = 2
out_features = 2

m = nn.Linear(in_features, out_features)

其中 x 包含三个输入(即批量大小为 3)、x[0]x[1]x[3],每个大小为 2,输出的形状为 (batch size, out_features) = (3, 2)

参数(权重和偏差)的值是:

>>> m.weight
tensor([[-0.4500, 0.5856],
[-0.1807, -0.4963]])

>>> m.bias
tensor([ 0.2223, -0.6114])

(因为它们是随机初始化的,很可能你会得到与上面不同的值)

输出是:

>>> y = m(x)
tensor([[-0.8133, -0.2959],
[ 0.8079, -1.1077],
[ 0.2223, -0.6114]])

并且(在幕后)它被计算为:

y = x.matmul(m.weight.t()) + m.bias  # y = x*W^T + b

y[i,j] == x[i,0] * m.weight[j,0] + x[i,1] * m.weight[j,1] + m.bias[j]

其中i在区间[0, batch_size)中,j[0, out_features)中。

关于python - PyTorch 中 nn.Linear 的类定义是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54916135/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com