DeepSeek-V3 横空出世,训练和推理成本极低,一个重要的原因就是采用了 FP8 进行训练和推理,今天结合最近的实践来分析一下其中的原理:
Group/Block wise 量化
分块量化(Block-wise Quantization),也称为分组量化(Per-group Quantization),是一种细粒度量化方法。
特征异常值是指在特征分布中远离大部分数据的极端值。这些异常值对量化尤其具有挑战性,因为如果使用全局的量化参数(例如最大值),则这些异常值可能会导致大部分数据的量化精度下降。
细粒度量化的核心思想是使用更精细的量化粒度,即对输入和权重的不同部分使用不同的缩放因子。这样可以更好地适应数据的局部特征,减少异常值的影响。
分块量化将张量分割成更小的块或组,并为每个块分配独立的量化参数(缩放因子
s
和零点
z
)。
如上图所示,矩阵被分割成多个小块,每个小块使用不同的颜色进行标注,对应不同的量化参数。
-
优点
:提供了对量化过程更精细的控制,通常会在模型精度和计算效率方面带来更好的性能。通过调整块的大小,可以在精度和效率之间进行灵活的权衡。相比逐张量量化,分块量化能够更好地适应张量内部数据分布的变化,减少量化误差;相比逐通道量化,分块量化可以减少需要存储的量化参数数量,从而降低存储开销。
-
缺点
:需要合理划分组别,增加了量化策略的设计复杂性,而且分块量化一般对硬件不友好,计算效率低。
总之 Block-wise 量化是对矩阵分组,每一组有独立的量化参数,可以更好的控制精度损失。
DeepSeek-V3 量化配置
首先看 DeepSeek-V3 FP8 版本的模型配置:
"quantization_config": {
"activation_scheme": "dynamic",
"fmt": "e4m3",
"quant_method": "fp8",
"weight_block_size": [
128,
128
]
}
量化精度:FP8
量化粒度:
-
权重:block-wise 量化, 每个 block 的 shape 是[128,128],
静态离线
量化
-
激活:per-token-group 量化,
动态在线
量化
(1) 对于激活值,我们以 1x128 的 组 为基础对元素进行分组和缩放(每个 token 每 128 个通道);
(2) 对于权重,我们以 128x128 的 块 为基础对元素进行分组和缩放(每 128 个输入通道每 128 个输出通道)。
结合上图我们来看下如何对权重和激活值进行量化。
权重量化(block-wise)
假设
权重 B
的shape为:
[hidden_dim, out_dim]
1.分块方式:
-
在
hidden_dim
维度上每 128 个输入特征一组
-
在
out_dim
维度上每 128 个输出特征一组
2.量化缩放因子(scales):
-
Bs
的shape:
[hidden_dim//128, out_dim//128]
-
-
激活量化(per-token-group)
假设
激活A
输入的shape为:
[batch_size x seq_len, hidden_dim]
1.分块方式:
-
对于每一个 token,在
hidden_dim
维度上每 128 个通道的激活值分为一组,并为这一组计算一个单独的缩放因子。
2.量化缩放因子:
-
As
的 shape:
[batch_size x seq_len, hidden_dim//128]
-
-
FP8-GEMM 工程实现
下面主要针对 FP8 GEMM 的工程实现讨论。
理解了上面的权重和激活量化原理,那么下面来看如何进行两个FP8量化矩阵的乘法
运算。
经过量化,我们得到了下面参数:
// inputs
// A [M, K] fp8 (按行分组量化,每组对应一个 As 元素)
// B [N, K] fp8 (按块量化,块大小为 [block_k, block_n],每个块对应一个 Bs 元素)
// As [M, K/block_k] fp32 (A 的每行(或每组)的量化比例因子)
// Bs [K/block_k, N/block_n] fp32 (B 的每个块的量化比例因子)
// outputs
// mat [M, N] fp32
下面来看一下 DeepSeek-V3 报告里对 FP8-GEMM 的 CUDA 层面计算流程的解释:
GPU计算流程
背景:
-
下溢和精度损失:
使用 FP8 等低精度格式进行 GEMM 运算时,中间结果的累加容易出现下溢,导致精度损失。传统的做法是使用 FP32 进行累加,以保持精度。
下溢指的是
计算结果的绝对值非常小,小于浮点数所能表示的最小正数(非零)
。换句话说,计算结果太接近于零,以至于计算机无法用当前的浮点数格式精确地表示它,通常会被近似为零。
DeepSeek-V3 的方案:
所有FP8张量都采用E4M3格式(4位指数和3位尾数),以获得更高的精度.
FP8表示
计算过程:
以
𝑁𝐶 = 128
个元素 MMA 的间隔转移到 CUDA Cores 进行高精度累加。
计算流程
每当 Tensor Core 累加了 128 个 FP8 结果后,就会将这些结果转换(或缩放)到 FP32 精度,然后在 CUDA Cores 的 FP32 寄存器中进行累加。
计算流程:
-
Tensor Core 以 FP8 精度高效地执行大量的矩阵乘法和累加(MMA)操作。使用低精度累加器存储中间结果
-
每累加 128 个元素(Nc = 128),就将这些 FP8 累加结果转换到 FP32 精度。
-
在 CUDA Cores 的 FP32 寄存器中进行高精度的累加,最终结果经过Scaling Factor缩放,也就是反量化。
-
重复步骤 1-3,直到完成所有的矩阵乘法和累加操作。
Python native实现
核心代码:
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""This function performs matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
B_tiles = [
[
B[
j * block_n : min((j + 1) * block_n, N),
i * block_k : min((i + 1) * block_k, K),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i] # [M, 128]
b = B_tiles[j][i]. #[128, 128]
c = C_tiles[j] # [M, 128]
s = As_tiles[i] * Bs[j][i] #[M, 1]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
可以结合上面对矩阵乘法的注释来理解分块矩阵乘法的过程:
进行矩阵乘法的时候,先对矩阵 A 和 B 依照各自的量化粒度分块,在分块的粒度上进行矩阵乘法运算,然后再乘以量化因子进行反量化,得到分块的FP32浮点结果。
Trition 实现
代码参考 sglang 中的实现:
1.函数接口:
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
2.Triton 算子配置
# 尝试加载之前通过 tuning 方式获得的最佳配置信息。
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
可以通过对 Triton 算子进行 tuning 来得到最优的 kernel 配置,接着调用 Triton 算子。
3.Triton算子实现
我觉得Triton 的代码介于 PyTorch 和 CUDA 代码之间,它提供了一种比手写 CUDA 算子更高层次的抽象,方便开发。
核心计算流程如下,注意累加器 accumulator 是 float32 精度的。
@triton.jit
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
):
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
Cutlass 实现
先了解一下几种量化缩放的术语(和量化粒度有关):
-
张量级缩放(Tensor-wise Scaling):
每个张量使用单个缩放因子,在尾声(epilogue)中应用。
-
行级缩放(Row-wise Scaling):
使用一个行向量进行缩放,对于操作数 A 的维度为 Mx1,对于操作数 B 的维度为 1xN,避免沿归约维度进行缩放。这也可以在尾声中使用 EpilogueVisitorTree 来处理。
-
分块缩放(Block-wise Scaling):
引入一个 2D 缩放张量,每个 CTA 块分配一个缩放值。由于此缩放涉及归约维度 (M, N, K),因此必须在主循环中应用,这会影响性能。
-
分组缩放(Group-wise Scaling):
使用一个 2D 缩放张量,每个 CTA 块有多个缩放值。缩放粒度独立于 CTA 块配置,为将来的实现提供了更大的灵活性。
关于 FP8-block-wise 量化有先后两个 PR,第一个 PR 先支持了
Blockwise Scaling
,第二个 PR 在第一个的基础上支持了
Groupwise Scaling,
下面依次介绍。