专栏名称: 数据派THU
本订阅号是“THU数据派”的姊妹账号,致力于传播大数据价值、培养数据思维。
目录
相关文章推荐
CDA数据分析师  ·  【干货】常用的5种数据分析方法大揭秘 ·  2 天前  
大数据分析和人工智能  ·  陈果:数据归IT,分析归业务 ,重构业务价值视角 ·  3 天前  
玉树芝兰  ·  最强思考 o1 Pro + ... ·  2 天前  
51好读  ›  专栏  ›  数据派THU

清华大学朱军详解珠算:贝叶斯深度学习的GPU库(附视频)

数据派THU  · 公众号  · 大数据  · 2017-06-02 21:23

正文

来源:机器之心

演讲者:朱军

本文长度为5000字,建议阅读10分钟

本文探讨分享了贝叶斯深度学习模型的计算平台:珠算。

[导读]5月27-28日,机器之心在北京 898 创新空间顺利主办了第一届全球机器智能峰会(GMIS 2017)。中国科学院自动化研究所复杂系统管理与控制国家重点实验室主任王飞跃为大会做了开幕式致辞。大会第一天,「LSTM 之父」Jürgen Schmidhuber、Citadel 首席人工智能官邓力、腾讯 AI Lab 副主任俞栋、英特尔 AIPG 数据科学部主任 Yinyin Liu、GE Transportation Digital Solutions CTO Wesly Mukai 等知名人工智能专家参与峰会,并通过主题演讲、圆桌论坛等形式从科学家、企业家、技术专家的视角对人工智能技术前沿和未来发展进行了解读。


大会第一天下午,清华大学智能技术与系统国家重点实验室朱军发表了主题为《珠算:贝叶斯深度学习的 GPU 库》的演讲,他探讨并分享了贝叶斯深度学习模型的计算平台:珠算。该平台由清华大学机器学习组开发,目前已经在 GitHub 上开源,可参阅这篇报道《清华大学发布珠算:一个用于生成模型的 Python 库》http://dwz.cn/63BhzL


珠算项目地址:

https://github.com/thu-ml/zhusuan


在 GMIS 2017 大会上,朱军从深度学习谈起,对该项目进行了更加深入的介绍,同时还在深度生成模型、贝叶斯推理等更广泛方面分享了自己的思考。


以下为演讲视频:



以下是该演讲视频的主要内容:


谢谢机器之心的邀请,很高兴有这个机会和大家分享一下我们实验室做的计算平台,因为我们是实验室,不像公司里有那么多的人,但我们做的东西是属于比较前沿的。



我们研究的是贝叶斯深度学习,首先我跟大家分享一下为什么要关心贝叶斯深度学习。


贝叶斯深度学习


现在深度学习在各个领域里有很多用处。虽然 Deep Learning 非常好,但还不足够好。我们看一下大家都很熟知的 Deep Learning 还存在的两个问题:



一个问题是(深度学习)可能不是很鲁棒。可能会存在这种所谓的对抗样本,这有一个简单的例子,比如你有一个建筑物的图片,你可以用一个训练很好的神经网络分类得很准确。但是,我们如果加一些噪声,这些噪声可能是人检测不到的,合成一个图片之后却可以完全误导这个网络,甚至能够按照你的意愿误导分到某一个类。这是非常不好的性质,尤其是当我们在关键领域用深度学习的时候——一旦遇到这种情况,可能就会有一些比较致命性的错误发生。所以我们就想提出一个问题:机器学习或者深度学习本身能不能像人一样犯错误?人可能更多的时候是更鲁棒的,人可能会犯错误,但是人犯的错误相对都是比较直观、比较合理一点的——可能有某种道理在里面。


另外一个问题是深度学习大部分情况下都被我们当成一个黑箱。所以现在有很多的工作,包括我们自己的工作,都是试图去解释深度学习学到了什么。这里我们列了一个去年做的 CNNVis 的工作,能展示卷积网络每一层是什么、层和层之间是怎么关联的。这个方法非常受欢迎,也从一个侧面说明了大家对这个问题关心的程度。


