output = [] for i in range(n): kv = ratio * kv.to(torch.float32) + torch.einsum( "... n d, ... n e -> ... d e", k[:, :, i : i + 1], v[:, :, i : i + 1], ) qkv = torch.einsum( "... n e, ... e d -> ... n d", q[:, :, i : i + 1].to(torch.float32), kv.to(torch.float32), ) output.append(qkv) output = torch.concat(output, dim=-2)
constint32_t tid = threadIdx.x; constint32_t current_head = blockIdx.x; constint32_t b = current_head / num_heads; constint32_t h = current_head % num_heads;
if (b >= batch_size) return;
constint32_t qk_offset = b * num_heads * dim + h * dim; constint32_t v_offset = b * num_heads * embed_dim + h * embed_dim; constint32_t kv_offset = b * num_heads * dim * embed_dim + h * dim * embed_dim;
// 1. 计算新的kv: new_kv = ratio * past_kv + k * v^T constfloat ratio = expf(-1.0f * slope[h]); for (int d = tid; d T k_value = k[qk_offset + d]; for (int e = 0; e constint32_t kv_index = kv_offset + d * embed_dim + e; new_kv[kv_index] = ratio * past_kv[kv_index] + k_value * v[v_offset + e]; } }
__syncthreads(); // 确保所有线程完成new_kv的计算
// 2. 计算qkv attention输出: output = q * new_kv for (int e = tid; e float sum = 0.0f;
# Store KV with 2D masking tl.store( KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], kv.to(KV.dtype.element_ty), mask=(d_mask[:, None] & e_mask[None, :]), )
# Compute matrix-vector multiplication using element-wise operations and reduction o = tl.sum(q[:, None] * kv, axis=0)
# Store output with masking tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
deftriton_lightning_attn_decode(q, k, v, kv, s): """Triton implementation of Lightning Attention decode operation""" b, h, n, d = q.shape e = v.shape[-1] assert n == 1, "Sequence length must be 1 in decode mode"
# Get padded dimensions (power of 2) d_padded = next_power_of_2(d) e_padded = next_power_of_2(e)
constint32_t tid = threadIdx.x; constint32_t current_head = blockIdx.x; constint32_t b = current_head / num_heads; constint32_t h = current_head % num_heads;
if (b >= batch_size) return;
constint32_t qk_offset = b * num_heads * dim + h * dim; constint32_t v_offset = b * num_heads * embed_dim + h * embed_dim; constint32_t kv_offset = b * num_heads * dim * embed_dim + h * dim * embed_dim;
for (int d = tid; d q_shared[d] = q[qk_offset + d]; k_shared[d] = k[qk_offset + d]; } for (int e = tid; e v_shared[e] = v[v_offset + e]; }
__syncthreads();
constfloat ratio = expf(-1.0f * slope[h]);
for (int d = tid; d T k_val = k_shared[d]; for (int e = 0; e int past_kv_idx = kv_offset + d * embed_dim + e; T v_val = v_shared[e]; float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val; int shared_idx = d * (embed_dim + 1) + e; new_kv_shared[shared_idx] = new_val; } }
__syncthreads();
for (int idx = tid; idx int d = idx / embed_dim; int e = idx % embed_dim; int shared_idx = d * (embed_dim + 1) + e; int global_idx = kv_offset + idx; new_kv[global_idx] = new_kv_shared[shared_idx]; }
__syncthreads();
for (int e = tid; e float sum = 0.0f; for (int d = 0; d int shared_idx = d * (embed_dim + 1) + e; sum += q_shared[d] * new_kv_shared[shared_idx]; } output_shared[e] = static_cast(sum); }
__syncthreads();
if (tid == 0) { for (int e = 0; e output[v_offset + e] = output_shared[e]; } } }
把这个结果反馈给Cursor,Cursor现在可以知道主要问题是写new_kv的时候内部循环·for (int e = 0; e < embed_dim; ++e)·导致线程在访问全局内存时stride太大,然后内存没有合并访问,且每个线程需要写入多次全局内存,增加了内存事务数。这也是我们看到这个kernel写全局内存的时候比Triton多了几倍的原因。知道原因之后Cursor就可以改成正确的代码了。代码如下:
constint32_t tid = threadIdx.x; constint32_t current_head = blockIdx.x; constint32_t b = current_head / num_heads; constint32_t h = current_head % num_heads;
if (b >= batch_size) return;
constint32_t qk_offset = b * num_heads * dim + h * dim; constint32_t v_offset = b * num_heads * embed_dim + h * embed_dim; constint32_t kv_offset = b * num_heads * dim * embed_dim + h * dim * embed_dim;
for (int d = tid; d q_shared[d] = q[qk_offset + d]; k_shared[d] = k[qk_offset + d]; } for (int e = tid; e v_shared[e] = v[v_offset + e]; }
__syncthreads();
constfloat ratio = expf(-1.0f * slope[h]);
for (int d = tid; d T k_val = k_shared[d]; for (int e = 0; e int past_kv_idx = kv_offset + d * embed_dim + e; T v_val = v_shared[e]; float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val; int shared_idx = d * (embed_dim + 1) + e; new_kv_shared[shared_idx] = new_val; } }
__syncthreads();
for (int idx = tid; idx int d = idx / embed_dim; int e = idx % embed_dim; int shared_idx = d * (embed_dim + 1) + e; int global_idx = kv_offset + idx; new_kv[global_idx] = new_kv_shared[shared_idx]; }
__syncthreads();
for (int e = tid; e float sum = 0.0f; for (int d = 0; d int shared_idx = d * (embed_dim + 1) + e; sum += q_shared[d] * new_kv_shared[shared_idx]; } output_shared[e] = static_cast(sum); }
__syncthreads();
if (tid == 0) { for (int e = 0; e output[v_offset + e] = output_shared[e]; } } }