专栏名称: 小白学视觉
本公众号主要介绍机器视觉基础知识和新闻,以及在学习机器视觉时遇到的各种纠结和坑的心路历程。
目录
相关文章推荐
超级数学建模  ·  不是吧!瓷器也会开花? ·  2 天前  
超级数学建模  ·  3000一罐的贵妇面霜,真好用! ·  3 天前  
超级数学建模  ·  懂中式美学的人,真不简单! ·  3 天前  
超级数学建模  ·  他是DeepSeek关键人才!差点留在美国… ... ·  3 天前  
51好读  ›  专栏  ›  小白学视觉

PyTorch实现 Self Attention

小白学视觉  · 公众号  ·  · 2024-09-22 10:05

正文

点击上方 小白学视觉 ”,选择加" 星标 "或“ 置顶

重磅干货,第一时间送达

仅作学术分享,不代表本公众号立场,侵权联系删除
转载于: 作者丨Connolly@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/445016136
编辑丨极市平台

通过修改SelfAttention的执行逻辑,可以节省大量的激活值显存开销。

这篇文章的消除方法来自于2021年12月10日谷歌放到arxiv上的文章self attention does not need O(n^2) memory. 该方法巧妙地使用了小学学到的加法分配率,将self attention中的固定激活值降到了O(1)的程度。[1]

Self Attention 固定激活值显存分析

Hugging face Transformers中,SelfAttention 内核实现

表格中只列举了会实测中产生激活值的操作,其中B为Batch_size,L为sequence_length,H为hidden_size,m为SelfAttention中head的数量。

则总和

观察:

  1. 固定时, 即模型结构是固定的时候, 我们发现激活值是和 线性相关的。
  2. 变化时, 我们发现会存在一个常数项 , 我称这个常数激活值开销为固定激活值。这个主要是在Query和Key矩阵做乘法, 以及后续的一些操作中生成的。即在 等操作中出现。

SelfAttention 固定激活值显存优化

1. Prerequisites

1.1 Softmax 计算过程

对于向量 表示 中的第 个元素, 那么这个元素的softmax值为:


1.2 SelfAttention计算过程

为了简化计算,我们先忽略掉Scale和Dropout,因为它们都是单操作数的op,这个忽略不会给我们的分析带来影响。考虑最后输出矩阵第i行,第j列的结果,在原始的实现中,他的计算过程为:

, QK的矩阵乘法, 产生Tensor , shape为

维度的Softmax, 产生Tensor , shape为

. Softmax和Value的矩阵乘, 产生最终输出结果, shape为 .

写成伪代码则为:

"""
inputs: Q[L][H/m], K[L][H/m], V[L][H/m]
outputs: O[L][H/m]

matrix A[L][L]=0, S[L][L]=0, O[L][H/m]=0 # 初始化为0矩阵, A,S为中间激活值矩阵
"""


# QK Matmul
for i in range(L):
    for j in range(L):
        for l in range(H/m):
            A[i][j] += Q[i][l]*Q[l][j]

# Softmax, dim=-1
for i in range(L):
    temp = 0
    for j in range(L):
        S[i][j] = math.exp(A[i][j])
        temp += S[i][j]
    S[i]/=temp

# OV Matmul
for i in range(L):
    for j in range(H/m):
        for l in range(L):
            O[i][j] += S[i][l]*Q[l][j]

return O

2. 显存优化

Google采用了一个非常简单的方法来节省Attention核中的大量的显存开销,具体计算过程为:

, QK的矩阵乘法, 但是不单独执行, 直接代入下一个式子。

, 这里没有除以求和值, 而是把除法挪到了下面。



可以发现, 和原来的算法的差别在于把 的计算放到了后面。采用这种方法的好处是, 我 们可以分开计算 了。

我们用临时变量







请到「今天看啥」查看全文