Evolutionary Boids

The Hard-Coded Boids example used hand designed rules to update agent velocities, but it would be cool to learn behaviours.

In this example we look at how we can ue the Evosax and Flax libraries to use neuro-evolution to learn behaviours. Esquilax provides several utilities to help implement this training process.

Agent steering updates will be generated by a small multi-layer network with the network parameters for each agent taken from a population sampled using neuro-evolution.

In this case each agent will use a single parameter from the population (i.e. each agent will have slightly different behaviours) but we could also have the agents share policies, and asses the performance of each set of parameters across the whole plot.

As such the training loop will broadly follow these steps:

  • Sample a population of parameters from the evolutionary strategy

  • Randomly initialise a new simulation

  • Run the simulation using the sampled population as agent parameters

  • Use the rewards collected over the course of the simulation as fitness measure for the population

State

First we import JAX, Chex, Flax, and Esquilax

import chex
import evosax
import flax
import jax
import jax.numpy as jnp

import esquilax
from esquilax.ml import evo

In this case the state of the agents will be stored as their positions and velocity in polar co-ordinates

@chex.dataclass
class Boid:
    pos: chex.Array
    heading: float
    speed: float

and the parameters include values that control speed/acceleration and reward values

@chex.dataclass
class Params:
    max_speed: float = 0.05
    min_speed: float = 0.025
    max_rotate: float = 0.1
    max_accelerate: float = 0.005
    close_range: float = 0.01
    collision_penalty: float = 0.1

Updates

The observation function counts any neighbours, calculates relative heading and position, then aggregates contributions from neighbours in-range

@partial(
    esquilax.transforms.spatial,
    i_range=0.1,
    reduction=(jnp.add, jnp.add, jnp.add, jnp.add),
    default=(0, jnp.zeros(2), 0.0, 0.0),
    include_self=False,
)
def observe(_k: chex.PRNGKey, _params: Params, a: Boid, b: Boid):
    dh = esquilax.utils.shortest_vector(
        a.heading, b.heading, length=2 * jnp.pi
    )
    dx = esquilax.utils.shortest_vector(a.pos, b.pos)
    return 1, dx, b.speed, dh

The next update then aggregates the observations into an observation array to be passed to the steering neural network

@esquilax.transforms.amap
def flatten_observations(_k: chex.PRNGKey, params: Params, observations):
    boid, n_nb, x_nb, s_nb, h_nb = observations

    def obs_to_nbs():
        _dx_nb = x_nb / n_nb
        _s_nb = s_nb / n_nb
        _h_nb = h_nb / n_nb

        d = jnp.sqrt(jnp.sum(_dx_nb * _dx_nb)) / 0.1
        phi = jnp.arctan2(_dx_nb[1], _dx_nb[0]) + jnp.pi
        d_phi = esquilax.utils.shortest_vector(
            boid.heading, phi, 2 * jnp.pi
        ) / jnp.pi
        dh = _h_nb / jnp.pi
        ds = (_s_nb - boid.speed)
        ds = ds / (params.max_speed - params.min_speed)

        return jnp.array([d, d_phi, dh, ds])

    return jax.lax.cond(
        n_nb > 0,
        obs_to_nbs,
        lambda: jnp.array([-1.0, 0.0, 0.0, 0.0]),
    )

if a boid has neighbours, this function then converts the observation to a vector (in polar co-ordinates) to the average position of the local flock, and polar co-ordinates to the average heading of the local flock, taking into account the heading of the boid. If there are no neighbours it returns a default value. The result is a size 4 observation vector for each agent.

The observation can be fed to the network using the built in esquilax.ml.get_actions() function that maps the observations across population parameter samples. The output of this function is the steering updates for each agent.

The outputs of the network are then converted to updated agent headings and speeds

