专栏名称: 机器之心
目录
相关文章推荐
爱可可-爱生活  ·  SWE-RL让AI具备更强大的软件工程推理能 ... ·  23 小时前  
爱可可-爱生活  ·  【一个关于长上下文大语言模型(LLM)的综述 ... ·  昨天  
AI范儿  ·  AI 创业公司估值排行榜:从 ... ·  昨天  
爱可可-爱生活  ·  【[542星]OmniServe:集成了 ... ·  2 天前  
硅星GenAI  ·  DeepSeek开源周Day 2: ... ·  2 天前  
硅星GenAI  ·  DeepSeek开源周Day 2: ... ·  2 天前  
51好读  ›  专栏  ›  机器之心

硬核NeruIPS 2018最佳论文,一个神经了的常微分方程

机器之心  · 掘金  · AI  · 2018-12-24 01:57

正文

阅读 61

硬核NeruIPS 2018最佳论文,一个神经了的常微分方程

机器之心原创,作者:蒋思源。

这是一篇神奇的论文,以前一层一层叠加的神经网络似乎突然变得连续了,反向传播也似乎不再需要一点一点往前传、一层一层更新参数了。

在最近结束的 NeruIPS 2018 中,来自多伦多大学的陈天琦等研究者成为最佳论文的获得者。他们提出了一种名为神经常微分方程的模型,这是新一类的深度神经网络。神经常微分方程不拘于对已有架构的修修补补,它完全从另外一个角度考虑如何以连续的方式借助神经网络对数据建模。在陈天琦的讲解下,机器之心将向各位读者介绍这一令人兴奋的神经网络新家族。

在与机器之心的访谈中,陈天琦的导师 David Duvenaud 教授谈起这位学生也是赞不绝口。Duvenaud 教授认为陈天琦不仅是位理解能力超强的学生,钻研起问题来也相当认真透彻。他说:「天琦很喜欢提出新想法,他有时会在我提出建议一周后再反馈:『老师你之前建议的方法不太合理。但是我研究出另外一套合理的方法,结果我也做出来了。』」Ducenaud 教授评价道,现如今人工智能热度有增无减,教授能找到优秀博士生基本如同「鸡生蛋还是蛋生鸡」的问题,顶尖学校的教授通常能快速地招纳到博士生,「我很幸运地能在事业起步阶段就遇到陈天琦如此优秀的学生。」

本文主要介绍神经常微分方程背后的细想与直观理解,很多延伸的概念并没有详细解释,例如大大降低计算复杂度的连续型流模型和官方 PyTorch 代码实现等。这一篇文章重点对比了神经常微分方程(ODEnet)与残差网络,我们不仅能通过这一部分了解如何从熟悉的 ResNet 演化到 ODEnet,同时还能还有新模型的前向传播过程和特点。

其次文章比较关注 ODEnet 的反向传播过程,即如何通过解常微分方程直接把梯度求出来。这一部分与传统的反向传播有很多不同,因此先理解反向传播再看源码可能是更好的选择。值得注意的是,ODEnet 的反传只有常数级的内存占用成本。

如下展示了文章的主要结构:

  • 常微分方程

  • 从残差网络到微分方程

  • 从微分方程到残差网络

  • 网络对比

  • 神经常微分方程

  • 连续型的归一化流

  • 变量代换定理

常微分方程

其实初读这篇论文,还是有一些疑惑的,因为很多概念都不是我们所熟知的。因此如果想要了解这个模型,那么同学们,你们首先需要回忆高数上的微分方程。有了这样的概念后,我们就能愉快地连续化神经网络层级,并构建完整的神经常微分方程。

常微分方程即只包含单个自变量 x、未知函数 f(x) 和未知函数的导数 f'(x) 的等式,所以说 f'(x) = 2x 也算一个常微分方程。但更常见的可以表示为 df(x)/dx = g(f(x), x),其中 g(f(x), x) 表示由 f(x) 和 x 组成的某个表达式,这个式子是扩展一般神经网络的关键,我们在后面会讨论这个式子怎么就连续化了神经网络层级。

一般对于常微分方程,我们希望解出未知的 f(x),例如 f'(x) = 2x 的通解为 f(x)=x^2 +C,其中 C 表示任意常数。而在工程中更常用数值解,即给定一个初值 f(x_0),我们希望解出末值 f(x_1),这样并不需要解出完整的 f(x),只需要一步步逼近它就行了。

