这个演讲介绍了如何使用CUTLASS 3.x风格的代码在Hopper架构上实现输入为FPA+INTB混合精度矩阵乘法,内容包括:1.使用CuTe进行数据传输。2. FPA+INTB矩阵乘法案例讲解。Slides来自BiliBili NVIDIA英伟达频道 上传的《TensorRT-LLM 中的 Hopper Mixed GEMM 的 CUTLASS 3.x 实现讲解》视频讲解。这里参考视频并更详细记录了每一页Slides的要点,通过这个视频了解下CuTe的基本概念和CuTe实现GEMM的数据流动,以及从更High Level的角度看CUTLASS 3.x是如何实现Mixed GEMM的。
总览&目录 这个演讲主要会分成三部分,首先是对CuTe的介绍,然后介以GEMM数据传输为例展示它是如何用Cute来做的。安排这两节是因为是在CUTLASS 3.x的底层实现中,无论你的数据在各个层级的管理,还是真正做一个GEMM运算,都是需要大量的使用到CuTe的API的。不熟悉CUTLASS的开发者初次见到CuTe会比较陌生,所以对其进行介绍。最后一个会具体的去浏览一下Mixed GEMM的CUTLASS 3.x实现。
CuTe 介绍 CuTe其实就是用来管理CUDA Tensors计算的一个工具库,它最主要的概念就是Layout 和Tensor 。其中Layout是由Shape 和Stride 这两个概念组成的,可以把它理解为一个函数,作用就是把一个N维的逻辑坐标映射到真实的一维的连续的索引上去。有了这个Layout之后,再把一个真正的内存的指针传给Tensor的模板参数,这就构成了一个真正的Tensor。这里可以发现对于同一个内存指针指向的连续空间,给它不同的Layout就可以得到不同视角的Tensor,这就让CuTe具有了很大的灵活性,让它可以去处理一些复杂的索引问题。
CuTe提供了对Layout的形式代数操作:Layout可以被组合、操作、进行平铺和分区等。
CuTe提供了很多的API,包含一些基本的变换的一些API,这里列出了几个CuTe的操作函数,如get、rank、depth、shape、stride和size等。最后,Slides提到了一些其他相关概念,如Composition(组合)、Complement(补集)、Inverse(逆)、Product(乘积)和Divide(除法)。这张Slides的最后一个链接是CuTe的官方文档有更详细的CuTe的概念以及API介绍。
这张Slides讲解了在CUDA张量操作中Layout的表示方法,主要通过Shape(形状)和Stride(步长)来描述。以下是主要内容:
Layout表示:通过Shape和Stride来定义多维数组在内存中的排列方式。 使用公式:offset = inner_product(coord, stride) 例如,对于元素f,其逻辑坐标为(0,1,1),物理偏移量计算为40 + 11 + 2*1 = 3 展示了同样的2x3矩阵,但按列优先顺序存储在内存中。 展示了如何将2x3的矩阵按行优先顺序存储在内存中。 仍然来看刚才的三维的例子,我们可以把它看成是两个2x2的矩阵,也可以把后面的矩阵放在前面的矩阵的下面。这样我们就得到了一个4x2的二维矩阵,它的shape可以认为是4x2,但是我们这里并不能直接写成4,因为ac/ce/eg之间的距离不是固定的,如果写成4就没办法用一个数字去表示矩阵的Stride。注意到,a和c之间,e和g之间的距离都是4,而a和e,c和g之间的距离都是2。所以我们需要用到一个嵌套的表达。如下图红色部分:
我们需要把这个4x2矩阵的第一维的Shape写成(2,2),Stride写成(4,2),可以参考Slides中的左下角的图进行理解,在水平方向上可以理解为每个元素有2个子元素,每个子元素之间的距离是4,所以第一个Stride就是4,然后在z方向会重复两次,所以第二个Stride就是2。因此,通过嵌套形式的表达我们就可以表达出更复杂的一些Layout的例子。
这张Slides展示了多种不同的Layout。每种Layout都有其特定的用途和优势。以下是对每种Layout的简要解释:
Column-Major Padded (列优先带填充) Column-Major Interleaved (列优先交错): 形状: (4,(4,2)),步长: (4,(1,16)) Row-Major Pitch-Linear (行优先带间距): 形状: (4,(2,4)),步长: (8,(4,1)) 形状: ((2,2),(2,4)),步长: ((1,8),(16,2)) 这张Slides介绍了CUTLASS 3.x中的CuTe的使用示例。主要内容如下:
标题为"LAYOUT USAGE EXAMPLE"(Layout使用示例)。 定义了一个形状(Shape)为(8, (2, 2))和步长(Stride)为(2, (1, 16))的布局。 展示了CuTe中如何使用make_layout函数创建布局,以及如何使用make_tensor函数创建张量。 提供了一个8x4的矩阵图示,展示了元素在内存中的排列。 图中用不同颜色标注了这些访问和切片操作对应的矩阵区域。 这张Slides可以帮助大家理解为什么我们需要CuTe,在CuTe之前(CUTLASS 2.x)我们实现一个地址准换需要Slides右边展示的这么多代码,我们需要理解每一行代码的作用。有了CuTe之后我们只需要Slides左边的几行代码就可以完成了。而且在CUTLASS 2.x中,定义一个Layout每个都需要有自己的实现,现在我们只需要用Shape和Stride就可以拿到任意想要的Layout。
GEMM Data Flow with Cute 接下来了解一下GEMM中是如何用CuTe做数据的传输的。
在讲解GEMM数据传输之前需要讲一下Copy这个API,这个API是用CuTe做数据传输时一定会用到的。左边的API比较简单,把src Tensor和dst Tensor传禁区就可以完成数据拷贝,会自动根据GPU的架构,数据的存储位置去自动选择用UniversalCopy或者SM80_CP_ASYNC_CACHEALWAYS。它只会在这两个里面选择,如果我们想要得到更好的性能,建议使用右边的API。右边的copy API我们需要在第一个参数中显示指定一个copy_atom,它就是CuTe会为各种不同架构中的数据传输指令做一个封装。这里列出了不同架构下的数据传输指令,如果我们想用第二个API需要了解一下每个指令的作用和使用场景。另外,注意到这里copy的都是Tensor,所以我们可以用一些神奇的Layout达到一些除了数据拷贝之外的其它的一些变换的效果。
这张Slides举了一个用Copy做矩阵转置的简单例子,虽然不是最佳性能的实现,但可以让大家看到CuTe的魅力。右边的上两个图分别是从逻辑和物理的角度来看矩阵转置在做什么,物理角度来看就是要把abcd...->aeim...的顺序。我们在构造Tensor的时候就可以让iTensor和oTensor的shape都是一样的mxn,只不过在读进iTensor的时候让它以一个Column Major的方式读进来,所以我们构造Stride的时候传入(1, m),右边的图里把iTensor.layout也画出来了,我们再以Row Major的方式写出去就达到了一个转置的效果,因此oTensor的stride就是(n, 1)。
到了这一步如果不看Tile/Partition相关的代码,我们直接调用一个COPY就完成了转置。然后现在我们希望把它并行起来,要并行起来就需要不同的Block和Thread去负责不同区域的矩阵转置。所以,我们需要继续调用local_tile去给不同的Block分配该负责哪个区域,local_partion就是给Block里面每个Thread去分配区域。
代码里的local_tile有三个参数,第一个参数就是待划分的Tensor,这里就是iTensor。第二个参数就是希望一个Block的大小是多少,例如这个例子中BLOCK_TILE设成2,这里只是举例子,实际上不可能这么小,只是为了方面画图帮助读者理解,所以一个Block的大小就是2x2的区域。第三个参数会传入当前Block的坐标,这样通过local_tile这个API我们就可以得到一个gI Tensor,它就可以拿到当前Block(Block 0)负责的区域。在右边的图中就是绿色标注出来的元素,也就是左上角的2x2小矩阵。那么我们一共需要4个Block去把这个转置的任务完成。
现在我们得到Block需要的Tile之后,我们就可以把它传入给local_paritition,第一个参数就是local_tile得到的Tensor,第二个参数就是Block内的线程排布的情况,这里以blockDim.x=2,blockDim.y=1举例子,再次强调这里只是为了举例子。这样,我们的一个Tile是2x2,而线程的排布是2x1,那么一个线程就要负责一个1x2的Tile。Slides中图的红色部分就表示第0个Block的第0号线程要负责的数据。然后调用copy就可以做到并行的拷贝了。在这个例子中,0号线程会把offset是0的idata就是a拷贝到odata的offset 0号位置,1号线程会把offset是4的idata就是e拷贝到odata的offset 1号位置。
Slides最下方的表格展示了通过Layout我们还可以做到COPY,BROADCAST,GATHER等等操作。另外,你做这些操作你几乎不用改左边的任何的实现代码,只需要把Layout改成你需要的形式就可以了。
TiledCopy就是我们用来以Tile为单位的Copy时去构造source/dest tensor需要用到的东西。包括MMA的话,也是不同的MMA实现需要不同的Tile的形式,也需要TiledCopy去做。想构造一个Tiled Copy需要用到make_tiled_copy这个API。第一个参数传的还是Copy_Atom,第二个参数就是Dest Tensor的Stride的Layout,第三个参数是Dest Tensor的Value Layout。
Value Layout不好理解,可以用print_latex把你构造好的一个Tiled Copy打印出来,就是Slides下面的图。最左边的图就是每一个线程在它的Source Tensor里面去负责读取哪些数据,比如在这个例子里面T0需要读取列方向最开始的4个数据,T1是列方向接下来的4个。右边的图就是在Dest Tensor里面每个Thread需要去写哪些数据,在这个例子里和读Tensor的位置是一样的。
通过这个图来理解代码,首先32x8就是说线程在M这个方向是32个线程,然后在K这个方向是8个线程。Value Layout的4x1就是在M方向上要读连续的4个数据,在K方向只需要读一个数据。所以这里构造出来的Tiled Copy它的基本的Copy单位就是一个(32x4, 8x1)=(128, 8)这样的Tile。
拿到Tiled Copy之后首先需要get_slice把当前的线程号传进去,这样会得到一个Thread Copy表示当前线程需要Copy的Tile,然后用它去做partition_S并且把Source Tensor传进去,就可以直接拿到当前线程需要负责拷贝的Source Tensor的数据有哪些。它的Shape就是CPY_M和CPY_K,然后CPY就是我们刚才说的128x8的这个Tile大小,然后CPY_M和CPY_K分别表示它需要在M方向以及K方向做这么多次Copy,才能把gA这个Tensor完整的拷贝过去。同理对于Dest Tensor来说,我们需要调用一个partition_D同样可以得到(CPY, CPY_M, CPY_N)这个shape的Tensor,然后再调用copy这个API就可以了。
最右边的图画了一下Tiled Copy的封装层级,最底层它有Copy_Op和Copy_Traits这两个概念,Copy_Op就是底层的数据传输指令,是PTX的代码,然后Copy_Traits是关于代码的元信息,比如线程的Layou是什么样的。这两个就可以封装层我们最常用的Copy_Atom,CopyAtom再去封装得到TiledCopy,然后我们去划分Source Tensor和Dest Tensor需要通过get_slice拿到ThreadCopy,去做一个划分。
我们想构造一个Tiled Copy应该怎么设置Thread Layout和Value Layout呢?这个和Copy_Atom的指令是相关的。这里讲一个LDSM的例子展示一下应该怎么设置这个参数。LDSM就是ld.maxtrix这个指令,这个指令就是以一个warp为单位去load 1个/2个/4个 8x8矩阵的指令。这个指令有2种形式,一种是Trans的,一种是非Trans的。在CuTe里面把它封装成LDSM_后缀,LDSM_N代表非Trans的类型。对于非Trans类型可以在这里打印一下Copy_Atom,非Trans表示Source Thread会读取一列连续的数据,然后Dest里的Thread会拿连续的一列里面的2个元素。对于Trans类型,不同在于Source一个线程仍然是拿8个连续元素,但是Dest这边会把Source这边的一个线程连续的8个元素分给8个不同的线程,而一个线程的元素会来自两个不同的Source的线程。如果使用非Trans类型的话,Layout一定要传一个col-major进来,同理对于Trans类型一定需要一个raw-major的Source Tensor。
接下来看Stride/Value Layout的设置,这是一个Warp级别的指令,线程数一定要是32的倍数。我们先看非Trans的情况,Dest Tensor这里是在m方向上4个线程去负责连续的8个数据, 意味着这里的线程Layout一定要是4的倍数,并且M方向一定是取2个连续的数据。另外,对于Thread Layout的另外一个没有标注为红色的数字,我们是可以扩大的然后得到一个更大的Tile。
同理,对于Trans类型,它是8个连续的数据分给8个不同的线程,每个线程拿1个数据,所以在K维度上线程必须是8的倍数,并且K方向的Value Layout必须是1了。然后,M方向的Value Layout值是2,这里Dest Tensor的某个线程在M方向上拿到的2个数据是不连续的2个数据。
Tiled Copy我听得比较迷糊,建议大家学习下reed佬的CuTe文章。
接下来开始看GEMM中的数据传输是怎么做的,我们以矩阵A为例想一下怎么把数据从Global Memory传输到Shared Memory。
首先,我们需要用make_tiled_copy来构造一个Tiled Copy,然后get_slice拿到对应的Thread Copy。接着,我们来构造Copy的Source Tensor,这个时候Soruce Tensor是来自Global Memory的,然后我们需要以Block的形式把它拷贝到Shared Memory,这里就需要用到local_tile这个指令,第一个参数传的就是Global Memory里面的Tensor mA,然后把Block Shape/Thread传进去,Step是<_1 x _1>{} , 这是因为CUTLASS里面会把Block写成M, N, K三维结构,对于A来说没有N这个维度,所以这里就设置为X,表示这个维度不参与计算。通过这个local_tile我们就得到了gA,它就是当前Block需要负责的一个Source Tensor的表示。它的Shape就是(BLK_M, BLK_K, k),BLK_M和BLK_K就是这个Tile的Shape,k就表示一共有k个这个Shape的Tile要做拷贝。然后构造Dest Tensor的时候,由于Dest Tensor是在Shared Memory上,我们直接用make_tensor就可以了。这里的gA和sA拿到的是当前Block负责的数据区域,我们还需要进一步使用我们刚才获得的Thread Copy,分别用partition_S和partition_D得到当前线程负责的区域,来看Shape的话,对于partition_S得到的就是(ACPY, ACPY_M, ACPY_K, k),这里的k还是保持不变,即一共有k个Tile,只不过在拷贝当前这个Tile的时候需要以ACPY为单位,然后分别在M和K方向上拷贝这么多次。对于Dest Tensor的话,Shape的最后一个维度就不是k了,是PIPE,因为我们需要用到Pipline,PIPE的意思就是一共有多少个Stage。
接着是Shared Memory到Register File的拷贝,因为Register File是直接要用来做MMA的,所以数据的排布是不能随便设置的。我们可以直接使用CuTe提供的make_tiled_copy_A这个API来构造Tiled Copy,这样就不需要设置它的Thread Layout和Value Layout了,只需要把我们构造的一个用于做MMA的tile_mma传过去就可以自动为我们计算我们需要用什么样的Layout去做Copy。然后同样还是用get_slice拿到Thread Copy。然后Source Tensor因为它不涉及Register File,所以还是partition_S就可以完成。然后在Dest Tensor涉及到MMA所以有一些不一样,所以我们需要用MMA的get_thread_slice去拿到thread_mma,然后使用partition_fragment_A去拿到MMA视角的当前Thread需要负责的Tensor是什么。最后还需要使用retile_D才可以得到我们的Copy视角下我们需要负责的一个Tensor是什么。最后还是一样的copy完成数据拷贝。
Mixed GEMM Walk Through 这部分更多的是在Concept级别的事情去怎么做,它是会组合上面提到的CuTe相关的代码以及《TensorRT-LLM中的 Quantization GEMM(Ampere Mixed GEMM)的 CUTLASS 2.x 实现讲解》里面讲到的Fast Convert的相关代码,主要是展示怎么做。
这张Slides展示了Hopper上的WGMMA的PTX,定义了Hopper上的WGMMA在做什么。《TensorRT-LLM中的 Quantization GEMM(Ampere Mixed GEMM)的 CUTLASS 2.x 实现讲解》里面介绍到的CUTLASS 2.x的Ampere的这些Tensor Core是同步的。同步意味着输入输出A,B,C都是在寄存器层面发射一条同步指令。Hopper上这个指令变成异步之后它可以接收来自Shared Memory的矩阵A,B了。然后,在Hopper架构中我们仍然没有一种(除了FP8这种指令)FP8和FP16直接计算的指令,所以我们要做的事情和《TensorRT-LLM中的 Quantization GEMM(Ampere Mixed GEMM)的 CUTLASS 2.x 实现讲解》中介绍的差不多。我们的数据是Mixed的,但是我们读上来的数据要做Conversion。然后我们可以把weight权重放到矩阵A那边,然后读出来之后我们做一些Conversion,这个数据留存在寄存器上,然后我们把原来的矩阵A放在矩阵B的位置,这样直接去读就可以了。
然后这张Slides的内容也比较多,看图片比较模糊,这里简要说明下。这张Slides是关于如何实现 Mixed 数据类型的通用矩阵乘加(GEMM)操作,特别是在使用Hopper架构下的异步Warp Group矩阵乘加累积(MMA)操作时。内容涵盖了异步Warp Group级的矩阵乘法和累积操作的执行方法,支持的数据类型和矩阵形状,以及相关的编程指令。具体内容包括:
执行fence操作,确保寄存器/共享内存中的操作对warp组可见。 执行wmma.fence
操作,确保所有在warp组内对寄存器或共享内存的写入都已经完成,这是确保数据一致性的关键步骤。 执行fence.proxy.async
操作,这是一个代理操作,用于在异步代理中使通用代理操作可见。 通过wmma.mma_async
指令进行异步矩阵乘加操作。这个指令允许在不阻塞其他GPU操作的情况下,进行矩阵乘法和累加运算。 使用wmma.commit_group
指令来创建一个wmma操作组,并提交所有挂起的wmma.mma_async
操作到这个组中。 确保在继续之前,所需的wmma组已经完成所有操作。 一旦wmma组完成,所有的wmma.mma_async
操作也就全部执行完毕。 操作类型 :介绍了两种基本操作,一种是矩阵D作为输入和累积器被禁用的情况,另一种是常规的矩阵乘法和累积。异步Multiply-and-Accumulate指令 wmma.mma_async
指令:具体介绍了如何使用此指令执行矩阵乘加操作。语法:提供了不同数据类型(如半精度浮点型)的语法示例。 数据类型:列出了支持的多种数据类型,如半精度浮点、整数等。 矩阵形状:详细列出了操作支持的矩阵形状,如16x16x16、32x8x16等,这有助于开发者选择适合其特定应用需求的矩阵形状。 这张Slides讲解了如何实现混合数据类型的GEMM(通用矩阵乘法)。主要内容包括:
A和B矩阵的数据类型不同,例如A(激活)用FP16/FP8,B(权重)用INT8/INT4/FP8 将高精度数据加载到smem (不需要做任何Conversion) 将低精度数据加载到smem (这个是必然要求,我们要做MultiStage,数据一定要从Gloabl中Stream到各个Stage的Shared Memory中) 补充:
WGMMA: Hopper Warp Group MMA (矩阵乘累加操作)
解释:这是Hopper架构上的Warp组矩阵乘累加操作 WS: Warp Specialized
解释:表示warp专业化,即不同的warp执行不同的专门任务 SS: Src operator of GMMA are both from SMEM
解释:GMMA操作的两个源操作数都来自共享内存(SMEM) RS: Src operator A of GMMA is from RF, Src operator B is from SMEM 下面这张Slides的大部分内容是在《CUTLASS 2.x & CUTLASS 3.x Intro 学习笔记》里面有讲过的,学习之前贴到下面再复习一下再贴这节课的Slides方便大家比较;
这张Slides介绍了一下Hopper架构上的Warp Specialized GEMM实现,采用了生产者-消费者模型。内容如下:
源代码位置:cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_mixed_input.hpp 总体架构分为Producer Warps (TMA Warps) 和 Consumer Warps (TC Warps),通过共享内存进行数据交换。 Producer Warps (TMA Warps): 使用CollectiveMma::load(...) & Persistent方法 发出TMA指令加载A和B矩阵,更新smem_full barrier 更新传输字节数并到达smem_full barrier Consumer Warps (TC Warps): 使用CollectiveMma::mma(...) & Persistent方法 使用SWIZZLE将寄存器文件(RF)写入共享内存(SMEM) 包含Mbarrier和Data Buffer两部分 每个stage有两个buffer:Mat A MtilexKtile 和 Mat B NtilexKtile 使用smem_empty和smem_full标志来同步Producer和Consumer Producer和Consumer交替工作,通过共享内存和 barrier机制同步 多个stage (0 到 N-1) 用于流水线操作 对比上面的Hopper架构上的Warp Specialized GEMM实现,这里的Consumer Warps少了一个Persistent方法,只关注CollectiveMma::mma(...)。中间的Shared Memory之前是两个Data Type一模一样的,但在这里我们交换了A,B矩阵,并且它们的Data Type是不一样的。所以,这里故意画的矩阵A的这个Buffer的长度要短一点,表示低精度。
Slides的流程图展示了数据处理的pipeline,包括: 生产者线程束(Producer Warps / TMA Warps) 消费者线程束(Consumer Warps / TC Warps) 发出TMA指令加载A和B矩阵,并更新smem_full barrier 更新传输字节并到达smem_full barrier 还需要注意下这张Slides里面去掉了K方向的循环,但实际实现中仍然是有的。
这两张Slides分别对生产者线程束 (Producer Warps) 和消费者线程束(TC Warps)对应流程的一部分底层代码进行了解释,这些代码比较抽象和复杂,视频里面也没怎么讲,感兴趣的可以去深入CutLass的源代码研究。
最后这张Slides是介绍消费者线程束(TC Warps)中将低精度数据转换为高精度并保存在寄存器文件(RF)中的实现细节,以及具体是怎么做的Copy的细节。
总结 总的来说,这个演讲是一个针对CUTLASS 3.x版本在Hopper架构上实现混合精度矩阵乘法的技术演讲的总结。主要内容包括:
CuTe工具库的介绍,其核心概念是Layout和Tensor,可以灵活处理复杂的索引问题。 以矩阵转置和GEMM数据传输为例,展示了CuTe在数据操作和并行计算中的强大功能。 详细讲解了Hopper架构上异步Warp Group级矩阵乘加操作的执行方法和支持的数据类型。 介绍了如何利用异步WGMMA指令以及CuTe实现Mixed GEMM的一些细节。 从这个笔记可以简要了解一些概念,起一个科普和熟悉CuTe相关API的作用,如果要进一步学习则需要继续深入学习CuTe和CutLass,当然这里的内容对学习CutLass也是有帮助的。