2022 年 8 月 31 日 — 发布者 Andreas Steiner 和 Marc van Zee,Google Research,Brain 团队简介 在这篇博文中,我们将展示如何使用 TensorFlow.js 将基于 Python 的 JAX 函数和 Flax 机器学习模型转换为在浏览器中运行。我们制作了三个 JAX 到 TensorFlow.js 转换的例子,每个例子的复杂度递增: 简单的 JAX 函数 在 MNIST 数据集上训练的图像分类 Flax 模型 完整的图像/文本 Vision Transformer (ViT) 演示,它被用于 Google AI 博客文章 Locked-Image Tuning: Adding Language Understanding to Image Models (演示的预览见下图 1)
发布者 Andreas Steiner 和 Marc van Zee,Google Research,Brain 团队
在这篇博文中,我们将展示如何使用 TensorFlow.js 将基于 Python 的 JAX 函数和 Flax 机器学习模型转换为在浏览器中运行。我们制作了三个 JAX 到 TensorFlow.js 转换的例子,每个例子的复杂度递增:
对于每个例子,都有 Google Colab 笔记本供您使用,您可以自行尝试 JAX 到 TensorFlow.js 的转换。
图 1. TensorFlow.js 模型将用户提供的文本提示与预计算的图像嵌入相匹配 (自己试试)。请参阅下文的 示例 3:LiT 演示 以了解实现细节。 |
.json
),以便模型可以在浏览器中与 Tensorflow.js 一起使用。tf.Variable
s) 和计算。图 2. converters.convert_jax() 内部转换步骤的高级可视化,它将 JAX 函数转换为 Tensorflow.js 模型。 |
weight
并实现函数 prod
,该函数将输入与参数相乘(在一个真实的示例中,params
将包含神经网络中使用的所有模块的权重)
def prod(params, xs): return params['weight'] * xs |
params = {'weight': np.array([0.5, 1])} # 这代表 3 个输入的批次,每个输入长度为 2。 xs = np.arange(6).reshape((3, 2)) prod(params, xs) |
[0.5, 1]
[[0. 1.] [1. 3.] [2. 5.]] |
convert_jax
将其转换为 TensorFlow.js,并使用辅助函数 get_tfjs_predict_fn
(可在 Colab 中找到),让我们可以验证 JAX 函数和 Web 模型的输出是否匹配。(注意:此辅助函数仅在 Colab 中有效,因为它使用一些工具使用 Javascript 运行 Web 模型。)
tfjs.converters.convert_jax( prod, params, input_signatures=[tf.TensorSpec((3, 2), tf.float32)], model_dir=model_dir) tfjs_predict_fn = get_tfjs_predict_fn(model_dir) tfjs_predict_fn(xs) # 与 JAX 的输出相同。 |
tfjs.converters.convert_jax( prod, params, input_signatures=[tf.TensorSpec((None, 2), tf.float32)], polymorphic_shapes=['(b, 2)')], model_dir=model_dir) tfjs_predict_fn = get_tfjs_predict_fn(model_dir) tfjs_predict_fn(np.array([[1., 2.]])) # 输出: [[0.5, 2. ]] |
train_ds, test_ds = train.get_datasets() state = train.train_and_evaluate(config, workdir=f'./workdir') |
state.apply_fn
,它可以用来计算输入图像的 logits。请注意,该函数需要第一个参数为模型权重 state.params
。给定一个形状为 [batch_size, 28, 28, 1
] 的输入图像批次,这将生成每个模型(形状为 [batch_size, 10 ])对十个标签的概率分布的 logits。
logits = state.apply_fn({'params': state.params}, imgs) |
tfjs.converters.convert_jax( state.apply_fn, {'params': state.params}, input_signatures=[tf.TensorSpec((1, 28, 28, 1), tf.float32)], model_dir=tfjs_model_dir, ) |
status
文本中显示简单的进度更新,确保在传输模型权重时提供一些反馈。
tf.loadGraphModel(modelDir + '/model.json', { onProgress: p => status.innerText = `loading model: ${Math.round(p*100)}%` }) |
img
是长度为 28*28
的 Uint8Array
,它首先被转换为 TensorFlow.js tf.tensor
,然后计算模型输出,并通过 tf.softmax() 函数将其转换为概率。然后,通过调用 .dataSync()
同步等待计算的输出值,并在显示之前将其转换为 JavaScript 数组。
ui.onUpdate(img => { const imgs = tf.tensor(img).cast('float32').reshape([1, 28, 28, 1]) const logits = model.predict(imgs) const preds = tf.softmax(logits) const { values, indices } = tf.topk(preds, 10) ui.showPreds([...values.dataSync()], [...indices.dataSync()]) }) |
图 3. 我们来自 Colab 的模型在 MNIST 测试数据集上的准确率为 99.1%,但仍令人惊讶地难以识别手写数字。左侧,模型预测了各种数字,而不是“一”。右侧,“一”的绘制方式更像是训练集中的数据。示例 3:LiT 演示使用 TensorFlow.js 模型编写更现实的应用程序则更加复杂。本节将介绍用于创建 Google AI 博客文章 锁定图像微调:为图像模型添加语言理解功能 中的演示应用程序的主要步骤。请参考该文章以了解 ML 模型实现的技术细节。同时,请确保查看最终的 LiT 演示。 调整模型在开始实现 ML 演示之前,认真考虑不同的选项及其各自的优缺点是一个好时机。 总的来说,您有两种选择:在服务器端基础设施上运行 ML 模型,或在边缘(即访问用户的设备上)运行 ML 模型。
您用于演示的模型由两部分组成:图像编码器和文本编码器(见图 4)。 对于计算图像嵌入,您使用一个大型模型,而对于文本嵌入则使用一个小模型。为了使演示运行得更快并产生更好的结果,昂贵的图像嵌入是预先计算的,因此 Tensorflow.js 模型只需要计算文本嵌入,然后比较图像和文本嵌入以计算相似度。
对于演示,我们现在可以免费获得这些强大的 ViT-Large 图像表示,因为我们可以为所有演示图像预先计算它们。这使得我们能够以有限的计算预算制作出令人信服的演示。除了“小型”文本编码器之外,我们还为相同的图像嵌入准备了一个“小型”文本编码器(LiT-L16S),它表现略好,但使用更多带宽来下载模型权重,并且需要更多 GPU 内存才能在设备上运行。我们使用来自 此 Colab 的代码评估了不同的模型
请注意,“零样本性能”仅应视为一个代理。最终,模型的性能需要足以满足演示的需要,在本例中,我们的手动测试表明,即使是最小的文本转换器也能够计算出足以满足演示需求的相似度。接下来,我们使用此 TensorFlow.js 基准测试工具 在不同平台上测试了小型和小型文本编码器的性能(使用“自定义模型”选项,并在 WebGL 后端上对 5x16 个令牌进行基准测试) 请注意,由于模型无法装入内存,上述表格中“Samsung S21 G5”的“小型”文本编码器模型结果缺失。就性能而言,带有“微型”文本编码器的模型产生的结果在大约 0.1-1 秒内,即使在测试过的最小平台上,感觉也相当快。 Lit-LiT 网络应用为该应用准备模型稍微复杂一些,因为我们需要不仅转换文本转换器模型权重,还需要匹配的标记器和预先计算的图像嵌入。Colab 加载 LiT 模型并展示如何使用它,然后准备网络应用所需的內容:
data/ 目录中准备,然后作为 ZIP 文件下载。然后可以将此文件上传到网络托管,网络应用从此处加载它(例如在 GitHub 页面上:vision_transformer/lit/data)。整个客户端应用程序的代码在 GitHub 上提供: https://github.com/google-research/big_vision/tree/main/big_vision/tools/lit_demo/. 该应用程序使用 Lit 网络组件 构建。 main index.html 声明演示应用程序: |
<lit-demo-app></lit-demo-app> |
lit-demo-app.ts
中定义,位于 src/components 子目录中,与所有其他网络组件(图像轮播、模型控件等)并排。image-prompts.ts
调用模块 src/lit_demo/compute.ts
中的函数,该模块包装了所有 TensorFlow.js 特定代码。
export class Model { /** 对文本进行标记。 */ tokenize(texts: string[]): tf.Tensor { /* ... */ } /** 计算文本嵌入。 */ embed(tokens: tf.Tensor): tf.Tensor { return this.model!.execute({inputs: tokens}) as tf.Tensor; } /** 计算文本/预计算图像嵌入的相似度。 */ computeSimilarities(texts: string[], imgidxs: number[]) { const textEmbeddings = this.embed(this.tokenize(texts)); const imageEmbeddingsTransposed = tf.transpose( tf.concat(imgidxs.map(idx => tf.slice(this.zimgs!, idx, 1)))); return tf.matMul(textEmbeddings, imageEmbeddingsTransposed); } /** 对 `computeSimilarities()` 应用 softmax。 */ computeProbabilities(texts: string[], imgidx: number): number[] { const sims = this.computeSimilarities(texts, [imgidx]); const row = tf.squeeze(tf.slice(tf.transpose(sims), 0, 1)); return [...tf.softmax(tf.mul(this.def!.temperature, row)).dataSync()]; } } |
data/
目录的父目录在 src/lit/constants.ts
文件中的 baseUrl 中引用。默认情况下,它引用的是官方演示中的模型。当将 baseUrl
替换为不同的服务器时,请确保启用 跨源资源共享。
<!-- 加载全局符号 `lit`。 --> <script src="exports_bin.js"></script> <script> async function demo() { lit.setBaseUrl('https://google-research.github.io/vision_transformer/lit'); const model = new lit.Model('tiny'); await model.load(); console.log(model.computeProbabilities(['a dog', 'a cat'], /*imgIdx=*/1); } demo(); </script> |
2022 年 8 月 31 日 — 发布者:Andreas Steiner 和 Marc van Zee,Google Research,Brain 团队介绍 本博客文章演示了如何使用 TensorFlow.js 将基于 Python 的 JAX 函数和 Flax 机器学习模型转换为浏览器。我们制作了三个 JAX 到 TensorFlow.js 转换的示例,每个示例的复杂度都在增加:简单 JAX 函数 训练的图像分类 Flax 模型…