在我看来,Deep Learning 本身属于机器学习的一个极端,它用了大量的训练样本,用了大量的计算资源。结果是我们在很多任务下,在特定环境、特定数据集上可以得到非常高的准确度,当然背后也有我们对网络结构的人为调整。


另外一极端是贝叶斯的学习方法,大家可能知道,2015 年的时候,在 AlphaGo 火之前,Science 有一篇文章就说怎么设计贝叶斯程序,在这种情况下可以用少量的训练样本帮助我们学非常精确的模型,当时展示的成果是这个贝叶斯程序可以(在手写体数字生成和识别任务上)通过视觉图灵测试。这从一个方面告诉我们:我们做学习的时候可以有不同的思路。



这是学习范式的两个极端,两者之间就有很多的事情可以做。我们把中间称之为「贝叶斯深度学习(Bayesian Deep Learning)」。它既有贝叶斯本身的可解释性,可以从少量的数据里边来学习;另外又有 Deep Learning 非常强大的拟合能力。



给大家看一个最近非常火的例子,叫深度生成模型(Deep Generative Models),这是典型的融合了深度学习和贝叶斯方法的模型。这里做了一个抽象:上面有一个隐含的变量,用 Z 表示;中间会经过一个深度神经网络,你可以根据你的任务选择不同的神经网络、不同的深度、不同的结构;下面是我们观察到的数据 X。这个场景有很多,比如对抗生成网络,可以生成高维的自然图片。实际上,Z 可以是非常随机的噪声,通过神经网络可以生成非常高质量的图片。


在这种框架下,我们可以做很多。比如可以给隐含变量设定某些结构信息,比如生成人脸时,有一些变量指代人的姿态,另外一些变量可能描述其他的特征,这两个放在一起我们就可以构建这样一个深度生成模型。


同一列的变量有同一姿态,可以变化其它变量来生成不同的图片。现在是非常受欢迎、非常强大的一种模型了。


下面用更形式化的方式进行描述。我们用概率模型来描述,比如对 Z 变量(隐含变量),我们会用 P(Z) 来描述它的先验分布;中间有一个参数化的神经网络做变换;最后生成我们想要的数据 X。在不同场景下,这个 Z 的含义可能不一样。比如:如果要生成医学图片,我们通常希望 Z 能够表达造成疾病的原因;而对于文本图片,我们可能希望理解背后的主题等等。



这个模型其实非常直观,但是它的难点在于我们所谓的 Inference(推断),这个过程是反向过来的——在 Inference 过程中,观察一些 X,然后我们用一些推导工具推导出我们观察到的 Z 到底是什么。在这个过程中,我们要用到一个主要的公式——贝叶斯公式。


珠算



那么珠算平台到底是起到什么作用呢?


我们都知道有很多公开的框架可以支持深度学习进行非常迅速的开发和原型设计,但目前还并没有很好的平台能支持贝叶斯深度学习。所以,我们构建了称之为珠算的平台。珠算平台可以支持我们进行深度学习,也可以支持贝叶斯推断,当然还可以是两者之间有机的融合。


大家知道,珠算或算盘是最古老的计算机器(calculating machine),被认为是中国历史上第五大发明。我们之所以取名为「珠算」,就是希望这个平台能够从某种意义上给传统算盘一种新的解释,同时还希望这个平台能够进行高效的计算。



珠算是一个生成模型的 Python 库,构建于 TensorFlow 之上。珠算不像现有的主要是为监督学习而设计的深度学习库,它是一种扎根于贝叶斯推断并支持多种生成模型的软件库。珠算区别于其他平台的一个很大的特点,即可以深度地做贝叶斯推断,因此,也就可以很有效地支持深度生成模型。珠算平台可以在 GPU 上训练神经网络,同时我们可以在上面做概率建模和概率推断,带来好处有:可以利用无监督数据、可以做小样本学习、可以做不确定性的推理和决策、可以生成新的样本等等。



