MoveNet 和 TensorFlow.js 的下一代姿态检测
2021 年 5 月 17 日

Ronny VotelNa Li,Google Research 发表

今天我们很高兴发布我们最新的姿态检测模型 MoveNet,以及我们新的姿态检测 API 在 TensorFlow.js 中。MoveNet 是一种超快且准确的模型,可以检测人体 17 个关键点。该模型在 TF Hub 上提供,有两种变体,分别称为 闪电雷电。闪电适用于对延迟要求严格的应用程序,而雷电适用于需要高精度的应用程序。这两个模型在大多数现代台式机、笔记本电脑和手机上都比实时运行得更快(30+ FPS),这对于健身、运动和健康方面的实时应用程序至关重要。这是通过在浏览器中使用 TensorFlow.js 在客户端完全运行模型来实现的,初始页面加载后无需进行服务器调用,也不需要安装任何依赖项。

试用实时演示!

MoveNet can track keypoints through fast motions and atypical poses.
MoveNet 可以通过快速运动和非典型姿势跟踪关键点。

人类姿态估计在过去五年中取得了长足的进步,但令人惊讶的是,它尚未在许多应用程序中出现。这是因为人们更多地关注使姿态模型更大更准确,而不是进行工程工作以使其变得快速并在任何地方部署。对于 MoveNet,我们的目标是设计和优化一个模型,它利用了最先进架构的最佳方面,同时将推理时间尽可能降低。结果是,该模型可以跨各种姿势、环境和硬件设置提供准确的关键点。

使用 MoveNet 打破实时健康应用程序的局限

我们与 IncludeHealth 合作,这是一家数字健康和性能公司,以了解 MoveNet 是否可以帮助为患者解锁远程护理。IncludeHealth 开发了一个交互式网络应用程序,引导患者在家中舒适的环境中(使用手机、平板电脑或笔记本电脑)完成各种例行程序。这些例行程序由物理治疗师数字构建和开具,以测试平衡、力量和活动范围。

该服务需要基于网络并在本地运行的姿态模型,以实现隐私,这些模型可以以高帧速率提供精确的关键点,然后用来量化和定性人类姿势和运动。虽然典型的现成检测器足以用于简单的运动,例如 肩部外展和内收全身深蹲,但对于更复杂的姿势,例如 坐姿膝盖伸展 或仰卧姿势(仰卧),即使是针对错误数据进行训练的最先进检测器也会感到头疼。

Comparison of a traditional detector (top) vs MoveNet (bottom) on difficult poses.
传统检测器(顶部)与 MoveNet(底部)在困难姿势上的比较。

我们向 IncludeHealth 提供了 MoveNet 的早期版本,可通过新的 姿态检测 API 访问。该模型是在健身、舞蹈和瑜伽姿势上进行训练的(有关训练数据集的更多详细信息,请见下文)。IncludeHealth 将该模型集成到他们的应用程序中,并对 MoveNet 相对于其他可用姿态检测器进行了基准测试

“MoveNet 模型融合了提供处方护理所需的强大速度和精度。虽然其他模型牺牲了其中一个来换取另一个,但这种独特的平衡开启了下一代护理交付。谷歌团队一直是这项工作的出色合作者。” - Ryan Eder,IncludeHealth 创始人兼首席执行官。

下一步,IncludeHealth 将与医院系统、保险计划和军队合作,使传统护理和培训能够扩展到实体店之外。

IncludeHealth demo application running in browser that quantifies balance and motion using keypoint estimation powered by MoveNet and TensorFlow.js
IncludeHealth 演示应用程序在浏览器中运行,使用由 MoveNet 和 TensorFlow.js 提供支持的关键点估计来量化平衡和运动

安装

有两种方法可以使用新的姿态检测 api 与 MoveNet

  1. 通过 NPM
    import * as poseDetection from '@tensorflow-models/pose-detection';
  2. 通过脚本标签
    <script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow/tfjs-core"></script>
    <script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow/tfjs-converter"></script>
    <script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow/tfjs-backend-webgl"></script>
    <script src="https://cdn.jsdelivr.net.cn/npm/@tensorflow-models/pose-detection"></script>

