Browse Source

refactor(mge/distributed): sync interpreter for distribtued launcher

GitOrigin-RevId: 8a88c272a1
release-1.2
Megvii Engine Team 4 years ago
parent
commit
9d928e7f83
8 changed files with 199 additions and 312 deletions
  1. +69
    -23
      imperative/python/megengine/distributed/launcher.py
  2. +0
    -1
      imperative/python/test/integration/test_dp_correctness.py
  3. +0
    -1
      imperative/python/test/integration/test_param_pack.py
  4. +0
    -2
      imperative/python/test/unit/autodiff/test_grad_manger.py
  5. +25
    -42
      imperative/python/test/unit/core/test_autodiff.py
  6. +96
    -218
      imperative/python/test/unit/functional/test_functional_distributed.py
  7. +9
    -24
      imperative/python/test/unit/module/test_batchnorm.py
  8. +0
    -1
      imperative/python/test/unit/quantization/test_observer.py

+ 69
- 23
imperative/python/megengine/distributed/launcher.py View File

@@ -6,59 +6,105 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import multiprocessing as mp import multiprocessing as mp


from .group import init_process_group
from ..core._imperative_rt import sync
from .group import group_barrier, init_process_group
from .helper import get_device_count_by_fork from .helper import get_device_count_by_fork
from .server import Server from .server import Server
from .util import get_free_ports from .util import get_free_ports




def _run_wrapped(func, master_ip, port, world_size, rank, dev, args, kwargs):
def _run_wrapped(
func, is_multimachine, master_ip, port, world_size, rank, dev, args, kwargs
):
"""Init distributed process group and run wrapped function.""" """Init distributed process group and run wrapped function."""
init_process_group( init_process_group(
master_ip=master_ip, port=port, world_size=world_size, rank=rank, device=dev master_ip=master_ip, port=port, world_size=world_size, rank=rank, device=dev
) )
if is_multimachine:
group_barrier()
func(*args, **kwargs) func(*args, **kwargs)
sync()
if is_multimachine:
group_barrier()




def launcher(func):
"""Decorator for launching multiple processes in single-machine multi-gpu training."""
class launcher:
"""Decorator for launching multiple processes in single-machine multi-gpu training.
:param func: the function you want to launch in distributed mode.
:param n_gpus: how many devices each node.
:param world_size: how many devices totally.
:param rank_start: start number for rank.
:param master_ip: ip address for master node (where the rank 0 is).
:param port: server port for distributed server.
"""


n_gpus = get_device_count_by_fork("gpu")
def __new__(cls, *args, **kwargs):
if not args:
return functools.partial(cls, **kwargs)
return super().__new__(cls)


def wrapper(*args, **kwargs):
master_ip = "localhost"
server = Server()
port = server.py_server_port
def __init__(
self,
func,
n_gpus=None,
world_size=None,
rank_start=0,
master_ip="localhost",
port=0,
):
self.func = func
self.n_gpus = n_gpus if n_gpus is not None else get_device_count_by_fork("gpu")
self.world_size = world_size if world_size is not None else self.n_gpus
self.rank_start = rank_start
self.master_ip = master_ip
self.port = port
# master node create server
if self.rank_start == 0:
self.server = Server(self.port)
self.port = self.server.py_server_port
else:
assert self.port != 0, "you have to assign a port for distributed server"


def __call__(self, *args, **kwargs):
procs = [] procs = []
for rank in range(n_gpus):
for dev in range(self.n_gpus):
p = mp.Process( p = mp.Process(
target=_run_wrapped, target=_run_wrapped,
args=(func, master_ip, port, n_gpus, rank, rank, args, kwargs),
args=(
self.func,
self.world_size > self.n_gpus,
self.master_ip,
self.port,
self.world_size,
dev + self.rank_start,
dev,
args,
kwargs,
),
) )
p.start() p.start()
procs.append(p) procs.append(p)


ranks = [rank for rank in range(n_gpus)]
devs = list(range(self.n_gpus))


while len(ranks) > 0:
while len(devs) > 0:
left = [] left = []
# check all processes in one second # check all processes in one second
time_to_wait = 1.0 / len(ranks)
for rank in ranks:
procs[rank].join(time_to_wait)
code = procs[rank].exitcode
time_to_wait = 1.0 / len(devs)
for dev in devs:
procs[dev].join(time_to_wait)
code = procs[dev].exitcode
# terminate processes if one of them has failed # terminate processes if one of them has failed
if code != 0 and code != None: if code != 0 and code != None:
for i in ranks:
for i in devs:
procs[i].terminate() procs[i].terminate()
assert ( assert (
code == 0 or code == None code == 0 or code == None
), "subprocess {} exit with code {}".format(rank, code)
), "subprocess {} exit with code {}".format(dev + self.rank_start, code)
if code == None: if code == None:
left.append(rank)
ranks = left

