adaptdl.torch package

class adaptdl.torch.Accumulator(*args, **kwargs)[source]

Bases: MutableMapping

This class helps aggregate simple statistics across all replicas in the current job, and across any number of checkpoint-restarts. Can be used to compute metrics like loss and accuracy, synchronized across each replica.

Accumulators imitate python dictionaries, but with a few key differences described below. Primarily, its usage and behavior depend on whether it is set to accumulation mode or to synchronized mode.

  1. Accumulation mode: the accumulator is being updated on all replicas. Operations like accum["key"] += val or accum.update(key=val) will aggregate the updates locally on each replica, which are lazily synchronized in the background (either upon a checkpoint or a switch to synchronized mode). Each replica may make different updates, which are summed together when synchronized. While accumulation mode is enabled, all read operations on the accumulator will behave as if they were performed on an empty dict, ie. len(accum) will always return 0. By default, all accumulators are set to accumulation mode.

  2. Synchronized mode: the accumulator contains the same data on every replica, and the application must ensure that all write operations are exactly the same across all replicas. While in synchronized mode, the accumulator may be used as if it were a native python dict, and all read/write operations are supported. Accumulator.synchronized() may be used to enter synchronized mode. Upon entering synchronized mode, the accumulator will automatically sum all updates from all replicas to ensure the same data is available to each replica.

Using accumulators, many training/validation metrics can be computed easily and correctly in an elastic distributed setting. For example, a simple validation step which calculates a loss and accuracy can be implemented as follows:

accum = Accumulator()  # New accumulator starts in accumulation mode.

for epoch in remaining_epochs_until(60):

    for batch in validloader:
        accum["loss_sum"] += <loss summed within the batch>
        accum["correct"] += <number of correct predictions>
        accum["total"] += <total number of samples in the batch>

    with accum.synchronized():  # Enter synchronized mode.
        accum["loss_avg"] = accum["loss_sum"] / accum["total"]
        accum["accuracy"] = accum["correct"] / accum["total"]
        print("Loss: {}, Accuracy: {}".format(
              accum["loss_avg"], accum["accuracy"]))
    # Back to accumulation mode.
  • args – Positional arguments same as dict.

  • kwargs – Keyword arguments same as dict.


Supports the += operation, e.g. accum += {key1: val1, key2: val2}. Behaves the same way as accum.update({key1: val1, key2: val2}).


other – Mapping object or an iterable of key-update pairs.


Supports the -= operation, e.g. accum -= {key1: val1, key2: val2}. Behaves the same way as accum.subtract({key1: val1, key2: val2}).


other – Mapping object or an iterable of key-update pairs.


Supports indexing, e.g. val = accum[key] and accum[key] += 1. The former (read access) should only be used when the accumulator is in synchronized mode.


other – Key used to access a value in the accumulator.

subtract(*args, **kwargs)[source]

Apply a collection of key-update pairs. Unlike Accumulator.update(), this method subtracts the updates from the accumulated values.


A context manager which can be used to define the code to execute in synchronized mode. Within the context manager, any code can interact with this accumulator as if it were a regular Python dict. The application must ensure that whatever operations performed within this context block are the same across all replicas.


Entering this context manager is a distributed synchronization point! Please ensure that all replicas enter this context manager at the same point in their code.

update(*args, **kwargs)[source]

Apply a collection of key-update pairs. Unlike dict.update, this method additively applies the updates to the accumulated values.

  • args – Positional arguments same as dict.update. Can be a mapping object or an iterable of key-update pairs.

  • kwargs – Keyword arguments same as dict.update. Each keyword is the string key corresponding to the provided update.

class adaptdl.torch.AdaptiveDataLoader(dataset, batch_size=1, shuffle=False, **kwargs)[source]

Bases: DataLoader, AdaptiveDataLoaderMixin

