使用 TensorFlow.js 预测球和好球
2018 年 7 月 25 日
作者:Nick Kreeger

在这篇文章中,我们将使用 TensorFlow.jsD3.js 和 Web 的力量来可视化训练模型的过程,该模型可以从棒球数据中预测球(蓝色区域)和好球(橙色区域)。在过程中,我们将可视化模型在整个训练过程中理解的打击区。您可以在浏览器中访问 此 Observable 笔记本 来运行此模型。

如果您不熟悉棒球中的打击区,这里有一篇 文章 详细介绍。
上面的 GIF 可视化了神经网络学习如何调用球(蓝色区域)和好球(橙色区域)。在每个训练步骤之后,热图都会更新模型的预测。
使用 Observable 直接在您的浏览器中运行此模型。

体育中的高级指标

如今的职业体育环境充斥着大量数据。这些数据正在被球队、爱好者和球迷应用于各种用例。得益于 TensorFlow 等框架,这些数据集已准备好应用机器学习。

MLBAM 的 PITCHf/x

美国职业棒球大联盟先进媒体 (MLBAM) 发布了一个 大型数据集,供公众进行研究。该数据集包含过去几年中美国职业棒球大联盟比赛中投出的球的传感器信息。我使用该数据集整理了一个包含 5,000 个样本(2,500 个球和 2,500 个好球)的 训练集

以下是训练数据中前几个字段的示例

以下是训练数据在打击区绘制时的样子。蓝色点标记为球,橙色点标记为好球(由美国职业棒球大联盟裁判判定)

使用 TensorFlow.js 构建模型

TensorFlow.js 将机器学习带到了 JavaScript 和 Web。我们将使用这个很棒的框架来构建一个深度神经网络模型。该模型将能够像美国职业棒球大联盟裁判一样精确地调用球和好球。

输入

此模型在 PITCHf/x 中的以下字段上进行训练
  • 球越过本垒板的位置坐标(‘px’ 和 ‘pz’)。
  • 击球手站在本垒板的哪一边。
  • 打击区高度(击球手的躯干)以英尺为单位。
  • 打击区底部高度(击球手的膝盖)以英尺为单位。
  • 裁判判定的球的实际标签(球或好球)。

架构

此模型将使用 TensorFlow.js Layers API 定义。Layers API 基于 Keras,对于使用过该框架的人来说应该很熟悉
const model = tf.sequential();

// Two fully connected layers with dropout between each:
model.add(tf.layers.dense({units: 24, activation: 'relu', inputShape: [5]}));
model.add(tf.layers.dropout({rate: 0.01}));
model.add(tf.layers.dense({units: 16, activation: 'relu'}));
model.add(tf.layers.dropout({rate: 0.01}));

// Only two classes: "strike" and "ball":
model.add(tf.layers.dense({units: 2, activation: 'softmax'}));

model.compile({
  optimizer: tf.train.adam(0.01),
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy']
});

加载和准备数据

整理后的训练集可通过 GitHub gist 获得。需要下载此数据集才能开始将 CSV 数据转换为 TensorFlow.js 用于训练的格式。
const data = [];
csvData.forEach((values) => {
  // 'logit' data uses the 5 fields:
  const x = [];
  x.push(parseFloat(values.px));
  x.push(parseFloat(values.pz));
  x.push(parseFloat(values.sz_top));
  x.push(parseFloat(values.sz_bot));
  x.push(parseFloat(values.left_handed_batter));
  // The label is simply 'is strike' or 'is ball':
  const y = parseInt(values.is_strike, 10);
  data.push({x: x, y: y});
});
// Shuffle the contents to ensure the model does not always train on the same
// sequence of pitch data:
tf.util.shuffle(data);
解析 CSV 数据后,需要将 JS 类型转换为张量批次,以便进行训练和评估。有关此过程的更多详细信息,请参阅 代码实验室。TensorFlow.js 团队正在开发一个新的数据 API,以在将来简化此数据摄取过程。

训练模型

让我们将所有这些整合在一起。模型已定义,训练数据已准备就绪,现在我们可以开始训练了。以下异步方法训练一批训练样本并更新热图
// Trains and reports loss+accuracy for one batch of training data:
async function trainBatch(index) {
  const history = await model.fit(batches[index].x, batches[index].y, {
    epochs: 1,
    shuffle: false,
    validationData: [batches[index].x, batches[index].y],
    batchSize: CONSTANTS.BATCH_SIZE
  });

  // Don't block the UI frame by using tf.nextFrame()
  await tf.nextFrame();
  updateHeatmap();
  await tf.nextFrame();
}

可视化模型的准确性

热图是使用样本 4 英尺 x 4 英尺网格(均匀放置在本垒板上方)的预测矩阵构建的。在每个训练步骤之后,将该矩阵传递到模型中,以检查模型的准确性。预测结果使用 D3 库呈现为热图。

构建预测矩阵

热图中使用的预测矩阵从本垒板的中心开始,向左和向右延伸 2 英尺。它还从本垒板底部延伸到 4 英尺高。示例打击区范围在本垒板上方 1.5 到 3.5 英尺之间。以下视觉效果有助于可视化这些 2D 面板
此视觉效果显示了打击区和预测矩阵与本垒板和比赛场地之间的关系。

将预测矩阵与模型一起使用

在模型训练完一批之后,将预测矩阵传递到模型中,以请求矩阵中的球或好球预测
function predictZone() {
  const predictions = model.predictOnBatch(predictionMatrix.data);
  const values = predictions.dataSync();

  // Sort each value so the higher prediction is the first element in the array:
  const results = [];
  let index = 0;
  for (let i = 0; i < values.length; i++) {
    let list = [];
    list.push({value: values[index++], strike: 0});
    list.push({value: values[index++], strike: 1});
    list = list.sort((a, b) => b.value - a.value);
    results.push(list);
  }
  return results;
}

使用 D3 的热图

现在可以使用 D3 可视化预测结果。50x50 网格中的每个元素将以 SVG 中的 10 像素 x 10 像素矩形呈现。每个矩形的颜色将取决于预测结果(球或好球)以及模型对该结果的确定程度(比例范围为 50%-100%)。以下代码片段显示了如何从 D3 svg 矩形分组中更新数据
function updateHeatmap() {
  rects.data(generateHeatmapData());
  rects
    .attr('x', (coord) => { return scaleX(coord.x) * CONSTANTS.HEATMAP_SIZE; })
    .attr('y', (coord) => { return scaleY(coord.y) * CONSTANTS.HEATMAP_SIZE; })
    .attr('width', CONSTANTS.HEATMAP_SIZE)
    .attr('height', CONSTANTS.HEATMAP_SIZE)
    .style('fill', (coord) => {
      if (coord.strike) {
        return strikeColorScale(coord.value);
      } else {
        return ballColorScale(coord.value);
      }
  });
}
有关使用 D3 绘制热图的完整详细信息,请参见 此部分

总结

如今,Web 有许多用于创建视觉效果的奇妙库和工具。将这些工具与 TensorFlow.js 的机器学习能力相结合,使开发人员能够创建一些真正有趣的演示。

想了解更多信息,请查看以下链接
下一篇文章
Predicting balls and strikes using TensorFlow.js

- 作者:Nick Kreeger

在这篇文章中,我们将使用 TensorFlow.jsD3.js 和 Web 的力量来可视化训练模型的过程,该模型可以从棒球数据中预测球(蓝色区域)和好球(橙色区域)。在过程中,我们将可视化模型在整个训练过程中理解的打击区。您可以在浏览器中访问 此 Observable 笔记本 来运行此模型。

如果您是…