zero.training.learn¶
-
zero.training.
learn
(model, optimizer, loss_fn, step, batch, star=False)[source]¶ The “default” training step.
The function does the following:
Switches the model to the training mode, sets its gradients to zero.
Performs the call
step(batch)
orstep(*batch)
The output from the previous step is passed to
loss_fn
torch.Tensor.backward
is applied to the obtained loss tensor.The optimization step is performed.
Returns the loss’s value (float) and
step
’s output
- Parameters
model (torch.nn.modules.module.Module) – the model to train
optimizer (torch.optim.optimizer.Optimizer) – the optimizer for
model
loss_fn (Callable[.., torch.Tensor]) – the function that takes
step
’s output as input and returns a loss tensorstep (Callable[T, Any]) – the function that takes
batch
as input and produces input forloss_fn
batch (T) – input for
step
star (bool) – if True, then the output of
step
is unpacked when passed toloss_fn
, i.e.loss_fn(*step_output)
is performed instead ofloss_fn(step_output)
- Returns
(loss_value, step_output)
- Return type
Tuple[float, Any]
Note
After the function returns:
model
’s gradients (caused by backward) are preservedmodel
’s state (training or not) is undefined
Examples
model = ... optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) loss_fn = torch.nn.MSELoss() def step(batch): X, y = batch return model(X), y for epoch in epoches: for batch in batches: learn(model, optimizer, loss_fn, step, batch, True)
model = ... optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) def step(batch): X, y = batch return {'y_pred': model(X), 'y': y} loss_fn = lambda out: torch.nn.functional.mse_loss(out['y_pred'], out['y']) for epoch in epoches: for batch in batches: learn(model, optimizer, loss_fn, step, batch)