# Copyright 2020 Petuum, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
import collections
import functools
import logging
import math
import numpy as np
import pickle
import random
import torch
from torch.utils.data import DataLoader, Sampler
import adaptdl.checkpoint
import adaptdl.collective
import adaptdl.env
from adaptdl.torch.epoch import current_epoch
from adaptdl.torch._metrics import (
profile_step_start, profile_step_commit,
set_batch_size, get_goodput_fn, get_progress)
from adaptdl._signal import get_exit_flag
logging.basicConfig(level=logging.INFO)
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)
[docs]class ElasticSampler(Sampler):
"""
A PyTorch Sampler which partitions data samples across multiple replicas,
and supports deterministic continuing across checkpoint-restarts. Shuffling
is deterministic for each epoch, and :meth:`ElasticSampler.set_epoch`
should be invoked to obtain different orderings in different epochs.
Arguments:
dataset (torch.util.data.Dataset): The dataset to sample from.
shuffle (bool): Whether the data samples should be shuffled.
.. automethod:: __iter__
.. automethod:: __len__
"""
def __init__(self, dataset, shuffle=True):
self.dataset = dataset
self.shuffle = shuffle
self.num_replicas = adaptdl.env.num_replicas()
self.rank = adaptdl.env.replica_rank()
self.epoch = 0
self.index = 0
[docs] def __iter__(self):
"""
Iterate through the samples in the dataset, in the order defined for a
set epoch, starting at a set index. Produces only the indices for the
local replica.
Returns: Iterator over data sample indices.
"""
if self.shuffle:
# Deterministically shuffle based on epoch.
g = torch.Generator()
g.manual_seed(hash((self.epoch, self.index // len(self.dataset))))
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
base_index = self.index % len(self.dataset)
# Subsample.
local_indices = indices[base_index + self.rank::self.num_replicas]
# Add extra samples to make it evenly divisible.
if len(local_indices) < len(self):
local_indices.append(indices[self.rank])
assert len(local_indices) == len(self)
return iter(local_indices)
[docs] def __len__(self):
"""
The total number of samples to be iterated through, starting at the set
index, for the local replica.
Returns (int): Number of samples.
"""
base_index = self.index % len(self.dataset)
return math.ceil((len(self.dataset) - base_index) / self.num_replicas)
[docs] def set_epoch(self, epoch, index=0):
"""
Set the epoch to derive samples from. Optional argument ``index`` can
be specified to start sampling from a particular index, e.g. after a
checkpoint-restart.
Arguments:
epoch (int): The epoch to sample from.
index (int): The index to start sampling from.
"""
self.epoch = epoch
self.index = index
[docs]def current_dataloader():
"""
Reference to the data loader currently being iterated.
Returns (AdaptiveDataLoaderHelper): Current data loader.
"""
return AdaptiveDataLoaderHelper._current
[docs]class AdaptiveDataLoaderHelper(object):
"""
This class provides fine-grained control over adaptive training loops. It
can be used for building more user-friendly custom data loaders, such as
:class:`AdaptiveDataLoader`.
Arguments:
batch_size (int): The target total batch size across all replicas. The
actual total batch size may be different due to rounding (each
replica must have the same local batch size), or being scaled up
using adaptive batch sizes.
"""
# Epoch -> the number of dataloader loops completed so far in that epoch,
# across all AdaptiveDataLoader objects.
_position = collections.Counter()
_training = None # The AdaptiveDataLoader which loads training data.
_current = None # The AdaptiveDataLoader which is currently iterating.
def __init__(self, batch_size=1):
# Autoscale batch size fields.
self._max_batch_size = None
self._local_bsz_bounds = None
# Create and load state.
self._state = _AdaptiveDataLoaderState()
adaptdl.checkpoint.load_state(self._state)
self.batch_size = batch_size
self.future_exit = None
self._gradient_accumulation = False
self._speedup_threshold = 1.05
self._accum_count = 0
@property
def current_index(self):
"""
The total number of data samples processed so far in the current loop.
Includes the data processed by all replicas. ``None`` if this data
loader is not currently being iterated.
"""
if AdaptiveDataLoaderHelper._current is not self:
return None
return self._state.current_index
@current_index.setter
def current_index(self, index):
if AdaptiveDataLoaderHelper._current is not self:
return
self._state.current_index = index
@property
def end_index(self):
"""
(Optional) Can be used to track the end index of dataset across
restarts.
"""
return self._state.end_index
@end_index.setter
def end_index(self, index):
"""
(Optional) Supports mutations of end_index
"""
self._state.end_index = index
@property
def max_batch_size(self):
"""
The maximum total batch size allowed for adaptive batch size. ``None``
if adaptive batch size is disabled.
"""
return self._max_batch_size
@property
def local_bsz_bounds(self):
"""
The local batch size bounds on each replica. A pair of integers,
(min_local_bsz, max_local_bsz).
"""
return self._local_bsz_bounds
@property
def current_local_bsz(self):
"""
The current logical local batch size used by the dataloader.
The batch size returned by the dataloader may be smaller if
gradient accumulation is used
"""
return self._state.current_local_bsz
@property
def accumulation_steps(self):
"""
The number of batches returned by the dataloader before a
step is taken.
"""
return self._state.accumulation_steps
[docs] def is_accum_step(self):
"""
Whether the current step's gradient will be accumulated.
"""
return self._accum_count < self._state.accumulation_steps
[docs] def is_optim_step(self):
"""
Whether the optimizer step will be invoked in this step.
"""
return not self.is_accum_step()
[docs] def train(self):
"""
Set this data loader to be the one used for training. Only one data
loader may be used for training.
"""
if AdaptiveDataLoaderHelper._training is None:
AdaptiveDataLoaderHelper._training = self
set_batch_size(self.batch_size, self.max_batch_size,
self.local_bsz_bounds, self._gradient_accumulation)
[docs] def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None,
gradient_accumulation=False):
"""
Enables adaptive batch size. Should be invoked once after the data
loader object is created.
Arguments:
max_batch_size (int): Maximum total batch size allowed.
local_bsz_bounds (tuple): A pair of (min_local_bsz, max_local_bsz),
the min and max local batch sizes allowed on each replica.
Raises:
ValueError: If any of the provided batch size bounds are invalid.
"""
if not isinstance(max_batch_size, int) or \
max_batch_size < self.batch_size:
raise ValueError("invalid max_batch_size")
if local_bsz_bounds is not None and (
local_bsz_bounds[0] is not None and
local_bsz_bounds[0] > self.batch_size or
local_bsz_bounds[1] is not None and
local_bsz_bounds[1] < self.batch_size):
raise ValueError("invalid local_bsz_bounds")
self._max_batch_size = max_batch_size
self._local_bsz_bounds = local_bsz_bounds
self._gradient_accumulation = gradient_accumulation
self.train()
def _sync_local_bsz(self):
goodput_fn = get_goodput_fn()
if self.max_batch_size is None or goodput_fn is None:
# No autoscale batch size, just divide batch size evenly.
self._state.current_local_bsz = math.ceil(
self.batch_size / adaptdl.env.num_replicas())
self._state.accumulation_steps = 0
elif not self._state.current_local_bsz:
# if init, use the batch size suggested
_, atomic_bsz, accum_steps = goodput_fn.optimize(
adaptdl.env.num_nodes(), adaptdl.env.num_replicas(),
max_batch_size=self._max_batch_size,
atomic_bsz_range=self._local_bsz_bounds,
accumulation=self._gradient_accumulation)
self._state.current_local_bsz = atomic_bsz
self._state.accumulation_steps = accum_steps
else:
# if not first time, we check against the relative speedup
suggest_goodput, atomic_bsz, accum_steps = goodput_fn.optimize(
adaptdl.env.num_nodes(), adaptdl.env.num_replicas(),
max_batch_size=self._max_batch_size,
atomic_bsz_range=self._local_bsz_bounds,
accumulation=self._gradient_accumulation)
# get current goodput
current_goodput = goodput_fn(
adaptdl.env.num_nodes(), adaptdl.env.num_replicas(),
self.current_local_bsz, self.accumulation_steps)
# use only if speedup is significant
speedup = suggest_goodput / max(current_goodput, 1e-8)
if speedup > self._speedup_threshold:
self._state.current_local_bsz = atomic_bsz
self._state.accumulation_steps = accum_steps
self._state.current_local_bsz, self._state.accumulation_steps = \
adaptdl.collective.broadcast((self._state.current_local_bsz,
self._state.accumulation_steps))
return self.current_local_bsz
@property
def training(self):
return self is AdaptiveDataLoaderHelper._training
[docs] @contextmanager
def profile(self, commit):
"""
Every iteration of every epoch should be profiled under this context.
Note that, custom DataLoader writers should make sure that it gets
called equal number of times on each replica.
Arguments:
commit (bool): Whether to commit the profiled results.
"""
# Synchronize the exit signal so all replicas exit after
# the same iteration. Do this asynchronously to prevent
# unnecessary blocking on the network.
if self.future_exit is not None and self.future_exit.result():
adaptdl.checkpoint.save_all_states()
exit(143) # Standard exit code response to SIGTERM.
self.future_exit = adaptdl.collective.allreduce_async(
get_exit_flag(), lambda a, b: a or b)
profile_step_start(self.current_local_bsz)
yield
if commit:
profile_step_commit(self.is_accum_step())
self._accum_count = (0 if self.is_optim_step()
else self._accum_count + 1)
[docs] @contextmanager
def context(self):
"""
All iterators should be iterated under this context. It ensures
proper cleanup of elastic context at the end of each epoch.
"""
epoch = current_epoch()
try:
if AdaptiveDataLoaderHelper._current is not None:
raise RuntimeError("overlapping dataloader \
iterations detected")
AdaptiveDataLoaderHelper._current = self
yield
finally:
self._state.current_index = 0
self._state.end_index = 0
self._state.last_position[epoch] = self._position[epoch]
self._position[epoch] += 1
AdaptiveDataLoaderHelper._current = None
@property
def current_batch_size(self):
return (self.current_local_bsz * (self.accumulation_steps + 1) *
adaptdl.env.num_replicas())
[docs] def skipdone(self):
"""
Should be called just after entering the `_elastic` context to make
sure that the dataloader loop is not replayed if has already finished
before a restart.
"""
epoch = current_epoch()
position = self._position[epoch]
if position <= self._state.last_position.get(epoch, -1):
# Already completed the dataloader loop at the current
# position, skip this loop and keep replaying the application
# code.
LOG.info("skipping %s loop at position %s in epoch %s",
self.__class__.__name__, position, epoch)
self._position[epoch] += 1
return True
else:
return False
[docs] def to_tensorboard(self, writer, global_step, tag_prefix=""):
"""
Output some useful metrics to TensorBoard.
Arguments:
writer (torch.utils.tensorboard.SummaryWriter): ``SummaryWriter``
object to output metrics to.
global_step (int): Global step value to record.
tag_prefix (str): Prefix added to each metric's tag.
"""
if tag_prefix and not tag_prefix.endswith("/"):
tag_prefix += "/"
writer.add_scalar(tag_prefix + "Total_Batch_Size",
self.current_batch_size, global_step)
writer.add_scalar(tag_prefix + "Local_Batch_Size",
self.current_local_bsz, global_step)
writer.add_scalar(tag_prefix + "Accumulation_Steps",
self.accumulation_steps, global_step)
[docs]class AdaptiveDataLoaderMixin(object):
"""
This class provides elastic functionality to any custom DataLoader which
inherits it. It defines a member _elastic of type
:class:`AdaptiveDataLoaderHelper` which has useful methods and members to
implement restart-safe, elastic DataLoaders. It also exposes public methods
which can be used inside training loops directly from
:class:`AdaptiveDataLoader`.
"""
def __init__(self, batch_size):
self._elastic = AdaptiveDataLoaderHelper(batch_size)
[docs] def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None,
gradient_accumulation=False):
self._elastic.autoscale_batch_size(max_batch_size, local_bsz_bounds,
gradient_accumulation)
@property
def current_local_bsz(self):
if AdaptiveDataLoaderHelper._current is not self._elastic:
return None
return self._elastic.current_local_bsz
@property
def accumulation_steps(self):
"""
The number of batches returned by the dataloader before a
step is taken.
"""
return self._elastic.accumulation_steps
@property
def training(self):
return self._elastic.training
@property
def current_batch_size(self):
if AdaptiveDataLoaderHelper._current is not self._elastic:
return None
return self._elastic.current_batch_size
[docs] def to_tensorboard(self, writer, global_step, tag_prefix=""):
self._elastic.to_tensorboard(writer, global_step, tag_prefix)
to_tensorboard.__doc__ = AdaptiveDataLoaderHelper.to_tensorboard.__doc__
def _worker_init_wrapper(worker_init_fn, num_workers):
# Set globally-unique python and numpy seeds for each worker.
@functools.wraps(worker_init_fn)
def wrapper(worker_id):
nonlocal num_workers
num_workers = num_workers or 1
# https://pytorch.org/docs/master/data.html#randomness-in-multi-process-data-loading.
seed = torch.initial_seed() + adaptdl.env.replica_rank() * num_workers
torch.manual_seed(seed)
np.random.seed(seed % 2 ** 32)
random.seed(seed)
if worker_init_fn is not None:
return worker_init_fn(worker_id)
return wrapper
[docs]class AdaptiveDataLoader(DataLoader, AdaptiveDataLoaderMixin):
"""
This class is a PyTorch DataLoader that also supports adaptive batch sizes
and checkpoint-restart elasticity. Applications can typically use objects
of this class as direct replacements for PyTorch DataLoaders. However, some
notable differences are:
1. The ``batch_size`` argument defines the target total batch size across
all replicas, rather than the local batch size on each replica.
2. Custom ``sampler`` and ``batch_sampler`` are not supported.
3. Iterating through the dataloader is only allowed from within an epoch
loop (see :mod:`adaptdl.torch.epoch`), and only one dataloader loop is
allowed at any given time.
Arguments:
dataset (torch.util.data.Dataset): Dataset from which to load the data.
batch_size (int): The target total batch size across all replicas. The
actual total batch size may be different due to rounding (each
replica must have the same local batch size), or being scaled up
using adaptive batch sizes.
shuffle (bool): Whether the data is reshuffled at every epoch.
**kwargs: Keyword arguments passed to ``torch.util.data.Dataloader``.
Raises:
ValueError: If ``sampler`` or ``batch_sampler`` are not ``None``.
.. automethod:: __iter__
"""
def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs):
if kwargs.get("batch_sampler") is not None \
or kwargs.get("sampler") is not None:
raise ValueError("AdaptiveDataLoader does not support "
"custom 'sampler' or 'batch_sampler'")
# Custom sampler is incompatible with shuffle=True, so we always set
# shuffle=False in __init__ and let our own sampler do the shuffling.
kwargs["sampler"] = ElasticSampler(dataset, shuffle=shuffle)
kwargs["worker_init_fn"] = _worker_init_wrapper(
kwargs.get("worker_init_fn"), kwargs.get("num_workers"))
super().__init__(dataset, batch_size, shuffle=False, **kwargs)
AdaptiveDataLoaderMixin.__init__(self, batch_size)
[docs] def __iter__(self):
"""
Iterate over batches of data. When adaptive batch size is disabled,
stops after the entire dataset has been processed once in total by all
replicas. This means if there are K replicas, then this method will
iterate over ~1/K of the dataset. When adaptive batch size is enabled,
stops after making enough statistical progress roughly equivalent to
one pass over the dataset with non-adaptive batch size. In this case,
the dataset may be processed more than once.
A checkpoint-restart may be triggered in-between each batch. In this
case, the current iteration state will be saved and restored after the
restart, and continue where it left off.
"""
epoch = current_epoch()
num_replicas = adaptdl.env.num_replicas()
with self._elastic.context():
if self._elastic.skipdone():
return
done = False
while not done:
self.sampler.set_epoch(
epoch, index=self._elastic.current_index)
self.batch_sampler.batch_size = self._elastic._sync_local_bsz()
for idx, batch in enumerate(super().__iter__()):
with self._elastic.profile(self.training and idx >= 1):
yield batch
# Increment by the number of data samples processed
self._elastic.current_index += \
num_replicas * self.batch_sampler.batch_size
if self._elastic.max_batch_size is not None and \
get_progress() >= len(self.dataset) * \
(epoch + 1) / self.batch_size:
done = True
break
if self._elastic.max_batch_size is None:
done = True
self._elastic.current_index -= \
self._elastic.current_index % -len(self.dataset)
class _AdaptiveDataLoaderState(adaptdl.checkpoint.State):
# Assume dataloaders are initialized in the same order in every replica.
# Keep a map of epoch -> number of dataloaders initialized so far in that
# epoch, and use that count to construct a unique name for the state.
init_count = collections.Counter()
def __init__(self):
if current_dataloader() is not None:
raise RuntimeError("dataloader may not be initialized during "
"dataloader iteration")
epoch = current_epoch()
count = _AdaptiveDataLoaderState.init_count[epoch]
super().__init__("adaptdl-dataloader-epoch{}-{}".format(epoch, count))
_AdaptiveDataLoaderState.init_count[epoch] += 1
self.current_index = 0 # Index within the current dataloader loop.
self.end_index = 0 # End index of the current DataLoader loop.
self.last_position = {} # Epoch -> position of last completed loop.
self.current_local_bsz = 0
self.accumulation_steps = 0
def save(self, fileobj):
pickle.dump((self.current_index, self.end_index,
self.last_position), fileobj)
def load(self, fileobj):
self.current_index, self.end_index, self.last_position = \
pickle.load(fileobj)