专栏名称: 自动驾驶之心
自动驾驶开发者社区,关注计算机视觉、多维感知融合、部署落地、定位规控、领域方案等,坚持为领域输出最前沿的技术方向!
目录
相关文章推荐
HRBar  ·  HR1号为操盘手训练 ·  2 天前  
人力资源心理学  ·  我在公司工作5年,每日勤勤恳恳,爱岗敬业;直 ... ·  2 天前  
HRTechChina  ·  【北京】2025人力资源科技年度论坛(HR+ ... ·  2 天前  
51好读  ›  专栏  ›  自动驾驶之心

CUDA-Free Inference for LLMs

自动驾驶之心  · 公众号  ·  · 2024-10-22 07:30

正文

作者 | BBuf  编辑 | 自动驾驶之心

原文链接:https://zhuanlan.zhihu.com/p/2130907920

点击下方 卡片 ,关注“ 自动驾驶之心 ”公众号

戳我-> 领取 自动驾驶近15个 方向 学习 路线

>> 点击进入→ 自动驾驶之心 CUDA编程 技术交流群

本文只做学术分享,如有侵权,联系删文

blog链接:https://pytorch.org/blog/cuda-free-inference-for-llms/

无CUDA的LLM推理

作者:Adnan Hoque, Less Wright, Raghu Ganti 和 Mudhakar Srivatsa

在这篇博客中,我们讨论了如何使用OpenAI的Triton语言实现流行的LLM模型(如Meta的Llama3-8B和IBM的Granite-8B Code)的FP16推理,其中 100 % 的计算都是使用Triton语言完成的。对于使用我们基于Triton kernel的模型进行单个token生成的时间,我们能够在Nvidia H100 GPU上达到相对于CUDA kernel主导工作流的 0.76-0.78 x性能,在Nvidia A100 GPU上达到 0.62-0.82 x性能。

为什么要探索使用100%的Triton?Triton为LLM在不同类型的GPU(如NVIDIA、AMD,以及未来的Intel和其他基于GPU的加速器)上运行提供了一条路径。它还为GPU编程提供了更高层次的Python抽象,使我们能够比使用特定供应商的API更快地编写高性能kernel。在本文的其余部分,我们将分享我们如何实现无CUDA的计算,对单个kernel进行微基准测试以进行比较,并讨论我们如何进一步改进未来的Triton kernel以缩小差距。

图1. 在NVIDIA H100和A100上,Llama3-8B和Granite-8B的Triton和CUDA变体的推理吞吐量基准测试 设置:批量大小 = 2,输入序列长度 = 512,输出序列长度 = 256

2.0 Transformer块的组成

我们从Transformer模型中发生的计算分解开始。下图显示了一个典型Transformer块的“kernels”。

图2. 按核心kernels划分的Transformer块

Llama3架构的核心操作总结如下:

这些操作中的每一个都是通过在GPU上执行一个(或多个)kernels来计算的。尽管这些kernels的具体细节在不同的transformer模型中可能有所不同,但核心操作保持不变。例如,IBM的Granite 8B Code模型在MLP层中使用了偏置,这与Llama3不同。这种变化确实需要对kernels进行修改。一个典型的模型是由这些transformer块堆叠在一起,并通过嵌入层连接起来的。

3.0 模型推理

典型的模型架构代码与一个由PyTorch启动的python model.py文件共享。在默认的PyTorch eager执行模式下,这些kernel都是使用CUDA执行的。为了实现Llama3-8B和Granite-8B端到端推理的100% Triton,我们需要编写和集成手写的Triton kernel,并利用torch.compile(生成Triton操作)。首先,我们用编译器生成的Triton kernel替换较小的操作,其次,我们用手写的Triton kernel替换更昂贵和复杂的计算(例如矩阵乘法和flash attention)。

Torch.compile自动为RMSNorm、RoPE、SiLU和Element Wise Multiplication生成Triton kernel。使用Nsight Systems等工具,我们可以观察这些生成的kernel;它们在矩阵乘法和注意力之间显示为微小的深绿色kernel。

图3 . Llama3-8B 使用 torch.compile 的跟踪,显示用于矩阵乘法和 flash attention 的 CUDA kernels

对于上述跟踪,我们注意到在 Llama3-8B 风格的模型中,构成 80% 端到端延迟的两个主要操作是矩阵乘法和注意力 kernels,并且这两个操作仍然是 CUDA kernels。因此,为了缩小剩余的差距,我们用手写的 Triton kernels 替换了矩阵乘法和注意力 kernels。

4.0 Triton SplitK GEMM Kernel

对于线性层中的矩阵乘法,我们编写了一个自定义的FP16 Triton GEMM(通用矩阵-矩阵乘法)kernel,该kernel利用了SplitK工作分解(https://pytorch.org/blog/accelerating-moe-model//#30-work-decomposition---splitk)。我们之前在其他博客中讨论过这种并行化方法,作为加速LLM推理解码部分的一种方式。

这里对上面博客中的 Work Decomposition - SplitK 一节也翻译一下

工作分解 - SplitK

我们之前已经证明,对于LLM推理中发现的矩阵问题大小,特别是在W4A16量化推理的背景下,通过应用SplitK工作分解(https://arxiv.org/abs/2402.00025),GEMM内核可以加速。因此,我们通过在vLLM MoE kernel(https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py)中实现SplitK,开始了我们的MoE加速研究,这相对于数据并行方法产生了大约18-20%的加速。

这一结果表明,SplitK优化可以作为在推理设置中改进/开发Triton kernel的更公式化方法的一部分。为了建立对这些不同工作分解的直觉,让我们考虑一个简单的例子,即两个4x4矩阵的乘法和SplitK=2。

在下图中显示的数据并行GEMM kernel中,输出矩阵的单个块的计算将由1个线程块TB0处理。

Figure 2. Data Parallel GEMM

相比之下,在SplitK kernel中,计算输出矩阵中单个块所需的工作被“分割”或共享给两个线程块TB0和TB1。这提供了更好的负载均衡和增加的并行性。

Figure 3. SplitK GEMM

关键思想是我们将并行性从MN增加到MN*SplitK。这种方法确实会带来一些成本,例如通过原子操作增加线程块间通信。然而,这些成本相对于节省的其他受限GPU资源(如共享内存和寄存器)来说是微不足道的。最重要的是,SplitK策略为瘦矩阵(如MoE推理中的情况)提供了优越的负载均衡特性,并且在解码和推理期间是常见的矩阵配置文件。







请到「今天看啥」查看全文