专栏名称: 飞桨PaddlePaddle
源于产业实践的开源深度学习平台
目录
相关文章推荐
大厂日爆  ·  腾讯组织架构调整,IEG迎来新变化 ·  昨天  
大厂日爆  ·  腾讯组织架构调整,IEG迎来新变化 ·  昨天  
财联社AI daily  ·  AI会玩宝可梦了!Claude打赢道馆馆主 ·  2 天前  
财联社AI daily  ·  AI会玩宝可梦了!Claude打赢道馆馆主 ·  2 天前  
51好读  ›  专栏  ›  飞桨PaddlePaddle

飞桨首创 FlashMask :加速大模型灵活注意力掩码计算,长序列训练的利器

飞桨PaddlePaddle  · 公众号  ·  · 2024-10-29 20:08

正文

在 Transformer 类大模型训练任务中,注意力掩码(Attention Mask)一方面带来了大量的冗余计算,另一方面因其 巨大的存储占用导致难以实现长序列场景的高效训练(其中 为序列长度)。虽然业界已有 FlashAttention 等针对特定注意力掩码的计算加速方法,但其支持的注意力掩码模式有限,难以满足大模型训练任务对灵活注意力掩码的需求。为了解决上述问题,飞桨独创 FlashMask 技术,提出了列式稀疏的注意力掩码表示方法,支持灵活多样的注意力掩码模式,使得存储复杂度从 降低至 ,并在此基础上实现了高效的算子 Kernel,极致加速大模型训练效率,尤其是长序列场景下的训练效率。
我们在NVIDIA A100 (80G) GPU上对 FlashMask 在大语言模型微调和对齐训练中的表现进行了评估,包括 SFT、LoRA、DPO和 RM 。与现有的 FlashAttention 稠密掩码方法相比,FlashMask 在端到端训练速度上实现了显著提升,速度提高幅度在 1.65 倍到 3.22 倍之间。此外,我们还评估了其 Kernel 层次上的性能。FlashMask 在理论最大浮点运算次数上达到了37.8%到62.3%,在 Kernel 每秒浮点运算次数(TFLOPs/s)方面,其性能超过FlexAttention,提升幅度为 12.1% 到 60.7% 。

arXiv 论文:

https://arxiv.org/abs/2410.01359 

PaddlePaddle 官方文档:

https://www.paddlepaddle.org.cn/documentation/docs/en/develop/api/paddle/nn/functional/flashmask_attention_en.html 

PaddleNLP 开源集成:

https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/docs/flashmask.md 

星河社区快速体验:

https://aistudio.baidu.com/projectdetail/8459413 
大模型训练的挑战
随着人工智能技术的迅猛发展,以 Transformer 为代表的大模型在自然语言处理、计算机视觉和多模态应用中展现出了非凡的能力。在这些大模型中,注意力(Attention)机制是一个关键环节。为了在大模型训练任务中确定哪些 Query-Key token 之间需要进行有效的 Attention 计算,业界通常使用注意力掩码(Attention Mask)。然而,目前的注意力掩码通常采用二维稠密矩阵表示,这导致了一些问题。一方面,这种表示方法引入了大量冗余计算, 因为许多无效 token 的 Attention 仍需计算; 另一方面,这种掩码的空间复杂度为 (其中 为序列长度) ,在长序列的训练场景中会造成巨大的存储压力,难以进行高效训练。为了解决这些问题,业界已经提出了一些方案,如 Memory Efficient Attention (MEA) [1] 和 FlashAttention [2] 。然而,这些方案支持的注意力掩码类型较为有限。正如图 1 所示,FlashAttention 只能支持如纯因果掩码(Causal)、滑动窗口掩码(Sliding Window)、因果文档掩码(Causal Document Mask)和文档掩码(Document Mask)等几种固定形式的掩码。然而,实际训练任务中使用的注意力掩码形式往往丰富多变,当前技术难以满足大模型不同训练任务对注意力掩码灵活性的要求。
图1 常见的注意力掩码类型
FlashMask 的创新:列式稀疏掩码表示方法与高效计算

1. 关键洞察

