RL Boids

We can naturally use esquilax simulations as reinforcement learning environments, allowing training policies across batches of agents, and with multiple policies.

Environment

We’ll use the same updates and states from the Evolutionary Boids example, but wrap them up in an environment class

class BoidEnv(esquilax.ml.rl.Environment):
    def __init__(self, n_agents: int):
        self.n_agents = n_agents

    @property
    def default_params(self) -> Params:
        return Params()

    def reset(
        self, key: chex.PRNGKey, params: Params
    ) -> Tuple[chex.Array, Boid]:
        k1, k2, k3 = jax.random.split(key, 3)

        boids = Boid(
            pos=jax.random.uniform(k1, (self.n_agents, 2)),
            speed=jax.random.uniform(
                k2,
                (self.n_agents,),
                minval=params.min_speed,
                maxval=params.max_speed,
            ),
            heading=jax.random.uniform(
                k3, (self.n_agents,),
                minval=0.0, maxval=2.0 * jnp.pi
            ),
        )
        obs = self.get_obs(boids, params=params, key=key)
        return obs, boids

    def step(
        self,
        key: chex.PRNGKey,
        params: Params,
        state: Boid,
        actions: chex.Array,
    ) -> Tuple[chex.Array, Boid, chex.Array, chex.Array]:
        headings, speeds = update_velocity(
            key, params, (actions, state)
        )
        pos = move(key, params, (state.pos, headings, speeds))
        rewards = reward(key, params, pos, pos, pos=pos)
        boids = Boid(pos=pos, heading=headings, speed=speeds)
        obs = self.get_obs(boids, params=params, key=key)
        return obs, state, rewards, False

    def get_obs(
        self, state, params=None, key=None,
    ) -> chex.Array:
        n_nb, x_nb, s_nb, h_nb = observe(
            key, params, state, state, pos=state.pos
        )
        obs = flatten_observations(
            key, params, (state, n_nb, x_nb, s_nb, h_nb)
        )
        return obs

This structure is reasonably standard for reinforcement learning environments, with methods to reset the environment state, and a step methods that accepts actions and consequently updates the state of the environment. We’ve also included a convenience observation function that generates a flattened observation from the current environment state.

RL Agent

We also define the RL agent. In this case the boid agents will share a single policy (though we could also initialise individual policies). We implement the shared policy agent class esquilax.ml.rl.SharedPolicyAgent

Note

We’ll not implement the full RL agent functionality here (for brevity). The agent can be used to implement specific RL algorithms.

class RLAgent(ml.rl.SharedPolicyAgent):
    def sample_actions(self, _k, observations):
        actions = ml.get_actions(
            self.apply_fn, True, self.params, observations
        )
        return actions, None

    def update(self, _k, trajectories):
        return self, -1

The sample actions functions generates actions given observations, in this case we simply apply the agent network across the set of observations.

The update function should update the parameters and optimiser of the agent, given trajectories collected over the course of training.

Training

We can then run the training loop

def rl_boids(
    env_params: Params,
    n_agents: int,
    n_epochs: int,
    n_env: int,
    n_steps: int,
    layer_width: int = 16,
    show_progress: bool = True,
):
    k = jax.random.PRNGKey(451)
    k_init, k_train = jax.random.split(k)

    env = BoidEnv(n_agents)

    network = MLP(layer_width=layer_width, actions=2)
    opt = optax.adam(1e-4)
    agents = RLAgent.init(k_init, network, opt, (4,))

    trained_agents, rewards, _ = ml.rl.train(
        k_train,
        agents,
        env,
        env_params,
        n_epochs,
        n_env,
        n_steps,
        show_progress=show_progress,
    )

    return trained_agents, rewards

We initialise the environment and the RL agent from the neural network. We can then run the training loop using the built in esquilax.ml.rl.train() function.