专栏名称: 极市平台
极市平台是由深圳极视角推出的专业的视觉算法开发与分发平台,为视觉开发者提供多领域实景训练数据库等开发工具和规模化销售渠道。本公众号将会分享视觉相关的技术资讯,行业动态,在线分享信息,线下活动等。 网站: http://cvmart.net/
目录
相关文章推荐
深夜书屋  ·  真的太怀念,太感动了! ·  13 小时前  
哔哩哔哩  ·  纪录片首曝,《哪吒2》的幕后秘密 ·  昨天  
读书杂志  ·  中读年卡 | 中国纹样,美! ·  2 天前  
冯唐  ·  在我的爱情里,女方掌握主导权 ·  5 天前  
51好读  ›  专栏  ›  极市平台

线性回归的解析解与数值解(含代码)

极市平台  · 公众号  ·  · 2024-08-27 22:00

正文

↑ 点击 蓝字 关注极市平台
作者丨单博
来源丨笑傲算法江湖
编辑丨极市平台

极市导读

机器学习基础夯实篇-详解 线性回归的解析解与数值解。>> 加入极市CV技术交流群,走在计算机视觉的最前沿

很多人在开始学习机器学习的时候都看不上线性回归,觉得这种算法太老太笨,不够fancy,草草学一下就去看随机森林、GBDT、SVM甚至神经网络这些模型去了。但是后来才发现线性回归依然是工业界使用最广泛的模型。而且线性回归细节特别多,技术面时被问到的概率也很大,希望大家能学好线性回归这块机器学习,也可能是一个offer的敲门砖。

学习中,顺着线性回归,可以引申出多项式回归、岭回归、lasso回归,此外还串联了逻辑回归、softmax回归、感知机。通过线性回归,还能巩固和实践机器学习基础,比如损失函数、评价指标、过拟合、正则化等概念。最后,线性回归与后续要学到的神经网络、贝叶斯、SVM、PCA等算法都有一定的关系。

本文将会出现不少数学公式,需要用到线性代数和微积分的一些基本概念。要理解这些方程式,你需要知道什么是向量和矩阵,如何转置向量和矩阵,什么是点积、逆矩阵、偏导数。对于极度讨厌数学的读者,还是需要学习这一章,但是可以跳过那些数学公式,希望文字足以让你了解大多数的概念。

线性模型

线性模型的表达式很简单:

其中:

  • 是线性模型, 输出预测值
  • 的特征数量, 或者说属性个数, 有几个特性属性, 就叫几元。
  • 是第 个特征, 举个例子, 影响房价的因素有地段、房屋面积、房龄等等, 这三个就是特征/属性, 影响糖尿病病情的有年龄、体重BMI、血压等等, 这三个也是特征/属性。
  • 是第 个模型参数, 或者说权重系数, 就是上述属性的参数。
  • 是偏差项, 或者叫偏置项, 也是参数。

上面的式子,用向量形式写就是:

其中 , 注意这里不是逗号是分号, 是表示换行; 学得之后, 模型就可以确定。

线性模型形式简单、易于建模, 但却蕴涵着机器学习中一些重要的基本思想。许多功能更为强大的非线性模型 (nonlinear model) 可在线性模型的基础上通过引入层级结构或高维映射而得。此外, 由于 直观表达了各个特征在预测中的重要性, 因此线性模型有很好的可解释性(comprehensibility)。

为什么需要 (Bias Parameter)?类似于线性函数中的截距,在线性模型中补偿了目标值的平均值(在训练集上的)与基函数值加权平均值之间的差距。即打靶打歪了,但是允许通过平移固定向量的方式移动到目标点上(每个预测点和目标点之间的偏置都必须是固定的)。

这里还需要澄清一下,什么是“线性”?很多初学者都会把“线性”简单的理解成预测的模型是一条线,或者在分类任务中用一条线把数据集分开。初学这么理解没有错。其实,线性是描述自变量之间只存在线性关系,即自变量只能通过相加或者相减进行组合,通俗来说就行没有 这样的高次形式。

线性回归

线性回归(linear regression)属于监督学习, 解决回归问题。目标是, 给定数据集 , 试图从此数据集中学习到一个函数 也就是求得变量参数 , 使得 , 也就是说预测出来的值和真实值无限接近。

这里解释一下 为什么只能无限接近? 因为我们只能拿某一类事件所有数据中抽样出来的部分数据进行学习,抽样出来数据不可能涵盖事件所有的可能性,所以最终只能学习到抽样代表的总体的规律。

如何衡量 之间的差别呢?

