点击上方
“
小白学视觉
”,选择加"
星标
"或“
置顶
”
重磅干货,第一时间送达
来源丨https://zhuanlan.zhihu.com/p/445016136
通过修改SelfAttention的执行逻辑,可以节省大量的激活值显存开销。
这篇文章的消除方法来自于2021年12月10日谷歌放到arxiv上的文章self attention does not need O(n^2) memory. 该方法巧妙地使用了小学学到的加法分配率,将self attention中的固定激活值降到了O(1)的程度。[1]
Self Attention 固定激活值显存分析
Hugging face Transformers中,SelfAttention 内核实现
表格中只列举了会实测中产生激活值的操作,其中B为Batch_size,L为sequence_length,H为hidden_size,m为SelfAttention中head的数量。
则总和
。
观察:
-
当
固定时, 即模型结构是固定的时候, 我们发现激活值是和
线性相关的。
-
当
变化时, 我们发现会存在一个常数项
, 我称这个常数激活值开销为固定激活值。这个主要是在Query和Key矩阵做乘法, 以及后续的一些操作中生成的。即在
等操作中出现。
SelfAttention 固定激活值显存优化
1. Prerequisites
1.1 Softmax 计算过程
对于向量
表示
中的第
个元素, 那么这个元素的softmax值为:
1.2 SelfAttention计算过程
为了简化计算,我们先忽略掉Scale和Dropout,因为它们都是单操作数的op,这个忽略不会给我们的分析带来影响。考虑最后输出矩阵第i行,第j列的结果,在原始的实现中,他的计算过程为:
, QK的矩阵乘法, 产生Tensor
, shape为
维度的Softmax, 产生Tensor
, shape为
. Softmax和Value的矩阵乘, 产生最终输出结果, shape为
.
写成伪代码则为:
"""
inputs: Q[L][H/m], K[L][H/m], V[L][H/m]
outputs: O[L][H/m]
matrix A[L][L]=0, S[L][L]=0, O[L][H/m]=0 # 初始化为0矩阵, A,S为中间激活值矩阵
"""
# QK Matmul
for i in range(L):
for j in range(L):
for l in range(H/m):
A[i][j] += Q[i][l]*Q[l][j]
# Softmax, dim=-1
for i in range(L):
temp = 0
for j in range(L):
S[i][j] = math.exp(A[i][j])
temp += S[i][j]
S[i]/=temp
# OV Matmul
for i in range(L):
for j in range(H/m):
for l in range(L):
O[i][j] += S[i][l]*Q[l][j]
return O
2. 显存优化
Google采用了一个非常简单的方法来节省Attention核中的大量的显存开销,具体计算过程为:
, QK的矩阵乘法, 但是不单独执行, 直接代入下一个式子。
, 这里没有除以求和值, 而是把除法挪到了下面。
可以发现, 和原来的算法的差别在于把
的计算放到了后面。采用这种方法的好处是, 我 们可以分开计算
和
了。
我们用临时变量
和