Hard-Coded Boids¶
This example implements the popular boids swarming model first developed by Reynolds. The update algorithm implemented here is adapted from this demo.
State¶
We first import JAX, Chex, and Esquilax
from functools import partial
import chex
import jax
import jax.numpy as jnp
import esquilax
State can be represented by any PyTree (e.g. a dict or tuple), but in this case we will use a chex dataclass for readability
@chex.dataclass
class Boid:
pos: chex.Array
vel: chex.Array
@chex.dataclass
class Params:
cohesion: float
avoidance: float
alignment: float
max_speed: float
min_speed: float
close_range: float
The Boid
class stores the state of the boids, their current positions
and velocities. The Params
class stores parameters used for the steering
algorithm and when updating agent positions.
Updates¶
We then use Esquilax to implement observation and update transformations. Firstly agents observe the state of neighbours within a given range
reduction = esquilax.reductions.Reduction(
fn=(jnp.add, jnp.add, jnp.add, jnp.add),
id=(0, jnp.zeros(2), jnp.zeros(2), jnp.zeros(2)),
)
@partial(
esquilax.transforms.spatial,
i_range=0.2,
reduction=reduction,
include_self=False,
)
def observe(params: Params, a: Boid, b: Boid):
v = esquilax.utils.shortest_vector(b.pos, a.pos)
d = jnp.sum(v**2)
close_vec = jax.lax.cond(
d < params.close_range**2,
lambda: v,
lambda: jnp.zeros(2),
)
return 1, b.pos, b.vel, close_vec
This function is mapped across all pairs of agents within range of each other. The function calculates the distance between the two agents, and then returns a tuple containing:
1
to count the neighbourThe position of the neighbour
The velocity of the neighbour
The vector to the neighbour if it within collision range
The values are summed up, specified by the tuple of reduction functions
(jnp.add, jnp.add, jnp.add, jnp.add)
which form a monoid with
the default values (0, jnp.zeros(2), jnp.zeros(2), jnp.zeros(2))
. The
space is subdivided into 5
cells along each dimension, and
agents do not include themselves in the observation by setting include_self=False
.
The result of this transformation is a tuple of arrays with combined observations
for each individual agent.
The next transformation combines the observations into a steering vector
@esquilax.transforms.amap
def steering(params: Params, observations):
x, v, n_nb, x_nb, v_nb, v_cl = observations
def steer():
x_nb_avg = x_nb / n_nb
v_nb_avg = v_nb / n_nb
_dv_x = params.cohesion * esquilax.utils.shortest_vector(x, x_nb_avg)
_dv_v = params.alignment * esquilax.utils.shortest_vector(v, v_nb_avg)
return _dv_x + _dv_v
dv_nb = jax.lax.cond(n_nb > 0, steer, lambda: jnp.zeros(2))
v = v + dv_nb + v_cl
return v
observations
is a tuple of agent states, and the observations from observe
.
This function checks if the agent observed any neighbours, and if so combines
these values into a single steering vector. The function is mapped across the
argument data, and so produces a new velocity for each agent.
We then have two functions that rescales the agents velocity, and then updates their position
@esquilax.transforms.amap
def limit_speed(params: Params, v: chex.Array):
s = jnp.sqrt(jnp.sum(v * v))
v = jax.lax.cond(
s < params.min_speed,
lambda _v: params.min_speed * _v / s,
lambda _v: _v,
v,
)
v = jax.lax.cond(
s > params.max_speed,
lambda _v: params.max_speed * _v / s,
lambda _v: _v,
v
)
return v
@esquilax.transforms.amap
def move(_params: Params, x):
pos, vel = x
return (pos + vel) % 1.0
These functions are also mapped across all the argument data, and so effectively scale the velocity and update positions of all the agents.
Step Function¶
The step function defines how the state of the simulation is updated, it should have the signature
step(i, k, params, state) -> (state, records)
where i
is the current step number, k
a JAX random key, params
any parameters that are static over the simulation, and state
the simulation
state. It should return a tuple containing the updated state, and any data to be recorded
over the course of the simulation.
For the boids model this looks like:
def step(_i, _k, params: Params, boids: Boid):
n_nb, x_nb, v_nb, v_cl = observe(params, boids, boids, pos=boids.pos)
vel = steering(
params,
(boids.pos, boids.vel, n_nb, x_nb, v_nb, v_cl)
)
vel = limit_speed(params, vel)
pos = move(params, (boids.pos, vel))
return Boid(pos=pos, vel=vel), pos
Each step the agents observe their neighbours, update and scale their velocities, and update positions. It then returns the updates state, and the positions of the agents are recorded at each step.
Initialise and Run¶
We can then initialise and run the simulation using JAX random sampling, and the
Esquilax sim_runner
function
def boids_sim(n: int, n_steps: int, show_progress: bool = True):
k = jax.random.PRNGKey(101)
k1, k2 = jax.random.split(k)
pos = jax.random.uniform(k1, (n, 2))
vel = 0.01 * jax.random.uniform(k2, (n, 2))
boids = Boid(pos=pos, vel=vel)
params = Params(
cohesion=0.001,
avoidance=0.05,
alignment=0.05,
max_speed=0.05,
min_speed=0.01,
close_range=0.02,
)
_, history, _ = esquilax.sim_runner(
step, params, boids, n_steps, k, show_progress=show_progress
)
return history
trajectories = boids_sim(
5, 20, show_progress=False
)