专栏名称: ChatAI42技术与产品
智能聊天机器人(Chatbots)是交互的新趋势,Google、Facebook、Microsoft、百度、阿里等众多公司已加入此阵列,就等你了!我们会定期发布聊天机器人的各种信息,其中使用的机器学习/深度学习技术、产品、分享活动等等
目录
相关文章推荐
51好读  ›  专栏  ›  ChatAI42技术与产品

估计KL散度的艺术:平衡偏差与方差的实用指南

ChatAI42技术与产品  · 公众号  · 机器人  · 2025-02-12 12:12

主要观点总结

本文主要探讨了KL散度的三种蒙特卡洛估计方法,包括原始估计量k₁、低方差估计量k₂和突破性改进的无偏低方差估计量k₃。文章详细阐述了这三种估计量的特点,通过理论分析和实验验证,展示了它们在偏差和方差之间的权衡。文章还讨论了不同场景下的性能对比和推荐使用的估计量。

关键观点总结

关键观点1: 三种KL散度的蒙特卡洛估计方法

包括原始估计量k₁、平方对数估计量k₂和控制变量法的妙用估计量k₃。每种估计方法都有其特点和适用场景。

关键观点2: 估计量的偏差与方差权衡

k₁严格无偏但方差极高,k₂在小差异场景中偏差可忽略但方差较低,k₃实现无偏且低方差。

关键观点3: 实验验证与推荐估计量

通过实验对比不同估计量的性能,根据场景特征推荐合适的估计量。


正文

请到「今天看啥」查看全文


Home [1] | GitHub [2] | Twitter [3] | Youtube [4] | Bilibili [5]

问题背景

最近看 DeepSeek 论文和 GRPO 时,发现他们用了一种很有意思的 KL 散度近似预估形式,就深入了解了下其来源。本文对其来源做个简单的说明。

在概率建模和强化学习中, KL散度 Kullback-Leibler Divergence )是衡量两个概率分布差异的常用指标。其定义为:

当解析解难以计算时(如高维空间或复杂分布),我们常借助蒙特卡洛方法对其值进行估计。本文将探讨 的三种不同估计量,揭示它们在偏差与方差间的精妙权衡。

📌 一个好的估计量应该是 无偏(具有正确的均值)且方差低的

不同估计量及其局限

1. 原始估计量(k₁)

直接从定义出发,使用单样本对数比值的期望:

k₁ 特点

  • 无偏性
  • 高方差 :对数比值在 区域会产生极端负值,导致估计震荡剧烈

2. 低方差估计量的秘密:f-散度的启示

平方对数估计量(k₂)

引入新的统计量:

具有低偏差 :其期望是一个 -散度。 -散度定义为 ,其中 是一个凸函数且 。KL 散度以及其他各种著名的概率距离都是 -散度。

现在这里有一个关键的非显而易见的事实:当 接近 时,所有具有可微 -散度在二阶近似下都类似于 KL 散度。具体来说,对于一个参数化分布

其中 的 Fisher 信息矩阵,在 处计算。

下面做个推导。

  • f-散度的泰勒展开 接近 时,比值 接近 1。将 处进行二阶泰勒展开:
  • 代入f-散度定义 将展开式代入 ,得到:
  • 一阶项 。由于 ,一阶项为
  • 二阶项 。要证明当 接近 时,

其中 是 Fisher 信息矩阵,可以通过以下步骤完成:

(1). 参数化分布与对数展开

假设 是参数化分布,且当 。将 处进行泰勒展开:

其中:

  • 得分函数 (记作 ),满足:
  • Fisher 信息矩阵

(2). 近似分布比值

通过指数化对数展开式,得到 的近似:

取倒数并展开到二阶:

(3). 计算

将上述近似代入平方差:

保留至二阶项(忽略高阶小量):

(4). 计算期望值

对近似后的平方差取期望 。由于 ,可用 的期望近似:

逐项分析:

(a) 第一项

(b) 第二项 :由于 ,交叉项

因此,主要贡献来自第一项:

