esquilax.ml.rl.SharedPolicyAgent¶
- class esquilax.ml.rl.SharedPolicyAgent(*args, **kwargs)¶
Bases:
Agent
Agent with a single trained policy
- apply_grads(*, grads, **kwargs) → typing_extensions.Self¶
Apply gradients to the agent parameters and optimiser
- Parameters:
grads – Gradients corresponding to the agent parameters.
**kwargs – Any keyword arguments to pass to underlying
apply_gradients
function.
- Returns:
Updated agent
- Return type:
- classmethod init(key: chex.PRNGKey, model: flax.linen.Module, tx: optax.GradientTransformation, observation_shape: Tuple[int, Ellipsis]) → typing_extensions.Self¶
Initialise the agent from a network
- Parameters:
key – JAX random key.
model – Flax neural network model definition.
tx – Optax optimiser.
observation_shape – Shape of observations
- Returns:
Initialised agent.
- Return type: