专栏名称: TensorFlow
Google官方账号,分享人工智能和TensorFlow相关的最新消息、技术资源、活动和实践案例。联系我们:[email protected]
目录
相关文章推荐
新智元  ·  乙巳蛇年 新智元十年 追梦ASI时代 ·  16 小时前  
爱可可-爱生活  ·  【为什么对我们来说对话更容易】《Why ... ·  3 天前  
人工智能那点事  ·  值得刷屏!无人机“老妈式”喊话,安全感爆棚! ·  3 天前  
机器之心  ·  物理测试暴击AI圈,DeepSeek ... ·  3 天前  
51好读  ›  专栏  ›  TensorFlow

TensorFlow Lite 设备端训练

TensorFlow  · 公众号  · AI  · 2021-12-20 18:59

正文


发布人:TensorFlow Lite 团队


TensorFlow Lite 是 Google 的机器学习框架,用于在多种设备和平台上部署机器学习模型,例如移动设备(iOS 和 Android)、桌面设备和其他边缘设备。

  • TensorFlow Lite

    https://www.tensorflow.google.cn/lite


最近,我们又添加了在浏览器中运行 TensorFlow Lite 模型的支持。要使用 TensorFlow Lite 构建应用,您可以利用 TensorFlow Hub 中的现成模型,或者使用转换器将现有的 TensorFlow 模型转换为 TensorFlow Lite 模型。

  • 构建应用

    https://www.tensorflow.google.cn/lite/guide#development_workflow

  • TensorFlow Hub

    https://hub.tensorflow.google.cn/s?deployment-format=lite

  • 转换器

    https://tensorflow.google.cn/lite/convert/index

模型部署到应用中后,您可以基于输入数据在该模型上运行推理

  • 运行推理

    https://tensorflow.google.cn/lite/guide#2_run_inference


除运行推理外,TensorFlow Lite 现在还支持在设备端训练模型。设备端训练支持有趣的个性化用例,其中模型可以根据用户需求进行微调。例如,您可以部署一个图像分类模型,允许用户使用迁移学习对模型进行微调来识别鸟类,同时允许其他用户重新训练该模型来识别水果。这项新功能在 TensorFlow 2.7 及以上版本中提供,现在可用于 Android 应用,并会在未来增加对 iOS 的支持。

  • 迁移学习

    https://developers.google.com/machine-learning/glossary#transfer-learning


设备端训练也是根据分散式数据训练全局模型的联合学习用例的必要基础。本文文章不会涉及到联合学习,而是侧重帮助您在 Android 应用中集成设备端训练。

  • 联合

    https://ai.googleblog.com/2017/04/federated-learning-collaborative.html


本文后半部分,我们将参考 ColabAndroid 示例应用,向您介绍设备端学习的端到端实现路径,引导您完成图像分类模型的微调。

  • Colab

    https://tensorflow.google.cn/lite/examples/on_device_training/overview

  • Android 示例应用

    https://github.com/tensorflow/examples/tree/master/lite/examples/model_personalization


对早期方法的改进


我们在 2019 年的文章中介绍了设备端训练的概念,并展示了一个在 TensorFlow Lite 中进行设备端训练的示例。但是,当时存在几个限制。比如,自定义模型结构和优化器并不容易。您还必须处理多个物理 TensorFlow Lite (.tflite) 模型,而不是单个 TensorFlow Lite 模型。同样,存储和更新训练权重也没有简单的方法。我们最新的 TensorFlow Lite 版本提供更便捷的设备端训练选项,简化了这个过程,接下来就给大家介绍一下。

  • 文章

    https://blog.tensorflow.google.cn/2019/12/example-on-device-model-personalization.html


它是怎样实现的呢?


要部署内置设备端训练的 TensorFlow Lite 模型,简要步骤如下:


  • 构建用于训练和推理的 TensorFlow 模型

  • 将 TensorFlow 模型转换为 TensorFlow Lite 格式

  • 将模型集成到您的 Android 应用中

  • 在应用中调用模型训练,与调用模型推理的方式类似


具体步骤如下。


构建用于训练和推理的 TensorFlow 模型


TensorFlow Lite 模型应当同时支持模型推理和模型训练,训练通常涉及将模型的权重保存到文件系统,并从文件系统中恢复权重。这样做是为了在每个训练周期结束后保存训练权重,以便下个训练周期可以使用前一个周期的权重,而不是从头开始训练。


  • 一个使用训练数据训练模型的 train 函数。如下的 train 函数进行预测,计算损失(或误差),使用 tf.GradientTape() 记录自动微分的操作并更新模型的参数。

  • train

    https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/on_device_training/overview.ipynb#scrollTo=d8577c80&line=38&uniqifier=1

  • 自动微分

    https://tensorflow.google.cn/guide/autodiff#automatic_differentiation_and_gradients


# The `train` function takes a batch of input images and labels.
@tf.function(input_signature=[
tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
tf.TensorSpec([None, 10], tf.float32),
])
def train(self, x, y):
with tf.GradientTape() as tape:
prediction = self.model(x)
loss = self._LOSS_FN(prediction, y)
gradients = tape.gradient(loss, self.model.trainable_variables)
self._OPTIM.apply_gradients(
zip(gradients, self.model.trainable_variables))
result = {"loss": loss}
for grad in gradients:
result[grad.name] = grad
return result


  • 一个调用模型推理的 infer 函数或 predict 函数。这和您目前使用 TensorFlow Lite 进行推理的方法类似。

  • infer

    https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/on_device_training/overview.ipynb#scrollTo=d8577c80&line=38&uniqifier=1


@tf.function(input_signature=[tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32)])
def predict(self, x):
return {
"output": self.model(x)
}


  • 一个 save/restore 函数,将训练权重(即模型使用的参数)以 Checkpoints 格式保存到文件系统。该 save 函数的代码如下所示。

  • save/restore

    https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/on_device_training/overview.ipynb#scrollTo=d8577c80&line=38&uniqifier=1


@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
def save(self, checkpoint_path):
tensor_names = [weight.name for weight in self.model.weights]
tensors_to_save = [weight.read_value() for weight in self.model.weights]
tf.raw_ops.Save(
filename=checkpoint_path, tensor_names=tensor_names,
data=tensors_to_save, name='save')
return {
"checkpoint_path": checkpoint_path
}


转换为 TensorFlow Lite 格式


您可能已经熟悉将 TensorFlow 模型转换为 TensorFlow Lite 格式的工作流。设备端训练的一些低级功能(例如,存储模型参数的变量)仍处于实验阶段,而其他(例如,权重序列化)目前依赖于 TF Select 运算符,因此您需要在转换过程中设置这些标志。您可以在 Colab 中找到所有需要设置标志的示例。

  • 转换

    https://tensorflow.google.cn/lite/convert

  • TF Select

    https://tensorflow.google.cn/lite/guide/ops_select

  • Colab

    https://www.tensorflow.org/lite/examples/on_device_training/overview


# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()


将模型集成到您的 Android 应用中


将模型转换为 TensorFlow Lite 格式后,您就可以将模型集成到应用中了!更多详细信息,请参阅 Android 应用示例。

  • Android

    https://github.com/tensorflow/examples/tree/master/lite/examples/model_personalization


在应用中调用模型训练和推理


在 Android 中,可以使用 Java 或 C++ API 执行 TensorFlow Lite 设备端训练。您可以创建一个 TensorFlow Lite Interpreter 的实例来加载模型和驱动模型训练任务。我们先前已经定义了多个 tf.functions:可以使用 TensorFlow Lite 对签名的支持来调用这些函数,签名允许单个 TensorFlow Lite 模型支持多个“入口”点。例如,我们为设备端训练定义了一个 train 函数, 这是模型的其中一个签名。通过指定签名的名称 (“train”)使用 TensorFlow Lite 的 runSignature 方法,即可调用 train 函数:

  • Interpreter

    https://tensorflow.google.cn/lite/guide/inference#load_and_run_a_model_in_java

  • 签名

    https://tensorflow.google.cn/lite/guide/signatures


// Run training for a few steps.
float[] losses = new float[NUM_EPOCHS];
for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) {
for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) {
Mapinputs = new HashMap<>>();
inputs.put("x", trainImageBatches.get(batchIdx));
inputs.put("y", trainLabelBatches.get(batchIdx));

Mapoutputs = new HashMap<>();
FloatBuffer loss = FloatBuffer.allocate(1);
outputs.put("loss", loss);

interpreter.runSignature(inputs, outputs, "train");

// Record the last loss.
if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0);
}
}


同样,下面的示例展示了如何使用模型的“infer”签名调用推理函数:

try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
// Restore the weights from the checkpoint file.

