cuda_src = r""" constexpr int B_r = 16; constexpr int B_c = 16; constexpr int d = 128; constexpr int o_per_thread_x = 1; constexpr int o_per_thread_y = 128/32;
#define NEG_INFINITY __int_as_float(0xff800000)
extern "C" __global__ void silty_attn(float* out, float* out_l, float *K, float *V, float *Q, float scaling, int n, int T_r, int T_c) { int tid_x = threadIdx.x; int tid_y = threadIdx.y; __shared__ float Q_i[B_r][d]; __shared__ float K_j[B_c][d]; __shared__ float V_j[B_c][d];
for (int ii = 0; ii for (int dd = 0; dd out[(ii + blockDim.x * tid_x + i * B_r) * d + dd + blockDim.y * tid_y] = o_i[ii][dd] / l_i[ii]; } out_l[ii + blockDim.x * tid_x + i * B_r] = 1 / l_i[ii]; } } } """
deffn(): err = cuda.cuLaunchKernel( kernel, 1, # grid x dim 1, # grid y dim 1, # grid z dim 32, # block x dim 32, # block y dim 1, # block z dim 0, # dynamic shared memory torch.cuda.current_stream().stream_id, # stream args.data_ptr(), # kernel arguments 0, # extra (ignore) ) fn()
作者这里实现的kernel感觉比较奇怪,特别是下标的混用bug估计会导致这个kernel存在正确性问题,此外这个Kernel里面每个线程具体负责哪些计算很难看得清楚,因此我在后面新增一节展示一下 https://github.com/tspeterkim/flash-attention-minimal 中对 Flash Attention 的极简 cuda 实现,这个实现非常清晰易懂。
defflash_attention(Q, K, V, B_r=64, B_c=768): """ 使用分块计算和在线softmax校正执行flash attention算法。 """ O = torch.zeros((N, d)) # 初始化输出矩阵,对应伪代码的第2行 l = torch.zeros((N, 1)) # 存储softmax分母,对应伪代码的第2行 m = torch.full((N, 1), -torch.inf) # 存储每个block的最大值,对应伪代码的第2行
# 对应伪代码的第5行,for 1<=j<=T_c,注意这里是把K, V分成了T_c=[N/B_c]块,每一块的大小是[B_c, d]这么大 # 所以在python实现的时候就直接通过一个步长为B_c的循环来处理 for j in range(0, N, B_c): # 下面三行就对应了伪代码的第6行,Load Kj, Vj from HBM to on-chip SRAM # 但是这里是单纯的 python 实现,我们不可能真的把这一块内存从HBM上放到SRAM上 # 这里只是一个伪代码的逻辑说明,可以假装它做到了,因为在Triton里面真的可以在Python层做到。 j_end = j + B_c Kj = K[j:j_end, :] Vj = V[j:j_end, :]
# 对应伪代码的第7行,for 1<=i # 所以在python实现的时候就直接通过一个步长为B_r的循环来处理 for i in range(0, N, B_r): i_end = i + B_r mi = m[i:i_end, :] li = l[i:i_end, :] Oi = O[i:i_end, :] Qi = Q[i:i_end, :]
本课程作者实现的 Flash Attention cuda kernel比较奇怪,这里推荐一个非常简单清晰的 Flash Attention 开源 cuda 实现:https://github.com/tspeterkim/flash-attention-minimal 。
#include #include #include
__global__ void forward_kernel(const float* Q, const float* K, const float* V, const int N, const int d, const int Tc, const int Tr, const int Bc, const int Br, const float softmax_scale, float* l, float *m, float* O) { int tx = threadIdx.x; int bx = blockIdx.x; int by = blockIdx.y; // batch and head index
// Offset into Q,K,V,O,l,m - different for each batch and head int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = nh int lm_offset = (bx * gridDim.y * N) + (by * N); // offset for l and m
// Define SRAM for Q,K,V,S extern __shared__ float sram[]; int tile_size = Bc * d; // size of Qi, Kj, Vj float* Qi = sram; float* Kj = &sram[tile_size]; float* Vj = &sram[tile_size * 2]; float* S = &sram[tile_size * 3];
for (int j = 0; j // Load Kj, Vj to SRAM for (int x = 0; x // Bc个线程,每个线程负责K的一列,注意转置之后,该矩阵列优先 Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x]; // Bc个线程,每个线程负责V的一行,注意该矩阵行优先 Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x]; } __syncthreads(); // such that the inner loop can use the correct Kj, Vj
for (int i = 0; i // Load Qi to SRAM, l and m to registers for (int x = 0; x Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x]; }
// S = QK^T, row_m = rowmax(S) float row_m = -INFINITY; // tx 用来枚举 S:(Br, Bc) 的行,这里的 for y in Bc的循环用来枚举S所有的列 // 因为每一行都要和所有的列做点积得到S for (int y = 0; y float sum = 0; for (int x = 0; x // 每个线程负责每个 S:(Br, Bc) 中一行的计算,每个thread访问的Qi对应行的起始地址为 tx*d sum += Qi[(tx * d) + x] * Kj[(y * d) + x]; } sum *= softmax_scale; S[(Bc * tx) + y] = sum;
if (sum > row_m) row_m = sum; }
// P = exp(S - row_m), row_l = rowsum(P) float row_l = 0; for (int y = 0; y S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m); row_l += S[(Bc * tx) + y]; }
// Compute new m and l float row_m_new = max(row_m_prev, row_m); float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);
// Write O, l, m to HBM for (int x = 0; x float pv = 0; // Pij * Vj for (int y = 0; y pv += S[(Bc * tx) + y] * Vj[(y * d) + x]; } O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \ * ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \ + (__expf(row_m - row_m_new) * pv)); } m[lm_offset + (Br * i) + tx] = row_m_new; l[lm_offset + (Br * i) + tx] = row_l_new; } __syncthreads(); // otherwise, thread can use the wrong Kj, Vj in inner loop } }
torch::Tensor forward(torch::Tensor Q, torch::Tensor K, torch::Tensor V) { // TODO: determine Bc, Br dynamically const int Bc = 32; const int Br = 32;
const int B = Q.size(0); const int nh = Q.size(1); const int N = Q.size(2); const int d = Q.size(3);
const int Tc = ceil((float) N / Bc); const int Tr = ceil((float) N / Br); const float softmax_scale = 1.0 / sqrt(d);
// Initialize O, l, m to HBM auto O = torch::zeros_like(Q); auto l = torch::zeros({B, nh, N}); auto m = torch::full({B, nh, N}, -INFINITY); torch::Device device(torch::kCUDA); l = l.to(device); m = m.to(device);
// Calculate SRAM size needed per block const int sram_size = (3 * Bc * d * sizeof(float)) + (Bc * Br * sizeof(float)); int max_sram_size; cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0); printf("Max shared memory: %d, requested shared memory: %d \\n", max_sram_size, sram_size);
dim3 grid_dim(B, nh); // batch_size x num_heads dim3 block_dim(Bc); // Bc threads per block