专栏名称: 机器之心
目录
相关文章推荐
机器之心  ·  DeepSeek一口气开源3个项目,还有梁文 ... ·  昨天  
AI前线  ·  民间大神魔改4090 ... ·  昨天  
财联社AI daily  ·  阿里扔“王炸”! ·  昨天  
财联社AI daily  ·  阿里扔“王炸”! ·  昨天  
爱可可-爱生活  ·  本文创新性地提出了 MinionS ... ·  2 天前  
爱可可-爱生活  ·  突破性的“一步扩散”生成模型 查看图片 ... ·  3 天前  
51好读  ›  专栏  ›  机器之心

生产级深度学习的开发经验分享:数据集的构建和提升是关键

机器之心  · 掘金  · AI  · 2018-06-14 07:04

正文

生产级深度学习的开发经验分享:数据集的构建和提升是关键

选自Pete Warden's Blog,作者:Pete Warden,机器之心编译。

深度学习的研究和生产之间存在较大差异,在学术研究中,人们一般更重视模型架构的设计,并使用较小规模的数据集。本文从生产层面强调了深度学习项目开发中需要更加重视数据集的构建,并以作者本人的亲身开发经验为例子,分享了几个简单实用的建议,涉及了数据集特性、迁移学习、指标以及可视化分析等层面。无论是对于研究者还是开发者,这些建议都有一定的参考价值。

本文还得到了 Andrej Karpathy 的转发:

作者简介:Pete Warden 是 Jetpac Inc 的 CTO,著有《The Public Data Handbook》和《The Big Data Glossary》两本 O'Reilly 出版的书,并参与建立了多个开源项目,例如 OpenHeatMap 和 Data Science Toolkit 等。

图片来源:Lisha Li

Andrej Karpathy 在 Train AI( www.figure-eight.com/train-ai/ )进行演讲时展示了这张幻灯片,我非常喜欢它!它完美地展现了深度学习的研究与实际的生产之间的差异。学术论文大多仅仅使用公开数据中的一小部分作为数据集而关注创造和改进模型。然而据我所知,当人们开始在实际的应用中使用机器学习时,对于训练数据的担忧占据了他们的大部分时间。

有很多很好的理由可以用来解释为什么研究人员如此关注于模型的架构,但这也确实意味着,对那些专注于将机器学习应用于生产环境中的人员来说,他们可以获取到的相关资源是很少的。为了解决这个问题,我在会议上进行了关于「the unreasonable effectiveness of training data」的演讲,而在这篇博客中,我想进一步阐述为什么数据如此重要以及改进它的一些实用技巧。

作为我工作的一部分,我与很多研究人员还有产品团队之间进行了密切的合作。我看到当他们专注于模型构建这一角度时可以获得很好的效果,而这也让我笃信于改进数据的威力。将深度学习应用到大多数应用中的最大障碍是如何在现实世界中获得足够高的准确率,而据我所知,提高准确度的最快途径就是改进训练集。即使你在其他限制(如延迟或存储空间)上遇到了阻碍,在特定的模型上提高准确率也可以帮助你通过使用规模较小的架构来对这些性能指标做出权衡。


语音数据集

我无法将我对于生产性系统的大部分观察分享给大家,但我有一个开源的例子可以用来阐释相同的模式。去年,我为 TensorFlow 创建了一个简单的语音识别示例( www.tensorflow.org/tutorials/a… ),结果表明在现有的数据集中,没有哪一个是可以很容易地被用作训练数据的。多亏了由 AIY 团队帮助我开发的开放式语音记录站点( aiyprojects.withgoogle.com/open_speech… )我才得以在很多志愿者的慷慨帮助下,收集到了 6 万个记录了人们说短单词的一秒钟音频片段。在这一数据训练下的模型虽然可以使用,但仍然没有达到我想要的准确度。为了了解我设计模型时可能存在的局限性,我用相同的数据集发起了一个 Kaggle 比赛( www.kaggle.com/c/tensorflo… )。参赛者的表现比我的简单模型要好得多,但即使有很多不同的方法,多个团队的精确度最终都仅仅达到了 91%左右。对我而言,这意味着数据本身存在着根本性的问题,而实际上参赛者们也的确发现了很多问题,比如不正确的标签或被截断过的音频。这些都激励着我去解决他们发现的问题并且增加这个数据集的样本数量。

