diff --git a/imperative/python/megengine/distributed/launcher.py b/imperative/python/megengine/distributed/launcher.py index 6cbe1ec2..f90683bd 100644 --- a/imperative/python/megengine/distributed/launcher.py +++ b/imperative/python/megengine/distributed/launcher.py @@ -6,59 +6,105 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import functools import multiprocessing as mp -from .group import init_process_group +from ..core._imperative_rt import sync +from .group import group_barrier, init_process_group from .helper import get_device_count_by_fork from .server import Server from .util import get_free_ports -def _run_wrapped(func, master_ip, port, world_size, rank, dev, args, kwargs): +def _run_wrapped( + func, is_multimachine, master_ip, port, world_size, rank, dev, args, kwargs +): """Init distributed process group and run wrapped function.""" init_process_group( master_ip=master_ip, port=port, world_size=world_size, rank=rank, device=dev ) + if is_multimachine: + group_barrier() func(*args, **kwargs) + sync() + if is_multimachine: + group_barrier() -def launcher(func): - """Decorator for launching multiple processes in single-machine multi-gpu training.""" +class launcher: + """Decorator for launching multiple processes in single-machine multi-gpu training. + + :param func: the function you want to launch in distributed mode. + :param n_gpus: how many devices each node. + :param world_size: how many devices totally. + :param rank_start: start number for rank. + :param master_ip: ip address for master node (where the rank 0 is). + :param port: server port for distributed server. + """ - n_gpus = get_device_count_by_fork("gpu") + def __new__(cls, *args, **kwargs): + if not args: + return functools.partial(cls, **kwargs) + return super().__new__(cls) - def wrapper(*args, **kwargs): - master_ip = "localhost" - server = Server() - port = server.py_server_port + def __init__( + self, + func, + n_gpus=None, + world_size=None, + rank_start=0, + master_ip="localhost", + port=0, + ): + self.func = func + self.n_gpus = n_gpus if n_gpus is not None else get_device_count_by_fork("gpu") + self.world_size = world_size if world_size is not None else self.n_gpus + self.rank_start = rank_start + self.master_ip = master_ip + self.port = port + # master node create server + if self.rank_start == 0: + self.server = Server(self.port) + self.port = self.server.py_server_port + else: + assert self.port != 0, "you have to assign a port for distributed server" + def __call__(self, *args, **kwargs): procs = [] - for rank in range(n_gpus): + for dev in range(self.n_gpus): p = mp.Process( target=_run_wrapped, - args=(func, master_ip, port, n_gpus, rank, rank, args, kwargs), + args=( + self.func, + self.world_size > self.n_gpus, + self.master_ip, + self.port, + self.world_size, + dev + self.rank_start, + dev, + args, + kwargs, + ), ) p.start() procs.append(p) - ranks = [rank for rank in range(n_gpus)] + devs = list(range(self.n_gpus)) - while len(ranks) > 0: + while len(devs) > 0: left = [] # check all processes in one second - time_to_wait = 1.0 / len(ranks) - for rank in ranks: - procs[rank].join(time_to_wait) - code = procs[rank].exitcode + time_to_wait = 1.0 / len(devs) + for dev in devs: + procs[dev].join(time_to_wait) + code = procs[dev].exitcode # terminate processes if one of them has failed if code != 0 and code != None: - for i in ranks: + for i in devs: procs[i].terminate() assert ( code == 0 or code == None - ), "subprocess {} exit with code {}".format(rank, code) + ), "subprocess {} exit with code {}".format(dev + self.rank_start, code) if code == None: - left.append(rank) - ranks = left - - return wrapper + left.append(dev) + devs = left diff --git a/imperative/python/test/integration/test_dp_correctness.py b/imperative/python/test/integration/test_dp_correctness.py index bf0b6206..7e39c296 100644 --- a/imperative/python/test/integration/test_dp_correctness.py +++ b/imperative/python/test/integration/test_dp_correctness.py @@ -6,7 +6,6 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import multiprocessing as mp import os import platform import re diff --git a/imperative/python/test/integration/test_param_pack.py b/imperative/python/test/integration/test_param_pack.py index 35ac665c..44f1afed 100644 --- a/imperative/python/test/integration/test_param_pack.py +++ b/imperative/python/test/integration/test_param_pack.py @@ -6,7 +6,6 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import multiprocessing as mp import platform import numpy as np diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index e7c59fab..dd11f3ad 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -17,7 +17,6 @@ import megengine.functional as F import megengine.module as M import megengine.optimizer as optim from megengine.autodiff import GradManager -from megengine.core._imperative_rt.imperative import sync from megengine.distributed.helper import get_device_count_by_fork from megengine.jit import trace @@ -135,6 +134,5 @@ def test_remote_grad(): for func in train_funcs: for i in range(3): func(x) - sync() worker() diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index c797f6fd..0c6693f0 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -17,7 +17,6 @@ import megengine as mge import megengine.distributed as dist import megengine.functional as F from megengine.core._imperative_rt import TensorAttr, imperative -from megengine.core._imperative_rt.imperative import sync from megengine.core.autodiff.grad import Grad from megengine.core.ops.builtin import Elemwise from megengine.core.tensor.raw_tensor import as_raw_tensor @@ -65,47 +64,31 @@ def save_to(self, name="grad"): def test_dist_grad(): world_size = 2 x_np = np.random.rand(10).astype("float32") - server = dist.Server() - port = server.py_server_port - - def worker0(): - dist.init_process_group("localhost", port, world_size, 0, 0) - mge.device.set_default_device("gpu0") - grad = Grad() - - x = as_tensor(x_np) - grad.wrt(x, callback=save_to(x)) - # need a placeholder to trace operator - send_x = remote_send(x, 1) - recv_x = remote_recv(1, x_np.shape, x_np.dtype, "gpu0") - y = recv_x * recv_x - - grad([y], [as_tensor(np.ones_like(x_np))]) - np.testing.assert_almost_equal(x.grad.numpy(), x.numpy() * 2) - - def worker1(): - dist.init_process_group("localhost", port, world_size, 1, 1) - mge.device.set_default_device("gpu1") - grad = Grad() - - recv_x = remote_recv(0, x_np.shape, x_np.dtype, "gpu1") - send_x = remote_send(recv_x, 0) - - grad([], []) - - # sync because grad has a send operator - sync() - send_x.device._cn._sync_all() - - import multiprocessing as mp - - p0 = mp.Process(target=worker0) - p1 = mp.Process(target=worker1) - p0.start() - p1.start() - p0.join(10) - p1.join(10) - assert p0.exitcode == 0 and p1.exitcode == 0 + + @dist.launcher + def worker(): + rank = dist.get_rank() + if rank == 0: + grad = Grad() + + x = as_tensor(x_np) + grad.wrt(x, callback=save_to(x)) + # need a placeholder to trace operator + send_x = remote_send(x, 1) + recv_x = remote_recv(1, x_np.shape, x_np.dtype) + y = recv_x * recv_x + + grad([y], [as_tensor(np.ones_like(x_np))]) + np.testing.assert_almost_equal(x.grad.numpy(), x.numpy() * 2) + elif rank == 1: + grad = Grad() + + recv_x = remote_recv(0, x_np.shape, x_np.dtype) + send_x = remote_send(recv_x, 0) + + grad([], []) + + worker() def test_grad(): diff --git a/imperative/python/test/unit/functional/test_functional_distributed.py b/imperative/python/test/unit/functional/test_functional_distributed.py index b238de53..92aa3b17 100644 --- a/imperative/python/test/unit/functional/test_functional_distributed.py +++ b/imperative/python/test/unit/functional/test_functional_distributed.py @@ -6,7 +6,6 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import multiprocessing as mp import platform import numpy as np @@ -16,6 +15,7 @@ import megengine as mge import megengine.distributed as dist from megengine import Parameter, Tensor, tensor from megengine.device import get_default_device, set_default_device +from megengine.distributed.helper import get_device_count_by_fork from megengine.functional.distributed import ( all_gather, all_reduce_max, @@ -38,20 +38,16 @@ from megengine.functional.distributed import ( @pytest.mark.skipif( platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" ) +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.isolated_distributed def test_reduce_sum(): - world_size = 2 - server = dist.Server() - port = server.py_server_port - - def worker(rank, data, expect, port): - if mge.get_device_count("gpu") < world_size: - return - dist.init_process_group("localhost", port, world_size, rank, rank) - inp = tensor(data) + @dist.launcher(n_gpus=2) + def worker(data, expect): + rank = dist.get_rank() + inp = tensor(data[rank]) output = reduce_sum(inp) if rank == 0: - assert np.allclose(output.numpy(), expect) + assert np.allclose(output.numpy(), expect[rank]) else: assert np.allclose(output.numpy(), 0) @@ -59,16 +55,9 @@ def test_reduce_sum(): x = np.random.rand(*shape).astype("float32") y = np.random.rand(*shape).astype("float32") z = x + y - p0 = mp.Process(target=worker, args=(0, x, z, port)) - p1 = mp.Process(target=worker, args=(1, y, None, port)) - - p0.start() - p1.start() - - p0.join(10) - p1.join(10) - - assert p0.exitcode == 0 and p1.exitcode == 0 + data = (x, y) + expect = (z, None) + worker(data, expect) for shape in [(2, 3), (8, 10), (99, 77)]: check(shape) @@ -80,33 +69,22 @@ def test_reduce_sum(): @pytest.mark.skipif( platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" ) +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.isolated_distributed def test_broadcast(): - world_size = 2 - server = dist.Server() - port = server.py_server_port - - def worker(rank, data, expect, port): - if mge.get_device_count("gpu") < world_size: - return - dist.init_process_group("localhost", port, world_size, rank, rank) - inp = tensor(data) + @dist.launcher(n_gpus=2) + def worker(data, expect): + rank = dist.get_rank() + inp = tensor(data[rank]) output = broadcast(inp) - assert np.allclose(output.numpy(), expect) + assert np.allclose(output.numpy(), expect[rank]) def check(shape): x = np.random.rand(*shape).astype("float32") y = x + 1 - p0 = mp.Process(target=worker, args=(0, x, x, port)) - p1 = mp.Process(target=worker, args=(1, y, x, port)) - - p0.start() - p1.start() - - p0.join(10) - p1.join(10) - - assert p0.exitcode == 0 and p1.exitcode == 0 + data = (x, y) + expect = (x, x) + worker(data, expect) for shape in [(2, 3), (8, 10), (99, 77)]: check(shape) @@ -118,34 +96,23 @@ def test_broadcast(): @pytest.mark.skipif( platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" ) +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.isolated_distributed def test_all_gather(): - world_size = 2 - server = dist.Server() - port = server.py_server_port - - def worker(rank, data, expect, port): - if mge.get_device_count("gpu") < world_size: - return - dist.init_process_group("localhost", port, world_size, rank, rank) - inp = tensor(data) + @dist.launcher(n_gpus=2) + def worker(data, expect): + rank = dist.get_rank() + inp = tensor(data[rank]) output = all_gather(inp) - assert np.allclose(output.numpy(), expect) + assert np.allclose(output.numpy(), expect[rank]) def check(shape): x = np.random.rand(*shape).astype("float32") y = np.random.rand(*shape).astype("float32") z = np.concatenate((x, y)) - p0 = mp.Process(target=worker, args=(0, x, z, port)) - p1 = mp.Process(target=worker, args=(1, y, z, port)) - - p0.start() - p1.start() - - p0.join(10) - p1.join(10) - - assert p0.exitcode == 0 and p1.exitcode == 0 + data = (x, y) + expect = (z, z) + worker(data, expect) for shape in [(2, 3), (8, 10), (99, 77)]: check(shape) @@ -157,34 +124,23 @@ def test_all_gather(): @pytest.mark.skipif( platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" ) +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.isolated_distributed def test_reduce_scatter_sum(): - world_size = 2 - server = dist.Server() - port = server.py_server_port - - def worker(rank, data, expect, port): - if mge.get_device_count("gpu") < world_size: - return - dist.init_process_group("localhost", port, world_size, rank, rank) - inp = tensor(data) + @dist.launcher(n_gpus=2) + def worker(data, expect): + rank = dist.get_rank() + inp = tensor(data[rank]) output = reduce_scatter_sum(inp) - assert np.allclose(output.numpy(), expect) + assert np.allclose(output.numpy(), expect[rank]) def check(shape): x = np.random.rand(*shape).astype("float32") y = np.random.rand(*shape).astype("float32") z = x + y - p0 = mp.Process(target=worker, args=(0, x, z[: shape[0] // 2], port)) - p1 = mp.Process(target=worker, args=(1, y, z[shape[0] // 2 :], port)) - - p0.start() - p1.start() - - p0.join(10) - p1.join(10) - - assert p0.exitcode == 0 and p1.exitcode == 0 + data = (x, y) + expect = (z[: shape[0] // 2], z[shape[0] // 2 :]) + worker(data, expect) for shape in [(2, 4), (8, 10), (88, 44)]: check(shape) @@ -196,34 +152,23 @@ def test_reduce_scatter_sum(): @pytest.mark.skipif( platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" ) +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.isolated_distributed def test_all_reduce_sum(): - world_size = 2 - server = dist.Server() - port = server.py_server_port - - def worker(rank, data, expect, port): - if mge.get_device_count("gpu") < world_size: - return - dist.init_process_group("localhost", port, world_size, rank, rank) - inp = tensor(data) + @dist.launcher(n_gpus=2) + def worker(data, expect): + rank = dist.get_rank() + inp = tensor(data[rank]) output = all_reduce_sum(inp) - assert np.allclose(output.numpy(), expect) + assert np.allclose(output.numpy(), expect[rank]) def check(shape): x = np.random.rand(*shape).astype("float32") y = np.random.rand(*shape).astype("float32") z = x + y - p0 = mp.Process(target=worker, args=(0, x, z, port)) - p1 = mp.Process(target=worker, args=(1, y, z, port)) - - p0.start() - p1.start() - - p0.join(10) - p1.join(10) - - assert p0.exitcode == 0 and p1.exitcode == 0 + data = (x, y) + expect = (z, z) + worker(data, expect) for shape in [(2, 3), (8, 10), (99, 77)]: check(shape) @@ -235,34 +180,23 @@ def test_all_reduce_sum(): @pytest.mark.skipif( platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" ) +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.isolated_distributed def test_all_reduce_max(): - world_size = 2 - server = dist.Server() - port = server.py_server_port - - def worker(rank, data, expect, port): - if mge.get_device_count("gpu") < world_size: - return - dist.init_process_group("localhost", port, world_size, rank, rank) - inp = tensor(data) + @dist.launcher(n_gpus=2) + def worker(data, expect): + rank = dist.get_rank() + inp = tensor(data[rank]) output = all_reduce_max(inp) - assert np.allclose(output.numpy(), expect) + assert np.allclose(output.numpy(), expect[rank]) def check(shape): x = np.random.rand(*shape).astype("float32") y = np.random.rand(*shape).astype("float32") z = np.maximum(x, y) - p0 = mp.Process(target=worker, args=(0, x, z, port)) - p1 = mp.Process(target=worker, args=(1, y, z, port)) - - p0.start() - p1.start() - - p0.join(10) - p1.join(10) - - assert p0.exitcode == 0 and p1.exitcode == 0 + data = (x, y) + expect = (z, z) + worker(data, expect) for shape in [(2, 3), (8, 10), (99, 77)]: check(shape) @@ -274,34 +208,23 @@ def test_all_reduce_max(): @pytest.mark.skipif( platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" ) +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.isolated_distributed def test_all_reduce_min(): - world_size = 2 - server = dist.Server() - port = server.py_server_port - - def worker(rank, data, expect, port): - if mge.get_device_count("gpu") < world_size: - return - dist.init_process_group("localhost", port, world_size, rank, rank) - inp = tensor(data) + @dist.launcher(n_gpus=2) + def worker(data, expect): + rank = dist.get_rank() + inp = tensor(data[rank]) output = all_reduce_min(inp) - assert np.allclose(output.numpy(), expect) + assert np.allclose(output.numpy(), expect[rank]) def check(shape): x = np.random.rand(*shape).astype("float32") y = np.random.rand(*shape).astype("float32") z = np.minimum(x, y) - p0 = mp.Process(target=worker, args=(0, x, z, port)) - p1 = mp.Process(target=worker, args=(1, y, z, port)) - - p0.start() - p1.start() - - p0.join(10) - p1.join(10) - - assert p0.exitcode == 0 and p1.exitcode == 0 + data = (x, y) + expect = (z, z) + worker(data, expect) for shape in [(2, 3), (8, 10), (99, 77)]: check(shape) @@ -313,20 +236,16 @@ def test_all_reduce_min(): @pytest.mark.skipif( platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" ) +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.isolated_distributed def test_gather(): - world_size = 2 - server = dist.Server() - port = server.py_server_port - - def worker(rank, data, expect, port): - if mge.get_device_count("gpu") < world_size: - return - dist.init_process_group("localhost", port, world_size, rank, rank) - inp = tensor(data) + @dist.launcher(n_gpus=2) + def worker(data, expect): + rank = dist.get_rank() + inp = tensor(data[rank]) output = gather(inp) if rank == 0: - assert np.allclose(output.numpy(), expect) + assert np.allclose(output.numpy(), expect[rank]) else: assert np.allclose(output.numpy(), 0) @@ -334,16 +253,9 @@ def test_gather(): x = np.random.rand(*shape).astype("float32") y = np.random.rand(*shape).astype("float32") z = np.concatenate((x, y)) - p0 = mp.Process(target=worker, args=(0, x, z, port)) - p1 = mp.Process(target=worker, args=(1, y, None, port)) - - p0.start() - p1.start() - - p0.join(10) - p1.join(10) - - assert p0.exitcode == 0 and p1.exitcode == 0 + data = (x, y) + expect = (z, None) + worker(data, expect) for shape in [(2, 3), (8, 10), (99, 77)]: check(shape) @@ -355,33 +267,22 @@ def test_gather(): @pytest.mark.skipif( platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" ) +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.isolated_distributed def test_scatter(): - world_size = 2 - server = dist.Server() - port = server.py_server_port - - def worker(rank, data, expect, port): - if mge.get_device_count("gpu") < world_size: - return - dist.init_process_group("localhost", port, world_size, rank, rank) - inp = tensor(data) + @dist.launcher(n_gpus=2) + def worker(data, expect): + rank = dist.get_rank() + inp = tensor(data[rank]) output = scatter(inp) - assert np.allclose(output.numpy(), expect) + assert np.allclose(output.numpy(), expect[rank]) def check(shape): x = np.random.rand(*shape).astype("float32") y = x + 1 - p0 = mp.Process(target=worker, args=(0, x, x[: shape[0] // 2], port)) - p1 = mp.Process(target=worker, args=(1, y, x[shape[0] // 2 :], port)) - - p0.start() - p1.start() - - p0.join(10) - p1.join(10) - - assert p0.exitcode == 0 and p1.exitcode == 0 + data = (x, y) + expect = (x[: shape[0] // 2], x[shape[0] // 2 :]) + worker(data, expect) for shape in [(2, 3), (8, 10), (100, 77)]: check(shape) @@ -393,35 +294,24 @@ def test_scatter(): @pytest.mark.skipif( platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" ) +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.isolated_distributed def test_all_to_all(): - world_size = 2 - server = dist.Server() - port = server.py_server_port - - def worker(rank, data, expect, port): - if mge.get_device_count("gpu") < world_size: - return - dist.init_process_group("localhost", port, world_size, rank, rank) - inp = tensor(data) + @dist.launcher(n_gpus=2) + def worker(data, expect): + rank = dist.get_rank() + inp = tensor(data[rank]) output = all_to_all(inp) - assert np.allclose(output.numpy(), expect) + assert np.allclose(output.numpy(), expect[rank]) def check(shape): x = np.random.rand(*shape).astype("float32") y = np.random.rand(*shape).astype("float32") a = np.concatenate((x[: shape[0] // 2], y[: shape[0] // 2])) b = np.concatenate((x[shape[0] // 2 :], y[shape[0] // 2 :])) - p0 = mp.Process(target=worker, args=(0, x, a, port)) - p1 = mp.Process(target=worker, args=(1, y, b, port)) - - p0.start() - p1.start() - - p0.join(10) - p1.join(10) - - assert p0.exitcode == 0 and p1.exitcode == 0 + data = (x, y) + expect = (a, b) + worker(data, expect) for shape in [(2, 3), (8, 10), (100, 77)]: check(shape) @@ -433,33 +323,21 @@ def test_all_to_all(): @pytest.mark.skipif( platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" ) +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.isolated_distributed def test_io_remote(): - world_size = 2 - server = dist.Server() - port = server.py_server_port val = np.random.rand(4, 5).astype(np.float32) - def worker(rank): - if mge.get_device_count("gpu") < world_size: - return + @dist.launcher(n_gpus=2) + def worker(): + rank = dist.get_rank() if rank == 0: # remote send - dist.init_process_group("localhost", port, world_size, rank, rank) x = Tensor(val, device="gpu0") y = remote_send(x, 1) assert y.numpy()[0] == 0 else: # remote recv - dist.init_process_group("localhost", port, world_size, rank, rank) y = remote_recv(0, val.shape, val.dtype) assert y.device == "gpu1" np.testing.assert_almost_equal(val, y.numpy()) - procs = [] - for rank in range(world_size): - p = mp.Process(target=worker, args=(rank,)) - p.start() - procs.append(p) - - for p in procs: - p.join(10) - assert p.exitcode == 0 + worker() diff --git a/imperative/python/test/unit/module/test_batchnorm.py b/imperative/python/test/unit/module/test_batchnorm.py index e9ef3131..ba440820 100644 --- a/imperative/python/test/unit/module/test_batchnorm.py +++ b/imperative/python/test/unit/module/test_batchnorm.py @@ -7,7 +7,6 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import functools -import multiprocessing as mp import platform import numpy as np @@ -17,6 +16,7 @@ import megengine as mge import megengine.distributed as dist from megengine import Tensor from megengine.core._trace_option import use_symbolic_shape +from megengine.distributed.helper import get_device_count_by_fork from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) @@ -28,6 +28,7 @@ _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol @pytest.mark.skipif( platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" ) +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.isolated_distributed def test_syncbn(): nr_chan = 8 @@ -41,15 +42,14 @@ def test_syncbn(): server = dist.Server() port = server.py_server_port - def worker(rank, data, yv_expect, running_mean, running_var): - if mge.get_device_count("gpu") < nr_ranks: - return - dist.init_process_group("localhost", port, nr_ranks, rank, rank) + @dist.launcher(n_gpus=2) + def worker(data, yv_expect, running_mean, running_var): + rank = dist.get_rank() bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps) for i in range(steps): - yv = bn(Tensor(data[i])) + yv = bn(Tensor(data[rank][i])) - _assert_allclose(yv.numpy(), yv_expect) + _assert_allclose(yv.numpy(), yv_expect[rank]) _assert_allclose(bn.running_mean.numpy(), running_mean) _assert_allclose(bn.running_var.numpy(), running_var) @@ -77,24 +77,9 @@ def test_syncbn(): for j in range(steps): data[i].append(xv[j][:, :, :, i * 8 : i * 8 + 8]) - procs = [] - for rank in range(nr_ranks): - p = mp.Process( - target=worker, - args=( - rank, - data[rank], - yv_expect[:, :, :, rank * 8 : rank * 8 + 8], - running_mean, - running_var, - ), - ) - p.start() - procs.append(p) + yv_expect = [yv_expect[:, :, :, i * 8 : i * 8 + 8] for i in range(nr_ranks)] - for p in procs: - p.join(10) - assert p.exitcode == 0 + worker(data, yv_expect, running_mean, running_var) def test_batchnorm(): diff --git a/imperative/python/test/unit/quantization/test_observer.py b/imperative/python/test/unit/quantization/test_observer.py index aa962264..e1a0a091 100644 --- a/imperative/python/test/unit/quantization/test_observer.py +++ b/imperative/python/test/unit/quantization/test_observer.py @@ -1,4 +1,3 @@ -import multiprocessing as mp import platform import numpy as np