专栏名称: 机器学习算法与自然语言处理
一个有情怀的公众号。机器学习、自然语言处理、算法等知识集中营、期待与你相遇~
目录
相关文章推荐
连云港市场监管  ·  连云港在全省率先出台《数据知识产权保护指南》 ... ·  12 小时前  
连云港市场监管  ·  连云港在全省率先出台《数据知识产权保护指南》 ... ·  12 小时前  
知识产权那点事  ·  【案例报告】AI一键生成的图片版权归属 ·  4 天前  
太格有物  ·  新品快讯|沃尔沃发布越野电动SUV,adid ... ·  2 天前  
51好读  ›  专栏  ›  机器学习算法与自然语言处理

只需百行代码,让H100提速30%,斯坦福开源全新AI加速框架

机器学习算法与自然语言处理  · 公众号  ·  · 2024-05-13 20:40

正文


MLNLP 社区是国内外知名的机器学习与自然语言处理社区,受众覆盖国内外NLP硕博生、高校老师以及企业研究人员。
社区的愿景 是促进国内外自然语言处理,机器学习学术界、产业界和广大爱好者之间的交流和进步,特别是初学者同学们的进步。
转载自 | 机器之心
提高 GPU 利用率,就是这么简单。
AI 的快速发展,伴随而来的是大计算量。这就自然而然的引出了一个问题:如何减少 AI 对计算的需求,并提高现有 AI 计算效率。
为了回答这一问题,来自斯坦福的研究者在博客《GPUs Go Brrr》中给出了答案。
博客地址:https://hazyresearch.stanford.edu/blog/2024-05-12-tk
文章主要专注于两个问题:一是硬件真正需要什么?二是如何满足硬件需求?
文章用大量篇幅讨论了如何让 GPU 更快的运行,并发布了一个库 ThunderKittens,用户可以很容易地在 CUDA 上编写快速的深度学习内核。其具有以下特点:
  • 简单,ThunderKittens 写起来非常简单。
  • 可扩展性,如果用户需要 ThunderKittens 无法提供的功能,可以进行功能扩展。
  • 速度快。
GitHub 链接:https://github.com/HazyResearch/ThunderKittens
ThunderKittens 使得一些棘手的事情变得非常简单,从而在现代硬件上实现了非常高的利用率。项目中,作者用 ThunderKittens 编写了一个 RTX 4090 简单的 FlashAttention-2 内核,代码总共有 58 行代码(不包括空格),结果显示,ThunderKittens 在 RTX 4090 上实现了大约 122 TFLOP(理论最大值的 74%)。此外,内核程序只有 100 行的情况下,ThunderKittens 在 H100 上的性能比 FlashAttention-2 高出约 30%。

英伟达 H100 有些小怪癖

该研究重点关注 NVIDIA H100,不过所介绍的内容也适用于其他 GPU。
H100 SXM GPU 包含:
  • 80 GB HBM3,带宽为 3 TB/s(实际上带宽会少一些);
  • 50 MB 二级缓存,带宽 12 TB/s,在 GPU 上分成两个 25MB 的部分,通过 crossbar 连接;
  • 132 个流多处理器 (SM,streaming multiprocessors)。
除了上述这些,H100 SXM GPU 还有很多可关注的东西,例如内存控制器、指令缓存等。
研究者表示保持张量核心的运行流畅并不容易。他们发现了一些 AI 硬件上的怪癖,这些怪癖中的很多内容也适用于非 H100 GPU,但 H100 尤其棘手。(相比之下,RTX 4090 则非常容易使用),这些怪癖包括:
  • WGMMA 指令是必需的,但使用起来也非常令人恼火;
  • 共享内存实际上并没有那么快,并且需要非常小心;
  • 地址生成成本很高;
  • 占用率仍然有帮助,寄存器通常是关键资源。
