专栏名称: GiantPandaCV
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
最高人民法院  ·  Yo!听说代表法要修改啦? ·  4 天前  
51好读  ›  专栏  ›  GiantPandaCV

FasterTransformer Decoding 源码分析(四)-SelfAttention实现介绍

GiantPandaCV  · 公众号  ·  · 2024-05-08 22:04

主要观点总结

本文是对FasterTransformer Decoding源码中selfAttention部分的详细分析,涉及模块介绍、设计优化和源码分析。selfAttention是decoder中对inputs进行编码生成feature的模块,文中分析了其公式原理、设计策略及代码实现。同时,也探讨了其性能优化方法,如KV Cache、Cache Layout和Block/Thread设计等,并详细介绍了源码中相关函数的调用和具体实现。

关键观点总结

关键观点1: 模块介绍

SelfAttention模块位于DecoderLayer的第二个模块,输入为经过LayerNorm正则化后的结果,处理后进行残差连接再输入LayerNorm中。其功能可简单理解为对inputs进行编码生成feature,并与encoder产生的feature进行crossAttention生成最终的结果。

关键观点2: 设计&优化

FasterTransformer中针对SelfAttention模块设计了KV Cache、Cache Layout和Block/Thread设计等优化策略,这些策略实现了空间换时间的优化,提高了计算效率。

关键观点3: 源码分析

文中详细分析了源码中相关函数的调用和具体实现,包括主体框架代码、kernel函数调用及实现,解析了selfAttention公式的处理流程,并探讨了crossAttention和quant功能的代码实现。


正文



作者丨进击的Killua
来源丨https://zhuanlan.zhihu.com/p/669648527
编辑丨GiantPandaCV


本文是FasterTransformer Decoding源码分析的第四篇,也是该系列文章中最核心的一篇。笔者试图去分析selfAttention部分的代码实现和优化,内容较多也比较复杂,笔者会尽最大努力把原理阐述清楚。

一、模块介绍

如下图所示,SelfAttention模块位于DecoderLayer的第二个模块,输入为decoder inputs经过LayerNorm正则化后的结果,经过该模块处理后进行残差连接再输入LayerNorm中。SelfAttention可以简单理解为decoder中对 inputs进行编码生成feature的模块,在后面的流程中会和encoder产生的feature进行crossAttention生成最终的结果。

SelfAttention在decoder中的位置

那么SelfAttention模块本质上就是要实现如下几个公式,这里简单介绍下这几个公式,分别是:

  1. 线性化生成Q、K、V。

  2. 矩阵乘(Q*KT)得到attention Score 。

  3. 对attention Score进行softmax化得到logits。

  4. 使用logits和V进行点乘,再线性化获得最终输出。

SelfAttention 公式

对这几个公式不清楚的可以去看李宏毅老师的讲解视频,每个步骤做了非常详细的介绍。

二、设计&优化

我们先来看下FasterTransformer中针对SelfAttention这个模块设计了哪些优化策略,再来看代码是如何实现的。

1. KV Cache

我们知道在Decoder中解码是逐步进行的,先来看下原始的解码步骤,如下图和文字描述。

  1. step=1,输入= [s], 输出 = 我

  2. step=2,输入= [s] 我, 输出 = [s] 我 有

  3. step=3,输入= [s] 我 有, 输出 = [s] 我 有 猫

  4. step=4,输入= [s] 我 有 猫, 输出 = [s] 我 有 猫 [e]

Decoder 逐步解码过程

因此在逐步解码过程中,针对SelfAttention中Q,K,V矩阵的生成过程如下图所示。(这里仅展示了deocder中首层的SelfAttention,每层的计算逻辑类似)

原始的QKV生成过程

通过观察这个计算过程我们可以发现,每个step的结果中绝大部分的数值都被上个step计算得到过,如下图红框所示。

原始的QKV生成过程,大量重复计算

因此我们可以将每个step中计算过的结果缓存起来,在后续步骤中可以跳过这些内容的计算,只计算增量部分的内容。优化后的计算流程如下图所示,每个step只需计算本次新输入词的Q,K,V,大部分key和value结果均来自前面step计算结果的缓存,这就是KV cache的设计原理,也是经典的空间换时间的优化方法。该例子中仅包含3个step,实际解码过程可能包含上千个step,因此可以节省大量的矩阵计算量,代价就是需要更多的全局内存空间。

优化后的QKV生成过程

2. Cache Layout

