// SGEMM: Block Tile + K Tile, with smem // Block Tile (BM, BN) + K Tile (BK=32) // grid((N + BN - 1) / BN, (M + BM - 1) / BM), block(BN, BM) // a: MxK, b: KxN, c: MxN, compute: c = a * b, all row major __global__ voidsgemm(float* a, float* b, float* c, int M, int N, int K){ // [1] Block Tile: 32x32的block处理c上一块32x32的元素计算 // [2] K Tile: 使用共享内存,并将K分块为BK大小的块 constexprint
BM = 32; constexprint BN = 32; constexprint BK = 32; __shared__ float s_a[BM][BK], s_b[BK][BN];
int bx = blockIdx.x; int by = blockIdx.y; int tx = threadIdx.x; int ty = threadIdx.y; int tid = threadIdx.y * blockDim.x + tx; // tid within the block // load values to shared memory, 32x32 threads working together // to fetch data along the row direction of a and b both for s_a // and s_b 32x32x4x2=8KB, we use 32x32 threads within block to // load 32x32 elements from global memory to shared memory, namely, // each thread will load 1 element. int load_smem_a_m = tid / 32; // 0~31, tid / 32, tid / BM, threadIdx.y int load_smem_a_k = tid % 32; // 0~31, tid % 32, tid % BK, threadIdx.x int load_smem_b_k = tid / 32; // 0~31, tid / 32, tid / BK, threadIdx.y int load_smem_b_n = tid % 32; // 0~31, tid % 32, tid % BN, threadIdx.x int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c // if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
float sum = 0.f; for (int bk = 0; bk 1) / BK; ++bk) { int load_gmem_a_k = bk * BK + load_smem_a_k; int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; s_a[load_smem_a_m][load_smem_a_k] = a[load_gmem_a_addr]; int load_gmem_b_k = bk * BK + load_smem_b_k; int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; s_b[load_smem_b_k][load_smem_b_n] = b[load_gmem_b_addr]; __syncthreads(); #pragma unroll for (int k = 0; k int comp_smem_a_m = load_smem_a_m; int comp_smem_b_n = load_smem_b_n; sum += s_a[comp_smem_a_m][k] * s_b[k][comp_smem_b_n]; } __syncthreads(); } int store_gmem_c_m = load_gmem_a_m; int store_gmem_c_n = load_gmem_b_n; int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n; c[store_gmem_c_addr] = sum; }
// SGEMM: Block Tile + Thread Tile + K Tile + Vec4, with smem // BK:TILE_K=8 BM=BN=128 // TM=TN=8 增加计算密度 BM/TM=16 BN/TN=16 // dim3 blockDim(BN/TN, BM/TM); // dim3 gridDim((N + BN - 1) / BN, (M + BM - 1) / BM) __global__ voidsgemm_thread_tile_vec4( float* a, float* b, float* c, int M, int N, int K){ // [1] Block Tile: 一个16x16的block处理C上大小为128X128的一个目标块 // [2] Thread Tile: 每个thread负责计算TM*TN(8*8)个元素,增加计算密度 // [3] K Tile: 将K分块,每块BK大小,迭代(K+BK-1/BK)次, // 每次计算TM*TN个元素各自的部分乘累加 // [4] Vectorize: 减少load和store指令,使用float4 constexprint BM = 128; constexprint BN = 128; constexprint BK = 8; constexprint TM = 8; constexprint TN = 8;
int bx = blockIdx.x; int by = blockIdx.y; int tx = threadIdx.x; int ty = threadIdx.y; int tid = threadIdx.y * blockDim.x + tx; // tid within the block __shared__ float s_a[BM][BK], s_b[BK][BN]; // 2*128*8*4=8KB
// 0. 先计算shared memory中的索引 // tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=8 按行读取 A行主序 // 对于s_a每行8个数据,每个线程读取4个,需要2个线程;总共128行,需要128x2刚好256线程 int load_smem_a_m = tid / 2; // tid/2 (128/8)*(128/8)=256 threads per block, tid/2->[0,128), BM=128 0~127 int load_smem_a_k = (tid % 2 == 0) ? 0 : 4; // (tid%2 == 0) ? 0 : 4, col of s_a 0,4 // tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=8 BN=128 按行读取 B行主序 // 对于s_b每行128个数据,每个线程读4个数据,需要32个线程;总共8行,需要32x8=256个线程 int load_smem_b_k = tid / 32; // tid/32, row of s_b 256/32=8 行 0~7 int load_smem_b_n = (tid % 32) * 4; // (tid % 32) * 4, col of s_b 0,4,...,124 // 1. 再计算全局内存中的索引 // 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块 int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
float r_c[TM][TN] = {0.0}; // 8x8 // 2. 先对K进行分块,每块BK大小 for (int bk = 0; bk 1) / BK; ++bk) { // 加载数据到共享内存smem s_a BM*BK 128*8 vectorize float4 int load_gmem_a_k = bk * BK + load_smem_a_k; // global col of a int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; FLOAT4(s_a[load_smem_a_m][load_smem_a_k]) = FLOAT4(a[load_gmem_a_addr]); // 加载数据到共享内存smem s_b BK*BN 8*128 vectorize float4 int load_gmem_b_k = bk * BK + load_smem_b_k; // global row of b int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; FLOAT4(s_b[load_smem_b_k][load_smem_b_n]) = FLOAT4(b[load_gmem_b_addr]); __syncthreads(); #pragma unroll for (int k = 0; k // 3. 每个线程负责计算BM*BN(12x128)中的TM*TN(8x8)个元素 #pragma unroll for (int m = 0; m #pragma unroll for (int n = 0; n // k from 0~7,0 ~ BK, ty and tx range from 0 to 15, 16x8=128 int comp_smem_a_m = ty * TM + m; // 128*8 128/TM(8)=16 M方向 16线程 int comp_smem_b_n = tx * TN + n; // 8*128 128/TN(8)=16 N方向 16线程 r_c[m][n] += s_a[comp_smem_a_m][k] * s_b[k][comp_smem_b_n]; } } } __syncthreads(); }
#pragma unroll for (int m = 0; m int store_gmem_c_m = by * BM + ty * TM + m; #pragma unroll for (int n = 0; n 4) { int store_gmem_c_n = bx * BN + tx * TN + n; int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n; FLOAT4(c[store_gmem_c_addr]) = FLOAT4(r_c[m][n]); } } }
这里gemm的实现比较简单,只使用了CUDA Cores,并且只实现Block Tile + K Tile以及Block Tile + K Tile+Thread Tile+向量化的版本。主要在于如何加载gmem中的数据到smem,也就是把全局内存中的数据索引mapping到共享内存中的。核心思维:把一个block中的线程id按照线性来理解,然后把这个线性的id和全局内存索引以及共享内存索引进行匹配。比如Block Tile + K Tile的实现,block内一共32x32个Threads,需要加载到smem的数据也是32x32,那么,最简单的做法,只需要每个线程加载一个互不重复数据即可。NOTE,本文的gemm kernel修改自:紫气东来:CUDA(三):通用矩阵乘法:从入门到熟练(https://zhuanlan.zhihu.com/p/657632577)
// Block All Reduce Sum // grid(N/128), block(128) // a: Nx1, y=sum(a) template<constint NUM_THREADS = 128> __global__ voidblock_all_reduce_sum(float* a, float* y, int N){ int tid = threadIdx.x; int idx = blockIdx.x * NUM_THREADS + tid; constexprint NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE; __shared__ float reduce_smem[NUM_WARPS]; // keep the data in register is enougth for warp operaion. float sum = (idx 0.0f; int warp = tid / WARP_SIZE; int lane = tid % WARP_SIZE; // perform warp sync reduce. sum = warp_reduce_sum(sum); // warp leaders store the data to shared memory. if (lane == 0) reduce_smem[warp] = sum; __syncthreads(); // make sure the data is in shared memory. // the first warp compute the final sum. sum = (lane 0.0f; if (warp == 0) sum = warp_reduce_sum(sum); if (tid == 0) atomicAdd(y, sum); }
// Block All Reduce Sum + float4 // grid(N/128), block(128/4) // a: Nx1, y=sum(a) template<constint NUM_THREADS = 128/4> __global__ voidblock_all_reduce_sum_vec4(float* a, float* y, int N){ int tid = threadIdx.x; int idx = (blockIdx.x * NUM_THREADS + tid) * 4; constexprint NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE; __shared__ float reduce_smem[NUM_WARPS];
float4 reg_a = FLOAT4(a[idx]); // keep the data in register is enougth for warp operaion. float sum = (idx 0.0f; int warp = tid / WARP_SIZE; int lane = tid % WARP_SIZE; // perform warp sync reduce. sum = warp_reduce_sum(sum); // warp leaders store the data to shared memory. if (lane == 0) reduce_smem[warp] = sum; __syncthreads(); // make sure the data is in shared memory. // the first warp compute the final sum. sum = (lane 0.0f; if (warp == 0) sum = warp_reduce_sum(sum); if (tid == 0) atomicAdd(y, sum); }
block all reduce是在warp reduce的基础上进行的,reduce_smem这部分的共享内存申请无法避免,这是用来同步每个warp之间得到局部结果。注意,最后,还需要atomicAdd做一个block级别的原子操作,以得到全局的和。float4向量化优化访存,可以减缓WarpScheduler发送指令的压力。
0x05 sgemv k32/k128/k16 kernel
// SGEMV: Warp SGEMV K32 // 假设K为32的倍数,每个warp负责一行 // grid(M/4), block(32,4) blockDim.x=32=K, blockDim.y=4 // a: MxK, x: Kx1, y: Mx1, compute: y = a * x __global__ voidsgemv_k32(float* a, float* x, float* y, int M, int K){ int tx = threadIdx.x; // 0~31 int ty = threadIdx.y; // 0~4 int bx = blockIdx.x; // 0~M/4 int lane = tx % WARP_SIZE; // 0~31 int m = bx * blockDim.y + ty; // (0~M/4) * 4 + (0~3) if (m float sum = 0.0f; int NUM_WARPS = (K + WARP_SIZE - 1) / WARP_SIZE; #pragma unroll for (int w = 0; w // 若NUM_WARPS>=2,先将当前行的数据累加到第一个warp中 int k = w * WARP_SIZE + lane; sum += a[m * K + k] * x[k]; } sum = warp_reduce_sum(sum); if (lane == 0) y[m] = sum; } }
// SGEMV: Warp SGEMV K128 + Vec4 // 假设K为128的倍数 float4 // grid(M/4), block(32,4) blockDim.x=32=K, blockDim.y=4 // a: MxK, x: Kx1, y: Mx1, compute: y = a * x __global__ voidsgemv_k128(float* a, float* x, float* y, int M, int K){ // 每个线程负责4个元素,一个warp覆盖128个元素 int tx = threadIdx.x; // 0~31 int ty = threadIdx.y; // 0~3 int bx = blockIdx.x; // 0~M/4 int lane = tx % WARP_SIZE; // 0~31 int m = blockDim.y * bx + ty; // (0~M/4) * 4 + (0~3)
if (m float sum = 0.0f; // process 4*WARP_SIZE elements per warp. int
NUM_WARPS = (((K + WARP_SIZE - 1) / WARP_SIZE) + 4 - 1) / 4; #pragma unroll for (int w = 0; w int k = (w * WARP_SIZE + lane) * 4; float4 reg_x = FLOAT4(x[k]); float4 reg_a = FLOAT4(a[m * K + k]); sum += (reg_a.x * reg_x.x + reg_a.y * reg_x.y + reg_a.z * reg_x.z + reg_a.w * reg_x.w); } sum = warp_reduce_sum(sum); if(lane == 0) y[m] = sum; } }
// SGEMV: Warp SGEMV K16 // 假设K为16 // NUM_THREADS=128, NUM_WARPS=NUM_THREADS/WARP_SIZE; // NUM_ROWS=NUM_WARPS * ROW_PER_WARP, grid(M/NUM_ROWS), block(32,NUM_WARPS) // a: MxK, x: Kx1, y: Mx1, compute: y = a * x template<constint ROW_PER_WARP = 2> __global__ voidsgemv_k16(float* A, float* x, float* y, int M, int K){ constexprint K_WARP_SIZE = (WARP_SIZE + ROW_PER_WARP - 1) / ROW_PER_WARP; int tx = threadIdx.x; // 0~31 int ty = threadIdx.y; // 0~NUM_WARPS int bx = blockIdx.x; // 0~M/NUM_ROWS (NUM_ROWS=NUM_WARPS * ROW_PER_WARP) int lane = tx % WARP_SIZE; // 0~31 int k = lane % K_WARP_SIZE; // 0~15 // gloabl row of a: MxK and y:Mx1, blockDim.y=NUM_WARPS int m = (blockDim.y * bx + ty) * ROW_PER_WARP + lane / K_WARP_SIZE; if (m float sum = A[m * K + k] * x[k]; sum = warp_reduce_sum(sum); // 注意是k == 0,而不是lane == 0 if(k == 0) y[m] = sum; } }
// Dot Product // grid(N/128), block(128) // a: Nx1, b: Nx1, y=sum(elementwise_mul(a,b)) template<constint NUM_THREADS = 128> __global__ voiddot(float* a, float* b, float* y, int N){ int tid = threadIdx.x; int idx = blockIdx.x * NUM_THREADS + tid; constexprint NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE; __shared__ float reduce_smem[NUM_WARPS];
// keep the data in register is enougth for warp operaion. float prod = (idx 0.0f; int warp = tid / WARP_SIZE; int lane = tid % WARP_SIZE; // perform warp sync reduce. prod = warp_reduce_sum(prod); // warp leaders store the data to shared memory. if (lane == 0) reduce_smem[warp] = prod; __syncthreads(); // make sure the data is in shared memory. // the first warp compute the final sum. prod = (lane 0.0f; if (warp == 0) prod = warp_reduce_sum(prod); if (tid == 0) atomicAdd(y, prod); }
// Dot Product + Vec4 // grid(N/128), block(128/4) // a: Nx1, b: Nx1, y=sum(elementwise_mul(a,b)) template<constint NUM_THREADS = 128/4> __global__ voiddot_vec4(float* a, float* b, float* y, int N){ int tid = threadIdx.x; int idx = (blockIdx.x * NUM_THREADS + tid) * 4; constexprint NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE; __shared__ float reduce_smem[NUM_WARPS];
float4 reg_a = FLOAT4(a[idx]); float4 reg_b = FLOAT4(b[idx]); float prod = (idx + reg_a.z * reg_b.z + reg_a.w * reg_b.w) : 0.0f; int warp = tid / WARP_SIZE; int lane = tid % WARP_SIZE; // perform warp sync reduce. prod = warp_reduce_sum(prod); // warp leaders store the data to shared memory. if (lane == 0) reduce_smem[warp] = prod; __syncthreads(); // make sure the data is in shared memory. // the first warp compute the final sum. prod = (lane 0.0f; if (warp == 0) prod = warp_reduce_sum(prod); if (tid == 0) atomicAdd(y, prod); }
dot product kernel的核心就是block reduce,不多说了。
0x07 elementwise, elementwise + vec4
// ElementWise Add // grid(N/128), block(128) // a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b) __global__ voidelementwise_add(float* a, float* b, float* c, int N){ int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx }
// Softmax x: N, y: N // grid(N/128), block(K=128) template<constint NUM_THREADS = 128> __global__ voidsoftmax(float* x, float* y, float* total, int N){ constint tid = threadIdx.x; constint idx = blockIdx.x * blockDim.x + tid; constexprint NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE; __shared__ float reduce_smem[NUM_WARPS];
float sum = (idx 0.0f; int warp = tid / WARP_SIZE; int lane = tid % WARP_SIZE; sum = warp_reduce_sum(sum); if (lane == 0) reduce_smem[warp] = sum; __syncthreads(); // compute the final sum in each warp sum = (lane 0.0f; sum = warp_reduce_sum(sum); // sum(e^x_0,...,e^x_n-1) // get the total sum of all blocks. if (tid == 0) atomicAdd(total, sum); __threadfence(); // grid level memory fence 注意这里需要网格级别的内存同步 // e^x_i/sum(e^x_0,...,e^x_n-1) if (idx }
// Softmax x: N, y: N // grid(N/128), block(K=128) template<constint NUM_THREADS = 128> __global__ voidsoftmax_v2(float* x, float* y, float* total, int N){ constint tid = threadIdx.x; constint idx = blockIdx.x * blockDim.x + tid;
float exp_val = (idx 0.0f; float sum = block_reduce_sum(exp_val); // get the total sum of all blocks. if (tid == 0) atomicAdd(total, sum); __threadfence(); // grid level memory fence 注意这里需要网格级别的内存同步 // e^x_i/sum(e^x_0,...,e^x_n-1) if (idx }
// Softmax Vec4 x: N, y: N // grid(N/128), block(128/4) template<constint NUM_THREADS = 128/4> __global__ voidsoftmax_v2_vec4(float* x, float* y, float* total, int N){ constint tid = threadIdx.x; constint idx = (blockIdx.x * blockDim.x + tid) * 4;
float4 reg_x = FLOAT4(x[idx]); float4 reg_exp; reg_exp.x = (idx 0.0f; reg_exp.y = (idx 0.0f; reg_exp.z = (idx 0.0f; reg_exp.w = (idx 0.0f; float exp_val = (reg_exp.x + reg_exp.y + reg_exp.z + reg_exp.w); float sum = block_reduce_sum(exp_val); // get the total sum of all blocks. if (tid == 0) atomicAdd(total, sum); __threadfence(); // grid level memory fence 注意这里需要网格级别的内存同步 // e^x_i/sum(e^x_0,...,e^x_n-1) if (idx float4 reg_y; reg_y.x = reg_exp.x / (*total); reg_y.y = reg_exp.y / (*total); reg_y.z = reg_exp.z / (*total); reg_y.w = reg_exp.w / (*total); FLOAT4(y[idx]) = reg_y; } }
// Relu x: N, y: N y=max(0,x) // grid(N/128), block(K=128) __global__ voidrelu(float* x, float* y, int N){ int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx 0.0f, x[idx]); }
// Relu x: N, y: N y=max(0,x) Vec4 // grid(N/128/4), block(128/4) __global__ voidrelu_vec4(float* x, float* y, int N){ int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; if (idx float4 reg_x = FLOAT4(x[idx]); float4 reg_y; reg_y.x = fmaxf(0.0f, reg_x.x); reg_y.y = fmaxf(0.0f, reg_x.y); reg_y.z = fmaxf(0.0f, reg_x.z); reg_y.w = fmaxf(0.0f, reg_x.w); FLOAT4(y[idx]) = reg_y; } }
0x0c layer_norm, layer_norm + vec4
// Layer Norm: x: NxK(K=128<1024), y': NxK, y'=x-mean(x)/std(x) each row // mean(x) = sum(x)/K, 1/std(x) = rsqrtf( sum( (x-mean(x))^2 )/K ) each row // grid(N*K/K), block(K<1024) N=batch_size*seq_len, K=hidden_size // y=y'*g + b (g: scale, b: bias) template<constint NUM_THREADS=128> __global__ voidlayer_norm(float* x, float* y, float g, float b, int N, int K){ int tid = threadIdx.x; // 0..K-1 int bid = blockIdx.x; // 0..N-1 int idx = bid * blockDim.x + threadIdx.x; constfloat epsilon = 1e-5f;
__shared__ float s_mean; // shared within block __shared__ float s_variance; // shared within block float value = (idx 0.0f; // load once only float sum = block_reduce_sum(value); if (tid == 0) s_mean = sum / (float) K; // wait for s_mean in shared memory to be ready for all threads __syncthreads(); float variance = (value - s_mean) * (value - s_mean); variance = block_reduce_sum(variance); if (tid == 0) s_variance = rsqrtf(variance / (float) K + epsilon); // wait for s_variance in shared memory to be ready for all threads