专栏名称: GiantPandaCV
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
51好读  ›  专栏  ›  GiantPandaCV

[Prefill优化]图解vLLM Prefix Prefill Triton Kernel

GiantPandaCV  · 公众号  ·  · 2024-06-14 10:31

正文



作者丨DefTruth
来源丨https://zhuanlan.zhihu.com/p/695799736
编辑丨GiantPandaCV


0x00 前言

在上一篇Prefill优化的文章中,已经详细讲解了vLLM Automatic Prefix Caching(Hash RadixAttention)的原理和Cache调度的实现,包括SGLang RadixAttention原理,并且结合图解和代码,详细分析了vLLM中的Hash RadixAttention实现。vLLM中的Hash RadixAttention内容包括:Hash RadixAttention、Hash Prefix Tree、Prefix/Generate 阶段Hash码处理、Prefix + Generated KV Caching的调度逻辑、边界情况思考、vLLM Automatic Prefix Caching在多轮对话中的应用分析以及代码应用实践。本篇,继续深入,讲解Automatic Prefix Caching中用到的Triton Based Prefix Prefill Kernel。 推荐先阅读完上一篇的Automatic Prefix Caching原理,再来阅读本篇的kernel解读。

DefTruth:[Prefill优化][万字] 原理&图解vLLM Automatic Prefix Cache(RadixAttention): 首Token时延优化 https://zhuanlan.zhihu.com/p/693556044

