diff --git a/imperative/python/megengine/distributed/launcher.py b/imperative/python/megengine/distributed/launcher.py index ad1dde40..31fceebc 100644 --- a/imperative/python/megengine/distributed/launcher.py +++ b/imperative/python/megengine/distributed/launcher.py @@ -12,7 +12,7 @@ import multiprocessing as mp from ..core._imperative_rt.core2 import sync from .group import group_barrier, init_process_group from .helper import get_device_count_by_fork -from .server import Server +from .server import Client, Server from .util import get_free_ports diff --git a/imperative/python/megengine/distributed/server.py b/imperative/python/megengine/distributed/server.py index 955d28df..433c9017 100644 --- a/imperative/python/megengine/distributed/server.py +++ b/imperative/python/megengine/distributed/server.py @@ -6,11 +6,11 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import multiprocessing as mp import threading import time from collections import defaultdict from functools import partial -from queue import Queue from socketserver import ThreadingMixIn from xmlrpc.client import ServerProxy from xmlrpc.server import SimpleXMLRPCServer @@ -133,7 +133,7 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): pass -def _start_server(py_server_port, mm_server_port, queue): +def _start_server(py_server_port, queue): """ Start python distributed server and multiple machine server. @@ -142,10 +142,11 @@ def _start_server(py_server_port, mm_server_port, queue): :param queue: server port will put in this queue, puts exception when process fails. """ try: + mm_server_port = create_mm_server("0.0.0.0", 0) server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) server.register_instance(Methods(mm_server_port)) - _, port = server.server_address - queue.put(port) + _, py_server_port = server.server_address + queue.put((py_server_port, mm_server_port)) server.serve_forever() except Exception as e: queue.put(e) @@ -160,17 +161,17 @@ class Server: """ def __init__(self, port=0): - self.mm_server_port = create_mm_server("0.0.0.0", 0) - q = Queue() - self.proc = threading.Thread( - target=_start_server, args=(port, self.mm_server_port, q), daemon=True, - ) + q = mp.Queue() + self.proc = mp.Process(target=_start_server, args=(port, q), daemon=True) self.proc.start() ret = q.get() if isinstance(ret, Exception): raise ret else: - self.py_server_port = ret + self.py_server_port, self.mm_server_port = ret + + def __del__(self): + self.proc.terminate() class Client: