esquilax.ml.rl.SharedPolicyAgent

class esquilax.ml.rl.SharedPolicyAgent(*args, **kwargs)

Bases: Agent

Agent with a single trained policy

apply_grads(*, grads, **kwargs) typing_extensions.Self

Apply gradients to the agent parameters and optimiser

Parameters:
  • grads – Gradients corresponding to the agent parameters.

  • **kwargs – Any keyword arguments to pass to underlying apply_gradients function.

Returns:

Updated agent

Return type:

esquilax.ml.rl.SharedPolicyAgent

classmethod init(key: chex.PRNGKey, model: flax.linen.Module, tx: optax.GradientTransformation, observation_shape: Tuple[int, Ellipsis]) typing_extensions.Self

Initialise the agent from a network

Parameters:
  • key – JAX random key.

  • model – Flax neural network model definition.

  • tx – Optax optimiser.

  • observation_shape – Shape of observations

Returns:

Initialised agent.

Return type:

esquilax.ml.rl.SharedPolicyAgent