return wrapper
left.append(dev)
devs = left

+ 0
- 1
imperative/python/test/integration/test_dp_correctness.py View File

@@ -6,7 +6,6 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import os import os
import platform import platform
import re import re


+ 0
- 1
imperative/python/test/integration/test_param_pack.py View File

@@ -6,7 +6,6 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import platform import platform


import numpy as np import numpy as np


+ 0
- 2
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -17,7 +17,6 @@ import megengine.functional as F
import megengine.module as M import megengine.module as M
import megengine.optimizer as optim import megengine.optimizer as optim
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
from megengine.core._imperative_rt.imperative import sync
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace from megengine.jit import trace


@@ -135,6 +134,5 @@ def test_remote_grad():
for func in train_funcs: for func in train_funcs:
for i in range(3): for i in range(3):
func(x) func(x)
sync()


worker() worker()

+ 25
- 42
imperative/python/test/unit/core/test_autodiff.py View File

@@ -17,7 +17,6 @@ import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
import megengine.functional as F import megengine.functional as F
from megengine.core._imperative_rt import TensorAttr, imperative from megengine.core._imperative_rt import TensorAttr, imperative
from megengine.core._imperative_rt.imperative import sync
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Grad
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.raw_tensor import as_raw_tensor from megengine.core.tensor.raw_tensor import as_raw_tensor
@@ -65,47 +64,31 @@ def save_to(self, name="grad"):
def test_dist_grad(): def test_dist_grad():
world_size = 2 world_size = 2
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
server = dist.Server()
port = server.py_server_port

def worker0():
dist.init_process_group("localhost", port, world_size, 0, 0)
mge.device.set_default_device("gpu0")
grad = Grad()

x = as_tensor(x_np)
grad.wrt(x, callback=save_to(x))
# need a placeholder to trace operator
send_x = remote_send(x, 1)
recv_x = remote_recv(1, x_np.shape, x_np.dtype, "gpu0")
y = recv_x * recv_x

grad([y], [as_tensor(np.ones_like(x_np))])
np.testing.assert_almost_equal(x.grad.numpy(), x.numpy() * 2)

def worker1():
dist.init_process_group("localhost", port, world_size, 1, 1)
mge.device.set_default_device("gpu1")
grad = Grad()

recv_x = remote_recv(0, x_np.shape, x_np.dtype, "gpu1")
send_x = remote_send(recv_x, 0)

grad([], [])

# sync because grad has a send operator
sync()
send_x.device._cn._sync_all()

import multiprocessing as mp

p0 = mp.Process(target=worker0)
p1 = mp.Process(target=worker1)
p0.start()
p1.start()
p0.join(10)
p1.join(10)
assert p0.exitcode == 0 and p1.exitcode == 0

@dist.launcher
def worker():
rank = dist.get_rank()
if rank == 0:
grad = Grad()

x = as_tensor(x_np)
grad.wrt(x, callback=save_to(x))
# need a placeholder to trace operator
send_x = remote_send(x, 1)
recv_x = remote_recv(1, x_np.shape, x_np.dtype)
y = recv_x * recv_x

grad([y], [as_tensor(np.ones_like(x_np))])
np.testing.assert_almost_equal(x.grad.numpy(), x.numpy() * 2)
elif rank == 1:
grad = Grad()

recv_x = remote_recv(0, x_np.shape, x_np.dtype)
send_x = remote_send(recv_x, 0)

grad([], [])

worker()




def test_grad(): def test_grad():


+ 96
- 218
imperative/python/test/unit/functional/test_functional_distributed.py View File

@@ -6,7 +6,6 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import platform import platform


import numpy as np import numpy as np
@@ -16,6 +15,7 @@ import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
from megengine import Parameter, Tensor, tensor from megengine import Parameter, Tensor, tensor
from megengine.device import get_default_device, set_default_device from megengine.device import get_default_device, set_default_device
from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.distributed import ( from megengine.functional.distributed import (
all_gather, all_gather,
all_reduce_max, all_reduce_max,
@@ -38,20 +38,16 @@ from megengine.functional.distributed import (
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
) )
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_reduce_sum(): def test_reduce_sum():
world_size = 2
server = dist.Server()
port = server.py_server_port

def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = reduce_sum(inp) output = reduce_sum(inp)
if rank == 0: if rank == 0:
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])
else: else:
assert np.allclose(output.numpy(), 0) assert np.allclose(output.numpy(), 0)