本文包含内容如下。( 提示:需要先阅读完vLLM Automatic Prefix Cache作为基础

  • 0x01 OpenAI Triton: Triton Kernel编程极简入门

  • 0x02 vLLM Prefix Prefill Kernel: Prefix Prefill Kernel与Attention Kernel区别

  • 0x03 vLLM Prefix Prefill Kernel: 先说Tiling分块策略

  • 0x04 vLLM Prefix Prefill Kernel: 再看Kernel调用

  • 0x05 vLLM Prefix Prefill Kernel: 如何确认有多少个Token被Prefix Cache命中?

  • 0x06 vLLM Prefix Prefill Kernel: 通用Head Sizes支持

  • 0x07 vLLM Prefix Prefill Kernel: MQA/GQA支持

  • 0x08 vLLM Prefix Prefill Kernel: Triton Kernel解析

  • 0x09 总结

0x01 OpenAI Triton: Triton Kernel编程极简入门

关于OpenAI Triton,这里只做简单的介绍。网上可以找到大量入门的文章,本文也先不重复了。本文主要关注Prefix Prefill Kernel的实现,而非深挖Triton的底层的原理。

GPU基础架构

传统的基于 CUDA 进行 GPU 编程难度较大,在优化 CUDA 代码时,必须考虑到数据流在DRAM、SRAM 和 ALU之间的Load/Store的问题,还需要仔细考虑到Grid、Block、Thread和Warp等不同级别的调度优化问题。这些问题包括但不限于:

1. 从 DRAM 的内存传输必须合并成大型事务,以利用现代内存接口的大总线宽度(内存合并访问)。
2. 数据必须在重复使用前手动存储到 SRAM 中,并进行管理来最小化bank conflict。
3. 计算必须仔细地进行划分和调度,不仅是在流式多处理器(SMs)之间,还包括在其内部,以促进指令/线程级并行性,并利用专用的 ALU(例如,Tensor Cores)。

因此,哪怕是CUDA熟练工,也得花费不少的精力,才能写出一个性能接近理论峰值的Kernel。(比如像我这种菜菜的,基本上已经放弃手撸CUDA Kernel了)。Triton 的出现,降低了CUDA Kernel编写的难度,它将一些需要精心设计的优化策略进行自动化,比如内存事务合并、SRAM分配和管理、流水线优化等,从而使得编程人员可以将更多的精力放在算法本身。

Triton Complier 编译优化

从官方放出的这个表格中,我们可以看到,如果使用Triton,内存事务合并、SRAM管理以及SM内的线程调度都是自动进行的,我们只需要把精力花在SM之间管理即可,这也就是说, Triton的编程粒度是Block (每个Block只会被调度到一个SM上),而不是Thread。我们只需要考虑每个Block需要做什么,至于Thread/Warp的分布和调度,Triton自动给我们处理了。那么,Block这个概念,在Triton中通过什么进行表达呢?答案是: program

Triton Block-wise 编程模型

block -> program,在Triton中,使用 program_id 来标识一个唯一的program。编程人员只需要考虑一个program(block)内的编程逻辑,比如这个最简单的add_kernel。 x_ptr , y_ptr , 和 output_ptr 分别是指向第一个输入向量、第二个输入向量和输出向量的指针。这些向量存储在 GPU 的内存中。比较常见的就是PyTorch和Triton一起使用,Triton将会传入的Tensor当成指针来处理,而非数据张量。 BLOCK_SIZE: tl.constexpr 表示一个triton的编译时常量,表示每个 block需要处理的元素数量。 mask = offsets < n_elements 表示创建一个mask以防止内存操作超出范围。tl.load和tl.store分表表示triton中的数据加载和写入的操作,这也是需要注意的,Triton为了能更好地进行性能优化,它是在指针级别上做操作的,而非数据Tensor级别。

  • Triton Based Kernel

import tritonimport triton.language as tl@triton.jitdef add_kernel(x_ptr,  # *Pointer* to first input vector.
              y_ptr,  # *Pointer* to second input vector.
              output_ptr,  # *Pointer* to output vector.
              n_elements,  # Size of the vector.
              BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
              # NOTE: `constexpr` so it can be used as a shape value.
              ):
   # There are multiple 'programs' processing different data. We identify which program
   # we are here:
   #  有多个'程序'(也就是block)处理不同的数据。我们在这里标识我们是哪个程序:
   pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
   # This program will process inputs that are offset from the initial data.
   # For instance, if you had a vector of length 256 and block_size of 64, the programs
   # would each access the elements [0:64, 64:128, 128:192, 192:256].
   # Note that offsets is a list of pointers:
   # 该程序将处理与初始数据偏移的输入。
   # 例如,如果您有长度为 256 的向量和块大小为 64,程序
   # 将分别访问元素[0:64, 64:128, 128:192, 192:256]。
   # 请注意,偏移量是指针的列表:
   block_start = pid * BLOCK_SIZE
   offsets = block_start + tl.arange(0, BLOCK_SIZE)
   # Create a mask to guard memory operations against out-of-bounds accesses.
   # 创建一个mask以防止内存操作超出范围。
   mask = offsets < n_elements
   # Load x and y from DRAM, masking out any extra elements in case the input is not a
   # multiple of the block size.
   x = tl.load(x_ptr + offsets, mask=mask)
   y = tl.load(y_ptr + offsets, mask=mask)
   output = x + y
   # Write x + y back to DRAM.
   tl.store(output_ptr + offsets, output, mask=mask)
  • PyTorch调用 (提示:Triton将会传入的Tensor当成指针来处理,而非数据张量)

def add(x: torch.Tensor, y: torch.Tensor):
   # 我们需要预先分配输出。
   output = torch.empty_like(x)
   assert x.is_cuda and y.is_cuda and output.is_cuda
   n_elements = output.numel()
   # SPMD启动网格表示并行运行的内核实例数。
   # 它类似于CUDA启动网格。对于add_kernel我们使用一个1D网格,其大小是块的数量:
   grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
   # 注意:
   #  - 每个torch.tensor对象都隐式地转换为指向其第一个元素的指针。
   #  - `triton.jit`'ed函数可以通过一个启动网格索引来获得一个可调用的GPU内核。
   #  - 不要忘记将元参数作为关键字参数传递。
   add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
   # 我们返回一个指向z的句柄,但是,由于`torch.cuda.synchronize()`尚未被调用,内核此时仍在异步运行。
   return output

需要注意的是,Triton将会传入的Tensor当成指针来处理,而非数据张量。并且,由于Triton Kernel也是异步调用的,因此在测试性能的时候,需要在函数返回后添加 torch.cuda.synchronize() 。更详细的Triton 入门,推荐阅读:如何入门 OpenAI Triton 编程? 以及 科密中的科蜜:OpenAI Triton 入门教程,讲解地很详细,本文Triton部分内容参考自这两篇文章(侵删)。Anyway,现在只需要记住以下这点,即可继续往下阅读本文了。

1. Program相当于CUDA编程中的Block,program_id相当于block id。
2. CUDA的编程模型从grid-block-thread,被简化为 Block-wise ,kernel启动时,只需要考虑一个grid中block的布局。比如,grid=(M,N,D/BLOCK_K)表示这个gird是一个3D的block布局。

0x02 vLLM Prefix Prefill Kernel: Prefix Prefill Kernel与Attention Kernel区别

回到Prefix Caching这个话题。需要注意的是,在使用了Prefix Caching后,就无法使用常规的Attention kernel来计算Prefill阶段的注意力结果了。这是由于常规kernel都暗含着一个假设,Q_len等于KV_len,同时两者也等于prompt_len。但是在Prefix Caching的背景下,这个假设就不成立了。因为当前请求的prompt中,会有部分被缓存的KV Cache命中,不需要重复计算,也就是说,Q_len Q_len ,我们需要一个新的kernel来处理这种情况。vLLM中也正是这样处理的,目前prefix prefill kernel的实现在vllm/attention/ops/prefix_prefill.py。如果使用了prefix caching,则会走到这里实现的triton based prefix prefill kernel。

prefix prefill kernel

prefix prefill kernel相关的也有不少细节,这里先简单说明它解决的问题。接下来会继续分析它的源码。

0x03 vLLM Prefix Prefill Kernel: 先说Tiling分块策略

  • 图解Tiling分块策略

自顶向下,为了更好理解Prefix Prefill Kernel的含义,我们先来看看它的Tiling策略和Block的布局。这部分代码如下。其中 BLOCK_M : 表示最内层 Q的seq_len行方向并行 ,每个block处理BLOCK_M个tokens; BLOCK_N : 表示最内层 KV的seq_len列方向并行 ,每个block处理BLOCK_N个tokens; BLOCK_MxBLOCK_N(128x128) 作为一个最小的Tile进行处理;max_input_len/BLOCK: 表示最多需要多少个block,max_input_len由用户指定,可以是1024/2048/4096等;在目前的vLLM Prefix Prefill Kernel中,有 BLOCK_M=BLOCK_M=BLOCK=128

        # 其他的代码逻辑先省略,现在只关注Tiling和Block的布局
       cap = torch.cuda.get_device_capability()
       BLOCK = 128 if cap[0] >= 8 else 64
       # shape constraints
       Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
       sm_scale = 1.0 / (Lq**0.5)
       batch, head = b_seq_len.shape[0], q.shape[1]
       num_queries_per_kv = q.shape[1] // k.shape[1]
       # BLOCK_M: 表示Q的seq_len行方向并行,每个block处理BLOCK_M个tokens
       # BLOCK_N: 表示K的seq_len列方向并行,每个block处理BLOCK_N个tokens
       # max_input_len/BLOCK: 表示最多需要多少个block,max_input_len由用户指定,可以是256/512/1024/2048/...等
       # BLOCK_M=BLOCK_M=BLOCK=128
       grid = (batch, head, triton.cdiv(max_input_len, BLOCK))  # batch, head,

我们可以看到,Prefix Prefill Kernel中,分块的布局采用的是 [batch, heads, max_input_len/BLOCK] 。熟悉FlashAttention V2的同学会马上反应过来,这其实就是FlashAttention V2中的Tiling逻辑。只是在Prefix Prefill Kernel中,还需要额外考虑处理被Prefix Cache命中的KV Cache。Prefix Cache和没有被Prefix Cache命中的New Tokens是分开处理的。以batch_size=8,heads=8,max_input_len=1024,BLOCK_SIZE(=BLOCK)=128为例,其对应的分块策略和Block布局如下(vLLM Prefix Prefill Kernel Tiling)。

vLLM Prefix Prefill Kernel Tiling

[batch, heads, max_input_len/BLOCK],我们可以看到最内层的max_input_len/BLOCK,表示处理一个head上的Attention计算。比如max_input_len=1024, BLOCK=128时,最内层有8=1024/128个program(也就是Thread Block)来负责这个Head的Attention计算,其中又有BLOCK_M=BLOCK_N=BLOCK=128,表示,每个Thread Block处理这个Head的BLOCK_M个New Query Tokens的Attention,并且对于KV按照BLOCK_N=128的块大小进行迭代计算 FlashAttention ,最后,一个Thread Block会得到[BLOCK_M, D]大小的Attention输出O(D表示head size,比如64, 128等);由于有大量的Query Tokens对应的KV Cache被Prefix Cache命中,因此,这些tokens的Attention计算可以直接skip掉,从而节省了大量的计算,只需要对没被Prefix Cache命中的New Tokens进行Attention计算即可。 FlashAttention的原理解析,推荐阅读我写的另一篇文章:

DefTruth:[Attention优化][2w字] 原理&图解: 从Online-Softmax到FlashAttention V1/V2/V3 https://zhuanlan.zhihu.com/p/668888063
  • Prefix Caching下的Prefill阶段

我们知道Prefill阶段的目的有两个:(1)产生Prompt Tokens的KV Cache;(2)生成首Token;通常,我们会用TTFT(Time To First Token)来评估Prefill的耗时。整合一下Prefix Cache到Prefill的流程,大概长这样;紫色部分表示会被Prefix Cache命中的Tokens,这部分的KV Cache直接使用Prefix Cache中保存的即可。绿色部分为当前输入的Prompt中没被Prefix Cache命中的New Tokens,这部分New Tokens需要计算KV Cache;在获得Prompt中所有Tokens的KV Cache之后,就可以生成首Token了。

Prefill with Prefix Cache

0x04 vLLM Prefix Prefill Kernel: 再看Kernel调用

理解一个Kernel最基本的方法,就是先看下它的接口,以及它是怎么使用的。context_attention_fwd是_fwd_kernel的封装函数,需要传入的参数如下:

    def context_attention_fwd(q, # new tokens对应的query Tensor, 没有被prefix cache命中的部分
                             k, # new tokens对应的keys Tensor
                             v, # new tokens对应的values Tensor
                             o, # Attention输出
                             k_cache, # prefix cache命中的tokens对应的keys
                             v_cache, # prefix cache命中的tokens对应的values
                             b_loc,   # new tokens对应的block_table
                             b_start_loc,   # new tokens len的前缀和cumsum(query_lens)
                             b_seq_len,     # 实际的seq_len=new_tokens len + b_ctx_len
                             b_ctx_len,     # 命中了prefix cache的token数
                             max_input_len, # 最大seq_len限制,比如1024、4096等
                             alibi_slopes=None,
                             sliding_window=None):

不再单独说明了,直接看代码中的注释吧。prefix prefill kernel的单测代码在test_prefix_prefill.py( https://github.com/vllm-project/vllm/blob/main/tests/kernels/test_prefix_prefill.py) ,部分代码以及注释如下:

import torch, randomfrom vllm.attention.ops.prefix_prefill import context_attention_fwddef test_contexted_kv_attention(
   num_heads: int,
   num_queries_per_kv: int,
   head_size: int,
   sliding_window: int,
   dtype: torch.dtype,
   device: str,) -> None:
   random.seed(0)
   torch.manual_seed(0)
   if torch.cuda.is_available():
       torch.cuda.manual_seed(0)
   torch.set_default_device(device)
   torch.cuda.set_device(device)

   MAX_SEQ_LEN = 1024
   MAX_CTX_LEN = 1024
   BS = 10
   cache_size = 640 # 表示有640个cache blocks
   block_size = 32  # 每个block保存32个tokens的KV Cache
   max_block_per_request = 64 # 每个request最多使用多少个cache blocks
   # query_lens表示没被prefix cache命中的token
   query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
   # ctx_lens被prefix cache命中的token
   ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
   # 当前seq_len(即prompt_len)=ctx_len+query_len
   seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
   # MQA/GQA处理num_queries_per_kv表示一个kv head对应多少个query head
   # num_queries_per_kv=1为MHA,num_queries_per_kv>1表示MQA/GQA
   num_kv_heads = num_heads // num_queries_per_kv
   # Batch内New Tokens总和,prefix prefill kernel是假设一个BS的seq拼接在一起作为输入的,不是padding。
   num_tokens = sum(query_lens)
   query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
   query.uniform_(-1e-3, 1e-3)
   # Attention输出的tokens数和query tensor中使用的tokens数相同
   output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
   # 总共需要的KV Cache,是按照输入seq_len计算的。无论是否在Prefix Cache中,都是需要显存来存放的。
   kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
   kv.uniform_(-1e-3, 1e-3)
   key, value = kv.unbind(dim=1)
   # 这里表示的模拟Prefix Cache中的K Cache
   k_cache = torch.zeros(cache_size,
                         block_size,
                         num_kv_heads,
                         head_size,
                         dtype=dtype)
   # 这里表示的模拟Prefix Cache中的V Cache
   v_cache = torch.zeros(cache_size,
                         block_size,
                         num_kv_heads,
                         head_size,
                         dtype=dtype)
   # 这里表示模拟没被Prefix Cache命中的New Tokens的KV Cache,使用的是query_lens。
   k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
   v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
   # 将640个block id随机打乱,并给BS中的每个seq选择一部分作为block ids,构建block_table
   values = torch.arange(0, cache_size, dtype=torch.long)
   values = values[torch.randperm(cache_size)]
   block_table = values[:BS * max_block_per_request].view(
       BS, max_block_per_request)
   # 记录batch中每个seq的长度
   b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
   # 记录batch中每个seq中命中了prefix cache的token数
   b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
   # b_start_loc表示query tensor的start_loc,是query_lens的前缀和,因为query tensor是bs中所有new tokens
   # 拼接在一起作为输入的,在kernel中我们需要知道每个seq的new tokens对应的start位置
   b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
                                           dtype=torch.long),
                              dim=0)
   max_input_len = MAX_SEQ_LEN
   # copy kv to cache
   b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
                                               dtype=torch.long),
                                  dim=0)
   # 将key, values Tensor中全量的cache,按照BS中每个seq的block_table,拷贝到
   # Prefix cache对应的k_cache和v_cache中;以及将New Tokens对应的kv cache
   # 拷贝到k和v Tensor中。
   for i in range(BS):
       for j in range(query_lens[i]):
           k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
                                           j])
           v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
                                             b_ctx_len[i] + j])
       cur_ctx = 0
       block_id = 0
       while cur_ctx < b_ctx_len[i]:
           start_loc = b_seq_start_loc[i] + cur_ctx
           if cur_ctx + block_size > b_ctx_len[i]:
               end_loc = b_seq_start_loc[i] + b_ctx_len[i]
           else:
               end_loc = start_loc + block_size
           start_slot = block_table[i, block_id] * block_size
           end_slot = start_slot + end_loc - start_loc
           k_cache.view(-1, num_kv_heads,
                        head_size)[start_slot:end_slot].copy_(
                            key[start_loc:end_loc])
           v_cache.view(-1, num_kv_heads,
                        head_size)[start_slot:end_slot].copy_(
                            value[start_loc:end_loc])
           cur_ctx += block_size
           block_id += 1
   # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
   # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
   # k_cache的MQA/GQA处理,以及按照8对tensor进行重新布局(估计和提升kernel IO access效率有关)
   k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
                          8).permute(0, 2, 3, 1, 4).contiguous()
   # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
   # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
   # v_cache的MQA/GQA处理
   v_cache = v_cache.view(-1, block_size, num_kv_heads,
                          head_size).permute(0, 2, 3, 1).contiguous()
   # 调用prefix prefill triton kernel
   context_attention_fwd(query, # new tokens对应的query Tensor, 没有被prefix cache命中的部分
                         k,     # new tokens对应的keys Tensor
                         v,     # new tokens对应的values Tensor
                         output,  # Attention输出
                         k_cache, # prefix cache命中的tokens对应的keys
                         v_cache, # prefix cache命中的tokens对应的values
                         block_table, # new tokens对应的block_table
                         b_start_loc,   # new tokens len的前缀和cumsum(query_lens)
                         b_seq_len,     # 实际的seq_len=new_tokens len + b_ctx_len
                         b_ctx_len,     # 命中了prefix cache的token数
                         max_input_len, # 最大seq_len限制,比如1024、4096等
                         sliding_window=sliding_window)
   torch.cuda.synchronize()

