专栏名称: 机器之心
目录
相关文章推荐
机器之心  ·  DeepSeek一口气开源3个项目,还有梁文 ... ·  昨天  
AI前线  ·  民间大神魔改4090 ... ·  昨天  
财联社AI daily  ·  阿里扔“王炸”! ·  昨天  
财联社AI daily  ·  阿里扔“王炸”! ·  昨天  
爱可可-爱生活  ·  本文创新性地提出了 MinionS ... ·  2 天前  
爱可可-爱生活  ·  突破性的“一步扩散”生成模型 查看图片 ... ·  3 天前  
51好读  ›  专栏  ›  机器之心

如何在Tensorflow.js中处理MNIST图像数据

机器之心  · 掘金  · AI  · 2018-06-26 06:07

正文

如何在Tensorflow.js中处理MNIST图像数据

选自freeCodeCamp

作者:Kevin Scott

机器之心编译

参与:李诗萌、路

数据清理是数据科学和机器学习中的重要组成部分,本文介绍了如何在 Tensorflow.js(0.11.1)中处理 MNIST 图像数据,并逐行解释代码。


有人开玩笑说有 80% 的数据科学家在清理数据,剩下的 20% 在抱怨清理数据……在数据科学工作中,清理数据所占比例比外人想象的要多得多。一般而言,训练模型通常只占机器学习或数据科学家工作的一小部分(少于 10%)。
——Kaggle CEO Antony Goldbloom

对任何一个机器学习问题而言,数据处理都是很重要的一步。本文将采用 Tensorflow.js(0.11.1)的 MNIST 样例( github.com/tensorflow/… ),逐行运行数据处理的代码。

MNIST 样例

18 import * as tf from '@tensorflow/tfjs';
19
20 const IMAGE_SIZE = 784;
21 const NUM_CLASSES = 10;
22 const NUM_DATASET_ELEMENTS = 65000;
23
24 const NUM_TRAIN_ELEMENTS = 55000;
25 const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;
26
27 const MNIST_IMAGES_SPRITE_PATH =
28 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
29 const MNIST_LABELS_PATH =
30 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';`

首先,导入 TensorFlow(确保你在转译代码)并建立一些常量,包括:

  • IMAGE_SIZE:图像尺寸(28*28=784)
  • NUM_CLASSES:标签类别的数量(这个数字可以是 0~9,所以这里有 10 类)
  • NUM_DATASET_ELEMENTS:图像总数量(65000)
  • NUM_TRAIN_ELEMENTS:训练集中图像的数量(55000)
  • NUM_TEST_ELEMENTS:测试集中图像的数量(10000,亦称余数)
  • MNIST_IMAGES_SPRITE_PATH&MNIST_LABELS_PATH:图像和标签的路径

将这些图像级联为一个巨大的图像,如下图所示:

MNISTData

接下来,从第 38 行开始是 MnistData,该类别使用以下函数:

  • load:负责异步加载图像和标注数据;
  • nextTrainBatch:加载下一个训练批;
  • nextTestBatch:加载下一个测试批;
  • nextBatch:返回下一个批的通用函数,该函数的使用取决于是在训练集还是测试集。

本文属于入门文章,因此只采用 load 函数。

load

async load() {
 // Make a request for the MNIST sprited image.
 const img = new Image();
 const canvas = document.createElement('canvas');
 const ctx = canvas.getContext('2d');

异步函数(async)是 Javascript 中相对较新的语言功能,因此你需要一个转译器。

Image 对象是表示内存中图像的本地 DOM 函数,在图像加载时提供可访问图像属性的回调。canvas 是 DOM 的另一个元素,该元素可以提供访问像素数组的简单方式,还可以通过上下文对其进行处理。

因为这两个都是 DOM 元素,所以如果用 Node.js(或 Web Worker)则无需访问这些元素。有关其他可替代的方法,请参见下文。

imgRequest

const imgRequest = new Promise((resolve, reject) => {
 img.crossOrigin = '';
 img.onload = () => {
 img.width = img.naturalWidth;
 img.height = img.naturalHeight;

该代码初始化了一个 new promise,图像加载成功后该 promise 结束。该示例没有明确处理误差状态。

crossOrigin 是一个允许跨域加载图像并可以在与 DOM 交互时解决 CORS(跨源资源共享,cross-origin resource sharing)问题的图像属性。naturalWidth 和 naturalHeight 指加载图像的原始维度,在计算时可以强制校正图像尺寸。

 const datasetBytesBuffer =
 new ArrayBuffer(NUMDATASETELEMENTS * IMAGESIZE * 4);
57
58 const chunkSize = 5000;
59 canvas.width = img.width;
60 canvas.height = chunkSize;

该代码初始化了一个新的 buffer,包含每一张图的每一个像素。它将图像总数和每张图像的尺寸和通道数量相乘。

我认为 chunkSize 的用处在于防止 UI 一次将太多数据加载到内存中,但并不能 100% 确定。

62 for (let i = 0; i < NUMDATASETELEMENTS / chunkSize; i++) {
63 const datasetBytesView = new Float32Array(
64 datasetBytesBuffer, i * IMAGESIZE * chunkSize * 4,
 IMAGESIZE * chunkSize);
66 ctx.drawImage(
67 img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
68 chunkSize);
69
70 const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

该代码遍历了每一张 sprite 图像,并为该迭代初始化了一个新的 TypedArray。接下来,上下文图像获取了一个绘制出来的图像块。最终,使用上下文的 getImageData 函数将绘制出来的图像转换为图像数据,返回的是一个表示底层像素数据的对象。







请到「今天看啥」查看全文