zero.training.learn

zero.training.learn(model, optimizer, loss_fn, step, batch, star=False)[source]

The “default” training step.

The function does the following:

  1. Switches the model to the training mode, sets its gradients to zero.

  2. Performs the call step(batch) or step(*batch)

  3. The output from the previous step is passed to loss_fn

  4. torch.Tensor.backward is applied to the obtained loss tensor.

  5. The optimization step is performed.

  6. Returns the loss’s value (float) and step’s output

Parameters
  • model (torch.nn.modules.module.Module) – the model to train

  • optimizer (torch.optim.optimizer.Optimizer) – the optimizer for model

  • loss_fn (Callable[.., torch.Tensor]) – the function that takes step’s output as input and returns a loss tensor

  • step (Callable[T, Any]) – the function that takes batch as input and produces input for loss_fn

  • batch (T) – input for step

  • star (bool) – if True, then the output of step is unpacked when passed to loss_fn, i.e. loss_fn(*step_output) is performed instead of loss_fn(step_output)

Returns

(loss_value, step_output)

Return type

Tuple[float, Any]

Note

After the function returns:

  • model’s gradients (caused by backward) are preserved

  • model’s state (training or not) is undefined

Examples

model = ...
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
loss_fn = torch.nn.MSELoss()

def step(batch):
    X, y = batch
    return model(X), y

for epoch in epoches:
    for batch in batches:
        learn(model, optimizer, loss_fn, step, batch, True)
model = ...
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

def step(batch):
    X, y = batch
    return {'y_pred': model(X), 'y': y}

loss_fn = lambda out: torch.nn.functional.mse_loss(out['y_pred'], out['y'])

for epoch in epoches:
    for batch in batches:
        learn(model, optimizer, loss_fn, step, batch)