You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

group.py 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import time
  10. from contextlib import contextmanager
  11. from typing import List, Optional, Tuple
  12. from mprop import mproperty
  13. from ..device import _sh, set_default_device, what_is_xpu
  14. from ..random import seed
  15. from .server import Client, Server
  16. class StaticData:
  17. server = None
  18. client = None
  19. master_ip = None
  20. py_server_port = None
  21. mm_server_port = None
  22. world_size = None
  23. proc_rank = None
  24. device = None
  25. backend = None
  26. device_type = None
  27. machine_ranks = None
  28. _sd = None
  29. class Group:
  30. r"""Include ranked nodes running collective communication (See :mod:`~.functional.distributed`).
  31. By default collectives operate on the default group (also called ``WORLD``)
  32. and require all processes to enter the distributed function call.
  33. Args:
  34. proc_ranks: rank list of the group, the first one is root rank.
  35. """
  36. def __init__(self, proc_ranks):
  37. if len(proc_ranks) == 0: # empty group
  38. self.proc_ranks = None
  39. self.stream = None
  40. else:
  41. self.reset(proc_ranks)
  42. def reset(self, proc_ranks):
  43. self.check(proc_ranks)
  44. self.proc_ranks = proc_ranks
  45. self.stream = _sh.get_next()
  46. def check(self, proc_ranks):
  47. assert _sd is not None, "please call init_process_group first"
  48. for rank in proc_ranks:
  49. assert isinstance(rank, int)
  50. assert rank >= 0 and rank < _sd.world_size
  51. assert _sd.proc_rank in proc_ranks
  52. @property
  53. def size(self):
  54. assert len(self.proc_ranks) > 0, "invalid group"
  55. return len(self.proc_ranks)
  56. @property
  57. def key(self):
  58. assert len(self.proc_ranks) > 0, "invalid group"
  59. return ",".join(map(str, self.proc_ranks))
  60. @property
  61. def rank(self):
  62. assert len(self.proc_ranks) > 0, "invalid group"
  63. return self.proc_ranks.index(_sd.proc_rank)
  64. @property
  65. def comp_node(self):
  66. assert len(self.proc_ranks) > 0, "invalid group"
  67. return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream)
  68. @property
  69. def is_single_machine(self):
  70. if self.is_single_machine_cache is not None:
  71. return self.is_single_machine_cache
  72. assert _sd is not None, "please call init_process_group first"
  73. for rank in self.proc_ranks:
  74. if _sd.machine_ranks is None or rank not in _sd.machine_ranks:
  75. self.is_single_machine_cache = False
  76. return False
  77. self.is_single_machine_cache = True
  78. return True
  79. WORLD = Group([])
  80. _devices = {"gpu", "cuda", "rocm"}
  81. _backends = {"nccl", "rccl", "shm", "auto"}
  82. def init_process_group(
  83. master_ip: str,
  84. port: int,
  85. world_size: int,
  86. rank: int,
  87. device: int,
  88. backend: Optional[str] = "auto",
  89. device_type: str = "xpu",
  90. ) -> None:
  91. r"""Initialize the distributed process group and specify the device used in the current process
  92. Args:
  93. master_ip: ip address of the master node.
  94. port: port available for all processes to communicate.
  95. world_size: total number of processes participating in the job.
  96. rank: rank of the current process.
  97. device: the GPU device id to bind this process to.
  98. backend: communicator backend, currently support 'nccl' and 'shm'.
  99. """
  100. physical_device_type = what_is_xpu() if device_type == "xpu" else device_type
  101. if not isinstance(master_ip, str):
  102. raise TypeError("Expect type str but got {}".format(type(master_ip)))
  103. if not isinstance(port, int):
  104. raise TypeError("Expect type int but got {}".format(type(port)))
  105. if not isinstance(world_size, int):
  106. raise TypeError("Expect type int but got {}".format(type(world_size)))
  107. if not isinstance(rank, int):
  108. raise TypeError("Expect type int but got {}".format(type(rank)))
  109. if not isinstance(device, int):
  110. raise TypeError("Expect type int but got {}".format(type(backend)))
  111. if backend not in _backends:
  112. raise ValueError(
  113. "backend should be one of {} but got {}".format(_backends, backend)
  114. )
  115. if physical_device_type not in _devices:
  116. raise ValueError(
  117. "{} is not a valid distributed device type".format(device_type)
  118. )
  119. global _sd
  120. assert _sd is None, "init_process_group should be called only once"
  121. _sd = StaticData()
  122. assert world_size > 1
  123. assert rank >= 0 and rank < world_size
  124. assert port > 0
  125. _sd.client = Client(master_ip, port)
  126. _sd.master_ip = master_ip
  127. _sd.py_server_port = port
  128. _sd.mm_server_port = _sd.client.get_mm_server_port()
  129. _sd.world_size = world_size
  130. _sd.proc_rank = rank
  131. _sd.device = device
  132. _sd.backend = backend
  133. _sd.device_type = device_type
  134. WORLD.reset(list(range(world_size)))
  135. set_default_device("{}{}".format(device_type, device))
  136. seed(int(time.time()) + rank)
  137. def _set_machine_ranks(ranks) -> None:
  138. global _sd
  139. assert _sd is not None
  140. _sd.machine_ranks = ranks
  141. @contextmanager
  142. def override_backend(new_backend: str):
  143. r"""Override distributed backend
  144. Args:
  145. new_backend: communicator backend set in this context.
  146. """
  147. global _sd
  148. assert _sd, "please call init_process_group first"
  149. old_backend = _sd.backend
  150. _sd.backend = new_backend
  151. try:
  152. yield
  153. finally:
  154. _sd.backend = old_backend
  155. def is_distributed() -> bool:
  156. r"""Return True if the distributed process group has been initialized."""
  157. return _sd is not None
  158. def get_rank() -> int:
  159. r"""Get the rank of the current process."""
  160. return _sd.proc_rank if _sd is not None else 0
  161. def get_world_size() -> int:
  162. r"""Get the total number of processes participating in the job."""
  163. return _sd.world_size if _sd is not None else 1
  164. def get_backend() -> str:
  165. r"""Get the backend str."""
  166. assert _sd is not None, "please call init_process_group first"
  167. return _sd.backend if _sd is not None else None
  168. def get_py_server_addr() -> Tuple[str, int]:
  169. r"""Get master_ip and port of python XML RPC server."""
  170. assert _sd is not None, "please call init_process_group first"
  171. return _sd.master_ip, _sd.py_server_port
  172. def get_mm_server_addr() -> Tuple[str, int]:
  173. r"""Get master_ip and port of C++ mm_server."""
  174. assert _sd is not None, "please call init_process_group first"
  175. return _sd.master_ip, _sd.mm_server_port
  176. def get_client() -> Client:
  177. r"""Get client of python XML RPC server."""
  178. assert _sd is not None, "please call init_process_group first"
  179. return _sd.client
  180. def new_group(proc_ranks: List[int]) -> Group:
  181. r"""Build a subgroup containing certain ranks."""
  182. return Group(proc_ranks)
  183. def group_barrier(group: Group = WORLD) -> None:
  184. r"""Block until all ranks in the group reach this barrier."""
  185. # if running with single node, skip it
  186. if _sd is None:
  187. return
  188. assert isinstance(group, Group)
  189. _sd.client.group_barrier(group.key, group.size)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台