esquilax.sim_runner¶
- esquilax.sim_runner(step_fun: Callable, params: Any, initial_state: chex.ArrayTree, n_steps: int, rng: chex.PRNGKey | int, show_progress: bool = True, pbar_id: int = 0, **static_kwargs) Tuple[Any, Any, chex.PRNGKey] ¶
Run a simulation and track state
Repeated applies a provided simulation update function and records simulation state over the course of execution.
- Parameters:
step_fun (
Callable
) –Update function that should have the signature
def step(i, k, params, state, **static_kwargs): ... return new_state, records
where the arguments are
i
: The current step numberk
: A JAX random keyparams
: Sim parametersstate
: The current simulation state**static_kwargs
: Static keyword arguments
and returns
new_state
: Updated simulation staterecords
: State data to be recorded
params – Simulation parameters. Parameters are constant over the course of the simulation.
initial_state – Initial simulation state.
n_steps – Number of steps to run.
rng – Either an integer random seed, or a JAX PRNGKey.
show_progress – If
True
a progress bar will be shown.pbar_id – Optional progress bar index, can be used to print multiple progress bars.
**static_kwargs – Any keyword static values passed to the step function. These should be used for any values or functionality required to be known at compile time by JAX.
- Returns:
Tuple containing
The final state of the simulation
Recorded values
Update random key
- Return type:
[Any
,Any
,chex.PRNGKey]