adaptdl.torch.accumulator module

class adaptdl.torch.accumulator.Accumulator(*args, **kwargs)[source]

Bases: MutableMapping

This class helps aggregate simple statistics across all replicas in the current job, and across any number of checkpoint-restarts. Can be used to compute metrics like loss and accuracy, synchronized across each replica.

Accumulators imitate python dictionaries, but with a few key differences described below. Primarily, its usage and behavior depend on whether it is set to accumulation mode or to synchronized mode.

  1. Accumulation mode: the accumulator is being updated on all replicas. Operations like accum["key"] += val or accum.update(key=val) will aggregate the updates locally on each replica, which are lazily synchronized in the background (either upon a checkpoint or a switch to synchronized mode). Each replica may make different updates, which are summed together when synchronized. While accumulation mode is enabled, all read operations on the accumulator will behave as if they were performed on an empty dict, ie. len(accum) will always return 0. By default, all accumulators are set to accumulation mode.

  2. Synchronized mode: the accumulator contains the same data on every replica, and the application must ensure that all write operations are exactly the same across all replicas. While in synchronized mode, the accumulator may be used as if it were a native python dict, and all read/write operations are supported. Accumulator.synchronized() may be used to enter synchronized mode. Upon entering synchronized mode, the accumulator will automatically sum all updates from all replicas to ensure the same data is available to each replica.

Using accumulators, many training/validation metrics can be computed easily and correctly in an elastic distributed setting. For example, a simple validation step which calculates a loss and accuracy can be implemented as follows:

accum = Accumulator()  # New accumulator starts in accumulation mode.

for epoch in remaining_epochs_until(60):

    for batch in validloader:
        ...
        accum["loss_sum"] += <loss summed within the batch>
        accum["correct"] += <number of correct predictions>
        accum["total"] += <total number of samples in the batch>

    with accum.synchronized():  # Enter synchronized mode.
        accum["loss_avg"] = accum["loss_sum"] / accum["total"]
        accum["accuracy"] = accum["correct"] / accum["total"]
        print("Loss: {}, Accuracy: {}".format(
              accum["loss_avg"], accum["accuracy"]))
        accum.clear()
    # Back to accumulation mode.
Parameters
  • args – Positional arguments same as dict.

  • kwargs – Keyword arguments same as dict.

__iadd__(other)[source]

Supports the += operation, e.g. accum += {key1: val1, key2: val2}. Behaves the same way as accum.update({key1: val1, key2: val2}).

Parameters

other – Mapping object or an iterable of key-update pairs.

__isub__(other)[source]

Supports the -= operation, e.g. accum -= {key1: val1, key2: val2}. Behaves the same way as accum.subtract({key1: val1, key2: val2}).

Parameters

other – Mapping object or an iterable of key-update pairs.

__getitem__(key)[source]

Supports indexing, e.g. val = accum[key] and accum[key] += 1. The former (read access) should only be used when the accumulator is in synchronized mode.

Parameters

other – Key used to access a value in the accumulator.

subtract(*args, **kwargs)[source]

Apply a collection of key-update pairs. Unlike Accumulator.update(), this method subtracts the updates from the accumulated values.

Parameters
synchronized()[source]

A context manager which can be used to define the code to execute in synchronized mode. Within the context manager, any code can interact with this accumulator as if it were a regular Python dict. The application must ensure that whatever operations performed within this context block are the same across all replicas.

Warning

Entering this context manager is a distributed synchronization point! Please ensure that all replicas enter this context manager at the same point in their code.

update(*args, **kwargs)[source]

Apply a collection of key-update pairs. Unlike dict.update, this method additively applies the updates to the accumulated values.

Parameters
  • args – Positional arguments same as dict.update. Can be a mapping object or an iterable of key-update pairs.

  • kwargs – Keyword arguments same as dict.update. Each keyword is the string key corresponding to the provided update.