自己试试吧!

安装完包后,您只需要按照以下几个步骤就可以开始使用它了

// Create a detector.
const detector = await poseDetection.createDetector(poseDetection.SupportedModels.MoveNet);

检测器默认使用闪电版;要选择雷电版,请按如下方式创建检测器

// Create a detector.
const detector = await poseDetection.createDetector(poseDetection.SupportedModels.MoveNet, {modelType: poseDetection.movenet.modelType.SINGLEPOSE_THUNDER});
// Pass in a video stream to the model to detect poses.
const video = document.getElementById('video');
const poses = await detector.estimatePoses(video);

每个姿势包含 17 个关键点,具有绝对 x、y 坐标、置信度分数和名称

console.log(poses[0].keypoints);
// Outputs:
// [
//    {x: 230, y: 220, score: 0.9, name: "nose"},
//    {x: 212, y: 190, score: 0.8, name: "left_eye"},
//    ...
// ]

有关 API 的更多详细信息,请参阅我们的 README

当您开始使用 MoveNet 进行游戏和开发时,我们感谢您的 反馈贡献。如果您使用此模型制作了一些东西,请在社交媒体上使用 #MadeWithTFJS 对其进行标记,以便我们找到您的作品,我们很乐意看到您创作的作品。

MoveNet 深入研究

MoveNet 架构

MoveNet 是一种 自下而上的估计模型,使用热图来准确地定位人体关键点。该架构包括两个部分:一个 特征提取器 和一组 预测头。预测方案大体上遵循 CenterNet,但进行了一些显著的更改,既提高了速度也提高了精度。所有模型都是使用 TensorFlow 对象检测 API 训练的。

MoveNet 中的特征提取器是 MobileNetV2,附加了一个 特征金字塔网络 (FPN),它允许输出高分辨率(输出步幅 4)、语义丰富的特征图。特征提取器附加了四个预测头,负责密集地预测

  • 人中心热图:预测人实例的几何中心
  • 关键点回归场:预测人的完整关键点集,用于将关键点分组为实例
  • 人关键点热图:预测所有关键点的坐标,独立于人实例
  • 每个关键点的 2D 偏移场:预测从每个输出特征图像素到每个关键点的精确亚像素位置的局部偏移
MoveNet architecture
MoveNet 架构

虽然这些预测是并行计算的,但可以通过考虑以下操作顺序来深入了解模型的操作

步骤 1:人中心热图用于识别帧中所有个人的中心,定义为属于该人的所有关键点的算术平均值。选择得分最高(按到帧中心的距离的倒数加权)的位置。

步骤 2:通过从对应于对象中心的像素切片关键点回归输出,生成该人的初始关键点集。由于这是一个从中心向外的预测——必须在不同的尺度上进行操作——因此回归的关键点的质量不会非常准确。

步骤 3:关键点热图中的每个像素都乘以一个权重,该权重与从对应回归的关键点到该像素的距离成反比。这确保我们不接受来自背景人员的关键点,因为它们通常不会在回归的关键点附近,因此得分会很低。

步骤 4:通过检索每个关键点通道中最大热图值对应的坐标,选择最终的关键点预测集。然后将局部 2D 偏移预测添加到这些坐标中,以获得细化的估计值。请参阅下图,它说明了这四个步骤。

MoveNet post-processing steps
MoveNet 后处理步骤。

训练数据集

MoveNet 在两个数据集上进行了训练:COCO 和一个称为 Active 的内部 Google 数据集。虽然 COCO 是检测的标准基准数据集——由于其场景和比例的多样性——但它不适用于健身和舞蹈应用,因为这些应用表现出具有挑战性的姿势和明显的运动模糊。Active 是通过对来自 YouTube 的瑜伽、健身和舞蹈视频中的关键点进行标记(采用 COCO 标准的 17 个身体关键点)来生成的。从每个视频中选择不超过三个帧进行训练,以促进场景和个人的多样性。

在 Active 验证数据集上的评估表明,相对于仅使用 COCO 训练的相同架构,性能有了显著提高。这并不奇怪,因为 COCO 很少表现出具有极端姿势(例如瑜伽、俯卧撑、倒立等等)的个人。