为了做珠算平台,第一步是个抽象过程,需要把一类的模型能够抽象表达出来,在这里我们用贝叶斯网络。贝叶斯网络是在深度学习流行之前非常主流的方法,它是一种非常好的形式化方式,能非常直观地刻画模型。但是,与传统的贝叶斯网络不同,我们是深度融合了贝叶斯方法和深度神经网络的优点,因此,我们的贝叶斯网络有两类节点:随机的节点和确定性的节点。


确定性的节点基本上对应了深度神经网络的非线性变换,而随机节点可以描述不确定性。珠算是完全支持这两种节点的。在确定性的节点上我们把 TensorFlow 的所有操作都继承了下来。我们可以像在 TensorFlow 上构建神经网络一样构建中间的一些模块。如上图所示,构建一个模型很直观。我们首先只需要初始化 BayesianNet 环境,然后按照直观写模型。



这是一个具体的例子,如上图所示,我们需要生成手写体字符,这种情况下因为数据不是很高维,用简单的生成模型就够了,比如有一个 Z 变量,Z 是随机的,经过两层的全连接的神经网络,最后生成我们的 X,这种模型在珠算里面非常容易写。


可以在初始化BayesianNet环境之后,就沿着箭头的方向来写。比如:我们说 Z 变量服从一个高斯分布(z = zs.Normal()),珠算平台中有正态分布函数可以刻画该分布。接下来是两层的全连接层(layers.fully_connected()),最下面是数据的生成,比如我们数据是二值的,那么可以用伯努力随机分布来刻画它,这是非常直观地写模型的框架。你可以根据自己的需要书写其他的生成模型。



对于这种模型最难的实际上是推断部分,在机器学习里有两类的推断方法,一种是变分(Variational)方法,一种是蒙特卡罗模拟方法。对于变分方法来说,红色的点是我们的目标,在某个概率分布空间里面,但我们并不能直接计算。所以,变分方法主要是希望在某个简化的子集里找一个蓝色的点去逼近它,我们希望这个逼近是最优的,所以通常情况下要解决最优化问题。这里边有很多推导公并没有提到。对于 MCMC 方法来说,现在主流的解决方法是构造一些动力学方程,以达到模拟的效果,这里也隐含了很多技术细节。


因此,即使是非常简单的模型,如果要做推断都可能需要很多的数学推导,我们需要算梯度、调步长参数等等。而且很多步骤可能都会使我们犯错误,所以这是一个复杂的过程。而珠算要做的就是简化推导实现的过程,并用一个非常简洁的(概率)编程方式写出来,编程对计算机来说是最容易理解的。



给大家两个例子看我们怎么通过珠算实现推断的。首先,比如我们要做一个变分推断,在珠算上变分推断只需要三步:第一步,我们要构造一个变分分布,这个变分分布就像我前面讲的生成模型一样,可以通过初始化一个 BayesianNet,然后非常直观地写每部分是确定性的还是随机的等等。第二步,可以调用一下变分目标(variational objective),比如 z.sgvb,珠算上实现了不同的变分目标。剩下的事情,就是使用梯度下降进行迭代,就像我们实现深度神经网络一样,不断地使用随机梯度下降进行迭代而达到优化,这是典型变分推断的实现。



如果我们要做的是 HMC,HMC 是一个混合的蒙特卡罗方法或者哈密尔顿蒙特卡罗方法,这属于机器学习里面的一种十分优秀的算法,它可以处理高维空间里面的采样,该算法在珠算上也非常容易来实现。我们首先需要构建变量以储存样本,然后就可以初始化 HMC 采样器。接下来调用 sample() 函数就可以得到一个采样算子,随后的在不断运行样本迭代时,就像求解一个最优化算法一样。如果大家熟悉深度神经网络过程的话,基本上我们对这种贝叶斯神经网络可以完全对等地去实现。


贝叶斯深度学习怎么用?



贝叶斯深度学习在什么地方可以用?我给大家看一些例子。在我们课题组里主要强调如何用非常少的标注数据进行有效的学习。在机器学习里边有一个大家研究很多的叫半监督学习(Semi—supervised Learning)的场景,它可以利用少量标注数据帮助大量的未标注数据从中学习分类器。技术细节我就不说了,来看看结果。这个红色框里面是我们做出的结果,比如说在 SVHN 的数据集上,我们大概用 1% 的训练数据就可以达到 5% 的错误率,这个是目前最佳的结果。


