From 0adf49b137079bcfcf121486374a2626155990c0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 14 Jan 2021 12:36:04 +0800 Subject: [PATCH] fix(mge/distributed): fix deadlock by mixing thread and fork GitOrigin-RevId: c138cb9c280aeb37d83d2ca31549df0661ebbbc9 --- imperative/python/megengine/distributed/launcher.py | 2 +- imperative/python/megengine/distributed/server.py | 21 +++++++++++---------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/imperative/python/megengine/distributed/launcher.py b/imperative/python/megengine/distributed/launcher.py index ad1dde40..31fceebc 100644 --- a/imperative/python/megengine/distributed/launcher.py +++ b/imperative/python/megengine/distributed/launcher.py @@ -12,7 +12,7 @@ import multiprocessing as mp from ..core._imperative_rt.core2 import sync from .group import group_barrier, init_process_group from .helper import get_device_count_by_fork -from .server import Server +from .server import Client, Server from .util import get_free_ports diff --git a/imperative/python/megengine/distributed/server.py b/imperative/python/megengine/distributed/server.py index 955d28df..433c9017 100644 --- a/imperative/python/megengine/distributed/server.py +++ b/imperative/python/megengine/distributed/server.py @@ -6,11 +6,11 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import multiprocessing as mp import threading import time from collections import defaultdict from functools import partial -from queue import Queue from socketserver import ThreadingMixIn from xmlrpc.client import ServerProxy from xmlrpc.server import SimpleXMLRPCServer @@ -133,7 +133,7 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): pass -def _start_server(py_server_port, mm_server_port, queue): +def _start_server(py_server_port, queue): """ Start python distributed server and multiple machine server. @@ -142,10 +142,11 @@ def _start_server(py_server_port, mm_server_port, queue): :param queue: server port will put in this queue, puts exception when process fails. """ try: + mm_server_port = create_mm_server("0.0.0.0", 0) server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) server.register_instance(Methods(mm_server_port)) - _, port = server.server_address - queue.put(port) + _, py_server_port = server.server_address + queue.put((py_server_port, mm_server_port)) server.serve_forever() except Exception as e: queue.put(e) @@ -160,17 +161,17 @@ class Server: """ def __init__(self, port=0): - self.mm_server_port = create_mm_server("0.0.0.0", 0) - q = Queue() - self.proc = threading.Thread( - target=_start_server, args=(port, self.mm_server_port, q), daemon=True, - ) + q = mp.Queue() + self.proc = mp.Process(target=_start_server, args=(port, q), daemon=True) self.proc.start() ret = q.get() if isinstance(ret, Exception): raise ret else: - self.py_server_port = ret + self.py_server_port, self.mm_server_port = ret + + def __del__(self): + self.proc.terminate() class Client: