使用 JAX 构建强化学习代理,并使用 TensorFlow Lite 部署到 Android
2022 年 10 月 3 日

作者:Wei Wei,开发者倡导者

在我们的上一篇博文中 使用 TensorFlow 构建棋盘游戏应用程序:新的 TensorFlow Lite 参考应用程序,我们向您展示了如何使用 TensorFlow 和 TensorFlow Agents 训练一个强化学习 (RL) 代理来玩一个简单的棋盘游戏“Plane Strike”。我们还将训练后的模型转换为 TensorFlow Lite,然后将其部署到一个功能齐全的 Android 应用程序中。在这篇博文中,我们将演示一条新的路径:使用 Flax/JAX 训练相同的 RL 代理,并将其部署到我们之前构建的同一 Android 应用程序中。完整的代码已在 tensorflow/examples 存储库中开源,供您参考。

为了刷新您的记忆,我们基于 RL 的代理需要根据人类玩家的棋盘位置预测打击位置,以便它可以在人类玩家之前完成游戏。有关更详细的游戏规则,请参阅我们之前的 博客

Demo game play in ‘Plane Strike’
“Plane Strike”中的演示游戏玩法

背景:JAX 和 TensorFlow

JAX 是 Google Research 开发的一个类似 NumPy 的库,用于高性能计算。它使用 XLA 来编译针对 GPU 和 TPU 优化的程序。 Flax 是一个基于 JAX 之上的流行神经网络库。研究人员一直在使用 JAX/Flax 训练具有数十亿个参数的超大型模型(例如用于语言理解和生成的 PaLM,或用于图像生成的 Imagen),充分利用了现代硬件。如果您不熟悉 JAX 和 Flax,请从 JAX 101 教程Flax 入门示例 开始。

TensorFlow 最初作为 ML 库在 2015 年底推出,此后发展成为一个丰富的生态系统,其中包括用于生产化 ML 管道 (TFX)、数据可视化 (TensorBoard)、将 ML 模型部署到边缘设备 (TensorFlow Lite) 以及在 Web 浏览器或任何能够执行 JavaScript 的设备上运行的设备 (TensorFlow.js) 的工具。在 JAX 或 Flax 中开发的模型可以通过首先将此类模型转换为 TensorFlow SavedModel 格式,然后使用与在 TensorFlow 中原生开发相同的工具,来利用这个丰富的生态系统。

如果您已经拥有 JAX 训练的模型并希望立即部署它,我们为您整理了一份资源列表
  • 这篇博文演示了如何将 Flax/JAX 模型转换为 TFLite 并在原生 Android 应用程序中运行它
总的来说,无论您的部署目标是什么(服务器、Web 或移动),我们都能满足您的需求。
使用 Flax/JAX 实现游戏代理

回到我们的棋盘游戏,为了实现我们的 RL 代理,我们将利用与之前相同的 gym 环境。我们将这次使用 Flax/JAX 训练相同的策略梯度模型。回想一下,从数学角度来说,策略梯度定义为
 

其中

  • T:每集的时间步长数,这在每个集之间可能有所不同
  • st:时间步长 t 处的状态
  • at:在给定状态 s 时,在时间步长 t 处选择的动作
  • πθ:由 θ 参数化的策略
  • R(*):给定策略所收集的奖励

我们将一个 3 层 MLP 定义为我们的策略网络,它预测代理的下一个打击位置。

class PolicyGradient(nn.Module):

  """用于预测下一个打击位置的神经网络"""

 

  @nn.compact

  def __call__(self, x):

    dtype = jnp.float32

    x = x.reshape((x.shape[0], -1))

    x = nn.Dense(

        features=2 * common.BOARD_SIZE**2, name='hidden1', dtype=dtype)(

           x)

    x = nn.relu(x)

    x = nn.Dense(features=common.BOARD_SIZE**2, name='hidden2', dtype=dtype)(x)

    x = nn.relu(x)

    x = nn.Dense(features=common.BOARD_SIZE**2, name='logits', dtype=dtype)(x)

    policy_probabilities = nn.softmax(x)

    return policy_probabilities


在我们的主要训练循环中,在每次迭代中,我们使用神经网络玩一轮游戏,收集轨迹信息(游戏棋盘位置、采取的动作和奖励),对奖励进行折现,然后使用轨迹训练模型。

for i in tqdm(range(iterations)):

   predict_fn = functools.partial(run_inference, params)

   board_log, action_log, result_log = common.play_game(predict_fn)

   rewards = common.compute_rewards(result_log)

   optimizer, params, opt_state = train_step(optimizer, params, opt_state,

                                             board_log, action_log, rewards)


