https://blog.tensorflowcn.cn/2022/09/optimizing-tf-xla-and-jax-for-llm-training-on-nvidia-gpus.html
https://blogger.googleusercontent.com/img/b/R29vZ2xl/AVvXsEhu7jxiuANJm2mKsE_-8WP2eOKqivo7VW7p_Qv7bhW4UqtRjaKJMx9-VbQUsqrPSI3SSB-GxHw1xb3b3qs9su4TaPgvJ9Y9ki5q3MnTFnB1epDGpnCMjozcWZ-z2kJWz5qc4bnxRZKoYIbBI60hourWcCf3aeYQQJxVsAuzzsoFo6eAkxCrD3u4Tqou/s16000/NVIDIA%20H100.jpeg
作者:Douglas Yarrington (Google TPgM)、James Rubin (Google PM)、Neal Vaidya (NVIDIA TME)、Jay Rodge (NVIDIA PMM)
NVIDIA 和 Google 共同宣布了新的里程碑和计划,通过利用
XLA 的强大功能,优化 TensorFlow 和 JAX 在 Ampere 和最近发布的 Hopper GPU 架构上的性能:XLA 是由 Google 构建的高效、灵活且可扩展的 ML 编译器。我们将加深与专注于在当前可用的 A100 GPU 上提供更高性能的专门工程团队之间的合作。NVIDIA 和 Google 还将共同支持最近发布的 H100 GPU 中的独特功能,包括 Transformer Engine,它支持硬件加速的 8 位浮点 (FP8) 数据类型和 transformer 库。
我们宣布了 TensorFlow 性能的提升,XLA 中新的 NVIDIA GPU 特定功能,以及 JAX 首次发布用于多节点、多 GPU 训练的功能,这将显著提高大型语言模型 (LLM) 的训练速度。我们预计 Hopper 架构将特别适合 LLM。
|
NVIDIA H100 Tensor Core GPU |
用于 GPU 的 XLA
Google 在 NVIDIA GPU 上使用 LLM 能够实现高性能,这得益于一项名为 XLA 的重要技术,它支持所有领先的 ML 框架,例如 TensorFlow、JAX 和 PyTorch。在 Google,超过 90% 的 ML 编译(涵盖研究和生产)都依赖于 XLA。这些应用涵盖了各种 ML 使用场景,从 DeepMind 和 Google Research 中的超大规模模型训练,到我们产品中的优化部署,再到 Waymo 的边缘推理。
XLA 的强大功能集加速了大型语言模型的性能,并
解决了当今行业中大多数大型模型挑战。例如,XLA 独有的 SPMD 功能可以自动化大多数将模型分配到多个核心和设备所需的作业,使大型模型训练的扩展性和性能大幅提高。XLA 还可以自动识别并选择最适合目标后端的最佳手写库实现,例如 cuDNN 用于 CUDA 芯片组。否则,XLA 可以本地生成优化代码以实现高效执行。
我们一直在与 NVIDIA 合作开发一些令人振奋的功能和集成,这些功能和集成将进一步优化 GPU 上的 LLM。我们最近使全约简等集合操作能够与计算并行运行。这显著减少了客户的端到端延迟。此外,我们启用了对 bfloat16 的支持,这使得计算性能比 32 位浮点提高了 4.5 倍,同时保持了相同的值动态范围。
我们共同的努力意味着 XLA 与 NVIDIA 的 AI 工具深度集成,并且可以更好地利用 NVIDIA 的 AI 硬件优化库。在 2023 年第一季度,我们将发布 XLA-cuDNN 图 API 集成,该集成为客户提供卷积/矩阵乘法运算和 transformer 中的多头注意力机制的优化融合,从而改进内存使用并加快 GPU 内核执行速度。结果,开销显著降低,性能显著提升。
用于 GPU 的 TensorFlow
TensorFlow 最近发布了分布式张量 (DTensor),使张量能够跨设备(如 NVIDIA GPU)存储,同时允许程序无缝地操作它们。DTensor 的目标是简化、易于理解并快速地在多个设备上并行化大型 TensorFlow 模型。DTensor 是本地 TensorFlow 张量的直接替代品,并且可以很好地扩展到大型集群。此外,DTensor 项目改进了底层的 TensorFlow 执行和通信原语,并且可以
立即使用!
我们还在与 NVIDIA 合作开发 TensorFlow 中的一些令人振奋的新功能,这些功能利用了 GPU,包括支持新的 FP8 数据类型,这应该在使用 Hopper H100 GPU 时显著提高 transformer 模型的训练速度。
用于 GPU 的 JAX
Google
致力于为每个开发人员提供专用的工具,用于 ML 工作流程的每个步骤。这包括用于健壮的、可投入生产的模型的 TensorFlow 以及具有
高度优化功能 的 JAX,用于尖端的科研工作。我们很高兴地宣布 NVIDIA 和 Google 工程团队之间独特的合作关系,以增强 TensorFlow 和 JAX 对大型深度学习模型(如 LLM)的支持。这两个框架都完全支持 NVIDIA A100 GPU,并且将来将支持最近发布的 H100 GPU。
JAX 的主要优势之一是能够轻松实现优异的硬件利用率,在加速器上实现业界领先的 FLOP 性能。通过与 NVIDIA 的合作,我们将这些优势通过 XLA 编译器的一些魔法转移到 GPU 上。具体来说,我们利用 XLA 进行操作符融合,改进
GSPMD 以支持 GPU 上的通用数据和模型并行,并针对跨主机 NVLink 进行优化。
未来计划
NVIDIA 和 Google 对本文中分享的所有进展感到满意,并热切期待社区成员分享他们使用 TensorFlow 和 JAX 的体验,通过利用 XLA 的强大功能来支持 Ampere (A100) 和 Hopper (H100) GPU。
查看
发布说明 以获取更多信息。要保持最新,您可以阅读 TensorFlow
博客,关注
twitter.com/tensorflow,或订阅
youtube.com/tensorflow。如果您构建了一些想分享的内容,请将其提交到我们的社区亮点
goo.gle/TFCS。如需反馈,请在
GitHub 上提交问题或发布到
TensorFlow 论坛。
TensorFlow 也在
NVIDIA GPU 云 (NGC) 中提供,作为一个 docker
容器,其中包含一组经过验证的库,这些库可以启用和优化 GPU 性能,其中
JAX NGC 容器将于今年晚些时候发布。
谢谢!
贡献者: Frederic Bastien (NVIDIA)、Abhishek Ratna (Google)、Sean Lee (NVIDIA)、Nathan Luehr (NVIDIA)、Ayan Moitra (NVIDIA)、Yash Katariya (Google)、Peter Hawkins (Google)、Skye Wanderman-Milne (Google)、David Majnemer (Google)、Stephan Herhut (Google)、George Karpanov (Google)、Mahmoud Soliman (NVIDIA)、Yuan Lin (NVIDIA)、Vartika Singh (NVIDIA)、Vinod Grover (NVIDIA)、Pooya Jannaty (NVIDIA)、Paresh Kharya (NVIDIA)、Santosh Bhavani (NVIDIA)