You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Tuple
  10. from ..core._imperative_rt.core2 import apply
  11. from ..core.autodiff.grad import _grad_manager_dict
  12. from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend
  13. from ..core.tensor.utils import isscalar, setscalar
  14. from ..device import get_default_device
  15. from ..tensor import Tensor
  16. from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank
  17. __all__ = [
  18. "reduce_sum",
  19. "broadcast",
  20. "all_gather",
  21. "reduce_scatter_sum",
  22. "all_reduce_sum",
  23. "all_reduce_max",
  24. "all_reduce_min",
  25. "gather",
  26. "scatter",
  27. "all_to_all",
  28. "remote_send",
  29. "remote_recv",
  30. ]
  31. def collective_comm(inp, mode, group, device):
  32. """Helper function for applying collective communication functions."""
  33. assert isinstance(group, Group)
  34. if group is None:
  35. return inp
  36. addr, port = get_mm_server_addr()
  37. op = CollectiveComm(
  38. key=group.key,
  39. nr_devices=group.size,
  40. rank=group.rank,
  41. is_root=(group.rank == 0),
  42. local_grad=False,
  43. addr=addr,
  44. port=port,
  45. mode=mode,
  46. dtype=inp.dtype,
  47. backend=get_backend(),
  48. comp_node=device,
  49. )
  50. (result,) = apply(op, inp)
  51. # assume all workers have homogeneous shape
  52. if mode in (
  53. CollectiveComm.Mode.REDUCE_SUM,
  54. CollectiveComm.Mode.BROADCAST,
  55. CollectiveComm.Mode.ALL_REDUCE_SUM,
  56. CollectiveComm.Mode.ALL_REDUCE_MAX,
  57. CollectiveComm.Mode.ALL_REDUCE_MIN,
  58. ):
  59. if isscalar(inp):
  60. setscalar(result)
  61. return result
  62. def reduce_sum(
  63. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  64. ) -> Tensor:
  65. """
  66. Create reduce_sum operator for collective communication.
  67. :param inp: input tensor.
  68. :param group: communication group.
  69. :param device: execution device.
  70. """
  71. mode = CollectiveComm.Mode.REDUCE_SUM
  72. return collective_comm(inp, mode, group, device)
  73. def broadcast(
  74. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  75. ) -> Tensor:
  76. """
  77. Create broadcast operator for collective communication.
  78. :param inp: input tensor.
  79. :param group: communication group.
  80. :param device: execution device.
  81. """
  82. mode = CollectiveComm.Mode.BROADCAST
  83. return collective_comm(inp, mode, group, device)
  84. def all_gather(
  85. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  86. ) -> Tensor:
  87. """
  88. Create all_gather operator for collective communication.
  89. :param inp: input tensor.
  90. :param group: communication group.
  91. :param device: execution device.
  92. """
  93. mode = CollectiveComm.Mode.ALL_GATHER
  94. return collective_comm(inp, mode, group, device)
  95. def reduce_scatter_sum(
  96. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  97. ) -> Tensor:
  98. """
  99. Create reduce_scatter_sum operator for collective communication.
  100. :param inp: input tensor.
  101. :param group: communication group.
  102. :param device: execution device.
  103. """
  104. mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
  105. return collective_comm(inp, mode, group, device)
  106. def all_reduce_sum(
  107. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  108. ) -> Tensor:
  109. """
  110. Create all_reduce_sum operator for collective communication.
  111. :param inp: input tensor.
  112. :param group: communication group.
  113. :param device: execution device.
  114. """
  115. mode = CollectiveComm.Mode.ALL_REDUCE_SUM
  116. return collective_comm(inp, mode, group, device)
  117. def all_reduce_max(
  118. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  119. ) -> Tensor:
  120. """
  121. Create all_reduce_max operator for collective communication.
  122. :param inp: input tensor.
  123. :param group: communication group.
  124. :param device: execution device.
  125. """
  126. mode = CollectiveComm.Mode.ALL_REDUCE_MAX
  127. return collective_comm(inp, mode, group, device)
  128. def all_reduce_min(
  129. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  130. ) -> Tensor:
  131. """
  132. Create all_reduce_min operator for collective communication.
  133. :param inp: input tensor.
  134. :param group: communication group.
  135. :param device: execution device.
  136. """
  137. mode = CollectiveComm.Mode.ALL_REDUCE_MIN
  138. return collective_comm(inp, mode, group, device)
  139. def gather(
  140. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  141. ) -> Tensor:
  142. """
  143. Create gather operator for collective communication.
  144. :param inp: input tensor.
  145. :param group: communication group.
  146. :param device: execution device.
  147. """
  148. mode = CollectiveComm.Mode.GATHER
  149. return collective_comm(inp, mode, group, device)
  150. def scatter(
  151. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  152. ) -> Tensor:
  153. """
  154. Create scatter operator for collective communication.
  155. :param inp: input tensor.
  156. :param group: communication group.
  157. :param device: execution device.
  158. """
  159. mode = CollectiveComm.Mode.SCATTER
  160. return collective_comm(inp, mode, group, device)
  161. def all_to_all(
  162. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  163. ) -> Tensor:
  164. """
  165. Create all_to_all operator for collective communication.
  166. :param inp: input tensor.
  167. :param group: communication group.
  168. :param device: execution device.
  169. """
  170. mode = CollectiveComm.Mode.ALL_TO_ALL
  171. return collective_comm(inp, mode, group, device)
  172. class _RemoteSend(PyOpBase):
  173. def __init__(self, op: RemoteSend):
  174. self.op = op
  175. def _default_rule(self, data):
  176. return apply(self.op, data)
  177. def _grad_rule(self, data):
  178. self.dtype = data.dtype
  179. self.shape = data.shape
  180. self.device = data.device
  181. (self.dummy,) = self._default_rule(data)
  182. return self.dummy, self.backward
  183. def backward(self, grad):
  184. assert grad is None
  185. if get_client().check_is_grad(self.op.key):
  186. return remote_recv(
  187. self.op.rank_to,
  188. self.shape,
  189. self.dtype,
  190. device=str(self.device),
  191. inp=self.dummy,
  192. )
  193. class _RemoteRecv(PyOpBase):
  194. def __init__(self, op: RemoteRecv):
  195. self.op = op
  196. def _default_rule(self, dummy):
  197. return apply(self.op, dummy)
  198. def _grad_rule(self, dummy):
  199. return self._default_rule(dummy), self.backward
  200. def backward(self, grad):
  201. get_client().set_is_grad(self.op.key, grad is not None)
  202. if grad is not None:
  203. remote_send(grad, self.op.rank_from)
  204. def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
  205. """
  206. Send a Tensor to a remote process.
  207. :param inp: tensor to send.
  208. :param dest_rank: destination process rank.
  209. """
  210. key = "{}->{}".format(get_rank(), dest_rank)
  211. grad_keys = {}
  212. for n, g in _grad_manager_dict.items():
  213. if g._is_attached_to(inp):
  214. grad_keys[n] = g
  215. get_client().set_remote_tracer(key, grad_keys)
  216. op = RemoteSend()
  217. op.key = key
  218. op.addr, op.port = get_mm_server_addr()
  219. op.rank_to = dest_rank
  220. op.backend = get_backend()
  221. (dummy,) = apply(_RemoteSend(op), inp)
  222. for g in grad_keys.values():
  223. g._refkeeper.append(dummy)
  224. def remote_recv(
  225. src_rank: int,
  226. shape: Tuple[int],
  227. dtype: type,
  228. device: Optional[str] = None,
  229. inp=None,
  230. ) -> Tensor:
  231. """
  232. Receive a Tensor from a remote process.
  233. :param src_rank: source process rank.
  234. :param shape: the shape of the tensor to receive.
  235. :param dtype: the data type of the tensor to receive.
  236. :param device: the device to place the received tensor.
  237. :param inp: dummy input to determine recved tensor type
  238. """
  239. key = "{}->{}".format(src_rank, get_rank())
  240. if device is None:
  241. device = get_default_device()
  242. # dummy input
  243. if inp is None:
  244. inp = Tensor([0], device=device)
  245. tracer_set = get_client().check_remote_tracer(key)
  246. for n in tracer_set:
  247. g = _grad_manager_dict.get(n)
  248. if g is not None:
  249. g.wrt(inp)
  250. g._refkeeper.append(inp)
  251. _isscalar = False
  252. if len(shape) == 0:
  253. shape = (1,)
  254. _isscalar = True
  255. op = RemoteRecv()
  256. op.key = key
  257. op.cn = device
  258. op.shape = shape
  259. op.dtype = dtype
  260. op.addr, op.port = get_mm_server_addr()
  261. op.rank_from = src_rank
  262. op.backend = get_backend()
  263. (ret,) = apply(_RemoteRecv(op), inp)
  264. if _isscalar:
  265. setscalar(ret)
  266. return ret

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台