cublas_wrapper_->Gemm(CUBLAS_OP_N, CUBLAS_OP_N, d_model_, // n batch_size, local_hidden_units_, // k attention_weights->attention_output_weight.kernel, d_model_, // n context_buf_, local_hidden_units_, // k attention_out, d_model_ /* n */);
// Use smem_size_in_bytes (above) to determine the amount of shared memory.extern __shared__ char smem_[];
// The shared memory for the Q*K^T values and partial logits in softmax.float* qk_smem = reinterpret_cast<float*>(smem_);
// The shared memory for the logits. For FP32, that's the same buffer as qk_smem.char* logits_smem_ = smem_; if (sizeof(Tk) !=4) { // TODO - change to tlengthconstint max_timesteps = min(params.timestep, params.memory_max_len); logits_smem_ += (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len +1, 4) *16: div_up(max_timesteps +1, 4) *16; } Tk* logits_smem = reinterpret_cast<Tk*>(logits_smem_);
// The shared memory to do the final reduction for the output values. Reuse qk_smem. Tk* out_smem = reinterpret_cast<Tk*>(smem_);
// The shared memory buffers for the block-wide reductions. One for max, one for sum. __shared__ float red_smem[WARPS_PER_BLOCK *2];
// A vector of Q or K elements for the current timestep. using Qk_vec_k =typename Qk_vec_k_<T, Dh_MAX>::Type; // with kernel-used precision using Qk_vec_m =typename Qk_vec_m_<T, Dh_MAX>::Type; // with memory-used precision // Use alignment for safely casting the shared buffers as Qk_vec_k.// Shared memory to store Q inputs. __shared__ __align__(sizeof(Qk_vec_k)) Tk q_smem[Dh_MAX];
// This is one of the reasons we should have a separate kernel for cross attention __shared__ __align__(sizeof(Qk_vec_k)) Tk bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1];
这段代码对cache layout进行了一些说明和size计算,具体设计参考优化设计部分。
// The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8/16 for FP32/FP16/FP8. Since each thread// owns x elements, we have to decompose the linear index into chunks of x values and the posi-// tion of the thread in that chunk. // The number of elements in a chunk of 16B (that's the x in the above formula). constexpr int QK_ELTS_IN_16B =16/sizeof(T); // The number of K vectors in 16B. constexpr int QK_VECS_IN_16B =16/sizeof(Qk_vec_m);
Qk_vec_k k; zero(k); if (DO_CROSS_ATTENTION) { // The 16B chunk written by the thread.int co = tidx / QK_VECS_IN_16B; // The position of the thread in that 16B chunk.int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + // params.timestep*QK_ELTS_IN_16B + tlength * QK_ELTS_IN_16B + ci; k =!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(¶ms.k_cache[offset])) : k; } else { if (params.int8_mode ==2) { using Packed_Int8_t =typename packed_type<int8_t, num_elems<Qk_vec_m>::value>::type; using Packed_Float_t =typename packed_type<float, num_elems<Qk_vec_m>::value>::type; constauto k_scaling = params.qkv_scale_out[1]; constauto k_quant = *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[qk_offset]);
if (!is_masked) { // Store the Q values to shared memory. *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q;
// Store Dh values of k_bias into smem, since will need to add later // if params.timestep == 0 if (DO_CROSS_ATTENTION && params.timestep == 0) { *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias; }
// Write the K values to the global memory cache. // // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory // system. We designed it this way as it allows much better memory loads (and there are many // more loads) + the stores are really "write and forget" since we won't need the ack before // the end of the kernel. There's plenty of time for the transactions to complete.
// The 16B chunk written by the thread. int co = tidx / QK_VECS_IN_16B; // The position of the thread in that 16B chunk. int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + // params.timestep*QK_ELTS_IN_16B + tlength_circ * QK_ELTS_IN_16B + ci;
if (handle_kv) { // Trigger the stores to global memory. if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { *reinterpret_cast(¶ms.k_cache[offset]) = vec_conversion(k); } }
// Compute \sum_i Q[i] * K^T[i] for the current timestep. #ifdef MMHA_USE_FP32_ACUM_FOR_FMA using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type;