Source code for adaptdl.torch.accumulator

# Copyright 2020 Petuum, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import collections.abc
import contextlib
import copy
import pickle

import adaptdl.checkpoint
import adaptdl.collective
from adaptdl.torch.epoch import current_epoch
from adaptdl.torch.data import current_dataloader


[docs]class Accumulator(collections.abc.MutableMapping): """ This class helps aggregate simple statistics across all replicas in the current job, and across any number of checkpoint-restarts. Can be used to compute metrics like loss and accuracy, synchronized across each replica. Accumulators imitate python dictionaries, but with a few key differences described below. Primarily, its usage and behavior depend on whether it is set to *accumulation mode* or to *synchronized mode*. 1. **Accumulation mode:** the accumulator is being updated on all replicas. Operations like ``accum["key"] += val`` or ``accum.update(key=val)`` will aggregate the updates locally on each replica, which are lazily synchronized in the background (either upon a checkpoint or a switch to synchronized mode). Each replica may make different updates, which are summed together when synchronized. While accumulation mode is enabled, all read operations on the accumulator will behave as if they were performed on an empty ``dict``, ie. ``len(accum)`` will always return ``0``. By default, all accumulators are set to accumulation mode. 2. **Synchronized mode:** the accumulator contains the same data on every replica, and the application must ensure that all write operations are exactly the same across all replicas. While in synchronized mode, the accumulator may be used as if it were a native python ``dict``, and all read/write operations are supported. :meth:`Accumulator.synchronized` may be used to enter synchronized mode. Upon entering synchronized mode, the accumulator will automatically sum all updates from all replicas to ensure the same data is available to each replica. Using accumulators, many training/validation metrics can be computed easily and correctly in an elastic distributed setting. For example, a simple validation step which calculates a loss and accuracy can be implemented as follows: .. code-block:: python accum = Accumulator() # New accumulator starts in accumulation mode. for epoch in remaining_epochs_until(60): for batch in validloader: ... accum["loss_sum"] += <loss summed within the batch> accum["correct"] += <number of correct predictions> accum["total"] += <total number of samples in the batch> with accum.synchronized(): # Enter synchronized mode. accum["loss_avg"] = accum["loss_sum"] / accum["total"] accum["accuracy"] = accum["correct"] / accum["total"] print("Loss: {}, Accuracy: {}".format( accum["loss_avg"], accum["accuracy"])) accum.clear() # Back to accumulation mode. Arguments: args: Positional arguments same as ``dict``. kwargs: Keyword arguments same as ``dict``. .. automethod:: __iadd__ .. automethod:: __isub__ .. automethod:: __getitem__ """ def __init__(self, *args, **kwargs): self._sync_count = collections.Counter() self._synchronized = None self._state = _AccumulatorState(*args, **kwargs) adaptdl.checkpoint.load_state(self._state)
[docs] @contextlib.contextmanager def synchronized(self): """ A context manager which can be used to define the code to execute in *synchronized* mode. Within the context manager, any code can interact with this accumulator as if it were a regular Python ``dict``. The application must ensure that whatever operations performed within this context block are the same across all replicas. .. warning:: Entering this context manager is a distributed synchronization point! Please ensure that all replicas enter this context manager at the same point in their code. """ if self._synchronized is not None: # Already synchronized, don't need to do anything. yield self return epoch = current_epoch() # Remove saved results from all finished epochs. Since finished # epochs are never replayed, they should never be needed again. for key in list(self._state.results_history.keys()): if key is not None and key < epoch: self._state.results_history.pop(key) # Get the number of synchronizations so far in the current epoch. count = self._sync_count[epoch] self._sync_count[epoch] += 1 results_list = self._state.results_history[epoch] assert count <= len(results_list) if count < len(results_list): # Results for this synchronization are saved in the history. self._synchronized = results_list[count] self._state.updates.clear() else: self._state.sync() # Sync results and updates across replicas. if current_dataloader() is None: # Only save into results history if outside of a dataloader # iteration, since code inside iterations are not replayed. results_list.append(copy.deepcopy(self._state.results)) self._synchronized = self._state.results try: yield self finally: self._synchronized = None
[docs] def update(self, *args, **kwargs): """ Apply a collection of key-update pairs. Unlike ``dict.update``, this method *additively* applies the updates to the accumulated values. Arguments: args: Positional arguments same as ``dict.update``. Can be a mapping object or an iterable of key-update pairs. kwargs: Keyword arguments same as ``dict.update``. Each keyword is the string key corresponding to the provided update. """ for key, val in dict(*args, **kwargs).items(): self[key] += val
[docs] def subtract(self, *args, **kwargs): """ Apply a collection of key-update pairs. Unlike :meth:`Accumulator.update`, this method *subtracts* the updates from the accumulated values. Arguments: args: Positional arguments same as :meth:`Accumulator.update`. kwargs: Keyword arguments same as :meth:`Accumulator.update`. """ for key, val in dict(*args, **kwargs).items(): self[key] -= val
[docs] def __iadd__(self, other): """ Supports the += operation, e.g. ``accum += {key1: val1, key2: val2}``. Behaves the same way as ``accum.update({key1: val1, key2: val2})``. Arguments: other: Mapping object or an iterable of key-update pairs. """ self.update(other) return self
[docs] def __isub__(self, other): """ Supports the -= operation, e.g. ``accum -= {key1: val1, key2: val2}``. Behaves the same way as ``accum.subtract({key1: val1, key2: val2})``. Arguments: other: Mapping object or an iterable of key-update pairs. """ self.subtract(other) return self
[docs] def __getitem__(self, key): """ Supports indexing, e.g. ``val = accum[key]`` and ``accum[key] += 1``. The former (read access) should only be used when the accumulator is in synchronized mode. Arguments: other: Key used to access a value in the accumulator. """ if self._synchronized is not None: return self._synchronized.__getitem__(key) # In accumulation mode, return a dummy object which captures all # updates performed on it, to be later applied by __setitem__. return _Value(self, key)
def __setitem__(self, key, value): if self._synchronized is not None: return self._synchronized.__setitem__(key, value) # Whenever an in-place addition or subtraction is done, like a[k] += v, # python will essentially perform 3 steps: (1) tmp = a.__getitem__(k), # (2) tmp += v, (3) a.__setitem__(k, tmp). In order to obtain the # update v, we let a.__getitem__(k) return an opaque object which # implements the __add__ operator to capture the update v in step (2). # Then, a.__setitem__(k, tmp) can pull v out of tmp and accumulate it. if not isinstance(value, _Value): raise TypeError("invalid value type: {}".format(type(value))) if value.accum is not self: raise ValueError("incompatible {}".format(self.__class__.__name__)) if key != value.key: raise ValueError("incompatible key: {}".format(value.key)) self._state.updates.setdefault(key, 0) self._state.updates[key] += value.update # Rest of the abstract methods needed by collections.MutableMapping def __contains__(self, key): if self._synchronized is not None: return self._synchronized.__contains__(key) return {}.__contains__(key) def __delitem__(self, key): if self._synchronized is not None: return self._synchronized.__delitem__(key) return {}.__delitem__(key) def __iter__(self): if self._synchronized is not None: return self._synchronized.__iter__() return {}.__iter__() def __len__(self): if self._synchronized is not None: return self._synchronized.__len__() return {}.__len__() def __repr__(self): if self._synchronized is not None: return self._synchronized.__repr__() return {}.__repr__()
class _Value(object): __slots__ = ["accum", "key", "update"] def __init__(self, accum, key): # Initialize the opaque object used for supporting "accum[k] += v" and # "accum[k] -= v" operations. self.accum = accum self.key = key self.update = 0 def __add__(self, update): if isinstance(update, _Value): raise TypeError("invalid update type: {}".format(type(update))) self.update += update return self def __sub__(self, update): if isinstance(update, _Value): raise TypeError("invalid update type: {}".format(type(update))) self.update -= update return self class _AccumulatorState(adaptdl.checkpoint.State): # Assume accumulators are initialized in the same order in every replica. # Keep a map of epoch -> number of accumulators initialized so far in that # epoch, and use that count to construct a unique name for the state. init_count = collections.Counter() def __init__(self, *args, **kwargs): if current_dataloader() is not None: raise RuntimeError("accumulator may not be initialized during " "dataloader iteration") epoch = current_epoch() count = _AccumulatorState.init_count[epoch] super().__init__("adaptdl-accumulator-epoch{}-{}".format(epoch, count)) _AccumulatorState.init_count[epoch] += 1 self.results_history = collections.defaultdict(list) self.results = dict(*args, **kwargs) self.updates = {} def save(self, fileobj): pickle.dump((self.results_history, self.results), fileobj) def load(self, fileobj): self.results_history, self.results = pickle.load(fileobj) def sync(self): # Aggregate pending updates across all replicas and apply them. updates = adaptdl.collective.allreduce(self.updates, _dict_iadd) _dict_iadd(self.results, updates) self.updates.clear() def _dict_iadd(a, b): for k, v in b.items(): if k in a: a[k] += v else: a[k] = v return a