esquilax.sim.Sim

class esquilax.sim.Sim

Bases: Generic[esquilax.typing.TSimParams, esquilax.typing.TSimState]

Base class wrapping simulation functionality for batch execution and ML use-cases

Simulation environment base-class with methods used by training functions. For use inside JIT compiled training loops, static simulation parameters can be assigned as attributes of the derived class.

abstract default_params() esquilax.typing.TSimParams

Return default simulation parameters

Returns:

Simulation parameters

Return type:

esquilax.typing.TSimParams

init_and_run(n_steps: int, key: chex.PRNGKey, show_progress: bool = True, params: esquilax.typing.TSimParams | None = None, **step_kwargs) Tuple[esquilax.typing.TSimState, chex.ArrayTree]

Convenience function to initialise and run the simulation

Parameters:
  • n_steps – Number of simulation steps to run

  • key – JAX random key

  • show_progress – If True a progress bar will be displayed. Default True

  • params – Optional simulation parameters, if not provided default sim parameters will be used.

  • **step_kwargs – Any additional keyword arguments passed to the step function. Arguments are static over the course of the simulation.

Returns:

Tuple containing

  • The final state of the simulation

  • Tree of recorded data

  • Updated JAX random key

Return type:

tuple[esquilax.typing.TSimState, chex.ArrayTree, jax,random.PRNGKey]

abstract initial_state(key: chex.PRNGKey, params: esquilax.typing.TSimParams) esquilax.typing.TSimState

Initialise the initial state of the simulation

Parameters:
  • key – JAX random key.

  • params – Simulation parameters

Returns:

The initial state of the environment.

Return type:

esquilax.typing.TSimState

run(n_steps: int, key: chex.PRNGKey, params: esquilax.typing.TSimParams, initial_state: esquilax.typing.TSimState, show_progress: bool = True, **step_kwargs: Any) Tuple[esquilax.typing.TSimState, chex.ArrayTree, chex.PRNGKey]

Convenience function to run the simulation for a fixed number of steps

Parameters:
  • n_steps – Number of simulation steps

  • key – JAX random key

  • params – Simulation time-independent parameters

  • initial_state – Initial state of the simulation

  • show_progress – If True a progress bar will be displayed. Default True

  • **step_kwargs – Any additional keyword arguments passed to the step function. Arguments are static over the course of the simulation.

Returns:

Tuple containing

  • The final state of the simulation

  • Tree of recorded data

  • Updated JAX random key

Return type:

tuple[esquilax.typing.TSimState, chex.ArrayTree, jax,random.PRNGKey]

abstract step(i: int, key: chex.PRNGKey, params: esquilax.typing.TSimParams, state: esquilax.typing.TSimState, **kwargs: Any) Tuple[esquilax.typing.TSimState, chex.ArrayTree]

A single step/update of the environment

The step function should return a tuple containing the updated simulation state, and any data to be recorded each step (see esquilax.runner.sim_runner() for more details). For example:

class Sim(SimEnv):
    def step(self, i, k, params, state):
        ...
        return new_state, records

Any static arguments required by the simulation can be accessed from the self argument of the method.

Parameters:
  • i – Current step number

  • key – JAX random key

  • params – Simulation time-independent parameters

  • state – Simulation state

  • **kwargs – Any additional keyword arguments.

Returns:

Tuple containing the updated simulation state, and any data to be recorded over the course of the simulation.

Return type:

tuple[esquilax.typing.TSimState, chex.ArrayTree]