The base class for all agents. More...
Public Member Functions | |
Dict[str, Any] | __getstate__ (self) |
To get the agent's current state (dict of attributes). More... | |
def | __init__ (self) |
The class initializer. More... | |
None | __setstate__ (self, Dict[str, Any] state) |
To load the agent's current state (dict of attributes). More... | |
None | load (self, *args, **kwargs) |
Load method for the agent. More... | |
Any | policy (self, *args, **kwargs) |
Policy method for the agent. More... | |
None | save (self, *args, **kwargs) |
Save method for the agent. More... | |
Any | train (self, *args, **kwargs) |
Training method for the agent. More... | |
Data Fields | |
loss | |
The list of losses accumulated after each backward call. More... | |
save_path | |
The path to save agent states and models. More... | |
Static Private Member Functions | |
pytorch.Tensor | _adjust_dims_for_tensor (pytorch.Tensor tensor, int target_dim) |
Helper function to adjust dimensions of tensor. More... | |
pytorch.Tensor | _cast_to_tensor (Union[List, Tuple, np.ndarray, pytorch.Tensor] data) |
Helper function to cast data to tensor. More... | |
Private Attributes | |
__dict__ | |
_advantage_norm_codes | |
The Advantage normalisation codes; indicating the codes to normalise Advantages. More... | |
_reward_norm_codes | |
The reward normalisation codes; indicating the codes to normalise rewards. More... | |
_state_norm_codes | |
The state normalisation codes; indicating the codes to normalise states. More... | |
_td_norm_codes | |
The TD normalisation codes; indicating the codes to normalise TD Errors. More... | |
The base class for all agents.
def rlpack.utils.base.agent.Agent.__init__ | ( | self | ) |
The class initializer.
Defines basic variables useful for all agents.
Reimplemented in rlpack.actor_critic.a2c.A2C, rlpack.actor_critic.a3c.A3C, rlpack.dqn.dqn_proportional_prioritization_agent.DqnProportionalPrioritizationAgent, rlpack.dqn.dqn_rank_based_prioritization_agent.DqnRankBasedPrioritizationAgent, and rlpack.dqn.dqn_agent.DqnAgent.
Dict[str, Any] rlpack.utils.base.agent.Agent.__getstate__ | ( | self | ) |
To get the agent's current state (dict of attributes).
None rlpack.utils.base.agent.Agent.__setstate__ | ( | self, | |
Dict[str, Any] | state | ||
) |
To load the agent's current state (dict of attributes).
state | Dict[str, Any]: The agent's states in dictionary. |
|
staticprivate |
Helper function to adjust dimensions of tensor.
This only works for tensors when they have a single axis. along any dimension and doesn't change underlying data or change the storage.
tensor | pytorch.Tensor: The tensor whose dimensions are required to be changed. |
target_dim | int: The target number of dimensions. |
|
staticprivate |
Helper function to cast data to tensor.
data | Union[List, Tuple, np.ndarray, pytorch.Tensor]: The data to convert to tensor. |
None rlpack.utils.base.agent.Agent.load | ( | self, | |
* | args, | ||
** | kwargs | ||
) |
Load method for the agent.
This class needs to be overriden for every agent that inherits it. All necessary agent states and attributes must be loaded in the implementation such that training can be restarted.
args | Positional arguments for load method. |
kwargs | Keyword arguments for load method. |
Reimplemented in rlpack.actor_critic.a2c.A2C, and rlpack.dqn.dqn_agent.DqnAgent.
Any rlpack.utils.base.agent.Agent.policy | ( | self, | |
* | args, | ||
** | kwargs | ||
) |
Policy method for the agent.
This class needs to be overriden for every agent that inherits it
args | Positional arguments for policy method |
kwargs | Keyword arguments for policy method. |
Reimplemented in rlpack.dqn.dqn_agent.DqnAgent, and rlpack.actor_critic.a2c.A2C.
None rlpack.utils.base.agent.Agent.save | ( | self, | |
* | args, | ||
** | kwargs | ||
) |
Save method for the agent.
This class needs to be overriden for every agent that inherits it. All necessary agent states and attributes must be saved in the implementation such that training can be restarted.
args | Positional arguments for save method. |
kwargs | Keyword arguments for save method. |
Reimplemented in rlpack.actor_critic.a2c.A2C, and rlpack.dqn.dqn_agent.DqnAgent.
Any rlpack.utils.base.agent.Agent.train | ( | self, | |
* | args, | ||
** | kwargs | ||
) |
Training method for the agent.
This class needs to be overriden for every agent that inherits it.
args | Positional arguments for train method. |
kwargs | Keyword arguments for train method. |
Reimplemented in rlpack.actor_critic.a2c.A2C, and rlpack.dqn.dqn_agent.DqnAgent.
|
private |
|
private |
The Advantage normalisation codes; indicating the codes to normalise Advantages.
|
private |
The reward normalisation codes; indicating the codes to normalise rewards.
|
private |
The state normalisation codes; indicating the codes to normalise states.
|
private |
The TD normalisation codes; indicating the codes to normalise TD Errors.
rlpack.utils.base.agent.Agent.loss |
The list of losses accumulated after each backward call.
rlpack.utils.base.agent.Agent.save_path |
The path to save agent states and models.