int NUM_TESTS = 10;
FloatBuffer testImages = FloatBuffer.allocateDirect(NUM_TESTS * 28 * 28).order(ByteOrder.nativeOrder());
FloatBuffer output = FloatBuffer.allocateDirect(NUM_TESTS * 10).order(ByteOrder.nativeOrder());

// Fill the test data.

// Run the inference.
Mapinputs = new HashMap<>>();
inputs.put("x", testImages.rewind());
Mapoutputs = new HashMap<>();
outputs.put("output", output);
anotherInterpreter.runSignature(inputs, outputs, "infer");
output.rewind();

// Process the result to get the final category values.
int[] testLabels = new int[NUM_TESTS];
for (int i = 0; i < NUM_TESTS; ++i) {
int index = 0;
for (int j = 1; j < 10; ++j) {
if (output.get(i * 10 + index) < output.get(i * 10 + j))
index = testLabels[j];
}
testLabels[i] = index;
}
}


就这么简单!现在您拥有了一个可以使用设备端训练的 TensorFlow Lite 模型。我们希望此代码演示能让您充分了解如何在 TensorFlow Lite 中运行设备端训练,我们很期待看到您的实际成果。


实际使用注意事项


理论上,您应该能将 TensorFlow Lite 中的设备端训练应用于 TensorFlow 支持的任何用例。但实际上,在应用中部署设备端训练前,您需要牢记一些实际使用注意事项:


  • 用例:Colab 示例展示了视觉用例的设备端训练示例。如果您在特定模型或用例方面遇到问题,请在 GitHub 上告诉我们。

  • GitHub

    https://github.com/tensorflow/tensorflow/issues?q=is:open+is:issue+label:comp:lite


  • 性能:根据用例的不同,设备端训练可能需要几秒钟或更长时间。如果运行的设备端训练属于面向用户的功能(例如,您的最终用户正在与该功能互动),您应该计算应用中各种可能的训练输入所花费的时间,以限制训练时间。如果您的用例需要的设备端训练时间很长,请考虑先使用桌面设备或在云端训练模型,然后在设备端进行微调。


  • 电池用量:就像模型推理一样,在设备上调用模型训练可能会导致电池耗尽。如果模型训练属于不面向用户的功能,我们建议遵循 Android 的指南,在后台执行任务。

  • 指南

    https://developer.android.com/guide/background#recommended-solutions


  • 从头开始训练对比再训练:理论上,可以使用上述功能在设备上从头开始训练模型。但实际上,从头开始训练需要大量训练数据,而且即便使用处理器强大的服务器,也要花费几天时间。因此,对于设备端应用,我们建议在已经训练过的模型上再训练(即迁移学习),如 Colab 示例所示。

  • 迁移学习

    https://developers.google.com/machine-learning/glossary#t


路线图


后续工作包括(但不限于)iOS 的设备端训练支持,改进性能以利用设备端加速器(例如 GPU)进行设备端训练,通过在 TensorFlow Lite 中原生实现更多训练算子来降低二进制文件大小,实现更高级别的 API 支持(例如通过 TensorFlow Lite Task Library),以抽象出涵盖其他设备端训练用例(例如 NLP)的实现细节和示例。我们的长期路线图可能涉及提供设备端端到端联合学习解决方案。

  • Task Library

    https://tensorflow.google.cn/lite/inference_with_metadata/task_library/overview


未来计划


感谢您的阅读!我们十分期待看到您使用设备端学习构建的内容。再次提醒,此处是示例应用和 Colab 的链接。如果您有任何反馈,请在 TensorFlow 论坛 或 GitHub 上告诉我们。

  • 示例

    https://github.com/tensorflow/examples/tree/master/lite/examples/model_personalization

  • Colab

    https://tensorflow.google.cn/lite/examples/on_device_training/overview#train_the_tensorflow_lite_model

  • TensorFlow 论坛

    https://discuss.tensorflow.google.cn/


致谢


这篇文章包含 Google TensorFlow Lite 团队众多成员(包括 Michelle Carney、Lawrence Chan、Jaesung Chung、Jared Duke、Terry Heo、Jared Lim、Yu-Cheng Ling、Thai Nguyen、Karim Nosseir、Arun Venkatesan、Haoliang Zhang)、其他 TensorFlow Lite 团队成员,以及我们 Google Research 协作者的重要贡献。


点击“阅读原文”访问 TensorFlow 官网



不要忘记“一键三连”哦~

分享

点赞

在看