adaptdl.torch.epoch module

This module provides tools for the top-level loop over epochs during training. AdaptDL expects the training program to be implemented as loop over several epochs, each containing a series of loops over datasets (e.g. one loop over the training set followed by one loop over the validation set). The program can be interrupted between every iteration of any dataset loop, trigger a checkpoint to be taken, and restarted using a different set of replicas.

Due to checkpoint-restarts, parts of the training program may be executed multiple times (e.g. once after each restart)! To avoid incorrect execution, ensure that your code is idempotent in the following locations:

  1. Immediately before any epoch loop (using remaining_epochs_until()).

  2. Immediately before any dataset loop (using adaptdl.torch.data.AdaptiveDataLoader).

Your code may be non-idempotent in other locations.

### IDEMPOTENT CODE ONLY ###

for epoch in remaining_epochs_until(30):

    ### IDEMPOTENT CODE ONLY ###

    for batch in train_loader:
        # ... any code ...

    ### IDEMPOTENT CODE ONLY ###

    for batch in valid_loader:
        # ... any code ...

    # ... any code ...

# ... any code ...

### END PROGRAM ###

For example, a common non-idempotent operation is learning-rate annealing:

for epoch in remaining_epochs_until(30):

    lr_scheduler.step()  # (A) WRONG!

    for batch in train_loader:
        # ...

    lr_scheduler.step()  # (B) WRONG!

    for batch in valid_loader:
        # ...

    lr_scheduler.step()  # (C) OK!

Location (A) will be executed again after any checkpoint-restart during either the training or validation loop, resulting in the learning rate being annealed several times in one epoch! Similarly with location (B), if checkpoint-restart happens during the validation loop.

Location (C) results in the correct behavior, because (1) an epoch will not be repeated once it has finished, and (2) no checkpoint-restarts can occur between the learning rate annealing and the end of the epoch.

adaptdl.torch.epoch.current_epoch()[source]

Get the current epoch while iterating with remaining_epochs_until().

Returns

The current epoch number if called from within a remaining_epochs_until() iteration, None otherwise.

Return type

int or None

adaptdl.torch.epoch.finished_epochs()[source]

Get the number of epochs finished using remaining_epochs_until().

Returns

The number of finished epochs. Equal to current_epoch() if called from within a remaining_epochs_until() iteration.

Return type

int

adaptdl.torch.epoch.remaining_epochs_until(epoch)[source]

Iterate over epochs in a way that is consistent with checkpoint-restarts. For example:

for epoch in remaining_epochs_until(30):
    print(current_epoch())  # Should print 0 through 29

for epoch in remaining_epochs_until(60):
    print(current_epoch())  # Should print 30 through 59

If a checkpoint-restart happens during an epoch, all previous epochs will be skipped after the program restarts.

Parameters

epoch (int) – The epoch number to end at (exclusively).

Raises

RuntimeError – If invoked before a previous epoch loop has ended.