TensorFlow Lite 中的设备上训练
2021 年 11 月 9 日

TensorFlow Lite 团队发布

TensorFlow Lite 是 Google 的机器学习框架,用于在多个设备和表面(如移动设备(iOS 和 Android)、台式机和其他边缘设备)上部署机器学习模型。最近,我们还添加了在浏览器中运行 TensorFlow Lite 模型的支持。为了使用 TensorFlow Lite 构建应用程序,您可以使用 TensorFlow Hub 中的现成模型,或者使用 转换器 将现有 TensorFlow 模型转换为 TensorFlow Lite 模型。模型部署到应用程序后,您可以根据输入数据 运行推理

除了运行推理之外,TensorFlow Lite 现在还支持在设备上训练您的模型。设备上训练支持有趣的个性化用例,模型可以在这些用例中根据用户需求进行微调。例如,您可以部署一个图像分类模型,并允许用户微调模型以使用 迁移学习 识别鸟类物种,同时允许其他用户重新训练同一模型以识别水果。此新功能在 TensorFlow 2.7 及更高版本中可用,目前可用于 Android 应用程序。(iOS 支持将在未来添加。)

设备上训练也是 联邦 学习用例在分散数据上训练全局模型的必要基础。本博文不涵盖联邦学习,而是专注于帮助您将设备上训练集成到您的 Android 应用程序中。

在本文的后面,我们将参考一个 Colab 和一个 Android 示例应用程序,当我们指导您完成设备上学习的端到端实现路径以微调图像分类模型时。

与早期方法相比的改进

在我们 2019 年的 博文 中,我们介绍了设备上训练的概念以及 TensorFlow Lite 中设备上训练的示例。但是,存在一些限制。例如,很难自定义模型结构和优化器。您还必须处理多个物理 TensorFlow Lite(.tflite)模型,而不是单个 TensorFlow Lite 模型。同样,也没有简单的方法来存储和更新训练权重。我们最新的 TensorFlow Lite 版本通过提供更方便的设备上训练选项来简化此过程,如下所述。

它是如何工作的?

为了部署内置设备上训练的 TensorFlow Lite 模型,以下是高级步骤

  • 构建用于训练和推理的 TensorFlow 模型
  • 将 TensorFlow 模型转换为 TensorFlow Lite 格式
  • 将模型集成到您的 Android 应用程序中
  • 在应用程序中调用模型训练,类似于您调用模型推理的方式

这些步骤将在下面进行解释。

构建用于训练和推理的 TensorFlow 模型

TensorFlow Lite 模型不仅应支持模型推理,还应支持模型训练,这通常涉及将模型的权重保存到文件系统以及从文件系统恢复权重。这样做是为了在每个训练时期后保存训练权重,以便下一个训练时期可以使用前一个时期的权重,而不是从头开始训练。

我们建议的方法是实现这些 tf.functions 来表示训练、推理、保存权重和加载权重

  • 一个 train 函数,使用训练数据训练模型。以下 train 函数进行预测、计算损失(或误差),并使用 tf.GradientTape() 记录 自动微分 的操作并更新模型的参数。
    # The `train` function takes a batch of input images and labels.
    @tf.function(input_signature=[
         tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
         tf.TensorSpec([None, 10], tf.float32),
     ])
    def train(self, x, y):
       with tf.GradientTape() as tape:
         prediction = self.model(x)
         loss = self._LOSS_FN(prediction, y)
       gradients = tape.gradient(loss, self.model.trainable_variables)
       self._OPTIM.apply_gradients(
           zip(gradients, self.model.trainable_variables))
       result = {"loss": loss}
       for grad in gradients:
         result[grad.name] = grad
       return result
    
  • 一个 inferpredict 函数,调用模型推理。这类似于您目前使用 TensorFlow Lite 进行推理的方式。
    @tf.function(input_signature=[tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32)])
     def predict(self, x):
       return {
           "output": self.model(x)
       }
    
  • 一个 save/restore 函数,以 检查点 格式将训练权重(即模型使用的参数)保存到文件系统。以下显示了 save 函数的代码。
    @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
     def save(self, checkpoint_path):
       tensor_names = [weight.name for weight in self.model.weights]
       tensors_to_save = [weight.read_value() for weight in self.model.weights]
       tf.raw_ops.Save(
           filename=checkpoint_path, tensor_names=tensor_names,
           data=tensors_to_save, name='save')
       return {
           "checkpoint_path": checkpoint_path
       }
    

转换为 TensorFlow Lite 格式

您可能已经熟悉将 TensorFlow 模型 转换 为 TensorFlow Lite 格式的工作流程。设备上训练的一些底层功能(例如存储模型参数的变量)仍然是实验性的,其他功能(例如权重序列化)目前依赖于 TF Select 运算符,因此您需要在转换期间设置这些标志。您可以在 Colab 中找到需要设置的所有标志的示例。

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
   tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
   tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()

将模型集成到您的 Android 应用程序中

将模型转换为 TensorFlow Lite 格式后,就可以将其集成到应用程序中了!有关更多详细信息,请参阅 Android 应用程序示例。

在应用程序中调用模型训练和推理

