专栏名称: Python中文社区
致力于成为国内最好的Python开发者学习交流平台,这里有关于Python的国内外最新消息,每日推送有趣有料的技术干货和社区动态。 官方网站:www.python-cn.com
目录
相关文章推荐
Python爱好者社区  ·  太强了!35个python案例.pdf ·  1 周前  
Python爱好者社区  ·  为什么感觉中国人月薪过万很普遍了? ·  4 天前  
Python爱好者社区  ·  700篇大模型论文 ·  5 天前  
Python爱好者社区  ·  80w,我入局了! ·  1 周前  
Python爱好者社区  ·  史上最强!LSTM杀疯了 ·  1 周前  
51好读  ›  专栏  ›  Python中文社区

Python机器学习算法入门之梯度下降法实现线性回归

Python中文社区  · 公众号  · Python  · 2016-12-01 08:20

正文



專 欄


ZZR,Python中文社区专栏作者,OpenStack工程师,曾经的NLP研究者。主要兴趣方向:OpenStack、Python爬虫、Python数据分析。

Blog:http://skydream.me/

CSDN:http://blog.csdn.net/titan0427/article/details/50365480


1. 背景

       文章的背景取自An Introduction to Gradient Descent and Linear Regression,本文想在该文章的基础上,完整地描述线性回归算法。部分数据和图片取自该文章。没有太多时间抠细节,所以难免有什么缺漏错误之处,望指正。

        线性回归的目标很简单,就是用一条线,来拟合这些点,并且使得点集与拟合函数间的误差最小。如果这个函数曲线是一条直线,那就被称为线性回归,如果曲线是一条二次曲线,就被称为二次回归。数据来自于GradientDescentExample中的data.csv文件,共100个数据点,如下图所示:

        我们的目标是用一条直线来拟合这些点。既然是二维,那y=b+mx">么y=b+mx这个公式相信对于中国学生都很熟悉。其中


Error(b,m)=1N1N((b+mxi)yi)2

        计算损失函数的python代码如下:

  1. # y = b + mx

  2. def compute_error_for_line_given_points(b, m, points):

  3.    totalError = sum((((b + m * point[0]) - point[1]) ** 2 for point in points))

  4.    return totalError / float(len(points))

        现在问题被转化为,寻找参数



2. 多元线性回归模型

        从机器学习的角度来说,以上的数据只有一个feature,所以用一元线性回归模型即可。这里我们将一元线性模型的结论一般化,即推广到多元线性回归模型。这部分内部参考了机器学习中的数学(1)-回归(regression)、梯度下降(gradient descent)。假设有


hθ(x)=θ0+θ1x1+...+θnxn=θTxx0=1

J(θ)=12i=1m(hθ(x(i))y(i))2mm

        更一般地,我们可以得到广义线性回归。

广线hθ(x)=θTx=θ0+i=1nθiϕi(xi)

2.1 误差函数的进一步思考

        这里有一个有意思的东西,就是误差函数为什么要写成这样的形式。首先是误差函数最前面的系数

y(i)=θTx(i)+ε(i)

        假定误差


        所以求


3 最小二乘法求误差函数最优解

        最小二乘法(normal equation)相信大家都很熟悉,这里简单进行解释并提供python实现。首先,我们进一步把

 J(θ)=12i=1m(hθ(x(i))y(i))2=12(XθY)T(XθY)


       

       所以

        当然这里可能遇到一些问题,比如


3.1 python实现最小二乘法

        这里的代码仅仅针对背景里的这个问题。部分参考了回归方法及其python实现

  1. # 通过最小二乘法直接得到最优系数,返回计算出来的系数b, m

  2. def least_square_regress(points):

  3.    x_mat = np.mat(np.array([np.ones([len(points)]), points[:, 0]]).T)  # 转为100行2列的矩阵,2列其实只有一个feature,其中x0恒为1

  4.    y_mat = points[:, 1].reshape(len(points), 1)  # 转为100行1列的矩阵

  5.    xT_x = x_mat.T * x_mat

  6.    if np.linalg.det(xT_x) == 0.0:

  7.        print('this matrix is singular,cannot inverse')  # 奇异矩阵,不存在逆矩阵

  8.        return

  9.    coefficient_mat = xT_x.I * (x_mat.T * y_mat)

  10.    return coefficient_mat[0, 0], coefficient_mat[1, 0] # 即系数b和m

        程序执行结果如下:
b = 7.99102098227, m = 1.32243102276, error = 110.257383466, 相关系数 = 0.773728499888

        拟合结果如下图:


4. 梯度下降法求误差函数最优解

        有了最小二乘法以后,我们已经可以对数据点进行拟合。但由于最小二乘法需要计算最小二乘法和梯度下降法有哪些区别?


4.1. 梯度

        首先,我们简单回顾一下微积分中梯度的概念。这里参考了方向导数与梯度,具体的证明请务必看一下这份材料,很短很简单的。

        讨论函数


        定义函数

fl=limρ0f(x+Δx,y+Δy)f(x,y)ρρ=(Δx)2+(Δy)2

        方向导数可以理解为,函数

gradf(x,y)=fxi+fyj

        从几何角度来理解,函数


        函数


4.2 梯度方向计算

        理解了梯度的概念之后,我们重新回到1. 背景中提到的例子。1. 背景提到,梯度下降法所做的是从图中的任意一点开始,逐步找到图的最低点。那么现在问题来了,从任意一点开始,

Error(b,m)m=i=1Nxi((b+mxi)yi)

Error(b,m)b=i=1N((b+mxi)yi)x01

        有了这两个结果,我们就可以开始使用梯度下降法来寻找误差函数

        回到更一般的情况,对于每一个向量

θjJ(θ)=θj12i=1m(hθ(x(i))y(i))2=i=1m(hθ(x(i))y(i))xj(i)


4.3 批量梯度下降法

        从上面的公式中,我们进一步得到特征的参数

θj=θjαJ(θ)θj=θjαi=1m(hθ(x(i))y(i))xj(i)

        针对此例,梯度下降法一次迭代过程的python代码如下:

  1. def step_gradient(b_current, m_current, points, learningRate):

  2.    b_gradient = 0

  3.    m_gradient = 0

  4.    N = float(len(points))

  5.    for i in range(0, len(points)):

  6.        x = points[i, 0]

  7.        y = points[i, 1]

  8.        m_gradient += (2 / N) * x * ((b_current + m_current * x) - y)

  9.        b_gradient += (2 / N) * ((b_current + m_current * x) - y)

  10.    new_b = b_current - (learningRate * b_gradient)  # 沿梯度负方向

  11.    new_m = m_current - (learningRate * m_gradient)  # 沿梯度负方向

  12.    return [new_b, new_m]

       

        其中learningRate是学习速率,它决定了逼近最低点的速率。可以想到的是,如果learningRate太大,则可能导致我们不断地最低点附近来回震荡;而learningRate太小,则会导致逼近的速度太慢。An Introduction to Gradient Descent and Linear Regression提供了完整的实现代码GradientDescentExample

        这里多插入一句,如何在python中生成GIF动图。配置的过程参考了使用Matplotlib和Imagemagick实现算法可视化与GIF导出。需要安装ImageMagick,使用到的python库是Wand: a ctypes-based simple ImageMagick binding for Python。然后修改C:\Python27\Lib\site-packages\matplotlib__init__.py文件,在

  1. # this is the instance used by the matplotlib classes

  2. rcParams = rc_params()

        后面加上:

  1. # fix a bug by ZZR

  2. rcParams['animation.convert_path'] = 'C:\Program Files\ImageMagick-6.9.2-Q16\convert.exe'

        即可在python中调用ImageMagick。如何画动图参见Matplotlib动画指南,不再赘述。learningRate=0.0001,迭代100轮的结果如下图:

After {100} iterations b = 0.0350749705923, m = 1.47880271753, error = 112.647056643, 相关系数 = 0.773728499888
After {1000} iterations b = 0.0889365199374, m = 1.47774408519, error = 112.614810116, 相关系数 = 0.773728499888
After {1w} iterations b = 0.607898599705, m = 1.46754404363, error = 112.315334271, 相关系数 = 0.773728499888
After {10w} iterations b = 4.24798444022, m = 1.39599926553, error = 110.786319297, 相关系数 = 0.773728499888


4.4 随机梯度下降法

        批量梯度下降法每次迭代都要用到训练集的所有数据,计算量很大,针对这种不足,引入了随机梯度下降法。随机梯度下降法每次迭代只使用单个样本,迭代公式如下:

θj=θjα(hθ(x(i))y(i))xj(i)

        可以看出,随机梯度下降法是减小单个样本的错误函数,每次迭代不一定都是向着全局最优方向,但大方向是朝着全局最优的。

        这里还有一些重要的细节没有提及,比如如何确实learningRate,如果判断何时递归可以结束等等。


参考文献

  1. An Introduction to Gradient Descent and Linear Regression

  2. 方向导数与梯度

  3. 最小二乘法和梯度下降法有哪些区别?

  4. GradientDescentExample

  5. 机器学习中的数学(1)-回归(regression)、梯度下降(gradient descent)

  6. @邹博_机器学习

  7. 回归方法及其python实现

  8. 使用Matplotlib和Imagemagick实现算法可视化与GIF导出

  9. Wand: a ctypes-based simple ImageMagick binding for Python

  10. Matplotlib动画指南

ARTICLES
近期热门文章



生成器

关于生成器的那些事儿

爬虫代理

如何构建爬虫代理服务

地理编码

怎样用Python实现地理编码

nginx日志

使用Python分析nginx日志


淘宝女郎

一个批量抓取淘女郎写真图片的爬虫

IP代理池

突破反爬虫的利器——开源IP代理池

布隆去重

基于Redis的Bloomfilter去重(附代码)

QQ空间爬虫

QQ空间爬虫最新分享,一天 400 万条数据

匿名代理池

进击的爬虫:用Python搭建匿名代理池

在公众号底部回复上述关键词可直接打开相应文章


我 们 终 将 改 变 潮 水 的 方 向

§§

Python中文社区
www.python-cn.com

致力于成为

国内最好的Python社区


QQ群:152745094

专栏作者申请邮箱

[email protected]

— Life is short,we use Python —

 


点击阅读原文可直接进入作者博客


推荐文章
Python爱好者社区  ·  太强了!35个python案例.pdf
1 周前
Python爱好者社区  ·  为什么感觉中国人月薪过万很普遍了?
4 天前
Python爱好者社区  ·  700篇大模型论文
5 天前
Python爱好者社区  ·  80w,我入局了!
1 周前
Python爱好者社区  ·  史上最强!LSTM杀疯了
1 周前
medworld器械世界  ·  国税总局出手,医械个代危险!
7 年前