专栏名称: GiantPandaCV
专注于机器学习、深度学习、计算机视觉、图像处理等多个方向技术分享。团队由一群热爱技术且热衷于分享的小伙伴组成。我们坚持原创,每天一到两篇原创技术分享。希望在传播知识、分享知识的同时能够启发你,大家一起共同进步(・ω<)☆
目录
相关文章推荐
51好读  ›  专栏  ›  GiantPandaCV

详解vLLM和SGLang awq dequantize kernel的魔法

GiantPandaCV  · 公众号  · 3D  · 2025-03-16 22:48

正文

0x0. 前言

本片文章解析一下vLLM/SGLang中 awq int4的反量化kernel,这个kernel触发条件为当输入x的shape的tokens<256时,这个时候会先把int4的awq权重使用 awq_dequantize 反量化回float16,然后调用PyTorch Matmul执行float16的乘法,代码位置见: https://github.com/vllm-project/vllm/blob/b82662d9523d9aa1386d8d1de410426781a1fa3b/vllm/model_executor/layers/quantization/awq.py#L162-L184

def apply(self,
          layer: torch.nn.Module,
          x: torch.Tensor,
          bias: Optional[torch.Tensor] = None)
 -> torch.Tensor:

    qweight = layer.qweight
    scales = layer.scales
    qzeros = layer.qzeros
    pack_factor = self.quant_config.pack_factor
    out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
    reshaped_x = x.reshape(-1, x.shape[-1])

    # num_tokens >= threshold
    FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256

    if FP16_MATMUL_HEURISTIC_CONDITION:
        out = ops.awq_dequantize(qweight, scales, qzeros, 000)
        out = torch.matmul(reshaped_x, out)
    else:
        out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
                           pack_factor)
    if bias isnotNone:
        out.add_(bias)
    return out.reshape(out_shape)

本文要解析的就是这里的 vllm ops.awq_dequantize 这个kernel,这个kernel的代码单独抽出来只有几十行代码,但是代码中涉及到的魔法和数学有点多,如果不了解这里的原理就会很痛苦,所以我这里来详细解析一下。vllm ops.awq_dequantize 这个算子的原始来源是FasterTransformer仓库,然后sglang的sgl-kernel也有一份针对这个算子的干净实现,并通过调整线程块有更快的速度,我这里直接针对这份代码来解析,链接见:https://github.com/sgl-project/sglang/blob/main/sgl-kernel/csrc/gemm/awq_kernel.cu#L7-L127

还需要说明一下,对于AWQ/GPTQ来说,权重的量化不是PerChannel的而是GroupWise的,也就是在K方向会有GS组Scales和Zeros,例如假设K/GS=128,那就是在K方向有128行的Weight共享一个Scales和Zeros。因此,它和PerChannel的差异就是需要在反量化的时候乘以Scales并加上Zeros。除此之外,AWQ本身需要在Activation计算之前乘以它自己的ActScale。在下面的Kernel中,针对的是weight,K方向就是行(row)方向。

0x1. 接口函数

// PyTorch接口函数,用于AWQ权重反量化
torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros) {
// 获取输入张量的维度信息
int qweight_rows = qweight.size(0);
int qweight_cols = qweight.size(1);
int group_size = qweight_rows / scales.size(0); // 计算量化组大小

// 设置CUDA网格和块的维度
int x_num_threads = 16;
int y_num_threads = 16;
int x_blocks = qweight_cols / x_num_threads;
int y_blocks = qweight_rows / y_num_threads;

// 确保在正确的CUDA设备上执行
const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));

// 创建输出张量,与scales具有相同的数据类型和设备
auto output_tensor_options = torch::TensorOptions().dtype(scales.dtype()).device(scales.device());
  at::Tensor output = torch::empty({qweight_rows, qweight_cols * 8}, output_tensor_options);

// 获取各个张量的数据指针
auto _qweight = reinterpret_cast<int*>(qweight.data_ptr<int>());
auto _scales = reinterpret_cast(scales.data_ptr<:half>());
auto _zeros = reinterpret_cast<int*>(qzeros.data_ptr<int>());
auto  _output = reinterpret_cast(output.data_ptr<:half>());

