在问答(SQuAD1.1[1] 2016, TriviaQA[2] 2017), 总结(CNN/DailyMail[3] 2017), 语言建模(WikiText-103[4]), 文本提取(Tiny-Shakespeare[5] 合成数据集)在除语言建模的4/5个任务中取得最佳得分 [6]
刚刚在奥地利维也纳落幕的ICML2024, 一篇针对 LLM Inference 的 workload 的文章,通过 研究 attention score 的稀疏性质,让人眼睛一亮,其落款机构正好来自一家知名的,长期在存储、计算的跷跷板上跳舞的近存计算芯片公司:Graphcore.
在最领先的 attention实现中,dot product (gemv) 成为了限制计算的关键。应用部分,LLM推理被分为prefill 和 generation两个阶段,其中generation阶段由于需要反复地从片外存储加载KV Cache成为了制约推理性能的关键点。
1. LLM性能限制器
LLM, Llama2-7b在 A100 (40GB)性能测试显示为访存瓶颈[6]
记系统算力rA(scalar operation/sec), 吞吐rM(scalar data transfer/sec),算力密度I(x)。
从图中我们可以得到测试落点满足不等式 rA > I(x) * rM ,也就是 I(x) , 因此我们推断当前的workload属于内存受限类型。
memory bound,可以理解为对硬件达到峰值,其访存需要大于硬件提供的吞吐极限,因此执行速度受到访存限制。
为此Graphcore在构建了 GQA (group query attention)的访存模型:
将transfomer layer记做 N 参数,芯片数据处理批 (batch size) B; 由于 GQA 多头(heads)注意力 被 g 查询向量共享, 因此共有BC个元素的 KV Cache需要被加载到片上存储。
在推理中,增加B可以提高吞吐,当B足够大时 (一般情况下N >> B, N/B 会显著降低,但不会趋于0),算力密度由模型参数和KV Cache大小的比值 N/C 决定。
每层模型参数由经典的Llama2(不含bias)模型进行简化 , attention 包含: 3(down_proj) + 1(weighted sum over heads) + 4(upper_proj) + 4(down_proj) 组参数:
同时,KV cache计:
带入 (C1) 表达式,并令控制变量:
得到 计算密度,数据处理批:
表明控制变量较低时(增大 g, S=4k, dm=128=4k/32),提高batch size 有助于提高计算密度。
在 g=8, dm=2*S 达到llama2 最佳计算密度 104 较常规MHA提升了15倍!Sheng, Y. et al 2023 [7]实验表明, attention score (s) 具有稀疏性:
s = softmax(q · K⊤ / √ dh )
top-32 token 拿到了多数layer 80%以上注意力加权分数
因此一个很自然的想法是,就是通过预测这种稀疏性,来降低Q, K, V的加载量。为了量化所构建稀疏模型,Graphcore 通过定量分析 group query attention 的 Roofline 模型,分别在模型评估和性能评估上研究该方法的有效性。
其中性能评估包括了llama 7b标准配置下,算子层面的微测试(micro benchmark,40 ms/query -> 5.82 ms/query) 以及 端到端模型的延迟测试,显示算子层面约有6倍的提升:
B=64, Head=32, dh=128, r=32(计算qk score s, hidden size 的采样), k=128(计算 attention s, q top-k tokens), GPU A100(40GB)
2. Bandwidth Efficient Attention (SparQ)
此次文章在ICML发表,却一反常态,将相关工作放在结尾,十分有趣。
在LLM推理加速中,减少KV-Cache load/store 成为一个热门的话题。这类工作可以分为两类。
第一种是借鉴蒸馏思路对KV-Cache压缩。在2个月前 NVIDIA发表的 动态内存压缩(Dynamic Memory Compression:Retrofitting LLMs Accelerate Inference)属于这一类:通过在finetune阶段获得的一个选择器(decision variable),推理阶段沿着序列维度来决定是否需要合并(线性累加)KV cache;论文着力点在于如何finetune阶段获得一个决策掩码矩阵;并通过和GQA, H20[8]等缓存驱逐(cache eviction)模型比较下游任务的表现和压缩比,来判断方法的有效性。
第二种就是通过, attention score 的稀疏性进行来进行压缩。FlexGen[8] 计算score的稀疏性,并因此只加载部分 V 的值,但还需要加载全部的K和Q。因此该文章最大的价值,就是量化一个稀疏预测方法,只需加载topK的K, V 序列以及部分 Q, K的维度在保持精度情况下来加速模型推理。
该文章对 Llama-2-7b, Llama-2-13b, Llama-3-8b, Mistral-7b, Gemma-7b等模型展开端到端研究,并最终验证该想法在以上参考模型和对应5个下游任务的有效性。
llama2 7b 压缩比/精度 实验
从上述实验🧪看出 Bandwith Efficient Attention 在较高的放存带宽下,几乎不掉点,且无需微调。
因此,由于上述模型,和任务数据集的广泛使用,使得文章的结论和方法在实际生产应用中都具有相当的理论指导和实践意义。
该Attention包含三个重要步骤:
第一步:计算 qk matrix/vector product 时候,对q, k 沿hidden_size维度采样绝对值较大的q,只load对应位置的q slice, k slice
第二步:注意我们的generation会持续多轮,因此prefill结束后,我们加载完全长度的 K, V(在不断生成中), 并通过topk (k <=32)计算出 attention score (s) 对应位置,因此只加在对应位置的 KV cache:
第三步:通过下面running mean value进行矫正:
Ms 是我们的topK mask,实际运算包含一个偏上的gather操作,保证数据连续分布, V^bar 是我们的 V 沿序列维度的平均值。
非融合代码描述如下:
Bandwidth Efficent Attention (SparQ) generation预测示意代码
论文中通过 Triton 来进行codeGen。
论文尚未在 HIP/CUDA 下验证 gather操作(selected Q, KV cache),在多一个 store 的情况下加载部分 KV Cache 是否可以打平,并加速。
但相比 pytorch scaled dot product (Meta's flash attn v2 + memory efficient attn) 在 GPU 有了近3倍提升。同时论文提到在 IPU 上,当完全可以通过SRAM加载时,减少 K, V 参数加载可以达到 300 倍的加速效果。
3. 结论
通过大量实验,Bandwith Efficient Attention (SparQ) 发现了存在 attention 的 sparsity 结构。针对推理 generation 需要反复加载 KV Cache问题,提出通过预测稀疏方式来减少 KV Cache 加载。
虽然没有在最领先的带 KV Cache 的 flash attention 的实现中验证,但其在已有数据上的研究,显示了Attention的sparse结构。这意味着prefill结束后,新生成的 KVCache 不必全部加载。
通过triton 生成的kernel 在 A100 上较 pytorch scaled dot product (Meta's flash attn v2 + memory efficient attn) 有3倍提升,在 IPU硬件上有近300倍提升(fully fitted in SRAM)。
因此,如何结合现有的Attention实现方案,优化 prediction (topk, gather op) 可能会成为NV/AMD GPU上推理效率提升的关键方法之一。
参考文献
[1]Rajpurkar, P., Zhang, J., Lopyrev, K., and Liang, P. SQuAD:
100,000+ questions for machine comprehension of text.
arXiv preprint arXiv:1606.05250, 2016.
[2]Joshi, M., Choi, E., Weld, D. S., and Zettlemoyer, L.
TriviaQA: A large scale distantly supervised challenge
dataset for reading comprehension. arXiv preprint
arXiv:1705.03551, 2017.
[3]See, A., Liu, P. J., and Manning, C. D. Get to the point:
Summarization with pointer-generator networks. arXiv
preprint arXiv:1704.04368, 2017.
[4]Merity, S., Xiong, C., Bradbury, J., and Socher, R.
Pointer sentinel mixture models. arXiv preprint
arXiv:1609.07843, 2016.
[5]Karpathy, A. The unreasonable effectiveness of recurrent
neural networks. (Online: accessed 27 January 2024),
2015. URL https://github.com/karpathy/
char-rnn.
[6]SparQ Attention: Bandwidth-Efficient LLM Inference, Luka Ribar et al, Graphcore Research, https://arxiv.org/pdf/2312.04985, 2024, accessed on 1st Aug 2024
[7]Sheng, Y., Zheng, L., Yuan, B., Li, Z., Ryabinin, M., Chen,
B., Liang, P., Re, C., Stoica, I., and Zhang, C. FlexGen: ´
high-throughput generative inference of large language
models with a single GPU. In International Conference
on Machine Learning, pp. 31094–31116. PMLR, 2023.
[8]Zhang, Z., Sheng, Y., Zhou, T., Chen, T., Zheng, L., Cai,
R., Song, Z., Tian, Y., Re, C., Barrett, C., et al. H ´ 2O:
Heavy-hitter oracle for efficient generative inference of
large language models. arXiv preprint arXiv:2306.14048,
2023.