我的课程笔记,欢迎关注:https://github.com/BBuf/how-to-optim-algorithm-in-cuda/tree/master/cuda-mode
CUDA-MODE课程笔记 第7课: Quantization Cuda vs Triton
适配课件详细解读
作者课件可以在这里找到:https://github.com/cuda-mode/lectures 。我也下载里一份放在 https://github.com/BBuf/how-to-optim-algorithm-in-cuda/tree/master/cuda-mode/ppt 这里。
PyTorch最近一年发布了一些生成式AI模型的案例研究,这些模型运行速度极快,且代码及其精简。这些模型比如GPT-FAST,SAM-FAST都应用了量化技术,Charles很大程度上是这些量化Kernel的主要开发者。因此,这节课由Charles来分享量化技术。
这张Slides介绍了演讲者的背景和最近的研究重点,内容如下:
AO (Architecture Optimization) 团队
Segment-anything-fast, gpt-fast, sdxl-fast 等项目
TorchAO - 提供了一个 GitHub 链接 (https://github.com/pytorch-labs/ao)
i8i8->i32 vs i8i8bf16->bf16
这张Slides介绍了三种不同的量化技术:
动态量化流程 (Dynamic Quantization Flow):
仅权重量化 (Weight Only Quantization):
总的来说,这张Slides展示了这三种技术在处理神经网络计算时的不同流程。动态量化通过在计算过程中使用整数运算来提高效率,而仅权重量化则只对权重进行压缩,在实际计算时仍使用浮点数。未量化的方法则完全使用浮点数,可能提供最高的精度但计算效率较低。
这张Slides进一步说明了动态量化(Dynamic Quantization)的概念和流程:
量化后的公式:Y = (Sx*Xint).(Wint * Sw)
重排后的公式:Y = Sx * (Xint.Wint) * Sw
这里,Sx 和 Sw 是缩放因子,Xint 和 Wint 是量化后的整数值。
开始于浮点权重(Float Weight)和浮点激活值(Float Activation)
权重在预处理阶段进行量化(Quantize (preprocess))
使用 Int8 进行乘法运算(Multiplication (Int8))
使用 Int32 进行累加运算(Accumulation (Int32))
最后将结果重新缩放(Rescale (Float))回浮点数
输出浮点激活值(Float Activation)
这张Slides展示了逐张量量化(per-tensor quantization)和逐token量化 + 逐通道量化(per-token + per-channel quantization)两种动态量化方式。性能比较(以SAM模型为例,vit_h, bsz=16):
无量化:运行时间 785.313 ms,峰值内存 15.279(单位未指明,可能是GB)
动态量化:运行时间 731.649 ms,峰值内存 18.631
另外,这里的链接是Triton的矩阵乘法教程。
结论:动态量化可以提高计算效率,在这个例子中,运行时间减少了约7%。不同的量化策略(逐张量、逐token、逐通道)可以应用于不同的张量,以优化性能和精度。虽然动态量化提高了计算速度,但它的显存占用却更多了大概是15%-20%。
这张Slides指出,显存增加的原因是要把int8的结果累加到int32类型中,因此相比于BFloat1增加了额外的显存。
这张Slides进一步详细介绍了动态量化(Dynamic Quantization)的概念、方法和性能比较:
量化公式:Y = (Sx*Xint).(Wint * Sw)
重排公式:Y = Sx * (Xint.Wint) * Sw
其中使用了不同的数据类型:- int8:用于Xint和Wint
- bf16:用于Sx和Sw
- int32:用于中间计算结果XWint
性能比较(以SAM模型为例,vit_h, bsz=16):
无量化:运行时间 785.313 ms,峰值内存 15.279 GB
动态量化:运行时间 731.649 ms,峰值内存 18.631 GB
动态量化with fusion:运行时间 695.115 ms,峰值内存 14.941 GB
结论:动态量化可以显著提高计算效率,运行时间减少约7%。动态量化with fusion进一步优化了性能,运行时间比无量化减少约11.5%,同时还略微降低了内存使用。
这里展示的是要在Torch Compile中实现动态量化with fusion需要做出的努力,因为Torch Compile并不愿意融合乘法操作,所以作者不得不在Torch Compile的矩阵乘法kernel后强制添加一个乘法的epilogue(实际上这是一个编译器的PASS,需要匹配到矩阵乘法+乘法才能生效)。图片比较难看代码,这里贴一下:
# This op is a special case of the int_mm op which we use based on the pattern # _int_mm -> mul (defined in ../fx_passes/post_grad.py) in order to prevent # realization of the int32 _int_mm output by forcing fusion with the mul op. # This is only used when config.force_fuse_int_mm_with_mul = True def tuned_fused_int_mm_mul (mat1, mat2, mat3, out_dtype, *, layout=None) : out_dtype = ( torch.promote_types(mat3.get_dtype(), torch.int32) if out_dtype is None else out_dtype ) m, n, k, layout, mat1, mat2, mat3 = mm_args( mat1, mat2, mat3, layout=layout, out_dtype=out_dtype ) choices: List[Dict[Any, Any]] = [] for config in int8_mm_configs(m, n, k): mm_template.maybe_append_choice( choices, input_nodes=(mat1, mat2, mat3), layout=layout, **dict(mm_options(config, m, n, k, layout), ACC_TYPE="tl.int32" ), suffix_args=1 , epilogue_fn=V.ops.mul, ) return autotune_select_algorithm("int_mm" , choices, [mat1, mat2, mat3], layout)
然后,Triton在实现这个需求时相比于Torch Compile会很简单,一行代码即可。
这张Slides介绍了Int8权重量化(Int8 Weight Only Quantization)的概念和流程。主要内容:
反量化(Dequantize)步骤:将量化后的权重转回浮点
浮点激活(Float Activation)保持不变
乘法运算使用浮点(Multiplication (Float))
累加使用fp32(Accumulation (fp32))
最后输出浮点激活(Float Activation)
这张Slides展示了Int8权重量化(Int8 Weight Only Quantization)的性能表现,无量化: 93.08 tokens/s,int8权重量化: 40.59 tokens/s,可以看到int8权重量化反而降低了处理速度,约为无量化版本的43.6%。
在图表中,对比了Batch size 1: cublas 和 int8 weight only quantized matmul。蓝线: cublas A16W16 matmul (使用16位精度的cublas矩阵乘法)。红线: A16W8 matmul (使用16位激活和8位权重的矩阵乘法)
这张Slides讲到如果按照普通的gemm triton kernel模板,上面的Int8权重量化的性能低于预期的原因是:
执行了比基础matmul更多的工作,展示了一段代码,显示了额外的加载和类型转换操作,这些额外操作可能导致性能下降
块大小被限制为大于等于16,当前配置只执行64个块,少于A 100GPU的108个多处理器,这可能导致一些多处理器未被充分利用
然后Torch Compile通过链接里的代码解决了这个问题,贴一下:
@register_decomposition([aten.mm]) @pw_cast_for_opmath def mm (self, input2) : # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning. # todo: Look into why and fix it (hopefully) if config.coordinate_descent_tuning: if guard_size_oblivious(self.shape[0
] == 1 ) or guard_size_oblivious( input2.shape[1 ] == 1 ): return (self.unsqueeze(2 ) * input2.unsqueeze(0 )).sum(dim=1 ) ... return NotImplemented
实际上这个操作就是让GEMV用Cuda Core而不是Tensor Core来完成计算,具体做法就是把GEMV操作等价为一个element-wise乘法加一个reduce操作。这个操作通过Torch Compile生成的Triton Kernel代码如下:
这张Slides展示了一个名为 triton_() 的函数(由Torch编译器生成),该函数实现了 Int8 权重量化的GEMV操作。完整流程为:
xnumel 和 rnumel 都设置为 4096
使用 program_id(0) 和 XBLOCK 计算偏移量
XBLOCK 始终为 1,每个 program_id 处理输出的单个值
加载权重的一列的一个chunk(可能是 int8 格式)
def triton_ (in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr) : xnumel = 4096 rnumel = 4096 xoffset = tl.program_id(0 ) * XBLOCK xindex = xoffset + tl.arange(0 , XBLOCK)[:, None ] xmask = xindex rbase = tl.arange(0 , RBLOCK)[None , :] x0 = xindex _tmp6 = tl.full([XBLOCK, RBLOCK], 0 , tl.float32) for roffset in range(0 , rnumel, RBLOCK): rindex = roffset + rbase rmask = rindex r1 = rindex tmp0 = tl.load(in_ptr0 + (r1), None , eviction_policy='evict_last' ).to(tl.float32) tmp2 = tl.load(in_ptr1 + (r1 + (4096 *x0)), xmask, eviction_policy='evict_first' , other=0.0 ) tmp1 = tmp0.to(tl.float32) tmp3 = tmp2.to(tl.float32) tmp4 = tmp1 * tmp3 tmp5 = tl.broadcast_to(tmp4, [XBLOCK, RBLOCK]) tmp7 = _tmp6 + tmp5 _tmp6 = tl.where(xmask, tmp7, _tmp6) tmp6 = tl.sum(_tmp6, 1 )[:, None ] tmp9 = tl.load(in_ptr2 + (x0), xmask, eviction_policy='evict_last' ).to(tl.float32) tmp11 = tl.load(in_ptr3 + (x0), xmask, eviction_policy='evict_last' ).to(tl.float32) tmp8 = tmp6.to(tl.float32) tmp10 = tmp8 * tmp9 tmp12 = tmp10 + tmp11 tl.store(out_ptr1 + (x0), tmp12, xmask)
这张Slides主要讲述了Int8权重量化(Int8 Weight Only Quantization)的优化过程和结果。
性能问题解决:通过使用torch.compile可以解决之前遇到的性能问题。
int8权重量化优化后:135.01 tokens/s
这显示优化后的int8权重量化性能显著提升,超过了无量化版本。
cublas A16W16 matmul(蓝线)性能最佳
A16W8 fixed matmul(黄线)性能介于两者之间
尽管性能提升明显,但仍未完全匹配默认bf16的性能
这主要是由于torch.compile的开销,在端到端测试中这个差距会减小
在优化过程中遇到了triton的一些限制,通过避免使用tensor cores来绕过这些限制
目前仍然缺乏对批次大小大于1(bsz>1)的高性能内核
这里bsz=1的时候是memory bound的GEMV,如果bsz>1,这个时候就是GEMM Kernel了,很可能就是compute bound了,普通的kernel优化预计很难超越cuBLAS的性能。
从Int4 Weight Only开始,Triton开始力不从心了。要点为:
目前PyTorch没有原生的int4/uint4数据类型(dtype)。
这意味着我们需要将更大尺寸的张量拆解成多个int4类型。
由于Triton在类型转换和乘法操作上的限制,我们在实际操作中会失去更多性能。
图示展示了int4数据(4位整数)如何被打包进更大的数据类型中。
"But we can see how far we can get with just triton"(但我们可以看看仅使用triton能走多远)说明了作者打算在现有Triton框架限制下探索Int4量化的潜力。右上角显示了一个int4x2的基本结构,每个元素包含两个4位整数。下方展示了四种不同的打包/解包布局,展示了如何在更大的数据结构中组织int4数据。
Slides里面的右下角的4张图有拼写错误,注意鉴别。比如最后一张图的第一列应该是ABEF才对。
这张Slides详解了Int4权重量化(Int4 Weight Only Quantization)在矩阵乘法(matmul)中的实现策略,特别是关于数据打包和解包的选择。
在进行矩阵乘法时,由于这是权重,我们希望在int4x2格式中连续的信息在解包后仍然保持连续。
由于矩阵乘法的实现通常让单个线程处理所有的K维度,所以选择了右下角的选项。这种选择可以避免因打包方式导致线程加载不必要的数据。
Slides里面的右下角的4张图有拼写错误,注意鉴别。比如最后一张图的第一列应该是ABEF才对。
这里提供了具体的代码来展示如何打包/解包uint8和int4:
int4[2 *k,n]=(uint4x2[k,n] & 0xF ) - 8 int4[2 *k+1 ,n]=(uint4x2[k,n] >> 4 ) - 8
解释说选择uint8是因为triton框架对int8的位移操作存在问题。这里的uint4x2量化Kernel代码在:https://github.com/pytorch/pytorch/blob/main/torch/_inductor/kernel/unpack_mixed_mm.py
这张Slides主要讨论了Int4权重量化(Int4 Weight Only Quantization)的性能表现和一些相关观察。
int8权重量化优化版:135.01 tokens/s
uint4x2权重量化:43.59 tokens/s
uint4x2量化的性能(Triton实现)只有无量化情况下的1/2,而不是预期的4倍快。作者提到如果现在重新实现,会参考fast int8 kernel的方法,而不是slow int8 kernel。此外,提到Jeff Johnson(PyTorch GPU后端的开发者)使用CUDA开发了一个int4 kernel并集成到了PyTorch中,速度非常快,也就是上面表格的Int4分组量化。代码:https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml
这个是kernel的签名,感兴趣的读者可以自行查看代码。
从这个Int4 Weight Only的cuda量化kernel实现可以看到Triton的局限性。
这张Slides讨论了Triton的一些局限性:
Triton在处理复杂操作和非标准数据类型时会遇到困难。
当批处理大小大于1时,int8/int4权重量化也会遇到问题。
这张Slides介绍了Trito的优势:
提到了两个具体例子:a) Fused_int_mm_mul(融合整数矩阵乘法和乘法操作)
b) SAM flash attention(Segment Anything Model中使用的快速注意力机制)
最重要的是,使用Triton可以达到这种性能水平,而无需直接处理.cu文件(CUDA源代码文件)。
https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L325
https://github.com/pytorch-labs/segment-anything-fast/blob/main/segment_anything_fast/flash_4.py#L13
这里讲的就是SAM里面的Attention操作相比于标准的SelfAttention需要融合两个MASK,这个时候使用Triton实现的FlashAttention就可以非常快的实现这个需求,并且性能很好。
要复现作者的实验或者学习GPU上量化Kernel的实现可以点击这张Slides里的链接。
作者分享的Slides里面还有一些有趣的内容作为附录,这里挑选其中的一些来解读,主要是实验结果和概念部分,对torchao的使用部分的Slides有需要的读者可以自行查看。
这张Slides展示了对SAM(Segment Anything Model)模型进行不同量化和优化技术的实验结果。主要内容如下:
动态量化(Dynamic Quant)相比基准模型获得了约13%的速度提升。
仅权重量化(Weight Only Quant)对性能提升不明显,原因是模型主要受计算限制,且其内核设计并不针对大Batch进行优化。
所有的量化技术都只导致了很小的精度损失。图表详细展示了不同方法的性能对比:
int8 weight only quant(8位整数仅权重量化)
int8 dynamic quant(8位整数动态量化,包括权重和激活)
2:4 pruned cusparselt(一种稀疏化技术)
表格中比较了这些方法在以下几个方面的表现:
相对于SDPA的加速比(speedup over SDPA)
COCO 2017验证集上的准确率(coco 2017 val accuracy)
这张Slides展示了对Llama2 7B模型进行不同量化方法的实验结果。主要内容如下:
使用仅权重int8和int4量化分别实现了45%和86%的加速。
int4仅权重量化导致了小幅度的精度下降,但使用GPTQ(一种量化技术)可以恢复其中一半的精度损失。
动态量化(Dynamic Quantization)虽然测试过,但因为模型受内存限制,其精度和性能都不如仅权重量化,所以未列入表格。
int8 weight only quant(8位整数仅权重量化)
int4g128 weight only groupwise quant(4位整数分组仅权重量化)
每秒处理的token数(bs 1 (tok/s))
wikitext bits_per_byte(困惑度相关指标)
结果显示,int4量化提供了最大的速度提升(1.86倍),但有轻微的精度损失。int8量化在保持精度的同时也提供了显著的速度提升(1.45倍)。
这张Slides展示了对Llama2 7B模型进行模拟低精度量化的实验结果。主要内容如下:
实验目的:了解分组大小(groupsize)、位数(bit number)和GPTQ(量化技术)如何影响模型准确性。实验使用wikitext bits_per_byte困惑度作为评估指标。
GPTQ效果:在大多数情况下,GPTQ能够恢复约一半的性能损失(PPL,困惑度)。特例:在G=64、2位量化的情况下,未使用GPTQ时的PPL异常地高。