gpt4 book ai didi

machine-learning - Transformer 模型中自注意力的计算复杂性

转载 作者:行者123 更新时间:2023-12-03 09:40:20 31 4
gpt4 key购买 nike

我最近经历了Transformer来自 Google Research 的论文描述了自注意力层如何完全取代传统的基于 RNN 的序列编码层以进行机器翻译。在论文的表 1 中,作者比较了不同序列编码层的计算复杂度,并声明(稍后)当序列长度 n 时,自注意力层比 RNN 层更快。小于向量表示的维数 d .
然而,如果我对计算的理解是正确的,自注意力层的复杂性似乎比声称的要低。让 X成为自注意力层的输入。然后,X将有形状 (n, d)因为有 n每个维度的词向量(对应于行)d .计算自注意力的输出需要以下步骤(为了简单起见,考虑单头自注意力):

  • 线性变换 X 的行计算查询 Q , 键 K ,和值 V矩阵,每个矩阵都有形状 (n, d) .这是通过后乘 X 实现的具有 3 个学习过的形状矩阵 (d, d) ,计算复杂度为 O(n d^2) .
  • 计算层输出,在论文的公式 1 中指定为 SoftMax(Q Kt / sqrt(d)) V ,其中 softmax 是在每一行上计算的。计算Q Kt有复杂性 O(n^2 d) ,然后将结果与 V 相乘有复杂性 O(n^2 d)以及。

  • 因此,该层的总复杂度为 O(n^2 d + n d^2) ,这比传统的 RNN 层差。在考虑适当的中间表示维度( dkdv )并最终乘以头数 h 后,我也获得了多头注意力的相同结果。 .
    为什么作者在报告总计算复杂度时忽略了计算查询、键和值矩阵的成本?
    我知道提议的层在 n 上是完全可并行化的。位置,但我相信表 1 并没有考虑到这一点。

    最佳答案

    首先,您的复杂性计算是正确的。那么,困惑的根源是什么?
    当原Attention paper首次引入,不需要计算Q , VK矩阵,因为这些值直接取自 RNN 的隐藏状态,因此注意力层的复杂性 O(n^2·d) .
    现在,了解什么Table 1请记住大多数人阅读论文的方式:他们阅读标题、摘要,然后查看图表和表格。只有当结果有趣时,他们才会更彻底地阅读论文。所以,Attention is all you need的主要思想论文是在 seq2seq 设置中用注意力机制完全替换 RNN 层,因为 RNN 训练真的很慢。如果你看 Table 1在这种情况下,您会看到它比较了 RNN、CNN 和注意力,并突出了本文的动机:使用注意力应该比 RNN 和 CNN 更有利。它应该在3个方面具有优势:计算步数恒定,运算量恒定通常的 Google 设置的计算复杂度较低,其中 n ~= 100d ~= 1000 .但正如任何想法一样,它击中了现实的硬墙。实际上,为了让这个伟大的想法发挥作用,他们必须添加位置编码,重新制定注意力并为其添加多个头。结果是 Transformer 架构,其计算复杂度为 O(n^2·d + n·d^2)仍然比 RNN 快得多(就挂钟时间而言),并产生更好的结果。
    所以你的问题的答案是作者在 Table 1 中提到的注意力层。是严格的注意力机制。这不是 Transformer 的复杂性。他们非常清楚他们模型的复杂性(我引用):

    Separable convolutions [6], however, decrease the complexityconsiderably, to O(k·n·d + n·d^2). Even with k = n, however, thecomplexity of a separable convolution is equal to the combination of aself-attention layer and a point-wise feed-forward layer, the approachwe take in our model.

    关于machine-learning - Transformer 模型中自注意力的计算复杂性,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65703260/

    31 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com