baselax.agents

baselax.agents module provides RL agents implementations with JAX.

baselax.agents.BaseAgent

class baselax.agents.BaseAgent(network: Callable[[Space], Transformed], env: Env, learning_rate: Union[float, Callable[[Union[ndarray, float, int]], Union[ndarray, float, int]]])

Bases: ABC

An abstract BaseAgent class.

Parameters
  • network (Callable[[gym.Space], haiku.Transformed]) – A network function creator that takes the action space as input and return the network function.

  • env (gym.Env) – a gym environment wrapper with batch input and output.

  • learning_rate (Union[float, optax.Schedule]) – a learning rate or learning rate schedule for the optimizer.

class OptimOutput(loss)

Bases: tuple

loss

Alias for field number 0

class OptimState(count, opt_state)

Bases: tuple

count

Alias for field number 0

opt_state

Alias for field number 1

class Params(policy)

Bases: tuple

policy

Alias for field number 0

class PolicyOutput(actions)

Bases: tuple

actions

Alias for field number 0

class PolicyState(count)

Bases: tuple

count

Alias for field number 0

init_optimizer(params: Params, rng: PRNGSequence) OptimState

Initialize the optimizer state.

Parameters
  • params (BaseAgentParams) – the agent parameters.

  • rng (haiku.PRNGSequence) – the random number generator.

Returns

the initialized optimizer state.

Return type

BaseAgentOptimState

init_params(rng: PRNGSequence) Params

Initialize the parameters for the agent

Parameters

rng (haiku.PRNGSequence) – the random number generator.

Returns

the initialized parameters.

Return type

BaseAgentParams

init_policy(rng: PRNGSequence) PolicyState

Initialize the policy state.

Parameters

rng (haiku.PRNGSequence) – the random number generator.

Returns

the initialized policy state.

Return type

BaseAgentPolicyState

jit()

Jitting agent predict and update methods for speeding up

Examples

agent = MyAgent(…) agent.jit()

abstract predict(params: Params, policy_state: PolicyState, obs: DeviceArrayBase, key: PRNGSequence, evaluation: bool, **kwargs) Tuple[PolicyOutput, PolicyState]

Select actions for the agent with given observations.

Parameters
  • params (BaseAgentParams) – the agent parameters.

  • policy_state (BaseAgentPolicyState) – the policy state.

  • obs (jnp.DeviceArray) – observations from the environment.

  • key (haiku.PRNGSequence) – the random number generator that can be used for exploration.

  • evaluation (bool) – whether to evaluate the policy or not.

Returns

return the policy output and the updated policy state.

Return type

Tuple[BaseAgentPolicyOutput, BaseAgentPolicyState]

abstract update(params: Params, optimizer_state: OptimState, data: Mapping[str, DeviceArrayBase], **kwargs) Tuple[OptimOutput, Params, OptimState]

Update the agent policy with given parameters

Parameters
  • params (BaseAgentParams) – the agent parameters.

  • optimizer_state (BaseAgentOptimState) – the current optimizer state.

  • data (Mapping[str, jnp.DeviceArray]) – the training data dict which may includes the observations, actions, rewards, etc.

Returns

return the optimization output, updated agent parameters, and the updated optimizer state.

Return type

Tuple[BaseAgentOptimOutput, BaseAgentParams, BaseAgentOptimState]

baselax.agents.DQN

class baselax.agents.DQN(network: ~typing.Callable[[~gym.spaces.space.Space], ~haiku._src.transform.Transformed], env: ~gym.core.Env, learning_rate: ~typing.Union[float, ~typing.Callable[[~typing.Union[~jax._src.numpy.ndarray.ndarray, float, int]], ~typing.Union[~jax._src.numpy.ndarray.ndarray, float, int]]], discount_factor: float = 0.99, epsilon_schedule: ~typing.Callable[[~typing.Union[~jax._src.numpy.ndarray.ndarray, float, int]], ~typing.Union[~jax._src.numpy.ndarray.ndarray, float, int]] = <function polynomial_schedule.<locals>.schedule>, target_update_interval: int = 50)

Bases: BaseAgent

The implementation of a DQN agent

Parameters
  • network (Callable[[gym.Space], haiku.Transformed]) – A network function creator that takes the action space as input and return the network function.

  • env (gym.Env) – A Gym environment for initializing the observation and action spaces.

  • learning_rate (Union[float, optax.Schedule]) – A learning rate that can be a float number or an optax.Schedule object.

  • discount_factor (float, optional) – The discount factor. Defaults to 0.99.

  • epsilon_schedule (optax.Schedule, optional) – The epsilon-greedy schedule. Defaults to optax.polynomial_schedule(init_value=0.9, end_value=0.05, power=1., transition_steps=50000).

  • target_update_interval (int, optional) – The update interval of the target network. Defaults to 50.

class OptimOutput(loss)

Bases: tuple

loss

Alias for field number 0

class OptimState(count, opt_state)

Bases: tuple

count

Alias for field number 0

opt_state

Alias for field number 1

class Params(policy, target)

Bases: tuple

policy

Alias for field number 0

target

Alias for field number 1

class PolicyOutput(actions, q_values, epsilon)

Bases: tuple

actions

Alias for field number 0

epsilon

Alias for field number 2

q_values

Alias for field number 1

class PolicyState(count)

Bases: tuple

count

Alias for field number 0

init_optimizer(params: Params, rng: PRNGSequence) OptimState

Initialize the optimizer state.

Parameters
  • params (Params) – the agent parameters.

  • rng (haiku.PRNGSequence) – the random number generator.

Returns

the initialized optimizer state.

Return type

OptimState

init_params(rng: PRNGSequence) Params

Initialize the parameters for the agent

Parameters

rng (haiku.PRNGSequence) – the random number generator.

Returns

the initialized parameters.

Return type

Params

init_policy(rng: PRNGSequence) PolicyState

Initialize the policy state.

Parameters

rng (haiku.PRNGSequence) – the random number generator.

Returns

the initialized policy state.

Return type

PolicyState

predict(params: Params, policy_state: PolicyState, obs: DeviceArrayBase, key: PRNGSequence, evaluation: bool, **kwargs) Tuple[PolicyOutput, PolicyState]

Select actions for the agent with given observations.

Parameters
  • params (Params) – the agent parameters.

  • policy_state (PolicyState) – the policy state.

  • obs (jnp.DeviceArray) – observations from the environment.

  • key (haiku.PRNGSequence) – the random number generator that can be used for exploration.

  • evaluation (bool) – whether to evaluate the policy or not.

Returns

return the policy output and the updated policy state.

Return type

Tuple[PolicyOutput, PolicyState]

update(params: Params, optimizer_state: OptimState, data: Mapping[str, DeviceArrayBase], **kwargs) Tuple[OptimOutput, Params, OptimState]

Update the agent policy with given parameters

Parameters
  • params (Params) – the agent parameters.

  • optimizer_state (OptimState) – the current optimizer state.

  • data (Mapping[str, jnp.DeviceArray]) – the training data dict which may includes the observations, actions, rewards, etc.

Returns

return the optimizer output, updated agent parameters, and the updated optimizer state.

Return type

Tuple[OptimOutput, Params, OptimState]