You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

group.py 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # -*- coding: utf-8 -*-
  2. import time
  3. from contextlib import contextmanager
  4. from typing import List, Optional, Tuple
  5. from mprop import mproperty
  6. from ..device import _sh, set_default_device, what_is_xpu
  7. from ..random import seed
  8. from .server import Client, Server
  9. class StaticData:
  10. server = None
  11. client = None
  12. master_ip = None
  13. py_server_port = None
  14. mm_server_port = None
  15. world_size = None
  16. proc_rank = None
  17. device = None
  18. backend = None
  19. device_type = None
  20. machine_ranks = None
  21. _sd = None
  22. class Group:
  23. r"""Include ranked nodes running collective communication (See :mod:`~.functional.distributed`).
  24. By default collectives operate on the default group (also called ``WORLD``)
  25. and require all processes to enter the distributed function call.
  26. Args:
  27. proc_ranks: rank list of the group, the first one is root rank.
  28. """
  29. def __init__(self, proc_ranks):
  30. if len(proc_ranks) == 0: # empty group
  31. self.proc_ranks = None
  32. self.stream = None
  33. else:
  34. self.reset(proc_ranks)
  35. def reset(self, proc_ranks):
  36. self.check(proc_ranks)
  37. self.proc_ranks = proc_ranks
  38. self.is_single_machine_cache = None
  39. self.stream = _sh.get_next()
  40. def check(self, proc_ranks):
  41. assert _sd is not None, "please call init_process_group first"
  42. for rank in proc_ranks:
  43. assert isinstance(rank, int)
  44. assert rank >= 0 and rank < _sd.world_size
  45. assert _sd.proc_rank in proc_ranks
  46. @property
  47. def size(self):
  48. assert len(self.proc_ranks) > 0, "invalid group"
  49. return len(self.proc_ranks)
  50. @property
  51. def key(self):
  52. assert len(self.proc_ranks) > 0, "invalid group"
  53. return ",".join(map(str, self.proc_ranks))
  54. @property
  55. def rank(self):
  56. assert len(self.proc_ranks) > 0, "invalid group"
  57. return self.proc_ranks.index(_sd.proc_rank)
  58. @property
  59. def comp_node(self):
  60. assert len(self.proc_ranks) > 0, "invalid group"
  61. return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream)
  62. @property
  63. def is_single_machine(self):
  64. if self.is_single_machine_cache is not None:
  65. return self.is_single_machine_cache
  66. assert _sd is not None, "please call init_process_group first"
  67. for rank in self.proc_ranks:
  68. if _sd.machine_ranks is None or rank not in _sd.machine_ranks:
  69. self.is_single_machine_cache = False
  70. return False
  71. self.is_single_machine_cache = True
  72. return True
  73. WORLD = Group([])
  74. _devices = {"gpu", "cuda", "rocm"}
  75. _backends = {"nccl", "rccl", "shm", "auto"}
  76. def init_process_group(
  77. master_ip: str,
  78. port: int,
  79. world_size: int,
  80. rank: int,
  81. device: int,
  82. backend: Optional[str] = "auto",
  83. device_type: str = "xpu",
  84. ) -> None:
  85. r"""Initialize the distributed process group and specify the device used in the current process
  86. Args:
  87. master_ip: ip address of the master node.
  88. port: port available for all processes to communicate.
  89. world_size: total number of processes participating in the job.
  90. rank: rank of the current process.
  91. device: the GPU device id to bind this process to.
  92. backend: communicator backend, currently support 'nccl' and 'shm'.
  93. """
  94. physical_device_type = what_is_xpu() if device_type == "xpu" else device_type
  95. if not isinstance(master_ip, str):
  96. raise TypeError("Expect type str but got {}".format(type(master_ip)))
  97. if not isinstance(port, int):
  98. raise TypeError("Expect type int but got {}".format(type(port)))
  99. if not isinstance(world_size, int):
  100. raise TypeError("Expect type int but got {}".format(type(world_size)))
  101. if not isinstance(rank, int):
  102. raise TypeError("Expect type int but got {}".format(type(rank)))
  103. if not isinstance(device, int):
  104. raise TypeError("Expect type int but got {}".format(type(backend)))
  105. if backend not in _backends:
  106. raise ValueError(
  107. "backend should be one of {} but got {}".format(_backends, backend)
  108. )
  109. if physical_device_type not in _devices:
  110. raise ValueError(
  111. "{} is not a valid distributed device type".format(device_type)
  112. )
  113. global _sd
  114. assert _sd is None, "init_process_group should be called only once"
  115. _sd = StaticData()
  116. assert world_size > 1
  117. assert rank >= 0 and rank < world_size
  118. assert port > 0
  119. _sd.client = Client(master_ip, port)
  120. _sd.master_ip = master_ip
  121. _sd.py_server_port = port
  122. _sd.mm_server_port = _sd.client.get_mm_server_port()
  123. _sd.world_size = world_size
  124. _sd.proc_rank = rank
  125. _sd.device = device
  126. _sd.backend = backend
  127. _sd.device_type = device_type
  128. WORLD.reset(list(range(world_size)))
  129. set_default_device("{}{}".format(device_type, device))
  130. seed(int(time.time()) + rank)
  131. if backend == "nccl":
  132. # init nccl env
  133. from ..core._imperative_rt.common import init_nccl_env
  134. group_barrier()
  135. init_nccl_env(master_ip, _sd.mm_server_port, world_size, rank, 0)
  136. def _set_machine_ranks(ranks) -> None:
  137. global _sd
  138. assert _sd is not None
  139. _sd.machine_ranks = ranks
  140. @contextmanager
  141. def override_backend(new_backend: str):
  142. r"""Override distributed backend
  143. Args:
  144. new_backend: communicator backend set in this context.
  145. """
  146. global _sd
  147. assert _sd, "please call init_process_group first"
  148. old_backend = _sd.backend
  149. _sd.backend = new_backend
  150. try:
  151. yield
  152. finally:
  153. _sd.backend = old_backend
  154. def is_distributed() -> bool:
  155. r"""Return True if the distributed process group has been initialized."""
  156. return _sd is not None
  157. def get_rank() -> int:
  158. r"""Get the rank of the current process."""
  159. return _sd.proc_rank if _sd is not None else 0
  160. def get_world_size() -> int:
  161. r"""Get the total number of processes participating in the job."""
  162. return _sd.world_size if _sd is not None else 1
  163. def get_backend() -> str:
  164. r"""Get the backend str."""
  165. assert _sd is not None, "please call init_process_group first"
  166. return _sd.backend if _sd is not None else None
  167. def get_py_server_addr() -> Tuple[str, int]:
  168. r"""Get master_ip and port of python XML RPC server."""
  169. assert _sd is not None, "please call init_process_group first"
  170. return _sd.master_ip, _sd.py_server_port
  171. def get_mm_server_addr() -> Tuple[str, int]:
  172. r"""Get master_ip and port of C++ mm_server."""
  173. assert _sd is not None, "please call init_process_group first"
  174. return _sd.master_ip, _sd.mm_server_port
  175. def get_client() -> Client:
  176. r"""Get client of python XML RPC server."""
  177. assert _sd is not None, "please call init_process_group first"
  178. return _sd.client
  179. def new_group(proc_ranks: List[int]) -> Group:
  180. r"""Build a subgroup containing certain ranks."""
  181. return Group(proc_ranks)
  182. def group_barrier(group: Group = WORLD) -> None:
  183. r"""Block until all ranks in the group reach this barrier."""
  184. # if running with single node, skip it
  185. if _sd is None:
  186. return
  187. assert isinstance(group, Group)
  188. _sd.client.group_barrier(group.key, group.size)