要详细了解数据集以及 MoveNet 在不同类别中的表现,请参阅 模型卡

Images from Active keypoint dataset.
来自 Active 关键点数据集的图像。

优化

虽然在架构设计、后处理逻辑和数据选择方面付出了很多努力,使 MoveNet 成为一个高质量的检测器,但同样重视推理速度。首先,MobileNetV2 中的瓶颈层被选中用于 FPN 中的横向连接。同样,每个预测头中的卷积滤波器的数量也大大减少,以加快对输出特征图的执行速度。除了第一个 MobileNetV2 层之外,整个网络都使用深度可分离卷积。

MoveNet 反复进行分析,发现并删除了特别慢的操作。例如,我们用 tf.math.argmax 替换了 tf.math.top_k,因为它执行速度明显更快,并且适用于单人设置。

为了确保 TensorFlow.js 的快速执行,所有模型输出都打包到单个输出张量中,因此从 GPU 到 CPU 只有一个下载操作。

也许最显著的加速是模型使用 192x192 的输入(Thunder 使用 256x256)。为了弥补较低的分辨率,我们根据上一帧的检测结果应用智能裁剪。这使得模型能够将注意力和资源集中在主要主体上,而不是背景上。

时间滤波

在高 FPS 相机流上操作提供了对关键点估计应用平滑的便利。Lightning 和 Thunder 都对传入的关键点预测流应用了强大的非线性滤波器。该滤波器经过调整,可以同时抑制高频噪声(即抖动)和模型中的异常值,同时在快速运动期间保持高带宽吞吐量。这导致在所有情况下都具有最小延迟的平滑关键点可视化。

MoveNet 浏览器性能

为了量化 MoveNet 的推理速度,模型在多个设备上进行了基准测试。模型延迟(以 FPS 表示)在具有 WebGL 的 GPU 上以及 WebAssembly (WASM) 上进行了测量,WebAssembly (WASM) 是低端或没有 GPU 的设备的典型后端。


MacBook Pro 15” 2019. 

Intel 酷睿 i9. 

AMD Radeon Pro Vega 20 图形。

(FPS)

iPhone 12

(FPS)

Pixel 5

(FPS)

台式机 

Intel i9-10900K. Nvidia GTX 1070 GPU。

(FPS)

WebGL

104  |  77

51  |  43

34  |  12

87  |  82

WASM 

使用 SIMD + 多线程

42  |  21

N/A

N/A

71  |  30

MoveNet 在不同设备和 TF.js 后端上的推理速度。每个单元格中的第一个数字代表 Lightning,第二个数字代表 Thunder。

TF.js 不断优化其后端,以加速所有支持设备上的模型执行。我们在这里应用了一些技术来帮助模型实现这种性能,例如为深度可分离卷积实现 打包 WebGL 内核 以及改进移动 Chrome 的 GL 调度。

要查看模型在您设备上的 FPS,请 尝试我们的演示。您可以在演示 UI 中实时切换模型类型和后端,以查看最适合您设备的内容。

展望未来

下一步是将 Lightning 和 Thunder 模型扩展到多人领域,以便开发人员可以支持在相机视野中有多个人的应用程序。

我们还计划加快 TensorFlow.js 后端,使模型执行更快。这可以通过重复的基准测试和后端优化来实现。

鸣谢

我们要感谢 MoveNet 的其他贡献者:陈宇辉Ard OerlemansFrancois BellettiAndrew BunnerVijay Sundaram,以及参与 TensorFlow.js 姿势检测 API 的人员:余平Sandeep GuptaJason MayesMasoud Charkhabi

下一篇文章
Next-Generation Pose Detection with MoveNet and TensorFlow.js

作者:Ronny Votel李娜,Google Research 我们很高兴在 TensorFlow.js 中推出最新的姿势检测模型 MoveNet,以及我们新的姿势检测 API。MoveNet 是一款超快速且准确的模型,可以检测身体的 17 个关键点。该模型在 TF Hub 上提供,有两个变体,称为 LightningThunder。Lightning 适用于延迟敏感的应用程序,而 Th…