adaptdl.torch.parallel module

class adaptdl.torch.parallel.AdaptiveDataParallel(model, optimizer, lr_scheduler=None, mp_scaler=None, scaling_rule: Optional[ScalingRuleBase] = None, name='adaptdl-dataparallel', **kwargs)[source]

Bases: DistributedDataParallel

This class extends PyTorch DistributedDataParallel with support for adaptive batch sizes and checkpoint-restart elasticity. It automatically saves the given model, optimizer, and (optionally) LR scheduler whenever a checkpoint is triggered, and restores their states after restart. The optimizer is automatically patched with the chosen scaling rule.

  • model (torch.nn.Module) – Model to be distributed.

  • optimizer (torch.optim.Optimizer) – Optimizer used to update the given

  • parameters (model's) –

  • of (will be patched using subclass) –

:param adaptdl.torch.scaling_rules.ScalingRuleBase.: :param scaling_rule: Scaling rule used to :type scaling_rule: ScalingRuleBase :param patch the given optimizer: :param default to AdaScale.: :param lr_scheduler: LR scheduler used :type lr_scheduler: torch.optim.lr_scheduler._LRScheduler :param to anneal the learning rate for the given optimizer.: :param name: Unique name for each instance of this class, needed only :type name: string :param if multiple instances exist.:

forward(*args, **kwargs)[source]
property gain

Current estimate of the AdaScale gain (r_t) value.

to_tensorboard(writer, global_step, tag_prefix='')[source]

Output some useful metrics to TensorBoard.

  • writer (torch.utils.tensorboard.SummaryWriter) – SummaryWriter object to output metrics to.

  • global_step (int) – Global step value to record.

  • tag_prefix (str) – Prefix added to each metric’s tag.

training: bool
zero_grad(*args, **kwargs)[source]

Sets gradients of all model parameters to zero.