adaptdl.torch.parallel module
- class adaptdl.torch.parallel.AdaptiveDataParallel(model, optimizer, lr_scheduler=None, mp_scaler=None, scaling_rule: Optional[ScalingRuleBase] = None, name='adaptdl-dataparallel', **kwargs)[source]
Bases:
DistributedDataParallel
This class extends PyTorch DistributedDataParallel with support for adaptive batch sizes and checkpoint-restart elasticity. It automatically saves the given model, optimizer, and (optionally) LR scheduler whenever a checkpoint is triggered, and restores their states after restart. The optimizer is automatically patched with the chosen scaling rule.
- Parameters
model (torch.nn.Module) – Model to be distributed.
optimizer (torch.optim.Optimizer) – Optimizer used to update the given
parameters (model's) –
of (will be patched using subclass) –
:param
adaptdl.torch.scaling_rules.ScalingRuleBase
.: :param scaling_rule: Scaling rule used to :type scaling_rule: ScalingRuleBase :param patch the given optimizer: :param default to AdaScale.: :param lr_scheduler: LR scheduler used :type lr_scheduler: torch.optim.lr_scheduler._LRScheduler :param to anneal the learning rate for the given optimizer.: :param name: Unique name for each instance of this class, needed only :type name: string :param if multiple instances exist.:- property gain
Current estimate of the AdaScale gain (r_t) value.
- to_tensorboard(writer, global_step, tag_prefix='')[source]
Output some useful metrics to TensorBoard.
- Parameters
writer (torch.utils.tensorboard.SummaryWriter) –
SummaryWriter
object to output metrics to.global_step (int) – Global step value to record.
tag_prefix (str) – Prefix added to each metric’s tag.
- training: bool