You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

group.py 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import time
  10. from contextlib import contextmanager
  11. from typing import List, Optional, Tuple
  12. from mprop import mproperty
  13. from ..device import _sh, set_default_device, what_is_xpu
  14. from ..random import seed
  15. from .server import Client, Server
  16. class StaticData:
  17. server = None
  18. client = None
  19. master_ip = None
  20. py_server_port = None
  21. mm_server_port = None
  22. world_size = None
  23. proc_rank = None
  24. device = None
  25. backend = None
  26. device_type = None
  27. machine_ranks = None
  28. _sd = None
  29. class Group:
  30. r"""Include ranked nodes running collective communication (See :mod:`~.functional.distributed`).
  31. By default collectives operate on the default group (also called ``WORLD``)
  32. and require all processes to enter the distributed function call.
  33. Args:
  34. proc_ranks: rank list of the group, the first one is root rank.
  35. """
  36. def __init__(self, proc_ranks):
  37. if len(proc_ranks) == 0: # empty group
  38. self.proc_ranks = None
  39. self.stream = None
  40. else:
  41. self.reset(proc_ranks)
  42. def reset(self, proc_ranks):
  43. self.check(proc_ranks)
  44. self.proc_ranks = proc_ranks
  45. self.is_single_machine_cache = None
  46. self.stream = _sh.get_next()
  47. def check(self, proc_ranks):
  48. assert _sd is not None, "please call init_process_group first"
  49. for rank in proc_ranks:
  50. assert isinstance(rank, int)
  51. assert rank >= 0 and rank < _sd.world_size
  52. assert _sd.proc_rank in proc_ranks
  53. @property
  54. def size(self):
  55. assert len(self.proc_ranks) > 0, "invalid group"
  56. return len(self.proc_ranks)
  57. @property
  58. def key(self):
  59. assert len(self.proc_ranks) > 0, "invalid group"
  60. return ",".join(map(str, self.proc_ranks))
  61. @property
  62. def rank(self):
  63. assert len(self.proc_ranks) > 0, "invalid group"
  64. return self.proc_ranks.index(_sd.proc_rank)
  65. @property
  66. def comp_node(self):
  67. assert len(self.proc_ranks) > 0, "invalid group"
  68. return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream)
  69. @property
  70. def is_single_machine(self):
  71. if self.is_single_machine_cache is not None:
  72. return self.is_single_machine_cache
  73. assert _sd is not None, "please call init_process_group first"
  74. for rank in self.proc_ranks:
  75. if _sd.machine_ranks is None or rank not in _sd.machine_ranks:
  76. self.is_single_machine_cache = False
  77. return False
  78. self.is_single_machine_cache = True
  79. return True
  80. WORLD = Group([])
  81. _devices = {"gpu", "cuda", "rocm"}
  82. _backends = {"nccl", "rccl", "shm", "auto"}
  83. def init_process_group(
  84. master_ip: str,
  85. port: int,
  86. world_size: int,
  87. rank: int,
  88. device: int,
  89. backend: Optional[str] = "auto",
  90. device_type: str = "xpu",
  91. ) -> None:
  92. r"""Initialize the distributed process group and specify the device used in the current process
  93. Args:
  94. master_ip: ip address of the master node.
  95. port: port available for all processes to communicate.
  96. world_size: total number of processes participating in the job.
  97. rank: rank of the current process.
  98. device: the GPU device id to bind this process to.
  99. backend: communicator backend, currently support 'nccl' and 'shm'.
  100. """
  101. physical_device_type = what_is_xpu() if device_type == "xpu" else device_type
  102. if not isinstance(master_ip, str):
  103. raise TypeError("Expect type str but got {}".format(type(master_ip)))
  104. if not isinstance(port, int):
  105. raise TypeError("Expect type int but got {}".format(type(port)))
  106. if not isinstance(world_size, int):
  107. raise TypeError("Expect type int but got {}".format(type(world_size)))
  108. if not isinstance(rank, int):
  109. raise TypeError("Expect type int but got {}".format(type(rank)))
  110. if not isinstance(device, int):
  111. raise TypeError("Expect type int but got {}".format(type(backend)))
  112. if backend not in _backends:
  113. raise ValueError(
  114. "backend should be one of {} but got {}".format(_backends, backend)
  115. )
  116. if physical_device_type not in _devices:
  117. raise ValueError(
  118. "{} is not a valid distributed device type".format(device_type)
  119. )
  120. global _sd
  121. assert _sd is None, "init_process_group should be called only once"
  122. _sd = StaticData()
  123. assert world_size > 1
  124. assert rank >= 0 and rank < world_size
  125. assert port > 0
  126. _sd.client = Client(master_ip, port)
  127. _sd.master_ip = master_ip
  128. _sd.py_server_port = port
  129. _sd.mm_server_port = _sd.client.get_mm_server_port()
  130. _sd.world_size = world_size
  131. _sd.proc_rank = rank
  132. _sd.device = device
  133. _sd.backend = backend
  134. _sd.device_type = device_type
  135. WORLD.reset(list(range(world_size)))
  136. set_default_device("{}{}".format(device_type, device))
  137. seed(int(time.time()) + rank)
  138. def _set_machine_ranks(ranks) -> None:
  139. global _sd
  140. assert _sd is not None
  141. _sd.machine_ranks = ranks
  142. @contextmanager
  143. def override_backend(new_backend: str):
  144. r"""Override distributed backend
  145. Args:
  146. new_backend: communicator backend set in this context.
  147. """
  148. global _sd
  149. assert _sd, "please call init_process_group first"
  150. old_backend = _sd.backend
  151. _sd.backend = new_backend
  152. try:
  153. yield
  154. finally:
  155. _sd.backend = old_backend
  156. def is_distributed() -> bool:
  157. r"""Return True if the distributed process group has been initialized."""
  158. return _sd is not None
  159. def get_rank() -> int:
  160. r"""Get the rank of the current process."""
  161. return _sd.proc_rank if _sd is not None else 0
  162. def get_world_size() -> int:
  163. r"""Get the total number of processes participating in the job."""
  164. return _sd.world_size if _sd is not None else 1
  165. def get_backend() -> str:
  166. r"""Get the backend str."""
  167. assert _sd is not None, "please call init_process_group first"
  168. return _sd.backend if _sd is not None else None
  169. def get_py_server_addr() -> Tuple[str, int]:
  170. r"""Get master_ip and port of python XML RPC server."""
  171. assert _sd is not None, "please call init_process_group first"
  172. return _sd.master_ip, _sd.py_server_port
  173. def get_mm_server_addr() -> Tuple[str, int]:
  174. r"""Get master_ip and port of C++ mm_server."""
  175. assert _sd is not None, "please call init_process_group first"
  176. return _sd.master_ip, _sd.mm_server_port
  177. def get_client() -> Client:
  178. r"""Get client of python XML RPC server."""
  179. assert _sd is not None, "please call init_process_group first"
  180. return _sd.client
  181. def new_group(proc_ranks: List[int]) -> Group:
  182. r"""Build a subgroup containing certain ranks."""
  183. return Group(proc_ranks)
  184. def group_barrier(group: Group = WORLD) -> None:
  185. r"""Block until all ranks in the group reach this barrier."""
  186. # if running with single node, skip it
  187. if _sd is None:
  188. return
  189. assert isinstance(group, Group)
  190. _sd.client.group_barrier(group.key, group.size)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台