专栏名称: 码农翻身
工作15年的前IBM架构师分享好玩有趣的编程知识和职场的经验教训, 不容错过。
目录
相关文章推荐
OSC开源社区  ·  前端年度大事件盘点:尤雨溪成立公司、ECMA ... ·  3 天前  
逸言  ·  项目札记008 | 团队成员的能力培养 ·  昨天  
程序猿  ·  程序员的一周,凌晨 3 点才是效率最高的 ·  5 天前  
程序猿  ·  你觉得是上班更苦还是上学更苦? ·  2 天前  
51好读  ›  专栏  ›  码农翻身

从鸡和鸡腿的关系,发现人工智能的秘密......

码农翻身  · 公众号  · 程序员  · 2024-12-27 08:56

主要观点总结

闪客向小宇解释了机器学习的基本原理,通过数鸡腿和预测函数等例子引出损失函数和梯度下降的概念。通过生活中的例子,形象展示了如何通过梯度下降求解损失函数最小值的过程。

关键观点总结

关键观点1: 机器学习符号主义

人工智能(AI)领域的一种方法,主要基于符号和规则来表示知识和推理,与现代机器学习方法(例如深度学习)形成对比。

关键观点2: 损失函数

表示预测结果与真实结果之间差距的函数,用于评估模型的性能。

关键观点3: 梯度下降

一种求解损失函数最小值的优化方法,通过不断沿梯度的反方向调整参数,逐步逼近最优解。

关键观点4: 符号主义的应用

通过数鸡腿和预测函数等例子,引出损失函数和梯度下降的实际应用。

关键观点5: 梯度下降的详细解释

通过生活实例解释梯度下降的思路,损失函数的参数调整过程以及梯度的概念。


正文

小宇:闪客闪客,现在的 AI 好神奇呀,你能给我讲讲它的原理吗?

闪客:你个菜鸡,连最基本的机器学习是什么都不知道,就妄想一下子了解现在 AI 原理?

小宇:额,注意你的态度!那你说怎么办嘛!

闪客:现在你先忘掉 AI,忘掉所有的什么 ChatGPT、大模型、深度学习、机器学习、神经网络这些概念。

小宇:哦好,虽然我本来就没听说过这些。

线性回归




闪客:啊这... 好吧,我们来一个场景。我想研究鸡的数量和腿的数量的关系,于是我列了一个表格。

鸡    腿 

5    10

7    14

8    16

9    18

那我问你,假如鸡的数量是 10,那么腿的数量是多少?
小宇:额,你是不是把我当傻帽呀,我不看你这表也知道,腿的数量就是鸡数量的 2 倍嘛,当然是 20 了!

闪客:没错,你直接找到了鸡和腿数量之间的规律,是严格符合 y = 2x 的函数关系。假如世界上所有的事情都能找到其对应的严格的函数关系,那该多好,这就是早期机器学习符号主义的愿景。

画外音:机器学习的符号主义(Symbolic AI 或 Symbolic Machine Learning) 是人工智能(AI)领域的一种方法,主要基于符号和规则来表示知识和推理。这种方法与现代机器学习方法(例如深度学习)形成了鲜明对比,后者依赖于神经网络和大量数据的模式识别。符号主义在20世纪70-90年代被广泛应用,是人工智能早期的主要研究方向之一。 

小宇:诶?这看起来非常科学严谨呀,为什么这样做不行呢?

闪客:如果能实现这个愿景固然是好的,但人们还是低估了这个世界的复杂程度。想想看,如果让你用一个函数来预测股票是涨还是跌,这可能吗?

小宇:总感觉理论上是可行的,但实际上应该做不到,不然我也不会在这学什么机器学习了哈哈。

闪客:是的,这种看似能够找到规律的事情都做不到,更别提人类智慧这种的复杂问题了。 

小宇:哎,那这可怎么办呢?

闪客:别急,咱先别考虑那么远的问题,我先给你出一个比刚刚数鸡腿更复杂点的问题,你找找看下面 X 和 Y 的关系。

