MLNLP
社区是国内外知名的机器学习与自然语言处理社区,受众覆盖国内外NLP硕博生、高校老师以及企业研究人员。
社区的愿景
是促进国内外自然语言处理,机器学习学术界、产业界和广大爱好者之间的交流和进步,特别是初学者同学们的进步。
AI 的快速发展,伴随而来的是大计算量。这就自然而然的引出了一个问题:如何减少 AI 对计算的需求,并提高现有 AI 计算效率。
为了回答这一问题,来自斯坦福的研究者在博客《GPUs Go Brrr》中给出了答案。
博客地址:https://hazyresearch.stanford.edu/blog/2024-05-12-tk
文章主要专注于两个问题:一是硬件真正需要什么?二是如何满足硬件需求?
文章用大量篇幅讨论了如何让 GPU 更快的运行,并发布了一个库 ThunderKittens,用户可以很容易地在 CUDA 上编写快速的深度学习内核。其具有以下特点:
-
简单,ThunderKittens 写起来非常简单。
-
可扩展性,如果用户需要 ThunderKittens 无法提供的功能,可以进行功能扩展。
-
GitHub 链接:https://github.com/HazyResearch/ThunderKittens
ThunderKittens 使得一些棘手的事情变得非常简单,从而在现代硬件上实现了非常高的利用率。项目中,作者用 ThunderKittens 编写了一个 RTX 4090 简单的 FlashAttention-2 内核,代码总共有 58 行代码(不包括空格),结果显示,ThunderKittens 在 RTX 4090 上实现了大约 122 TFLOP(理论最大值的 74%)。此外,内核程序只有 100 行的情况下,ThunderKittens 在 H100 上的性能比 FlashAttention-2 高出约 30%。
英伟达 H100 有些小怪癖
该研究重点关注 NVIDIA H100,不过所介绍的内容也适用于其他 GPU。
-
80 GB HBM3,带宽为 3 TB/s(实际上带宽会少一些);
-
50 MB 二级缓存,带宽 12 TB/s,在 GPU 上分成两个 25MB 的部分,通过 crossbar 连接;
-
132 个流多处理器 (SM,streaming multiprocessors)。
除了上述这些,H100 SXM GPU 还有很多可关注的东西,例如内存控制器、指令缓存等。
研究者表示保持张量核心的运行流畅并不容易。他们发现了一些 AI 硬件上的怪癖,这些怪癖中的很多内容也适用于非 H100 GPU,但 H100 尤其棘手。(相比之下,RTX 4090 则非常容易使用),这些怪癖包括:
-
WGMMA 指令是必需的,但使用起来也非常令人恼火;
-
-
-
H100 有一组新指令,称为「warp group matrix multiply accumulate,WGMMA」(PTX 中的 wgmma.mma_async,或 SASS 中的 HGMMA/IGMMA/QGMMA/BGMMA)。以前的 GPU 上可用的张量核心指令是 wmma.mma.sync 和 mma.sync 。通过这些指令,SM 单个象限上的 32 个线程将同步地将其数据块馈送到张量核心并等待结果。
不同的是,wgmma.mma_async 指令并非如此,128 个连续线程(分布在 SM 的所有象限中)协作同步,并直接从共享内存(也可以选择寄存器)异步启动矩阵乘法。
在基准测试中,研究团队发现这些指令对于提取 H100 的完整计算是必要的。如果没有它们,GPU 的峰值利用率似乎只能达到峰值利用率的 63% 左右。
共享内存的单次访问延迟约为 30 个周期,这听起来似乎不算多,但在这段时间内,SM 的张量核心几乎可以完成两个完整的 32x32 矩阵乘法运算。
共享内存处理起来有些棘手,因为它被存储(banked)在 32 个独立的内存存储中。如果不小心,这可能会导致所谓的 bank 冲突,即同一内存 bank 被要求同时提供多个不同的内存片段,导致请求被串行化,这可能会不成比例地减慢内核的速度 - 而 wgmma 和 mma 指令所需的寄存器布局会受到这些 bank 冲突的影响。解决方法是使用各种交错模式重新排列共享内存,以避免这些冲突。
H100 其中一个特点是张量核心和内存都足够快,以至于仅仅生成用于获取数据的内存地址就占据了芯片资源的相当一部分。
NVIDIA 似乎已经意识到了这一点,因为他们赋予了 GPU 张量内存加速器(或称之为 TMA)。TMA 允许用户在全局和共享内存中指定多维张量布局,这节省了所有的地址生成成本,并且还使得构建 pipeline 更加容易。
研究团队还发现 TMA 和 wgmma.mma_async 一样,在实现 H100 的全部潜力方面是完全不可或缺的。
在某些方面,与前几代硬件相比,H100 对占用率的依赖程度较低。NVIDIA 确实在设计 GPU 时考虑了占用率。虽然对于 H100 来说,占用率只能说有用,但作用不大。研究者发现在 A100 和 RTX 4090 上它变得越来越重要。
ThunderKittens
那么,如何才能更轻松地编写内核,同时仍兼具硬件的全部功能?
研究团队设计了一个嵌入 CUDA 中的 DSL,被命名为 ThunderKittens。
ThunderKittens 旨在尽可能简单,并包含四种模板类型:
-
寄存器 tile—— 寄存器文件中的 2D 张量。
-
-
-
tile 通过高度、宽度和布局进行参数化,寄存器向量由长度和布局参数化,共享向量仅由长度参数化。这样通常不会遭受 bank 冲突的困扰。
该研究给出了一个用 ThunderKittens 编写的,用于 RTX 4090 的简单前向 flash attention 内核:
#define NUM_WORKERS 16
using namespace kittens;
__global__ void attend_ker64(int n, const bf16* __restrict__ __q__, const bf16* __restrict__ __k__, const bf16* __restrict__ __v__, bf16* __o__) {
auto warpid = kittens::warpid();
auto block_start = blockIdx.x*(n*64);
const bf16 *_q = __q__ + block_start, *_k = __k__ + block_start, *_v = __v__ + block_start;
bf16 *_o = __o__ + block_start;
extern __shared__ alignment_dummy __shm[];
shared_allocator al((int*)&__shm[0]);
st_bf_1x4<:st_layout::swizzle> (&k_smem)[NUM_WORKERS] = al.allocate, NUM_WORKERS>();
st_bf_1x4<:st_layout::swizzle> (&v_smem)[NUM_WORKERS] = al.allocate, NUM_WORKERS>();
rt_bf_1x4<> q_reg, k_reg, v_reg;
rt_fl_1x1<> att_block;
rt_bf_1x1<> att_block_mma;
rt_fl_1x4<> o_reg;
rt_fl_1x1<>::col_vec max_vec_last, max_vec;
rt_fl_1x1<>::col_vec norm_vec_last, norm_vec;
int qo_blocks = n / (q_reg.rows*NUM_WORKERS), kv_blocks = n / (q_reg.rows*NUM_WORKERS);
for(auto q_blk = 0; q_blk < qo_blocks; q_blk++) {
load(q_reg, _q + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
mul(q_reg, q_reg, __float2bfloat16(0.125f));
neg_infty(max_vec);
zero(norm_vec);
zero(o_reg);
for(auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++) {
load(v_smem[warpid], _v + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
load(k_smem[warpid], _k + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
__syncthreads();
for(int subtile = 0; subtile < NUM_WORKERS; subtile++) {
load(k_reg, k_smem[subtile]);
zero(att_block);
mma_ABt(att_block, q_reg, k_reg, att_block);
copy(norm_vec_last, norm_vec);
copy(max_vec_last, max_vec);
row_max(max_vec, att_block, max_vec);
sub_row(att_block, att_block, max_vec);
exp(att_block, att_block);
sub(max_vec_last, max_vec_last, max_vec);
exp(max_vec_last, max_vec_last);
mul(norm_vec, norm_vec, max_vec_last);
row_sum(norm_vec, att_block, norm_vec);
div_row(att_block, att_block, norm_vec);
mul(norm_vec_last, norm_vec_last, max_vec_last);
div(norm_vec_last, norm_vec_last, norm_vec);
copy(att_block_mma, att_block);
load(v_reg, v_smem[subtile]);
rt_bf_1x4<:rt_layout::col> &v_reg_col = swap_layout_inplace(v_reg);
mul_row(o_reg, o_reg, norm_vec_last);
mma_AB(o_reg, att_block_mma, v_reg_col, o_reg);
}
__syncthreads();
}
store(_o + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, o_reg, q_reg.cols);
}
}
总共大约有 60 行 CUDA 代码,硬件利用率为 75%,虽然非常密集,但大部分复杂性在于算法,而不是混合模式或寄存器布局。
TMA、WGMMA、swizzling 模式和描述符的复杂度又如何呢?如下是用 ThunderKittens 编写的, H100 的 FlashAttention-2 前向传递:
template<int D>
__global__ __launch_bounds__((NUM_WORKERS)*kittens::WARP_THREADS, 2)
void fwd_attend_ker_dim(int N, const CUtensorMap* tma_q, const CUtensorMap* tma_k, const CUtensorMap* tma_v, CUtensorMap* tma_o) {
extern __shared__ int __shm[];
tma_swizzle_allocator al((int*)&__shm[0]);
constexpr int tile_width = fwd_attend_ker_tile_dims::tile_width;
constexpr int qo_height = fwd_attend_ker_tile_dims::qo_height;
constexpr int kv_height = fwd_attend_ker_tile_dims::kv_height;
st_bf (&q_smem) [NUM_WARPGROUPS] = al.allocate, NUM_WARPGROUPS>();
st_bf (&k_smem)[2][NUM_WORKERS_KV] = al.allocate, 2, NUM_WORKERS_KV>();
st_bf (&v_smem)[2][NUM_WORKERS_KV] = al.allocate, 2, NUM_WORKERS_KV>();
int tic = 0, toc = 1;
rt_fl<1, kv_height> att_block;
rt_bf<1, kv_height> att_block_mma;
rt_fl<1, qo_height> o_prev;
col_vec1, kv_height>> max_vec_last, max_vec;
col_vec1, kv_height>> norm_vec_last, norm_vec;
int warpid = kittens::warpid();
int warpgroupid = warpid/kittens::WARPGROUP_WARPS;
int kv_blocks = N / (NUM_WORKERS_KV*k_smem[0][0].rows);
__shared__ uint64_t qsmem_barrier, kvsmem_barrier;
int q_phasebit = 0;
int kv_phasebit = 0;
if (threadIdx.x == 0) {
tma::init_barrier, NUM_WARPGROUPS>(qsmem_barrier, 1);
tma::init_barrier, NUM_WORKERS_KV*2>(kvsmem_barrier, 1);
}
if (warpid == 0) {
for (int wg = 0; wg < NUM_WORKERS/kittens::WARPGROUP_WARPS; wg++) {
int tile_idx = (blockIdx.y * NUM_WARPGROUPS * gridDim.x) + (blockIdx.x * NUM_WARPGROUPS) + wg;
tma::load_async((q_smem[wg]), tma_q, qsmem_barrier, tile_idx);
}
for (int w = 0; w < NUM_WORKERS_KV; w++) {
int tile_idx = (blockIdx.y * NUM_WORKERS_KV * kv_blocks) + (0 * NUM_WORKERS_KV) + w;
tma::load_async((k_smem[tic][w]), tma_k, kvsmem_barrier, tile_idx);
tma::load_async((v_smem[tic][w]), tma_v, kvsmem_barrier, tile_idx);
}
}
neg_infty(max_vec);
zero(norm_vec);
zero(o_prev);
__syncthreads();
tma::arrive_and_wait(qsmem_barrier, q_phasebit);
q_phasebit ^= 1;
if constexpr (D == 64) { warpgroup::mul(q_smem[warpgroupid], q_smem[warpgroupid], __float2bfloat16(0.125f)); }
else { warpgroup::mul(q_smem[warpgroupid], q_smem[warpgroupid], __float2bfloat16(0.08838834764f)); }
for (auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++, tic ^= 1, toc ^= 1) {
tma::arrive_and_wait(kvsmem_barrier, kv_phasebit);
kv_phasebit ^= 1;
__syncthreads();
if (warpid == 0) {
tma::set_bytes(kvsmem_barrier, 2 * NUM_WORKERS_KV * k_smem[0][0].num_elements * sizeof(bf16));
if (kv_idx + 1 < kv_blocks) {
for (int w = 0; w < NUM_WORKERS_KV; w++) {
int tile_idx = (blockIdx.y * NUM_WORKERS_KV * kv_blocks) + ((kv_idx + 1) * NUM_WORKERS_KV) + w;
tma::load_async((k_smem[toc][w]), tma_k, kvsmem_barrier, tile_idx);
tma::load_async((v_smem[toc][w]), tma_v, kvsmem_barrier, tile_idx);
}
}
}
warpgroup::mma_fence(att_block);
warpgroup::mm_ABt(att_block, q_smem[warpgroupid], k_smem[tic][0]);
warpgroup::mma_commit_group();
copy(norm_vec_last, norm_vec);
copy(max_vec_last, max_vec);
warpgroup::mma_async_wait();
row_max(max_vec, att_block, max_vec);
sub_row(att_block, att_block, max_vec);
exp(att_block, att_block);
sub(max_vec_last, max_vec_last, max_vec);
exp(max_vec_last, max_vec_last);
mul(norm_vec, norm_vec, max_vec_last);
row_sum(norm_vec, att_block, norm_vec);
div_row(att_block, att_block, norm_vec);