adaptdl.torch.scaling_rules module¶
- class adaptdl.torch.scaling_rules.AdaScale[source]¶
Bases:
adaptdl.torch.scaling_rules.ScalingRuleBase
Implements the AdaScale algorithm for scaling the learning rate for distributed and large batch size training.
- class adaptdl.torch.scaling_rules.LEGWScale(base_warmup_epochs, data_size)[source]¶
Bases:
adaptdl.torch.scaling_rules.ScalingRuleBase
Implements the LEGWScale algorithm for scaling the learning rate.
Essentially, with LEGWScale, lr_factor is calculated based on training progress as follows: - when current_step < base_warmup_epoch * scale * steps_per_epoch:
lr_factor = sqrt(scale) * progress_ratio where `progress_ratio = current_step /
(scale * base_warmup_epochs * steps_per_epoch)`
when current_step >= base_warmup_epoch * scale * steps_per_epoch: lr_factor = sqrt(scale)
In order to adapt LEGWScale to AdaptDL, progress_ratio is calculated differently as: progress / (scale * base_warmup_epochs * steps_per_epoch) where progress is the effective steps trained based on AdaptDL’s estimation.
- Argmuents:
base_warmup_epochs: Base warmup epochs data_size: total number of samples in the dataset
- class adaptdl.torch.scaling_rules.ScalingRuleBase[source]¶
Bases:
object
Base class for scaling rules that has the ability to track gradient noise scale calculations. Its subclasses can be used in combination with
adaptdl.torch.parallel.AdaptiveDataParallel
andtorch.optim.SGD
.optim = torch.optim.SGD(model, lr=0.001) adascale = AdaScale() model = AdaptiveDataParallel(model, optim, adascale) for epoch in ...: for batch in ...: optim.zero_grad() loss = ... loss.backward() adascale.step()