@@ -59,16 +55,9 @@ def test_reduce_sum():
x = np.random.rand(*shape).astype("float32") x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32") y = np.random.rand(*shape).astype("float32")
z = x + y z = x + y
p0 = mp.Process(target=worker, args=(0, x, z, port))
p1 = mp.Process(target=worker, args=(1, y, None, port))

p0.start()
p1.start()

p0.join(10)
p1.join(10)

assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (z, None)
worker(data, expect)


for shape in [(2, 3), (8, 10), (99, 77)]: for shape in [(2, 3), (8, 10), (99, 77)]:
check(shape) check(shape)
@@ -80,33 +69,22 @@ def test_reduce_sum():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
) )
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_broadcast(): def test_broadcast():
world_size = 2
server = dist.Server()
port = server.py_server_port

def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = broadcast(inp) output = broadcast(inp)
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])


def check(shape): def check(shape):
x = np.random.rand(*shape).astype("float32") x = np.random.rand(*shape).astype("float32")
y = x + 1 y = x + 1
p0 = mp.Process(target=worker, args=(0, x, x, port))
p1 = mp.Process(target=worker, args=(1, y, x, port))

p0.start()
p1.start()

p0.join(10)
p1.join(10)

assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (x, x)
worker(data, expect)


for shape in [(2, 3), (8, 10), (99, 77)]: for shape in [(2, 3), (8, 10), (99, 77)]:
check(shape) check(shape)
@@ -118,34 +96,23 @@ def test_broadcast():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
) )
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_all_gather(): def test_all_gather():
world_size = 2
server = dist.Server()
port = server.py_server_port

def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = all_gather(inp) output = all_gather(inp)
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])


def check(shape): def check(shape):
x = np.random.rand(*shape).astype("float32") x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32") y = np.random.rand(*shape).astype("float32")
z = np.concatenate((x, y)) z = np.concatenate((x, y))
p0 = mp.Process(target=worker, args=(0, x, z, port))
p1 = mp.Process(target=worker, args=(1, y, z, port))

p0.start()
p1.start()

p0.join(10)
p1.join(10)

assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (z, z)
worker(data, expect)


for shape in [(2, 3), (8, 10), (99, 77)]: for shape in [(2, 3), (8, 10), (99, 77)]:
check(shape) check(shape)
@@ -157,34 +124,23 @@ def test_all_gather():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
) )
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_reduce_scatter_sum(): def test_reduce_scatter_sum():
world_size = 2
server = dist.Server()
port = server.py_server_port

def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = reduce_scatter_sum(inp) output = reduce_scatter_sum(inp)
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])


def check(shape): def check(shape):
x = np.random.rand(*shape).astype("float32") x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32") y = np.random.rand(*shape).astype("float32")
z = x + y z = x + y
p0 = mp.Process(target=worker, args=(0, x, z[: shape[0] // 2], port))
p1 = mp.Process(target=worker, args=(1, y, z[shape[0] // 2 :], port))

p0.start()
p1.start()

p0.join(10)
p1.join(10)

assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (z[: shape[0] // 2], z[shape[0] // 2 :])
worker(data, expect)


for shape in [(2, 4), (8, 10), (88, 44)]: for shape in [(2, 4), (8, 10), (88, 44)]:
check(shape) check(shape)
@@ -196,34 +152,23 @@ def test_reduce_scatter_sum():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
) )
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_all_reduce_sum(): def test_all_reduce_sum():
world_size = 2
server = dist.Server()
port = server.py_server_port

def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = all_reduce_sum(inp) output = all_reduce_sum(inp)
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])


def check(shape): def check(shape):
x = np.random.rand(*shape).astype("float32") x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32") y = np.random.rand(*shape).astype("float32")
z = x + y z = x + y
p0 = mp.Process(target=worker, args=(0, x, z, port))
p1 = mp.Process(target=worker, args=(1, y, z, port))

p0.start()
p1.start()

p0.join(10)
p1.join(10)

assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (z, z)
worker(data, expect)


for shape in [(2, 3), (8, 10), (99, 77)]: for shape in [(2, 3), (8, 10), (99, 77)]:
check(shape) check(shape)
@@ -235,34 +180,23 @@ def test_all_reduce_sum():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
) )
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_all_reduce_max(): def test_all_reduce_max():
world_size = 2
server = dist.Server()
port = server.py_server_port

def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = all_reduce_max(inp) output = all_reduce_max(inp)
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])