我查看了错误度量标准,以了解模型最常遇到的问题,结果发现「其他」类别(当语音被识别出来,但这些单词不在模型有限的词汇表内时)更容易发生错误。为了解决这个问题,我增加了我们捕获的不同单词的数量,以提供更加多样的训练数据。

由于 Kaggle 参赛者报告了标签错误,我通过众包的形式增加了一个额外的验证过程:要求人们倾听每个片段并确保其与预期标签相符。由于 Kaggle 竞赛中还发现了一些几乎无声或被截断的文件,我还编写了一个实用的程序来进行一些简单的音频分析( github.com/petewarden/… ),并自动清除特别糟糕的样本。最后,尽管删除了错误的文件,但由于更多志愿者和一些付费的众包服务人员的努力,我们最终获得了超过 10 万的发言样本。

为了帮助他人使用数据集(并从我的错误中吸取教训!)我将所有相关内容以及最新的结果写入了一篇 arXiv 论文( arxiv.org/abs/1804.03… )。其中最重要的结论是,在不改变模型或测试数据的情况下,(通过改进数据)我们可以将 top-1 准确率从 85.4% 提高到 89.7%。这是一个巨大的提升(超过了 4%),并且当人们在安卓或树莓派的样例程序中使用该模型时,获得了更好的效果。尽管目前我使用的远非最优的模型,但我确信如果我将这些时间花费在调整模型上,我将无法获得这样的性能提升。

在生产的配置过程中,我多次见证了上述这样的性能提升。当你想要做同样的事情的时候,可能很难知道应该从哪里开始。你可以从我处理语音数据的技巧中得到一些灵感,但在接下来的内容中,我将为你介绍一些我认为有用的具体的方法。


首先,观察你的数据

这看起来显而易见,但你首先最应该做的是随机浏览你将要使用的训练数据。将一些文件复制到本地计算机上,然后花几个小时来预览它们。如果您正在处理图片,使用类似于 MacOS 的取景器的功能滚动浏览缩略图,将可以让你快速地浏览数千个图片。对于音频,你可以使用取景器播放预览,或者将文本随机片段转储到终端。正因为我没有花费足够的时间来对第一版语音命令进行上述处理,Kaggle 参赛者们才会在开始处理数据时发现了很多问题。

我总是觉得这个过程有点愚蠢,但我从未后悔过。每当我完成这些工作时,我都可以发现一些对数据来说非常重要的事情,比如不同类别之间样本数量的失衡、数据乱码(例如扩展名标识为 JPG 的 PNG 文件)、错误的标签,或者仅仅是令人惊讶的组合。Tom White 在对 ImageNet 的检查中获得了许多惊人的发现,比如:标签「太阳镜」,实际上是指一种古老的用来放大阳光的设备。Andrej 对 ImageNet 进行手动分类的工作( karpathy.github.io/2014/09/02/… )同样教会了我很多与这个数据集相关的知识,包括如何分辨所有不同的犬种,甚至是人。

你将要采取的行动取决于你的发现,但是在你做任何其他数据清理工作之前,你都应该先进行这种检查,因为对数据集内容的直观了解有助于你在其余步骤中做出更好的决定。


快速地选择一个模型

不要在选择模型上花费太多时间。如果你正在进行图像分类任务,请查看 AutoML,或查看 TensorFlow 的模型存储库( github.com/tensorflow/… )或 Fast.AI 收集的样例( www.fast.ai/ )来找到你产品中面对的类似问题的模型。重要的是尽可能快地开始迭代,这样你就可以尽早且经常性地让实际用户来试用你的模型。你随时都可以上线改进的模型,并且可能会看到更好的结果,但你必须首先对数据进行合适的处理。深度学习仍然遵循「输入决定输出」的基本计算规律,所以即使是最好的模型也会受到训练集中数据缺陷的限制。通过选择模型并对其进行测试,你将能够理解这些缺陷从而开始改进数据。