X    Y 
1   2.6

2   3.0

3   3.7

4   4.5

5   4.4

6   4.9

7   6.0

8   6.2

9   6.4

10   7.2 


小宇:额,总感觉有点规律,但又不能一下子看出来,有点烧脑。

闪客:确实,不过我们把这些点画在坐标轴上,你再看看呢。

小宇:哇!这么看清晰了好多,不过还是不能一下子看出来什么。

闪客:那我再加一条线呢?

小宇:哎呀!这感觉已经找到规律了,大概就是 y = 0.5x + 2 嘛!

闪客:没错,你居然直接把函数说出来了。

小宇:你图都画成这样了,我还说不出函数,那就太不应该了。不过我猜到你接下来要说什么了,就是如何找到这个函数对吧?

闪客:没错,直觉上,我们是想让这条线尽可能靠近所有点,但怎么用数学或计算机语言表达"靠得近",就是个问题了。

小宇:emmm,好像不太容易想到,没想到这么简单直观的问题,要是用严肃的数学语言描述,还挺难的。

闪客:是的,我给你加几条线,你看看有没有启发。

小宇:啊!我明白了,可以用每个点到这条线的偏离距离的总和,来表示点与线的“贴合程度”,这个数越小越好。

闪客:没错!所以我们就可以定义如下的损失函数,来表示这条线和这些点的偏离程度,只要找到这个函数的最小值即可!

小宇:额,你这太不丝滑了呀,前面还一个公式都没有,怎么突然冒出来这么个东西。


损失函数




闪客:哈哈,本来想给你吓回去的,但既然你没走,那我们就专门来聊聊这个"损失函数"到底是个啥东西,为什么叫它"损失"。

小宇:是因为算出来的数特别让人"损失信心"吗?

闪客:哈哈哈,这个脑洞不错,但其实它的损失更像是"我们和完美结果之间的差距"。差距越大,损失就越大,差距越小,损失就越小。

小宇:哇,这个解释好理解!

闪客:来,我们先从直观的定义开始。假设某个点的真实值是 y,而我们的预测值是 ŷ 。你觉得两者的误差可以怎么表示?

小宇:很简单呀,直接用 y−ŷ 不就行了?

闪客:不错!这叫"误差"或"偏差"。但问题来了,你觉得要是我们把所有点的误差加起来,有啥问题?

小宇:嗯~正的误差和负的误差会互相抵消,最终看起来像没什么偏差?

闪客:没错,像刚刚的那几个 XY 的点,如果按这种算法来评估,就有可能找到一种驴唇不对马嘴的预测,但它的损失却是 0! 

小宇:哈哈,确实离了大谱了,那可咋办呢?

闪客:为了不让误差"藏着掖着",我们可以给它取个绝对值,这样正负误差都成了正的:

小宇:哦,这样挺公平的呀。诶等等,这又有个新的数学公式,你得解释解释。

闪客:额,你是没上过初中么?这个符号就是求和符号,表示把所有的 y - ŷ 的值都累加起来。比如把等差数列写成求和符号的形式就是这样。

小宇:哦懂了,这好像确实是初中就学过的,嘿嘿。

闪客:回过头来看,这样确实很公平,但有个小问题,就是绝对值有"尖点",数学优化的时候不太友好,计算起来跟被卡在牙缝里一样麻烦。

小宇:嗯确实,做题的时候其实最讨厌碰到绝对值符号,还得分段讨论,有一种情况没想全就要扣分,最头疼了。

闪客:所以我们更喜欢"平方误差",就是把误差平方后再加起来:

小宇:哇!这的确是个绝妙的办法呀,平方之后,正负误差都成正的!而且大的误差更显眼,就像班里成绩特别差的同学会被老师特别关照一样。

闪客:哈哈哈,没错!我们再平均一下,去掉样本数量大小因素的影响,这就叫"均方误差"(Mean Squared Error, MSE),看起来是不是又简单又合理?