// 配置CUDA核函数的执行参数
dim3 num_blocks(x_blocks, y_blocks);
dim3 threads_per_block(x_num_threads, y_num_threads);

// 获取当前CUDA流并启动核函数
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  dequantize_weights<<0, stream>>>(
      _qweight, _scales, _zeros, _output, group_size, qweight_cols);

// 返回反量化后的权重张量
return output;
}

需要注意的点是,kernel的输入是 int4 类型的,输出是 float16 类型的,然后输入的shape是 [qweight_rows, qweight_cols] ,输出的shape是 [qweight_rows, qweight_cols * 8] 。由此,我们也可以看出输入数据的元素是一个32位整数 source ,它包含了8个4位整数(每个4位可以表示0-15的值)。这8个4位整数被紧密地打包在一起,如下图所示:

[4bit][4bit][4bit][4bit][4bit][4bit][4bit][4bit]

接下来,在kernel launch配置方面,使用二维的线程网格和线程块,并且每个线程处理输入Tensor中的一个元素,非常直观:

int x_num_threads = 16;
int y_num_threads = 16;
int x_blocks = qweight_cols / x_num_threads;
int y_blocks = qweight_rows / y_num_threads;
dim3 num_blocks(x_blocks, y_blocks);
dim3 threads_per_block(x_num_threads, y_num_threads);

0x2. dequantize_weights kernel 流程

// 权重反量化的CUDA kernel,最大线程数为256
__global__ void __launch_bounds__(256) dequantize_weights(
    int* __restrict__ qweight,    // 量化后的权重
    half* __restrict__ scales,    // 量化比例因子
    int* __restrict__ qzeros,     // 量化零点
    half* __restrict__ output,    // 输出的反量化权重
    int group_size,               // 量化组大小
    int qweight_cols) {           // 量化权重的列数
// 计算当前线程处理的列和行索引
int col = blockIdx.x * blockDim.x + threadIdx.x;
int row = blockIdx.y * blockDim.y + threadIdx.y;

// 获取当前处理位置的零点,并反量化为fp16x2格式
  uint4 zeros = dequantize_s4_to_fp16x2(qzeros[col + (row / group_size) * qweight_cols]);
// 加载对应的缩放因子
  uint4 loaded_scale = *(uint4*)(scales + 8 * col + (row / group_size) * qweight_cols * 8);

// 将量化权重反量化为fp16x2格式
  uint4 weight_fp16 = dequantize_s4_to_fp16x2(qweight[col + row * qweight_cols]);

// 对每个fp16x2元素执行(weight - zero) * scale操作
// 处理第一对fp16值
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x));
asm volatile("mul.rn.f16x2 %0, %1, %2;\n"  : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(loaded_scale.x));
// 处理第二对fp16值
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(zeros.y));
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(loaded_scale.y));
// 处理第三对fp16值
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(zeros.z));
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(loaded_scale.z));
// 处理第四对fp16值
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(zeros.w));
asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(loaded_scale.w));

// 计算输出指针位置并存储结果
  half* output_ptr = output + 8 * col + 8 * row * qweight_cols;
  *(uint4*)output_ptr = weight_fp16;
}

这里整体是非常好理解的,我们根据线程id定位到当前线程处理的列和行索引之后分别加载零点zeros,缩放系数loaded_scale和权重weight_fp16并对zeros/weight_fp16应用 dequantize_s4_to_fp16x2 反量化kernel把当前行列所在的int32类型的值(8个int4)反量化为8个half类型的输出值,注意这里是用4个half2来存储的。然后使用 (weight - zero) * scale 操作来完成反量化的过程。

这里解析一个 asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x)); 指令:

这行代码使用了CUDA PTX,用于执行半精度浮点数(fp16)的减法操作。它的基本语法为:

asm [volatile] ("汇编指令" : 输出操作数 : 输入操作数 : 可能被修改的寄存器);