文章进一步描述了 GPU 这些怪癖的具体内容。
WGMMA 指令令人恼火
H100 有一组新指令,称为「warp group matrix multiply accumulate,WGMMA」(PTX 中的 wgmma.mma_async,或 SASS 中的 HGMMA/IGMMA/QGMMA/BGMMA)。以前的 GPU 上可用的张量核心指令是 wmma.mma.sync 和 mma.sync 。通过这些指令,SM 单个象限上的 32 个线程将同步地将其数据块馈送到张量核心并等待结果。
不同的是,wgmma.mma_async 指令并非如此,128 个连续线程(分布在 SM 的所有象限中)协作同步,并直接从共享内存(也可以选择寄存器)异步启动矩阵乘法。
在基准测试中,研究团队发现这些指令对于提取 H100 的完整计算是必要的。如果没有它们,GPU 的峰值利用率似乎只能达到峰值利用率的 63% 左右。
共享内存
共享内存的单次访问延迟约为 30 个周期,这听起来似乎不算多,但在这段时间内,SM 的张量核心几乎可以完成两个完整的 32x32 矩阵乘法运算。
共享内存处理起来有些棘手,因为它被存储(banked)在 32 个独立的内存存储中。如果不小心,这可能会导致所谓的 bank 冲突,即同一内存 bank 被要求同时提供多个不同的内存片段,导致请求被串行化,这可能会不成比例地减慢内核的速度 - 而 wgmma 和 mma 指令所需的寄存器布局会受到这些 bank 冲突的影响。解决方法是使用各种交错模式重新排列共享内存,以避免这些冲突。
地址生成
H100 其中一个特点是张量核心和内存都足够快,以至于仅仅生成用于获取数据的内存地址就占据了芯片资源的相当一部分。
NVIDIA 似乎已经意识到了这一点,因为他们赋予了 GPU 张量内存加速器(或称之为 TMA)。TMA 允许用户在全局和共享内存中指定多维张量布局,这节省了所有的地址生成成本,并且还使得构建 pipeline 更加容易。
研究团队还发现 TMA 和 wgmma.mma_async 一样,在实现 H100 的全部潜力方面是完全不可或缺的。
占用
在某些方面,与前几代硬件相比,H100 对占用率的依赖程度较低。NVIDIA 确实在设计 GPU 时考虑了占用率。虽然对于 H100 来说,占用率只能说有用,但作用不大。研究者发现在 A100 和 RTX 4090 上它变得越来越重要。

ThunderKittens

那么,如何才能更轻松地编写内核,同时仍兼具硬件的全部功能?
研究团队设计了一个嵌入 CUDA 中的 DSL,被命名为 ThunderKittens。
ThunderKittens 旨在尽可能简单,并包含四种模板类型:
  • 寄存器 tile—— 寄存器文件中的 2D 张量。
  • 寄存器向量 —— 寄存器文件中的 1D 张量。
  • 共享 tile—— 共享内存中的 2D 张量。
  • 共享向量 —— 共享内存中的 1D 张量。
tile 通过高度、宽度和布局进行参数化,寄存器向量由长度和布局参数化,共享向量仅由长度参数化。这样通常不会遭受 bank 冲突的困扰。
研究团队还提供了一些必要操作:
初始化,如将共享向量清零
  • 一元运算,如 exp
  • 二元运算,如 mul
  • 行 / 列操作,如 row_sum
