adaptdl.torch.accumulator module
- class adaptdl.torch.accumulator.Accumulator(*args, **kwargs)[source]
Bases:
MutableMapping
This class helps aggregate simple statistics across all replicas in the current job, and across any number of checkpoint-restarts. Can be used to compute metrics like loss and accuracy, synchronized across each replica.
Accumulators imitate python dictionaries, but with a few key differences described below. Primarily, its usage and behavior depend on whether it is set to accumulation mode or to synchronized mode.
Accumulation mode: the accumulator is being updated on all replicas. Operations like
accum["key"] += val
oraccum.update(key=val)
will aggregate the updates locally on each replica, which are lazily synchronized in the background (either upon a checkpoint or a switch to synchronized mode). Each replica may make different updates, which are summed together when synchronized. While accumulation mode is enabled, all read operations on the accumulator will behave as if they were performed on an emptydict
, ie.len(accum)
will always return0
. By default, all accumulators are set to accumulation mode.Synchronized mode: the accumulator contains the same data on every replica, and the application must ensure that all write operations are exactly the same across all replicas. While in synchronized mode, the accumulator may be used as if it were a native python
dict
, and all read/write operations are supported.Accumulator.synchronized()
may be used to enter synchronized mode. Upon entering synchronized mode, the accumulator will automatically sum all updates from all replicas to ensure the same data is available to each replica.
Using accumulators, many training/validation metrics can be computed easily and correctly in an elastic distributed setting. For example, a simple validation step which calculates a loss and accuracy can be implemented as follows:
accum = Accumulator() # New accumulator starts in accumulation mode. for epoch in remaining_epochs_until(60): for batch in validloader: ... accum["loss_sum"] += <loss summed within the batch> accum["correct"] += <number of correct predictions> accum["total"] += <total number of samples in the batch> with accum.synchronized(): # Enter synchronized mode. accum["loss_avg"] = accum["loss_sum"] / accum["total"] accum["accuracy"] = accum["correct"] / accum["total"] print("Loss: {}, Accuracy: {}".format( accum["loss_avg"], accum["accuracy"])) accum.clear() # Back to accumulation mode.
- Parameters
args – Positional arguments same as
dict
.kwargs – Keyword arguments same as
dict
.
- __iadd__(other)[source]
Supports the += operation, e.g.
accum += {key1: val1, key2: val2}
. Behaves the same way asaccum.update({key1: val1, key2: val2})
.- Parameters
other – Mapping object or an iterable of key-update pairs.
- __isub__(other)[source]
Supports the -= operation, e.g.
accum -= {key1: val1, key2: val2}
. Behaves the same way asaccum.subtract({key1: val1, key2: val2})
.- Parameters
other – Mapping object or an iterable of key-update pairs.
- __getitem__(key)[source]
Supports indexing, e.g.
val = accum[key]
andaccum[key] += 1
. The former (read access) should only be used when the accumulator is in synchronized mode.- Parameters
other – Key used to access a value in the accumulator.
- subtract(*args, **kwargs)[source]
Apply a collection of key-update pairs. Unlike
Accumulator.update()
, this method subtracts the updates from the accumulated values.- Parameters
args – Positional arguments same as
Accumulator.update()
.kwargs – Keyword arguments same as
Accumulator.update()
.
- synchronized()[source]
A context manager which can be used to define the code to execute in synchronized mode. Within the context manager, any code can interact with this accumulator as if it were a regular Python
dict
. The application must ensure that whatever operations performed within this context block are the same across all replicas.Warning
Entering this context manager is a distributed synchronization point! Please ensure that all replicas enter this context manager at the same point in their code.
- update(*args, **kwargs)[source]
Apply a collection of key-update pairs. Unlike
dict.update
, this method additively applies the updates to the accumulated values.- Parameters
args – Positional arguments same as
dict.update
. Can be a mapping object or an iterable of key-update pairs.kwargs – Keyword arguments same as
dict.update
. Each keyword is the string key corresponding to the provided update.