小宇:嗯,这次终于没有突然甩出高大上的东西,我的信心回来了一点。

闪客:好了,找到了损失函数,还记得我们要干啥不?

小宇:记得,让损失函数最小!

闪客:不错,这时候我们得把 ŷ 表示出来,我们可以假设预测的直线的方程是 y = wx + b,像下面这样。

不过我们可以先简化一点,认为这条直线穿过原点,这样就可以少个 b。

这个时候带入 MSE 中,就是

我们想要计算的就是,w 为多少的时候,这个损失函数的值最小。

小宇:完了完了,我已经头疼了,这里咋这么多字母,我已经晕了。

闪客:别急,这些字母里其实只有 w 是未知的,其他的都是已知数。我们举个简单的例子就明白了了。我们先不看上面那个复杂的例子,假设 x=[1,2,3,4] y=[1,2,3,4] 这样傻子都能看出来规律对吧,我们就用这个来举例。

小宇:哈哈这个简单,不用算也知道就是 y = x

闪客:没错,我们就用这个算一下,把这里的 x 和 y 的值都代入到刚刚的损失函数中。

接下来就是一个标准的求函数 L 的极小值点的过程,这种苦力活我怎么可能自己做呢,交给 AI 吧。

小宇:哈哈,你可真懒,不过这过程解释得真细致呀,要让你讲肯定不能这么有耐心。我再补个图吧,刚刚 w = 1 就表示预测直线的方程是 y = x,就像这样,确实损失最小呢!

闪客:没错,实际上刚刚的

画成图就是个抛物线,寻找最小值点就是寻找抛物线的最低点。

小宇:原来如此!诶?那如果回到最初,我们不简化预测函数的直线方程,直接是 y = wx + b 呢?这要怎么办?

闪客:一样的,这样最终代入到损失函数后,就是关于 w 和 b 两个未知变量的函数,求极值点如果画成图的话,就不再是抛物线了,而是三维坐标中的曲面。

这时候就得用偏导数来计算了,具体太数学了就不展开了,偏导数我做了两个动图,你可以感悟一下。

小宇:哎呀虽然这动画很丝滑,但想起来是真烧脑呀,更何况这还是最简单的形状了,如果七扭八歪或者维度更高就...

闪客:是的,所以这个时候我们就不能直接硬求解了,得累死你,而且也利用不了计算机的优势。这时候我们可以用另一种更适合计算机一步一步逼近答案的求解方法 -- 梯度下降。


梯度下降




小宇:啊,这么神奇!那快告诉我什么是梯度下降呢?

闪客:别急,直接告诉你可不是我的风格,我们先不要管什么梯度下不下降的,先来想想我们的目的是什么。

小宇:嗯目的我还是清晰的,就是我们想求解一个叫损失函数的最小值,比如 L(w, b) 甚至更多维的 L(w₁, w₂, ..., b)。

闪客:没错,但最终目标可不是知道这个最小值是多少。

小宇:哦哦对,我说得不给力,是求解使得这个损失函数最小的 w 和 b 都是多少。

闪客:没错,那你想想看,直接一步到位求出 w 和 b 的值太难了,那我们是不是可以一点一点调整它们,分多次求解呢?

小宇:一点一点调整?听起来好像是个思路,但还是没太明白怎么调整。

闪客:没关系,我们假设个生活中的场景,你现在有一杯咖啡和糖,你怎么调出符合你口味的甜度呢?

小宇:哦这个我深有感悟,一步到位很难。比如我想要微微甜,那就得先加一点点糖,然后尝一尝,然后再加点,再尝一尝,直到刚好到我满意为止。

闪客:没错,没想到你还挺精致的,这就是梯度下降的精髓!

小宇:啊,这和梯度下降有什么关系呢?

闪客:你可以把符合你的口味这个目标当做一个损失函数,糖的量就是损失函数中的参数,你不能一下子就确定糖这个值是多少,于是只能从一个初始状态开始,比如先加一勺糖,然后一点一点变化糖的量。每次加完糖后你品尝咖啡就是你在计算这次的损失函数,也就是你对口味的喜欢程度。

