2022 年 10 月 3 日 — 作者:Wei Wei,开发者倡导者在我们的上一篇博文中 使用 TensorFlow 构建棋盘游戏应用程序:新的 TensorFlow Lite 参考应用程序,我们向您展示了如何使用 TensorFlow 和 TensorFlow Agents 训练一个强化学习 (RL) 代理来玩一个简单的棋盘游戏“Plane Strike”。我们还将训练后的模型转换为 TensorFlow Lite,然后将其部署到一个功能齐全的 Android 应用程序中。在这篇博文中,我们将演示一条新的路径:使用 Flax/JAX 训练相同的 RL 代理,并将其部署到我们之前构建的同一 Android 应用程序中。完整的代码已在 tensorflow/examples 存储库中开源,供您参考。
作者:Wei Wei,开发者倡导者
在我们的上一篇博文中 使用 TensorFlow 构建棋盘游戏应用程序:新的 TensorFlow Lite 参考应用程序,我们向您展示了如何使用 TensorFlow 和 TensorFlow Agents 训练一个强化学习 (RL) 代理来玩一个简单的棋盘游戏“Plane Strike”。我们还将训练后的模型转换为 TensorFlow Lite,然后将其部署到一个功能齐全的 Android 应用程序中。在这篇博文中,我们将演示一条新的路径:使用 Flax/JAX 训练相同的 RL 代理,并将其部署到我们之前构建的同一 Android 应用程序中。完整的代码已在 tensorflow/examples 存储库中开源,供您参考。
为了刷新您的记忆,我们基于 RL 的代理需要根据人类玩家的棋盘位置预测打击位置,以便它可以在人类玩家之前完成游戏。有关更详细的游戏规则,请参阅我们之前的 博客。
“Plane Strike”中的演示游戏玩法 |
其中
我们将一个 3 层 MLP 定义为我们的策略网络,它预测代理的下一个打击位置。
class PolicyGradient(nn.Module): """用于预测下一个打击位置的神经网络"""
@nn.compact def __call__(self, x): dtype = jnp.float32 x = x.reshape((x.shape[0], -1)) x = nn.Dense( features=2 * common.BOARD_SIZE**2, name='hidden1', dtype=dtype)( x) x = nn.relu(x) x = nn.Dense(features=common.BOARD_SIZE**2, name='hidden2', dtype=dtype)(x) x = nn.relu(x) x = nn.Dense(features=common.BOARD_SIZE**2, name='logits', dtype=dtype)(x) policy_probabilities = nn.softmax(x) return policy_probabilities |
for i in tqdm(range(iterations)): predict_fn = functools.partial(run_inference, params) board_log, action_log, result_log = common.play_game(predict_fn) rewards = common.compute_rewards(result_log) optimizer, params, opt_state = train_step(optimizer, params, opt_state, board_log, action_log, rewards) |
def compute_loss(logits, labels, rewards): one_hot_labels = jax.nn.one_hot(labels, num_classes=common.BOARD_SIZE**2) loss = -jnp.mean( jnp.sum(one_hot_labels * jnp.log(logits), axis=-1) * jnp.asarray(rewards)) return loss
def train_step(model_optimizer, params, opt_state, game_board_log, predicted_action_log, action_result_log): """运行一个训练步骤。"""
def loss_fn(model_params): logits = run_inference(model_params, game_board_log) loss = compute_loss(logits, predicted_action_log, action_result_log) return loss
def compute_grads(params): return jax.grad(loss_fn)(params)
grads = compute_grads(params) updates, opt_state = model_optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) return model_optimizer, params, opt_state
@jax.jit def run_inference(model_params, board): logits = PolicyGradient().apply({'params': model_params}, board) return logits |
# 转换为 tflite 模型 model = PolicyGradient() jax_predict_fn = lambda input: model.apply({'params': params}, input)
tf_predict = tf.function( jax2tf.convert(jax_predict_fn, enable_xla=False), input_signature=[ tf.TensorSpec( shape=[1, common.BOARD_SIZE, common.BOARD_SIZE], dtype=tf.float32, name='input') ], autograph=False, )
converter = tf.lite.TFLiteConverter.from_concrete_functions( [tf_predict.get_concrete_function()], tf_predict)
tflite_model = converter.convert()
# 保存模型 with open(os.path.join(modeldir, 'planestrike.tflite'), 'wb') as f: f.write(tflite_model) |
使用 Netron 可视化由 Flax/JAX 转换的 TFLite 模型 |
convertBoardStateToByteBuffer(board); |
2022 年 10 月 3 日 — 作者:Wei Wei,开发者倡导者在我们之前的博文 使用 TensorFlow 构建棋盘游戏应用程序:新的 TensorFlow Lite 参考应用程序中,我们向您展示了如何使用 TensorFlow 和 TensorFlow Agents 训练一个强化学习 (RL) 智能体来玩简单的棋盘游戏“飞机攻击”。我们还将训练后的模型转换为 TensorFlow Lite,然后将其部署到功能齐全的 Android a…