根据第一点优化设计,有了Cache后每个step都需要Load Key Cache,FasterTransformer设计了高效的存储layout来支持多轮数据的读写,由于场景上是多读少写(每个key写入一次,需要在多个step中读取),所以设计的初衷是通过牺牲一些写入的效率来最大化读取效率。我们看下Key Cache的shape是:

[num_layer, batch, head_num, size_per_head // x, max_seq_len, x]

这里x是根据数据类型来确定的,比如处理的数据类型是FP32类型(4 bytes),则x=4,即保证最后一个维度的size是16 bytes。为什么要保证16 bytes?因为在很多情况下GPU 的全局内存对齐要求是 128 位(16 bytes),这样首先访问是对齐的,其次同个warp的多个相邻线程可以进行内存联合访问,这样就可以提高存储的访问速度。

num_layer,batch和head_num这前三个维度参数含义比较明确,在具体的核函数内部执行时其为固定值,现对后三个维度的使用和优化进行说明。假设 x=4, max_seq_len=6, size_per_head=8 ,即后三维是[2, 6, 4],如下图所示。对于一个warp中的线程:

写场景

其在生成了当前词的key后需要将其写入cache中,由thread1负责前16B写入,thead2负责后16B写入,由于中间隔了96B,所以无法做写入合并。

读场景

在需要读入cache中的历史key时,thread1首次循环负责读入第一个key的前16B,thread2首次循环负责读入第二个key的前16B,这两个读请求地址是连续的可以合并请求;同理第二轮循环thread1和thread2也可以合并请求,由此最大化了读取的效率,降低整体耗时。

key cache中 read/write过程

3. Block和Thread设计

我们再来从block和thread的视图来看下是如何实现selfAttention中的公式的。

block视图

每个block负责的运算任务如下图所示,其只负责一个词(即当前需要解码的新词)在一个head中一个step的kqv运算,输出output。

block计算视图

thread视图

具体到每个thread,它会负责该词query化后和 当前key + cache中的某些历史key进行kqv运算,最后在block维度上进行全局归约。

thread计算视图

三、源码分析

1. 方法入口

SelfAttention的调用入口如下,代码,解释下这里的输入和输出,具体逻辑在后面。

输入Tensor

  1. input_query:normalize之后的decoder_input,大小是[batch_size,hidden_units_]

  2. finished: 解码是否结束的标记,大小是[batch_size]

  3. sequence_lengths: 每个句子的长度,大小是[batch_size]

  4. step: 当前解码的步数

  5. cache_indirection(option):记录了解码到当前句子中每个词在前序步骤中的beam_index

输出Tensor

  1. hidden_features: SelfAttention的输出feature,大小是[batch_size,hidden_units_],和input_query大小一致。

  2. key_cache: SelfAttention中存储key的cache,用于后续step的计算。

  3. value_cache: SelfAttention中存储Value的cache,用于后续step的计算。

    // input tensors:    //      decoder_input [batch_size, hidden_dimension],    //      encoder_output [batch_size, mem_max_seq_len, memory_hidden_dimension],    //      encoder_sequence_length [batch_size],    //      finished [batch_size],    //      step [1] on cpu    //      sequence_lengths [batch_size]    //      cache_indirection [local_batch_size / beam_width, beam_width, max_seq_len]    // output tensors:    //      decoder_output [batch_size, hidden_dimension],    //      key_cache [num_layer, batch, head_num, size_per_head // x, max_seq_len, x]    //      value_cache [num_layer, batch, head_num, max_seq_len, size_per_head]    //      key_mem_cache [num_layer, batch_size, mem_max_seq_len, hidden_dimension],    //      value_mem_cache [num_layer, batch_size, mem_max_seq_len, hidden_dimension] 
      TensorMap self_attention_input_tensors{
           {"input_query", Tensor{MEMORY_GPU, data_type, {batch_size, hidden_units_}, decoder_normed_input_}},
           {"finished", input_tensors->at(3)},
           {"sequence_lengths", input_tensors->at(5)},
           {"step", input_tensors->at(4)}};        
       self_attention_input_tensors.insertIfValid("cache_indirection", input_tensors->at(6));


       TensorMap self_attention_output_tensors{
           {"hidden_features", Tensor{MEMORY_GPU, data_type, {batch_size, hidden_units_}, self_attn_output_}},
           {"key_cache",
            Tensor{MEMORY_GPU,
                   data_type,
                   std::vector<size_t>(output_tensors->at(1).shape.begin() + 1, output_tensors->at(1).shape.end()),
                   output_tensors->at(1).getPtrWithOffset(self_key_cache_offset)}},
           {"value_cache",
            Tensor{MEMORY_GPU,
                   data_type,
                   std::vector<size_t>(output_tensors->at(2).shape.begin() + 1, output_tensors->at(2).shape.end()),
                   output_tensors->at(2).getPtrWithOffset<T>(self_value_cache_offset)}}};


       self_attention_layer_->forward(&self_attention_output_tensors,
                                      &self_attention_input_tensors,
                                      &decoder_layer_weight->at(l).self_attention_weights);

2. 主体框架

主体框架代码由三部分构成,分别是该step的QKV生成、output生成和Linear输出,详见代码。其中第一部分和第三部分都使用了cublas的封装矩阵乘方法gemm,这里就不多介绍了,主要功能逻辑在第二部分output生成。

第一部分:QKV生成

公式里需要做三次乘法,这里直接用了一次矩阵乘就把QKV的结果都生成了,原理是将权重矩阵concat起来再做乘法,这样可以减少kernel调用次数,降低耗时。

            cublas_wrapper_->Gemm(CUBLAS_OP_N,
                                 CUBLAS_OP_N,
                                 3 * local_hidden_units_,  // n                                  batch_size,
                                 d_model_,  // k                                  attention_weights->query_weight.kernel,
                                 3 * local_hidden_units_,  // n                                  attention_input,
                                 d_model_,  // k                                  qkv_buf_,
                                 3 * local_hidden_units_ /* n */);

第二部分:output生成

核心函数调用,这里参数较多不一一介绍了,非常多(像一些has_ia3等参数应该是在不断迭代的过程中加入的),在后面函数实现中会将重点参数进行阐述。

    fusedQKV_masked_attention_dispatch<T>(
       qkv_buf_,
       attention_weights->query_weight.bias,
       relative_attention_bias,
       key_cache,
       value_cache,
       cache_indir,
       context_buf_,
       finished,
       sequence_lengths,  // NOTE: current seq len including padding (fixed after meeting the finished id)        batch_size,
       batch_size,
       beam_width,
       local_head_num_,
       size_per_head_,
       rotary_embedding_dim_,
       neox_rotary_style_,
       memory_max_len,
       d_prefix_prompt_lengths,
       max_prefix_prompt_length,
       input_tensors->getVal<int>("max_input_length", 0),
       input_tensors->getPtr<int>("total_padding_tokens", nullptr),
       input_tensors->getVal<int>("step"),
       q_scaling_,
       relative_attention_bias_stride,
       linear_bias_slopes,
       masked_tokens,
       input_tensors->getPtr<int>("ia3_tasks", nullptr),
       has_ia3 ? attention_weights->ia3_key_weight.kernel : nullptr,
       has_ia3 ? attention_weights->ia3_value_weight.kernel : nullptr,
       int8_mode_ == 2 ? attention_weights->query_weight.scale_out : nullptr,
       int8_mode_ == 2 ? attention_weights->attention_output_weight.scale : nullptr,
       int8_mode_,
       stream_);

第三部分:Linear输出

这里就是简单地对上步输出结果乘以一个权重矩阵。

            cublas_wrapper_->Gemm(CUBLAS_OP_N,
                                 CUBLAS_OP_N,
                                 d_model_,  // n                                  batch_size,
                                 local_hidden_units_,  // k                                  attention_weights->attention_output_weight.kernel,
                                 d_model_,  // n                                  context_buf_,
                                 local_hidden_units_,  // k                                  attention_out,
                                 d_model_ /* n */);

3. kernel函数调用

上述output生成步骤中会调用如下代码,这里针对每个head中需要处理的层数进行了分类,这个也是大量优化中的常用方案,针对不同的入参大小选择不同size和配置的kernel函数进行处理,这里有经验的一些成分在里面,我们常用的case是hidden_size_per_head=64(head=8)的情况。

template<typename T, typename KERNEL_PARAMS_TYPE>void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream){
   switch (params.hidden_size_per_head) {
       case 32:
           mmha_launch_kernel<T, 32, 32, KERNEL_PARAMS_TYPE>(params, stream);
           break;
       case 48:
           mmha_launch_kernel<T, 48, 64, KERNEL_PARAMS_TYPE>(params, stream);
           break;
       case 64:
           mmha_launch_kernel<T, 64, 64, KERNEL_PARAMS_TYPE>(params, stream);
           break;
       case 80:
           mmha_launch_kernel<T, 80, 128, KERNEL_PARAMS_TYPE>(params, stream);
           break;
       case 96:
           mmha_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);
           break;
       case 112:
           mmha_launch_kernel<T, 112, 128, KERNEL_PARAMS_TYPE>(params, stream);
           break;
       case 128:
           mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);
           break;
       case 144:
           mmha_launch_kernel<T, 144, 256, KERNEL_PARAMS_TYPE>(params, stream);
           break;
       case 160:
           mmha_launch_kernel<T, 160, 256, KERNEL_PARAMS_TYPE>(params, stream);
           break;
       case 192:
           mmha_launch_kernel<T, 192, 256, KERNEL_PARAMS_TYPE>(params, stream);
           break;
       case 224:
           mmha_launch_kernel<T, 224, 256, KERNEL_PARAMS_TYPE>(params, stream);
           break;
       case 256:
           mmha_launch_kernel<T, 256, 256, KERNEL_PARAMS_TYPE>(params, stream);
           break;
       default:
           assert(false);
   }}

