Pytorch官方Blog:FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention
理论上,Attention is All You Need。然而在实践中,我们还需要像FlashAttention这样的优化的注意力机制实现。
尽管这些融合的注意力实现在性能上有了显著提升,并使得长序列上下文成为可能,但这种效率的提升是以牺牲灵活性为代价的。你不能再通过简单地编写几个PyTorch运算符来尝试新的
注意力
变体,而通常需要重新写一个新的自定义的Kernel,即使是使用triton等工具,也并不简单!这为机器学习研究人员创造了一种“Software Lottery”(这个词来源于谷歌的论文The Hardware Lottery,在机器学习领域中用来描述某些研究想法因适合现有的软硬件环境而成功,而非因为这些想法在本质上优于其他研究方向)——如果你的
注意力
变体不适用于现有的任何一个已经优化的Kernel,你就注定要面对缓慢的运行时间和CUDA内存不足的问题。
一些注意力变体的例子包括因果(Causal)、相对位置嵌入(Relative Positional Embeddings)、Alibi、滑动窗口注意力(Sliding Window Attention)、前缀语言模型(PrefixLM)、文档掩码(Document Masking)、样本打包(Sample Packing)、不规则张量(Jagged Tensors)、软封顶(Tanh Soft-Capping)、分页注意力(PagedAttention)等。更糟糕的是,人们通常想要这些变体的组合!比如滑动窗口注意力+文档掩码+因果关系+上下文并行处理?或者分页注意力+滑动窗口+软封顶。
下图左边代表当今的状态——一些masking + biases + setting的组合已经实现了现有的内核。但各种选项导致设置的数量呈指数级增长,因此总体上我们得到的是相当零散的支持。更糟糕的是,研究人员提出的新的注意力变体将得不到任何支持。
为了一劳永逸地解决这个问题,Pytorch引入了FlexAttention,一种新的PyTorch API:
1. 提供了一个灵活的API,允许在几行典型的PyTorch代码中实现许多注意力变体(包括迄今为止在博客文章中提到的所有变体)。
2. 通过torch.compile将其降低为一个融合的FlashAttention内核,生成的FlashAttention内核不会生成任何额外的内存,并且性能与手写的内核竞争。
3. 利用PyTorch的自动微分机制,自动生成反向传播。
4. 最后,我们还可以利用注意力掩模中的稀疏性,从而在标准注意力实现上取得显著的改进。