在 Android 上,可以使用 Java 或 C++ API 执行 TensorFlow Lite 设备上训练。您可以创建 TensorFlow Lite 解释器 的实例来加载模型并驱动模型训练任务。我们之前定义了多个 tf.functions:这些函数可以使用 TensorFlow Lite 对 签名 的支持进行调用,这允许单个 TensorFlow Lite 模型支持多个“入口点”。例如,我们定义了一个用于设备上训练的 train 函数,它是模型的其中一个签名。可以通过指定签名的名称(“train”)使用 TensorFlow Lite 的 runSignature 方法调用 train 函数

 // Run training for a few steps.
float[] losses = new float[NUM_EPOCHS];
for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) {
    for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) {
        Map<String, Object> inputs = new HashMap<>>();
        inputs.put("x", trainImageBatches.get(batchIdx));
        inputs.put("y", trainLabelBatches.get(batchIdx));

        Map<String, Object> outputs = new HashMap<>();
        FloatBuffer loss = FloatBuffer.allocate(1);
        outputs.put("loss", loss);

        interpreter.runSignature(inputs, outputs, "train");

        // Record the last loss.
        if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0);
    }
}


类似地,以下示例显示了如何使用模型的“infer”签名调用推理

try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
    // Restore the weights from the checkpoint file.

    int NUM_TESTS = 10;
    FloatBuffer testImages = FloatBuffer.allocateDirect(NUM_TESTS * 28 * 28).order(ByteOrder.nativeOrder());
    FloatBuffer output = FloatBuffer.allocateDirect(NUM_TESTS * 10).order(ByteOrder.nativeOrder());

    // Fill the test data.

    // Run the inference.
    Map<String, Object> inputs = new HashMap<>>();
    inputs.put("x", testImages.rewind());
    Map<String, Object> outputs = new HashMap<>();
    outputs.put("output", output);
    anotherInterpreter.runSignature(inputs, outputs, "infer");
    output.rewind();

    // Process the result to get the final category values.
    int[] testLabels = new int[NUM_TESTS];
    for (int i = 0; i < NUM_TESTS; ++i) {
        int index = 0;
        for (int j = 1; j < 10; ++j) {
            if (output.get(i * 10 + index) < output.get(i * 10 + j))
                index = testLabels[j];
        }
        testLabels[i] = index;
    }
}

就是这样!您现在拥有一个能够使用设备上训练的 TensorFlow Lite 模型。我们希望此代码演练能使您对如何在 TensorFlow Lite 中运行设备上训练有一个很好的了解,我们很高兴看到您将它用于何处。

实际注意事项

从理论上讲,您应该能够将 TensorFlow Lite 中的设备上训练应用于 TensorFlow 支持的任何用例。但是,在实际应用中,在应用程序中部署设备上训练之前,您需要牢记一些实际注意事项

  • 用例:Colab 示例展示了视觉用例的设备上训练示例。如果您在特定模型或用例中遇到问题,请在 GitHub 上告知我们。
  • 性能:根据用例的不同,设备上训练可能需要几秒钟到更长时间。如果将设备上训练作为用户界面功能的一部分运行(例如,您的最终用户正在与该功能交互),则应测量应用程序中各种可能的训练输入所花费的时间,以限制训练时间。如果您的用例需要非常长的设备上训练时间,请考虑首先使用台式机或云进行模型训练,然后在设备上进行微调。
  • 电池使用情况:与模型推理一样,在设备上调用模型训练可能会导致电池电量消耗。如果模型训练是用户界面以外的功能的一部分,我们建议遵循 Android 的 指南 来实现后台任务。
  • 从头开始训练与重新训练:在理论上,应该可以使用上述功能从头开始在设备上训练模型。然而,在现实中,从头开始训练涉及大量的训练数据,即使在配备强大处理器的服务器上,也可能需要几天的时间。因此,对于设备上应用程序,我们建议在已训练的模型上重新训练(即 迁移学习),如 Colab 示例所示。

路线图

未来的工作包括(但不限于)iOS 上的设备上训练支持、利用设备上加速器(例如 GPU)进行设备上训练以提高性能、通过在 TensorFlow Lite 中本地实现更多训练操作来减小二进制文件大小、更高级别的 API 支持(例如通过 TensorFlow Lite 任务库)来抽象实现细节以及涵盖其他设备上训练用例(例如 NLP)的示例。我们的长期路线图可能涉及提供设备上端到端联邦学习解决方案。

下一步

感谢您的阅读!我们很高兴看到您使用设备上学习构建的内容。再次提醒您,以下是 示例 应用程序和 Colab 的链接。如果您有任何反馈,请在 TensorFlow 论坛GitHub 上告知我们。

鸣谢

这篇文章反映了 Google TensorFlow Lite 团队中许多人的重要贡献,包括 Michelle Carney、Lawrence Chan、Jaesung Chung、Jared Duke、Terry Heo、Jared Lim、Yu-Cheng Ling、Thai Nguyen、Karim Nosseir、Arun Venkatesan、Haoliang Zhang、其他 TensorFlow Lite 团队成员以及我们在 Google Research 的合作者。

下一篇文章
On-device training in TensorFlow Lite

TensorFlow Lite 团队发布 TensorFlow Lite 是 Google 的机器学习框架,用于在多种设备和界面上部署机器学习模型,例如移动设备(iOS 和 Android)、桌面和其他边缘设备。最近,我们还添加了在浏览器中运行 TensorFlow Lite 模型的支持。为了 构建应用程序 使用 TensorFlow Lite,您可以使用来自 Tens… 的现成模型