下面是详细解析:

  • asm volatile
    • asm 关键字表示这是内联汇编代码
    • volatile 修饰符告诉编译器不要优化或重排这段汇编代码,确保它按照指定的顺序执行
  • sub.f16x2 %0, %1, %2;\n
    • 这是实际的CUDA PTX汇编指令
    • sub.f16x2 是CUDA的指令,表示对两个并排的fp16值(packed half2)执行减法操作
    • %0, %1, %2 是占位符,分别对应后面定义的输出和输入操作数
    • \n 是换行符,用于格式化汇编代码
  • : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x));
    • 第一个冒号后的 "=r"(weight_fp16.x) 是输出操作数,=r 表示这是一个输出到通用寄存器的值
    • 第二个冒号后的 "r"(weight_fp16.x) "r"(zeros.x)) 是两个输入操作数,r 表示它们来自通用寄存器

通过这个指令就实现了反量化中的减零点的功能,kernel中其它的ptx指令类推。

0x3. dequantize_s4_to_fp16x2 kernel(魔法发生的地方)

这段代码对应的原理在nvidia 2023年夏日专场其实简单讲了一下,我这里结合当时的PPT复述一下这里的原理,通过这个复述读者稍后就可以知道代码中的那一堆魔术和用于计算的PTX指令是做了什么了。注意下面引用的图来BiliBili NVIDIA英伟达频道 上传的《TensorRT-LLM中的 Quantization GEMM(Ampere Mixed GEMM)的 CUTLASS 2.x 实现讲解》。

FasterTransformer 高效的Int8/Int4 快速Convert为FP16

这张slides展示了FP16的IEEE 754标准,一个16bit的数里面包含1个符号位,5个基码位,10个尾数。

假设我们有一个uint8的数143,如果我们把它放到实际的FP16的尾数位里面去,那么我们是否有办法通过合理的设置基码位把143表达出来呢?那我们按照已知的FP16的数值计算方法,拿基码位的二进制前面加上一个1.x,然后去乘以2的(基码位的值-15)次方,我们已知143对应的实际上对应的是下面的值。假设我们想用这个FP16的值来表达Int8,我们可以发现如果x=25的话,我们把上面的FP16的值减去1024就是下面的143了。因此,我们只需要把int8的值放到尾数位,然后把它的基码位设置成25,然后再把FP16的数值结果减去1024就可以得到UINT8转换到FP16的值。

总结一下就是直接把UINT8的数值放在FP16的尾数位,

然后再把FP16的基码位设置成25,这个25对应的十六进制表示就是0x64,

随后再把最终的这个值减去FP16形式的1024,就完成了从UINT8到FP16的转换。

如果是Int8的话,应该怎么做呢?可以注意到UINT8和INT8只是数值范围的区别,那么我们需要把INT8的数据加上128,就能把它转换成UINT8的形式。这样转换出来的FP16的结果,只需要在减去1024的时候多减去128,就恢复到了对应的原始INT8的数值。

那么我们怎么实际的去用指令完成上面描述的这个操作呢?可以注意到有一种叫作prmt的PTX指令,这个指令做的事情就是从2个32bit的寄存器A,B中抽出4个8bit组成最终的d。而这4个8bit怎么抽取,就是每个8bit对应到c寄存器里面的低4bit,就是说c寄存器的低4bit每个bit都是一个索引,假设A,B两个32位寄存器里面存放的是上方左图这样的数据形式,即ABCDEFGH。那么在c寄存器中,索引的4个数字分别是1,3,5,7,那么最终这个D寄存器里面的4个8bit数据就是GECA。通过这种指令就可以实现从32bit寄存器里面抽取对应想要的一个字节出来的效果。

对应到TRT-LLM的转换代码就是这样的形式,我们可以注意到它用permute指令从输入的UINT8数据和magic number组成的这两个32位寄存器中去抽取4个8bit,抽取的索引放在这个mask_for_elt_01/23中。这里的两个掩码值 mask_for_elt_01 = 0x5250 mask_for_elt_23 = 0x5351 是用于CUDA的PRMT(Permute)指令的控制参数,它们决定了如何重排字节。

--------------------分割线---------------------

这里我感觉比较难理解,所以下面详细拆解一下:

PRMT指令基础

首先,PRMT指令的格式是:

prmt.b32 d, a, b, c;

其中, d 是目标寄存器; a b 是源寄存器; c 是控制码(即我们讨论的掩码)。然后PRMT指令将 a







请到「今天看啥」查看全文