使用 TensorFlow Lite 创建卡通化器
2020 年 9 月 9 日
ML GDE Margaret Maynard-Reid (Tiny Peppers) 和 Sayak Paul (PyImageSearch) 的客座文章

这是一篇关于如何将 TensorFlow 模型转换为 TensorFlow Lite (TFLite) 并将其部署到 Android 应用中以对相机捕获的图像进行卡通化处理的端到端教程。

我们创建了这个端到端教程来帮助开发者实现以下目标
  • 为希望将用 TensorFlow 1.x 编写的模型转换为其 TFLite 变体(使用最新(v2)转换器的全新功能,例如基于 MLIR 的转换器、更多支持的操作以及改进的内核等)的开发者提供参考。
    (为了将 TensorFlow 2.x 模型转换为 TFLite,请遵循本指南。))
  • 如果您只对使用模型进行部署感兴趣,如何直接从 TensorFlow Hub 下载 .tflite 模型。
  • 了解如何使用 TFLite 工具,例如 Android 基准工具、模型元数据和代码生成。
  • 指导开发者如何使用 Android Studio 中的 ML 模型绑定功能轻松地使用 TFLite 模型创建移动应用程序。
请按照模型保存/转换、填充元数据的笔记本此处以及 GitHub 上的 Android 代码此处进行操作。如果您不熟悉 SavedModel 格式,请参阅TensorFlow 文档以了解更多信息。虽然本教程讨论了如何创建 TFLite 模型的步骤,但您可以随时直接从 TensorFlow Hub此处下载这些模型并在您自己的应用程序中开始使用它们。

白盒 CartoonGAN 是一种生成对抗网络,能够将输入图像(最好是自然图像)转换为卡通化表示。这里的目标是从输入图像生成卡通化图像,该图像在视觉上和语义上都是美观的。有关该模型的更多详细信息,请查看 Xinrui Wang 和 Jinze Yu 的论文使用白盒卡通表示学习卡通化。在本教程中,我们使用了白盒 CartoonGAN 的生成器部分。

创建 TensorFlow Lite 模型

白盒 CartoonGAN 的作者提供了预训练权重,可用于对图像进行推理。但是,如果我们想要开发一个移动应用程序而无需进行 API 调用来获取这些权重,那么这些权重就不是理想的选择。这就是为什么我们首先将这些预训练权重转换为 TFLite,这将更适合放入移动应用程序中。本节中讨论的所有代码都可以在 GitHub此处找到。
以下是本节中将涵盖内容的分步摘要
  • 从预训练模型检查点生成 SavedModel。
  • 使用最新的 TFLiteConverter 转换 SavedModel 并进行训练后量化。
  • 使用转换后的模型在 Python 中运行推理。
  • 添加元数据以轻松集成到移动应用程序中。
  • 运行模型基准测试以确保模型在移动设备上运行良好。

从预训练模型权重生成 SavedModel

白盒 CartoonGAN 的预训练权重以以下格式提供(也称为检查点) -
├── checkpoint
├── model-33999.data-00000-of-00001
└── model-33999.index
由于原始白盒 CartoonGAN 模型是在 TensorFlow 1 中实现的,因此我们首先需要使用 TensorFlow 1.15 生成一个在 SavedModel 格式中的单一自包含模型文件。然后,我们将在稍后切换到 TensorFlow 2 以将其转换为轻量级的 TFLite 格式。为此,我们可以遵循以下工作流程 -
  • 为模型输入创建占位符。
  • 实例化模型实例并将输入占位符通过模型运行以获取模型输出的占位符。
  • 将预训练检查点加载到模型的当前会话中。
  • 最后,导出到 SavedModel。
请注意,上述工作流程将基于 TensorFlow 1.x。

