Browse Source

feat(mge/distributed): add multiprocess launcher

GitOrigin-RevId: 7d831d125f
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
eee3e5594c
2 changed files with 19 additions and 2 deletions
  1. +1
    -0
      python_module/megengine/distributed/__init__.py
  2. +18
    -2
      python_module/megengine/distributed/util.py

+ 1
- 0
python_module/megengine/distributed/__init__.py View File

@@ -18,6 +18,7 @@ from .functional import (
)
from .util import (
get_backend,
get_free_ports,
get_master_ip,
get_master_port,
get_rank,


+ 18
- 2
python_module/megengine/distributed/util.py View File

@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import socket
from typing import Callable, Optional

import megengine._internal as mgb
@@ -110,7 +111,7 @@ def synchronized(func: Callable):
Specifically, we use this to prevent data race during hub.load"""

@functools.wraps(func)
def _(*args, **kwargs):
def wrapper(*args, **kwargs):
if not is_distributed():
return func(*args, **kwargs)

@@ -118,4 +119,19 @@ def synchronized(func: Callable):
group_barrier()
return ret

return _
return wrapper


def get_free_ports(num: Optional[int] = 1) -> int:
"""Get one or more free ports.
Return an integer if num is 1, otherwise return a list of integers
"""
socks, ports = [], []
for i in range(num):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
socks.append(sock)
ports.append(sock.getsockname()[1])
for sock in socks:
sock.close()
return ports[0] if num == 1 else ports

Loading…
Cancel
Save