RLPack
 
Loading...
Searching...
No Matches
rlpack.models.mlp.Mlp Class Reference

This class is a PyTorch Model implementing the MLP model for 1-D or 2-D state values. More...

+ Inheritance diagram for rlpack.models.mlp.Mlp:
+ Collaboration diagram for rlpack.models.mlp.Mlp:

Public Member Functions

def __init__ (self, int sequence_length, List[int] hidden_sizes, int num_actions, Activation activation=pytorch.nn.ReLU(), float dropout=0.5)
 Initialize Mlp model. More...
 
pytorch.Tensor forward (self, pytorch.Tensor x)
 The forwards method of the nn.Module. More...
 

Data Fields

 final_head
 The final head to produce logits for given action. More...
 
 flatten
 The object to flatten the output fo feature extractor. More...
 
 mlp_feature_extractor
 The feature extractor instance of rlpack.models._mlp_feature_extractor._MlpFeatureExtractor. More...
 

Detailed Description

This class is a PyTorch Model implementing the MLP model for 1-D or 2-D state values.

Constructor & Destructor Documentation

◆ __init__()

def rlpack.models.mlp.Mlp.__init__ (   self,
int  sequence_length,
List[int]  hidden_sizes,
int  num_actions,
Activation   activation = pytorch.nn.ReLU(),
float   dropout = 0.5 
)

Initialize Mlp model.

Parameters
sequence_lengthint: The sequence length of the expected tensor.
hidden_sizesList[int]: The list of hidden sizes for each layer.
num_actionsint: The number of actions for the environment.
activationActivation: The activation function class for the model. Must be an initialized activation object from PyTorch's nn (torch.nn) module.
dropoutfloat: The dropout to be used in the final Linear (FC) layer.

Member Function Documentation

◆ forward()

pytorch.Tensor rlpack.models.mlp.Mlp.forward (   self,
pytorch.Tensor  x 
)

The forwards method of the nn.Module.

Parameters
xpytorch.Tensor: The model input.
Returns
pytorch.Tensor: The model output (logits).

Field Documentation

◆ final_head

rlpack.models.mlp.Mlp.final_head

The final head to produce logits for given action.

◆ flatten

rlpack.models.mlp.Mlp.flatten

The object to flatten the output fo feature extractor.

◆ mlp_feature_extractor

rlpack.models.mlp.Mlp.mlp_feature_extractor