由于我们想要使用自定义模型进行设备上的机器学习,因此 TensorFlow Lite 是一个显而易见的选择,因为它可以轻松地将服务器上训练的模型转换为与手机兼容的模型(.tflite 格式),方法是使用 TFLiteConverter。
此外,我们已经在生产系统中体验过 TensorFlow 和 TensorFlow Serving 在服务器端机器学习方面的成功。TensorFlow 的库的设计初衷就是将运行机器学习作为主要目标,因此我们认为 TensorFlow Lite 也不会有所不同。
我们使用 ML Kit 在 TensorFlow Lite 模型上直接运行推理,并将其无缝地集成到我们的应用程序中。这使我们能够快速将该功能从原型阶段发展到生产就绪阶段。ML Kit 为我们提供了更高级别的 API,以便我们可以处理模型的初始化和加载,以及在图像上运行推理,而无需直接处理更低级别的 TensorFlow Lite C++ 库,从而使开发过程更快,并让我们有更多时间来完善模型而不是处理其他事情。
使用我们与人工审核员一起创建的分类数据集,我们基于 SqueezeNet 架构在 TensorFlow 中训练了一个 CNN 模型。我们选择这个架构是因为它体积更小,同时精度损失不大。我们使用 TFLiteConverter 将这个训练过的模型从 TensorFlow 的 Saved Model 格式转换为 TensorFlow Lite(.tflite)格式,以便在 Android 上使用。此阶段初始错误的部分原因是,我们使用的 TFLiteConverter 版本与 ML Kit 通过 Maven 引用到的 TensorFlow Lite 库版本不匹配。ML Kit 团队在解决这些问题方面提供了很大的帮助。
graph_def_file = “model_name.pb”
input_arrays = [“input_tensor_name”] # this array can have more than one input name if the model requires multiple inputs
output_arrays = [“output_tensor_name”] # this array can have more than one input name if the model has multiple outputs
converter = lite.TFLiteConverter.from_frozen_graph(
graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
在能够为图像分配类别的模型完成后,我们就可以将其打包到我们的应用程序中,并使用 ML Kit 对图像进行推理。由于我们使用的是自己的自定义训练模型,因此我们使用了 ML Kit 中的 Custom Model API。为了获得更高的精度,我们决定在模型转换过程中放弃量化步骤,并决定在 ML Kit 中使用浮点模型。这里遇到了一些挑战,因为 ML Kit 默认情况下会假设使用量化模型。但是,我们没有花费太多精力就能够更改模型初始化中的某些步骤以支持浮点模型。
// create a model interpreter for local model (bundled with app)
FirebaseModelOptions modelOptions = new FirebaseModelOptions.Builder()
.setLocalModelName(“model_name”)
.build();
modelInterpreter = FirebaseModelInterpreter.getInstance(modelOptions);
// specify input output details for the model
// SqueezeNet architecture uses 227 x 227 image as input
modelInputOutputOptions = new FirebaseModelInputOutputOptions.Builder()
.setInputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, 227, 227, 3})
.setOutputFormat(0, FirebaseModelDataType.FLOAT32, new int[]{1, numLabels})
.build();
// create input data
FirebaseModelInputs input = new FirebaseModelInputs.Builder().add(imgDataArray).build(); // imgDataArray is a float[][][][] array of (1, 227, 227, 3)
// run inference
modelInterpreter.run(input, modelInputOutputOptions);
我们不断更新这些精选目录,因为我们会在会员服务中添加新的预设。为了方便我们在用户无需更新应用的情况下动态更新这些列表,我们决定将这些目录存储在服务器上,并使用 API(一个用 Go 编写的微服务)提供服务,移动客户端可以定期向 API 查询以确保他们拥有最新版本的目录。移动客户端会缓存这个目录,只有在有新版本可用时才会获取。然而,这种方法会导致“冷启动”问题,对于第一次使用此功能且尚未连接互联网的用户来说,应用程序无法与 API 通信并下载这些目录。为了解决这个问题,我们决定在应用程序中附带这些目录的默认版本。这使得所有用户无论其互联网连接状况如何都可以使用此功能,这也是该功能最初的目标之一。