以下是在 TensorFlow 1.x 中所有这些代码的外观
with tf.Session() as sess:
   input_photo = tf.placeholder(tf.float32, [1, None, None, 3], name='input_photo')
 
   network_out = network.unet_generator(input_photo)
   final_out = guided_filter.guided_filter(input_photo, network_out,           r=1, eps=5e-3)
   final_out = tf.identity(final_out, name='final_output') 
  
   all_vars = tf.trainable_variables()
   gene_vars = [var for var in all_vars if 'generator' in var.name]
   saver = tf.train.Saver(var_list=gene_vars)
   sess.run(tf.global_variables_initializer())
   saver.restore(sess, tf.train.latest_checkpoint(model_path))
  
   # Export to SavedModel
   tf.saved_model.simple_save(
  sess,
      saved_model_directory,
      inputs={input_photo.name: input_photo},
      outputs={final_out.name: final_out}
   )
现在我们已经将原始模型以 SavedModel 格式保存了,我们可以切换到 TensorFlow 2 并继续将其转换为 TFLite。

将 SavedModel 转换为 TFLite

TFLite 支持三种不同的训练后量化策略 -
  • 动态范围
  • Float16
  • 整数
根据用例来确定特定策略。但是,在本教程中,我们将涵盖所有这些不同的量化策略,以使您有一个公平的认识。

具有动态范围和 Float16 量化的 TFLite 模型

使用这两种量化策略将模型转换为 TFLite 的步骤几乎相同,只是在 Float16 量化期间,您需要指定一个额外的选项。模型转换步骤在下面的代码中演示 -
# Create a concrete function from the SavedModel 
model = tf.saved_model.load(saved_model_dir)
concrete_func = model.signatures[
    tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

# Specify the input shape
concrete_func.inputs[0].set_shape([1, IMG_SHAPE, IMG_SHAPE, 3])

# Convert the model and export 
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16] # Only for float16
tflite_model = converter.convert()
open(tflite_model_path, 'wb').write(tflite_model)
从上面的代码中需要注意几件事 -
  • 在这里,我们指定了将转换为 TFLite 的模型的输入形状。但是,请注意,TFLite 支持来自 TensorFlow 2.3 的动态形状模型。我们使用固定形状的输入来限制在移动设备上运行的模型的内存使用量。
  • 为了使用动态范围量化转换模型,只需注释掉这一行 converter.target_spec.supported_types = [tf.float16]

具有整数量化的 TFLite 模型

为了使用整数量化转换模型,我们需要将一个代表性数据集传递给转换器,以便相应地校准激活范围。使用此策略生成的 TFLite 模型通常比我们刚刚看到的其他两种模型效果更好。整数量化模型通常也更小。

为了简洁起见,我们将跳过代表性数据集生成部分,但您可以在此笔记本中参考它。

为了让 TFLiteConverter 利用此策略,我们只需传递 converter.representative_dataset = representative_dataset_gen 并删除 converter.target_spec.supported_types = [tf.float16]

因此,在生成这些不同的模型后,我们从模型大小方面来看,现状如下 -
您可能很想直接使用用整数量化量化的模型,但在做出最终决定之前,您还应该考虑以下因素 -
  • 模型的最终结果的质量。
  • 推理时间(越低越好)。
  • 硬件加速器兼容性。
  • 内存使用量。
我们将在稍后讨论这些内容。如果您想更深入地了解这些不同的量化策略,请参阅官方指南此处

这些模型可在TensorFlow Hub 上找到,您可以在此处找到它们。

在 Python 中运行推理

在生成 TFLite 模型后,务必确保模型按预期执行。确保这一点的一个好方法是在将模型集成到移动应用程序中之前,使用 Python 对模型进行推理。

在将图像馈送到我们的白盒 CartoonGAN TFLite 模型之前,务必确保图像已正确预处理。否则,模型可能会意外执行。原始模型是在使用 BGR 图像训练的,因此我们需要在预处理步骤中考虑这一事实。您可以在此笔记本中找到所有预处理步骤。

