adaptdl.torch package¶
- class adaptdl.torch.Accumulator(*args, **kwargs)[source]¶
Bases:
collections.abc.MutableMapping
This class helps aggregate simple statistics across all replicas in the current job, and across any number of checkpoint-restarts. Can be used to compute metrics like loss and accuracy, synchronized across each replica.
Accumulators imitate python dictionaries, but with a few key differences described below. Primarily, its usage and behavior depend on whether it is set to accumulation mode or to synchronized mode.
Accumulation mode: the accumulator is being updated on all replicas. Operations like
accum["key"] += val
oraccum.update(key=val)
will aggregate the updates locally on each replica, which are lazily synchronized in the background (either upon a checkpoint or a switch to synchronized mode). Each replica may make different updates, which are summed together when synchronized. While accumulation mode is enabled, all read operations on the accumulator will behave as if they were performed on an emptydict
, ie.len(accum)
will always return0
. By default, all accumulators are set to accumulation mode.Synchronized mode: the accumulator contains the same data on every replica, and the application must ensure that all write operations are exactly the same across all replicas. While in synchronized mode, the accumulator may be used as if it were a native python
dict
, and all read/write operations are supported.Accumulator.synchronized()
may be used to enter synchronized mode. Upon entering synchronized mode, the accumulator will automatically sum all updates from all replicas to ensure the same data is available to each replica.
Using accumulators, many training/validation metrics can be computed easily and correctly in an elastic distributed setting. For example, a simple validation step which calculates a loss and accuracy can be implemented as follows:
accum = Accumulator() # New accumulator starts in accumulation mode. for epoch in remaining_epochs_until(60): for batch in validloader: ... accum["loss_sum"] += <loss summed within the batch> accum["correct"] += <number of correct predictions> accum["total"] += <total number of samples in the batch> with accum.synchronized(): # Enter synchronized mode. accum["loss_avg"] = accum["loss_sum"] / accum["total"] accum["accuracy"] = accum["correct"] / accum["total"] print("Loss: {}, Accuracy: {}".format( accum["loss_avg"], accum["accuracy"])) accum.clear() # Back to accumulation mode.
- Parameters
args – Positional arguments same as
dict
.kwargs – Keyword arguments same as
dict
.
- __iadd__(other)[source]¶
Supports the += operation, e.g.
accum += {key1: val1, key2: val2}
. Behaves the same way asaccum.update({key1: val1, key2: val2})
.- Parameters
other – Mapping object or an iterable of key-update pairs.
- __isub__(other)[source]¶
Supports the -= operation, e.g.
accum -= {key1: val1, key2: val2}
. Behaves the same way asaccum.subtract({key1: val1, key2: val2})
.- Parameters
other – Mapping object or an iterable of key-update pairs.
- __getitem__(key)[source]¶
Supports indexing, e.g.
val = accum[key]
andaccum[key] += 1
. The former (read access) should only be used when the accumulator is in synchronized mode.- Parameters
other – Key used to access a value in the accumulator.
- subtract(*args, **kwargs)[source]¶
Apply a collection of key-update pairs. Unlike
Accumulator.update()
, this method subtracts the updates from the accumulated values.- Parameters
args – Positional arguments same as
Accumulator.update()
.kwargs – Keyword arguments same as
Accumulator.update()
.
- synchronized()[source]¶
A context manager which can be used to define the code to execute in synchronized mode. Within the context manager, any code can interact with this accumulator as if it were a regular Python
dict
. The application must ensure that whatever operations performed within this context block are the same across all replicas.Warning
Entering this context manager is a distributed synchronization point! Please ensure that all replicas enter this context manager at the same point in their code.
- update(*args, **kwargs)[source]¶
Apply a collection of key-update pairs. Unlike
dict.update
, this method additively applies the updates to the accumulated values.- Parameters
args – Positional arguments same as
dict.update
. Can be a mapping object or an iterable of key-update pairs.kwargs – Keyword arguments same as
dict.update
. Each keyword is the string key corresponding to the provided update.
- class adaptdl.torch.AdaptiveDataLoader(dataset, batch_size=1, shuffle=False, **kwargs)[source]¶
Bases:
torch.utils.data.dataloader.DataLoader
,adaptdl.torch.data.AdaptiveDataLoaderMixin
This class is a PyTorch DataLoader that also supports adaptive batch sizes and checkpoint-restart elasticity. Applications can typically use objects of this class as direct replacements for PyTorch DataLoaders. However, some notable differences are:
The
batch_size
argument defines the target total batch size across all replicas, rather than the local batch size on each replica.Custom
sampler
andbatch_sampler
are not supported.Iterating through the dataloader is only allowed from within an epoch loop (see
adaptdl.torch.epoch
), and only one dataloader loop is allowed at any given time.
- Parameters
dataset (torch.util.data.Dataset) – Dataset from which to load the data.
batch_size (int) – The target total batch size across all replicas. The actual total batch size may be different due to rounding (each replica must have the same local batch size), or being scaled up using adaptive batch sizes.
shuffle (bool) – Whether the data is reshuffled at every epoch.
**kwargs – Keyword arguments passed to
torch.util.data.Dataloader
.
- Raises
ValueError – If
sampler
orbatch_sampler
are notNone
.
- __iter__()[source]¶
Iterate over batches of data. When adaptive batch size is disabled, stops after the entire dataset has been processed once in total by all replicas. This means if there are K replicas, then this method will iterate over ~1/K of the dataset. When adaptive batch size is enabled, stops after making enough statistical progress roughly equivalent to one pass over the dataset with non-adaptive batch size. In this case, the dataset may be processed more than once.
A checkpoint-restart may be triggered in-between each batch. In this case, the current iteration state will be saved and restored after the restart, and continue where it left off.
- class adaptdl.torch.AdaptiveDataParallel(model, optimizer, lr_scheduler=None, mp_scaler=None, scaling_rule: Optional[adaptdl.torch.scaling_rules.ScalingRuleBase] = None, name='adaptdl-dataparallel', **kwargs)[source]¶
Bases:
torch.nn.parallel.distributed.DistributedDataParallel
This class extends PyTorch DistributedDataParallel with support for adaptive batch sizes and checkpoint-restart elasticity. It automatically saves the given model, optimizer, and (optionally) LR scheduler whenever a checkpoint is triggered, and restores their states after restart. The optimizer is automatically patched with the chosen scaling rule.
- Parameters
model (torch.nn.Module) – Model to be distributed.
optimizer (torch.optim.Optimizer) – Optimizer used to update the given
parameters (model's) –
of (will be patched using subclass) –
:param
adaptdl.torch.scaling_rules.ScalingRuleBase
.: :param scaling_rule: Scaling rule used to :type scaling_rule: ScalingRuleBase :param patch the given optimizer: :param default to AdaScale.: :param lr_scheduler: LR scheduler used :type lr_scheduler: torch.optim.lr_scheduler._LRScheduler :param to anneal the learning rate for the given optimizer.: :param name: Unique name for each instance of this class, needed only :type name: string :param if multiple instances exist.:- property gain¶
Current estimate of the AdaScale gain (r_t) value.
- to_tensorboard(writer, global_step, tag_prefix='')[source]¶
Output some useful metrics to TensorBoard.
- Parameters
writer (torch.utils.tensorboard.SummaryWriter) –
SummaryWriter
object to output metrics to.global_step (int) – Global step value to record.
tag_prefix (str) – Prefix added to each metric’s tag.
- training: bool¶
- class adaptdl.torch.ElasticSampler(dataset, shuffle=True)[source]¶
Bases:
torch.utils.data.sampler.Sampler
A PyTorch Sampler which partitions data samples across multiple replicas, and supports deterministic continuing across checkpoint-restarts. Shuffling is deterministic for each epoch, and
ElasticSampler.set_epoch()
should be invoked to obtain different orderings in different epochs.- Parameters
dataset (torch.util.data.Dataset) – The dataset to sample from.
shuffle (bool) – Whether the data samples should be shuffled.
- __iter__()[source]¶
Iterate through the samples in the dataset, in the order defined for a set epoch, starting at a set index. Produces only the indices for the local replica.
Returns: Iterator over data sample indices.
- adaptdl.torch.current_dataloader()[source]¶
Reference to the data loader currently being iterated.
Returns (AdaptiveDataLoaderHelper): Current data loader.
- adaptdl.torch.current_epoch()[source]¶
Get the current epoch while iterating with
remaining_epochs_until()
.- Returns
The current epoch number if called from within a
remaining_epochs_until()
iteration,None
otherwise.- Return type
int or None
- adaptdl.torch.finished_epochs()[source]¶
Get the number of epochs finished using
remaining_epochs_until()
.- Returns
The number of finished epochs. Equal to
current_epoch()
if called from within aremaining_epochs_until()
iteration.- Return type
int
- adaptdl.torch.init_process_group(backend, init_method=None, world_size=None, rank=None)[source]¶
Initializes the default distributed process group and the AdaptDL collectives module.
- Parameters
backend (str or Backend) – The backend to use. Use “nccl” for multi-GPU training else “gloo”.
init_method (str, optional) – URL specifying how to initialize the process group.
world_size (int, optional) – Number of processes participating in the job
rank (int, optional) – Rank of the current process (it should be a number between 0 and
world_size
-1).
If init_method, world_size and rank is NOT provided, typically in the Kubernetes environment, AdaptDL will try to infer them through environment variables ADAPTDL_MASTER_ADDR, ADAPTDL_NUM_REPLICAS and ADAPTDL_REPLICA_RANK respectively.
- adaptdl.torch.remaining_epochs_until(epoch)[source]¶
Iterate over epochs in a way that is consistent with checkpoint-restarts. For example:
for epoch in remaining_epochs_until(30): print(current_epoch()) # Should print 0 through 29 for epoch in remaining_epochs_until(60): print(current_epoch()) # Should print 30 through 59
If a checkpoint-restart happens during an epoch, all previous epochs will be skipped after the program restarts.
- Parameters
epoch (int) – The epoch number to end at (exclusively).
- Raises
RuntimeError – If invoked before a previous epoch loop has ended.