Source code for adaptdl.torch.gradient_noise_scale

import functools
import logging
import math
import numpy as np
import torch.distributed
import torch.optim

from torch.autograd import Variable

import adaptdl.utils

__all__ = ["GradientNoiseScale"]

logging.basicConfig(level=logging.INFO)
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)


def _average_groups(grads1, grads2):
    ret = []
    for group1, group2 in zip(grads1, grads2):
        ret.append([])
        for g1, g2 in zip(group1, group2):
            if g1 is None:
                ret[-1].append(g2)
            elif g2 is None:
                ret[-1].append(g1)
            else:
                ret[-1].append((g1 + g2) / 2)
    return ret


def _normsqr_groups(grads, pinvs):
    ret = []
    for group, pinv_group in zip(grads, pinvs):
        normsqr = [(g / pinv).pow(2).sum(dtype=torch.float64)
                   for g, pinv in zip(group, pinv_group) if g is not None]
        ret.append(sum(normsqr).item() if normsqr else 0.0)
    return np.array(ret)


[docs]class GradientNoiseScale(object): """This class tracks gradient related stats and takes care of gradient accumulation.""" def __init__(self, adp, optimizer, mp_scaler=None, num_replicas=None, accum_scale=None): self._adp = adp self._optimizer = optimizer self._orig_optimizer_zero_grad = optimizer.zero_grad self._should_zero_grad = True self._mp_scaler = mp_scaler self._local_sqr = None self._num_replicas = (num_replicas if num_replicas is not None else torch.distributed.get_world_size()) self._accum_scale = accum_scale or self._num_replicas self._prev_grads = None self.reset_accumulation() self._optimizer.state.setdefault("gns", { "progress": 0.0, "prev_scale": 0.0, # Averages of n and v "sqr_avg": np.ones(len(optimizer.param_groups)), "var_avg": np.zeros(len(optimizer.param_groups)), # Whether estimates are biased (using differenced estimator). "biased": False, }) for idx, param_group in enumerate(self._optimizer.param_groups): for param in param_group["params"]: param.register_hook( functools.partial(self._backward_hook, idx, param)) self._callback_queued = False self._smoothing = 0.999 @property def _state(self): return self._optimizer.state["gns"]
[docs] def reset_accumulation(self): """reset accumulation calculations and gradients.""" self._orig_optimizer_zero_grad() self._local_sqr = None self._accum_count = 0
@property def should_zero_grad(self): return self._should_zero_grad @property def accum_scale(self): return self._accum_scale @property def accum_count(self): return self._accum_count
[docs] def set_accum_scale(self, accum_scale): if not np.isclose(self._accum_scale, accum_scale): self.reset_accumulation() self._accum_scale = accum_scale
@property def raw_sqr_avg(self): view = self._state["sqr_avg"].view() view.flags.writeable = False return view
[docs] def sqr_avg(self): """ Current estimate of the squared l2-norm of the true gradient (sigma squared). Returns (float): Estimate of squared l2-norm. """ return float(np.sum(np.maximum(self._state["sqr_avg"], 0.0)))
@property def raw_var_avg(self): view = self._state["var_avg"].view() view.flags.writeable = False return view
[docs] def var_avg(self): """ Current estimate of the trace of the covariance of the true gradient (mu squared). Returns (float): Estimate of trace of the covariance. """ return float(np.sum(np.maximum(self._state["var_avg"], 1e-6)))
[docs] def get_progress(self): return self._state["progress"]
[docs] def set_progress(self, progress): self._state["progress"] = progress
[docs] def gain(self, scale): """ Current estimate of the GradientNoiseScale gain ratio. Arguments: scale (float): The total scale to estimate the gain ratio for. Returns (float): Estimate of gain ratio. """ var = self.var_avg() norm = self.sqr_avg() return (var + norm) / (var / scale + norm)
def _update_avg(self, param_name, value, factor): biased = self._state.get(param_name + "_biased", 0.0) unbias = self._state.get(param_name + "_unbias", 0.0) biased = factor * biased + (1.0 - factor) * value unbias = factor * unbias + (1.0 - factor) self._state[param_name + "_biased"] = biased self._state[param_name + "_unbias"] = unbias self._state[param_name] = biased / unbias def _reset_avg(self, param_name): self._state.pop(param_name + "_biased", None) self._state.pop(param_name + "_unbias", None) @adaptdl.utils.print_exc def _backward_hook(self, idx, param, grad): # This method should be invoked once for each parameter during the # backward pass, before gradients are synchronized between replicas. if self._local_sqr is None: self._local_sqr = torch.zeros(len(self._optimizer.param_groups), device=grad.device, dtype=torch.float64) # Get the preconditioning matrix for the optimizer preconditioner = self._calculate_preconditioner(idx, param) # Update the local gradient square sum self._local_sqr[idx] += \ (grad.detach() / preconditioner).pow(2).sum(dtype=torch.float64) if not self._callback_queued: Variable._execution_engine.queue_callback(self._queue_callback) self._callback_queued = True @adaptdl.utils.print_exc def _queue_callback(self): # This method should be invoked after the entire backward pass. We want # to make sure self._final_callback is invoked once, only after all # gradients have been synchronized between each replica. However, the # synchronization code in DistributedDataParallel is also done in a # callback, which might not yet be executed. Therefore, we enqueue # self._final_callback from this method, which should ensure it is # invoked after the gradient synchronization callback. self._callback_queued = False self._accum_count += 1 if self._adp.require_backward_grad_sync: # Asynchronously sum the local squared-gradient statistics. The # actual gradient averaging should also be happening at the same # time, until self._final_callback is invoked. if self._num_replicas > 1: self._async_op = torch.distributed.all_reduce(self._local_sqr, async_op=True) Variable._execution_engine.queue_callback(self._final_callback) self._should_zero_grad = True else: # Keep on accumulating gradients, should not zero grad. self._should_zero_grad = False @adaptdl.utils.print_exc def _final_callback(self): # This method should be invoked once the gradients have been # synchronized between all replicas and accumulation steps. if self._num_replicas > 1: self._async_op.wait() grads = [] if self._mp_scaler is not None: mixed_precision_scale = self._mp_scaler.get_scale() else: mixed_precision_scale = 1.0 for group in self._optimizer.param_groups: grads.append([]) for param in group["params"]: if param.grad is None: grads[-1].append(None) continue param.grad.div_(self._accum_count) grads[-1].append(param.grad.detach().float() / mixed_precision_scale) preconditioner = self._get_preconditioner() # Note: mixed precision can result in nan/inf gradients, # which propogate into our norm and variance estimates. # Mixed precision autoscaling skips the skip where # there are nan/inf, so we also skip the update here grads_normsqr = _normsqr_groups(grads, preconditioner) if not np.all(np.isfinite(grads_normsqr)): LOG.warning(f"GradientNoiseScale detected invalid gradient! " f"at scale {mixed_precision_scale}, Skipping step.") return count = self._num_replicas * self._accum_count scale = self._accum_scale * self._accum_count if count > 1: # Average local squared-norm samples. local_sqr = self._local_sqr.cpu().numpy() / count # Gradient is squared in local_sqr, so need to square the # mixed precision scale as well local_sqr = (local_sqr / mixed_precision_scale ** 2) total_sqr = grads_normsqr if self._state["biased"]: self._reset_avg("sqr_avg") self._reset_avg("var_avg") self._state["biased"] = False self._prev_grads = None else: # Single gradient datapoint, use difference estimation. if self._prev_grads is not None: local_sqr = (_normsqr_groups(self._prev_grads, preconditioner) + grads_normsqr) / 2 avg_grads = _average_groups(grads, self._prev_grads) total_sqr = _normsqr_groups(avg_grads, preconditioner) count = 2 scale = 2 * self._accum_scale self._state["biased"] = True self._prev_grads = [[g.clone() if g is not None else None for g in group] for group in grads] if count > 1: grad_sqr = (count * total_sqr - local_sqr) / (count - 1) grad_var = (local_sqr - total_sqr) * scale / (count - 1) theta = self._smoothing ** scale self._update_avg('sqr_avg', grad_sqr, theta) self._update_avg('var_avg', grad_var, theta) def _get_preconditioner(self): out = [] for idx, group in enumerate(self._optimizer.param_groups): pinvs = [] for param in group["params"]: pinv = self._calculate_preconditioner(idx, param) pinvs.append(pinv) out.append(pinvs) return out def _calculate_preconditioner(self, idx, param): return torch.ones_like(param, memory_format=torch.preserve_format)
class AdamGradientNoiseScale(GradientNoiseScale): def __init__(self, adp, optimizer, mp_scaler=None, num_replicas=None, accum_scale=None): self._adam_param_group = {'beta': [], 'eps': []} super().__init__(adp, optimizer, mp_scaler, num_replicas, accum_scale) for idx, param_group in enumerate(self._optimizer.param_groups): self._adam_param_group['beta'].append(param_group['betas'][1]) self._adam_param_group['eps'].append(param_group['eps']) def _calculate_preconditioner(self, idx, param): state = self._optimizer.state[param] if state.get('step', 0) < 5: return torch.ones_like(param, memory_format=torch.preserve_format) exp_avg_sq = state["exp_avg_sq"].clone() # not sure if clone is needed beta2 = self._adam_param_group['beta'][idx] eps = self._adam_param_group['eps'][idx] correction = 1 - beta2 ** state['step'] pinv = (exp_avg_sq.sqrt() / math.sqrt(correction)).add_(eps) return pinv.to(param.device) def _reset_adam_state(self, step=0): for group in self._optimizer.param_groups: beta1, beta2 = group["betas"] for param in group["params"]: state = self._optimizer.state[param] if state.get("step", 0) > 0: state["exp_avg"].mul_( (1 - beta1 ** step) / (1 - beta1 ** state["step"])) state["exp_avg_sq"].mul_( (1 - beta2 ** step) / (1 - beta2 ** state["step"])) state["step"] = step def _final_callback(self): scale = self._accum_scale * self._accum_count if not np.isclose(scale, self._state["prev_scale"]): # reset Adam states when scale is changed self._reset_adam_state() self._state["prev_scale"] = scale return super()._final_callback()