This class is a generic class to train any agent in any environment. More...
Public Member Functions | |
def | __init__ (self, Agent agent, Dict[str, Any] config, Optional[Callable[[np.ndarray, Tuple[int,...]], np.ndarray]] reshape_func=None) |
None | evaluate_agent (self) |
Method to evaluate a trained model. More... | |
bool | is_eval (self) |
Check if environment is to be run in evaluation mode or not. More... | |
bool | is_train (self) |
Check if environment is to be run in training mode or not. More... | |
None | train_agent (self, bool render=False, bool load=False, bool plot=False, int verbose=-1, bool distributed_mode=False) |
Method to train the agent in the specified environment. More... | |
Data Fields | |
agent | |
The input RLPack agent to be run. More... | |
config | |
The input config for setup. More... | |
env | |
The gym environment on which the agent will run. More... | |
new_shape | |
The new shape requested in config to be used with reshape_func. More... | |
reshape_func | |
The input reshape function for states. More... | |
Private Member Functions | |
None | _generate_plot (self, Dict[int, List[float]] rewards_collector) |
Generates plot with matplotlib for Episodes vs. More... | |
None | _log (self, int ep, float mean_reward, bool distributed_mode, int verbose) |
Helper method to perform logging operations (both on console and cache). More... | |
None | _remove_log_file (self) |
Removes the log.txt file if it is present in the set save_path . More... | |
None | _write_log_file (self, List[str] log) |
Writes the logging messages from input to and saves it to set save_path as log.txt. More... | |
Static Private Member Functions | |
Union[None, float] | _list_mean (List[Union[float, int]] x) |
This function computes the mean of the input list. More... | |
np.ndarray | _reshape_func_default (np.ndarray x, Optional[Tuple[int,...]] shape=None) |
This is the default reshape function. More... | |
This class is a generic class to train any agent in any environment.
def rlpack.environments.environments.Environments.__init__ | ( | self, | |
Agent | agent, | ||
Dict[str, Any] | config, | ||
Optional[ Callable[[np.ndarray, Tuple[int, ...]], np.ndarray] ] | reshape_func = None |
||
) |
agent | Agent: The agent to be trained and/or evaluated in the environment specified in config . |
config | Dict[str, Any]: The configuration setting for experiment. |
reshape_func | Optional[Callable[[np.ndarray, Tuple[int, ...]], np.ndarray]]: The function to reshape the input states. Default: None. Default behavior is to not do any reshaping. |
|
private |
Generates plot with matplotlib
for Episodes vs.
rewards.
rewards_collector | Dict[int, List[float]]: Dict of lists of rewards collected in each episode. Each episode is present as a key. |
|
staticprivate |
This function computes the mean of the input list.
x | List[Union[float, int]]: The list for which mean is to be computed |
|
private |
Helper method to perform logging operations (both on console and cache).
ep | int: The episode which is currently being logged. |
mean_reward | float: The mean reward acquired between two successive calls of this method. |
distributed_mode | bool: Indicates if the environment is being run in distributed mode. |
verbose | bool: Indicates the verbose level. Refer notes for more details. This also refers to values logged on screen. If you want to disable the logging on screen, set logging level to WARNING. Default: -1 |
|
private |
Removes the log.txt
file if it is present in the set save_path
.
|
staticprivate |
This is the default reshape function.
If new_shape
has been set in config, input states are reshaped to new shapes, else returns the input as it is. Default behavior is not perform any reshaping.
x | np.ndarray: The input numpy array to reshape. |
shape | Optional[Tuple[int, ...]]: The new shape to which we want states to be reshaped. Default: None. |
|
private |
Writes the logging messages from input to and saves it to set save_path
as log.txt.
This method open files in append mode.
log | List[str]: The logging messages to write |
None rlpack.environments.environments.Environments.evaluate_agent | ( | self | ) |
Method to evaluate a trained model.
This method renders the environment and loads the model from save_path
. config must have set mode='eval' to run evaluation.
bool rlpack.environments.environments.Environments.is_eval | ( | self | ) |
Check if environment is to be run in evaluation mode or not.
bool rlpack.environments.environments.Environments.is_train | ( | self | ) |
Check if environment is to be run in training mode or not.
None rlpack.environments.environments.Environments.train_agent | ( | self, | |
bool | render = False , |
||
bool | load = False , |
||
bool | plot = False , |
||
int | verbose = -1 , |
||
bool | distributed_mode = False |
||
) |
Method to train the agent in the specified environment.
render | bool: Indicates if we wish to render the environment (in animation). Default: False. |
load | bool: Indicates weather to load a previously saved model or train a new one. If set true, config must be save_path or set or environment variable SAVE_PATH must be set. |
plot | bool: Indicates if to plot the training progress. If set True, rewards and episodes are recorded and plot is saved in save_path . |
verbose | bool: Indicates the verbose level. Refer notes for more details. This also refers to values logged on screen. If you want to disable the logging on screen, set logging level to WARNING. Default: -1 |
distributed_mode | Indicates if the environment is being run in distributed mode. Rewards are logged on console every reward_logging_frequency set in the console. |
Notes
Verbose levels:
rlpack.environments.environments.Environments.agent |
The input RLPack agent to be run.
rlpack.environments.environments.Environments.config |
The input config for setup.
rlpack.environments.environments.Environments.env |
The gym environment on which the agent will run.
rlpack.environments.environments.Environments.new_shape |
The new shape requested in config to be used with reshape_func.
rlpack.environments.environments.Environments.reshape_func |
The input reshape function for states.