该研究给出了一个用 ThunderKittens 编写的,用于 RTX 4090 的简单前向 flash attention 内核:
#define NUM_WORKERS 16 // This kernel uses 16 workers in parallel per block, to help issue instructions more quickly.using namespace kittens; // this kernel only handles headdim=64 for simplicity. Also n should be a multiple of 256 here.__global__ void attend_ker64(int n, const bf16* __restrict__ __q__, const bf16* __restrict__ __k__, const bf16* __restrict__ __v__, bf16* __o__) {
auto warpid = kittens::warpid(); auto block_start = blockIdx.x*(n*64); const bf16 *_q = __q__ + block_start, *_k = __k__ + block_start, *_v = __v__ + block_start; bf16 *_o = __o__ + block_start;
extern __shared__ alignment_dummy __shm[]; // this is the CUDA shared memory shared_allocator al((int*)&__shm[0]); // K and V live in shared memory -- this is about all that will fit. st_bf_1x4<:st_layout::swizzle> (&k_smem)[NUM_WORKERS] = al.allocate, NUM_WORKERS>(); st_bf_1x4<:st_layout::swizzle> (&v_smem)[NUM_WORKERS] = al.allocate, NUM_WORKERS>();
// Initialize all of the register tiles. rt_bf_1x4<> q_reg, k_reg, v_reg; // v_reg need to be swapped into col_l rt_fl_1x1<> att_block; rt_bf_1x1<> att_block_mma; rt_fl_1x4<> o_reg; rt_fl_1x1<>::col_vec max_vec_last, max_vec; // these are column vectors for the attention block rt_fl_1x1<>::col_vec norm_vec_last, norm_vec; // these are column vectors for the attention block int qo_blocks = n / (q_reg.rows*NUM_WORKERS), kv_blocks = n / (q_reg.rows*NUM_WORKERS);
for(auto q_blk = 0; q_blk < qo_blocks; q_blk++) {
// each warp loads its own Q tile of 16x64, and then multiplies by 1/sqrt(d) load(q_reg, _q + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols); mul(q_reg, q_reg, __float2bfloat16(0.125f)); // temperature adjustment
// zero flash attention L, M, and O registers. neg_infty(max_vec); // zero registers for the Q chunk zero(norm_vec); zero(o_reg);
// iterate over k, v for these q's that have been loaded for(auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++) {
// each warp loads its own chunk of k, v into shared memory load(v_smem[warpid], _v + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols); load(k_smem[warpid], _k + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols); __syncthreads(); // we need to make sure all memory is loaded before we can begin the compute phase
// now each warp goes through all of the subtiles, loads them, and then does the flash attention internal alg. for(int subtile = 0; subtile < NUM_WORKERS; subtile++) {
load(k_reg, k_smem[subtile]); // load k from shared into registers
zero(att_block); // zero 16x16 attention tile mma_ABt(att_block, q_reg, k_reg, att_block); // [email protected]
copy(norm_vec_last, norm_vec); copy(max_vec_last, max_vec);
row_max(max_vec, att_block, max_vec); // accumulate onto the max_vec sub_row(att_block, att_block, max_vec); // subtract max from attention -- now all <=0 exp(att_block, att_block); // exponentiate the block in-place.
sub(max_vec_last, max_vec_last, max_vec); // subtract new max from old max to find the new normalization. exp(max_vec_last, max_vec_last); // exponentiate this vector -- this is what we need to normalize by. mul(norm_vec, norm_vec, max_vec_last); // and the norm vec is now normalized.
row_sum(norm_vec, att_block, norm_vec); // accumulate the new attention block onto the now-rescaled norm_vec div_row(att_block, att_block, norm_vec); // now the attention block is correctly normalized
mul(norm_vec_last, norm_vec_last, max_vec_last); // normalize the previous norm vec according to the new max div(norm_vec_last, norm_vec_last, norm_vec); // normalize the previous norm vec according to the new norm
copy(att_block_mma, att_block); // convert to bf16 for mma_AB
load(v_reg, v_smem[subtile]); // load v from shared into registers. rt_bf_1x4<:rt_layout::col> &v_reg_col = swap_layout_inplace(v_reg); // this is a reference and the call has invalidated v_reg
mul_row(o_reg, o_reg, norm_vec_last); // normalize o_reg in advance of mma_AB'ing onto it mma_AB(o_reg, att_block_mma, v_reg_col, o_reg); // mfma onto o_reg with the local attention@V matmul. } __syncthreads(); // we need to make sure all warps are done before we can start loading the next kv chunk }
store(_o + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, o_reg, q_reg.cols); // write out o. compiler has an issue with register usage if d is made constexpr q_reg.rows :/ }}
总共大约有 60 行 CUDA 代码,硬件利用率为 75%,虽然非常密集,但大部分复杂性在于算法,而不是混合模式或寄存器布局。
TMA、WGMMA、swizzling 模式和描述符的复杂度又如何呢?如下是用 ThunderKittens 编写的, H100 的 FlashAttention-2 前向传递:
template<int D>__global__  __launch_bounds__((NUM_WORKERS)*kittens::WARP_THREADS, 2)void fwd_attend_ker_dim(int N, const CUtensorMap* tma_q, const CUtensorMap* tma_k, const CUtensorMap* tma_v, CUtensorMap* tma_o) {    extern __shared__ int __shm[]; // this is the CUDA shared memory    tma_swizzle_allocator al((int*)&__shm[0]);
constexpr int tile_width = fwd_attend_ker_tile_dims::tile_width; // constants constexpr int qo_height = fwd_attend_ker_tile_dims::qo_height; constexpr int kv_height = fwd_attend_ker_tile_dims::kv_height;
st_bf (&q_smem) [NUM_WARPGROUPS] = al.allocate, NUM_WARPGROUPS>(); st_bf (&k_smem)[2][NUM_WORKERS_KV] = al.allocate, 2, NUM_WORKERS_KV>(); st_bf (&v_smem)[2][NUM_WORKERS_KV] = al.allocate, 2, NUM_WORKERS_KV>();
int tic = 0, toc = 1; rt_fl<1, kv_height> att_block; rt_bf<1, kv_height> att_block_mma; rt_fl<1, qo_height> o_prev; col_vec1, kv_height>> max_vec_last, max_vec; col_vec1, kv_height>> norm_vec_last, norm_vec;
int warpid = kittens::warpid(); int warpgroupid = warpid/kittens::WARPGROUP_WARPS;
int kv_blocks = N / (NUM_WORKERS_KV*k_smem[0][0].rows);
__shared__ uint64_t qsmem_barrier, kvsmem_barrier;//, vsmem_barrier;
int q_phasebit = 0; int kv_phasebit = 0;
if (threadIdx.x == 0) { tma::init_barrier, NUM_WARPGROUPS>(qsmem_barrier, 1); tma::init_barrier, NUM_WORKERS_KV*2>(kvsmem_barrier, 1); }
if (warpid == 0) { for (int wg = 0; wg < NUM_WORKERS/kittens::WARPGROUP_WARPS; wg++) { // load q int tile_idx = (blockIdx.y * NUM_WARPGROUPS * gridDim.x) + (blockIdx.x * NUM_WARPGROUPS) + wg; tma::load_async((q_smem[wg]), tma_q, qsmem_barrier, tile_idx); } for (int w = 0; w < NUM_WORKERS_KV; w++) { // load k, v int tile_idx = (blockIdx.y * NUM_WORKERS_KV * kv_blocks) + (0 * NUM_WORKERS_KV) + w; tma::load_async((k_smem[tic][w]), tma_k, kvsmem_barrier, tile_idx); tma::load_async((v_smem[tic][w]), tma_v, kvsmem_barrier, tile_idx); } }
neg_infty(max_vec); // zero registers for the Q chunk zero(norm_vec); zero(o_prev); __syncthreads();
tma::arrive_and_wait(qsmem_barrier, q_phasebit); q_phasebit ^= 1;
if constexpr (D == 64) { warpgroup::mul(q_smem[warpgroupid], q_smem[warpgroupid], __float2bfloat16(0.125f)); } else { warpgroup::mul(q_smem[warpgroupid], q_smem[warpgroupid], __float2bfloat16(0.08838834764f)); }
for (auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++, tic ^= 1, toc ^= 1) { tma::arrive_and_wait(kvsmem_barrier, kv_phasebit); kv_phasebit ^= 1;
__syncthreads(); if (warpid == 0) { tma::set_bytes(kvsmem_barrier, 2 * NUM_WORKERS_KV * k_smem[0][0].num_elements * sizeof(bf16));
if (kv_idx + 1 < kv_blocks) { for (int w = 0; w < NUM_WORKERS_KV; w++) { int tile_idx = (blockIdx.y * NUM_WORKERS_KV * kv_blocks) + ((kv_idx + 1) * NUM_WORKERS_KV) + w; tma::load_async((k_smem[toc][w]), tma_k, kvsmem_barrier, tile_idx); tma::load_async((v_smem[toc][w]), tma_v, kvsmem_barrier, tile_idx); } } }
warpgroup::mma_fence(att_block); warpgroup::mm_ABt(att_block, q_smem[warpgroupid], k_smem[tic][0]); warpgroup::mma_commit_group();
copy(norm_vec_last, norm_vec); copy(max_vec_last, max_vec);
warpgroup::mma_async_wait();
row_max(max_vec, att_block, max_vec); // accumulate onto the max_vec sub_row(att_block, att_block, max_vec); exp(att_block, att_block);
sub(max_vec_last, max_vec_last, max_vec); exp(max_vec_last, max_vec_last); mul(norm_vec, norm_vec, max_vec_last);
row_sum(norm_vec, att_block, norm_vec); // accumulate onto the norm_vec div_row(att_block, att_block, norm_vec);






请到「今天看啥」查看全文