baselax.agents
baselax.agents module provides RL agents implementations with JAX.
baselax.agents.BaseAgent
- class baselax.agents.BaseAgent(network: Callable[[Space], Transformed], env: Env, learning_rate: Union[float, Callable[[Union[ndarray, float, int]], Union[ndarray, float, int]]])
Bases:
ABCAn abstract BaseAgent class.
- Parameters
network (Callable[[gym.Space], haiku.Transformed]) – A network function creator that takes the action space as input and return the network function.
env (gym.Env) – a gym environment wrapper with batch input and output.
learning_rate (Union[float, optax.Schedule]) – a learning rate or learning rate schedule for the optimizer.
- class OptimState(count, opt_state)
Bases:
tuple- count
Alias for field number 0
- opt_state
Alias for field number 1
- init_optimizer(params: Params, rng: PRNGSequence) OptimState
Initialize the optimizer state.
- Parameters
params (BaseAgentParams) – the agent parameters.
rng (haiku.PRNGSequence) – the random number generator.
- Returns
the initialized optimizer state.
- Return type
BaseAgentOptimState
- init_params(rng: PRNGSequence) Params
Initialize the parameters for the agent
- Parameters
rng (haiku.PRNGSequence) – the random number generator.
- Returns
the initialized parameters.
- Return type
BaseAgentParams
- init_policy(rng: PRNGSequence) PolicyState
Initialize the policy state.
- Parameters
rng (haiku.PRNGSequence) – the random number generator.
- Returns
the initialized policy state.
- Return type
BaseAgentPolicyState
- jit()
Jitting agent predict and update methods for speeding up
Examples
agent = MyAgent(…) agent.jit()
- abstract predict(params: Params, policy_state: PolicyState, obs: DeviceArrayBase, key: PRNGSequence, evaluation: bool, **kwargs) Tuple[PolicyOutput, PolicyState]
Select actions for the agent with given observations.
- Parameters
params (BaseAgentParams) – the agent parameters.
policy_state (BaseAgentPolicyState) – the policy state.
obs (jnp.DeviceArray) – observations from the environment.
key (haiku.PRNGSequence) – the random number generator that can be used for exploration.
evaluation (bool) – whether to evaluate the policy or not.
- Returns
return the policy output and the updated policy state.
- Return type
Tuple[BaseAgentPolicyOutput, BaseAgentPolicyState]
- abstract update(params: Params, optimizer_state: OptimState, data: Mapping[str, DeviceArrayBase], **kwargs) Tuple[OptimOutput, Params, OptimState]
Update the agent policy with given parameters
- Parameters
params (BaseAgentParams) – the agent parameters.
optimizer_state (BaseAgentOptimState) – the current optimizer state.
data (Mapping[str, jnp.DeviceArray]) – the training data dict which may includes the observations, actions, rewards, etc.
- Returns
return the optimization output, updated agent parameters, and the updated optimizer state.
- Return type
Tuple[BaseAgentOptimOutput, BaseAgentParams, BaseAgentOptimState]
baselax.agents.DQN
- class baselax.agents.DQN(network: ~typing.Callable[[~gym.spaces.space.Space], ~haiku._src.transform.Transformed], env: ~gym.core.Env, learning_rate: ~typing.Union[float, ~typing.Callable[[~typing.Union[~jax._src.numpy.ndarray.ndarray, float, int]], ~typing.Union[~jax._src.numpy.ndarray.ndarray, float, int]]], discount_factor: float = 0.99, epsilon_schedule: ~typing.Callable[[~typing.Union[~jax._src.numpy.ndarray.ndarray, float, int]], ~typing.Union[~jax._src.numpy.ndarray.ndarray, float, int]] = <function polynomial_schedule.<locals>.schedule>, target_update_interval: int = 50)
Bases:
BaseAgentThe implementation of a DQN agent
- Parameters
network (Callable[[gym.Space], haiku.Transformed]) – A network function creator that takes the action space as input and return the network function.
env (gym.Env) – A Gym environment for initializing the observation and action spaces.
learning_rate (Union[float, optax.Schedule]) – A learning rate that can be a float number or an optax.Schedule object.
discount_factor (float, optional) – The discount factor. Defaults to 0.99.
epsilon_schedule (optax.Schedule, optional) – The epsilon-greedy schedule. Defaults to optax.polynomial_schedule(init_value=0.9, end_value=0.05, power=1., transition_steps=50000).
target_update_interval (int, optional) – The update interval of the target network. Defaults to 50.
- class OptimState(count, opt_state)
Bases:
tuple- count
Alias for field number 0
- opt_state
Alias for field number 1
- class Params(policy, target)
Bases:
tuple- policy
Alias for field number 0
- target
Alias for field number 1
- class PolicyOutput(actions, q_values, epsilon)
Bases:
tuple- actions
Alias for field number 0
- epsilon
Alias for field number 2
- q_values
Alias for field number 1
- init_optimizer(params: Params, rng: PRNGSequence) OptimState
Initialize the optimizer state.
- Parameters
params (Params) – the agent parameters.
rng (haiku.PRNGSequence) – the random number generator.
- Returns
the initialized optimizer state.
- Return type
- init_params(rng: PRNGSequence) Params
Initialize the parameters for the agent
- Parameters
rng (haiku.PRNGSequence) – the random number generator.
- Returns
the initialized parameters.
- Return type
- init_policy(rng: PRNGSequence) PolicyState
Initialize the policy state.
- Parameters
rng (haiku.PRNGSequence) – the random number generator.
- Returns
the initialized policy state.
- Return type
- predict(params: Params, policy_state: PolicyState, obs: DeviceArrayBase, key: PRNGSequence, evaluation: bool, **kwargs) Tuple[PolicyOutput, PolicyState]
Select actions for the agent with given observations.
- Parameters
params (Params) – the agent parameters.
policy_state (PolicyState) – the policy state.
obs (jnp.DeviceArray) – observations from the environment.
key (haiku.PRNGSequence) – the random number generator that can be used for exploration.
evaluation (bool) – whether to evaluate the policy or not.
- Returns
return the policy output and the updated policy state.
- Return type
Tuple[PolicyOutput, PolicyState]
- update(params: Params, optimizer_state: OptimState, data: Mapping[str, DeviceArrayBase], **kwargs) Tuple[OptimOutput, Params, OptimState]
Update the agent policy with given parameters
- Parameters
params (Params) – the agent parameters.
optimizer_state (OptimState) – the current optimizer state.
data (Mapping[str, jnp.DeviceArray]) – the training data dict which may includes the observations, actions, rewards, etc.
- Returns
return the optimizer output, updated agent parameters, and the updated optimizer state.
- Return type
Tuple[OptimOutput, Params, OptimState]