小宇:啊,我明白了!没想到生活中的例子这么有启发作用!

闪客:对!生活中的很多事都是这样的,比如做饭调味、调音响音质,甚至选衣服搭配颜色,都是通过不断尝试和调整来找到最优解。机器学习的梯度下降,也是用这种思路来优化参数的。

小宇:这个思路我明白了。不过你之前说的“梯度”具体是啥呢?

闪客:假如损失函数只有一个参数,像之前的 L(w),那么梯度就和导数是一个意思。

如果损失函数有多个参数,像之前的 L(w,b),那么梯度就是各个参数的偏导数。

在这种情况下,梯度是个向量,是所有参数的偏导数累加起来的综合结果。

小宇:额,你这一大堆输出差点又给我整懵了,向量这个概念确实学过,但总感觉还不直观,你能形象地给我展示下么?

闪客:没问题,我们就拿之前三维坐标系下的那个带两个参数 w 和 b 的损失函数来说,对应图中的这个点,它的梯度是多少呢?

小宇:对 w 和 b 分别求偏导?

闪客:没错,在图中,对 w 求偏导就是把 b = 0 这个平面和曲面的交线求导数。

把视角转一下就清晰了。

小宇:原来如此!那对 b 求偏导呢?

闪客:也是一样,线画 w = -1 这个平面和曲面的交线。

从侧面看,这条交线已经在最低点了,所以 b 的偏导数就是 0。

所以把这两股偏导数的力量合在一起,就是最终的向量,也就是梯度。

小宇:我明白了!其实就是找个坡度最大的方向往下滑,直到滑到最低点。

闪客:没错,不过这里的图只是为了让你形象理解梯度的意思,实际计算的时候不用考虑那么多,直接求各参数的偏导数就行了。

小宇:诶,那算出偏导数之后,要怎么样呢?

闪客:简单!每次都沿着梯度的反方向,走一小步,也就是你说的往下滑。公式写出来是这样的:

小宇:哇,这么简单呀,其实就是每个参数每次都变化自己偏导数那么大的值就好了。

闪客:没错!不过这样的话有个小问题,就是每次变化的这个量,太大了容易走过了错过最低点,太小了又太磨叽,所以我们乘以一个学习率 η 来调整一下速度。

小宇:哦还真是,人类真是好聪明呀!

闪客:哈哈是呀。咱们找到了梯度下降的求解方法,你来实践一下吧。回到那个最简单的题目,假设 x 和 y 的数据如下:x=[1,2,3,4] y=[1,2,3,4] ,求一下 y = wx 中的 w 是多少。

虽然傻子也能直接看出 y = x 是最终的解,不过我们就用这个来举例实战一下,你来用梯度下降的方法求一下 w 的值。

小宇:好的,不过我学你,这种小事儿我也懒得自己算了,交给 AI 吧!

闪客:哈哈真不赖,活学活用呀,这 AI 直接把图都帮我们画出来了,图里可以看到损失函数的值 Loss 再逐渐降低为 0,而我们要计算的权重 w 的值在不断接近 1。之后你看到再复杂的机器学习或者深度学习等过程的展示,最核心的其实就是这两个东西的变化罢了。

小宇:哇,似乎有点 GET 到 AI 的核心逻辑了!我理解更高维度也就是更多参数的梯度下降求解,和这个步骤基本的思路是一致的。

闪客:没错,至于梯度下降的改进版本,比如动量法、Adam 优化器等,以及更多计算模型,比如神经网络、卷积神经网络等,都是在这个核心思路的基础上迭代出来的。

小宇:厉害了,这次讲得还挺耐心!

闪客:哎呀,不知不觉又到饭点了,今天讲的给你画了这么多图很累的,请我吃个饭吧。

小宇:哦才想起来我家里洗的衣服还在洗衣机里呢,我得回去晾衣服啦,下次吧。

闪客:哦~