Source code for adaptdl.torch.scaling_rules

# Copyright 2020 Petuum, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import math
import numpy as np
import warnings

from types import MethodType

from adaptdl.torch.data import current_dataloader


__all__ = ["ScalingRuleBase", "AdaScale", "LinearScale", "SqrtScale",
           "LEGWScale"]


[docs]class ScalingRuleBase(object): """ Base class for scaling rules that has the ability to track gradient noise scale calculations. Its subclasses can be used in combination with ``adaptdl.torch.parallel.AdaptiveDataParallel`` and ``torch.optim.SGD``. .. code-block:: python optim = torch.optim.SGD(model, lr=0.001) adascale = AdaScale() model = AdaptiveDataParallel(model, optim, adascale) for epoch in ...: for batch in ...: optim.zero_grad() loss = ... loss.backward() adascale.step() """ def __init__(self): # instance of AdaptiveDataParallel, needs to be set before any of the # methods can be used self.adp = None self._optimizer = None self._orig_optimizer_step = None
[docs] def scale_lr(self, scale): raise NotImplementedError
[docs] def zero_grad(self, *args, **kwargs): if self.adp.gns.should_zero_grad: self.adp.gns.reset_accumulation(*args, **kwargs) else: warnings.warn("skipping zero_grad for accumulated gradient")
[docs] def step(self, *args, **kwargs): """ Run one optimizer step. Essentially just invokes ``optimizer.step(*args, **kwargs)`` with a scaled learning rate. Arguments: args: Positional arguments passed to ``optimizer.step``. kwargs: Keyword arguments passed to ``optimizer.step``. """ if not self.adp: raise ValueError("AdaptiveDataParallel instance is not set!") if not self.adp.require_backward_grad_sync: return scale = self.adp.gns.accum_scale * self.adp.gns.accum_count initial_lr = [pg["lr"] for pg in self._optimizer.param_groups] scaled_lr = np.multiply(self.scale_lr(scale), initial_lr) for lr, pg in zip(scaled_lr, self._optimizer.param_groups): pg["lr"] = lr self._orig_optimizer_step(*args, **kwargs) for lr, pg in zip(initial_lr, self._optimizer.param_groups): pg["lr"] = lr self.adp.gns.set_progress(self.adp.gns.get_progress() + self.adp.gns.gain(scale))
def _patch_optimizer(self): """ Monkey-patch the optimizer's step function with :meth:`ScalingRuleBase.step`. """ @functools.wraps(self._optimizer.step) def step_wrapper(optim, *args, **kwargs): return self.step(*args, **kwargs) @functools.wraps(self._optimizer.zero_grad) def zero_wrapper(optim, *args, **kwargs): return self.zero_grad(*args, **kwargs) self._optimizer.step = MethodType(step_wrapper, self._optimizer) self._optimizer.zero_grad = MethodType(zero_wrapper, self._optimizer)
[docs] def initialize(self, adp, optimizer, patch_optimizer=False): self.adp = adp self._optimizer = optimizer self._orig_optimizer_step = optimizer.step if patch_optimizer: self._patch_optimizer()
[docs]class AdaScale(ScalingRuleBase): """ Implements the AdaScale_ algorithm for scaling the learning rate for distributed and large batch size training. .. _AdaScale: https://proceedings.icml.cc/static/paper_files/icml/2020/4682-Supplemental.pdf """ # noqa: E501
[docs] def scale_lr(self, scale): """Calculate factors to be applied to lr for each parameter group.""" var = self.adp.gns.raw_var_avg sqr = self.adp.gns.raw_sqr_avg var = np.maximum(var, 1e-6) sqr = np.maximum(sqr, 0.0) return (var + sqr) / (var / scale + sqr)
class AdamScale(AdaScale): """ Implements the variant of AdaScale_ that supports Adam, AdamW and RMSProp """ def scale_lr(self, scale, power=0.5): return np.power(super().scale_lr(scale=scale), power)
[docs]class LinearScale(ScalingRuleBase):
[docs] def scale_lr(self, scale): return scale
[docs]class SqrtScale(ScalingRuleBase):
[docs] def scale_lr(self, scale): return math.sqrt(scale)
[docs]class LEGWScale(ScalingRuleBase): """ Implements the LEGWScale algorithm for scaling the learning rate. Essentially, with LEGWScale, lr_factor is calculated based on training progress as follows: - when current_step < base_warmup_epoch * scale * steps_per_epoch: `lr_factor = sqrt(scale) * progress_ratio` where `progress_ratio = current_step / (scale * base_warmup_epochs * steps_per_epoch)` - when current_step >= base_warmup_epoch * scale * steps_per_epoch: `lr_factor = sqrt(scale)` In order to adapt LEGWScale to AdaptDL, `progress_ratio` is calculated differently as: `progress / (scale * base_warmup_epochs * steps_per_epoch)` where `progress` is the effective steps trained based on AdaptDL's estimation. Argmuents: base_warmup_epochs: Base warmup epochs data_size: total number of samples in the dataset .. _LEGWScale: https://arxiv.org/pdf/1901.08256.pdf """ def __init__(self, base_warmup_epochs, data_size): super().__init__() self._base_warmup_epochs = base_warmup_epochs self._data_size = data_size
[docs] def scale_lr(self, scale): dataloader = current_dataloader() # total training steps for warm up total_steps = self._base_warmup_epochs * scale * \ self._data_size / dataloader.batch_size max_lr_multiplier = math.sqrt(scale) # effective training steps taken progress = self.adp.gns.get_progress() if progress < total_steps: lr_factor = max_lr_multiplier * (progress / total_steps) else: lr_factor = max_lr_multiplier return lr_factor