以下代码使用 TFLite 模型对预处理的输入图像进行推理 -
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
input_details = interpreter.get_input_details()

interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'],                
preprocessed_source_image)
interpreter.invoke()

raw_prediction = interpreter.tensor(
    interpreter.get_output_details()[0]['index'])()
如上所述,输出将是图像,但具有 BGR 通道排序,这在视觉上可能不合适。因此,我们需要在后处理步骤中考虑这一事实。

在合并后处理步骤后,以下是最终图像与原始输入图像的对比 - 再次,您可以在此笔记本中找到所有后处理步骤。

添加元数据以轻松集成到移动应用程序中

TFLite 中的模型元数据可以让移动应用程序开发人员的生活更轻松。如果您的 TFLite 模型填充了正确的元数据,那么将该模型集成到移动应用程序中就只需几个按键。讨论填充 TFLite 模型元数据的代码超出了本教程的范围,请参阅 元数据指南。但在本节中,我们将为您提供一些关于为我们生成的 TFLite 模型填充元数据的关键要点。您可以按照 此笔记本 引用所有代码。我们在元数据填充期间发现的两个最重要的参数是均值和标准差,结果应该用它们来处理。在我们的案例中,均值和标准差需要用于预处理后处理。为了规范化输入图像,元数据配置应该类似于以下内容 -
input_image_normalization.options.mean = [127.5]
input_image_normalization.options.std = [127.5]
这将使输入图像中的像素范围变为 [-1, 1]。现在,在后处理过程中,需要将像素缩放到 [0, 255] 的范围内。为此,配置将如下所示 -
output_image_normalization.options.mean = [-1]
output_image_normalization.options.std = [0.00784313] # 1/127.5
从“添加元数据过程”创建了两个文件
  • 一个与原始模型同名的 .tflite 文件,其中添加了元数据,包括模型名称、描述、版本、输入和输出张量等。
  • 为了帮助显示元数据,我们还将元数据导出到一个 .json 文件中,以便您可以打印出来。当您将模型导入 Android Studio 时,元数据也可以显示。
填充了元数据的模型可以很容易地导入 Android Studio,我们将在后面的“模型部署到 Android”部分讨论。

在 Android 上基准测试模型(可选)

作为一个可选步骤,我们使用 TFLite Android 模型基准测试工具来了解部署之前在 Android 上的运行时性能。

使用基准测试工具有两种选择,一种是使用 C++ 二进制文件 在后台运行,另一种是使用 Android APK 在前台运行。

以下是使用基准测试 C++ 二进制文件的高级摘要

1. 配置 Android SDK/NDK 先决条件

2. 使用 bazel 构建基准测试 C++ 二进制文件
bazel build -c opt \
      --config=android_arm64 \
      tensorflow/lite/tools/benchmark:benchmark_model
3. 使用 adb(Android 调试桥)将基准测试工具二进制文件推送到设备并使其可执行
adb push benchmark_model /data/local tmp
    adb shell chmod +x /data/local/tmp/benchmark_model
4. 将 whitebox_cartoon_gan_dr.tflite 模型推送到设备
adb push whitebox_cartoon_gan_dr.tflite /data/local/tmp
5. 运行基准测试工具
adb shell /data/local/tmp/android_aarch64_benchmark_model \       
      --graph=/data/local/tmp/whitebox_cartoon_gan_dr.tflite \
      --num_threads=4
您将在终端中看到类似以下结果: 对另外两个 tflite 模型重复上述步骤:float16int8 变体。

总之,以下是我们在 Pixel 4 上运行基准测试工具获得的平均推理时间
有关详细信息和附加选项(例如如何减少运行之间的差异以及如何分析运算符等),请参阅基准测试工具的文档(C++ 二进制文件 | Android APK)。您还可以在 TensorFlow 官方文档 这里 查看一些流行的 ML 模型的性能值。

模型部署到 Android