在源码编译安装完vllm后,可以直接跑一下这个单测:

cd tests/kernels && pytest -v -s test_prefix_prefill.py

0x05 vLLM Prefix Prefill Kernel: 如何确认有多少个Token被Prefix Cache命中?

Anyway,如果你正在使用Prefix Caching这个功能,一定对它到底能命中多少个tokens感兴趣,因为这个信息对于我们debug和调试是比较重要的。目前vllm的logger日志中还没有提供这样的信息,但是我们可以在context_attention_fwd中加一行代码就可以看到这个信息,先作为一种临时的方案吧。其中b_seq_len表示输入的tokens数,b_ctx_len表示命中prefix cache的tokens数。

prefix cache debug

0x06 vLLM Prefix Prefill Kernel: 通用Head Sizes支持

Triton based的prefix prefill kernel,一开始有non-power-of-2的限制,也就是说,只支持head size(或者head dim)为2的次方,比如64、128等,但是对于80、96等不是2的次方数的head size是不支持的。不过有几位老哥提了PR对head size为non-power-of-2的情况进行了支持,也包括含有alibi的情况。因此,本小节就先解析一下prefix prefill kernel中对于通用head size的支持。

  • load/store Mask

支持non-power-of-2的通用head size,主要还是通过增加对head size的mask来处理的。一个很直观的想法就是,对于80、96,我们可以考虑将其round up到2的次方数,比如128;然后在实际使用的时候,我们可以根据实际的head size,构造load/store mask,防止内存越界即可。这段逻辑如下:

