From eee3e5594cfd5dcc08a37eee4c9b1ddd12d483bc Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 8 May 2020 18:21:18 +0800 Subject: [PATCH] feat(mge/distributed): add multiprocess launcher GitOrigin-RevId: 7d831d125fcfa22f203ab92fb0490536abf156f7 --- python_module/megengine/distributed/__init__.py | 1 + python_module/megengine/distributed/util.py | 20 ++++++++++++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/python_module/megengine/distributed/__init__.py b/python_module/megengine/distributed/__init__.py index fb6e8033..63974cd3 100644 --- a/python_module/megengine/distributed/__init__.py +++ b/python_module/megengine/distributed/__init__.py @@ -18,6 +18,7 @@ from .functional import ( ) from .util import ( get_backend, + get_free_ports, get_master_ip, get_master_port, get_rank, diff --git a/python_module/megengine/distributed/util.py b/python_module/megengine/distributed/util.py index 52248d30..115ae326 100644 --- a/python_module/megengine/distributed/util.py +++ b/python_module/megengine/distributed/util.py @@ -7,6 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import functools +import socket from typing import Callable, Optional import megengine._internal as mgb @@ -110,7 +111,7 @@ def synchronized(func: Callable): Specifically, we use this to prevent data race during hub.load""" @functools.wraps(func) - def _(*args, **kwargs): + def wrapper(*args, **kwargs): if not is_distributed(): return func(*args, **kwargs) @@ -118,4 +119,19 @@ def synchronized(func: Callable): group_barrier() return ret - return _ + return wrapper + + +def get_free_ports(num: Optional[int] = 1) -> int: + """Get one or more free ports. + Return an integer if num is 1, otherwise return a list of integers + """ + socks, ports = [], [] + for i in range(num): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("", 0)) + socks.append(sock) + ports.append(sock.getsockname()[1]) + for sock in socks: + sock.close() + return ports[0] if num == 1 else ports