接着终于进入到调用核函数的步骤,代码如下,根据是否提供了 cache_indir (即方法入口参数的cache_indirection)分成2类, cache_indir 为空可以认为没有beam_size=1,不为空则beam_size > 1,会对下游索引计算产生影响。函数调用参数主要会根据句子长度来确定单个block中需要的线程数,长度越大需要分配的线程数越多,但也是一些经验值。代码中通过 dim3 grid(params.num_heads, params.batch_size); 定义了grid,所以每个block只负责一个新词的kvq运算。

template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream){
   constexpr int  THREADS_PER_VALUE  = threads_per_value_t<T, Dh_MAX>::value;
   constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
   int            tlength            = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep;
   if (params.cache_indir == nullptr) {
       if (tlength < 32) {
           MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream);
       }
       else if (tlength < 2048) {
           MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128 , DO_CROSS_ATTENTION, false, stream);
       }
       else {
           MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream);
       }
   }
   else {
       if (tlength < 32) {
           MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream);
       }
       else if (tlength < 2048) {
           MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream);
       }
       else {
           MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream);
       }
   }}#define MMHA_LAUNCH_KERNEL(                                                                                            \    T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream)                \    size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK);          \    dim3   grid(params.num_heads, params.batch_size);                                                                  \    mmha::masked_multihead_attention_kernel                                            Dh,                                                                        \                                            Dh_MAX,                                                                    \                                            THDS_PER_KEY,                                                              \                                            THDS_PER_VALUE,                                                            \                                            THDS_PER_BLOCK,                                                            \                                            DO_CROSS_ATTENTION,                                                        \                                            HAS_BEAMS><<>>(params)


