From 3ecded74ea3b500d0bc95e669ef0a98e6b1193e4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 25 Nov 2020 19:28:11 +0800 Subject: [PATCH] refactor(distributed/server): use port 0 to get available port GitOrigin-RevId: e367846b9216ea6d5ef7ada698ffba790ba8e1e6 --- imperative/python/megengine/distributed/launcher.py | 21 +++++++++++++++------ imperative/python/megengine/distributed/server.py | 14 ++++++++------ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/imperative/python/megengine/distributed/launcher.py b/imperative/python/megengine/distributed/launcher.py index a6c7c05a..6cf39fa0 100644 --- a/imperative/python/megengine/distributed/launcher.py +++ b/imperative/python/megengine/distributed/launcher.py @@ -29,8 +29,8 @@ def launcher(func): def wrapper(*args, **kwargs): master_ip = "localhost" - port = get_free_ports(1)[0] - server = Server(port) + server = Server() + port = server.py_server_port procs = [] for rank in range(n_gpus): @@ -41,9 +41,18 @@ def launcher(func): p.start() procs.append(p) - for rank in range(n_gpus): - procs[rank].join() - code = procs[rank].exitcode - assert code == 0, "subprocess {} exit with code {}".format(rank, code) + ranks = [rank for rank in range(n_gpus)] + + while len(ranks) > 0: + left = [] + for rank in ranks: + procs[rank].join(1) + code = procs[rank].exitcode + assert ( + code == 0 or code == None + ), "subprocess {} exit with code {}".format(rank, code) + if code == None: + left.append(rank) + ranks = left return wrapper diff --git a/imperative/python/megengine/distributed/server.py b/imperative/python/megengine/distributed/server.py index c9ab3177..6ab00a8b 100644 --- a/imperative/python/megengine/distributed/server.py +++ b/imperative/python/megengine/distributed/server.py @@ -10,6 +10,7 @@ import threading import time from collections import defaultdict from functools import partial +from queue import Queue from socketserver import ThreadingMixIn from xmlrpc.client import ServerProxy from xmlrpc.server import SimpleXMLRPCServer @@ -132,7 +133,7 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): pass -def start_server(py_server_port, mm_server_port): +def start_server(py_server_port, mm_server_port, queue): """ Start python distributed server and multiple machine server. @@ -141,6 +142,8 @@ def start_server(py_server_port, mm_server_port): """ server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) server.register_instance(Methods(mm_server_port)) + _, port = server.server_address + queue.put(port) server.serve_forever() @@ -152,15 +155,14 @@ class Server: :param port: python server port. """ - def __init__(self, port): - self.py_server_port = get_free_ports(1)[0] if port == 0 else port + def __init__(self, port=0): self.mm_server_port = create_mm_server("0.0.0.0", 0) + q = Queue() self.proc = threading.Thread( - target=start_server, - args=(self.py_server_port, self.mm_server_port), - daemon=True, + target=start_server, args=(port, self.mm_server_port, q), daemon=True, ) self.proc.start() + self.py_server_port = q.get() class Client: