|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import functools
- import socket
- from typing import Callable, List, Optional
-
- import megengine._internal as mgb
-
- from ..core import set_default_device
-
- _master_ip = None
- _master_port = 0
- _world_size = 0
- _rank = 0
- _backend = None
- _group_id = 0
-
-
- def init_process_group(
- master_ip: str,
- master_port: int,
- world_size: int,
- rank: int,
- dev: int,
- backend: Optional[str] = "nccl",
- ) -> None:
- """Initialize the distributed process group, and also specify the device used in the current process.
-
- :param master_ip: IP address of the master node.
- :param master_port: Port available for all processes to communicate.
- :param world_size: Total number of processes participating in the job.
- :param rank: Rank of the current process.
- :param dev: The GPU device id to bind this process to.
- :param backend: Communicator backend, currently support 'nccl' and 'ucx'
- """
- global _master_ip # pylint: disable=global-statement
- global _master_port # pylint: disable=global-statement
- global _world_size # pylint: disable=global-statement
- global _rank # pylint: disable=global-statement
- global _backend # pylint: disable=global-statement
- global _group_id # pylint: disable=global-statement
-
- if not isinstance(master_ip, str):
- raise TypeError("Expect type str but got {}".format(type(master_ip)))
- if not isinstance(master_port, int):
- raise TypeError("Expect type int but got {}".format(type(master_port)))
- if not isinstance(world_size, int):
- raise TypeError("Expect type int but got {}".format(type(world_size)))
- if not isinstance(rank, int):
- raise TypeError("Expect type int but got {}".format(type(rank)))
- if not isinstance(backend, str):
- raise TypeError("Expect type str but got {}".format(type(backend)))
-
- _master_ip = master_ip
- _master_port = master_port
- _world_size = world_size
- _rank = rank
- _backend = backend
- _group_id = 0
-
- set_default_device(mgb.comp_node("gpu" + str(dev)))
-
- if rank == 0:
- _master_port = mgb.config.create_mm_server("0.0.0.0", master_port)
- if _master_port == -1:
- raise Exception("Failed to start server on port {}".format(master_port))
- else:
- assert master_port > 0, "master_port must be specified for non-zero rank"
-
-
- def is_distributed() -> bool:
- """Return True if the distributed process group has been initialized"""
- return _world_size is not None and _world_size > 1
-
-
- def get_master_ip() -> str:
- """Get the IP address of the master node"""
- return str(_master_ip)
-
-
- def get_master_port() -> int:
- """Get the port of the rpc server on the master node"""
- return _master_port
-
-
- def get_world_size() -> int:
- """Get the total number of processes participating in the job"""
- return _world_size
-
-
- def get_rank() -> int:
- """Get the rank of the current process"""
- return _rank
-
-
- def get_backend() -> str:
- """Get the backend str"""
- return str(_backend)
-
-
- def get_group_id() -> int:
- """Get group id for collective communication"""
- global _group_id
- _group_id += 1
- return _group_id
-
-
- def group_barrier() -> None:
- """Block until all ranks in the group reach this barrier"""
- mgb.config.group_barrier(_master_ip, _master_port, _world_size, _rank)
-
-
- def synchronized(func: Callable):
- """Decorator. Decorated function will synchronize when finished.
- Specifically, we use this to prevent data race during hub.load"""
-
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- if not is_distributed():
- return func(*args, **kwargs)
-
- ret = func(*args, **kwargs)
- group_barrier()
- return ret
-
- return wrapper
-
-
- def get_free_ports(num: int) -> List[int]:
- """Get one or more free ports.
- """
- socks, ports = [], []
- for i in range(num):
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.bind(("", 0))
- socks.append(sock)
- ports.append(sock.getsockname()[1])
- for sock in socks:
- sock.close()
- return ports
|