使用dim_mask

比如,对于query tensor的load逻辑,则需要通过以下代码构造mask:

        # [D]; starts at 0 表示head size,这里使用BLOCK_DMODEL_PADDED 为 2的次方
       offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
       # [M]; starts at current position in query
       # offs_m仅表示当前的query,一个batch有多个query
       offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
       # offs_m与off_d代表当前program要处理的块index,我们还需要将这个index转换成数据指针的位置
       # [M,D] 获取全局的数据指针
       off_q = (
           (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
           cur_head * stride_qh + offs_d[None, :] * stride_qd)
       # offs_m[:, None]) [BLOCK_M, 1] offs_d[None, :] [1, D]
       # broadcast -> [M,D]? cur_head已经考虑在内
       # [D] mask处理,因为pad成2的次方了。而实际可能不是2的次方,比如80<128
       dim_mask = tl.where(
           tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1)
       
       # Q看做全局数据指针,off_q看做全局索引 Q:(num_tokens, heads, head_size)
       # dim_mask[None, :] [1,D] -> 对D做mask,并且broadcast为[M,D];
       # cur_batch_seq_len - cur_batch_ctx_len表示实际query_len,超过的不用考虑计算
       # offs_m[:, None] [M, 1]
       # dim_mask[None, :] & offs_m[:, None] -> [1,D] & [M, 1]
       # >>> dim_mask = torch.tensor([1,1,1,1,0,0]) # [1,6]
       # >>> offs_m = torch.tensor([8,9,10,11]) # [4,1]
       # >>> mask=dim_mask[None, :] & (offs_m[:, None] < 12)
       # >>> mask
       # tensor([[1, 1, 1, 1, 0, 0],
       #         [1, 1, 1, 1, 0, 0],
       #         [1, 1, 1, 1, 0, 0],
       #         [1, 1, 1, 1, 0, 0]])
       # >>> mask.shape
       # torch.Size([4, 6]) # [M,D]
       q = tl.load(Q + off_q, # [M,D]
                   mask=dim_mask[None, :] &
                   (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len),
                   other=0.0)

我们可以看到,dim_mask结合seq_len维度的mask( offs_m[:,None]< cur_batch_seq_len - cur_batch_ctx_len) )就可以得到query tensor最终需要的mask,以确保内存访问不会越界。对于,k和v tensor的逻辑也类似。最终,我们得到了的mask长这样:

mask for query tensor

0x07 vLLM Prefix Prefill Kernel: MQA/GQA支持

接下来再看下prefix prefill kernel对于MQA/GQA的支持。逻辑其实很简单,对于MQA/GQA的情况,只需要计算清楚当前的query head需要用到的kv head的索引即可。这段代码逻辑为:

        cur_batch = tl.program_id(0) # batch中的不同query放在不同的block进行组织
       cur_head = tl.program_id(1) # 当前query中的不同head也在不同的block进行组织
       # .....
       # MQA/GQA处理:比如num_queries_per_kv=8,表示GQA,8个query head共享一个kv head,那么在kernel
       # 计算中实际使用到的kv head,则需要用当前query head的索引,除以8,比如query head的索引为14,则对应
       # 需要使用的kv head的索引为14//8=1(索引从0开始)
       cur_kv_head = cur_head // num_queries_per_kv

0x08 vLLM Prefix Prefill Kernel: Triton Kernel解析

最后,贴上带注释的_fwd_kernel_alibi源码。理解了前边所有的技术点后,接下来的这段代码应该很好理解了,不再赘述了。(后续有更新再继续补充注释)

    @triton.jit
   def _fwd_kernel_alibi(
       Q, # new tokens对应的query Tensor
       K, # new tokens对应的keys Tensor
       V, # new tokens对应的values Tensor
       K_cache,  # K prefix cache, 是已经准备好的Tensor
       V_cache,  # V prefix cache, 是已经准备好的Tensor
       B_Loc,    # block table
       sm_scale, # scale factor
       B_Start_Loc, # new tokens len的前缀和cumsum(query_lens)
       B_Seqlen,    # 实际的seq_len=new_tokens len + b_ctx_len
       B_Ctxlen,    # 命中了prefix cache的token数
       Alibi_slopes,
       block_size,  # block size,比如32
       x,           # 8, 对应k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8).permute(0, 2, 3, 1, 4).contiguous()
       Out,         # attention输出
       stride_b_loc_b, # 各种tensor不同维度上的stride
       stride_b_loc_s,
       stride_qbs,
       stride_qh,
       stride_qd,
       stride_kbs,
       stride_kh,
       stride_kd,
       stride_vbs,
       stride_vh,
       stride_vd,
       stride_obs,
       stride_oh,
       stride_od,
       stride_k_cache_bs,
       stride_k_cache_h,
       stride_k_cache_d,
       stride_k_cache_bl,
       stride_k_cache_x,
       stride_v_cache_bs,
       stride_v_cache_h,
       stride_v_cache_d,
       stride_v_cache_bl,
       num_queries_per_kv: int,  # MQA/GQA 一个kv head对应query head数量
       BLOCK_M: tl.constexpr,    # triton常量 此处BLOCK_M=BLOCK=128,表示Q的seq_len方向的分块大小
       BLOCK_DMODEL: tl.constexpr,  # head size 表示实际的head size
       BLOCK_DMODEL_PADDED: tl.constexpr,  # head size padded to a power of 2
       BLOCK_N: tl.constexpr,   # triton常量 此处BLOCK_N=BLOCK=128,表示KV的seq_len方向的分块大小
   ):
       # attn_bias[]
       cur_batch = tl.program_id(0) # batch中的不同query放在不同的block进行组织
       cur_head = tl.program_id(1) # 当前query中的不同head也在不同的block进行组织
       # BLOCK_M: 表示Q的seq_len行方向并行,每个block处理BLOCK_M个tokens
       # BLOCK_N: 表示K的seq_len列方向并行,每个block处理BLOCK_N个tokens
       # max_input_len/BLOCK: 表示最多需要多少个block,max_input_len由用户指定,可以是256/512/1024/2048/...等
       # BLOCK_M=BLOCK_M=BLOCK=128
       # start_m: 当前属于哪一个block,其实就是grid最内层的block id.
       start_m = tl.program_id(2)
       # MQA/GQA处理:比如num_queries_per_kv=8,表示GQA,8个query head共享一个kv head,那么在kernel
       # 计算中实际使用到的kv head,则需要用当前query head的索引,除以8,比如query head的索引为14,则对应
       # 需要使用的kv head的索引为14//8=1(索引从0开始)
       cur_kv_head = cur_head // num_queries_per_kv

       # cur_batch_seq_len: the length of prompts
       # cur_batch_ctx_len: the length of prefix
       # cur_batch_in_all_start_index: the start id of the dim=0
       cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
       cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
       # 拿到当前样本在batch中对应的start location,因为输入不是pad的,而是拼接在一起的
       # 需要记录当前的样本在batch中真正的start location
       # B_Start_Loc: 前缀和 torch.cumsum(torch.tensor([0] + query_lens[:-1]))
       cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
       
       # start position inside of the query
       # generally, N goes over kv, while M goes over query_len
       # block_start_loc: 当前block需要处理的Q的开始索引
       block_start_loc = BLOCK_M * start_m

       # initialize offsets
       # [N]; starts at 0 kv的分块偏移量
       offs_n = tl.arange(0, BLOCK_N)
       # [D]; starts at 0 表示head size,这里使用BLOCK_DMODEL_PADDED 为 2的次方
       offs_d = tl.arange(0






请到「今天看啥」查看全文