4. kernel函数实现

终于进入到kernel函数的实现部分,这里开始要起飞了。由于这个函数代码量有800+,这里只展示核心代码,完整的代码可到这里浏览。这里可以拆成6个子步骤进行阐述,分别是:

kernel实现过程拆解


1. 共享内存和变量的定义和初始化

代码实现上第一部分是共享内存和变量的定义初始化。这里共享存储的使用上有一些tricky,定义的 smem_ 共享存储变量,第一阶段会用于存储Q*K的中间结果 qk_smem ,第二阶段用于存储 logits 的结果 logits_smem (即softmax算出来的概率),第三阶段用于存储最终的输出结果 out_smem 。好处嘛不言而喻就是节省共享存储的空间,毕竟这部分资源非常有限,能使用这个方式的前提是这些内容在逻辑处理流程中没有overlap。

接着还定义了 red_smem q_smem bias_smem 共享内存来存储归约和一些中间结果值,这里的注释都还比较详细。

    // Use smem_size_in_bytes (above) to determine the amount of shared memory.    extern __shared__ char smem_[];

   // The shared memory for the Q*K^T values and partial logits in softmax.    float* qk_smem = reinterpret_cast<float*>(smem_);

   // The shared memory for the logits. For FP32, that's the same buffer as qk_smem.    char* logits_smem_ = smem_;
   if (sizeof(Tk) != 4) {
       // TODO - change to tlength        const int max_timesteps = min(params.timestep, params.memory_max_len);
       logits_smem_ +=
           (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
   }
   Tk* logits_smem = reinterpret_cast<Tk*>(logits_smem_);

   // The shared memory to do the final reduction for the output values. Reuse qk_smem.    Tk* out_smem = reinterpret_cast<Tk*>(smem_);

   // The shared memory buffers for the block-wide reductions. One for max, one for sum.    __shared__ float red_smem[WARPS_PER_BLOCK * 2];

   // A vector of Q or K elements for the current timestep.    using Qk_vec_k = typename Qk_vec_k_<T, Dh_MAX>::Type;  // with kernel-used precision    using Qk_vec_m = typename Qk_vec_m_<T, Dh_MAX>::Type;  // with memory-used precision
   // Use alignment for safely casting the shared buffers as Qk_vec_k.    // Shared memory to store Q inputs.    __shared__ __align__(sizeof(Qk_vec_k)) Tk q_smem[Dh_MAX];

   // This is one of the reasons we should have a separate kernel for cross attention    __shared__ __align__(sizeof(Qk_vec_k)) Tk bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1];

这段代码对cache layout进行了一些说明和size计算,具体设计参考优化设计部分。

    // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8/16 for FP32/FP16/FP8. Since each thread    // owns x elements, we have to decompose the linear index into chunks of x values and the posi-    // tion of the thread in that chunk.
   // The number of elements in a chunk of 16B (that's the x in the above formula).    constexpr int QK_ELTS_IN_16B = 16 / sizeof(T);
   // The number of K vectors in 16B.    constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec_m);


2. 获取本轮query和本轮key

这个代码片段主要是实现了从第一大步骤生成的query和key中取值的过程,掺杂了一些crossAttention和quant的代码,因为这个kernel是和crossAttention复用的,这里关于crossAttention和quant的介绍放到后面进行,我们先聚焦在selfAttention本身上。这段代码看上去比较长,其实就是2次Load操作,做了一些类型转换便于后续处理统一方便。

    // Trigger the loads from the Q and K buffers.    Qk_vec_k q;
   zero(q);
   if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
       if (params.int8_mode == 2) {
           using Packed_Int8_t  = typename packed_type<int8_t, num_elems<Qk_vec_m>::value>::type;
           using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec_m>::value>::type;
           const auto q_scaling = params.qkv_scale_out[0];
           const auto q_quant =
               *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[qk_offset]);

           convert_from_float(q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
       }
       else {
           q = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q[qk_offset]));
       }
   }

   Qk_vec_k k;
   zero(k);
   if (DO_CROSS_ATTENTION) {
       // The 16B chunk written by the thread.        int co = tidx / QK_VECS_IN_16B;
       // The position of the thread in that 16B chunk.        int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;

       // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.        int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
                    // params.timestep*QK_ELTS_IN_16B +                     tlength * QK_ELTS_IN_16B + ci;
       k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
               vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k_cache[offset])) :
               k;
   }
   else {
       if (params.int8_mode == 2) {
           using Packed_Int8_t  = typename packed_type<int8_t, num_elems<Qk_vec_m>::value>::type;
           using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec_m>::value>::type;
           const auto k_scaling = params.qkv_scale_out[1];
           const auto k_quant =
               *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[qk_offset]);

           convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
       }
       else {
           k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
                   vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k[qk_offset])) :
                   k;
       }
   }


3. 获取本轮query和本轮key的点乘结果

  • 将query存到共享内存中,这里主要是为了减少访问全局内存的次数,因为后面对cache的key做运算还会用到。

  • 接着就是key cache的index计算,把key存到全局cache中,因为key cache的设计(设计部分提到)导致这里的计算逻辑比较复杂,但耐心推导还是可以算出来。

  • 接着进行了本轮step的query和key的点乘运算。 qk = dot (q, k);

  • 在warp维度和block维度归约求和,得到完整的query和key的点乘结果。

  • 将本轮qk结果记录到 qk_max 用于后续求全局最大值,将本轮qk结果记录到 qk_smem 用于后续softmax计算。

    if (!is_masked) {
       // Store the Q values to shared memory.
       *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q;

       // Store Dh values of k_bias into smem, since will need to add later
       // if params.timestep == 0
       if (DO_CROSS_ATTENTION && params.timestep == 0) {
           *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias;
       }

       // Write the K values to the global memory cache.
       //
       // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
       // system. We designed it this way as it allows much better memory loads (and there are many
       // more loads) + the stores are really "write and forget" since we won't need the ack before
       // the end of the kernel. There's plenty of time for the transactions to complete.

       // The 16B chunk written by the thread.
       int co = tidx / QK_VECS_IN_16B;
       // The position of the thread in that 16B chunk.
       int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;

       // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
       int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
                    // params.timestep*QK_ELTS_IN_16B +
                    tlength_circ * QK_ELTS_IN_16B + ci;

       if (handle_kv) {
           // Trigger the stores to global memory.
           if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
               *reinterpret_cast(¶ms.k_cache[offset]) = vec_conversion(k);
           }
       }

       // Compute \sum_i Q[i] * K^T[i] for the current timestep.
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
       using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type;






请到「今天看啥」查看全文