You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

group.py 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import time
  10. from typing import List, Optional, Tuple
  11. from ..device import set_default_device, what_is_xpu
  12. from ..random import seed
  13. from .server import Client, Server
  14. class StaticData:
  15. server = None
  16. client = None
  17. master_ip = None
  18. py_server_port = None
  19. mm_server_port = None
  20. world_size = None
  21. proc_rank = None
  22. device = None
  23. backend = None
  24. next_stream = None
  25. device_type = None
  26. _sd = None
  27. class Group:
  28. r"""
  29. Include ranked nodes running collective communication (See :mod:`~.functional.distributed`).
  30. By default collectives operate on the default group (also called ``WORLD``)
  31. and require all processes to enter the distributed function call.
  32. :param proc_ranks: rank list of the group, the first one is root rank.
  33. """
  34. def __init__(self, proc_ranks):
  35. if len(proc_ranks) == 0: # empty group
  36. self.proc_ranks = None
  37. self.stream = None
  38. else:
  39. self.reset(proc_ranks)
  40. def reset(self, proc_ranks):
  41. self.check(proc_ranks)
  42. self.proc_ranks = proc_ranks
  43. self.stream = _sd.next_stream
  44. _sd.next_stream += 1
  45. def check(self, proc_ranks):
  46. assert _sd is not None, "please call init_process_group first"
  47. for rank in proc_ranks:
  48. assert isinstance(rank, int)
  49. assert rank >= 0 and rank < _sd.world_size
  50. assert _sd.proc_rank in proc_ranks
  51. @property
  52. def size(self):
  53. assert len(self.proc_ranks) > 0, "invalid group"
  54. return len(self.proc_ranks)
  55. @property
  56. def key(self):
  57. assert len(self.proc_ranks) > 0, "invalid group"
  58. return ",".join(map(str, self.proc_ranks))
  59. @property
  60. def rank(self):
  61. assert len(self.proc_ranks) > 0, "invalid group"
  62. return self.proc_ranks.index(_sd.proc_rank)
  63. @property
  64. def comp_node(self):
  65. assert len(self.proc_ranks) > 0, "invalid group"
  66. return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream)
  67. WORLD = Group([])
  68. _device2backend = {
  69. "gpu": "nccl",
  70. "cuda": "nccl",
  71. "rocm": "rccl",
  72. }
  73. _backends = {"nccl", "rccl", "ucx"}
  74. def init_process_group(
  75. master_ip: str,
  76. port: int,
  77. world_size: int,
  78. rank: int,
  79. device: int,
  80. backend: Optional[str] = None,
  81. device_type: str = "xpu",
  82. ) -> None:
  83. """
  84. Initialize the distributed process group and specify the device used in the current process
  85. :param master_ip: ip address of the master node.
  86. :param port: port available for all processes to communicate.
  87. :param world_size: total number of processes participating in the job.
  88. :param rank: rank of the current process.
  89. :param device: the GPU device id to bind this process to.
  90. :param backend: communicator backend, currently support 'nccl' and 'ucx'.
  91. """
  92. physical_device_type = what_is_xpu() if device_type == "xpu" else device_type
  93. backend = _device2backend[physical_device_type] if backend is None else backend
  94. if not isinstance(master_ip, str):
  95. raise TypeError("Expect type str but got {}".format(type(master_ip)))
  96. if not isinstance(port, int):
  97. raise TypeError("Expect type int but got {}".format(type(port)))
  98. if not isinstance(world_size, int):
  99. raise TypeError("Expect type int but got {}".format(type(world_size)))
  100. if not isinstance(rank, int):
  101. raise TypeError("Expect type int but got {}".format(type(rank)))
  102. if not isinstance(device, int):
  103. raise TypeError("Expect type int but got {}".format(type(backend)))
  104. if backend not in _backends:
  105. raise ValueError(
  106. "backend should be one of {} but got {}".format(_backends, backend)
  107. )
  108. if physical_device_type not in _device2backend:
  109. raise ValueError(
  110. "{} is not a valid distributed device type".format(device_type)
  111. )
  112. global _sd
  113. assert _sd is None, "init_process_group should be called only once"
  114. _sd = StaticData()
  115. assert world_size > 1
  116. assert rank >= 0 and rank < world_size
  117. assert port > 0
  118. _sd.client = Client(master_ip, port)
  119. _sd.master_ip = master_ip
  120. _sd.py_server_port = port
  121. _sd.mm_server_port = _sd.client.get_mm_server_port()
  122. _sd.world_size = world_size
  123. _sd.proc_rank = rank
  124. _sd.device = device
  125. _sd.backend = backend
  126. _sd.next_stream = 1
  127. _sd.device_type = device_type
  128. WORLD.reset(list(range(world_size)))
  129. set_default_device("{}{}".format(device_type, device))
  130. seed(int(time.time()) + rank)
  131. def is_distributed() -> bool:
  132. """Return True if the distributed process group has been initialized."""
  133. return _sd is not None
  134. def get_rank() -> int:
  135. """Get the rank of the current process."""
  136. return _sd.proc_rank if _sd is not None else 0
  137. def get_world_size() -> int:
  138. """Get the total number of processes participating in the job."""
  139. return _sd.world_size if _sd is not None else 1
  140. def get_backend() -> str:
  141. """Get the backend str."""
  142. assert _sd is not None, "please call init_process_group first"
  143. return _sd.backend if _sd is not None else None
  144. def get_py_server_addr() -> Tuple[str, int]:
  145. """Get master_ip and port of python XML RPC server."""
  146. assert _sd is not None, "please call init_process_group first"
  147. return _sd.master_ip, _sd.py_server_port
  148. def get_mm_server_addr() -> Tuple[str, int]:
  149. """Get master_ip and port of C++ mm_server."""
  150. assert _sd is not None, "please call init_process_group first"
  151. return _sd.master_ip, _sd.mm_server_port
  152. def get_client() -> Client:
  153. """Get client of python XML RPC server."""
  154. assert _sd is not None, "please call init_process_group first"
  155. return _sd.client
  156. def new_group(proc_ranks: List[int]) -> Group:
  157. """Build a subgroup containing certain ranks."""
  158. return Group(proc_ranks)
  159. def group_barrier(group: Group = WORLD) -> None:
  160. """Block until all ranks in the group reach this barrier."""
  161. # if running with single node, skip it
  162. if _sd is None:
  163. return
  164. assert isinstance(group, Group)
  165. _sd.client.group_barrier(group.key, group.size)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台