0x00 前言
在上一篇Prefill优化的文章中,已经详细讲解了vLLM Automatic Prefix Caching(Hash RadixAttention)的原理和Cache调度的实现,包括SGLang RadixAttention原理,并且结合图解和代码,详细分析了vLLM中的Hash RadixAttention实现。vLLM中的Hash RadixAttention内容包括:Hash RadixAttention、Hash Prefix Tree、Prefix/Generate 阶段Hash码处理、Prefix + Generated KV Caching的调度逻辑、边界情况思考、vLLM Automatic Prefix Caching在多轮对话中的应用分析以及代码应用实践。本篇,继续深入,讲解Automatic Prefix Caching中用到的Triton Based Prefix Prefill Kernel。
推荐先阅读完上一篇的Automatic Prefix Caching原理,再来阅读本篇的kernel解读。
DefTruth:[Prefill优化][万字] 原理&图解vLLM Automatic Prefix Cache(RadixAttention): 首Token时延优化
https://zhuanlan.zhihu.com/p/693556044
本文包含内容如下。(
提示:需要先阅读完vLLM Automatic Prefix Cache作为基础
)
0x01 OpenAI Triton: Triton Kernel编程极简入门
0x02 vLLM Prefix Prefill Kernel: Prefix Prefill Kernel与Attention Kernel区别
0x03 vLLM Prefix Prefill Kernel: 先说Tiling分块策略
0x04 vLLM Prefix Prefill Kernel: 再看Kernel调用
0x05 vLLM Prefix Prefill Kernel: 如何确认有多少个Token被Prefix Cache命中?
0x06 vLLM Prefix Prefill Kernel: 通用Head Sizes支持
0x07 vLLM Prefix Prefill Kernel: MQA/GQA支持
0x08 vLLM Prefix Prefill Kernel: Triton Kernel解析
0x09 总结
0x01 OpenAI Triton: Triton Kernel编程极简入门
关于OpenAI Triton,这里只做简单的介绍。网上可以找到大量入门的文章,本文也先不重复了。本文主要关注Prefix Prefill Kernel的实现,而非深挖Triton的底层的原理。
GPU基础架构
传统的基于 CUDA 进行 GPU 编程难度较大,在优化 CUDA 代码时,必须考虑到数据流在DRAM、SRAM 和 ALU之间的Load/Store的问题,还需要仔细考虑到Grid、Block、Thread和Warp等不同级别的调度优化问题。这些问题包括但不限于:
1. 从 DRAM 的内存传输必须合并成大型事务,以利用现代内存接口的大总线宽度(内存合并访问)。
2. 数据必须在重复使用前手动存储到 SRAM 中,并进行管理来最小化bank conflict。
3. 计算必须仔细地进行划分和调度,不仅是在流式多处理器(SMs)之间,还包括在其内部,以促进指令/线程级并行性,并利用专用的 ALU(例如,Tensor Cores)。
因此,哪怕是CUDA熟练工,也得花费不少的精力,才能写出一个性能接近理论峰值的Kernel。(比如像我这种菜菜的,基本上已经放弃手撸CUDA Kernel了)。Triton 的出现,降低了CUDA Kernel编写的难度,它将一些需要精心设计的优化策略进行自动化,比如内存事务合并、SRAM分配和管理、流水线优化等,从而使得编程人员可以将更多的精力放在算法本身。
Triton Complier 编译优化
从官方放出的这个表格中,我们可以看到,如果使用Triton,内存事务合并、SRAM管理以及SM内的线程调度都是自动进行的,我们只需要把精力花在SM之间管理即可,这也就是说,
Triton的编程粒度是Block
(每个Block只会被调度到一个SM上),而不是Thread。我们只需要考虑每个Block需要做什么,至于Thread/Warp的分布和调度,Triton自动给我们处理了。那么,Block这个概念,在Triton中通过什么进行表达呢?答案是:
program
。
Triton Block-wise 编程模型
block -> program,在Triton中,使用
program_id
来标识一个唯一的program。编程人员只需要考虑一个program(block)内的编程逻辑,比如这个最简单的add_kernel。
x_ptr
,
y_ptr
, 和
output_ptr
分别是指向第一个输入向量、第二个输入向量和输出向量的指针。这些向量存储在 GPU 的内存中。比较常见的就是PyTorch和Triton一起使用,Triton将会传入的Tensor当成指针来处理,而非数据张量。
BLOCK_SIZE: tl.constexpr
表示一个triton的编译时常量,表示每个 block需要处理的元素数量。
mask = offsets < n_elements
表示创建一个mask以防止内存操作超出范围。tl.load和tl.store分表表示triton中的数据加载和写入的操作,这也是需要注意的,Triton为了能更好地进行性能优化,它是在指针级别上做操作的,而非数据Tensor级别。
import triton import triton.language as tl @triton.jitdef add_kernel (x_ptr, # *Pointer* to first input vector. y_ptr, # *Pointer* to second input vector. output_ptr, # *Pointer* to output vector. n_elements, # Size of the vector. BLOCK_SIZE: tl. constexpr, # Number of elements each program should process. # NOTE: `constexpr` so it can be used as a shape value. ): # There are multiple 'programs' processing different data. We identify which program # we are here: # 有多个'程序'(也就是block)处理不同的数据。我们在这里标识我们是哪个程序: pid = tl. program_id(axis= 0 ) # We use a 1D launch grid so axis is 0. # This program will process inputs that are offset from the initial data. # For instance, if you had a vector of length 256 and block_size of 64, the programs # would each access the elements [0:64, 64:128, 128:192, 192:256]. # Note that offsets is a list of pointers: # 该程序将处理与初始数据偏移的输入。 # 例如,如果您有长度为 256 的向量和块大小为 64,程序 # 将分别访问元素[0:64, 64:128, 128:192, 192:256]。 # 请注意,偏移量是指针的列表: block_start = pid * BLOCK_SIZE offsets = block_start + tl. arange(0 , BLOCK_SIZE) # Create a mask to guard memory operations against out-of-bounds accesses. # 创建一个mask以防止内存操作超出范围。 mask = offsets < n_elements # Load x and y from DRAM, masking out any extra elements in case the input is not a # multiple of the block size.
x = tl. load(x_ptr + offsets, mask= mask) y = tl. load(y_ptr + offsets, mask= mask) output = x + y # Write x + y back to DRAM. tl. store(output_ptr + offsets, output, mask= mask)
def add (x: torch. Tensor, y: torch. Tensor): # 我们需要预先分配输出。 output = torch. empty_like(x) assert x. is_cuda and y. is_cuda and output. is_cuda n_elements = output. numel() # SPMD启动网格表示并行运行的内核实例数。 # 它类似于CUDA启动网格。对于add_kernel我们使用一个1D网格,其大小是块的数量: grid = lambda meta: (triton. cdiv(n_elements, meta['BLOCK_SIZE' ]), ) # 注意: # - 每个torch.tensor对象都隐式地转换为指向其第一个元素的指针。 # - `triton.jit`'ed函数可以通过一个启动网格索引来获得一个可调用的GPU内核。 # - 不要忘记将元参数作为关键字参数传递。 add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE= 1024 ) # 我们返回一个指向z的句柄,但是,由于`torch.cuda.synchronize()`尚未被调用,内核此时仍在异步运行。 return output
需要注意的是,Triton将会传入的Tensor当成指针来处理,而非数据张量。并且,由于Triton Kernel也是异步调用的,因此在测试性能的时候,需要在函数返回后添加
torch.cuda.synchronize()
。更详细的Triton 入门,推荐阅读:如何入门 OpenAI Triton 编程? 以及 科密中的科蜜:OpenAI Triton 入门教程,讲解地很详细,本文Triton部分内容参考自这两篇文章(侵删)。Anyway,现在只需要记住以下这点,即可继续往下阅读本文了。
1. Program相当于CUDA编程中的Block,program_id相当于block id。
2. CUDA的编程模型从grid-block-thread,被简化为
Block-wise
,kernel启动时,只需要考虑一个grid中block的布局。比如,grid=(M,N,D/BLOCK_K)表示这个gird是一个3D的block布局。
0x02 vLLM Prefix Prefill Kernel: Prefix Prefill Kernel与Attention Kernel区别
回到Prefix Caching这个话题。需要注意的是,在使用了Prefix Caching后,就无法使用常规的Attention kernel来计算Prefill阶段的注意力结果了。这是由于常规kernel都暗含着一个假设,Q_len等于KV_len,同时两者也等于prompt_len。但是在Prefix Caching的背景下,这个假设就不成立了。因为当前请求的prompt中,会有部分被缓存的KV Cache命中,不需要重复计算,也就是说,Q_len
Q_len
,我们需要一个新的kernel来处理这种情况。vLLM中也正是这样处理的,目前prefix prefill kernel的实现在vllm/attention/ops/prefix_prefill.py。如果使用了prefix caching,则会走到这里实现的triton based prefix prefill kernel。
prefix prefill kernel
prefix prefill kernel相关的也有不少细节,这里先简单说明它解决的问题。接下来会继续分析它的源码。
0x03 vLLM Prefix Prefill Kernel: 先说Tiling分块策略
自顶向下,为了更好理解Prefix Prefill Kernel的含义,我们先来看看它的Tiling策略和Block的布局。这部分代码如下。其中
BLOCK_M
: 表示最内层
Q的seq_len行方向并行
,每个block处理BLOCK_M个tokens;
BLOCK_N
: 表示最内层
KV的seq_len列方向并行
,每个block处理BLOCK_N个tokens;
BLOCK_MxBLOCK_N(128x128)
作为一个最小的Tile进行处理;max_input_len/BLOCK: 表示最多需要多少个block,max_input_len由用户指定,可以是1024/2048/4096等;在目前的vLLM Prefix Prefill Kernel中,有
BLOCK_M=BLOCK_M=BLOCK=128
。
# 其他的代码逻辑先省略,现在只关注Tiling和Block的布局 cap = torch. cuda. get_device_capability() BLOCK = 128 if cap[0 ] >= 8 else 64 # shape constraints Lq, Lk, Lv = q. shape[- 1 ], k. shape[- 1 ], v. shape[- 1 ] sm_scale = 1.0 / (Lq** 0.5 ) batch, head = b_seq_len. shape[0 ], q. shape[1 ] num_queries_per_kv = q. shape[1 ] // k. shape[1 ] # BLOCK_M: 表示Q的seq_len行方向并行,每个block处理BLOCK_M个tokens # BLOCK_N: 表示K的seq_len列方向并行,每个block处理BLOCK_N个tokens # max_input_len/BLOCK: 表示最多需要多少个block,max_input_len由用户指定,可以是256/512/1024/2048/...等 # BLOCK_M=BLOCK_M=BLOCK=128 grid = (batch, head, triton. cdiv(max_input_len, BLOCK)) # batch, head,
我们可以看到,Prefix Prefill Kernel中,分块的布局采用的是
[batch, heads, max_input_len/BLOCK]
。熟悉FlashAttention V2的同学会马上反应过来,这其实就是FlashAttention V2中的Tiling逻辑。只是在Prefix Prefill Kernel中,还需要额外考虑处理被Prefix Cache命中的KV Cache。Prefix Cache和没有被Prefix Cache命中的New Tokens是分开处理的。以batch_size=8,heads=8,max_input_len=1024,BLOCK_SIZE(=BLOCK)=128为例,其对应的分块策略和Block布局如下(vLLM Prefix Prefill Kernel Tiling)。
vLLM Prefix Prefill Kernel Tiling
[batch, heads, max_input_len/BLOCK],我们可以看到最内层的max_input_len/BLOCK,表示处理一个head上的Attention计算。比如max_input_len=1024, BLOCK=128时,最内层有8=1024/128个program(也就是Thread Block)来负责这个Head的Attention计算,其中又有BLOCK_M=BLOCK_N=BLOCK=128,表示,每个Thread Block处理这个Head的BLOCK_M个New Query Tokens的Attention,并且对于KV按照BLOCK_N=128的块大小进行迭代计算
FlashAttention
,最后,一个Thread Block会得到[BLOCK_M, D]大小的Attention输出O(D表示head size,比如64, 128等);由于有大量的Query Tokens对应的KV Cache被Prefix Cache命中,因此,这些tokens的Attention计算可以直接skip掉,从而节省了大量的计算,只需要对没被Prefix Cache命中的New Tokens进行Attention计算即可。
FlashAttention的原理解析,推荐阅读我写的另一篇文章:
DefTruth:[Attention优化][2w字] 原理&图解: 从Online-Softmax到FlashAttention V1/V2/V3
https://zhuanlan.zhihu.com/p/668888063
我们知道Prefill阶段的目的有两个:(1)产生Prompt Tokens的KV Cache;(2)生成首Token;通常,我们会用TTFT(Time To First Token)来评估Prefill的耗时。整合一下Prefix Cache到Prefill的流程,大概长这样;紫色部分表示会被Prefix Cache命中的Tokens,这部分的KV Cache直接使用Prefix Cache中保存的即可。绿色部分为当前输入的Prompt中没被Prefix Cache命中的New Tokens,这部分New Tokens需要计算KV Cache;在获得Prompt中所有Tokens的KV Cache之后,就可以生成首Token了。
Prefill with Prefix Cache
0x04 vLLM Prefix Prefill Kernel: 再看Kernel调用
理解一个Kernel最基本的方法,就是先看下它的接口,以及它是怎么使用的。context_attention_fwd是_fwd_kernel的封装函数,需要传入的参数如下:
def context_attention_fwd (q, # new tokens对应的query Tensor, 没有被prefix cache命中的部分 k, # new tokens对应的keys Tensor v, # new tokens对应的values Tensor
o, # Attention输出 k_cache, # prefix cache命中的tokens对应的keys v_cache, # prefix cache命中的tokens对应的values b_loc, # new tokens对应的block_table b_start_loc, # new tokens len的前缀和cumsum(query_lens) b_seq_len, # 实际的seq_len=new_tokens len + b_ctx_len b_ctx_len, # 命中了prefix cache的token数 max_input_len, # 最大seq_len限制,比如1024、4096等 alibi_slopes= None , sliding_window= None ):
不再单独说明了,直接看代码中的注释吧。prefix prefill kernel的单测代码在test_prefix_prefill.py(
https://github.com/vllm-project/vllm/blob/main/tests/kernels/test_prefix_prefill.py)
,部分代码以及注释如下:
import torch , random from vllm.attention.ops.prefix_prefill import context_attention_fwddef test_contexted_kv_attention ( num_heads: int , num_queries_per_kv: int , head_size: int , sliding_window: int , dtype: torch. dtype, device: str ,) -> None : random. seed(0 ) torch. manual_seed(0 ) if torch. cuda. is_available(): torch. cuda. manual_seed(0 ) torch. set_default_device(device) torch. cuda. set_device(device) MAX_SEQ_LEN = 1024 MAX_CTX_LEN = 1024 BS = 10 cache_size = 640 # 表示有640个cache blocks block_size = 32 # 每个block保存32个tokens的KV Cache max_block_per_request = 64 # 每个request最多使用多少个cache blocks # query_lens表示没被prefix cache命中的token query_lens = [random. randint(16 , MAX_SEQ_LEN) for _ in range (BS)] # ctx_lens被prefix cache命中的token ctx_lens = [random. randint(16 , MAX_CTX_LEN) for _ in range (BS)] # 当前seq_len(即prompt_len)=ctx_len+query_len seq_lens = [a + b for a, b in zip (query_lens, ctx_lens)] # MQA/GQA处理num_queries_per_kv表示一个kv head对应多少个query head # num_queries_per_kv=1为MHA,num_queries_per_kv>1表示MQA/GQA num_kv_heads = num_heads // num_queries_per_kv # Batch内New Tokens总和,prefix prefill kernel是假设一个BS的seq拼接在一起作为输入的,不是padding。 num_tokens = sum (query_lens) query = torch. empty(num_tokens, num_heads, head_size, dtype= dtype) query. uniform_(- 1e-3 , 1e-3 ) # Attention输出的tokens数和query tensor中使用的tokens数相同 output = torch. empty(num_tokens, num_heads, head_size, dtype= dtype) # 总共需要的KV Cache,是按照输入seq_len计算的。无论是否在Prefix Cache中,都是需要显存来存放的。 kv = torch. empty(sum (seq_lens), 2 , num_kv_heads, head_size, dtype= dtype) kv. uniform_(- 1e-3 , 1e-3 ) key, value = kv. unbind(dim= 1 ) # 这里表示的模拟Prefix Cache中的K Cache k_cache = torch. zeros(cache_size, block_size, num_kv_heads, head_size, dtype= dtype) # 这里表示的模拟Prefix Cache中的V Cache v_cache = torch. zeros(cache_size, block_size, num_kv_heads, head_size, dtype= dtype) # 这里表示模拟没被Prefix Cache命中的New Tokens的KV Cache,使用的是query_lens。 k = torch. zeros(sum (query_lens), num_kv_heads, head_size, dtype= dtype) v = torch. zeros(sum (query_lens), num_kv_heads, head_size, dtype= dtype) # 将640个block id随机打乱,并给BS中的每个seq选择一部分作为block ids,构建block_table values = torch. arange(0 , cache_size, dtype= torch. long) values = values[torch. randperm(cache_size)] block_table = values[:BS * max_block_per_request]. view( BS, max_block_per_request) # 记录batch中每个seq的长度 b_seq_len = torch. tensor(seq_lens, dtype= torch. long) # 记录batch中每个seq中命中了prefix cache的token数 b_ctx_len = torch. tensor(ctx_lens, dtype= torch. long) # b_start_loc表示query tensor的start_loc,是query_lens的前缀和,因为query tensor是bs中所有new tokens # 拼接在一起作为输入的,在kernel中我们需要知道每个seq的new tokens对应的start位置 b_start_loc = torch. cumsum(torch. tensor([0 ] + query_lens[:- 1 ], dtype= torch. long), dim= 0 ) max_input_len = MAX_SEQ_LEN # copy kv to cache b_seq_start_loc = torch. cumsum(torch. tensor([0 ] + seq_lens[:- 1 ], dtype= torch. long), dim= 0 ) # 将key, values Tensor中全量的cache,按照BS中每个seq的block_table,拷贝到 # Prefix cache对应的k_cache和v_cache中;以及将New Tokens对应的kv cache # 拷贝到k和v Tensor中。 for i in range (BS): for j in range (query_lens[i]): k[b_start_loc[i] + j]. copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) v[b_start_loc[i] + j]. copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j]) cur_ctx = 0 block_id = 0 while cur_ctx < b_ctx_len[i]: start_loc = b_seq_start_loc[i] + cur_ctx if cur_ctx + block_size > b_ctx_len[i]: end_loc = b_seq_start_loc[i] + b_ctx_len[i] else : end_loc = start_loc + block_size start_slot = block_table[i, block_id] * block_size end_slot = start_slot + end_loc - start_loc k_cache. view(- 1 , num_kv_heads, head_size)[start_slot:end_slot]. copy_( key[start_loc:end_loc]) v_cache. view(- 1 , num_kv_heads, head_size)[start_slot:end_slot]. copy_( value[start_loc:end_loc]) cur_ctx += block_size block_id += 1 # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] # k_cache的MQA/GQA处理,以及按照8对tensor进行重新布局(估计和提升kernel IO access效率有关) k_cache = k_cache. view(- 1 , block_size, num_kv_heads, head_size // 8 , 8 ). permute(0 , 2 , 3 , 1 , 4 ). contiguous() # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] # to V_cache[num_blocks, num_kv_heads, head_size, block_size] # v_cache的MQA/GQA处理 v_cache = v_cache. view(- 1 , block_size, num_kv_heads, head_size). permute(0 , 2 , 3 , 1 ). contiguous() # 调用prefix prefill triton kernel context_attention_fwd(query, # new tokens对应的query Tensor, 没有被prefix cache命中的部分 k, # new tokens对应的keys Tensor v, # new tokens对应的values Tensor output, # Attention输出 k_cache, # prefix cache命中的tokens对应的keys v_cache, # prefix cache命中的tokens对应的values block_table, # new tokens对应的block_table b_start_loc, # new tokens len的前缀和cumsum(query_lens) b_seq_len, # 实际的seq_len=new_tokens len + b_ctx_len b_ctx_len, # 命中了prefix cache的token数 max_input_len, # 最大seq_len限制,比如1024、4096等 sliding_window= sliding_window) torch. cuda. synchronize()
在源码编译安装完vllm后,可以直接跑一下这个单测:
cd tests/kernels && pytest -v -s test_prefix_prefill.py
0x05 vLLM Prefix Prefill Kernel: 如何确认有多少个Token被Prefix Cache命中?
Anyway,如果你正在使用Prefix Caching这个功能,一定对它到底能命中多少个tokens感兴趣,因为这个信息对于我们debug和调试是比较重要的。目前vllm的logger日志中还没有提供这样的信息,但是我们可以在context_attention_fwd中加一行代码就可以看到这个信息,先作为一种临时的方案吧。其中b_seq_len表示输入的tokens数,b_ctx_len表示命中prefix cache的tokens数。
prefix cache debug
0x06 vLLM Prefix Prefill Kernel: 通用Head Sizes支持
Triton based的prefix prefill kernel,一开始有non-power-of-2的限制,也就是说,只支持head size(或者head dim)为2的次方,比如64、128等,但是对于80、96等不是2的次方数的head size是不支持的。不过有几位老哥提了PR对head size为non-power-of-2的情况进行了支持,也包括含有alibi的情况。因此,本小节就先解析一下prefix prefill kernel中对于通用head size的支持。
支持non-power-of-2的通用head size,主要还是通过增加对head size的mask来处理的。一个很直观的想法就是,对于80、96,我们可以考虑将其round up到2的次方数,比如128;然后在实际使用的时候,我们可以根据实际的head size,构造load/store mask,防止内存越界即可。这段逻辑如下:
使用dim_mask
比如,对于query tensor的load逻辑,则需要通过以下代码构造mask:
# [D]; starts at 0 表示head size,这里使用BLOCK_DMODEL_PADDED 为 2的次方 offs_d = tl. arange(0 , BLOCK_DMODEL_PADDED) # [M]; starts at current position in query # offs_m仅表示当前的query,一个batch有多个query offs_m = start_m * BLOCK_M + tl. arange(0 , BLOCK_M) # offs_m与off_d代表当前program要处理的块index,我们还需要将这个index转换成数据指针的位置 # [M,D] 获取全局的数据指针 off_q = ( (cur_batch_in_all_start_index + offs_m[:, None ]) * stride_qbs + cur_head * stride_qh + offs_d[None , :] * stride_qd) # offs_m[:, None]) [BLOCK_M, 1] offs_d[None, :] [1, D] # broadcast -> [M,D]? cur_head已经考虑在内 # [D] mask处理,因为pad成2的次方了。而实际可能不是2的次方,比如80<128 dim_mask = tl. where( tl. arange(0 , BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1 , 0 ). to(tl. int1) # Q看做全局数据指针,off_q看做全局索引 Q:(num_tokens, heads, head_size) # dim_mask[None, :] [1,D] -> 对D做mask,并且broadcast为[M,D]; # cur_batch_seq_len - cur_batch_ctx_len表示实际query_len,超过的不用考虑计算 # offs_m[:, None] [M, 1] # dim_mask[None, :] & offs_m[:, None] -> [1,D] & [M, 1] # >>> dim_mask = torch.tensor([1,1,1,1,0,0]) # [1,6] # >>> offs_m = torch.tensor([8,9,10,11]) # [4,1] # >>> mask=dim_mask[None, :] & (offs_m[:, None] < 12) # >>> mask # tensor([[1, 1, 1, 1, 0, 0], # [1, 1, 1, 1, 0, 0], # [1, 1, 1, 1, 0, 0], # [1, 1, 1, 1, 0, 0]]) # >>> mask.shape # torch.Size([4, 6]) # [M,D] q = tl. load(Q + off_q, # [M,D] mask= dim_mask[None , :] & (offs_m[:, None ] < cur_batch_seq_len - cur_batch_ctx_len), other= 0.0 )
我们可以看到,dim_mask结合seq_len维度的mask(
offs_m[:,None]< cur_batch_seq_len - cur_batch_ctx_len)
)就可以得到query tensor最终需要的mask,以确保内存访问不会越界。对于,k和v tensor的逻辑也类似。最终,我们得到了的mask长这样:
mask for query tensor
0x07 vLLM Prefix Prefill Kernel: MQA/GQA支持
接下来再看下prefix prefill kernel对于MQA/GQA的支持。逻辑其实很简单,对于MQA/GQA的情况,只需要计算清楚当前的query head需要用到的kv head的索引即可。这段代码逻辑为:
cur_batch = tl. program_id(0 ) # batch中的不同query放在不同的block进行组织 cur_head = tl. program_id(1 ) # 当前query中的不同head也在不同的block进行组织 # ..... # MQA/GQA处理:比如num_queries_per_kv=8,表示GQA,8个query head共享一个kv head,那么在kernel # 计算中实际使用到的kv head,则需要用当前query head的索引,除以8,比如query head的索引为14,则对应 # 需要使用的kv head的索引为14//8=1(索引从0开始) cur_kv_head = cur_head // num_queries_per_kv
0x08 vLLM Prefix Prefill Kernel: Triton Kernel解析
最后,贴上带注释的_fwd_kernel_alibi源码。理解了前边所有的技术点后,接下来的这段代码应该很好理解了,不再赘述了。(后续有更新再继续补充注释)
@triton.jit def _fwd_kernel_alibi ( Q, # new tokens对应的query Tensor K, # new tokens对应的keys Tensor V, # new tokens对应的values Tensor K_cache, # K prefix cache, 是已经准备好的Tensor V_cache, # V prefix cache, 是已经准备好的Tensor B_Loc, # block table sm_scale, # scale factor B_Start_Loc, # new tokens len的前缀和cumsum(query_lens) B_Seqlen, # 实际的seq_len=new_tokens len + b_ctx_len B_Ctxlen, # 命中了prefix cache的token数 Alibi_slopes, block_size, # block size,比如32 x, # 8, 对应k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8).permute(0, 2, 3, 1, 4).contiguous() Out, # attention输出 stride_b_loc_b, # 各种tensor不同维度上的stride stride_b_loc_s, stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh, stride_kd, stride_vbs, stride_vh, stride_vd, stride_obs, stride_oh, stride_od, stride_k_cache_bs, stride_k_cache_h, stride_k_cache_d, stride_k_cache_bl, stride_k_cache_x, stride_v_cache_bs, stride_v_cache_h, stride_v_cache_d, stride_v_cache_bl, num_queries_per_kv: int , # MQA/GQA 一个kv head对应query head数量 BLOCK_M: tl. constexpr, # triton常量 此处BLOCK_M=BLOCK=128,表示Q的seq_len方向的分块大小 BLOCK_DMODEL: tl. constexpr, # head size 表示实际的head size BLOCK_DMODEL_PADDED: tl. constexpr, # head size padded to a power of 2 BLOCK_N: tl. constexpr, # triton常量 此处BLOCK_N=BLOCK=128,表示KV的seq_len方向的分块大小 ): # attn_bias[] cur_batch = tl. program_id(0 ) # batch中的不同query放在不同的block进行组织 cur_head = tl. program_id(1 ) # 当前query中的不同head也在不同的block进行组织 # BLOCK_M: 表示Q的seq_len行方向并行,每个block处理BLOCK_M个tokens # BLOCK_N: 表示K的seq_len列方向并行,每个block处理BLOCK_N个tokens # max_input_len/BLOCK: 表示最多需要多少个block,max_input_len由用户指定,可以是256/512/1024/2048/...等 # BLOCK_M=BLOCK_M=BLOCK=128 # start_m: 当前属于哪一个block,其实就是grid最内层的block id. start_m = tl. program_id(2 ) # MQA/GQA处理:比如num_queries_per_kv=8,表示GQA,8个query head共享一个kv head,那么在kernel # 计算中实际使用到的kv head,则需要用当前query head的索引,除以8,比如query head的索引为14,则对应 # 需要使用的kv head的索引为14//8=1(索引从0开始) cur_kv_head = cur_head // num_queries_per_kv # cur_batch_seq_len: the length of prompts # cur_batch_ctx_len: the length of prefix # cur_batch_in_all_start_index: the start id of the dim=0 cur_batch_ctx_len = tl. load(B_Ctxlen + cur_batch) cur_batch_seq_len = tl. load(B_Seqlen + cur_batch) # 拿到当前样本在batch中对应的start location,因为输入不是pad的,而是拼接在一起的 # 需要记录当前的样本在batch中真正的start location # B_Start_Loc: 前缀和 torch.cumsum(torch.tensor([0] + query_lens[:-1])) cur_batch_in_all_start_index = tl. load(B_Start_Loc + cur_batch) # start position inside of the query # generally, N goes over kv, while M goes over query_len # block_start_loc: 当前block需要处理的Q的开始索引 block_start_loc = BLOCK_M * start_m # initialize offsets # [N]; starts at 0 kv的分块偏移量 offs_n = tl. arange(0 , BLOCK_N) # [D]; starts at 0 表示head size,这里使用BLOCK_DMODEL_PADDED 为 2的次方 offs_d = tl. arange(0