RLPack
 
Loading...
Searching...
No Matches
rlpack.utils.base.agent.Agent Class Reference

The base class for all agents. More...

+ Inheritance diagram for rlpack.utils.base.agent.Agent:
+ Collaboration diagram for rlpack.utils.base.agent.Agent:

Public Member Functions

Dict[str, Any] __getstate__ (self)
 To get the agent's current state (dict of attributes). More...
 
def __init__ (self)
 The class initializer. More...
 
None __setstate__ (self, Dict[str, Any] state)
 To load the agent's current state (dict of attributes). More...
 
None load (self, *args, **kwargs)
 Load method for the agent. More...
 
Any policy (self, *args, **kwargs)
 Policy method for the agent. More...
 
None save (self, *args, **kwargs)
 Save method for the agent. More...
 
Any train (self, *args, **kwargs)
 Training method for the agent. More...
 

Data Fields

 loss
 The list of losses accumulated after each backward call. More...
 
 save_path
 The path to save agent states and models. More...
 

Static Private Member Functions

pytorch.Tensor _adjust_dims_for_tensor (pytorch.Tensor tensor, int target_dim)
 Helper function to adjust dimensions of tensor. More...
 
pytorch.Tensor _cast_to_tensor (Union[List, Tuple, np.ndarray, pytorch.Tensor] data)
 Helper function to cast data to tensor. More...
 

Private Attributes

 __dict__
 
 _advantage_norm_codes
 The Advantage normalisation codes; indicating the codes to normalise Advantages. More...
 
 _reward_norm_codes
 The reward normalisation codes; indicating the codes to normalise rewards. More...
 
 _state_norm_codes
 The state normalisation codes; indicating the codes to normalise states. More...
 
 _td_norm_codes
 The TD normalisation codes; indicating the codes to normalise TD Errors. More...
 

Detailed Description

The base class for all agents.

Constructor & Destructor Documentation

◆ __init__()

Member Function Documentation

◆ __getstate__()

Dict[str, Any] rlpack.utils.base.agent.Agent.__getstate__ (   self)

To get the agent's current state (dict of attributes).

Returns
Dict[str, Any]: The agent's states in dictionary.

◆ __setstate__()

None rlpack.utils.base.agent.Agent.__setstate__ (   self,
Dict[str, Any]  state 
)

To load the agent's current state (dict of attributes).

Parameters
stateDict[str, Any]: The agent's states in dictionary.

◆ _adjust_dims_for_tensor()

pytorch.Tensor rlpack.utils.base.agent.Agent._adjust_dims_for_tensor ( pytorch.Tensor  tensor,
int   target_dim 
)
staticprivate

Helper function to adjust dimensions of tensor.

This only works for tensors when they have a single axis. along any dimension and doesn't change underlying data or change the storage.

Parameters
tensorpytorch.Tensor: The tensor whose dimensions are required to be changed.
target_dimint: The target number of dimensions.
Returns
pytorch.Tensor: The tensor with adjusted dimensions.

◆ _cast_to_tensor()

pytorch.Tensor rlpack.utils.base.agent.Agent._cast_to_tensor ( Union[List, Tuple, np.ndarray, pytorch.Tensor]   data)
staticprivate

Helper function to cast data to tensor.

Parameters
dataUnion[List, Tuple, np.ndarray, pytorch.Tensor]: The data to convert to tensor.
Returns
pytorch.Tensor: The tensor from the input data.

◆ load()

None rlpack.utils.base.agent.Agent.load (   self,
args,
**  kwargs 
)

Load method for the agent.

This class needs to be overriden for every agent that inherits it. All necessary agent states and attributes must be loaded in the implementation such that training can be restarted.

Parameters
argsPositional arguments for load method.
kwargsKeyword arguments for load method.

Reimplemented in rlpack.actor_critic.a2c.A2C, and rlpack.dqn.dqn_agent.DqnAgent.

◆ policy()

Any rlpack.utils.base.agent.Agent.policy (   self,
args,
**  kwargs 
)

Policy method for the agent.

This class needs to be overriden for every agent that inherits it

Parameters
argsPositional arguments for policy method
kwargsKeyword arguments for policy method.
Returns
Any: Action to be taken.

Reimplemented in rlpack.dqn.dqn_agent.DqnAgent, and rlpack.actor_critic.a2c.A2C.

◆ save()

None rlpack.utils.base.agent.Agent.save (   self,
args,
**  kwargs 
)

Save method for the agent.

This class needs to be overriden for every agent that inherits it. All necessary agent states and attributes must be saved in the implementation such that training can be restarted.

Parameters
argsPositional arguments for save method.
kwargsKeyword arguments for save method.

Reimplemented in rlpack.actor_critic.a2c.A2C, and rlpack.dqn.dqn_agent.DqnAgent.

◆ train()

Any rlpack.utils.base.agent.Agent.train (   self,
args,
**  kwargs 
)

Training method for the agent.

This class needs to be overriden for every agent that inherits it.

Parameters
argsPositional arguments for train method.
kwargsKeyword arguments for train method.
Returns
Any: Action to be taken.

Reimplemented in rlpack.actor_critic.a2c.A2C, and rlpack.dqn.dqn_agent.DqnAgent.

Field Documentation

◆ __dict__

rlpack.utils.base.agent.Agent.__dict__
private

◆ _advantage_norm_codes

rlpack.utils.base.agent.Agent._advantage_norm_codes
private

The Advantage normalisation codes; indicating the codes to normalise Advantages.

◆ _reward_norm_codes

rlpack.utils.base.agent.Agent._reward_norm_codes
private

The reward normalisation codes; indicating the codes to normalise rewards.

◆ _state_norm_codes

rlpack.utils.base.agent.Agent._state_norm_codes
private

The state normalisation codes; indicating the codes to normalise states.

◆ _td_norm_codes

rlpack.utils.base.agent.Agent._td_norm_codes
private

The TD normalisation codes; indicating the codes to normalise TD Errors.

◆ loss

rlpack.utils.base.agent.Agent.loss

The list of losses accumulated after each backward call.

◆ save_path

rlpack.utils.base.agent.Agent.save_path

The path to save agent states and models.