Evolutionary Boids¶
The Hard-Coded Boids example used hand designed rules to update agent velocities, but it would be cool to learn behaviours.
In this example we look at how we can ue the Evosax and Flax libraries to use neuro-evolution to learn behaviours. Esquilax provides several utilities to help implement this training process.
Agent steering updates will be generated by a small multi-layer network with the network parameters for each agent taken from a population sampled using neuro-evolution.
In this case each agent will use a single parameter from the population (i.e. each agent will have slightly different behaviours) but we could also have the agents share policies, and asses the performance of each set of parameters across the whole plot.
As such the training loop will broadly follow these steps:
Sample a population of parameters from the evolutionary strategy
Randomly initialise a new simulation
Run the simulation using the sampled population as agent parameters
Use the rewards collected over the course of the simulation as fitness measure for the population
State¶
First we import JAX, Chex, Flax, and Esquilax
import chex
import evosax
import flax
import jax
import jax.numpy as jnp
import esquilax
from esquilax.ml import evo
In this case the state of the agents will be stored as their positions and velocity in polar co-ordinates
@chex.dataclass
class Boid:
pos: chex.Array
heading: float
speed: float
and the parameters include values that control speed/acceleration and reward values
@chex.dataclass
class Params:
max_speed: float = 0.05
min_speed: float = 0.025
max_rotate: float = 0.1
max_accelerate: float = 0.005
close_range: float = 0.01
collision_penalty: float = 0.1
Updates¶
The observation function counts any neighbours, calculates relative heading and position, then aggregates contributions from neighbours in-range
@partial(
esquilax.transforms.spatial,
i_range=0.1,
reduction=(jnp.add, jnp.add, jnp.add, jnp.add),
default=(0, jnp.zeros(2), 0.0, 0.0),
include_self=False,
)
def observe(_k: chex.PRNGKey, _params: Params, a: Boid, b: Boid):
dh = esquilax.utils.shortest_vector(
a.heading, b.heading, length=2 * jnp.pi
)
dx = esquilax.utils.shortest_vector(a.pos, b.pos)
return 1, dx, b.speed, dh
The next update then aggregates the observations into an observation array to be passed to the steering neural network
@esquilax.transforms.amap
def flatten_observations(_k: chex.PRNGKey, params: Params, observations):
boid, n_nb, x_nb, s_nb, h_nb = observations
def obs_to_nbs():
_dx_nb = x_nb / n_nb
_s_nb = s_nb / n_nb
_h_nb = h_nb / n_nb
d = jnp.sqrt(jnp.sum(_dx_nb * _dx_nb)) / 0.1
phi = jnp.arctan2(_dx_nb[1], _dx_nb[0]) + jnp.pi
d_phi = esquilax.utils.shortest_vector(
boid.heading, phi, 2 * jnp.pi
) / jnp.pi
dh = _h_nb / jnp.pi
ds = (_s_nb - boid.speed)
ds = ds / (params.max_speed - params.min_speed)
return jnp.array([d, d_phi, dh, ds])
return jax.lax.cond(
n_nb > 0,
obs_to_nbs,
lambda: jnp.array([-1.0, 0.0, 0.0, 0.0]),
)
if a boid has neighbours, this function then converts the observation to a vector (in polar co-ordinates) to the average position of the local flock, and polar co-ordinates to the average heading of the local flock, taking into account the heading of the boid. If there are no neighbours it returns a default value. The result is a size 4 observation vector for each agent.
The observation can be fed to the network using the built in
esquilax.ml.get_actions()
function that maps the observations
across population parameter samples. The output of this function is
the steering updates for each agent.
The outputs of the network are then converted to updated agent headings and speeds
@esquilax.transforms.amap
def update_velocity(
_k: chex.PRNGKey, params: Params, x: Tuple[chex.Array, Boid]
):
actions, boid = x
rotation = actions[0] * params.max_rotate * jnp.pi
acceleration = actions[1] * params.max_accelerate
new_heading = (boid.heading + rotation) % (2 * jnp.pi)
new_speeds = jnp.clip(
boid.speed + acceleration,
min=params.min_speed,
max=params.max_speed,
)
return new_heading, new_speeds
Finally all the boids positions are update from the new velocities
@esquilax.transforms.amap
def move(_key: chex.PRNGKey, _params: Params, x):
pos, heading, speed = x
d_pos = jnp.array(
[speed * jnp.cos(heading), speed * jnp.sin(heading)]
)
return (pos + d_pos) % 1.0
We will score agents based on distance to other agents, providing a negative score if too close, and then rewards that exponentially decay as distance increase. We can again use the spatial transformation to calculate reward contributions
@partial(
esquilax.transforms.spatial,
i_range=0.1,
reduction=jnp.add,
default=0.0,
include_self=False,
)
def reward(_k: chex.PRNGKey, params: Params, a: chex.Array, b: chex.Array):
d = esquilax.utils.shortest_distance(a, b, norm=True)
reward = jax.lax.cond(
d < params.close_range,
lambda _: -params.collision_penalty,
lambda _d: jnp.exp(-50 * _d),
d,
)
return reward
Training Environment¶
To use the built-in training functionality we wrap the environment
initialisation and model update in a esquilax.SimEnv
class:
class BoidEnv(esquilax.Sim):
def __init__(
self,
apply_fun: Callable,
n_agents: int,
min_speed: float,
max_speed: float
):
self.apply_fun = apply_fun
self.n_agents = n_agents
self.min_speed = min_speed
self.max_speed = max_speed
def default_params(self) -> Params:
return Params()
def initial_state(
self, k: chex.PRNGKey, params: Params
) -> Boid:
k1, k2, k3 = jax.random.split(k, 3)
return Boid(
pos=jax.random.uniform(k1, (self.n_agents, 2)),
speed=jax.random.uniform(
k2,
(self.n_agents,),
minval=self.min_speed,
maxval=self.max_speed
),
heading=jax.random.uniform(
k3,
(self.n_agents,),
minval=0.0,
maxval=2.0 * jnp.pi
),
)
def step(
self,
_i: int,
k: chex.PRNGKey,
params: Params,
boids: Boid,
*,
agent_params,
) -> Tuple[Boid, evo.TrainingData]:
n_nb, x_nb, s_nb, h_nb = observe(
k, params, boids, boids, pos=boids.pos
)
obs = flatten_observations(
k, params, (boids, n_nb, x_nb, s_nb, h_nb)
)
actions = esquilax.ml.get_actions(
self.apply_fun, False, agent_params, obs
)
headings, speeds = update_velocity(
k, params, (actions, boids)
)
pos = move(k, params, (boids.pos, headings, speeds))
rewards = reward(k, params, pos, pos, pos=pos)
boids = Boid(pos=pos, heading=headings, speed=speeds)
return (
boids,
evo.TrainingData(rewards=rewards, records=pos)
)
Static simulation parameters (in this case the number of agents and the network function) can be passed as attributes of the class.
The initialisation method initialises random initial positions and velocities of the boids.
The step method combines the simulation updates. The current population or parameter sample is provided as a keyword argument
agent_params
. The step function should also return aesquilax.ml.evo.TrainingData
class (containing generated rewards and any state data to be recorded) as data to be recorded.
Training¶
We can then run the training loop. First we define a simple network agents will use to steer. For this we will use Flax
class MLP(flax.linen.Module):
layer_width: int
actions: int
@flax.linen.compact
def __call__(self, x):
x = flax.linen.Dense(features=self.layer_width)(x)
x = flax.linen.tanh(x)
x = flax.linen.Dense(features=self.layer_width)(x)
x = flax.linen.tanh(x)
x = flax.linen.Dense(features=self.actions)(x)
x = flax.linen.tanh(x)
return x
to define a simple multi-layered network, with a tanh output
layer corresponding to desired actions in the [-1, 1]
range.
The full training process can then be run using built-in training functionality:
def evo_boids(
env_params: Params,
n_agents: int,
n_generations: int,
n_samples: int,
n_steps: int,
show_progress: bool = True,
strategy=evosax.strategies.OpenES,
layer_width: int = 16,
):
k = jax.random.PRNGKey(101)
network = MLP(layer_width=layer_width, actions=2)
net_params = network.init(k, jnp.zeros(4))
strategy = evo.BasicStrategy(
net_params, strategy, n_agents
)
evo_params = strategy.default_params()
evo_state = strategy.initialize(k, evo_params)
env = BoidEnv(
network.apply,
n_agents,
env_params.min_speed,
env_params.max_speed
)
evo_state, agent_rewards = evo.train(
strategy,
env,
n_generations,
n_steps,
n_samples,
False,
k,
evo_params,
evo_state,
show_progress=show_progress,
env_params=env_params,
)
params, evo_state = strategy.ask(
k, evo_state, evo_params
)
params_shaped = strategy.reshape_params(params)
test_data = evo.test(
params_shaped,
env,
n_samples,
n_steps,
False,
k,
env_params=env_params,
show_progress=show_progress,
)
return evo_state, agent_rewards, test_data.records, test_data.rewards
In this case we first initialise a random key and
dummy parameters for the neural-network. We then initialise an evolutionary
strategy from these parameters using esquilax.ml.evo.BasicStrategy
.
We then also initialise the evolutionary strategy state, and the training
environment.
We can then use esquilax.ml.evo.train()
to generate
a trained strategy state and record of rewards over training,
then use esquilax.ml.evo.test()
to test the trained strategy, and to generate trajectories for
analysis/visualisation.
Multi-Strategy¶
The above can be (relatively) easily extended to accommodate the training of multiple strategies in the same training loop. Multiple strategies can be passed as a collection, e.g. as a tuple
strategies = (
evo.BasicStrategy(net_params_a, strategy_a, n_agents_a),
evo.BasicStrategy(net_params_b, strategy_b, n_agents_b),
)
or a Flax FrozenDict
strategies = FrozenDict(
a=evo.BasicStrategy(net_params_a, strategy_a, n_agents_a),
b=evo.BasicStrategy(net_params_b, strategy_b, n_agents_b),
)
The strategy parameters and state should then have the same tree structure
evo_params = FrozenDict(
a=strategies["a"].default_params(),
b=strategies["b"].default_params(),
)
evo_states = FrozenDict(
a=strategies["a"].initialize(k1, evo_params["a"]),
b=strategies["b"].initialize(k2, evo_params["b"],
)
Finally we should ensure that the step function is updated to accommodate the
tree structure. The agent_params
argument will have the same tree structure,
and the training data returned by the step method should also have this structure,
e.g.:
def step(
self,
_i: int,
k: chex.PRNGKey,
params: Params,
boids: Boid,
*,
agent_params,
) -> Tuple[Boid, evo.TrainingData]:
# agent_params has structure FrozenDict(a=..., b=...)
...
# Then return data with the same structure
training_data = FrozenDict(
a=evo.TrainingData(rewards=rewards_a, records=pos_a),
b=evo.TrainingData(rewards=rewards_b, records=pos_b)
)
return boids, training_data
Esquilax will then handle mapping over the individual strategies during training. Strategies will be updated and queried independently, but can be made to interact via the simulation.