def check(shape): def check(shape):
x = np.random.rand(*shape).astype("float32") x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32") y = np.random.rand(*shape).astype("float32")
z = np.maximum(x, y) z = np.maximum(x, y)
p0 = mp.Process(target=worker, args=(0, x, z, port))
p1 = mp.Process(target=worker, args=(1, y, z, port))

p0.start()
p1.start()

p0.join(10)
p1.join(10)

assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (z, z)
worker(data, expect)


for shape in [(2, 3), (8, 10), (99, 77)]: for shape in [(2, 3), (8, 10), (99, 77)]:
check(shape) check(shape)
@@ -274,34 +208,23 @@ def test_all_reduce_max():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
) )
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_all_reduce_min(): def test_all_reduce_min():
world_size = 2
server = dist.Server()
port = server.py_server_port

def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = all_reduce_min(inp) output = all_reduce_min(inp)
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])


def check(shape): def check(shape):
x = np.random.rand(*shape).astype("float32") x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32") y = np.random.rand(*shape).astype("float32")
z = np.minimum(x, y) z = np.minimum(x, y)
p0 = mp.Process(target=worker, args=(0, x, z, port))
p1 = mp.Process(target=worker, args=(1, y, z, port))

p0.start()
p1.start()

p0.join(10)
p1.join(10)

assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (z, z)
worker(data, expect)


for shape in [(2, 3), (8, 10), (99, 77)]: for shape in [(2, 3), (8, 10), (99, 77)]:
check(shape) check(shape)
@@ -313,20 +236,16 @@ def test_all_reduce_min():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
) )
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_gather(): def test_gather():
world_size = 2
server = dist.Server()
port = server.py_server_port

def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = gather(inp) output = gather(inp)
if rank == 0: if rank == 0:
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])
else: else:
assert np.allclose(output.numpy(), 0) assert np.allclose(output.numpy(), 0)


@@ -334,16 +253,9 @@ def test_gather():
x = np.random.rand(*shape).astype("float32") x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32") y = np.random.rand(*shape).astype("float32")
z = np.concatenate((x, y)) z = np.concatenate((x, y))
p0 = mp.Process(target=worker, args=(0, x, z, port))
p1 = mp.Process(target=worker, args=(1, y, None, port))

p0.start()
p1.start()

p0.join(10)
p1.join(10)

assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (z, None)
worker(data, expect)


for shape in [(2, 3), (8, 10), (99, 77)]: for shape in [(2, 3), (8, 10), (99, 77)]:
check(shape) check(shape)
@@ -355,33 +267,22 @@ def test_gather():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
) )
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_scatter(): def test_scatter():
world_size = 2
server = dist.Server()
port = server.py_server_port

def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = scatter(inp) output = scatter(inp)
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])