@esquilax.transforms.amap
def update_velocity(
    _k: chex.PRNGKey, params: Params, x: Tuple[chex.Array, Boid]
):
    actions, boid = x
    rotation = actions[0] * params.max_rotate * jnp.pi
    acceleration = actions[1] * params.max_accelerate

    new_heading = (boid.heading + rotation) % (2 * jnp.pi)
    new_speeds = jnp.clip(
        boid.speed + acceleration,
        min=params.min_speed,
        max=params.max_speed,
    )

    return new_heading, new_speeds

Finally all the boids positions are update from the new velocities

@esquilax.transforms.amap
def move(_key: chex.PRNGKey, _params: Params, x):
    pos, heading, speed = x
    d_pos = jnp.array(
        [speed * jnp.cos(heading), speed * jnp.sin(heading)]
    )
    return (pos + d_pos) % 1.0

We will score agents based on distance to other agents, providing a negative score if too close, and then rewards that exponentially decay as distance increase. We can again use the spatial transformation to calculate reward contributions

@partial(
    esquilax.transforms.spatial,
    i_range=0.1,
    reduction=jnp.add,
    default=0.0,
    include_self=False,
)
def reward(_k: chex.PRNGKey, params: Params, a: chex.Array, b: chex.Array):
    d = esquilax.utils.shortest_distance(a, b, norm=True)

    reward = jax.lax.cond(
        d < params.close_range,
        lambda _: -params.collision_penalty,
        lambda _d: jnp.exp(-50 * _d),
        d,
    )
    return reward

Training Environment

To use the built-in training functionality we wrap the environment initialisation and model update in a esquilax.SimEnv class:

class BoidEnv(esquilax.Sim):
    def __init__(
        self,
        apply_fun: Callable,
        n_agents: int,
        min_speed: float,
        max_speed: float
    ):
        self.apply_fun = apply_fun
        self.n_agents = n_agents
        self.min_speed = min_speed
        self.max_speed = max_speed

    def default_params(self) -> Params:
        return Params()

    def initial_state(
        self, k: chex.PRNGKey, params: Params
    ) -> Boid:
        k1, k2, k3 = jax.random.split(k, 3)

        return Boid(
            pos=jax.random.uniform(k1, (self.n_agents, 2)),
            speed=jax.random.uniform(
                k2,
                (self.n_agents,),
                minval=self.min_speed,
                maxval=self.max_speed
            ),
            heading=jax.random.uniform(
                k3,
                (self.n_agents,),
                minval=0.0,
                maxval=2.0 * jnp.pi
            ),
        )

    def step(
        self,
        _i: int,
        k: chex.PRNGKey,
        params: Params,
        boids: Boid,
        *,
        agent_params,
    ) -> Tuple[Boid, evo.TrainingData]:

        n_nb, x_nb, s_nb, h_nb = observe(
            k, params, boids, boids, pos=boids.pos
        )
        obs = flatten_observations(
            k, params, (boids, n_nb, x_nb, s_nb, h_nb)
        )
        actions = esquilax.ml.get_actions(
            self.apply_fun, False, agent_params, obs
        )
        headings, speeds = update_velocity(
            k, params, (actions, boids)
        )
        pos = move(k, params, (boids.pos, headings, speeds))
        rewards = reward(k, params, pos, pos, pos=pos)
        boids = Boid(pos=pos, heading=headings, speed=speeds)
        return (
            boids,
            evo.TrainingData(rewards=rewards, records=pos)
        )
  • Static simulation parameters (in this case the number of agents and the network function) can be passed as attributes of the class.

  • The initialisation method initialises random initial positions and velocities of the boids.

  • The step method combines the simulation updates. The current population or parameter sample is provided as a keyword argument agent_params. The step function should also return a esquilax.ml.evo.TrainingData class (containing generated rewards and any state data to be recorded) as data to be recorded.

Training

We can then run the training loop. First we define a simple network agents will use to steer. For this we will use Flax

class MLP(flax.linen.Module):
    layer_width: int
    actions: int

    @flax.linen.compact
    def __call__(self, x):
        x = flax.linen.Dense(features=self.layer_width)(x)
        x = flax.linen.tanh(x)
        x = flax.linen.Dense(features=self.layer_width)(x)
        x = flax.linen.tanh(x)
        x = flax.linen.Dense(features=self.actions)(x)
        x = flax.linen.tanh(x)

        return x

to define a simple multi-layered network, with a tanh output layer corresponding to desired actions in the [-1, 1] range.

The full training process can then be run using built-in training functionality:

def evo_boids(
    env_params: Params,
    n_agents: int,
    n_generations: int,
    n_samples: int,
    n_steps: int,
    show_progress: bool = True,
    strategy=evosax.strategies.OpenES,
    layer_width: int = 16,
):
    k = jax.random.PRNGKey(101)

    network = MLP(layer_width=layer_width, actions=2)
    net_params = network.init(k, jnp.zeros(4))

    strategy = evo.BasicStrategy(
        net_params, strategy, n_agents
    )
    evo_params = strategy.default_params()
    evo_state = strategy.initialize(k, evo_params)

    env = BoidEnv(
        network.apply,
        n_agents,
        env_params.min_speed,
        env_params.max_speed
    )

    evo_state, agent_rewards = evo.train(
        strategy,
        env,
        n_generations,
        n_steps,
        n_samples,
        False,
        k,
        evo_params,
        evo_state,
        show_progress=show_progress,
        env_params=env_params,
    )

    params, evo_state = strategy.ask(
        k, evo_state, evo_params
    )
    params_shaped = strategy.reshape_params(params)

    test_data = evo.test(
        params_shaped,
        env,
        n_samples,
        n_steps,
        False,
        k,
        env_params=env_params,
        show_progress=show_progress,
    )

    return evo_state, agent_rewards, test_data.records, test_data.rewards

In this case we first initialise a random key and dummy parameters for the neural-network. We then initialise an evolutionary strategy from these parameters using esquilax.ml.evo.BasicStrategy. We then also initialise the evolutionary strategy state, and the training environment.

We can then use esquilax.ml.evo.train() to generate a trained strategy state and record of rewards over training, then use esquilax.ml.evo.test() to test the trained strategy, and to generate trajectories for analysis/visualisation.

Multi-Strategy

The above can be (relatively) easily extended to accommodate the training of multiple strategies in the same training loop. Multiple strategies can be passed as a collection, e.g. as a tuple

strategies = (
    evo.BasicStrategy(net_params_a, strategy_a, n_agents_a),
    evo.BasicStrategy(net_params_b, strategy_b, n_agents_b),
)

or a Flax FrozenDict

strategies = FrozenDict(
    a=evo.BasicStrategy(net_params_a, strategy_a, n_agents_a),
    b=evo.BasicStrategy(net_params_b, strategy_b, n_agents_b),
)

The strategy parameters and state should then have the same tree structure

evo_params = FrozenDict(
    a=strategies["a"].default_params(),
    b=strategies["b"].default_params(),
)
evo_states = FrozenDict(
    a=strategies["a"].initialize(k1, evo_params["a"]),
    b=strategies["b"].initialize(k2, evo_params["b"],
)

Finally we should ensure that the step function is updated to accommodate the tree structure. The agent_params argument will have the same tree structure, and the training data returned by the step method should also have this structure, e.g.:

def step(
    self,
    _i: int,
    k: chex.PRNGKey,
    params: Params,
    boids: Boid,
    *,
    agent_params,
) -> Tuple[Boid, evo.TrainingData]:
    # agent_params has structure FrozenDict(a=..., b=...)
    ...
    # Then return data with the same structure
    training_data = FrozenDict(
        a=evo.TrainingData(rewards=rewards_a, records=pos_a),
        b=evo.TrainingData(rewards=rewards_b, records=pos_b)
    )
    return boids, training_data

Esquilax will then handle mapping over the individual strategies during training. Strategies will be updated and queried independently, but can be made to interact via the simulation.