esquilax.sim.Sim¶
- class esquilax.sim.Sim¶
Bases:
Generic
[esquilax.typing.TSimParams
,esquilax.typing.TSimState
]Base class wrapping simulation functionality for batch execution and ML use-cases
Simulation environment base-class with methods used by training functions. For use inside JIT compiled training loops, static simulation parameters can be assigned as attributes of the derived class.
- abstract default_params() esquilax.typing.TSimParams ¶
Return default simulation parameters
- Returns:
Simulation parameters
- Return type:
- init_and_run(n_steps: int, key: chex.PRNGKey, show_progress: bool = True, params: esquilax.typing.TSimParams | None = None, pbar_id: int = 0, **step_kwargs) Tuple[esquilax.typing.TSimState, chex.ArrayTree] ¶
Convenience function to initialise and run the simulation
- Parameters:
n_steps – Number of simulation steps to run
key – JAX random key
show_progress – If
True
a progress bar will be displayed. DefaultTrue
params – Optional simulation parameters, if not provided default sim parameters will be used.
pbar_id – Optional progress bar index, can be used to print multiple progress bars.
**step_kwargs – Any additional keyword arguments passed to the step function. Arguments are static over the course of the simulation.
- Returns:
Tuple containing
The final state of the simulation
Tree of recorded data
Updated JAX random key
- Return type:
tuple[esquilax.typing.TSimState
,chex.ArrayTree
,jax,random.PRNGKey]
- abstract initial_state(key: chex.PRNGKey, params: esquilax.typing.TSimParams) esquilax.typing.TSimState ¶
Initialise the initial state of the simulation
- Parameters:
key – JAX random key.
params – Simulation parameters
- Returns:
The initial state of the environment.
- Return type:
- run(n_steps: int, key: chex.PRNGKey, params: esquilax.typing.TSimParams, initial_state: esquilax.typing.TSimState, show_progress: bool = True, pbar_id: int = 0, **step_kwargs: Any) Tuple[esquilax.typing.TSimState, chex.ArrayTree, chex.PRNGKey] ¶
Convenience function to run the simulation for a fixed number of steps
- Parameters:
n_steps – Number of simulation steps
key – JAX random key
params – Simulation time-independent parameters
initial_state – Initial state of the simulation
show_progress – If
True
a progress bar will be displayed. DefaultTrue
pbar_id – Optional progress bar index, can be used to print multiple progress bars.
**step_kwargs – Any additional keyword arguments passed to the step function. Arguments are static over the course of the simulation.
- Returns:
Tuple containing
The final state of the simulation
Tree of recorded data
Updated JAX random key
- Return type:
tuple[esquilax.typing.TSimState
,chex.ArrayTree
,jax,random.PRNGKey]
- abstract step(i: int, key: chex.PRNGKey, params: esquilax.typing.TSimParams, state: esquilax.typing.TSimState, **kwargs: Any) Tuple[esquilax.typing.TSimState, chex.ArrayTree] ¶
A single step/update of the environment
The step function should return a tuple containing the updated simulation state, and any data to be recorded each step (see
esquilax.runner.sim_runner()
for more details). For example:class Sim(SimEnv): def step(self, i, k, params, state): ... return new_state, records
Any static arguments required by the simulation can be accessed from the
self
argument of the method.- Parameters:
i – Current step number
key – JAX random key
params – Simulation time-independent parameters
state – Simulation state
**kwargs – Any additional keyword arguments.
- Returns:
Tuple containing the updated simulation state, and any data to be recorded over the course of the simulation.
- Return type:
tuple[esquilax.typing.TSimState
,chex.ArrayTree]