为了进一步加快模型的迭代速度,你可以尝试从一个已经在大型现有数据集上预训练过的模型开始,使用迁移学习来利用你收集到的(可能小得多的)一组数据对它进行微调。这通常比仅在较小的数据集上进行训练的结果要好得多,而且速度更快,这样一来你就可以快速地了解到应该如何调整数据收集策略。最重要的是,你可以根据结果中的反馈调整数据收集(和处理)流程,以便适应你的学习策略,而不是仅仅在训练之前将数据收集作为单独的阶段进行。


在做到之前先假装做到(人工标注数据)

建立研究和生产模型最大的不同之处在于,研究通常在开始时就有了明确的问题定义,而实际应用的需求潜藏在用户的头脑中,并且只能随着时间的推移而逐渐获知。例如,对于 Jetpac,我们希望找到好的照片并展示在城市的自动旅行指南中。刚开始我们要求评分者给他们认为好的照片打上标签,但我们最终却得到了很多张笑脸,因为这就是他们对这个问题的理解。我们将这些内容放入产品的展示模型中,来测试用户的反应,结果发现这并没有给他们留下什么深刻的印象。为了解决这个问题,我们将问题修改为「这张照片是否让你想要前往它所展示的地方?」。这很大程度上提高了我们结果的质量,然而事实表明,来自东南亚的工作人员,更倾向于认为充满了在大型酒店中穿西装的人和酒杯的会议照片看起来令人惊叹。这种不匹配是对我们生活的泡沫的一个提醒,但它同时也是一个实际问题,因为我们产品的目标受众是美国人,他们看到会议照片会感到压抑和沮丧。最终,我们六个 Jetpac 团队的成员自己为超过 200 万张照片进行了评分,因为我们比任何可以被训练去做这件事的人都更清楚标准。

这是一个极端的例子,但它表明标注过程在很大程度上依赖于应用程序的要求。对于大多数生产用例来说,找出模型正确问题的正确答案需要花费很长的一段时间,而这对于正确地解决问题至关重要。如果你正在试图让模型回答错误的问题,那么将永远无法在这个不可靠的基础上建立可靠的用户体验。

图片来自 Thomas Hawk

我发现能够判断你所问的问题是否正确的唯一方法是对你的应用程序进行模拟,而不是使用有人参与迭代的机器学习模型。因为在背后有人类的参与,这种方法有时被称为「Wizard-of-Oz-ing」。在 Jetpac 的案例中,我们让人们为一些旅行指南样例手动选择照片,而不是训练一个通过测试用户的反馈来调整挑选图片的标准的模型。一旦我们可以很可靠地从测试中获得正面反馈,我们接下来就可以将我们设计的照片选择规则转化为标注指导手册,以便用这样的方法获得数百万个图像用作训练集。然后,我们使用这些数据训练出了能够预测数十亿张照片质量的模型,但它的 DNA 来自我们设计的原始的人工规则。


在真实数据上进行训练

在 Jetpac 案例中,我们用于训练模型的图像和我们希望应用模型的图像来源相同(主要是 Facebook 和 Instagram),但是我发现的一个常见问题是,训练数据集与模型最终输入数据的一些关键差异最终会体现在生产中。例如,我经常会看到基于 ImageNet 训练的模型在被尝试应用到无人机或机器人中时会遇到问题。这是因为 ImageNet 大多为人们拍摄的照片,而这些照片存在着很多共性,比如:用手机或照相机拍摄,使用中性镜头,大致在头部高度,在白天或人造光线下拍摄,标记的物体居中并位于前景中等等。而机器人和无人机使用视频摄像机,通常配有高视野镜头,拍摄位置要么是在地面上要么是在高空中,同时缺乏光照条件,并且由于没有对于物体轮廓的智能判定,通常只能进行裁剪。这些差异意味着,如果你只是在 ImageNet 上训练模型并将其部署到某一台设备上,那么将无法获得较好的准确率。

训练数据和最终模型输入数据的差异还可能体现在很多细微的地方。想象一下,你正在使用世界各地的动物数据集来训练一个识别野生动物的相机。如果你只打算将它部署在婆罗洲的丛林中,那么企鹅标签被选中的概率会特别低。如果训练数据中包含有南极的照片,那么模型将会有很大的机会将其他动物误认为是企鹅,因而模型整体的准确率会远比你不使用这部分训练数据时低。