def check(shape): def check(shape):
x = np.random.rand(*shape).astype("float32") x = np.random.rand(*shape).astype("float32")
y = x + 1 y = x + 1
p0 = mp.Process(target=worker, args=(0, x, x[: shape[0] // 2], port))
p1 = mp.Process(target=worker, args=(1, y, x[shape[0] // 2 :], port))

p0.start()
p1.start()

p0.join(10)
p1.join(10)

assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (x[: shape[0] // 2], x[shape[0] // 2 :])
worker(data, expect)


for shape in [(2, 3), (8, 10), (100, 77)]: for shape in [(2, 3), (8, 10), (100, 77)]:
check(shape) check(shape)
@@ -393,35 +294,24 @@ def test_scatter():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
) )
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_all_to_all(): def test_all_to_all():
world_size = 2
server = dist.Server()
port = server.py_server_port

def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size:
return
dist.init_process_group("localhost", port, world_size, rank, rank)
inp = tensor(data)
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])
output = all_to_all(inp) output = all_to_all(inp)
assert np.allclose(output.numpy(), expect)
assert np.allclose(output.numpy(), expect[rank])


def check(shape): def check(shape):
x = np.random.rand(*shape).astype("float32") x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32") y = np.random.rand(*shape).astype("float32")
a = np.concatenate((x[: shape[0] // 2], y[: shape[0] // 2])) a = np.concatenate((x[: shape[0] // 2], y[: shape[0] // 2]))
b = np.concatenate((x[shape[0] // 2 :], y[shape[0] // 2 :])) b = np.concatenate((x[shape[0] // 2 :], y[shape[0] // 2 :]))
p0 = mp.Process(target=worker, args=(0, x, a, port))
p1 = mp.Process(target=worker, args=(1, y, b, port))

p0.start()
p1.start()

p0.join(10)
p1.join(10)

assert p0.exitcode == 0 and p1.exitcode == 0
data = (x, y)
expect = (a, b)
worker(data, expect)


for shape in [(2, 3), (8, 10), (100, 77)]: for shape in [(2, 3), (8, 10), (100, 77)]:
check(shape) check(shape)
@@ -433,33 +323,21 @@ def test_all_to_all():
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
) )
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_io_remote(): def test_io_remote():
world_size = 2
server = dist.Server()
port = server.py_server_port
val = np.random.rand(4, 5).astype(np.float32) val = np.random.rand(4, 5).astype(np.float32)


def worker(rank):
if mge.get_device_count("gpu") < world_size:
return
@dist.launcher(n_gpus=2)
def worker():
rank = dist.get_rank()
if rank == 0: # remote send if rank == 0: # remote send
dist.init_process_group("localhost", port, world_size, rank, rank)
x = Tensor(val, device="gpu0") x = Tensor(val, device="gpu0")
y = remote_send(x, 1) y = remote_send(x, 1)
assert y.numpy()[0] == 0 assert y.numpy()[0] == 0
else: # remote recv else: # remote recv
dist.init_process_group("localhost", port, world_size, rank, rank)
y = remote_recv(0, val.shape, val.dtype) y = remote_recv(0, val.shape, val.dtype)
assert y.device == "gpu1" assert y.device == "gpu1"
np.testing.assert_almost_equal(val, y.numpy()) np.testing.assert_almost_equal(val, y.numpy())


procs = []
for rank in range(world_size):
p = mp.Process(target=worker, args=(rank,))
p.start()
procs.append(p)

for p in procs:
p.join(10)
assert p.exitcode == 0
worker()

+ 9
- 24
imperative/python/test/unit/module/test_batchnorm.py View File

@@ -7,7 +7,6 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools import functools
import multiprocessing as mp
import platform import platform


import numpy as np import numpy as np
@@ -17,6 +16,7 @@ import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
from megengine import Tensor from megengine import Tensor
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.distributed.helper import get_device_count_by_fork
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm


_assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)
@@ -28,6 +28,7 @@ _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
) )
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_syncbn(): def test_syncbn():
nr_chan = 8 nr_chan = 8
@@ -41,15 +42,14 @@ def test_syncbn():
server = dist.Server() server = dist.Server()
port = server.py_server_port port = server.py_server_port


def worker(rank, data, yv_expect, running_mean, running_var):
if mge.get_device_count("gpu") < nr_ranks:
return
dist.init_process_group("localhost", port, nr_ranks, rank, rank)
@dist.launcher(n_gpus=2)
def worker(data, yv_expect, running_mean, running_var):
rank = dist.get_rank()
bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps) bn = SyncBatchNorm(nr_chan, momentum=momentum, eps=eps)
for i in range(steps): for i in range(steps):
yv = bn(Tensor(data[i]))
yv = bn(Tensor(data[rank][i]))


_assert_allclose(yv.numpy(), yv_expect)
_assert_allclose(yv.numpy(), yv_expect[rank])
_assert_allclose(bn.running_mean.numpy(), running_mean) _assert_allclose(bn.running_mean.numpy(), running_mean)
_assert_allclose(bn.running_var.numpy(), running_var) _assert_allclose(bn.running_var.numpy(), running_var)


@@ -77,24 +77,9 @@ def test_syncbn():
for j in range(steps): for j in range(steps):
data[i].append(xv[j][:, :, :, i * 8 : i * 8 + 8]) data[i].append(xv[j][:, :, :, i * 8 : i * 8 + 8])


procs = []
for rank in range(nr_ranks):
p = mp.Process(
target=worker,
args=(
rank,
data[rank],
yv_expect[:, :, :, rank * 8 : rank * 8 + 8],
running_mean,
running_var,
),
)
p.start()
procs.append(p)
yv_expect = [yv_expect[:, :, :, i * 8 : i * 8 + 8] for i in range(nr_ranks)]


for p in procs:
p.join(10)
assert p.exitcode == 0
worker(data, yv_expect, running_mean, running_var)




def test_batchnorm(): def test_batchnorm():


+ 0
- 1
imperative/python/test/unit/quantization/test_observer.py View File

@@ -1,4 +1,3 @@
import multiprocessing as mp
import platform import platform


import numpy as np import numpy as np


Loading…
Cancel
Save