You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

launcher.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import multiprocessing as mp
  11. import os
  12. import queue
  13. from .. import _exit
  14. from ..core._imperative_rt.core2 import full_sync
  15. from ..logger import get_logger
  16. from .group import _set_machine_ranks, group_barrier, init_process_group
  17. from .helper import _check_device_initialized, get_device_count_by_fork
  18. from .server import Client, Server
  19. WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = (
  20. "subprocess exited with code 0 but did not return a value"
  21. )
  22. def _run_wrapped(
  23. func,
  24. is_multimachine,
  25. master_ip,
  26. port,
  27. world_size,
  28. rank,
  29. dev,
  30. device_type,
  31. args,
  32. kwargs,
  33. backend,
  34. queue: mp.Queue,
  35. machine_ranks: list,
  36. ):
  37. """Init distributed process group and run wrapped function."""
  38. _check_device_initialized(device_type)
  39. init_process_group(
  40. master_ip=master_ip,
  41. port=port,
  42. world_size=world_size,
  43. rank=rank,
  44. device=dev,
  45. backend=backend,
  46. device_type=device_type,
  47. )
  48. # set NCCL_LAUNCH_MODE to avoid deadlock
  49. os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
  50. _set_machine_ranks(machine_ranks)
  51. if is_multimachine:
  52. group_barrier()
  53. ret = func(*args, **kwargs)
  54. queue.put((dev, ret))
  55. full_sync()
  56. if is_multimachine:
  57. group_barrier()
  58. _exit(0)
  59. class launcher:
  60. """Decorator for launching multiple processes in single-machine multi-gpu training.
  61. :param func: the function you want to launch in distributed mode.
  62. :param n_gpus: how many devices each node.
  63. :param world_size: how many devices totally.
  64. :param rank_start: start number for rank.
  65. :param master_ip: ip address for master node (where the rank 0 is).
  66. :param port: server port for distributed server.
  67. :param backend: set default collective communication backend.
  68. """
  69. def __new__(cls, *args, **kwargs):
  70. if not args:
  71. return functools.partial(cls, **kwargs)
  72. return super().__new__(cls)
  73. def __init__(
  74. self,
  75. func,
  76. n_gpus=None,
  77. world_size=None,
  78. rank_start=0,
  79. master_ip="localhost",
  80. port=0,
  81. device_type="xpu",
  82. backend="auto",
  83. ):
  84. self.func = func
  85. self.n_gpus = (
  86. n_gpus if n_gpus is not None else get_device_count_by_fork(device_type)
  87. )
  88. self.world_size = world_size if world_size is not None else self.n_gpus
  89. self.rank_start = rank_start
  90. self.master_ip = master_ip
  91. self.port = port
  92. self.device_type = device_type
  93. self.backend = backend
  94. # master node create server
  95. if self.rank_start == 0:
  96. self.server = Server(self.port)
  97. self.port = self.server.py_server_port
  98. else:
  99. assert self.port != 0, "you have to assign a port for distributed server"
  100. def __call__(self, *args, **kwargs):
  101. procs = []
  102. queue = mp.Queue(self.n_gpus)
  103. results = [None] * self.n_gpus
  104. machine_ranks = [i + self.rank_start for i in range(self.n_gpus)]
  105. for dev in range(self.n_gpus):
  106. p = mp.Process(
  107. target=_run_wrapped,
  108. args=(
  109. self.func,
  110. self.world_size > self.n_gpus,
  111. self.master_ip,
  112. self.port,
  113. self.world_size,
  114. dev + self.rank_start,
  115. dev,
  116. self.device_type,
  117. args,
  118. kwargs,
  119. self.backend,
  120. queue,
  121. machine_ranks,
  122. ),
  123. )
  124. p.start()
  125. procs.append(p)
  126. devs = list(range(self.n_gpus))
  127. def terminate():
  128. for dev in devs:
  129. procs[dev].terminate()
  130. devs.clear()
  131. result_count = 0
  132. while len(devs) > 0:
  133. left = []
  134. # check all processes in one second
  135. time_to_wait = 1.0 / len(devs)
  136. for dev in devs:
  137. procs[dev].join(time_to_wait)
  138. code = procs[dev].exitcode
  139. # terminate processes if one of them has failed
  140. if code != 0 and code != None:
  141. terminate()
  142. assert (
  143. code == 0 or code == None
  144. ), "subprocess {} exit with code {}".format(dev + self.rank_start, code)
  145. if code == None:
  146. left.append(dev)
  147. # DO NOT delete it, multiprocess.Queue has small buffer
  148. # fetch data early to avoid dead lock
  149. if not queue.empty():
  150. result_count += 1
  151. dev, ret = queue.get_nowait()
  152. results[dev] = ret
  153. devs = left
  154. while not queue.empty():
  155. result_count += 1
  156. dev, ret = queue.get_nowait()
  157. results[dev] = ret
  158. if result_count < self.n_gpus:
  159. get_logger().warning(WARN_SUBPROCESS_EXIT_WITHOUT_RETURN)
  160. return results

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台