|
|
@@ -7,6 +7,7 @@ |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import functools |
|
|
|
import socket |
|
|
|
from typing import Callable, Optional |
|
|
|
|
|
|
|
import megengine._internal as mgb |
|
|
@@ -110,7 +111,7 @@ def synchronized(func: Callable): |
|
|
|
Specifically, we use this to prevent data race during hub.load""" |
|
|
|
|
|
|
|
@functools.wraps(func) |
|
|
|
def _(*args, **kwargs): |
|
|
|
def wrapper(*args, **kwargs): |
|
|
|
if not is_distributed(): |
|
|
|
return func(*args, **kwargs) |
|
|
|
|
|
|
@@ -118,4 +119,19 @@ def synchronized(func: Callable): |
|
|
|
group_barrier() |
|
|
|
return ret |
|
|
|
|
|
|
|
return _ |
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
|
|
def get_free_ports(num: Optional[int] = 1) -> int: |
|
|
|
"""Get one or more free ports. |
|
|
|
Return an integer if num is 1, otherwise return a list of integers |
|
|
|
""" |
|
|
|
socks, ports = [], [] |
|
|
|
for i in range(num): |
|
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|
|
|
sock.bind(("", 0)) |
|
|
|
socks.append(sock) |
|
|
|
ports.append(sock.getsockname()[1]) |
|
|
|
for sock in socks: |
|
|
|
sock.close() |
|
|
|
return ports[0] if num == 1 else ports |