专栏名称: 机器学习算法与Python学习
作为沟通学习的平台,发布机器学习与数据挖掘、深度学习、Python实战的前沿与动态,欢迎机器学习爱好者的加入,希望帮助你在AI领域更好的发展,期待与你相遇!
目录
相关文章推荐
51好读  ›  专栏  ›  机器学习算法与Python学习

新PyTorch API:几行代码轻松搞定多种注意力变体

机器学习算法与Python学习  · 公众号  ·  · 2024-08-11 11:11

正文

机器之心报道

用 FlexAttention 尝试一种新的注意力模式。

理论上,注意力机制就是你所需要的一切。 然而在实际操作中,我们还需要优化像 FlashAttention 这样的注意力机制的实现。

尽管这些融合的注意力机制大大提高了性能,且支持长上下文,但这种效率的提升也伴随着灵活性的丧失。对于机器学习研究人员来说,这就像是一种「软件彩票」—— 如果你的注意力变体不适合现有的优化内核,你将面临运行缓慢和 CUDA 内存不足的困境。
一些注意力变体包括因果注意力、相对位置嵌入、Alibi、滑动窗口注意力、PrefixLM、文档掩码、不规则张量、PagedAttention 等。 更糟糕的是,人们通常希望将这些变体组合在一起!比如滑动窗口注意力 + 文档掩码 + 因果注意力 + 上下文并行,又比如 PagedAttention + 滑动窗口的组合。
下图左侧代表了当今的现状 —— 一些掩码 + 偏置 + 设置的组合已经有现成的内核实现。然而,各种选项的添加会导致设置呈指数级增长。更糟糕的是,这种方式不会支持新的注意力变体。
为了彻底地解决这个超立方体问题,PyTorch 团队引入了 FlexAttention,一个新的 PyTorch API。
  • FlexAttention 是一个灵活的 API,允许用户使用几行惯用的 PyTorch 代码就能实现多个注意力变体。
  • 团队人员通过 torch.compile 将其降低到一个融合的 FlashAttention 内核中 ,生成了一个不会占用额外内存且性能可与手写内核相媲美的 FlashAttention 内核。
  • 利用 PyTorch 的自动求导机制自动生成反向传播。
  • 最后,PyTorch 团队还可以利用注意力掩码中的稀疏性,从而显著改善标准注意力实现。

FlexAttention
经典的注意力方程式如下:

代码形式:
FlexAttention 形式如下,其通过接受用户定义的函数 score_mod 来解决上述问题。
代码形式:
此函数允许用户在 softmax 之前修改注意力分数。研究人员发现,该函数最终足以满足大多数用户对注意力变体的需求。
具体而言,score_mod 如下:
要应用此函数,可以将其实现为:
for b in range (batch_size):
    for h in range (num_heads):
        for q_idx in range (sequence_length):
            for kv_idx in range (sequence_length):
                modified_scores [b, h, q_idx, kv_idx]
 = score_mod (scores [b, h, q_idx, kv_idx], b, h, q_idx, kv_idx)






请到「今天看啥」查看全文