You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

launcher.py 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # -*- coding: utf-8 -*-
  2. import functools
  3. import multiprocessing as mp
  4. import os
  5. import queue
  6. from .. import _exit
  7. from ..core._imperative_rt.core2 import full_sync
  8. from ..device import get_device_count
  9. from ..logger import get_logger
  10. from .group import _set_machine_ranks, group_barrier, init_process_group
  11. from .helper import _check_device_initialized, _check_interpreter_status
  12. from .server import Client, Server
  13. WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = (
  14. "subprocess exited with code 0 but did not return a value"
  15. )
  16. def _run_wrapped(
  17. func,
  18. is_multimachine,
  19. master_ip,
  20. port,
  21. world_size,
  22. rank,
  23. dev,
  24. device_type,
  25. args,
  26. kwargs,
  27. backend,
  28. queue: mp.Queue,
  29. machine_ranks: list,
  30. ):
  31. r"""Init distributed process group and run wrapped function."""
  32. _check_interpreter_status()
  33. _check_device_initialized(device_type, dev)
  34. init_process_group(
  35. master_ip=master_ip,
  36. port=port,
  37. world_size=world_size,
  38. rank=rank,
  39. device=dev,
  40. backend=backend,
  41. device_type=device_type,
  42. )
  43. # set NCCL_LAUNCH_MODE to avoid deadlock
  44. os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
  45. _set_machine_ranks(machine_ranks)
  46. if is_multimachine:
  47. group_barrier()
  48. ret = func(*args, **kwargs)
  49. queue.put((dev, ret))
  50. full_sync()
  51. if is_multimachine:
  52. group_barrier()
  53. _exit(0)
  54. class launcher:
  55. r"""Decorator for launching multiple processes in single-machine multi-gpu training.
  56. Args:
  57. func: the function you want to launch in distributed mode.
  58. n_gpus: how many devices each node.
  59. world_size: how many devices totally.
  60. rank_start: start number for rank.
  61. master_ip: ip address for master node (where the rank 0 is).
  62. port: server port for distributed server.
  63. backend: set default collective communication backend.
  64. """
  65. def __new__(cls, *args, **kwargs):
  66. if not args:
  67. return functools.partial(cls, **kwargs)
  68. return super().__new__(cls)
  69. def __init__(
  70. self,
  71. func,
  72. n_gpus=None,
  73. world_size=None,
  74. rank_start=0,
  75. master_ip="localhost",
  76. port=0,
  77. device_type="xpu",
  78. backend="nccl",
  79. ):
  80. self.func = func
  81. self.n_gpus = n_gpus if n_gpus is not None else get_device_count(device_type)
  82. self.world_size = world_size if world_size is not None else self.n_gpus
  83. self.rank_start = rank_start
  84. self.master_ip = master_ip
  85. self.port = port
  86. self.device_type = device_type
  87. self.backend = backend
  88. # master node create server
  89. if self.rank_start == 0:
  90. self.server = Server(self.port)
  91. self.port = self.server.py_server_port
  92. else:
  93. assert self.port != 0, "you have to assign a port for distributed server"
  94. def __call__(self, *args, **kwargs):
  95. procs = []
  96. queue = mp.Queue(self.n_gpus)
  97. results = [None] * self.n_gpus
  98. machine_ranks = [i + self.rank_start for i in range(self.n_gpus)]
  99. for dev in range(self.n_gpus):
  100. p = mp.Process(
  101. target=_run_wrapped,
  102. args=(
  103. self.func,
  104. self.world_size > self.n_gpus,
  105. self.master_ip,
  106. self.port,
  107. self.world_size,
  108. dev + self.rank_start,
  109. dev,
  110. self.device_type,
  111. args,
  112. kwargs,
  113. self.backend,
  114. queue,
  115. machine_ranks,
  116. ),
  117. )
  118. p.start()
  119. procs.append(p)
  120. devs = list(range(self.n_gpus))
  121. def terminate():
  122. for dev in devs:
  123. procs[dev].terminate()
  124. devs.clear()
  125. result_count = 0
  126. while len(devs) > 0:
  127. left = []
  128. # check all processes in one second
  129. time_to_wait = 1.0 / len(devs)
  130. for dev in devs:
  131. procs[dev].join(time_to_wait)
  132. code = procs[dev].exitcode
  133. # terminate processes if one of them has failed
  134. if code != 0 and code != None:
  135. terminate()
  136. assert (
  137. code == 0 or code == None
  138. ), "subprocess {} exit with code {}".format(dev + self.rank_start, code)
  139. if code == None:
  140. left.append(dev)
  141. # DO NOT delete it, multiprocess.Queue has small buffer
  142. # fetch data early to avoid dead lock
  143. if not queue.empty():
  144. result_count += 1
  145. dev, ret = queue.get_nowait()
  146. results[dev] = ret
  147. devs = left
  148. while not queue.empty():
  149. result_count += 1
  150. dev, ret = queue.get_nowait()
  151. results[dev] = ret
  152. if result_count < self.n_gpus:
  153. get_logger().warning(WARN_SUBPROCESS_EXIT_WITHOUT_RETURN)
  154. return results