adaptdl.torch.scaling_rules module

class adaptdl.torch.scaling_rules.AdaScale[source]

Bases: ScalingRuleBase

Implements the AdaScale algorithm for scaling the learning rate for distributed and large batch size training.


Calculate factors to be applied to lr for each parameter group.

class adaptdl.torch.scaling_rules.LEGWScale(base_warmup_epochs, data_size)[source]

Bases: ScalingRuleBase

Implements the LEGWScale algorithm for scaling the learning rate.

Essentially, with LEGWScale, lr_factor is calculated based on training progress as follows: - when current_step < base_warmup_epoch * scale * steps_per_epoch:

lr_factor = sqrt(scale) * progress_ratio where `progress_ratio = current_step /

(scale * base_warmup_epochs * steps_per_epoch)`

  • when current_step >= base_warmup_epoch * scale * steps_per_epoch: lr_factor = sqrt(scale)

In order to adapt LEGWScale to AdaptDL, progress_ratio is calculated differently as: progress / (scale * base_warmup_epochs * steps_per_epoch) where progress is the effective steps trained based on AdaptDL’s estimation.


base_warmup_epochs: Base warmup epochs data_size: total number of samples in the dataset

class adaptdl.torch.scaling_rules.LinearScale[source]

Bases: ScalingRuleBase

class adaptdl.torch.scaling_rules.ScalingRuleBase[source]

Bases: object

Base class for scaling rules that has the ability to track gradient noise scale calculations. Its subclasses can be used in combination with adaptdl.torch.parallel.AdaptiveDataParallel and torch.optim.SGD.

optim = torch.optim.SGD(model, lr=0.001)
adascale = AdaScale()
model = AdaptiveDataParallel(model, optim, adascale)

for epoch in ...:
    for batch in ...:
        loss = ...
initialize(adp, optimizer, patch_optimizer=False)[source]
step(*args, **kwargs)[source]

Run one optimizer step. Essentially just invokes optimizer.step(*args, **kwargs) with a scaled learning rate.

  • args – Positional arguments passed to optimizer.step.

  • kwargs – Keyword arguments passed to optimizer.step.

zero_grad(*args, **kwargs)[source]
class adaptdl.torch.scaling_rules.SqrtScale[source]

Bases: ScalingRuleBase