|
|
@@ -8,26 +8,12 @@ |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import multiprocessing as mp |
|
|
|
|
|
|
|
from ..device import get_device_count |
|
|
|
from .group import init_process_group |
|
|
|
from .helper import get_device_count_by_fork |
|
|
|
from .server import Server |
|
|
|
from .util import get_free_ports |
|
|
|
|
|
|
|
|
|
|
|
def _get_device_count(): |
|
|
|
"""use subprocess to avoid cuda environment initialization in the main process""" |
|
|
|
|
|
|
|
def run(q): |
|
|
|
count = get_device_count("gpu") |
|
|
|
q.put(count) |
|
|
|
|
|
|
|
q = mp.Queue() |
|
|
|
p = mp.Process(target=run, args=(q,)) |
|
|
|
p.start() |
|
|
|
p.join() |
|
|
|
return q.get() |
|
|
|
|
|
|
|
|
|
|
|
def _run_wrapped(func, master_ip, port, world_size, rank, dev, args, kwargs): |
|
|
|
"""init distributed process group and run wrapped function""" |
|
|
|
init_process_group( |
|
|
@@ -36,33 +22,28 @@ def _run_wrapped(func, master_ip, port, world_size, rank, dev, args, kwargs): |
|
|
|
func(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
def launcher(n_gpus): |
|
|
|
def launcher(func): |
|
|
|
"""decorator for launching multiple processes in single-machine multi-gpu training""" |
|
|
|
|
|
|
|
count = _get_device_count() |
|
|
|
assert isinstance(n_gpus, int) and n_gpus > 1, "invalid n_gpus" |
|
|
|
assert n_gpus <= count, "{} gpus required, {} gpus provided".format(n_gpus, count) |
|
|
|
|
|
|
|
def decorator(func): |
|
|
|
def wrapper(*args, **kwargs): |
|
|
|
master_ip = "localhost" |
|
|
|
port = get_free_ports(1)[0] |
|
|
|
server = Server(port) |
|
|
|
n_gpus = get_device_count_by_fork("gpu") |
|
|
|
|
|
|
|
procs = [] |
|
|
|
for rank in range(n_gpus): |
|
|
|
p = mp.Process( |
|
|
|
target=_run_wrapped, |
|
|
|
args=(func, master_ip, port, n_gpus, rank, rank, args, kwargs), |
|
|
|
) |
|
|
|
p.start() |
|
|
|
procs.append(p) |
|
|
|
def wrapper(*args, **kwargs): |
|
|
|
master_ip = "localhost" |
|
|
|
port = get_free_ports(1)[0] |
|
|
|
server = Server(port) |
|
|
|
|
|
|
|
for rank in range(n_gpus): |
|
|
|
procs[rank].join() |
|
|
|
code = procs[rank].exitcode |
|
|
|
assert code == 0, "subprocess {} exit with code {}".format(rank, code) |
|
|
|
procs = [] |
|
|
|
for rank in range(n_gpus): |
|
|
|
p = mp.Process( |
|
|
|
target=_run_wrapped, |
|
|
|
args=(func, master_ip, port, n_gpus, rank, rank, args, kwargs), |
|
|
|
) |
|
|
|
p.start() |
|
|
|
procs.append(p) |
|
|
|
|
|
|
|
return wrapper |
|
|
|
for rank in range(n_gpus): |
|
|
|
procs[rank].join() |
|
|
|
code = procs[rank].exitcode |
|
|
|
assert code == 0, "subprocess {} exit with code {}".format(rank, code) |
|
|
|
|
|
|
|
return decorator |
|
|
|
return wrapper |