adaptdl.torch.epoch module
This module provides tools for the top-level loop over epochs during training. AdaptDL expects the training program to be implemented as loop over several epochs, each containing a series of loops over datasets (e.g. one loop over the training set followed by one loop over the validation set). The program can be interrupted between every iteration of any dataset loop, trigger a checkpoint to be taken, and restarted using a different set of replicas.
Due to checkpoint-restarts, parts of the training program may be executed multiple times (e.g. once after each restart)! To avoid incorrect execution, ensure that your code is idempotent in the following locations:
Immediately before any epoch loop (using
remaining_epochs_until()
).Immediately before any dataset loop (using
adaptdl.torch.data.AdaptiveDataLoader
).
Your code may be non-idempotent in other locations.
### IDEMPOTENT CODE ONLY ###
for epoch in remaining_epochs_until(30):
### IDEMPOTENT CODE ONLY ###
for batch in train_loader:
# ... any code ...
### IDEMPOTENT CODE ONLY ###
for batch in valid_loader:
# ... any code ...
# ... any code ...
# ... any code ...
### END PROGRAM ###
For example, a common non-idempotent operation is learning-rate annealing:
for epoch in remaining_epochs_until(30):
lr_scheduler.step() # (A) WRONG!
for batch in train_loader:
# ...
lr_scheduler.step() # (B) WRONG!
for batch in valid_loader:
# ...
lr_scheduler.step() # (C) OK!
Location (A) will be executed again after any checkpoint-restart during either the training or validation loop, resulting in the learning rate being annealed several times in one epoch! Similarly with location (B), if checkpoint-restart happens during the validation loop.
Location (C) results in the correct behavior, because (1) an epoch will not be repeated once it has finished, and (2) no checkpoint-restarts can occur between the learning rate annealing and the end of the epoch.
- adaptdl.torch.epoch.current_epoch()[source]
Get the current epoch while iterating with
remaining_epochs_until()
.- Returns
The current epoch number if called from within a
remaining_epochs_until()
iteration,None
otherwise.- Return type
int or None
- adaptdl.torch.epoch.finished_epochs()[source]
Get the number of epochs finished using
remaining_epochs_until()
.- Returns
The number of finished epochs. Equal to
current_epoch()
if called from within aremaining_epochs_until()
iteration.- Return type
int
- adaptdl.torch.epoch.remaining_epochs_until(epoch)[source]
Iterate over epochs in a way that is consistent with checkpoint-restarts. For example:
for epoch in remaining_epochs_until(30): print(current_epoch()) # Should print 0 through 29 for epoch in remaining_epochs_until(60): print(current_epoch()) # Should print 30 through 59
If a checkpoint-restart happens during an epoch, all previous epochs will be skipped after the program restarts.
- Parameters
epoch (int) – The epoch number to end at (exclusively).
- Raises
RuntimeError – If invoked before a previous epoch loop has ended.