现在回过头来讨论我们熟悉的神经网络,本质上不论是全连接、循环还是卷积网络,它们都类似于一个非常复杂的复合函数,复合的次数就等于层级的深度。例如两层全连接网络可以表示为 Y=f(f(X, θ1), θ2),因此每一个神经网络层级都类似于万能函数逼近器。

因为整体是复合函数,所以很容易接受复合函数的求导方法:链式法则,并将梯度从最外一层的函数一点点先向里面层级的函数传递,并且每传到一层函数,就可以更新该层的参数 θ。现在问题是,我们前向传播过后需要保留所有层的激活值,并在沿计算路径反传梯度时利用这些激活值。这对内存的占用非常大,因此也就限制了深度模型的训练过程。

神经常微分方程走了另一条道路,它使用神经网络参数化隐藏状态的导数,而不是如往常那样直接参数化隐藏状态。这里参数化隐藏状态的导数就类似构建了连续性的层级与参数,而不再是离散的层级。因此参数也是一个连续的空间,我们不需要再分层传播梯度与更新参数。总而言之,神经微分方程在前向传播过程中不储存任何中间结果,因此它只需要近似常数级的内存成本。

从残差网络到微分方程

残差网络 是一类特殊的卷积网络,它通过残差连接而解决了梯度反传问题,即当神经网络层级非常深时,梯度仍然能有效传回输入端。下图为原论文中残差模块的结构,残差块的输出结合了输入信息与内部卷积运算的输出信息,这种残差连接或恒等映射表示深层模型至少不能低于浅层网络的准确度。这样的残差模块堆叠几十上百个就是非常深的残差神经网络。


如果我们将上面的残差模块更加形式化地表示为以下方程:

其中 h_t 是第 t 层隐藏单元的输出值,f 为通过θ_t 参数化的神经网络。该方程式表示上图的整个残差模块,如果我们其改写为残差的形式,即 h_t+1 - h_t = f(h_t, θ_t )。那么我们可以看到神经网络 f 参数化的是隐藏层之间的残差,f 同样不是直接参数化隐藏层。

ResNet 假设层级的离散的,第 t 层到第 t+1 层之间是无定义的。那么如果这中间是有定义的呢?残差项 h_t0 - h_t1 是不是就应该非常小,以至于接近无穷小?这里我们少考虑了分母,即残差项应该表示为 (h_t+1 - h_t )/1,分母的 1 表示两个离散的层级之间相差 1。所以再一次考虑层级间有定义,我们会发现残差项最终会收敛到隐藏层对 t 的导数,而神经网络实际上参数化的就是这个导数。

所以若我们在层级间加入更多的层,且最终趋向于添加了无穷层时,神经网络就连续化了。可以说残差网络其实就是连续变换的欧拉离散化,是一个特例,我们可以将这种连续变换形式化地表示为一个常微分方程:

如果从导数定义的角度来看,当 t 的变化趋向于无穷小时,隐藏状态的变化 dh(t) 可以通过神经网络建模。当 t 从初始一点点变化到终止,那么 h(t) 的改变最终就代表着前向传播结果。这样利用神经网络参数化隐藏层的导数,就确确实实连续化了神经网络层级。

现在若能得出该常微分方程的数值解,那么就相当于完成了前向传播。具体而言,若 h(0)=X 为输入图像,那么终止时刻的隐藏层输出 h(T) 就为推断结果。这是一个常微分方程的初值问题,可以直接通过黑箱的常微分方程求解器(ODE Solver)解出来。而这样的求解器又能控制数值误差,因此我们总能在计算力和模型准确度之间做权衡。

形式上来说,现在就需要变换方程 (2) 以求出数值解,即给定初始状态 h(t_0) 和神经网络的情况下求出终止状态 h(t_1):

如上所示,常微分方程的数值解 h(t_1) 需要求神经网络 f 从 t_0 到 t_1 的积分。我们完全可以利用 ODE solver 解出这个值,这在数学物理领域已经有非常成熟的解法,我们只需要将其当作一个黑盒工具使用就行了。

从微分方程到残差网络