FlashMask 的核心发现是,在大模型常见的注意力掩码模式中,Query-Key token 的掩码模式具有一定的连续性。具体而言,对于每一个 Key token,无效注意力计算的 Query token 是相邻排列的。也就是说,在图 1 中二维掩码矩阵中,Query token 作用在每一列的 Key token 的灰色部分沿列方向连续分布。基于这一洞察,FlashMask 巧妙地将二维稠密掩码矩阵转换为一维的行索引区间,从而实现更为紧凑的表示形式,并显著降低了存储需求。我们可以公式化表示为:
其中 为 Key 的序列长度, 为二维的稠密掩码矩阵的第 列, 为连续的行索引区间,表示 的连续 Query token 是被 mask 掉,置为无效 Attention 计算。

2. 注意力掩码的列式稀疏掩码表示方法

为了高效处理因果和双向注意力场景中的复杂掩码模式,FlashMask 提出了一种新颖的列式稀疏表示方法。以对角线为区分,它使用四个一维向量来表示掩码:
  • 下三角起始行索引(Lower Triangular Start,简称
  • 下三角结束行索引(Lower Triangular End,简称
  • 上三角起始行索引(Upper Triangular Start,简称
  • 上三角结束行索引(Upper Triangular End,简称
其中下三角被 mask 掉的行索引区间使用 表示,上三角被 mask 掉的行索引区间使用 表示。
如图 2 所示,我们展示了16个 Query token 和16个 Key token 做 Attention 计算时较为复杂的二维稠密因果注意力的掩码矩阵,灰色单元格是 mask 区域。
图2 较为复杂的二维稠密因果注意力的掩码矩阵示意图
可以通过 两个向量进行表达,如下所示:
以第 0 列为例,开始 mask 的行为 13,结束 mask 的行为 15(开区间),表示位置为 13 和 14 的 Query token 不与位置为 0 的 Key token 做有效 Attention 计算。
为了高效处理因果和双向注意力场景中的复杂掩码模式,FlashMask 提出了一种新颖的列式稀疏表示方法。以对角线为区分,它使用四个一维向量来表示掩码:
  • 下三角起始行索引(Lower Triangular Start,简称
  • 下三角结束行索引(Lower Triangular End,简称
  • 上三角起始行索引(Upper Triangular Start,简称
  • 上三角结束行索引(Upper Triangular End,简称 ) 其中下三角被 mask 掉的行索引区间使用 表示,上三角被 mask 掉的行索引区间使用 [𝑈𝑇𝑆, 𝑈𝑇𝐸) 表示。
如图2所示,我们展示了16个Query token 和16个Key token 做 Attention 计算时较为复杂的二维稠密因果注意力的掩码矩阵,灰色单元格是 mask 区域。
更多的例子参考图 3 ,FlashMask 使用列式稀疏掩码表示方法,表达了图 1 中所有的注意力掩码模式。其中 - 的空缺表示在不同的场景下有不同的默认值, 中的默认值是 0,表示 mask 区域默认从第 0 行开始, 中的默认值是 Query 的序列长度,表示 mask 区域默认结束于最后一行。
图3 使用 FlashMask 的列式稀疏掩码表示方法表示图1的注意力掩码模式

3. 扩展 FlashAttention 支持复杂掩码

FlashMask 将列式掩码表示方法集成到 FlashAttention-2 算法中,增强了其对注意力掩码的支持能力。在 FlashAttention Kernel 的分块计算基础上,FlashMask 利用上述的 等掩码向量,来判断当前分块的掩码类型:
  • 完全掩码块:此类块的所有元素均被掩码,计算时可直接跳过。
  • 部分掩码块:此类块仅部分元素被掩码,因此需要对该块进行逐元素的掩码处理。
  • 未掩码块:此类块中的所有元素均未被掩码,可以简化计算过程,无需额外的掩码操作。
通过这种分类处理,FlashMask 显著提升了计算效率:完全掩码块的计算被直接跳过,未掩码块的计算得到简化,仅对部分掩码块执行必要的掩码操作,如图 4 所示。
图4 FlashMask 计算过程示意图
算法1详细描述了 FlashMask 扩展 FlashAttention-2 的前向计算过程,其中浅蓝色阴影部分表示 FlashMask 新增的计算步骤






请到「今天看啥」查看全文