2023 年 5 月 10 日 — TensorFlow 团队的 Ayush Jain、Carlos Araya 和 Mani Varadarajan 撰写欢迎来到 Google I/O 的 TensorFlow 和 Keras!机器学习领域正在以前所未有的速度发生着变化。大型语言模型 (LLM) 的兴起激发了全球开发者的想象力,新的生成式 AI 应用正在触及全球数亿人。这些模型经过……
欢迎来到 Google I/O 的 TensorFlow 和 Keras!
机器学习领域正在以前所未有的速度发生着变化。大型语言模型 (LLM) 的兴起激发了全球开发者的想象力,新的生成式 AI 应用正在触及全球数亿人。这些模型经过海量数据集的训练,并用于解决各种任务,从自然语言处理到图像生成。
为所有这些新功能提供支持需要新的模型效率和性能水平,以及对跨越来越多的设备(无论是服务器、网络、移动设备还是其他设备)进行无缝部署的支持。作为全球最大的机器学习社区之一的守护者,TensorFlow 团队一直在思考如何才能更好地为您服务。
为此,本文介绍了今年 TensorFlow 生态系统即将推出的众多改进和新增功能中的几个。让我们深入探讨吧!
我们今天将介绍的新功能
KerasCV 和 KerasNLP 使您能够在短短几行代码中访问经过预训练的、最先进的模型。
DTensor 通过结合不同的并行技术,帮助您扩展模型并高效地训练它们。
借助 JAX2TF,使用 JAX 数值库 编写的模型可以在 TensorFlow 生态系统中使用。
我们还预览了 TF 量化 API,它使您能够在不影响准确性的情况下使模型更具成本效益和资源效率。
KerasCV 和 KerasNLP 是功能强大、模块化的库,使您能够直接访问计算机视觉和自然语言处理领域的最先进技术。
KerasCV + KerasNLP 套件一览。 |
无论您是想对图像进行分类,还是像 Bard 一样从提示自动生成文本,还是其他任何操作,KerasCV 和 KerasNLP 都可以通过短短几行代码轻松实现。而且,由于它是 Keras 的一部分,因此与 TensorFlow 生态系统完全集成。
让我们来看一些图像生成代码。KerasCV 旨在支持多种模型,在本例中,我们将使用扩散模型。尽管底层架构很复杂,但您只需几行代码即可运行它。
from keras_cv.models import ( StableDiffusion, ) model = StableDiffusion( img_width=512, img_height=512, ) |
只需一行导入代码和一行初始化模型代码,您就可以生成全新的图像
images = model.text_to_image( "宇航员的照片" "骑着马", batch_size=3, ) |
KerasCV 生成的骑马的宇航员图像! |
这只是众多示例中的一个。要了解更多信息,请查看我们的 关于 KerasCV 和 KerasNLP 的完整演讲 或在 keras.io/keras_cv 和 keras.io/keras_nlp 查看深度工具包指南。
DTensor 通过为开发人员提供灵活地组合和微调多种并行技术,使模型训练规模更大、性能更高。
传统上,机器学习开发人员通过数据并行来扩展模型,这会将您的数据拆分并将其馈送到水平扩展的模型实例。这会扩展训练,但有一个重要的限制:它要求模型适合单个硬件设备。
随着模型越来越大,适合单个设备不再是保证 - 开发人员需要能够跨硬件设备扩展模型。这就是模型并行发挥作用的地方,它允许将模型拆分为可以在并行训练的碎片。
使用 DTensor,不仅支持数据并行和模型并行,还可以直接将它们组合起来,从而更有效地扩展模型。而且它完全不受加速器限制 - 无论您使用的是 TPU、GPU 还是其他设备。
混合(数据 + 模型)并行,使用 DTensor。 |
让我们来看一个示例。假设您正在构建一个变换器模型,例如 KerasNLP 中提供的开放预训练变换器 (OPT),并使用一些输入数据集进行训练
opt_lm = keras_nlp.models.OPTCasualLM.from_preset("opt_6.7b_en") opt_lm.compile(...) opt_lm.fit(wiki_text_dataset) |
但 OPT 的问题是 - 它非常大。它的变体参数多达 1750 亿个,如果我们尝试使用传统的数据并行,它会直接出错 - 在单个硬件设备中复制太多权重是不合理的。这就是 DTensor 发挥作用的地方。
要使用 DTensor,我们需要定义两件事
首先是网格,您在其中定义 (a) 一组硬件设备和 (b) 一种拓扑,这里指的是批处理和模型维度。
mesh_dims = [("batch", 2), ("model", 4)] mesh = dtensor.create_distributed_mesh(mesh_dims, device_type="GPU") dtensor.initialize_accelerator_system("GPU") |
其次是**布局**,它定义了如何在定义的网格上对张量维度进行分片。通过我们与 Keras 域包的集成,您只需一行代码就可以完成此操作。
layout_map = keras_nlp.models.OPTCausalLM.create_layout_map(mesh)
with layout_map.scope(): opt_lm = keras_nlp.models.OPTCausalLM.from_preset("opt_6.7b_en") opt_lm.compile(...) opt_lm.fit(wiki_text_dataset) |
目前 DTensor 的性能已经与行业基准相当,几乎与 NVIDIA 为 GPU 提供的模型并行性的黄金标准实现 Megatron 相媲美。我们正在努力进一步提高性能,涵盖各种硬件设备。
将来,DTensor 将与关键接口(如tf.distribute
和 Keras)完全集成,无论硬件如何,都只有一个入口点,以及许多其他提高生活质量的功能。如果您想了解更多信息,请查看DTensor 概述或Keras 集成指南!
许多如今家喻户晓的机器学习进步都起源于研究。例如,由 Google AI 创建和发布的 Transformer 架构,支撑了语言模型的巨大进步。
JAX 已成为进行这类探索的可靠工具,但将其转化为生产环境却很困难。为此,我们一直在思考如何更轻松地将研究成果引入 TensorFlow,让基于 JAX 的创新充分利用 TensorFlow 独特的、强大且多样化的生产生态系统。
这就是我们构建 JAX2TF 的原因,这是一个轻量级 API,为从 JAX 生态系统到 TensorFlow 生态系统提供了一条途径。这有很多用途,以下只是一些例子:
- **推理:** 将为 JAX 编写的模型部署到服务器上(使用 TF Serving)或设备上(使用 TFLite)。
- **微调:** 我们可以使用 JAX2TF 将使用 JAX 训练的模型的组件带到 TF 中,并使用您现有的训练数据和设置在 TensorFlow 中继续训练它。
- **融合:** 将使用 JAX 训练的模型部分与使用 TensorFlow 训练的模型部分组合在一起,以实现最大的灵活性。
实现 JAX 和 TensorFlow 之间这种互操作性的关键在于 jax2tf.convert
,它接收在 JAX 之上创建的模型组件(例如您的损失函数、预测函数等),并创建它们作为TensorFlow 函数的等效表示,然后可以导出为TensorFlow SavedModel。
我们为上述示例之一创建了一个代码演练:一个快速的微调设置,使用 JAX 生态系统中的建模库(如Flax和Optax)创建一个简单的模型,并将其带入 TF 以完成训练。请查看这里。
JAX2TF 已经作为 TensorFlow 生态系统中的各种工具的一部分内置在后台。例如,以下代码指南展示了如何从 JAX 转换为 TFLite以用于移动设备,以及从 JAX 转换为 TF.js以用于 Web 部署!
如今,机器学习开发人员在工作中面临着各种现实世界的约束,例如模型的大小或部署的位置。
我们希望开发人员能够使用 TensorFlow 快速调整和适应这些约束,并且在不牺牲模型质量的情况下做到这一点。为此,我们正在构建 TF 量化 API,这是一个用于 TF2 的原生量化工具包,将于 2023 年晚些时候公开发布。
简而言之,量化是一组旨在使模型更快、更小、总体上减少训练和服务所需的资源和基础设施的技术。
量化通过降低模型参数的精度来实现这一点,就像降低像下面爱因斯坦照片这样的图像的像素深度一样。请注意,即使精度降低了,我们仍然可以辨认出关键的细节。
一张爱因斯坦照片的渲染,其比特精度逐渐降低。 |
从高层次上看,这是通过获取起始精度的值范围,并将该范围映射到最终精度的单个桶来实现的。让我们用一个例子来说明这一点。
将浮点数表示量化为 4 位整数。 |
请看 x 轴上的 [0.27, 0.49] 范围:对于 float32,蓝线实际上代表了 7381976 个唯一数字!红线代表该范围的 int4 量化,将所有这些数字压缩到一个桶中:1001(十进制数为 9)。
通过量化来降低精度,我们可以以更高效、压缩的形式存储模型权重。
量化有几种不同的方法。
- **训练后量化 (PTQ):** 在训练后将模型转换为量化模型。这很简单,并且最容易使用,但可能会导致质量略有下降。
- **量化感知训练 (QAT):** 只在正向传递中模拟量化,从而提供最大的灵活性,同时质量损失最小。
- **量化训练:** 在训练过程中量化所有计算。这仍然处于起步阶段,需要更多的测试,但这是一个强大的工具,我们希望确保 TensorFlow 用户能够使用它。
以前,TensorFlow 提供了一些工具供开发人员量化其模型,例如PTQ 的指南和QAT 的指南。然而,这些工具有限——PTQ 依赖于转换为 TFLite 以进行移动部署,而 QAT 则要求您重写模型。
TF 量化 API 与众不同 - 它旨在无论您在何处部署都能正常工作,而且无需您重写一行现有的模型代码。我们以灵活性与保真度为设计理念,让您能够从更小的量化模型中获得好处,同时拥有新的细粒度控制级别,而无需担心它如何与您的堆栈完美契合。
既然您已经看到了这篇博文,那么现在就让我们一睹其风采吧。我们将从 TensorFlow 模型的典型设置开始,只是在 Keras 中添加了一些层。从那里,我们可以加载预定义的量化模式,将其作为配置映射应用于我们的模型。
# 第 1 步:像往常一样定义您的模型。 model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(32, 3, strides=1, padding='same', activation='relu'), … …]) # 第 2 步:使用预定义模式设置量化配置。 scheme = scheme_registry.get_scheme('pixel_8bit_qat') config_map = QuantizationConfigurationMap(scheme) |
但是,如果您需要更多灵活性,TF 量化 API 也将允许您完全自定义量化方式。它内置支持,让您能够管理模式以对每个层、操作或张量应用不同的行为!
# ...或者同样轻松地配置您自己的模式,无论是每层: layer_config = LayerConfiguration( weight_config=..., activation_config=...) config_map.set_config(model.layers[0], layer_config) # 每操作: config_map.set_config(model.layers[0], op_type='matmul', config={ 'a': ..., 'b': ..., 'return': ... }) # 甚至每个张量: _8bit_moving_average = QuantizationConfiguration(...) per_tensor_config = LayerConfiguration( weight_config=..., activation_config=_8bit_moving_average) config_map.set_config(model.layers[0], per_tensor_config) |
有了这些,我们就可以直接应用量化并在量化环境中进行训练或保存。我们的模型仍然与 TF 生态系统的其他部分具有自然兼容性,量化在其中真正发挥作用。
# 现在你可以生成一个量化感知模型了! tf.quantization.apply_quantization_on_model(model, config_map, …) # 从这里,你可以像往常一样进行训练和保存。 with tf.quantization.scope(config_map): model.fit() model.save() # 你也可以导出到 TFLite,无需任何更改! converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert() |
我们在 Pixel 7 上使用 MobileNetV2 模型进行了一些测试,发现与非量化基线相比,服务吞吐量提高了 16.7 倍。 这种增益不会对质量造成任何明显的损害:无论是 float32 基线还是 int8 量化模型都报告了 73% 的准确率。
TF 量化 API 尚未公开,但很快就会推出,并且将继续发展以提供更多优势。
今天,我们向您展示了我们一直在努力的一些关键内容,还有更多内容即将推出。
我们迫不及待地想看看您将构建什么,我们始终受到社区持久热情和持续合作的启发。 谢谢大家!
特别感谢 George Necula、Francois Chollet、Jonathan Bischof、Scott Zhu、Martin Gorner、Dong Li、Adam Koch、Bruce Fontaine、Laurence Moroney、Josh Gordon、Lauren Usui 以及众多其他为本文做出贡献的人员。