还记得之前的文章里说,回归问题的评价指标常用均方误差MSE吗?

MSE物理意义怎么解释?

均方误差有非常好的几何意义,它对应了常用的欧几里得距离或简称“欧氏距离”(Euclidean distance)。基于均方误差最小化来进行模型求解的方法称为“最小二乘法”(least square method)。在线性回归中,最小二乘法就是试图找到一条直线,使所有样本到直线上的欧氏距离之和最小。

最小二乘法:使得所选择的回归模型应该使所有观察值的残差平方和达到最小

如何求解模型参数 呢?

两个方法,一种是解析法,也就是最小二乘。另一个是逼近法,也就是梯度下降。

方法一:解析解法

线性回归模型的最小二乘“参数估计” (parameter estimation) 就是求解 , 使得 最小化的过程。

是关于 的凸函数(意思是可以找到全局最优解)。这里我们试图让均方误差MSE最小。

表示 的解, 是样本个数。这里的 是指后面的表达式值最小时的 取值。

那么上面的公式我们如何求得参数 呢? 这里我们又需要一些微积分 (calculus)的知识, 可以将 分别对 求导, 得到:

E对w求导
E对b求导

令上面两个式子等于零,就可以得到和的最优解的闭式(closed-form)解(闭式解也叫解析解)。

其中, 的平均值。

以上是对于输入属性为 1 个 的讨论, 也就是一元线性回归。

对于多个属性的讨论, 通常这时就引入了矩阵表示, 模型试图学得 , 使得 。这就是“多元线性回归”(multivariate linear regression)。

表示为 的一个参数, 那么:

然后对 求导就可以得到 矩阵的解(忽略了很多推导过程):

这里求解析解存在的问题是 在现实任务中往往不是满秩矩阵, 所以无法求解矩阵的逆, 故无法求得唯一的解。

  • 非满秩矩阵:例如3个变量,但是只有2个方程,故无法求得唯一的解。
  • 矩阵的逆:类似于数字的倒数(5对应1/5)目的是实现矩阵的除法。

解决方法: 引入正则化(regularization)将矩阵补成满秩

代码实战:

我们先生成一些数据,用于后面的实验。生成数据的函数是

# 随机生成一些用于实验的线性数据  
  
import numpy as np  
np.random.seed(42)    
m = 100  # number of instances  
X = 2 * np.random.rand(m, 1)  # column vector  
y = 4 + 3 * X + np.random.randn(m, 1)  # column vector  

画个图来看看生成的数据。

import matplotlib.pyplot as plt  
  
plt.figure(figsize=(6, 4))  
plt.plot(X, y, "b.")  
plt.xlabel("$x_1$")  
plt.ylabel("$y$", rotation=0)  
plt.axis([0, 2, 0, 15])  
plt.grid()  
plt.show()  

开始求解 , 也就是套公式

# add x0 = 1 to each instance  
X_b = np.c_[np.ones((100, 1)), X]   
theta_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)  

公式的结果theta_best的解为array([[4.21509616],[2.77011339]])。我们期待的是 得到的是 。非常接近, 噪声的存在使其不可能完全还原为原本的函数。

现在可以用 来做出预测了。预测结果为

X_new = np.array([[0], [2]])  
# add x0 = 1 to each instance  
X_new_b = np.c_[np.ones((2, 1)), X_new]   
y_predict = X_new_b.dot(theta_best)  

绘制模型的预测结果。

plt.plot(X_new, y_predict, "r-")  
plt.plot(X, y, "b.")  
plt.axis([0, 2, 0, 15])  
plt.show()  

另外也可以直接调用最小二乘函数scipy.linalg.lstsq()进行计算:

theta_best_svd, residuals, rank, s = np.linalg.lstsq(X_b, y, rcond=1e-6)  

theta_best_svd的计算结果为array([[4.21509616],[2.77011339]])。

方法二:数值解法

梯度下降,随机初始化和,通过逼近(沿着梯度下降的方向)的方式来求解(找到一个收敛的参数值)。

损失函数回顾:

公式里的 和解析解部分的 是一样的(只是换了下字母)。

注意到其中的参数 , 这个参数是可以简化部分求导 (消掉 )。除了参数外, 其它部分与解析解部分是完全相同的。

梯度下降参数优化方法:

其中 是学习率(learning rate)(学习率也经常用字母 表示),是用来控制下降每步的 距离(太小收敛会很慢, 太大则可能跳过最优点), 可以按照对数的方法来选择, 例如







请到「今天看啥」查看全文