RL Boids¶
We can naturally use esquilax simulations as reinforcement learning environments, allowing training policies across batches of agents, and with multiple policies.
Environment¶
We’ll use the same updates and states from the Evolutionary Boids example, but wrap them up in an environment class
class BoidEnv(esquilax.ml.rl.Environment):
def __init__(self, n_agents: int):
self.n_agents = n_agents
@property
def default_params(self) -> Params:
return Params()
def reset(
self, key: chex.PRNGKey, params: Params
) -> Tuple[chex.Array, Boid]:
k1, k2, k3 = jax.random.split(key, 3)
boids = Boid(
pos=jax.random.uniform(k1, (self.n_agents, 2)),
speed=jax.random.uniform(
k2,
(self.n_agents,),
minval=params.min_speed,
maxval=params.max_speed,
),
heading=jax.random.uniform(
k3, (self.n_agents,),
minval=0.0, maxval=2.0 * jnp.pi
),
)
obs = self.get_obs(boids, params=params)
return obs, boids
def step(
self,
_key: chex.PRNGKey,
params: Params,
state: Boid,
actions: chex.Array,
) -> Tuple[chex.Array, Boid, chex.Array, chex.Array]:
headings, speeds = update_velocity(
params, (actions, state)
)
pos = move(params, (state.pos, headings, speeds))
rewards = reward(params, pos, pos, pos=pos)
boids = Boid(pos=pos, heading=headings, speed=speeds)
obs = self.get_obs(boids, params=params)
return obs, state, rewards, False
def get_obs(
self, state, params=None
) -> chex.Array:
n_nb, x_nb, s_nb, h_nb = observe(
params, state, state, pos=state.pos
)
obs = flatten_observations(
params, (state, n_nb, x_nb, s_nb, h_nb)
)
return obs
This structure is reasonably standard for reinforcement learning environments, with methods to reset the environment state, and a step methods that accepts actions and consequently updates the state of the environment. We’ve also included a convenience observation function that generates a flattened observation from the current environment state.
RL Agent¶
We also define the RL agent. In this case the boid agents
will share a single policy (though we could also initialise
individual policies). We implement the shared policy agent
class esquilax.ml.rl.SharedPolicyAgent
Note
We’ll not implement the full RL agent functionality here (for brevity). The agent can be used to implement specific RL algorithms.
class RLAgent(ml.rl.Agent):
def sample_actions(
self, key, agent_state, observations, greedy=False,
):
actions = agent_state.apply(observations)
return actions, None
def update(
self, key, agent_state, trajectories,
):
return agent_state, -1
The sample actions functions generates actions given observations, in this case we simply apply the agent network across the set of observations.
The update function should update the parameters and optimiser of the agent, given trajectories collected over the course of training.
Training¶
We can then run the training loop
def rl_boids(
env_params: Params,
n_agents: int,
n_epochs: int,
n_env: int,
n_steps: int,
layer_width: int = 16,
show_progress: bool = True,
):
k = jax.random.PRNGKey(451)
k_init, k_train = jax.random.split(k)
env = BoidEnv(n_agents)
network = MLP(layer_width=layer_width, actions=2)
opt = optax.adam(1e-4)
agent = RLAgent()
agent_state = ml.rl.AgentState.init_from_model(
k_init, network, opt, (4,)
)
trained_agents, rewards, _ = ml.rl.train(
k_train,
agent,
agent_state,
env,
env_params,
n_epochs,
n_env,
n_steps,
show_progress=show_progress,
)
return trained_agents, rewards
We initialise the environment and the RL agent from the
neural network. We can then run the training loop using the
built in esquilax.ml.rl.train()
function.