现在,我们已经通过以下步骤(或直接从 TensorFlow Hub 下载模型 这里)获得了具有元数据的量化 TensorFlow Lite 模型,我们就可以将它们部署到 Android 了。请参考 GitHub 上的 Android 代码 这里

Android 应用程序使用 Jetpack Navigation Component 进行 UI 导航,使用 CameraX 进行图像捕获。我们使用新的 ML 模型绑定功能导入 tflite 模型,然后使用 Kotlin 协程进行模型推理的异步处理,以便在等待结果时不会阻塞 UI。

让我们一步一步深入了解细节
  • 下载 Android Studio 4.1 预览版。
  • 创建一个新的 Android 项目并设置 UI 导航。
  • 为图像捕获设置 CameraX API。
  • 使用 ML 模型绑定导入 .tflite 模型。
  • 将所有内容整合在一起。

下载 Android Studio 4.1 预览版

我们首先需要安装 Android Studio 预览版(4.1 Beta 1),以便使用新的 ML 模型绑定功能导入 .tflite 模型和自动代码生成。然后,您可以以可视化方式浏览 tfllite 模型,最重要的是在您的 Android 项目中直接使用生成的类。

这里 下载 Android Studio 预览版。您应该能够将预览版与 Android Studio 的稳定版并排运行。确保将 Gradle 插件更新至至少 4.1.0-alpha10;否则 ML 模型绑定菜单可能无法访问。

创建一个新的 Android 项目

首先,让我们创建一个新的 Android 项目,其中包含一个名为 MainActivity.kt 的空 Activity,该活动包含一个伴随对象,定义了将存储捕获图像的输出目录。

使用 Jetpack Navigation Component 导航应用程序的 UI。请参阅 此教程 以详细了解此支持库。

此示例应用程序中有 3 个屏幕
  • PermissionsFragment.kt 处理检查相机权限。
  • CameraFragment.kt 处理相机设置、图像捕获和保存。
  • CartoonFragment.kt 处理在 UI 中显示输入图像和卡通图像。
nav_graph.xml 中的导航图定义了三个屏幕的导航以及 CameraFragmentCartoonFragment 之间的数据传递。

为图像捕获设置 CameraX

CameraX 是一个 Jetpack 支持库,它使相机应用程序开发变得更加容易。

Camera1 API 使用简单,但缺少很多功能。Camera2 API 比 Camera1 提供更精细的控制,但它非常复杂,在一个非常基本的示例中几乎包含 1000 行代码。

另一方面,CameraX 设置起来要容易得多,代码量减少了 10 倍。此外,它还具有生命周期感知,因此您不需要编写额外的代码来处理 Android 生命周期。

以下是为本示例应用程序设置 CameraX 的步骤
  • 更新 build.gradle 依赖项
  • 使用 CameraFragment.kt 保存 CameraX 代码
  • 请求相机权限
  • 更新 AndroidManifest.ml
  • MainActivity.kt 中检查权限
  • 使用 CameraX Preview 类实现一个取景器
  • 实现图像捕获
  • 捕获图像并将其转换为 Bitmap
CameraSelector 被配置为能够使用前置摄像头和后置摄像头,因为该模型可以对任何类型的面部或物体进行风格化,而不仅仅是自拍照。

捕获图像后,我们将它转换为 Bitmap,并将其传递给 TFLite 模型进行推理。导航到一个新的屏幕 CartoonFragment.kt,在这里将显示原始图像和卡通化图像。

导入 TensorFlow Lite 模型

现在,UI 代码已经完成。现在该导入 TensorFlow Lite 模型进行推理了。ML 模型绑定可以轻松地完成此操作。在 Android Studio 中,转到文件 > 新建 > 其他 > TensorFlow Lite 模型:
  • 指定 .tflite 文件位置。
  • “自动将构建功能和必需的依赖项添加到 gradle”默认情况下处于选中状态。
  • 还要确保选中“自动将 TensorFlow Lite gpu 依赖项添加到 gradle”,因为 GAN 模型很复杂且速度很慢,因此我们需要启用 GPU 代理。