前面提到过残差网络是神经常微分方程的特例,可以说残差网络是欧拉方法的离散化。两三百年前解常微分方程的欧拉法非常直观,即 h(t +Δt) = h(t) + Δt×f(h(t), t)。每当隐藏层沿 t 走一小步Δt,新的隐藏层状态 h(t +Δt) 就应该近似在已有的方向上迈一小步。如果这样一小步一小步从 t_0 走到 t_1,那么就求出了 ODE 的数值解。

如果我们令 Δt 每次都等于 1,那么离散化的欧拉方法就等于残差模块的表达式 h(t+1) = h(t) + f(h(t), t)。但是欧拉法只是解常微分方程最基础的方法,它每走一步都会产生一点误差,且误差会累积起来。近百年来,数学家构建了很多现代 ODE 求解方法,它们不仅能保证收敛到真实解,同时还能控制误差水平。

陈天琦等研究者构建的 ODE 网络就使用了一种适应性的 ODE solver,它不像欧拉法移动固定的步长,相反它会根据给定的误差容忍度选择适当的步长逼近真实解。如下图所示,左边的残差网络定义有限转换的离散序列,它从 0 到 1 再到 5 是离散的层级数,且在每一层通过激活函数做一次非线性转换。此外,黑色的评估位置可以视为神经元,它会对输入做一次转换以修正传递的值。而右侧的 ODE 网络定义了一个向量场,隐藏状态会有一个连续的转换,黑色的评估点也会根据误差容忍度自动调整。

网络对比

在 David 的 Oral 演讲中,他以两段伪代码展示了 ResNet 与 ODEnet 之间的差别。如下展示了 ResNet 的主要过程,其中 f 可以视为卷积层,ResNet 为整个模型架构。在卷积层 f 中,h 为上一层输出的特征图,t 确定目前是第几个卷积层。ResNet 中的循环体为残差连接,因此该网络一共 T 个残差模块,且最终返回第 T 层的输出值。

def f(h, t, θ):
    return nnet(h, θ_t)

def resnet(h):
    for t in [1:T]:
        h = h + f(h, t, θ)
    return h复制代码

相比常见的 ResNet,下面的伪代码就比较新奇了。首先 f 与前面一样定义的是神经网络,不过现在它的参数θ是一个整体,同时 t 作为独立参数也需要馈送到神经网络中,这表明层级之间也是有定义的,它是一种连续的网络。而整个 ODEnet 不需要通过循环搭建离散的层级,它只要通过 ODE solver 求出 t_1 时刻的 h 就行了。

def f(h, t, θ):
    return nnet([h, t], θ)

def ODEnet(h, θ):
    return ODESolver(f, h, t_0, t_1, θ)复制代码

除了计算过程不一样,陈天琦等研究者还在 MNSIT 测试了这两种模型的效果。他们使用带有 6 个残差模块的 ResNet,以及使用一个 ODE Solver 代替这些残差模块的 ODEnet。以下展示了不同网络在 MNSIT 上的效果、参数量、内存占用量和计算复杂度。

其中单个隐藏层的 MLP 引用自 LeCun 在 1998 年的研究,其隐藏层只有 300 个神经元,但是 ODEnet 在有相似参数量的情况下能获得显著更好的结果。上表中 L 表示神经网络的层级数,L tilde 表示 ODE Solver 中的评估次数,它可以近似代表 ODEnet 的「层级深度」。值得注意的是,ODEnet 只有常数级的内存占用,这表示不论层级的深度如何增加,它的内存占用基本不会有太大的变化。

神经常微分方程

在与 ResNet 的类比中,我们基本上已经了解了 ODEnet 的前向传播过程。首先输入数据 Z(t_0),我们可以通过一个连续的转换函数(神经网络)对输入进行非线性变换,从而得到 f。随后 ODESolver 对 f 进行积分,再加上初值就可以得到最后的推断结果。如下所示,残差网络只不过是用一个离散的残差连接代替 ODE Solver。


在前向传播中,ODEnet 还有几个非常重要的性质,即模型的层级数与模型的误差控制。首先因为是连续模型,其并没有明确的层级数,因此我们只能使用相似的度量确定模型的「深度」,作者在这篇论文中采用 ODE Solver 评估的次数作为深度。