在 train_step() 方法中,我们首先使用轨迹计算损失。然后我们使用 jax.grad() 计算梯度。最后,我们使用 Optax,一个用于 JAX 的梯度处理和优化库,来更新模型参数。

def compute_loss(logits, labels, rewards):

  one_hot_labels = jax.nn.one_hot(labels, num_classes=common.BOARD_SIZE**2)

  loss = -jnp.mean(

      jnp.sum(one_hot_labels * jnp.log(logits), axis=-1) * jnp.asarray(rewards))

  return loss

 

 

def train_step(model_optimizer, params, opt_state, game_board_log,

              predicted_action_log, action_result_log):

"""运行一个训练步骤。"""

 

  def loss_fn(model_params):

    logits = run_inference(model_params, game_board_log)

    loss = compute_loss(logits, predicted_action_log, action_result_log)

    return loss

 

  def compute_grads(params):

    return jax.grad(loss_fn)(params)

 

  grads = compute_grads(params)

  updates, opt_state = model_optimizer.update(grads, opt_state)

  params = optax.apply_updates(params, updates)

  return model_optimizer, params, opt_state

 

 

@jax.jit

def run_inference(model_params, board):

  logits = PolicyGradient().apply({'params': model_params}, board)

  return logits


这就是训练循环。我们可以使用 TensorBoard 可视化训练进度,如下所示;这里我们使用代理指标“game_length”(完成游戏的步数)来跟踪进度。直观地,当代理变得更智能时,它可以在更少的步数内完成游戏。

将 Flax/JAX 模型转换为 TensorFlow Lite 并集成到 Android 应用程序中

模型训练完成后,我们使用 jax2tf(一个 TensorFlow-JAX 交互工具)将 JAX 模型转换为 TensorFlow 具体函数。最后一步是调用 TensorFlow Lite 转换器将具体函数转换为 TFLite 模型。

# 转换为 tflite 模型

 model = PolicyGradient()

 jax_predict_fn = lambda input: model.apply({'params': params}, input)

 

 tf_predict = tf.function(

     jax2tf.convert(jax_predict_fn, enable_xla=False),

     input_signature=[

         tf.TensorSpec(

             shape=[1, common.BOARD_SIZE, common.BOARD_SIZE],

             dtype=tf.float32,

             name='input')

     ],

     autograph=False,

 )

 

 converter = tf.lite.TFLiteConverter.from_concrete_functions(

     [tf_predict.get_concrete_function()], tf_predict)

 

 tflite_model = converter.convert()

 

 # 保存模型

 with open(os.path.join(modeldir, 'planestrike.tflite'), 'wb') as f:

   f.write(tflite_model)


JAX 转换后的 TFLite 模型的行为与任何 TensorFlow 训练的 TFLite 模型完全一致。您可以使用 Netron 对其进行可视化
Visualizing TFLite model converted from Flax/JAX using Netron
使用 Netron 可视化由 Flax/JAX 转换的 TFLite 模型
我们可以使用与之前完全相同的 Java 代码来调用模型并获取预测结果。

convertBoardStateToByteBuffer(board);
tflite.run(boardData, outputProbArrays);
float[] probArray = outputProbArrays[0];
int agentStrikePosition = -1;
float maxProb = 0;
for (int i = 0; i < probArray.length; i++) {
  int x = i / Constants.BOARD_SIZE;
  int y = i % Constants.BOARD_SIZE;
  if (board[x][y] == BoardCellStatus.UNTRIED && probArray[i] > maxProb) {
    agentStrikePosition = i;
    maxProb = probArray[i];
  }
}

结论

总之,本文将引导您学习如何使用 Flax/JAX 训练一个简单的强化学习模型,利用 jax2tf 将其转换为 TensorFlow Lite,并将转换后的模型集成到 Android 应用程序中。

现在您已经学习了如何使用 Flax/JAX 构建神经网络模型,并利用强大的 TensorFlow 生态系统将您的模型部署到几乎任何地方。我们迫不及待地想要看到您使用 JAX 和 TensorFlow 构建的精彩应用程序!
下一篇文章
Building a reinforcement learning agent with JAX, and deploying it on Android with TensorFlow Lite

作者:Wei Wei,开发者倡导者在我们之前的博文 使用 TensorFlow 构建棋盘游戏应用程序:新的 TensorFlow Lite 参考应用程序中,我们向您展示了如何使用 TensorFlow 和 TensorFlow Agents 训练一个强化学习 (RL) 智能体来玩简单的棋盘游戏“飞机攻击”。我们还将训练后的模型转换为 TensorFlow Lite,然后将其部署到功能齐全的 Android a…