Browse Source

refactor(distributed/server): use port 0 to get available port

GitOrigin-RevId: e367846b92
release-1.2
Megvii Engine Team 4 years ago
parent
commit
3ecded74ea
2 changed files with 23 additions and 12 deletions
  1. +15
    -6
      imperative/python/megengine/distributed/launcher.py
  2. +8
    -6
      imperative/python/megengine/distributed/server.py

+ 15
- 6
imperative/python/megengine/distributed/launcher.py View File

@@ -29,8 +29,8 @@ def launcher(func):

def wrapper(*args, **kwargs):
master_ip = "localhost"
port = get_free_ports(1)[0]
server = Server(port)
server = Server()
port = server.py_server_port

procs = []
for rank in range(n_gpus):
@@ -41,9 +41,18 @@ def launcher(func):
p.start()
procs.append(p)

for rank in range(n_gpus):
procs[rank].join()
code = procs[rank].exitcode
assert code == 0, "subprocess {} exit with code {}".format(rank, code)
ranks = [rank for rank in range(n_gpus)]

while len(ranks) > 0:
left = []
for rank in ranks:
procs[rank].join(1)
code = procs[rank].exitcode
assert (
code == 0 or code == None
), "subprocess {} exit with code {}".format(rank, code)
if code == None:
left.append(rank)
ranks = left

return wrapper

+ 8
- 6
imperative/python/megengine/distributed/server.py View File

@@ -10,6 +10,7 @@ import threading
import time
from collections import defaultdict
from functools import partial
from queue import Queue
from socketserver import ThreadingMixIn
from xmlrpc.client import ServerProxy
from xmlrpc.server import SimpleXMLRPCServer
@@ -132,7 +133,7 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass


def start_server(py_server_port, mm_server_port):
def start_server(py_server_port, mm_server_port, queue):
"""
Start python distributed server and multiple machine server.
@@ -141,6 +142,8 @@ def start_server(py_server_port, mm_server_port):
"""
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False)
server.register_instance(Methods(mm_server_port))
_, port = server.server_address
queue.put(port)
server.serve_forever()


@@ -152,15 +155,14 @@ class Server:
:param port: python server port.
"""

def __init__(self, port):
self.py_server_port = get_free_ports(1)[0] if port == 0 else port
def __init__(self, port=0):
self.mm_server_port = create_mm_server("0.0.0.0", 0)
q = Queue()
self.proc = threading.Thread(
target=start_server,
args=(self.py_server_port, self.mm_server_port),
daemon=True,
target=start_server, args=(port, self.mm_server_port, q), daemon=True,
)
self.proc.start()
self.py_server_port = q.get()


class Client:


Loading…
Cancel
Save