如何使用 TensorFlow Lite 在 Android 上生成超分辨率图像
2020 年 12 月 18 日

由 TensorFlow 开发倡导者 Wei Wei 发布


从低分辨率图像恢复高分辨率(HR)图像的任务通常被称为单图像超分辨率(SISR)。虽然双线性或三次插值等插值方法可以用来对低分辨率图像进行上采样,但所得图像的质量通常不太理想。深度学习,特别是生成对抗网络,已成功应用于生成更逼真的图像,例如 SRGANESRGAN。在本篇博文中,我们将使用来自 TensorFlow Hub 的预训练 ESRGAN 模型,并在 Android 应用中使用 TensorFlow Lite 生成超分辨率图像。最终的应用程序看起来如下所示,完整的代码已在 TensorFlow 示例库 中发布,供参考。

Screencap of TensorFlow Lite

首先,我们可以方便地从 TFHub 加载 ESRGAN 模型,并轻松将其转换为 TFLite 模型。请注意,这里我们使用的是 动态范围量化,并将输入图像尺寸固定为 50x50。已将转换后的模型上传到 TFHub,但我们要演示如何做到这一点,以防您想自己转换它(例如,在您自己的应用程序中尝试不同的输入大小)

model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
concrete_func.inputs[0].set_shape([1, 50, 50, 3])
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
  f.write(tflite_model)

esrgan_model_path = './ESRGAN.tflite'

您也可以在转换时不硬编码输入维度,并在运行时调整输入张量的大小,从而转换模型,因为 TFLite 现在支持动态形状的输入。请参阅 此示例,以了解更多信息。

模型转换完成后,我们可以快速验证 ESRGAN TFLite 模型是否确实生成了比三次插值更好的图像。如果您想更好地了解模型,我们还有另一个关于 ESRGAN 的 教程

lr = cv2.imread(test_img_path)
lr = cv2.cvtColor(lr, cv2.COLOR_BGR2RGB)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Run the model
interpreter.set_tensor(input_details[0]['index'], lr)
interpreter.invoke()

# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)
ESRGAN model with low res and high res image
LR:DIV2K 数据集 中的蝴蝶图像中裁剪出的低分辨率输入图像。ESRGAN (x4): 使用 ESRGAN 模型生成的上采样比例为 4 的超分辨率输出图像。双三次插值: 使用双三次插值生成的输出图像。如这里所示,双三次插值生成的图像比 ESRGAN 生成的图像模糊得多。ESRGAN 生成的图像上的 PSNR 也更高。

您可能已经知道,TensorFlow Lite 是在边缘设备上运行 TensorFlow 模型进行推理的官方框架,它已部署在全球超过 40 亿台边缘设备上,支持 Android、iOS、基于 Linux 的物联网设备和微控制器。您可以在 Java、C/C++ 或其他语言中使用 TFLite 来构建 Android 应用程序。在本篇博文中,我们将使用 TFLite C API,因为许多开发人员都要求提供这样的示例。

我们将 预构建的 AAR 文件中分发 TFLite C 头文件和库(核心库和 GPU 库)。为了正确设置 Android 构建,我们首先需要做的是下载 AAR 文件并提取头文件和共享库。看看如何在 download.gradle 文件中完成此操作。

由于我们使用 Android NDK 来构建应用程序(NDK r20 已确认可以正常工作),我们需要让 Android Studio 知道如何处理本机文件。这是在 CMakeList.txt 文件中完成的

我们在应用程序中包含了 3 个示例图像,因此用户可以轻松地多次运行同一个模型,这意味着我们需要缓存解释器以提高效率。这是通过在解释器成功创建后,将解释器指针从 C++ 代码传递到 Java 代码中完成的

extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_examples_superresolution_MainActivity_initWithByteBufferFromJNI(JNIEnv *env, jobject thiz, jobject model_buffer, jboolean use_gpu) {
  const void *model_data = static_cast<void *>(env->GetDirectBufferAddress(model_buffer));
  jlong model_size_bytes = env->GetDirectBufferCapacity(model_buffer);
  SuperResolution *super_resolution = new SuperResolution(model_data, static_cast<size_t>(model_size_bytes), use_gpu);
  if (super_resolution->IsInterpreterCreated()) {
    LOGI("Interpreter is created successfully");
    return reinterpret_cast<jlong>(super_resolution);
   } else {
    delete super_resolution;
    return 0;
  }
}

解释器创建完成后,运行模型非常简单,因为我们可以按照 TFLite C API 文档进行操作。我们首先需要小心地 从每个像素中提取 RGB 值。现在我们可以运行解释器了

// Feed input into model
status = TfLiteTensorCopyFromBuffer(input_tensor, input_buffer, kNumberOfInputPixels * kImageChannels * sizeof(float));
…...

// Run the interpreter
status = TfLiteInterpreterInvoke(interpreter_);
…...

// Extract the output tensor data
const TfLiteTensor* output_tensor = TfLiteInterpreterGetOutputTensor(interpreter_, 0);
float output_buffer[kNumberOfOutputPixels * kImageChannels];
status = TfLiteTensorCopyToBuffer(output_tensor, output_buffer, kNumberOfOutputPixels * kImageChannels * sizeof(float));
…...

有了模型结果,我们可以 将 RGB 值打包回每个像素中。

就是这样。一个使用 TFLite 在设备上生成超分辨率图像的参考 Android 应用程序。更多详细信息可以在 代码库 中找到。希望这对开始使用 TFLite 在 C/C++ 中构建令人惊叹的 ML 应用程序的 Android 开发人员有所帮助。

反馈

我们期待着看到您使用 TensorFlow Lite 构建的内容,以及您的反馈。请通过 直接联系我们或在 Twitter 上使用主题标签 #TFLite 和 #PoweredByTF 与我们分享您的用例。要报告错误和问题,请在 GitHub 上联系我们。

致谢

作者要感谢 @captain__pool 将 ESRGAN 模型上传到 TFHub,还要感谢 Tian Lin 和 Jared Duke 提供的宝贵反馈。

[1] Christian Ledig、Lucas Theis、Ferenc Huszar、Jose Caballero、Andrew Cunningham、Alejandro Acosta、Andrew Aitken、Alykhan Tejani、Johannes Totz、Zehan Wang、Wenzhe Shi。2016 年。使用生成对抗网络进行逼真的单图像超分辨率。

[2] Xintao Wang、Ke Yu、Shixiang Wu、Jinjin Gu、Yihao Liu、Chao Dong、Chen Change Loy、Yu Qiao、Xiaoou Tang。2018 年。ESRGAN:增强型超分辨率生成对抗网络。

[3] 基于 Tensorflow 2.x 的 EDSR、WDSR 和 SRGAN 实现,用于单图像超分辨率

[4] @captain__pool 的 ESGRAN 代码实现

[5] Eirikur Agustsson、Radu Timofte。2017 年。NTIRE 2017 单图像超分辨率挑战:数据集和研究。

下一篇文章
How to generate super resolution images using TensorFlow Lite on Android

由 TensorFlow 开发倡导者 Wei Wei 发布
从低分辨率图像恢复高分辨率(HR)图像的任务通常被称为单图像超分辨率(SISR)。虽然双线性或三次插值等插值方法可以用来对低分辨率图像进行上采样,但所得图像的质量通常不太理想。深度学习,特别是生成对抗网络,已成功应用于生成更逼真的图像,例如 SRGANESRGAN。在本篇博文中,我们将使用来自 TensorFlow Hub 的预训练 ESRGAN 模型,并在 Android 应用中使用 TensorFlow Lite 生成超分辨率图像。最终的应用程序看起来如下所示,完整的代码已在 TensorFlow 示例库 中发布,供参考。