You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

launcher.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import multiprocessing as mp
  11. import queue
  12. from ..core._imperative_rt.core2 import sync
  13. from ..logger import get_logger
  14. from .group import group_barrier, init_process_group
  15. from .helper import get_device_count_by_fork
  16. from .server import Client, Server
  17. WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = (
  18. "subprocess exited with code 0 but did not return a value"
  19. )
  20. def _run_wrapped(
  21. func,
  22. is_multimachine,
  23. master_ip,
  24. port,
  25. world_size,
  26. rank,
  27. dev,
  28. args,
  29. kwargs,
  30. queue: mp.Queue,
  31. ):
  32. """Init distributed process group and run wrapped function."""
  33. init_process_group(
  34. master_ip=master_ip, port=port, world_size=world_size, rank=rank, device=dev
  35. )
  36. if is_multimachine:
  37. group_barrier()
  38. ret = func(*args, **kwargs)
  39. queue.put((dev, ret))
  40. sync()
  41. if is_multimachine:
  42. group_barrier()
  43. class launcher:
  44. """Decorator for launching multiple processes in single-machine multi-gpu training.
  45. :param func: the function you want to launch in distributed mode.
  46. :param n_gpus: how many devices each node.
  47. :param world_size: how many devices totally.
  48. :param rank_start: start number for rank.
  49. :param master_ip: ip address for master node (where the rank 0 is).
  50. :param port: server port for distributed server.
  51. """
  52. def __new__(cls, *args, **kwargs):
  53. if not args:
  54. return functools.partial(cls, **kwargs)
  55. return super().__new__(cls)
  56. def __init__(
  57. self,
  58. func,
  59. n_gpus=None,
  60. world_size=None,
  61. rank_start=0,
  62. master_ip="localhost",
  63. port=0,
  64. ):
  65. self.func = func
  66. self.n_gpus = n_gpus if n_gpus is not None else get_device_count_by_fork("gpu")
  67. self.world_size = world_size if world_size is not None else self.n_gpus
  68. self.rank_start = rank_start
  69. self.master_ip = master_ip
  70. self.port = port
  71. # master node create server
  72. if self.rank_start == 0:
  73. self.server = Server(self.port)
  74. self.port = self.server.py_server_port
  75. else:
  76. assert self.port != 0, "you have to assign a port for distributed server"
  77. def __call__(self, *args, **kwargs):
  78. procs = []
  79. queue = mp.Queue(self.n_gpus)
  80. results = [None] * self.n_gpus
  81. for dev in range(self.n_gpus):
  82. p = mp.Process(
  83. target=_run_wrapped,
  84. args=(
  85. self.func,
  86. self.world_size > self.n_gpus,
  87. self.master_ip,
  88. self.port,
  89. self.world_size,
  90. dev + self.rank_start,
  91. dev,
  92. args,
  93. kwargs,
  94. queue,
  95. ),
  96. )
  97. p.start()
  98. procs.append(p)
  99. devs = list(range(self.n_gpus))
  100. def terminate():
  101. for dev in devs:
  102. procs[dev].terminate()
  103. devs.clear()
  104. while len(devs) > 0:
  105. left = []
  106. # check all processes in one second
  107. time_to_wait = 1.0 / len(devs)
  108. for dev in devs:
  109. procs[dev].join(time_to_wait)
  110. code = procs[dev].exitcode
  111. # terminate processes if one of them has failed
  112. if code != 0 and code != None:
  113. terminate()
  114. assert (
  115. code == 0 or code == None
  116. ), "subprocess {} exit with code {}".format(dev + self.rank_start, code)
  117. if code == None:
  118. left.append(dev)
  119. elif queue.empty():
  120. get_logger().warning(WARN_SUBPROCESS_EXIT_WITHOUT_RETURN)
  121. else:
  122. dev, ret = queue.get_nowait()
  123. results[dev] = ret
  124. devs = left
  125. return results

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台