专栏名称: Python开发者
人生苦短,我用 Python。伯乐在线旗下账号「Python开发者」分享 Python 相关的技术文章、工具资源、精选课程、热点资讯等。
目录
相关文章推荐
Python爱好者社区  ·  KAN教程PDF(附代码) ·  2 天前  
Python爱好者社区  ·  曝京东发出整体退租邮件,将搬离华南最大办公场 ... ·  6 天前  
Python爱好者社区  ·  强烈建议程序员们搞个香港身份,再不冲就晚了! ·  6 天前  
Python爱好者社区  ·  《黑神话 . 悟空》研发公司的薪资水平 ·  1 周前  
51好读  ›  专栏  ›  Python开发者

机器学习中的样本重要性权重 (Importance Weight)

Python开发者  · 公众号  · Python  · 2024-09-02 08:30

正文

样本重要性权重(Importance Weighting, IW)是一种在机器学习中应对「训练-测试数据分布不一致」问题的经典方法,通过对样本给予合适的权重,理论上我们可以在分布不一致的情况下,学出在目标分布上的无偏估计。

unsetunset简单理论推导unsetunset

假设训练集样本 来自于分布 ,我们称该分布为原始分布(Source Distribution),在该分布上我们要学习某个函数,从而在某个目标分布(Target Distribution)上进行预测。在目标分布 上对的估计为:

在原始分布上对的估计则为:

要想在原始分布上能够拟合出在目标分布 上的无偏估计,我们可以通过以下的变换来得到:

可见,对函数乘上一个权重 ,然后再在原始分布上进行估计,就可以得到在目标分布上的无偏估计,这里的就是所谓的样本重要性权重(Importance Weight,IW)

也就是说,我们只要把我们的学习目标从变成,这样学习出来的就相当于在目标分布Q上拟合出来的一样。

unsetunset加权的损失函数unsetunset

在训练机器学习模型时,我们常使用对数最大似然函数 来进行估计,所以我们可以设 ,可知在原始分布上 的期望就等于在目标分布上的期望,注意到这个对数概率取负号就正好就是样本的交叉熵损失,所以等同于在计算损失函数的时候,对对应的样本加上权重

unsetunset数据漂移(Covariant Shift)和标签漂移(Label Shift)unsetunset

原始分布和目标分布不同,可能有两个方面:

  • 数据漂移,又称协变量漂移(covariant shift),即 X 存在不同分布。比如训练猫狗图像分类的时候,都是真实照片,但是测试的时候却是卡通图片。
  • 标签漂移,即 Y 存在不同的分布。最常见的就是类别分布不平衡,训练时某些类别明显偏多,而其他类别因为资源问题或者采样偏差导致数量匮乏。

举例:Y的漂移——样本不平衡的情况

类别不平衡是一种常见的问题。假设两个类别c1和c2在训练集中的比例为3:1,而测试情况下我们同样关注c1和c2,所以可以假设测试环境中二者比例为1:1。这是一种label分布不一致的情况,在这种情况下,我们可以用label的分布,来计算IW。

  • 当样本标签为c1时,IW=1/3
  • 当样本标签为c2时,IW=1/1

因此,在训练时,我们可以给两个类别的样本,来计算loss的时候分别配以以上权重。

举例:回归问题中X的漂移

在论文《Improving predictive inference under covariate shift by weighting the log-likelihood function》中,有一个形象的例子:

左图所示,一批样本是由 产生,这是来自上帝视角的一个函数,是不可知的,用于产生样本点。

图片的小圈圈就是可观察到的一批样本,是由一个原始分布产生,可见 X 是分布不均匀的,在 处分布最集中。可以理解为,这就是我们采集到的训练数据的分布。在这批样本上直接进行OLS回归,拟合出来的函数是左图中黑色的实线,很明显跟真实曲线误差很大。

假设在测试场景中,样本的X来自于另外一个分布 ,就如右图所示的那些小黑点,很明显,用训练集拟合出来的黑色实线在这些样本上测试的话,效果就很差,二者大概的方向都是反的,一个跟X正相关一个则是负相关。

假设我们提前预知了测试的分布,然后使用 Weighted OLS (WOLS)进行回归,即在进行最大似然估计的时候,使用 进行计算,其中,那么拟合结果就是左图中的黑色虚线,很明显这个虚线可以很好地表达测试样本的趋势,这就是样本权重的作用,可以在仅使用原始分布的样本的情况下,拟合出对目标分布上的几乎无偏的估计。

当然,这里的前提是我们需要“提前预知”目标分布,这几乎是不可能的,或者很难获取的。但通过一些先验,我们也许可以找到一些近似的分布,只要能够比原始分布更好,那通过这样的样本权重,就可以训练出更好的模型。


参考文献:

  • Improving predictive inference under covariate shift by weighting the log-likelihood function. Journal of Statistical Planning and Inference, 2020

  • What is the effect of importance weighting in deep learning? ICML, 2019


推荐阅读  点击标题可跳转

1、详解常用机器学习算法优缺点

2、一文解决样本不均衡(全)

3、两行代码,应用 40 个机器学习模型!