|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import multiprocessing as mp
-
- from ..device import get_device_count
- from .group import init_process_group
- from .server import Server
- from .util import get_free_ports
-
-
- def _get_device_count():
- """use subprocess to avoid cuda environment initialization in the main process"""
-
- def run(q):
- count = get_device_count("gpu")
- q.put(count)
-
- q = mp.Queue()
- p = mp.Process(target=run, args=(q,))
- p.start()
- p.join()
- return q.get()
-
-
- def _run_wrapped(func, master_ip, port, world_size, rank, dev, args, kwargs):
- """init distributed process group and run wrapped function"""
- init_process_group(
- master_ip=master_ip, port=port, world_size=world_size, rank=rank, device=dev
- )
- func(*args, **kwargs)
-
-
- def launcher(n_gpus):
- """decorator for launching multiple processes in single-machine multi-gpu training"""
-
- count = _get_device_count()
- assert isinstance(n_gpus, int) and n_gpus > 1, "invalid n_gpus"
- assert n_gpus <= count, "{} gpus required, {} gpus provided".format(n_gpus, count)
-
- def decorator(func):
- def wrapper(*args, **kwargs):
- master_ip = "localhost"
- port = get_free_ports(1)[0]
- server = Server(port)
-
- procs = []
- for rank in range(n_gpus):
- p = mp.Process(
- target=_run_wrapped,
- args=(func, master_ip, port, n_gpus, rank, rank, args, kwargs),
- )
- p.start()
- procs.append(p)
-
- for rank in range(n_gpus):
- procs[rank].join()
- code = procs[rank].exitcode
- assert code == 0, "subprocess {} exit with code {}".format(rank, code)
-
- return wrapper
-
- return decorator
|