This class is a PyTorch DataLoader that also supports adaptive batch sizes and checkpoint-restart elasticity. Applications can typically use objects of this class as direct replacements for PyTorch DataLoaders. However, some notable differences are:

  1. The batch_size argument defines the target total batch size across all replicas, rather than the local batch size on each replica.

  2. Custom sampler and batch_sampler are not supported.

  3. Iterating through the dataloader is only allowed from within an epoch loop (see adaptdl.torch.epoch), and only one dataloader loop is allowed at any given time.

  • dataset ( – Dataset from which to load the data.

  • batch_size (int) – The target total batch size across all replicas. The actual total batch size may be different due to rounding (each replica must have the same local batch size), or being scaled up using adaptive batch sizes.

  • shuffle (bool) – Whether the data is reshuffled at every epoch.

  • **kwargs – Keyword arguments passed to


ValueError – If sampler or batch_sampler are not None.


Iterate over batches of data. When adaptive batch size is disabled, stops after the entire dataset has been processed once in total by all replicas. This means if there are K replicas, then this method will iterate over ~1/K of the dataset. When adaptive batch size is enabled, stops after making enough statistical progress roughly equivalent to one pass over the dataset with non-adaptive batch size. In this case, the dataset may be processed more than once.

A checkpoint-restart may be triggered in-between each batch. In this case, the current iteration state will be saved and restored after the restart, and continue where it left off.

class adaptdl.torch.AdaptiveDataParallel(model, optimizer, lr_scheduler=None, mp_scaler=None, scaling_rule: Optional[ScalingRuleBase] = None, name='adaptdl-dataparallel', **kwargs)[source]

Bases: DistributedDataParallel

This class extends PyTorch DistributedDataParallel with support for adaptive batch sizes and checkpoint-restart elasticity. It automatically saves the given model, optimizer, and (optionally) LR scheduler whenever a checkpoint is triggered, and restores their states after restart. The optimizer is automatically patched with the chosen scaling rule.

  • model (torch.nn.Module) – Model to be distributed.

  • optimizer (torch.optim.Optimizer) – Optimizer used to update the given

  • parameters (model's) –

  • of (will be patched using subclass) –

:param adaptdl.torch.scaling_rules.ScalingRuleBase.: :param scaling_rule: Scaling rule used to :type scaling_rule: ScalingRuleBase :param patch the given optimizer: :param default to AdaScale.: :param lr_scheduler: LR scheduler used :type lr_scheduler: torch.optim.lr_scheduler._LRScheduler :param to anneal the learning rate for the given optimizer.: :param name: Unique name for each instance of this class, needed only :type name: string :param if multiple instances exist.:

forward(*args, **kwargs)[source]
property gain

Current estimate of the AdaScale gain (r_t) value.

to_tensorboard(writer, global_step, tag_prefix='')[source]

Output some useful metrics to TensorBoard.

  • writer (torch.utils.tensorboard.SummaryWriter) – SummaryWriter object to output metrics to.

  • global_step (int) – Global step value to record.

  • tag_prefix (str) – Prefix added to each metric’s tag.

training: bool
zero_grad(*args, **kwargs)[source]

Sets gradients of all model parameters to zero.

class adaptdl.torch.ElasticSampler(dataset, shuffle=True)[source]

Bases: Sampler

A PyTorch Sampler which partitions data samples across multiple replicas, and supports deterministic continuing across checkpoint-restarts. Shuffling is deterministic for each epoch, and ElasticSampler.set_epoch() should be invoked to obtain different orderings in different epochs.

  • dataset ( – The dataset to sample from.

  • shuffle (bool) – Whether the data samples should be shuffled.


Iterate through the samples in the dataset, in the order defined for a set epoch, starting at a set index. Produces only the indices for the local replica.

Returns: Iterator over data sample indices.


The total number of samples to be iterated through, starting at the set index, for the local replica.

Returns (int): Number of samples.

set_epoch(epoch, index=0)[source]

Set the epoch to derive samples from. Optional argument index can be specified to start sampling from a particular index, e.g. after a checkpoint-restart.

  • epoch (int) – The epoch to sample from.

  • index (int) – The index to start sampling from.


Reference to the data loader currently being iterated.

Returns (AdaptiveDataLoaderHelper): Current data loader.


Get the current epoch while iterating with remaining_epochs_until().


The current epoch number if called from within a remaining_epochs_until() iteration, None otherwise.

Return type

int or None


Get the number of epochs finished using remaining_epochs_until().


The number of finished epochs. Equal to current_epoch() if called from within a remaining_epochs_until() iteration.

Return type


adaptdl.torch.init_process_group(backend, init_method=None, world_size=None, rank=None)[source]

Initializes the default distributed process group and the AdaptDL collectives module.

  • backend (str or Backend) – The backend to use. Use “nccl” for multi-GPU training else “gloo”.

  • init_method (str, optional) – URL specifying how to initialize the process group.

  • world_size (int, optional) – Number of processes participating in the job

  • rank (int, optional) – Rank of the current process (it should be a number between 0 and world_size-1).

If init_method, world_size and rank is NOT provided, typically in the Kubernetes environment, AdaptDL will try to infer them through environment variables ADAPTDL_MASTER_ADDR, ADAPTDL_NUM_REPLICAS and ADAPTDL_REPLICA_RANK respectively.


Iterate over epochs in a way that is consistent with checkpoint-restarts. For example:

for epoch in remaining_epochs_until(30):
    print(current_epoch())  # Should print 0 through 29

for epoch in remaining_epochs_until(60):
    print(current_epoch())  # Should print 30 through 59

If a checkpoint-restart happens during an epoch, all previous epochs will be skipped after the program restarts.


epoch (int) – The epoch number to end at (exclusively).


RuntimeError – If invoked before a previous epoch loop has ended.