Hard-Coded Boids

This example implements the popular boids swarming model first developed by Reynolds. The update algorithm implemented here is adapted from this demo.

State

We first import JAX, Chex, and Esquilax

from functools import partial
import chex
import jax
import jax.numpy as jnp
import esquilax

State can be represented by any PyTree (e.g. a dict or tuple), but in this case we will use a chex dataclass for readability

@chex.dataclass
class Boid:
    pos: chex.Array
    vel: chex.Array

@chex.dataclass
class Params:
    cohesion: float
    avoidance: float
    alignment: float
    max_speed: float
    min_speed: float
    close_range: float

The Boid class stores the state of the boids, their current positions and velocities. The Params class stores parameters used for the steering algorithm and when updating agent positions.

Updates

We then use Esquilax to implement observation and update transformations. Firstly agents observe the state of neighbours within a given range

reduction = esquilax.reductions.Reduction(
    fn=(jnp.add, jnp.add, jnp.add, jnp.add),
    id=(0, jnp.zeros(2), jnp.zeros(2), jnp.zeros(2)),
)

@partial(
    esquilax.transforms.spatial,
    i_range=0.2,
    reduction=reduction,
    include_self=False,
)
def observe(params: Params, a: Boid, b: Boid):
    v = esquilax.utils.shortest_vector(b.pos, a.pos)
    d = jnp.sum(v**2)

    close_vec = jax.lax.cond(
        d < params.close_range**2,
        lambda: v,
        lambda: jnp.zeros(2),
    )

    return 1, b.pos, b.vel, close_vec

This function is mapped across all pairs of agents within range of each other. The function calculates the distance between the two agents, and then returns a tuple containing:

  • 1 to count the neighbour

  • The position of the neighbour

  • The velocity of the neighbour

  • The vector to the neighbour if it within collision range

The values are summed up, specified by the tuple of reduction functions (jnp.add, jnp.add, jnp.add, jnp.add) which form a monoid with the default values (0, jnp.zeros(2), jnp.zeros(2), jnp.zeros(2)). The space is subdivided into 5 cells along each dimension, and agents do not include themselves in the observation by setting include_self=False. The result of this transformation is a tuple of arrays with combined observations for each individual agent.

The next transformation combines the observations into a steering vector

@esquilax.transforms.amap
def steering(params: Params, observations):
    x, v, n_nb, x_nb, v_nb, v_cl = observations

    def steer():
        x_nb_avg = x_nb / n_nb
        v_nb_avg = v_nb / n_nb
        _dv_x = params.cohesion * esquilax.utils.shortest_vector(x, x_nb_avg)
        _dv_v = params.alignment * esquilax.utils.shortest_vector(v, v_nb_avg)
        return _dv_x + _dv_v

    dv_nb = jax.lax.cond(n_nb > 0, steer, lambda: jnp.zeros(2))
    v = v + dv_nb + v_cl

    return v

observations is a tuple of agent states, and the observations from observe. This function checks if the agent observed any neighbours, and if so combines these values into a single steering vector. The function is mapped across the argument data, and so produces a new velocity for each agent.

We then have two functions that rescales the agents velocity, and then updates their position

@esquilax.transforms.amap
def limit_speed(params: Params, v: chex.Array):
    s = jnp.sqrt(jnp.sum(v * v))

    v = jax.lax.cond(
        s < params.min_speed,
        lambda _v: params.min_speed * _v / s,
        lambda _v: _v,
        v,
    )

    v = jax.lax.cond(
        s > params.max_speed,
        lambda _v: params.max_speed * _v / s,
        lambda _v: _v,
        v
    )

    return v


@esquilax.transforms.amap
def move(_params: Params, x):
    pos, vel = x
    return (pos + vel) % 1.0

These functions are also mapped across all the argument data, and so effectively scale the velocity and update positions of all the agents.

Step Function

The step function defines how the state of the simulation is updated, it should have the signature

step(i, k, params, state) -> (state, records)

where i is the current step number, k a JAX random key, params any parameters that are static over the simulation, and state the simulation state. It should return a tuple containing the updated state, and any data to be recorded over the course of the simulation.

For the boids model this looks like:

def step(_i, _k, params: Params, boids: Boid):
    n_nb, x_nb, v_nb, v_cl = observe(params, boids, boids, pos=boids.pos)

    vel = steering(
        params,
        (boids.pos, boids.vel, n_nb, x_nb, v_nb, v_cl)
    )
    vel = limit_speed(params, vel)
    pos = move(params, (boids.pos, vel))

    return Boid(pos=pos, vel=vel), pos

Each step the agents observe their neighbours, update and scale their velocities, and update positions. It then returns the updates state, and the positions of the agents are recorded at each step.

Initialise and Run

We can then initialise and run the simulation using JAX random sampling, and the Esquilax sim_runner function

def boids_sim(n: int, n_steps: int, show_progress: bool = True):
    k = jax.random.PRNGKey(101)
    k1, k2 = jax.random.split(k)

    pos = jax.random.uniform(k1, (n, 2))
    vel = 0.01 * jax.random.uniform(k2, (n, 2))
    boids = Boid(pos=pos, vel=vel)

    params = Params(
        cohesion=0.001,
        avoidance=0.05,
        alignment=0.05,
        max_speed=0.05,
        min_speed=0.01,
        close_range=0.02,
    )

    _, history, _ = esquilax.sim_runner(
        step, params, boids, n_steps, k, show_progress=show_progress
    )

    return history

trajectories = boids_sim(
    5, 20, show_progress=False
)