|
|
@@ -8,15 +8,30 @@ |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import functools |
|
|
|
import multiprocessing as mp |
|
|
|
import queue |
|
|
|
|
|
|
|
from ..core._imperative_rt.core2 import sync |
|
|
|
from ..logger import get_logger |
|
|
|
from .group import group_barrier, init_process_group |
|
|
|
from .helper import get_device_count_by_fork |
|
|
|
from .server import Client, Server |
|
|
|
|
|
|
|
WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = ( |
|
|
|
"subprocess exited with code 0 but did not return a value" |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _run_wrapped( |
|
|
|
func, is_multimachine, master_ip, port, world_size, rank, dev, args, kwargs |
|
|
|
func, |
|
|
|
is_multimachine, |
|
|
|
master_ip, |
|
|
|
port, |
|
|
|
world_size, |
|
|
|
rank, |
|
|
|
dev, |
|
|
|
args, |
|
|
|
kwargs, |
|
|
|
queue: mp.Queue, |
|
|
|
): |
|
|
|
"""Init distributed process group and run wrapped function.""" |
|
|
|
init_process_group( |
|
|
@@ -24,7 +39,8 @@ def _run_wrapped( |
|
|
|
) |
|
|
|
if is_multimachine: |
|
|
|
group_barrier() |
|
|
|
func(*args, **kwargs) |
|
|
|
ret = func(*args, **kwargs) |
|
|
|
queue.put((dev, ret)) |
|
|
|
sync() |
|
|
|
if is_multimachine: |
|
|
|
group_barrier() |
|
|
@@ -70,6 +86,8 @@ class launcher: |
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
procs = [] |
|
|
|
queue = mp.Queue(self.n_gpus) |
|
|
|
results = [None] * self.n_gpus |
|
|
|
for dev in range(self.n_gpus): |
|
|
|
p = mp.Process( |
|
|
|
target=_run_wrapped, |
|
|
@@ -83,6 +101,7 @@ class launcher: |
|
|
|
dev, |
|
|
|
args, |
|
|
|
kwargs, |
|
|
|
queue, |
|
|
|
), |
|
|
|
) |
|
|
|
p.start() |
|
|
@@ -90,6 +109,11 @@ class launcher: |
|
|
|
|
|
|
|
devs = list(range(self.n_gpus)) |
|
|
|
|
|
|
|
def terminate(): |
|
|
|
for dev in devs: |
|
|
|
procs[dev].terminate() |
|
|
|
devs.clear() |
|
|
|
|
|
|
|
while len(devs) > 0: |
|
|
|
left = [] |
|
|
|
# check all processes in one second |
|
|
@@ -99,11 +123,17 @@ class launcher: |
|
|
|
code = procs[dev].exitcode |
|
|
|
# terminate processes if one of them has failed |
|
|
|
if code != 0 and code != None: |
|
|
|
for i in devs: |
|
|
|
procs[i].terminate() |
|
|
|
terminate() |
|
|
|
assert ( |
|
|
|
code == 0 or code == None |
|
|
|
), "subprocess {} exit with code {}".format(dev + self.rank_start, code) |
|
|
|
if code == None: |
|
|
|
left.append(dev) |
|
|
|
elif queue.empty(): |
|
|
|
get_logger().warning(WARN_SUBPROCESS_EXIT_WITHOUT_RETURN) |
|
|
|
else: |
|
|
|
dev, ret = queue.get_nowait() |
|
|
|
results[dev] = ret |
|
|
|
devs = left |
|
|
|
|
|
|
|
return results |