此导入完成了两个操作
  • 自动创建一个 ml 文件夹并将模型文件 .tflite 文件放在其中。
  • 自动在文件夹下生成一个 Java 类:app/build/generated/ml_source_out/debug/[package-name]/ml,该类处理所有任务,例如模型加载、图像预处理和后处理,并运行模型推理以对输入图像进行风格化。
导入完成后,我们看到 *.tflite 显示模型元数据信息以及 Kotlin 和 Java 中的代码片段,这些代码片段可以复制/粘贴以使用模型: 重复上述步骤以导入其他两个 .tflite 模型变体。

将所有内容整合在一起

现在,我们已经设置了 UI 导航,配置了 CameraX 以进行图像捕获,并且 tflite 模型已导入,让我们将所有部分整合在一起!
  • 模型输入:使用 CameraX 捕获照片并保存
  • 对输入图像运行推理并创建卡通化版本
  • 在 UI 中显示原始照片和卡通化照片
  • 使用 Kotlin 协程防止模型推理阻塞 UI 主线程
首先,我们在 CameraFragment.ktimageCaptue?.takePicture() 中使用 CameraX 捕获照片,然后在 ImageCapture.OnImageSavedCallback{}onImageSaved() 中将 .jpg 图像转换为 Bitmap,在必要时进行旋转,然后将其保存到之前在 MainActivity 中定义的输出目录中。

使用 JetPack Nav Component,我们可以轻松地导航到 CartoonFragment.kt,并将图像目录位置作为字符串参数传递,并将 tflite 模型类型作为整数传递。然后在 CartoonFragment.kt 中,检索照片存储的文件目录字符串,创建一个图像文件,然后将其转换为 Bitmap,可将其用作 tflite 模型的输入。

CartoonFragment.kt 中,还要检索用于推理的 tflite 模型类型。对输入图像运行模型推理并创建卡通图像。我们在 UI 中显示原始图像和卡通化图像。

注意:推理需要时间,因此我们使用 Kotlin 协程来防止模型推理阻塞 UI 主线程。显示一个 ProgressBar,直到模型推理完成。

以下是将所有部分整合在一起后的情况,以及模型创建的卡通图像: 这就结束了本教程。希望您喜欢阅读它,并将学到的知识应用到您使用 TensorFlow Lite 的实际应用程序中。如果您使用您在这里学到的知识创建了任何很酷的示例,请记住将其添加到 awesome-tflite 中,这是一个包含 TensorFlow Lite 示例、教程、工具和学习资源的仓库。

鸣谢

使用 TensorFlow Lite 的卡通化器 项目,包含完整的教程,由 ML GDEs 和 TensorFlow Lite 团队共同完成。这是 TensorFlow Lite 端到端教程 系列的一部分。我们要感谢 Khanh LeViet 和 Lu Wang (TensorFlow Lite)、Hoi Lam (Android ML)、Trevor McGuire (CameraX) 和 Soonson Kwon (ML GDEs Google Developers Experts Program),感谢他们的合作和持续支持。

也要感谢论文 基于白盒卡通表示的卡通化学习 的作者:王欣蕊和于晋泽。

在开发应用程序时,重要的是要考虑 负责任的创新最佳实践;请查看 TensorFlow 负责任 AI 以获取您可以使用的资源和工具。
下一篇文章
How to Create a Cartoonizer with TensorFlow Lite

由 ML GDEs Margaret Maynard-Reid (Tiny Peppers) 和 Sayak Paul (PyImageSearch) 撰写

本文介绍了如何将 TensorFlow 模型转换为 TensorFlow Lite (TFLite) 并将其部署到 Android 应用程序,以卡通化相机拍摄的图像。

我们创建了此端到端教程,以帮助开发人员实现以下目标:
为希望... 的开发人员提供参考: