专栏名称: GiantPandaCV
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
GiantPandaCV  ·  免费 | 抢先试用此芯Armv9 AI ... ·  3 天前  
GiantPandaCV  ·  美团基于SGLang提供INT8无损满血版D ... ·  4 天前  
51好读  ›  专栏  ›  GiantPandaCV

FlashAttention-3 发布!比FlashAttention-2 快 1.5-2.0 倍

GiantPandaCV  · 公众号  · 3D  · 2024-07-12 14:34

主要观点总结

本文介绍了FlashAttention-3的优化技术,该技术旨在提高GPU上注意力机制的效率,特别是在Hopper GPU上。通过利用新的硬件特性如WGMMA、TMA和FP8,FlashAttention-3实现了更高的性能和更低的内存使用。文章还解释了为何需要重叠GEMM和softmax操作,并介绍了几种重叠技术。此外,还提到了针对LLMs中异常值的不相关处理技术,以减少量化误差。最后,文章展示了一些FlashAttention-3的结果,并将其与其他实现进行了比较。

关键观点总结

关键观点1: FlashAttention-3的优化技术

FlashAttention-3通过利用Hopper GPU的新硬件特性(WGMMA、TMA和FP8)来提高GPU上注意力机制的效率。这些技术提高了GPU的利用率,加速了大型语言模型(LLMs)的训练和推理。

关键观点2: GEMM和softmax操作的重叠

为了提高效率,需要重叠GEMM和softmax操作。现代加速器上的非矩阵乘法操作(如特殊函数)比矩阵乘法操作慢得多,因此重叠这些操作可以显著提高性能。FlashAttention-3使用了多种重叠技术,包括warp调度器的手动调度。

关键观点3: 不相关处理技术减少量化误差

LLM激活可能存在比其他特征大得多的异常值,这会导致量化误差。FlashAttention-3利用不相关处理技术来减少量化误差,这是一种在量化文献中使用的技术。通过模拟异常值,该技术可以减少量化误差。

关键观点4: FlashAttention-3的结果和比较

文章展示了一些FlashAttention-3的结果,并将其与FlashAttention-2以及Triton和cuDNN的实现进行了比较。FlashAttention-3在FP16和FP8上都实现了显著的性能提升。


正文

英文原文( Tri Dao ): https:// tridao.me/blog/2024/flash3/
中文翻译(手抓饼熊):https://zhuanlan.zhihu.com/p/708409249
总结 FlashAttention-3:
  1. 对H系列架构更好的优化(新的指令特性使用);

  2. Gemm和Softmax计算重叠;

  3. FP8支持;

Attention作为无处不在的Transformer架构的核心层,注意力是大型语言模型和长上下文应用的瓶颈。FlashAttention(以及FlashAttention-2)开创了一种通过最小化内存读/写来加速GPU上的注意力的方法,现在大多数库都在用它来加速Transformer的训练和推理。这导致了过去两年中LLM上下文长度的大幅增加,从2-4K(GPT-3,OPT)到128K(GPT-4),甚至1M(Llama 3)。然而,尽管取得成功, FlashAttention尚未充分利用现代硬件的新功能,FlashAttention-2在H100 GPU上仅实现了理论最大FLOP的35%利用率 。在这篇博文中,我们描述了三种加速Hopper GPU上注意力的主要技术: 利用Tensor Cores和TMA的异步性(1)通过warp-specialization重叠整体计算和数据移动,(2)交错块状matmul和softmax操作,以及(3)利用硬件支持FP8低精度的不一致处理。

我们很高兴发布FlashAttention-3,它采用了这些技术。与FP16相比,它的速度比FlashAttention-2快1.5-2.0倍,达到了740 TFLOPS,即H100理论最大FLOPS的75%利用率。使用FP8,FlashAttention-3接近1.2 PFLOPS,比基准FP8注意力小2.6倍的错误。

FlashAttention-3的改进将带来:

1. 更高效的GPU利用率 :新技术使用了H100 GPU最大性能的75%,而不是之前的35%。这导致训练和运行大型语言模型(LLMs)比以前版本快得多(1.5-2倍)。

2. 更低精度下更好的性能 :FlashAttention-3可以使用更低精度的数字(FP8)而保持准确性。这使得处理速度更快,潜在地减少内存使用,这可能会为运行大规模AI操作的客户带来成本节约和效率提高。

3. 在LLMs中使用更长的上下文 :通过加速注意力机制,FlashAttention-3使得AI模型能够更有效地处理更长的文本片段。这可以实现能够理解和生成更长、更复杂内容而不会减慢速度的应用程序。

FlashAttention-3可在以下链接获取:https://tridao.me/publications/flash3/flash3.pdf

FlashAttention Recap

FlashAttention是一种重新排序注意力计算并利用平铺和重计算来显着加快速度并将内存使用量从二次降至与序列长度线性相关的算法。我们使用平铺将输入块从HBM(GPU内存)加载到SRAM(快速缓存),针对该块执行注意力计算,并在HBM中更新输出。通过不将大型中间注意力矩阵写入HBM,我们减少了内存读写量,从而实现2-4倍的时钟速度加快。

这里我们展示了FlashAttention前向传递的图表:通过平铺和softmax重新缩放,我们通过块操作,避免了从HBM读写,同时获得了正确的输出,没有近似值。




New hardware features on Hopper GPUs - WGMMA, TMA, FP8

尽管FlashAttention-2在Ampere(A100) GPU上可以实现高达70%的理论最大FLOPS,但它尚未利用Hopper GPU上的新功能来最大化性能。我们在这里描述一些新的Hopper特定功能,以及它们的重要性。

  1. WGMMA(Warpgroup Matrix Multiply-Accumulate)。 这一新功能利用了Hopper上的新张量核心,具有比Ampere上的旧的mma.sync指令更高的吞吐量。


2. TMA(Tensor Memory Accelerator)。 这是一个专门的硬件单元,加速全局内存和共享内存之间的数据传输,处理所有索引计算和越界预测。这释放了寄存器,这是增加瓦片大小和效率的宝贵资源。


3. 低精度FP8。 这将张量核心吞吐量翻倍(例如,FP16为989 TFLOPS,FP8为1978 TFLOPS),但通过使用更少的位来表示浮点数来进行精度折衷。


FlashAttention-3利用Hopper的所有这些新功能,使用了NVIDIA的CUTLASS库中强大的抽象。

通过重写FlashAttention以利用这些新功能,我们已经显着加快了速度(例如,从FlashAttention-2 FP16前向传递的350 TFLOPS到约540-570 TFLOPS)。然而,Hopper上新指令(WGMMA和TMA)的异步特性开启了额外的算法机会,可以重叠操作,从而提取更大的性能。在本博客文章中,我们将解释两种特定于注意力的技术。分别使用TMA和WGMMA执行生产者和消费者warp的warp专用性的通用技术,在GEMM的上下文中已经广泛讨论,并在这里起到相同的作用。

Asynchrony: Overlapping GEMM and Softmax

Why overlap?

注意力机制中的GEMM(Q和K之间的矩阵乘法以及注意力概率P和V之间的矩阵乘法)和softmax是两个主要操作。为什么我们需要将它们重叠?毕竟大部分的FLOPS都在GEMM中吧?只要GEMM快速(例如,使用WGMMA指令计算),GPU不应该很快吗?

问题在于现代加速器上的非矩阵乘法操作比矩阵乘法操作慢得多。特殊函数(例如softmax中的指数函数)的吞吐量甚至比浮点乘加操作还要低;它们是由多功能单元计算的,这是一个与浮点乘加或矩阵乘加不同的单元。以H100 GPU SXM5为例,其FP16矩阵乘法性能为989 TFLOPS,但特殊函数性能仅为3.9 TFLOPS(吞吐量低256倍)!对于头维度为128的情况,矩阵乘法FLOPS比指数函数FLOPS多512倍,这意味着指数函数可能需要比矩阵乘法多50%的时间。对于FP8来说情况甚至更糟,矩阵乘法FLOPS速度是指数函数的两倍。理想情况下,我们希望矩阵乘法和softmax能够并行操作。当张量核心忙于矩阵乘法时,多功能单元应该计算指数函数!

Inter-warpgroup overlapping with pingpong scheduling

重叠GEMM和softmax的第一个且最简单的方法就是什么都不做!warp调度器已经尝试调度warp,以便如果某些warp被阻塞(例如,等待GEMM结果),其他warp可以运行。也就是说,warp调度器已经为我们做了一些这样的重叠工作,而且是免费的。

然而,我们可以通过手动调度一些操作来进一步改进。举个例子,如果我们有2个warp组(标记为1和2 - 每个warp组是4个warp的组合),我们可以使用同步屏障(bar.sync),这样warp组1首先执行其GEMM操作(例如,一次迭代中的GEMM1和下一次迭代中的GEMM0),然后warp组2进行处理(例如softmax)。

这将使我们能够在其他warpgroup的GEMMs的阴影中执行softmax。当然,这个数字只是一个夸张;在实践中,调度并不是这么干净的。尽管如此,来回调度可以将FP16注意力前向传递的性能从约570 TFLOPS提高到620 TFLOPS(头维度128,序列长度8K)。

Intra-warpgroup overlapping of GEMM and Softmax

即使在一个warpgroup内部,softmax的部分可以在该warpgroup的GEMMs运行时运行。这在这幅图中有所体现,同一颜色表示同一次迭代。

这种流水线技术将FP16注意力前向的吞吐量从大约620 TFLOPS提高到大约640-660 TFLOPS,但代价是更高的寄存器压力。我们需要更多的寄存器来保存GEMMs的累加器和softmax的输入/输出。总体而言,我们发现这种技术提供了有利的折衷方案。

Low-precision: reduce quantization error with incoherent processing

LLM激活可能存在比其他特征大得多的异常值。这些异常值使得量化变得困难,产生了更大的量化误差。我们利用不相关处理,这是量化文献中使用的一种技术(例如来自QuIP),它将查询和键与随机正交矩阵相乘,以“展开”异常值并减少量化误差。具体来说,我们使用哈达玛变换(带有随机符号),可以在O(d log d)的时间内每个注意头中执行,而不是O(d^2)的时间,其中d是头维度。由于哈达玛变换受到内存带宽的限制,它可以与之前的操作(如旋转嵌入,也受内存带宽限制)“免费”融合。







请到「今天看啥」查看全文