User Guide¶
Concepts¶
Parameters & State¶
Data in an Esquilax model is generally broken down into two components:
Parameters: These are generally model hyperparameters and are shared (i.e. broadcast) across agents and are static over time.
State: These represent the current state of the model, are mapped across, and updated when applying updates of the model.
For example in a model of a swarm, parameters might be:
The min/max speed of the agents
The steering parameters of the agents
and the state might be
The positions and velocities of each agent
PyTrees¶
JAX has the concept of a PyTree, a tree structure generated from Python containers. In Esquilax the state of agents/entities is generally represented by a PyTree of arrays, where the length of the arrays corresponds to the number of entities/agents. For a simple model this could well be a single array
agent_state = jnp.zeros(10)
or a container/dataclass of multiple arrays of data
agent_state = {"x": jnp.zeros(10), "y": jnp.zeros(10)}
agent_state = (jnp.zeros(10), jnp.zeros(10))
@chex.dataclass
class State:
x: float
y: float
agent_state = State(
x=jnp.zeros(10), y=jnp.zeros(10)
)
or nested combinations of these
agent_state = (
State(x=jnp.zeros(10), y=jnp.zeros(10)),
jnp.zeros(10),
)
As long as each array in the tree has the same length (i.e. in axis 0), Esquilax will handle mapping over the tree structure.
Map-Reduce Interaction Patterns¶
Esquilax is largely designed around the notion that in a multi-agent system, state is updated by entities performing the following steps
Observe the (local) state of the system
Take some action or update their state
and then we might also
Update the state of the system according to some model dynamics
For example in a swarm models we might have the following steps:
Each agent observes the positions and velocities of agents within a given range
Each agent updates it velocity vector based on this observation
The position of each agent is updated using their updated velocity
This process may be familiar from the observation-action loop (or Markov decision process) often seen in reinforcement learning problems. Esquilax intends to extend this paradigm to large-numbers of agents performing observations and updates in parallel.
As such Esquilax employs map-reduce patterns to apply updates:
Observe: Map an observation function over pairs of agents, then aggregate (i.e. reduce) the observations into a single observation for each agent.
Update: Map an update function over a set of agents.
Esquilax handles the process of mapping observations/updates over sets of agents, and allows these updates to be composed to build complex models. In each case model parameters are broadcast across all agents.
For example, in a model where agents are nodes on a graph we might have an update where each agent observes its neighbour on the graph. Using Esquilax this could look like
@esquilax.transforms.graph_reduce(jnp.add, 0)
def collect_opinions(_, params, a, b, edge_weight):
return params.scale * edge_weight * (a - b)
This function then:
Maps
collect_opinions
across graph edges (in parallel), calculatingparams.scale * edge_weight * (a - b)
the difference of the node values, scaled by the weight assigned to the edge and a shared scaling parameter.
Add up contributions from edges based on the start node, creating an array of results for each agent on the graph. In the case an agent has no edges the default value
0
is returned.
Some transformations use variations on this pattern. For example a transformation might select a random neighbour, and the apply a observation function to the agent and the randomly selected neighbour.
Parallelisation¶
Esquilax attempts to maintain performance as agent numbers scale by ensuring observations/updates can be performed in parallel where possible. The performance benefits come with some constraints.
In particular any reductions must be implemented as a monoid, i.e. a
function that takes two argument of a given type, and returns the same
type like (a, a) -> a
, and also has an identity/default value.
In the example above, the graph observation uses the function jnp.add
along with identity 0
, but other options could be:
jnp.min
with default valuejnp.finfo(jnp.float32).max
to get the minimum over neighboursjnp.logical_or
with default valueFalse
to check if any neighbour satisfies a condition.
or you can define your own reduction function.
Neuro-Evolution and Reinforcement Learning¶
Esquilax provides utilities and functionality for training agent policies, where
an Esquilax simulation is used as a multi-agent training environment. They allow
for multiple strategies or RL policies to be trained inside the same training loop.
See esquilax.ml
for more details.
Tips¶
Extending Functionality¶
Esquilax transformations can be used alongside other JAX code and with other JAX libraries, allowing Esquilax transformations to be combined with customised functionality. For custom behaviours to be used inside the Esquilax simulation runner, the only requirement is that it can be JIT compiled.
Static Features¶
JAX requires certain values to be known at compile time such as certain data dimension, and functions passed as arguments to JIT compiled function.
In some cases you may want to use a function inside a transformation without writing the function ahead of time.
For example you may want to use a Flax network inside a transformation, without having to initialise the network when writing the model (for instance if you want to vary network parameters). The networks forward pass function (initialised when the network is initialised can be passed as an argument to the inner function, but is required to be static, i.e this
@esquilax.transforms.amap
def foo(_k, f, x):
return f(x)
will not work. Instead static arguments can be passed as additional keyword arguments to the expected transformation signature, like
@esquilax.transforms.amap
def foo(_k, _, x, *, f):
return f(x)
results = foo(k, None, y, f=network_func)
will mark the keyword arguments as static during compilation.