有许多方法可以根据已知的先验知识(例如,在丛林环境中大幅度降低企鹅的概率)来校准结果,但使用能够反映产品真实场景的训练集会更加方便和有效。我发现最好的方法是始终使用从实际应用程序中直接捕获到的数据,这与我上面提到的「Wizard of Oz」方法之间存在很好的联系。这样一来,在训练过程中使用人来进行反馈的部分可以被数据的预先标注所替代,即使收集到的标签数量非常少,它们也可以反映真实的使用情况,并且也基本足够被用于进行迁移学习的一些初始实验了。


混淆矩阵

当我研究语音指令的例子时,我看到的最常见的报告之一是训练期间的混淆矩阵。这是一个显示在控制台中的例子:

[[258 0 0 0 0 0 0 0 0 0 0 0]
 [ 7 6 26 94 7 49 1 15 40 2 0 11]
 [ 10 1 107 80 13 22 0 13 10 1 0 4]
 [ 1 3 16 163 6 48 0 5 10 1 0 17]
 [ 15 1 17 114 55 13 0 9 22 5 0 9]
 [ 1 1 6 97 3 87 1 12 46 0 0 10]
 [ 8 6 86 84 13 24 1 9 9 1 0 6]
 [ 9 3 32 112 9 26 1 36 19 0 0 9]
 [ 8 2 12 94 9 52 0 6 72 0 0 2]
 [ 16 1 39 74 29 42 0 6 37 9 0 3]
 [ 15 6 17 71 50 37 0 6 32 2 1 9]
 [ 11 1 6 151 5 42 0 8 16 0 0 20]]

这可能看起来很吓人,但它实际上只是一个表格,显示网络出错的详细信息。这里有一个更加美观的带标签版本:

表中的每一行代表一组与真实标签相同的样本,每列显示标签预测结果的数量。例如,高亮显示的行表示所有无声的音频样本,如果你从左至右阅读,则可以发现标签预测的结果是正确的,因为每个标签都落在」Silence」一栏中。这表明,该模型可以很好地识无声的音频片段,不存在任何一个误判的情况。从列的角度来看,第一列显示有多少音频片段被预测为无声,我们可以看到一些实际上是单词的音频片段被误认为是无声的,这其中有很多误判。这些知识对我来说非常有用,因为它让我更加仔细地观察那些被误认为是无声的音频片段,而这些片段事实上并不总是安静的。这帮助我通过删除音量较低的音频片段来提高数据的质量,而如果没有混淆矩阵的线索,我将无从下手。

几乎所有对结果的总结都可能是有用的,但是我发现混淆矩阵是一个很好的折衷方案,它提供的信息比单个的准确率更多,同时也不会涵盖太多我无法处理的细节。在训练过程中观察数字变化也很有用,因为它可以告诉你模型正在努力学习什么类别,并可以让你在清理和扩充数据集时专注于某些方面。


可视化模型

可视化聚类是我最喜欢的用来理解我的网络如何解读训练数据的方式之一。TensorBoard 为这种探索提供了很好的支持,尽管它经常被用于查看词嵌入,但我发现它几乎适用于与任何嵌入有类似的工作方式的网络层。例如,图像分类网络在最后的全连接或 softmax 单元之前通常具有的倒数第二层,可以被用作嵌入(这就是简单的迁移学习示例的工作原理,如 TensorFlow for Poets( codelabs.developers.google.com/codelabs/te… ))。严格意义上来说,这些并不是嵌入,因为我们并没有在训练过程中努力确保在真正的嵌入具有希望的空间属性,但对它们的向量进行聚类确实会产生一些有趣的结果。

举例来说,之前一个同我合作过的团队对图像分类模型中某些动物的高错误率感到困惑。他们使用聚类可视化来查看他们的训练数据是如何分布到各种类别的,当他们看到「捷豹」时,他们清楚地发现数据被分成两个彼此之间存在一定间隔的不同的组。







请到「今天看啥」查看全文