Browse Source

refactor(distributed/server): use port 0 to get available port

GitOrigin-RevId: e367846b92
release-1.2
Megvii Engine Team 4 years ago
parent
commit
3ecded74ea
2 changed files with 23 additions and 12 deletions
  1. +15
    -6
      imperative/python/megengine/distributed/launcher.py
  2. +8
    -6
      imperative/python/megengine/distributed/server.py

+ 15
- 6
imperative/python/megengine/distributed/launcher.py View File

@@ -29,8 +29,8 @@ def launcher(func):


def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
master_ip = "localhost" master_ip = "localhost"
port = get_free_ports(1)[0]
server = Server(port)
server = Server()
port = server.py_server_port


procs = [] procs = []
for rank in range(n_gpus): for rank in range(n_gpus):
@@ -41,9 +41,18 @@ def launcher(func):
p.start() p.start()
procs.append(p) procs.append(p)


for rank in range(n_gpus):
procs[rank].join()
code = procs[rank].exitcode
assert code == 0, "subprocess {} exit with code {}".format(rank, code)
ranks = [rank for rank in range(n_gpus)]

while len(ranks) > 0:
left = []
for rank in ranks:
procs[rank].join(1)
code = procs[rank].exitcode
assert (
code == 0 or code == None
), "subprocess {} exit with code {}".format(rank, code)
if code == None:
left.append(rank)
ranks = left


return wrapper return wrapper

+ 8
- 6
imperative/python/megengine/distributed/server.py View File

@@ -10,6 +10,7 @@ import threading
import time import time
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
from queue import Queue
from socketserver import ThreadingMixIn from socketserver import ThreadingMixIn
from xmlrpc.client import ServerProxy from xmlrpc.client import ServerProxy
from xmlrpc.server import SimpleXMLRPCServer from xmlrpc.server import SimpleXMLRPCServer
@@ -132,7 +133,7 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass pass




def start_server(py_server_port, mm_server_port):
def start_server(py_server_port, mm_server_port, queue):
""" """
Start python distributed server and multiple machine server. Start python distributed server and multiple machine server.
@@ -141,6 +142,8 @@ def start_server(py_server_port, mm_server_port):
""" """
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False)
server.register_instance(Methods(mm_server_port)) server.register_instance(Methods(mm_server_port))
_, port = server.server_address
queue.put(port)
server.serve_forever() server.serve_forever()




@@ -152,15 +155,14 @@ class Server:
:param port: python server port. :param port: python server port.
""" """


def __init__(self, port):
self.py_server_port = get_free_ports(1)[0] if port == 0 else port
def __init__(self, port=0):
self.mm_server_port = create_mm_server("0.0.0.0", 0) self.mm_server_port = create_mm_server("0.0.0.0", 0)
q = Queue()
self.proc = threading.Thread( self.proc = threading.Thread(
target=start_server,
args=(self.py_server_port, self.mm_server_port),
daemon=True,
target=start_server, args=(port, self.mm_server_port, q), daemon=True,
) )
self.proc.start() self.proc.start()
self.py_server_port = q.get()




class Client: class Client:


Loading…
Cancel
Save