其次,深度与误差控制有着直接的联系,ODEnet 通过控制误差容忍度能确定模型的深度。因为 ODE Solver 能确保在误差容忍度之内逼近常微分方程的真实解,改变误差容忍度就能改变神经网络的行为。一般而言,降低 ODE Solver 的误差容忍度将增加函数的评估的次数,因此类似于增加了模型的「深度」。调整误差容忍度能允许我们在准确度与计算成本之间做权衡,因此我们在训练时可以采用高准确率而学习更好的神经网络,在推断时可以根据实际计算环境调整为较低的准确度。


如原论文的上图所示,a 图表示模型能保证在误差范围为内,且随着误差降低,前向传播的函数评估数增加。b 图展示了评估数与相对计算时间的关系。d 图展示了函数评估数会随着训练的增加而自适应地增加,这表明随着训练的进行,模型的复杂度会增加。

c 图比较有意思,它表示前向传播的函数评估数大致是反向传播评估数的一倍,这恰好表示反向传播中的 adjoint sensitivity 方法不仅内存效率高,同时计算效率也比直接通过积分器的反向传播高。这主要是因为 adjoint sensitivity 并不需要依次传递到前向传播中的每一个函数评估,即梯度不通过模型的深度由后向前一层层传。

反向传播

师从同门的 Jesse Bettencourt 向机器之心介绍道,「天琦最擅长的就是耐心讲解。」当他遇到任何无论是代码问题,理论问题还是数学问题,一旦是问了同桌的天琦,对方就一定会慢慢地花时间把问题讲清楚、讲透彻。而 ODEnet 的反向传播,就是这样一种需要耐心讲解的问题。

ODEnet 的反向传播与常见的反向传播有一些不同,我们可能需要仔细查阅原论文与对应的附录证明才能有较深的理解。此外,作者给出了 ODEnet 的 PyTorch 实现,我们也可以通过它了解实现细节。

正如作者而言,训练一个连续层级网络的主要技术难点在于令梯度穿过 ODE Solver 的反向传播。其实如果令梯度沿着前向传播的计算路径反传回去是非常直观的,但是内存占用会比较大而且数值误差也不能控制。作者的解决方案是将前向传播的 ODE Solver 视为一个黑箱操作,梯度很难或根本不需要传递进去,只需要「绕过」就行了。

作者采用了一种名为 adjoint method 的梯度计算方法来「绕过」前向传播中的 ODE Solver,即模型在反传中通过第二个增广 ODE Solver 算出梯度,其可以逼近按计算路径从 ODE Solver 传递回的梯度,因此可用于进一步的参数更新。这种方法如上图 c 所示不仅在计算和内存非常有优势,同时还能精确地控制数值误差。

具体而言,若我们的 损失函数 为 L(),且它的输入为 ODE Solver 的输出:

我们第一步需要求 L 对 z(t) 的导数,或者说模型损失的变化如何取决于隐藏状态 z(t) 的变化。其中损失函数 L 对 z(t_1) 的导数可以为整个模型的梯度计算提供入口。作者将这一个导数称为 adjoint a(t) = -dL/z(t),它其实就相当于隐藏层的梯度。

在基于链式法则的传统反向传播中,我们需要从后一层对前一层求导以传递梯度。而在连续化的 ODEnet 中,我们需要将前面求出的 a(t) 对连续的 t 进行求导,由于 a(t) 是损失 L 对隐藏状态 z(t) 的导数,这就和传统链式法则中的传播概念基本一致。下式展示了 a(t) 的导数,它能将梯度沿着连续的 t 向前传,附录 B.1 介绍了该式具体的推导过程。


在获取每一个隐藏状态的梯度后,我们可以再求它们对参数的导数,并更新参数。同样在 ODEnet 中,获取隐藏状态的梯度后,再对参数求导并积分后就能得到损失对参数的导数,这里之所以需要求积分是因为「层级」t 是连续的。这一个方程式可以表示为:

综上,我们对 ODEnet 的反传过程主要可以直观理解为三步骤,即首先求出梯度入口伴随 a(t_1),再求 a(t) 的变化率 da(t)/dt,这样就能求出不同时刻的 a(t)。最后借助 a(t) 与 z(t),我们可以求出损失对参数的梯度,并更新参数。当然这里只是简要的直观理解,更完整的反传过程展示在原论文的算法 1。







请到「今天看啥」查看全文