Source code for adaptdl.torch

# Copyright 2020 Petuum, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import sys
import os
if "darwin" in sys.platform.lower():
    # To avoid multiple runs of the model code
    # https://pythonspeed.com/articles/python-multiprocessing/
    import multiprocessing
    multiprocessing.set_start_method('fork')

import logging
import portpicker
import requests
import torch.distributed
import pkg_resources

import adaptdl.collective
import adaptdl.env
import semver
from .epoch import current_epoch, finished_epochs, remaining_epochs_until
from .data import current_dataloader, AdaptiveDataLoader, ElasticSampler
from .parallel import AdaptiveDataParallel
from .accumulator import Accumulator

logging.basicConfig(level=logging.INFO)
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)


def version_check(version):
    if semver.VersionInfo.isvalid(version) and \
            version != "0.0.0":
        return True
    else:
        return False


[docs]def init_process_group(backend, init_method=None, world_size=None, rank=None): """ Initializes the default distributed process group and the AdaptDL collectives module. Args: backend (str or Backend): The backend to use. Use "nccl" for multi-GPU training else "gloo". init_method (str, optional): URL specifying how to initialize the process group. world_size (int, optional): Number of processes participating in the job rank (int, optional): Rank of the current process (it should be a number between 0 and ``world_size``-1). If init_method, world_size and rank is NOT provided, typically in the Kubernetes environment, AdaptDL will try to infer them through environment variables ADAPTDL_MASTER_ADDR, ADAPTDL_NUM_REPLICAS and ADAPTDL_REPLICA_RANK respectively. """ if adaptdl.env.from_ray(): from adaptdl_ray.adaptdl.utils import unique_nodes_pg assert init_method is not None assert world_size is not None assert rank is not None os.environ["ADAPTDL_NUM_NODES"] = str(unique_nodes_pg()) os.environ["ADAPTDL_REPLICA_RANK"] = str(rank) os.environ["ADAPTDL_NUM_REPLICAS"] = str(world_size) url = adaptdl.env.supervisor_url() master_port = adaptdl.env.master_port() if rank is None: rank = adaptdl.env.replica_rank() if world_size is None: world_size = adaptdl.env.num_replicas() if init_method is not None: _, master_addr, master_port = init_method.split(":") master_addr = master_addr[2:] master_port = int(master_port) elif url: key = adaptdl.env.job_id() group = adaptdl.env.num_restarts() while True: response = requests.get(url=f"{url}/discover/{key}/{group}") if response.status_code != 408: # Timeout. break response.raise_for_status() master_addr = response.json()[0] sched_version = adaptdl.env.adaptdl_sched_version() trainer_version = pkg_resources.get_distribution("adaptdl").version if version_check(sched_version) and version_check(trainer_version): trainer_ver_maj = semver.VersionInfo.parse(trainer_version).major sched_ver_maj = semver.VersionInfo.parse(sched_version).major if trainer_ver_maj != sched_ver_maj: raise Exception('adaptdl version {} is incompatible with' 'scheduler version {}'.format(trainer_version, sched_version)) else: master_addr = adaptdl.env.master_addr() # Initialize collective module. adaptdl.collective.initialize(master_addr, master_port, rank, world_size) # Initialize torch.distributed. torch_port = adaptdl.collective.broadcast(portpicker.pick_unused_port()) init_method = "tcp://{}:{}?rank={}&world_size={}".format( master_addr, torch_port, rank, world_size) LOG.info("Initializing torch.distributed using %s", init_method) torch.distributed.init_process_group(backend, init_method) LOG.info("torch.distributed initialized")
__all__ = [ "init_process_group", "current_epoch", "finished_epochs", "remaining_epochs_until", "current_dataloader", "AdaptiveDataLoader", "ElasticSampler", "AdaptiveDataParallel", "Accumulator", ]