总结: 随着我们增加内存压缩次数的次数,Infini-attention 的性能会变得越来越差。据我们所知,
ring attention
、
YaRN
和
rope scaling
这三种方法仍是将预训练模型拓展更长上下文的最佳方式。
ring attention
https://x.com/Haojun_Zhao14/status/1815419356408336738
YaRN
https://arxiv.org/abs/2309.00071
rope scaling
https://arxiv.org/abs/2309.16039
引言:
语言模型的上下文长度也是除模型性能之外的重要属性之一。自 in-context learning (上下文学习) 出现以来,添加相关信息到模型的输入中日渐重要。因此,上下文长度迅速从段落 (BERT/GPT-1 的 512 个 tokens) 扩展到页面 (GPT-2 和 GPT-3 分别为 1024/2048 个 tokens), 再到书籍 (Claude 的 128k tokens), 甚至书籍集合 (Gemini 的 1-10M tokens)。然而,将 standard attention(标准注意力) 扩展到如此长度仍然面临挑战。
关于 Ring Attention (一种注意力机制) 的简单介绍: 据我们所知,Ring Attention 最初是由加州大学伯克利分校的研究人员在 2024 年提到的
Ring Attention with Blockwise Transformers for Near-Infinite Context
。这种工程技术通过以分块方式执行 self-attention 和 feedforward network 计算,并将序列维度分配到多个设备上,减轻了内存限制,实现并发计算和通信。
Ring Attention with Blockwise Transformers for Near-Infinite Context
https://arxiv.org/abs/2310.01889
即使使用 Ring Attention,要在 1 百万 token 的上下文长度上训练一个
Llama 3 8B
模型,batch size 为 1 时,仍然需要 512 个 GPU。正如 scaling laws (扩展定律) 提到
Scaling Laws for Neural Language Models
的那样,模型大小与其下游任务性能之间存在强相关性,这意味着模型越大越好 (当然,两种模型都应该被训练得很好)。因此,我们不仅需要 1 百万 token 的上下文长度,还希望在最大的模型上实现这一长度 (例如,Llama 3 8B 405B)。而目前只有少数几家公司拥有实现这一目标的资源。
Llama 3 8B
https://arxiv.org/abs/2407.21783
Scaling Laws for Neural Language Models
https://arxiv.org/abs/2001.08361
回顾自注意力的内存复杂度
在标准注意力机制 (非 flash-attention) 中,每个标记都会关注序列中的所有其他标记,从而形成一个大小为 [seq_len, seq_len] 的注意力矩阵。对于每对标记,我们都需要计算一个注意力分数。随着序列长度 (seq_len) 的增加,内存和计算需求呈二次方增长:注意力矩阵的内存复杂度为 O(seq_len^2)。例如,序列长度增加 10 倍会导致内存需求增加 100 倍。
即使是像 Flash Attention 这样的内存高效注意力方法,其内存需求仍会随上下文长度线性增长,并受限于单个 GPU 的内存容量。这导致在当今的 GPU 上,典型的最大上下文长度远低于 1M 个标记。
受此启发,我们探索了一种替代标准注意力的方法:无限注意力 (infini-attention)。这篇论文由来自 Google 的研究人员于 2024 年 4 月发布
Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention
。与计算每个词之间的注意力分数不同,无限注意力将序列划分为多个片段,将早期片段压缩到固定缓冲区,并允许下一个片段从早期片段中检索记忆,同时将注意力分数限制在当前片段内的词语之间。其关键优势在于固定的缓冲区大小为总内存使用设置了上限。它还在一个片段内使用相同的查询来访问该片段和压缩记忆中的信息,这使我们能够以低成本为预训练模型扩展上下文长度。理论上,我们可以实现无限长的上下文,因为它只为所有早期片段的记忆保留一个缓冲区。然而,实际上压缩限制了能有效存储的信息量,因此问题在于:这种压缩的记忆有多大的可用性 ?
Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention
https://arxiv.org/abs/2404.07143
虽然在理论上理解新方法相对容易,但实际使其运作往往是另一回事,而这个过程很少公开分享。出于这个原因,我们决定分享我们在复现无限注意力论文过程中的实验和记录,包括在调试过程中 (我们 90% 的时间都在调试一个收敛问题) 激励我们的因素,以及让这些方法正常工作可能有多困难。
随着 Llama 3 8B (上下文长度限制为 8k 个标记) 的发布,我们试图将这个长度扩展到 100 万个标记,而不会导致内存需求二次增长。在这篇博客中,我们将首先解释无限注意力的工作原理。然后,我们将介绍我们的复现原则,并描述我们最初的小规模实验。我们讨论了面临的挑战,如何解决这些挑战,并以我们的发现总结和其他探索的想法作为结束。如果你有兴趣测试我们训练的
检查点
, 你可以在
以下仓库
中找到它 (请注意,我们目前按原样提供代码)。
检查点
https://hf.co/nanotron/llama3-8b-infini-attention
仓库链接
https://github.com/huggingface/nanotron/tree/xrsrke/infini_attention_this_actually_works
第 1 节: 复现原则
我们发现以下规则在实现新方法时很有帮助,并将其用作我们大量工作的指导原则:
原则 1:
从能提供良好信号的最小模型规模开始,一旦获得良好信号就扩大实验规模。
原则 2:
始终训练一个可靠的基准模型来衡量进展。
原则 3:
为了确定某项修改是否提高了性能,训练两个除了被测试的修改之外完全相同的模型。
牢记这些原则,让我们深入了解 Infini-attention 的实际工作原理。理解其机制对于我们推进实验至关重要。
第 2 节: Infini-attention 的工作原理
步骤 1: 将输入序列分割成较小的、固定大小的块,称为 “ 片段 “。
步骤 2: 在每个片段内计算标准的因果点积注意力。
步骤 3: 使用当前片段的查询向量从压缩内存中提取相关信息。检索过程的数学定义如下:
: 查询矩阵,其中
是查询的数量,
是每个查询的维度。
: 非线性激活函数,具体为逐元素指数线性单元 (ELU) 加 1。
import torch.nn.functional as Ffrom torch import einsumfrom einops import rearrangedef _retrieve_from_memory (query_states, prev_memory, prev_normalization) : ... sigma_query_states = F.elu(query_states) + 1 retrieved_memory = einsum( sigma_query_states, prev_memory, "batch_size n_heads seq_len d_k, batch_size n_heads d_k d_v -> batch_size n_heads seq_len d_v" , ) denominator = einsum( sigma_query_states, prev_normalization, "batch_size n_heads seq_len d_head, batch_size n_heads d_head -> batch_size n_heads seq_len" , ) denominator = rearrange( denominator, "batch_size n_heads seq_len -> batch_size n_heads seq_len 1" , ) # NOTE: because normalization is the sum of all the keys, so each word should have the same normalization retrieved_memory = retrieved_memory / denominator return retrieved_memory
步骤 4: 将局部上下文 (来自当前片段) 与长期上下文 (从压缩内存中检索) 结合,生成最终输出。这样,注意力输出可以同时考虑短期和长期上下文。
: 一个可学习的标量参数,用于控制长期内存内容
和局部上下文之间的权衡。
步骤 5: 通过添加当前片段的键值状态来更新压缩内存,这使我们能够随时间累积上下文。