GitOrigin-RevId: 085fd1dcfd
release-1.2
@@ -45,9 +45,15 @@ def launcher(func): | |||
while len(ranks) > 0: | |||
left = [] | |||
# check all processes in one second | |||
time_to_wait = 1.0 / len(ranks) | |||
for rank in ranks: | |||
procs[rank].join(1) | |||
procs[rank].join(time_to_wait) | |||
code = procs[rank].exitcode | |||
# terminate processes if one of them has failed | |||
if code != 0 and code != None: | |||
for i in ranks: | |||
procs[i].terminate() | |||
assert ( | |||
code == 0 or code == None | |||
), "subprocess {} exit with code {}".format(rank, code) | |||
@@ -133,18 +133,22 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): | |||
pass | |||
def start_server(py_server_port, mm_server_port, queue): | |||
def _start_server(py_server_port, mm_server_port, queue): | |||
""" | |||
Start python distributed server and multiple machine server. | |||
:param py_server_port: python server port. | |||
:param mm_server_port: multiple machine server port. | |||
:param queue: server port will put in this queue, puts exception when process fails. | |||
""" | |||
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) | |||
server.register_instance(Methods(mm_server_port)) | |||
_, port = server.server_address | |||
queue.put(port) | |||
server.serve_forever() | |||
try: | |||
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) | |||
server.register_instance(Methods(mm_server_port)) | |||
_, port = server.server_address | |||
queue.put(port) | |||
server.serve_forever() | |||
except Exception as e: | |||
queue.put(e) | |||
class Server: | |||
@@ -159,10 +163,14 @@ class Server: | |||
self.mm_server_port = create_mm_server("0.0.0.0", 0) | |||
q = Queue() | |||
self.proc = threading.Thread( | |||
target=start_server, args=(port, self.mm_server_port, q), daemon=True, | |||
target=_start_server, args=(port, self.mm_server_port, q), daemon=True, | |||
) | |||
self.proc.start() | |||
self.py_server_port = q.get() | |||
ret = q.get() | |||
if isinstance(ret, Exception): | |||
raise ret | |||
else: | |||
self.py_server_port = ret | |||
class Client: | |||
@@ -159,11 +159,9 @@ def run_test( | |||
checkpoint = mge.load(model_path) | |||
data = checkpoint["data"] | |||
label = checkpoint["label"] | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
def worker(rank, max_err): | |||
dist.init_process_group("localhost", port, p_num, rank, rank) | |||
@dist.launcher | |||
def worker(max_err): | |||
net = MnistNet(has_bn=True) | |||
net.load_state_dict(checkpoint["net_init"]) | |||
lr = checkpoint["sgd_lr"] | |||
@@ -194,15 +192,7 @@ def run_test( | |||
else: | |||
np.testing.assert_allclose(param[1], param_ref[1], atol=max_err) | |||
procs = [] | |||
for rank in range(p_num): | |||
p = mp.Process(target=worker, args=(rank, max_err,)) | |||
p.start() | |||
procs.append(p) | |||
for p in procs: | |||
p.join(20) | |||
assert p.exitcode == 0 | |||
worker(max_err) | |||
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 4, reason="need more gpu device") | |||
@@ -23,6 +23,7 @@ from megengine.core.ops.builtin import Elemwise | |||
from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
from megengine.core.tensor.tensor import Tensor, apply | |||
from megengine.core.tensor.tensor_wrapper import TensorWrapper | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.functional.distributed import remote_recv, remote_send | |||
@@ -53,15 +54,19 @@ def save_to(self, name="grad"): | |||
return callback | |||
@pytest.mark.isolated_distributed | |||
@pytest.mark.skipif( | |||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
) | |||
@pytest.mark.skipif( | |||
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" | |||
) | |||
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||
@pytest.mark.isolated_distributed | |||
def test_dist_grad(): | |||
world_size = 2 | |||
x_np = np.random.rand(10).astype("float32") | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
server = dist.Server() | |||
port = server.py_server_port | |||
def worker0(): | |||
dist.init_process_group("localhost", port, world_size, 0, 0) | |||
@@ -47,8 +47,8 @@ def _assert_q_val(q, val): | |||
@pytest.mark.isolated_distributed | |||
def test_init_process_group(): | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
server = dist.Server() | |||
port = server.py_server_port | |||
def worker(rank, backend): | |||
dist.init_process_group("localhost", port, world_size, rank, rank, backend) | |||
@@ -92,11 +92,10 @@ def test_init_process_group(): | |||
def test_new_group(): | |||
world_size = 3 | |||
ranks = [2, 0] | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
def worker(rank): | |||
dist.init_process_group("localhost", port, world_size, rank, rank) | |||
@dist.launcher | |||
def worker(): | |||
rank = dist.get_rank() | |||
if rank in ranks: | |||
group = dist.new_group(ranks) | |||
assert group.size == 2 | |||
@@ -104,15 +103,7 @@ def test_new_group(): | |||
assert group.rank == ranks.index(rank) | |||
assert group.comp_node == "gpu{}:2".format(rank) | |||
procs = [] | |||
for rank in range(world_size): | |||
p = mp.Process(target=worker, args=(rank,)) | |||
p.start() | |||
procs.append(p) | |||
for p in procs: | |||
p.join(20) | |||
assert p.exitcode == 0 | |||
worker() | |||
@pytest.mark.skipif( | |||
@@ -125,8 +116,8 @@ def test_new_group(): | |||
@pytest.mark.isolated_distributed | |||
def test_group_barrier(): | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
server = dist.Server() | |||
port = server.py_server_port | |||
def worker(rank, q): | |||
dist.init_process_group("localhost", port, world_size, rank, rank) | |||
@@ -161,8 +152,8 @@ def test_group_barrier(): | |||
@pytest.mark.isolated_distributed | |||
def test_synchronized(): | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
server = dist.Server() | |||
port = server.py_server_port | |||
@dist.synchronized | |||
def func(rank, q): | |||
@@ -205,26 +196,16 @@ def test_synchronized(): | |||
@pytest.mark.isolated_distributed | |||
def test_user_set_get(): | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
def worker(rank): | |||
dist.init_process_group("localhost", port, world_size, rank, rank) | |||
@dist.launcher | |||
def worker(): | |||
# set in race condition | |||
dist.get_client().user_set("foo", 1) | |||
# get in race condition | |||
ret = dist.get_client().user_get("foo") | |||
assert ret == 1 | |||
procs = [] | |||
for rank in range(world_size): | |||
p = mp.Process(target=worker, args=(rank,)) | |||
p.start() | |||
procs.append(p) | |||
for p in procs: | |||
p.join(20) | |||
assert p.exitcode == 0 | |||
worker() | |||
def test_oprmm_hashable(): | |||
@@ -41,8 +41,8 @@ from megengine.functional.distributed import ( | |||
@pytest.mark.isolated_distributed | |||
def test_reduce_sum(): | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
server = dist.Server() | |||
port = server.py_server_port | |||
def worker(rank, data, expect, port): | |||
if mge.get_device_count("gpu") < world_size: | |||
@@ -83,8 +83,8 @@ def test_reduce_sum(): | |||
@pytest.mark.isolated_distributed | |||
def test_broadcast(): | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
server = dist.Server() | |||
port = server.py_server_port | |||
def worker(rank, data, expect, port): | |||
if mge.get_device_count("gpu") < world_size: | |||
@@ -121,8 +121,8 @@ def test_broadcast(): | |||
@pytest.mark.isolated_distributed | |||
def test_all_gather(): | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
server = dist.Server() | |||
port = server.py_server_port | |||
def worker(rank, data, expect, port): | |||
if mge.get_device_count("gpu") < world_size: | |||
@@ -160,8 +160,8 @@ def test_all_gather(): | |||
@pytest.mark.isolated_distributed | |||
def test_reduce_scatter_sum(): | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
server = dist.Server() | |||
port = server.py_server_port | |||
def worker(rank, data, expect, port): | |||
if mge.get_device_count("gpu") < world_size: | |||
@@ -199,8 +199,8 @@ def test_reduce_scatter_sum(): | |||
@pytest.mark.isolated_distributed | |||
def test_all_reduce_sum(): | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
server = dist.Server() | |||
port = server.py_server_port | |||
def worker(rank, data, expect, port): | |||
if mge.get_device_count("gpu") < world_size: | |||
@@ -238,8 +238,8 @@ def test_all_reduce_sum(): | |||
@pytest.mark.isolated_distributed | |||
def test_all_reduce_max(): | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
server = dist.Server() | |||
port = server.py_server_port | |||
def worker(rank, data, expect, port): | |||
if mge.get_device_count("gpu") < world_size: | |||
@@ -277,8 +277,8 @@ def test_all_reduce_max(): | |||
@pytest.mark.isolated_distributed | |||
def test_all_reduce_min(): | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
server = dist.Server() | |||
port = server.py_server_port | |||
def worker(rank, data, expect, port): | |||
if mge.get_device_count("gpu") < world_size: | |||
@@ -316,8 +316,8 @@ def test_all_reduce_min(): | |||
@pytest.mark.isolated_distributed | |||
def test_gather(): | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
server = dist.Server() | |||
port = server.py_server_port | |||
def worker(rank, data, expect, port): | |||
if mge.get_device_count("gpu") < world_size: | |||
@@ -358,8 +358,8 @@ def test_gather(): | |||
@pytest.mark.isolated_distributed | |||
def test_scatter(): | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
server = dist.Server() | |||
port = server.py_server_port | |||
def worker(rank, data, expect, port): | |||
if mge.get_device_count("gpu") < world_size: | |||
@@ -396,8 +396,8 @@ def test_scatter(): | |||
@pytest.mark.isolated_distributed | |||
def test_all_to_all(): | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
server = dist.Server() | |||
port = server.py_server_port | |||
def worker(rank, data, expect, port): | |||
if mge.get_device_count("gpu") < world_size: | |||
@@ -436,8 +436,8 @@ def test_all_to_all(): | |||
@pytest.mark.isolated_distributed | |||
def test_io_remote(): | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
server = dist.Server() | |||
port = server.py_server_port | |||
val = np.random.rand(4, 5).astype(np.float32) | |||
def worker(rank): | |||
@@ -38,7 +38,7 @@ def test_syncbn(): | |||
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) | |||
steps = 4 | |||
nr_ranks = 2 | |||
server = dist.Server(0) | |||
server = dist.Server() | |||
port = server.py_server_port | |||
def worker(rank, data, yv_expect, running_mean, running_var): | |||
@@ -28,25 +28,16 @@ def test_min_max_observer(): | |||
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||
@pytest.mark.isolated_distributed | |||
def test_sync_min_max_observer(): | |||
x = np.random.rand(6, 3, 3, 3).astype("float32") | |||
word_size = get_device_count_by_fork("gpu") | |||
x = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") | |||
np_min, np_max = x.min(), x.max() | |||
world_size = 2 | |||
port = dist.get_free_ports(1)[0] | |||
server = dist.Server(port) | |||
def worker(rank, slc): | |||
dist.init_process_group("localhost", port, world_size, rank, rank) | |||
@dist.launcher | |||
def worker(): | |||
rank = dist.get_rank() | |||
m = ob.SyncMinMaxObserver() | |||
y = mge.tensor(x[slc]) | |||
y = mge.tensor(x[rank * 3 : (rank + 1) * 3]) | |||
m(y) | |||
assert m.min_val == np_min and m.max_val == np_max | |||
procs = [] | |||
for rank in range(world_size): | |||
slc = slice(rank * 3, (rank + 1) * 3) | |||
p = mp.Process(target=worker, args=(rank, slc,), daemon=True) | |||
p.start() | |||
procs.append(p) | |||
for p in procs: | |||
p.join(20) | |||
assert p.exitcode == 0 | |||
worker() |