You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 8.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Tuple
  10. from ..core._imperative_rt.ops import CollectiveCommMode
  11. from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn
  12. from ..core.autodiff.grad import (
  13. Tracer,
  14. check_backward_allow_noinput,
  15. get_grad_managers,
  16. get_op_has_grad_fn,
  17. tracer_apply,
  18. )
  19. from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
  20. from ..core.tensor.core import apply
  21. from ..core.tensor.tensor import Tensor, tensor_apply
  22. from ..device import get_default_device
  23. from ..tensor import tensor
  24. from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank
  25. __all__ = [
  26. "reduce_sum",
  27. "broadcast",
  28. "all_gather",
  29. "reduce_scatter_sum",
  30. "all_reduce_sum",
  31. "all_reduce_max",
  32. "all_reduce_min",
  33. "gather",
  34. "scatter",
  35. "all_to_all",
  36. "remote_send",
  37. "remote_recv",
  38. ]
  39. @apply.register()
  40. def _(op: RemoteSend, *args: Tensor):
  41. ret = tensor_apply(op, *args)
  42. # set extra information
  43. tracer_set = dict()
  44. for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))):
  45. tracer_set[k.name] = True
  46. # check tracer_set in remote_recv
  47. get_client().set_remote_tracer(op.key, tracer_set)
  48. return ret
  49. @builtin_op_get_backward_fn.register(RemoteSend)
  50. def _(op: RemoteSend, inputs, outputs, input_requires_grad):
  51. def backward(*args):
  52. return [
  53. remote_recv(
  54. op.rank_to,
  55. inputs[0].shape,
  56. inputs[0].dtype,
  57. device=str(inputs[0].device),
  58. inp=inputs[0],
  59. )
  60. ]
  61. return backward, [True]
  62. @get_op_has_grad_fn.register(RemoteSend)
  63. def _(op: RemoteSend):
  64. def has_grad(opnode, reached):
  65. return get_client().check_is_grad(op.key)
  66. return has_grad
  67. @check_backward_allow_noinput.register(RemoteSend)
  68. def _(op: RemoteSend):
  69. return True
  70. @builtin_op_get_backward_fn.register(RemoteRecv)
  71. def _(op: RemoteRecv, inputs, outputs, input_requires_grad):
  72. def backward(*output_grads):
  73. return [remote_send(output_grads[0], op.rank_from)]
  74. return backward, [True]
  75. @get_op_has_grad_fn.register(RemoteRecv)
  76. def _(op: RemoteRecv):
  77. def has_grad(opnode, reached):
  78. ret = False
  79. for v in opnode.outputs:
  80. if v() in reached:
  81. ret = True
  82. break
  83. get_client().set_is_grad(op.key, ret)
  84. return ret
  85. return has_grad
  86. def collective_comm(inp, mode, group, device):
  87. """Helper function for applying collective communication functions."""
  88. assert isinstance(group, Group)
  89. if group is None:
  90. return inp
  91. op = CollectiveComm()
  92. op.key = group.key
  93. op.nr_devices = group.size
  94. op.rank = group.rank
  95. op.is_root = op.rank == 0
  96. op.local_grad = False
  97. op.addr, op.port = get_mm_server_addr()
  98. op.mode = mode
  99. op.dtype = inp.dtype
  100. op.backend = get_backend()
  101. op.comp_node = device
  102. return apply(op, inp)[0]
  103. def reduce_sum(
  104. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  105. ) -> Tensor:
  106. """
  107. Create reduce_sum operator for collective communication.
  108. :param inp: input tensor.
  109. :param group: communication group.
  110. :param device: execution device.
  111. """
  112. mode = CollectiveCommMode.REDUCE_SUM
  113. return collective_comm(inp, mode, group, device)
  114. def broadcast(
  115. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  116. ) -> Tensor:
  117. """
  118. Create broadcast operator for collective communication.
  119. :param inp: input tensor.
  120. :param group: communication group.
  121. :param device: execution device.
  122. """
  123. mode = CollectiveCommMode.BROADCAST
  124. return collective_comm(inp, mode, group, device)
  125. def all_gather(
  126. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  127. ) -> Tensor:
  128. """
  129. Create all_gather operator for collective communication.
  130. :param inp: input tensor.
  131. :param group: communication group.
  132. :param device: execution device.
  133. """
  134. mode = CollectiveCommMode.ALL_GATHER
  135. return collective_comm(inp, mode, group, device)
  136. def reduce_scatter_sum(
  137. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  138. ) -> Tensor:
  139. """
  140. Create reduce_scatter_sum operator for collective communication.
  141. :param inp: input tensor.
  142. :param group: communication group.
  143. :param device: execution device.
  144. """
  145. mode = CollectiveCommMode.REDUCE_SCATTER_SUM
  146. return collective_comm(inp, mode, group, device)
  147. def all_reduce_sum(
  148. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  149. ) -> Tensor:
  150. """
  151. Create all_reduce_sum operator for collective communication.
  152. :param inp: input tensor.
  153. :param group: communication group.
  154. :param device: execution device.
  155. """
  156. mode = CollectiveCommMode.ALL_REDUCE_SUM
  157. return collective_comm(inp, mode, group, device)
  158. def all_reduce_max(
  159. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  160. ) -> Tensor:
  161. """
  162. Create all_reduce_max operator for collective communication.
  163. :param inp: input tensor.
  164. :param group: communication group.
  165. :param device: execution device.
  166. """
  167. mode = CollectiveCommMode.ALL_REDUCE_MAX
  168. return collective_comm(inp, mode, group, device)
  169. def all_reduce_min(
  170. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  171. ) -> Tensor:
  172. """
  173. Create all_reduce_min operator for collective communication.
  174. :param inp: input tensor.
  175. :param group: communication group.
  176. :param device: execution device.
  177. """
  178. mode = CollectiveCommMode.ALL_REDUCE_MIN
  179. return collective_comm(inp, mode, group, device)
  180. def gather(
  181. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  182. ) -> Tensor:
  183. """
  184. Create gather operator for collective communication.
  185. :param inp: input tensor.
  186. :param group: communication group.
  187. :param device: execution device.
  188. """
  189. mode = CollectiveCommMode.GATHER
  190. return collective_comm(inp, mode, group, device)
  191. def scatter(
  192. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  193. ) -> Tensor:
  194. """
  195. Create scatter operator for collective communication.
  196. :param inp: input tensor.
  197. :param group: communication group.
  198. :param device: execution device.
  199. """
  200. mode = CollectiveCommMode.SCATTER
  201. return collective_comm(inp, mode, group, device)
  202. def all_to_all(
  203. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  204. ) -> Tensor:
  205. """
  206. Create all_to_all operator for collective communication.
  207. :param inp: input tensor.
  208. :param group: communication group.
  209. :param device: execution device.
  210. """
  211. mode = CollectiveCommMode.ALL_TO_ALL
  212. return collective_comm(inp, mode, group, device)
  213. def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
  214. """
  215. Send a Tensor to a remote process.
  216. :param inp: tensor to send.
  217. :param dest_rank: destination process rank.
  218. """
  219. op = RemoteSend()
  220. op.key = "{}->{}".format(get_rank(), dest_rank)
  221. op.addr, op.port = get_mm_server_addr()
  222. op.rank_to = dest_rank
  223. return apply(op, inp)[0]
  224. def remote_recv(
  225. src_rank: int,
  226. shape: Tuple[int],
  227. dtype: type,
  228. device: Optional[str] = None,
  229. inp=None,
  230. ) -> Tensor:
  231. """
  232. Receive a Tensor from a remote process.
  233. :param src_rank: source process rank.
  234. :param shape: the shape of the tensor to receive.
  235. :param dtype: the data type of the tensor to receive.
  236. :param device: the device to place the received tensor.
  237. :param inp: dummy input to determine recved tensor type
  238. """
  239. key = "{}->{}".format(src_rank, get_rank())
  240. if device is None:
  241. device = get_default_device()
  242. # dummy input
  243. if inp == None:
  244. inp = tensor([0])
  245. tracer_set = get_client().check_remote_tracer(key)
  246. for grad_manager in get_grad_managers():
  247. if grad_manager.name in tracer_set:
  248. grad_manager.wrt(inp)
  249. op = RemoteRecv()
  250. op.key = key
  251. op.cn = device
  252. op.shape = shape
  253. op.dtype = dtype
  254. op.addr, op.port = get_mm_server_addr()
  255. op.rank_from = src_rank
  256. return apply(op, inp)[0]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台