k₂ 特点

  • 有偏性 :有偏估计,但偏差较小
  • 低方差优势 :始终非负,排除极端样本干扰

3. 突破性改进:无偏低方差估计量

控制变量法的妙用(k₃)

是否有可能写出一个无偏且方差低的 KL 散度估计量呢?降低方差的一般方法是使用控制变量。即, 基础上加上一个期望为零但与 负相关的量 。保证期望为零的量是 。因此,对于任何 ,表达式 的无偏估计量。我们可以进行计算以最小化这个估计量的方差并求解 。但不幸的是,我们得到一个依赖于 的表达式,很难进行解析计算。然而,我们可以使用一个更简单的策略选择一个好的

注意,由于对数函数是凹函数, 。因此,如果我们 ,上述表达式保证为正 。它测量 与其切线之间的垂直距离。这就给我们留下了估计量

通过引入期望为零的调节项 ,构造了新估计量:

k₃ 特点

  1. 1. 无偏性
  2. 2. 方差压缩 :利用 的凸性,降低方差

实验验证:不同场景下的性能对比

实验代码:

import torch.distributions as dis

# 设置分布
p = dis.Normal(loc=0, scale=1)
q = dis.Normal(loc=0.1, scale=1)

# 生成样本
x = q.sample(sample_shape=(10_000_000,))
truekl = dis.kl_divergence(p, q).item()

# 计算各估计量
logr = p.log_prob(x) - q.log_prob(x)
k1 = -logr
k2 = logr**2 / 2
k3 = (logr.exp() - 1) - logr

# 输出结果
print(f"真实KL值: {truekl:.4f}")
for name, k inzip(["k1""k2""k3"], [k1, k2, k3]):
    bias_ratio = (k.mean().item() - truekl)/truekl
    std_ratio = k.std().item()/truekl
    print(f"{name} : 偏差={bias_ratio:.3%} 标准差={std_ratio:.2f}")

案例一:微小差异(KL=0.005)

假设 。则真实的 KL 值为 0.005 。

估计量
相对偏差
相对标准差
k₁
0
20
k₂
0.002
1.42
k₃
0
1.42

案例二:显著差异(KL=0.5)

假设 。则真实的 KL 值为 0.5 。

估计量
相对偏差
相对标准差
k₁
0
2
k₂
0.25
1.73
k₃
0
1.7

关键发现

  • • k₂ 在差异较小时偏差可忽略,但显著差异时偏差增大。
  • • k₃ 始终保持无偏,且方差与 k₂ 相当甚至更低。

总结与展望

通过理论分析与实验验证,我们展示了KL散度估计中的权衡艺术。 k₃ 估计量的提出,为需要精确评估概率分布差异的场景提供了可靠的工具。在实际应用中,建议根据具体需求选择合适的估计量,在计算效率与估计精度间取得最佳平衡。

核心洞见

本文系统探讨了三种KL散度的蒙特卡洛估计方法,揭示了 KL 估计量设计中的核心矛盾—— 偏差与方差的权衡

  • k₁(原始估计量) :严格无偏但方差极高,适用于理论验证,实际应用受限
  • k₂(平方对数估计量) :通过f散度框架实现低方差,在小差异场景中偏差可忽略,是快速诊断的理想选择
  • k₃ :融合控制变量法与凸优化思想,实现 无偏+低方差 的突破
场景特征
推荐估计量
关键优势
精确度量分布差异
k₃
无偏且稳定
实时训练监控
k₂
计算高效,偏差可接受
理论验证/敏感检测
k₁
严格无偏,基准参考

References

  • • 英文原文: Approximating KL Divergence from John Schulman [6]


引用链接

[1] Home : https://www.breezedeus.com
[2] GitHub : https://github.com/breezedeus
[3] Twitter : https://twitter.com/breezedeus
[4] Youtube : https://www.youtube.com/@breezedeus
[5] Bilibili : https://space.bilibili.com/509307267
[6] Approximating KL Divergence from John Schulman: http://joschu.net/blog/kl-approx.html


更多推荐阅读








请到「今天看啥」查看全文