基本信息和摘要
论文题目
MEDUSA: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads
Arxiv: https://arxiv.org/pdf/2401.10774
作者
Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D. Lee, Deming Chen, Tri Dao
作者研究单位
-
-
-
University of Illinois Urbana-Champaign
-
Carnegie Mellon University
-
University of Connecticut
摘要
本文提出了MEDUSA,一种通过增加额外解码头来
预测多个后续token
的高效方法,以加速大型语言模型(LLMs)的推理过程。MEDUSA利用基于
树的注意力机制
,构建多个候选续集,并在每个解码步骤中同时进行验证。与传统的自回归解码相比,MEDUSA仅引入了
极小的单步延迟开销
,同时显著减少了所需的解码步骤数量。作者为MEDUSA设计了两种微调程序,以满足不同用例的需求:MEDUSA-1和MEDUSA-2。此外,作者还提出了几种扩展,包括自蒸馏 (
self-distillation
) 和典型接受方案(
typical
acceptance scheme
),以提高MEDUSA的实用性。实验结果表明,MEDUSA-1在不影响生成质量的情况下,可以实现超过2.2倍的加速,而MEDUSA-2进一步提高了加速到2.3-3.6倍。
介绍
主要贡献
-
提出了MEDUSA,一个通过增加
多个解码头
来加速LLMs推理的框架。
-
利用
基于树的注意力机制
,MEDUSA能够并行构建和验证多个候选续集。
-
设计了两种微调策略(MEDUSA-1和MEDUSA-2),以适应不同的使用场景和需求。
-
提出了自蒸馏方法和典型接受方案,以改善在没有训练数据或模型经过特定训练(如RLHF)后的应用情况。
-
通过一系列实验验证了MEDUSA在不同模型规模和训练设置下的有效性,实现了显著的推理速度提升。
方法
提出方法的原因
大型语言模型(LLMs)的推理过程通常受限于自回归解码过程中的串行性质,导致操作受加速器内存带宽限制。为了解决这个问题,需要一种能够提高解码过程中的算术强度并减少解码步骤数量的方法。
具体方法描述
-
Figure 1
:MEDUSA的概览图,在LLM最后隐藏状态之上引入多个头,以并行预测多个后续token。
2. 树注意力机制
Tree Attention 是MEDUSA框架的核心组件之一,它允许模型通过树状结构的注意力机制并行处理多个候选续集。这种机制与传统的因果注意力(causal attention)不同,它专门设计用于处理由MEDUSA头生成的多个候选预测。
在Tree Attention中,每个MEDUSA头生成的top预测被用作构建候选续集的基础。通过计算每个头的top预测的笛卡尔积,形成树状结构,其中每个分支代表一个候选续集。为了确保每个token只访问其前驱token,设计了一种注意力掩码(attention mask),它只允许从当前token向前追溯到其前驱token的注意力流动。
具体来说,对于第
个头,其top-
预测形成了候选的基础,其中
是一个超参数。这些候选是通过确定每个头的top-
预测的笛卡尔积来建立的。
例如,如图所示,如果
且
,那么第一个头的每个预测可以由第二个头的任何预测来延续,从而形成具有
个候选的树结构。
在树状结构中,只有同一续集中的token被视为历史数据。这种注意力掩码的实现确保了即使在处理多个候选续集时,也能保持计算的效率,而无需扩大批量大小。累积的新token数量可以表示为
。
此外,作者还探讨了优化树结构构建的方法,通过使用校准数据集来估计不同头的top预测的准确性,并据此构建期望接受长度最大的树结构。这种方法不仅提高了加速率,而且通过选择与当前树结构连接并且具有最高准确性的节点,进一步优化了树状结构。
在实验部分,作者展示了稀疏树结构与密集树结构在加速率和速度上的表现,证明了稀疏树结构在保持稳定加速率方面的优势,尽管在更高的附加长度下速度有所下降。这些发现强调了在维持推理速度和提高加速率之间需要做出权衡。
训练策略
-
MEDUSA-1
:冻结主模型,仅对MEDUSA头进行微调,使用交叉熵损失函数。
-
MEDUSA-2
:MEDUSA头与主模型一同微调,使用组合损失、不同的学习率和头部预热策略。
扩展