gpt4 book ai didi

python - PyTorch 中带有偏置的嵌入层和线性层之间有什么区别

转载 作者:行者123 更新时间:2023-12-04 12:35:31 29 4
gpt4 key购买 nike

我正在阅读“使用 fastai 和 PyTorch 为编码人员进行深度学习”一书。对于 Embedding 模块的作用,我仍然有些困惑。它似乎是一个简短而简单的网络,但我似乎无法理解 Embedding 与 Linear 没有偏见的不同之处。我知道它做了一些更快的点积计算版本,其中一个矩阵是一个单热编码矩阵,另一个是嵌入矩阵。这样做实际上是为了选择一条数据?请指出我错在哪里。这是书中展示的简单网络之一。

class DotProduct(Module):
def __init__(self, n_users, n_movies, n_factors):
self.user_factors = Embedding(n_users, n_factors)
self.movie_factors = Embedding(n_movies, n_factors)

def forward(self, x):
users = self.user_factors(x[:,0])
movies = self.movie_factors(x[:,1])
return (users * movies).sum(dim=1)

最佳答案

嵌入

[...] what Embedding does differently than Linear without a bias.


基本上一切。 torch.nn.Embedding 是一个查找表;本质上与 torch.Tensor 的工作原理相同但有一些曲折(例如在指定索引处使用稀疏嵌入或默认值的可能性)。
例如:
import torch

embedding = torch.nn.Embedding(3, 4)

print(embedding.weight)

print(embedding(torch.tensor([1])))
会输出:
Parameter containing:
tensor([[ 0.1420, -0.1886, 0.6524, 0.3079],
[ 0.2620, 0.4661, 0.7936, -1.6946],
[ 0.0931, 0.3512, 0.3210, -0.5828]], requires_grad=True)
tensor([[ 0.2620, 0.4661, 0.7936, -1.6946]], grad_fn=<EmbeddingBackward>)
所以我们基本上采用了嵌入的第一行。它的作用不止于此。
它在哪里使用?
通常当我们想要为每一行编码一些含义(如 word2vec)时(例如,语义上接近的单词在欧几里得空间中接近)并可能训练它们
线性
torch.nn.Linear (没有偏见)也是 torch.Tensor (重量) 但是 它对它(和输入)进行操作,本质上是:
output = input.matmul(weight.t())
每次调用图层时(请参阅 source codefunctional definition of this layer )。
代码片段
代码片段中的层基本上是这样的:
  • __init__ 中创建两个查找表
  • 使用形状 (batch_size, 2) 的输入调用该层:
  • 第一列包含用户嵌入的索引
  • 第二列包含电影嵌入的索引

  • 这些嵌入相乘并求和返回 (batch_size,) (所以它与 nn.Linear 不同,后者会返回 (batch_size, out_features) 并做点积而不是像这里这样的逐元素乘法和求和)

  • 这可能用于训练一些类似推荐系统的表示(用户和电影)。
    其他的东西

    I know it does some faster computational version of a dot productwhere one of the matrices is a one-hot encoded matrix and the other isthe embedding matrix.


    不,它没有。 torch.nn.Embedding 可以是一个热编码并且也可能是稀疏的,但取决于算法(以及这些算法是否支持稀疏性),会有加速或没有加速。

    关于python - PyTorch 中带有偏置的嵌入层和线性层之间有什么区别,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65445174/

    29 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com