Browse Source

fix(mge/distributed): fix deadlock by mixing thread and fork

GitOrigin-RevId: c138cb9c28
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
0adf49b137
2 changed files with 12 additions and 11 deletions
  1. +1
    -1
      imperative/python/megengine/distributed/launcher.py
  2. +11
    -10
      imperative/python/megengine/distributed/server.py

+ 1
- 1
imperative/python/megengine/distributed/launcher.py View File

@@ -12,7 +12,7 @@ import multiprocessing as mp
from ..core._imperative_rt.core2 import sync from ..core._imperative_rt.core2 import sync
from .group import group_barrier, init_process_group from .group import group_barrier, init_process_group
from .helper import get_device_count_by_fork from .helper import get_device_count_by_fork
from .server import Server
from .server import Client, Server
from .util import get_free_ports from .util import get_free_ports






+ 11
- 10
imperative/python/megengine/distributed/server.py View File

@@ -6,11 +6,11 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import threading import threading
import time import time
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
from queue import Queue
from socketserver import ThreadingMixIn from socketserver import ThreadingMixIn
from xmlrpc.client import ServerProxy from xmlrpc.client import ServerProxy
from xmlrpc.server import SimpleXMLRPCServer from xmlrpc.server import SimpleXMLRPCServer
@@ -133,7 +133,7 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass pass




def _start_server(py_server_port, mm_server_port, queue):
def _start_server(py_server_port, queue):
""" """
Start python distributed server and multiple machine server. Start python distributed server and multiple machine server.


@@ -142,10 +142,11 @@ def _start_server(py_server_port, mm_server_port, queue):
:param queue: server port will put in this queue, puts exception when process fails. :param queue: server port will put in this queue, puts exception when process fails.
""" """
try: try:
mm_server_port = create_mm_server("0.0.0.0", 0)
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False)
server.register_instance(Methods(mm_server_port)) server.register_instance(Methods(mm_server_port))
_, port = server.server_address
queue.put(port)
_, py_server_port = server.server_address
queue.put((py_server_port, mm_server_port))
server.serve_forever() server.serve_forever()
except Exception as e: except Exception as e:
queue.put(e) queue.put(e)
@@ -160,17 +161,17 @@ class Server:
""" """


def __init__(self, port=0): def __init__(self, port=0):
self.mm_server_port = create_mm_server("0.0.0.0", 0)
q = Queue()
self.proc = threading.Thread(
target=_start_server, args=(port, self.mm_server_port, q), daemon=True,
)
q = mp.Queue()
self.proc = mp.Process(target=_start_server, args=(port, q), daemon=True)
self.proc.start() self.proc.start()
ret = q.get() ret = q.get()
if isinstance(ret, Exception): if isinstance(ret, Exception):
raise ret raise ret
else: else:
self.py_server_port = ret
self.py_server_port, self.mm_server_port = ret

def __del__(self):
self.proc.terminate()




class Client: class Client:


Loading…
Cancel
Save