zero.training¶
Easier training process.
Eval¶
-
class
zero.training.
Eval
(*models)[source]¶ Context-manager for models evaluation.
Switches one or more models to the evaluation mode and turns off gradients when enters a context (not when constructed!) and reverts all the changes to the previous state when exits the context.
Before:
model.eval() with torch.no_grad(): ...
After:
with Eval(model): ...
- Parameters
*models (
torch.nn.Module
) –
Examples
a = torch.nn.Linear(1, 1) b = torch.nn.Linear(2, 2) with Eval(a): ... with Eval(a, b): ...
Tutorial
model = torch.nn.Linear(1, 1) grad_before_context = torch.is_grad_enabled() for training_before_context in False, True: model.train(training_before_context) with Eval(model): assert not model.training assert not torch.is_grad_enabled() assert model.training == training_before_context assert torch.is_grad_enabled() == grad_before_context
ProgressTracker¶
-
class
zero.training.
ProgressTracker
(patience, min_delta=0.0)[source]¶ Tracks the best score, facilitates early stopping.
For
ProgressTracker
, the greater score is the better score. At any moment the tracker is in one of the following states: - success: the last score updated the best score - fail: lastn > patience
updates are not better than the best score - neutral: if neither success nor fail- Parameters
patience – Allowed number of bad updates. For example, if patience is 2, then 2 bad updates is not a fail, but 3 bad updates is a fail.
min_delta – minimal improvement over current best score to count it as success.
Examples
progress = ProgressTracker(2) progress = ProgressTracker(3, 0.1)
Tutorial
progress = ProgressTracker(2) progress.update(-999999999) assert progress.success # the first update always updates the best score progress.update(123) assert progress.success assert progress.best_score == 123 progress.update(0) assert not progress.success and not progress.fail progress.update(123) assert not progress.success and not progress.fail progress.update(123) # patience is 2 and the best score is not updated for more than 2 steps assert progress.fail assert progress.best_score == 123 # fail doesn't affect the best score progress.update(123) assert progress.fail # still no improvements progress.forget_bad_updates() assert not progress.fail and not progress.success assert progress.best_score == 123 progress.update(0) assert not progress.fail # just 1 bad update (the patience is 2) progress.reset() assert not progress.fail and not progress.success assert progress.best_score is None
The best score so far. |
|
Check if the tracker is in the ‘success’ state. |
|
Check if the tracker is in the ‘fail’ state. |
|
|
Update the tracker’s state. |
Reset bad updates and status, but not the best score. |
|
Reset everything. |
functions¶
|
The “default” training step. |