esquilax.sim_runner

esquilax.sim_runner(step_fun: Callable, params: Any, initial_state: chex.ArrayTree, n_steps: int, rng: chex.PRNGKey | int, show_progress: bool = True, **static_kwargs) Tuple[Any, Any, chex.PRNGKey]

Run a simulation and track state

Repeated applies a provided simulation update function and records simulation state over the course of execution.

Parameters:
  • step_fun (Callable) –

    Update function that should have the signature

    def step(i, k, params, state, **static_kwargs):
        ...
        return new_state, records
    

    where the arguments are

    • i: The current step number

    • k: A JAX random key

    • params: Sim parameters

    • state: The current simulation state

    • **static_kwargs: Static keyword arguments

    and returns

    • new_state: Updated simulation state

    • records: State data to be recorded

  • params – Simulation parameters. Parameters are constant over the course of the simulation.

  • initial_state – Initial simulation state.

  • n_steps – Number of steps to run.

  • rng – Either an integer random seed, or a JAX PRNGKey.

  • show_progress – If True a progress bar will be shown.

  • **static_kwargs – Any keyword static values passed to the step function. These should be used for any values or functionality required to be known at compile time by JAX.

Returns:

Tuple containing

  • The final state of the simulation

  • Recorded values

  • Update random key

Return type:

[Any, Any, chex.PRNGKey]