因为我们是一个生成模型,所以我们还可以去生成新的样本,比如说我们可以生成二维的手写体字符。在一维上固定一个变量,调另外一个变量,生成你想要的某个类别或者某种风格的字符。



这是更新的工作,我们是在生成对抗网络(GAN)上做的。大家知道 GAN,它的生成效果很好了。我们在小样本的学习下面也可以做非常好的效果,我们提出了一个 Triple GAN 的工作。在这个自然图片的数据集上,比之前大家做的各种 GAN 变种的结果显著要好(错误率更低)。大家同时可以看出来,这个生成结果和自然图片也非常接近了。



下面一个例子是我前面提到过的——用贝叶斯方法做小样本学习。这是一个极端的例子,就是在训练的时候给它看一些基本的数据,将来在测试的时候会遇到新的类别(或概念),我们只给它看一个训练样例,然后希望它能够从中学出来一个贝叶斯程序,可以生成同一类的数据或者做识别。我们现在有一些在汉字上做的初步结果。给大家看一些例子,比如最上面给出了某一种字的一个样例,下面是生成出来的;基本上,大家能看出来和原始给的那个字的风格还是非常一致的,所以这个效果还是非常好的。一些技术细节我在这里就不详细说了。



最后一个例子也是我前面讲的鲁棒的 Deep Learning。Deep Learning 有很多潜在攻击样本,我怎么让它变得更鲁棒?实际上,最近有一些工作显示使用贝叶斯推理可以让深度神经网络变得更鲁棒,比如:剑桥做的一个工作,这是我们复现出来的在一个数据集上的比较。这个测试数据集有一半是攻击样本、一半是正常样本。这个黑色的线是一个标准的神经网络,不用贝叶斯推理,它的正确率从 0.9 几(可能 0.97、0.98)一下子降到 0.6 几,降得非常严重。蓝色的线是贝叶斯神经网络,它可以做到更好,可以达到 75%、80% 左右的正确率,已经是非常不错的。


右边的图是说你可以过滤掉多少对抗样本。大家可以看出来,这个蓝色的线,用贝叶斯网络可以帮助我们更好地识别对抗样本,提升鲁棒性。我们最近做了一个工作,结果是红色的线,能够更显著地识别 adversarial sample 和 normal sample,两个混在一起的时候,测试准确度能够显著地提升,实际上我们可以在一定条件下 达到图中的 Normal Accuracy。



我们已经开源了珠算平台,现在我们把它当作是一个研究平台,也欢迎大家去尝试。我们在上面也开发了很多当前最佳的模型,包括经典的贝叶斯 logistic 回归、最新的贝叶斯神经网络、变分自编码器、GAN、主题模型 等等,我们自己也在不断做一些新模型。下面是开源的页面,大家可以在 GitHub 上找到。我们也写了一些 Online Documents,解释 API 怎么定义的,另外还有教程可以指导大家很快来实现比如我前面举例的网络模型。



特别感谢我们组的学生,这个项目主要是我的两个博士生 Jiaxin Shi(石佳欣)和 Jianfei Chen(陈键飞)主导的,贡献者还包括一些博士后和博士生以及本科生。这个项目也得到一些国家经费的支持,我们的合作者还有天工研究院、英伟达等等。



我的报告就到这里。谢谢大家!


为保证发文质量、树立口碑,数据派现设立“错别字基金”,鼓励读者积极纠错

若您在阅读文章过程中发现任何错误,请在文末留言,或到后台反馈,经小编确认后,数据派将向检举读者发8.8元红包

同一位读者指出同一篇文章多处错误,奖金不变。不同读者指出同一处错误,奖励第一位读者。

感谢一直以来您的关注和支持,希望您能够监督数据派产出更加高质的内容。

公众号底部菜单有惊喜哦!

企业,个人加入组织请查看“联合会”

往期精彩内容请查看“号内搜”

加入志愿者或联系我们请查看“关于我们”