2018 年 7 月 25 日 - 作者:Nick Kreeger
在这篇文章中,我们将使用 TensorFlow.js、D3.js 和 Web 的力量来可视化训练模型的过程,该模型可以从棒球数据中预测球(蓝色区域)和好球(橙色区域)。在过程中,我们将可视化模型在整个训练过程中理解的打击区。您可以在浏览器中访问 此 Observable 笔记本 来运行此模型。
如果您是…
上面的 GIF 可视化了神经网络学习如何调用球(蓝色区域)和好球(橙色区域)。在每个训练步骤之后,热图都会更新模型的预测。 |
const model = tf.sequential();
// Two fully connected layers with dropout between each:
model.add(tf.layers.dense({units: 24, activation: 'relu', inputShape: [5]}));
model.add(tf.layers.dropout({rate: 0.01}));
model.add(tf.layers.dense({units: 16, activation: 'relu'}));
model.add(tf.layers.dropout({rate: 0.01}));
// Only two classes: "strike" and "ball":
model.add(tf.layers.dense({units: 2, activation: 'softmax'}));
model.compile({
optimizer: tf.train.adam(0.01),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
const data = [];
csvData.forEach((values) => {
// 'logit' data uses the 5 fields:
const x = [];
x.push(parseFloat(values.px));
x.push(parseFloat(values.pz));
x.push(parseFloat(values.sz_top));
x.push(parseFloat(values.sz_bot));
x.push(parseFloat(values.left_handed_batter));
// The label is simply 'is strike' or 'is ball':
const y = parseInt(values.is_strike, 10);
data.push({x: x, y: y});
});
// Shuffle the contents to ensure the model does not always train on the same
// sequence of pitch data:
tf.util.shuffle(data);
解析 CSV 数据后,需要将 JS 类型转换为张量批次,以便进行训练和评估。有关此过程的更多详细信息,请参阅 代码实验室。TensorFlow.js 团队正在开发一个新的数据 API,以在将来简化此数据摄取过程。// Trains and reports loss+accuracy for one batch of training data:
async function trainBatch(index) {
const history = await model.fit(batches[index].x, batches[index].y, {
epochs: 1,
shuffle: false,
validationData: [batches[index].x, batches[index].y],
batchSize: CONSTANTS.BATCH_SIZE
});
// Don't block the UI frame by using tf.nextFrame()
await tf.nextFrame();
updateHeatmap();
await tf.nextFrame();
}
此视觉效果显示了打击区和预测矩阵与本垒板和比赛场地之间的关系。 |
function predictZone() {
const predictions = model.predictOnBatch(predictionMatrix.data);
const values = predictions.dataSync();
// Sort each value so the higher prediction is the first element in the array:
const results = [];
let index = 0;
for (let i = 0; i < values.length; i++) {
let list = [];
list.push({value: values[index++], strike: 0});
list.push({value: values[index++], strike: 1});
list = list.sort((a, b) => b.value - a.value);
results.push(list);
}
return results;
}
function updateHeatmap() {
rects.data(generateHeatmapData());
rects
.attr('x', (coord) => { return scaleX(coord.x) * CONSTANTS.HEATMAP_SIZE; })
.attr('y', (coord) => { return scaleY(coord.y) * CONSTANTS.HEATMAP_SIZE; })
.attr('width', CONSTANTS.HEATMAP_SIZE)
.attr('height', CONSTANTS.HEATMAP_SIZE)
.style('fill', (coord) => {
if (coord.strike) {
return strikeColorScale(coord.value);
} else {
return ballColorScale(coord.value);
}
});
}
有关使用 D3 绘制热图的完整详细信息,请参见 此部分。
2018 年 7 月 25 日 - 作者:Nick Kreeger
在这篇文章中,我们将使用 TensorFlow.js、D3.js 和 Web 的力量来可视化训练模型的过程,该模型可以从棒球数据中预测球(蓝色区域)和好球(橙色区域)。在过程中,我们将可视化模型在整个训练过程中理解的打击区。您可以在浏览器中访问 此 Observable 笔记本 来运行此模型。
如果您是…