Transformer模型已在语言和视觉领域取得成功。
然而,将其扩展到长序列(例如长文档或高分辨率图像)成本高昂,因为自注意力机制的时间和内存复杂度与输入序列长度呈二次方关系。
在本文中,我们提出了一种高效的自注意力机制——长短Transformer (Transformer-LS),用于对语言和视觉任务中的长序列进行建模,其时间复杂度为线性。
它结合了一种新颖的具有动态投影的长程注意力机制来建模远程关联,以及一种短期注意力机制来捕获细粒度的局部关联。
我们提出了一种双重归一化策略来解决这两种注意力机制之间的尺度不匹配问题。
Transformer-LS 可以应用于自回归模型和双向模型,而不会增加额外的复杂性。
我们的方法在语言和视觉领域的多个任务上都优于现有技术模型,包括远程竞技场基准测试、自回归语言建模和ImageNet分类。
For instance, Transformer-LS achieves 0.97 test BPC on enwik8 using half the number of parameters than previous method, while being faster and is able to handle 3
×
as long sequences compared to its full-attention version on the same hardware.
在 ImageNet 上,它可以获得最先进的结果(例如,仅在
224
×
224
ImageNet-1K 上训练的中等大小的 55.8M 模型可以获得 Top-1 准确率 84.1%),同时在高分辨率图像上更具可扩展性。
源代码和模型已发布在 https://github.com/NVIDIA/transformer-ls。
图1:
单个注意力头的长短期注意力。
其中,序列长度
n
=
8
,隐藏维度
d
=
3
,局部窗口段大小
w
=
2
,以及动态投影的秩
r
=
3
。
在图中,
K
(
V
)
表示键
K
或值
V
。在左图中,我们将
K
或
V
∈
ℝ
n
×
d
虚拟地复制成
n
行,并突出显示所有
n
查询
Q
短期注意力范围内(表示为
K
~
(
V
~
)
)的键和值。
在中间图中,所有查询都关注长期注意力中相同的投影键
K
¯
和值
V
¯
。
在右图中,
K
~
(
V
~
)
和
K
¯
(
V
¯
)
首先用两组 LayerNorm 进行归一化,查询同时关注其注意力范围内的归一化
K
~
(
V
~
)
和
K
¯
(
V
¯
)
。
其中
Q
,
K
,
V
∈
ℝ
n
×
d
是查询、键和值嵌入,
W
O
∈
ℝ
d
×
d
是输出的投影矩阵,第
i
个头
H
i
∈
ℝ
n
×
d
k
是缩放点积注意力,
d
k
=
d
/
h
是每个头的嵌入维度,
其中
W
i
Q
,
W
i
K
,
W
i
V
∈
ℝ
d
×
d
k
是学习到的投影矩阵,
A
i
∈
ℝ
n
×
n
表示每个注意力头的完整注意力矩阵。
计算和存储
A
i
的复杂度为
O
(
n
2
)
,当
n
很大时,这可能是难以承受的。
为简便起见,我们下面的讨论基于一维输入序列的情况。
给定预定的顺序,将其扩展到二维图像数据是很简单的。
3.2
通过分段滑动窗口实现短期注意力
我们使用简单而有效的滑动窗口注意力来捕获细粒度的局部相关性,其中每个查询都关注固定大小邻域内的附近符元。
类似的技术也已在
[14, 16, 11]
中采用。
具体来说,为了提高效率,我们将输入序列划分为长度为
w
的不相交段。
一个段内的所有符元都关注其所属段内的所有符元,以及其所属段左右两侧
w
/
2
个连续的符元(必要时进行零填充),从而导致对总共
2
w
个键值对的注意力跨度。
请参见附录中的图
5
。
对于第
i
个头部中位置
t
处的每个查询
Q
t
,我们将它窗口内的
2
w
键值对表示为
K
~
t
,
V
~
t
∈
ℝ
2
w
×
d
。
使用PyTorch实现时,这种分段滑动窗口注意力比每个符元滑动窗口注意力更快,其中每个符元都关注自身及其左右
w
个符元,并且其内存消耗随序列长度线性缩放;更多细节请参见
[14]
和我们的图
3
。
从这些观察结果出发,我们将第
i
个头部的动态低秩投影参数化为
P
i
=
f
(
K
)
∈
ℝ
n
×
r
,其中
r
≪
n
是低秩大小,而
P
i
取决于输入序列的所有键
K
∈
ℝ
n
×
d
。
它将
(
n
×
d
k
)
维键嵌入
K
W
i
K
和值嵌入
V
W
i
V
投影到更短的
(
r
×
d
k
)
维键
K
¯
i
和值
V
¯
i
嵌入。
与Linformer
[17]
不同,低秩投影矩阵是动态的,它取决于输入序列,旨在更灵活且更能适应例如插入、删除、释义以及其他改变序列长度的操作。
请参见表
2
中的示例。
注意,查询嵌入
Q
W
i
Q
∈
ℝ
n
×
d
k
保持相同的长度,我们让每个查询都关注
K
¯
i
和
V
¯
i
。
通过这种方式,完整的
(
n
×
n
)
注意力矩阵可以分解为两个矩阵的乘积,这两个矩阵具有
r
列或行。
具体来说,我们将动态投影矩阵
P
i
∈
ℝ
n
×
r
和低秩注意力的键值嵌入
K
¯
i
,
V
¯
i
∈
ℝ
r
×
d
k
定义为
其中
W
i
P
∈
ℝ
d
×
r
是可学习的参数,
1
softmax对所有
n
符元的第一维上的投影权重进行归一化,这在我们的实验中稳定了训练。
注意,在所有我们考虑的实验中
K
=
V
,所以如果
P
i
依赖于
V
,它将保持不变。公式
3
的计算复杂度为
O
(
r
n
)
。
为了了解完整的注意力是如何被低秩矩阵的乘积所取代的,我们将每个长程注意力的头部
H
i
∈
ℝ
n
×
d
k
计算为:
因此,完整的注意力现在被两个低秩矩阵
A
¯
i
∈
ℝ
n
×
r
和
P
i
⊺
∈
ℝ
r
×
n
的隐式乘积所取代,计算复杂度降低到
O
(
r
n
)
。
注意,查询在所有符元上的有效注意力权重之和仍然为1。
我们的全局注意力允许每个查询关注同一自注意力层内的所有符元嵌入。
相反,稀疏注意力机制
[14, 16]
需要堆叠多层来构建这种相关性。
应用于自回归模型:
在自回归模型中,每个符元只能关注之前的符元,因此长程注意力对于不同的符元应该具有不同的范围。
实现我们全局注意力的一个直接方法是循环更新每个查询的
K
¯
i
,
V
¯
i
,但这需要由于softmax的非线性而为每个符元重新计算公式(
3
)中的投影,这导致
O
(
r
n
2
)
的计算复杂度。
为了保持线性复杂度,对于自回归模型,我们首先将输入序列划分成长度为
l
的等长段,并应用我们的动态投影从每个段中提取
K
¯
i
,
V
¯
i
。
每个符元只能关注
K
¯
i
,
V
¯
i
不包含其未来符元的片段。
形式上,设
Q
t
为位置
t
处的查询,
K
(
l
−
1
)
s
:
l
s
,
V
(
l
−
1
)
s
:
l
s
为来自第
s
个片段的键值对,以及
s
t
=
⌊
t
/
l
⌋
。
对于自回归模型,我们通过关注
K
i
,
t
,
V
i
,
t
来计算
Q
t
的远程注意力,定义为
为了聚合局部和远程注意力,我们没有为不同的头采用不同的注意力机制
[12, 14, 34]
,而是让第
i
个头的每个查询都关注来自局部窗口和全局低秩投影的键和值的并集,因此它可以学习选择来自两者中的重要信息。
在我们对自回归语言模型的初步试验中,我们发现这种聚合策略比分离头部效果更好。
具体来说,对于第
i
个头,我们将全局低秩投影的键和值表示为
K
¯
i
,
V
¯
i
∈
ℝ
r
×
d
k
,并将局部键和值表示为
K
~
t
,
V
~
t
∈
ℝ
2
w
×
d
,它们位于查询
Q
t
位置
t
的局部窗口内。
然后,位置
t
处的第
i
个注意力
H
i
,
t
为
其中
[
⋅
;
⋅
]
表示沿第一维连接矩阵。
此外,我们发现
K
~
t
W
i
K
和
K
¯
i
的初始范数之间存在尺度不匹配,这使得在语言和视觉任务的初始化阶段,注意力偏向局部窗口。
我们引入一种归一化策略(DualLN)来对齐范数并提高聚合的有效性。
双层归一化(DualLN):
对于具有层归一化(LN)的Transformer(参见
[44]
的图示),
K
i
,
V
i
嵌入是LN层的输出,因此它们在初始化时均值为零,方差为一。
均值为零的向量的
ℓ
2
范数与其方差成比例。
我们注意到,加权平均值将降低此类均值为零向量的方差,从而降低其范数。
结果,公式(
3
)中加权平均值
K
¯
i
,
V
¯
i
的低秩注意力嵌入向量的范数将小于来自滑动窗口注意力的常规键和值嵌入(参见图
2
左图示)。
这种尺度不匹配会导致两个副作用。
首先,局部秩分量的内积
Q
t
W
i
Q
K
¯
i
⊺
的幅度往往小于局部窗口的幅度,因此长程注意力的注意力分数系统性地较小。
其次,即使低秩和局部窗口分配相同的注意力分数,低秩注意力的键值对
K
¯
i
,
V
¯
i
对
H
i
方向的影响也会自然减小,因为
V
¯
i
的范数较小。
这两种效应都会导致低秩分量上的梯度较小,并阻碍模型学习有效利用长程相关性。
其中
LN
(
⋅
)
L
,
LN
(
⋅
)
G
分别表示局部和全局注意力的层归一化。
在实践中,为了保持局部注意力和动态投影之间的一致性,我们使用
LN
L
(
K
)
,
LN
L
(
V
)
而不是
K
,
V
来计算公式
3
中的
K
¯
i
,
V
¯
i
。
如图
2
右图所示,使用双层归一化(DualLN)训练的Transformer-LS模型的验证损失始终低于未使用双层归一化(DualLN)的模型。