You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 8.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Tuple
  10. from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn
  11. from ..core.autodiff.grad import (
  12. Tracer,
  13. check_backward_allow_noinput,
  14. get_grad_managers,
  15. get_op_has_grad_fn,
  16. tracer_apply,
  17. )
  18. from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
  19. from ..core.tensor.core import apply
  20. from ..core.tensor.tensor import Tensor, tensor_apply
  21. from ..device import get_default_device
  22. from ..tensor import tensor
  23. from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank
  24. __all__ = [
  25. "reduce_sum",
  26. "broadcast",
  27. "all_gather",
  28. "reduce_scatter_sum",
  29. "all_reduce_sum",
  30. "all_reduce_max",
  31. "all_reduce_min",
  32. "gather",
  33. "scatter",
  34. "all_to_all",
  35. "remote_send",
  36. "remote_recv",
  37. ]
  38. @apply.register()
  39. def _(op: RemoteSend, *args: Tensor):
  40. ret = tensor_apply(op, *args)
  41. # set extra information
  42. tracer_set = dict()
  43. for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))):
  44. tracer_set[k.name] = True
  45. # check tracer_set in remote_recv
  46. get_client().set_remote_tracer(op.key, tracer_set)
  47. return ret
  48. @builtin_op_get_backward_fn.register(RemoteSend)
  49. def _(op: RemoteSend, inputs, outputs, input_requires_grad):
  50. def backward(*args):
  51. return [
  52. remote_recv(
  53. op.rank_to,
  54. inputs[0].shape,
  55. inputs[0].dtype,
  56. device=str(inputs[0].device),
  57. inp=inputs[0],
  58. )
  59. ]
  60. return backward, [True]
  61. @get_op_has_grad_fn.register(RemoteSend)
  62. def _(op: RemoteSend):
  63. def has_grad(opnode, reached):
  64. return get_client().check_is_grad(op.key)
  65. return has_grad
  66. @check_backward_allow_noinput.register(RemoteSend)
  67. def _(op: RemoteSend):
  68. return True
  69. @builtin_op_get_backward_fn.register(RemoteRecv)
  70. def _(op: RemoteRecv, inputs, outputs, input_requires_grad):
  71. def backward(*output_grads):
  72. return [remote_send(output_grads[0], op.rank_from)]
  73. return backward, [True]
  74. @get_op_has_grad_fn.register(RemoteRecv)
  75. def _(op: RemoteRecv):
  76. def has_grad(opnode, reached):
  77. ret = False
  78. for v in opnode.outputs:
  79. if v() in reached:
  80. ret = True
  81. break
  82. get_client().set_is_grad(op.key, ret)
  83. return ret
  84. return has_grad
  85. def collective_comm(inp, mode, group, device):
  86. """Helper function for applying collective communication functions."""
  87. assert isinstance(group, Group)
  88. if group is None:
  89. return inp
  90. addr, port = get_mm_server_addr()
  91. op = CollectiveComm(
  92. key=group.key,
  93. nr_devices=group.size,
  94. rank=group.rank,
  95. is_root=(group.rank == 0),
  96. local_grad=False,
  97. addr=addr,
  98. port=port,
  99. mode=mode,
  100. dtype=inp.dtype,
  101. backend=get_backend(),
  102. comp_node=device,
  103. )
  104. return apply(op, inp)[0]
  105. def reduce_sum(
  106. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  107. ) -> Tensor:
  108. """
  109. Create reduce_sum operator for collective communication.
  110. :param inp: input tensor.
  111. :param group: communication group.
  112. :param device: execution device.
  113. """
  114. mode = CollectiveComm.Mode.REDUCE_SUM
  115. return collective_comm(inp, mode, group, device)
  116. def broadcast(
  117. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  118. ) -> Tensor:
  119. """
  120. Create broadcast operator for collective communication.
  121. :param inp: input tensor.
  122. :param group: communication group.
  123. :param device: execution device.
  124. """
  125. mode = CollectiveComm.Mode.BROADCAST
  126. return collective_comm(inp, mode, group, device)
  127. def all_gather(
  128. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  129. ) -> Tensor:
  130. """
  131. Create all_gather operator for collective communication.
  132. :param inp: input tensor.
  133. :param group: communication group.
  134. :param device: execution device.
  135. """
  136. mode = CollectiveComm.Mode.ALL_GATHER
  137. return collective_comm(inp, mode, group, device)
  138. def reduce_scatter_sum(
  139. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  140. ) -> Tensor:
  141. """
  142. Create reduce_scatter_sum operator for collective communication.
  143. :param inp: input tensor.
  144. :param group: communication group.
  145. :param device: execution device.
  146. """
  147. mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
  148. return collective_comm(inp, mode, group, device)
  149. def all_reduce_sum(
  150. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  151. ) -> Tensor:
  152. """
  153. Create all_reduce_sum operator for collective communication.
  154. :param inp: input tensor.
  155. :param group: communication group.
  156. :param device: execution device.
  157. """
  158. mode = CollectiveComm.Mode.ALL_REDUCE_SUM
  159. return collective_comm(inp, mode, group, device)
  160. def all_reduce_max(
  161. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  162. ) -> Tensor:
  163. """
  164. Create all_reduce_max operator for collective communication.
  165. :param inp: input tensor.
  166. :param group: communication group.
  167. :param device: execution device.
  168. """
  169. mode = CollectiveComm.Mode.ALL_REDUCE_MAX
  170. return collective_comm(inp, mode, group, device)
  171. def all_reduce_min(
  172. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  173. ) -> Tensor:
  174. """
  175. Create all_reduce_min operator for collective communication.
  176. :param inp: input tensor.
  177. :param group: communication group.
  178. :param device: execution device.
  179. """
  180. mode = CollectiveComm.Mode.ALL_REDUCE_MIN
  181. return collective_comm(inp, mode, group, device)
  182. def gather(
  183. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  184. ) -> Tensor:
  185. """
  186. Create gather operator for collective communication.
  187. :param inp: input tensor.
  188. :param group: communication group.
  189. :param device: execution device.
  190. """
  191. mode = CollectiveComm.Mode.GATHER
  192. return collective_comm(inp, mode, group, device)
  193. def scatter(
  194. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  195. ) -> Tensor:
  196. """
  197. Create scatter operator for collective communication.
  198. :param inp: input tensor.
  199. :param group: communication group.
  200. :param device: execution device.
  201. """
  202. mode = CollectiveComm.Mode.SCATTER
  203. return collective_comm(inp, mode, group, device)
  204. def all_to_all(
  205. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  206. ) -> Tensor:
  207. """
  208. Create all_to_all operator for collective communication.
  209. :param inp: input tensor.
  210. :param group: communication group.
  211. :param device: execution device.
  212. """
  213. mode = CollectiveComm.Mode.ALL_TO_ALL
  214. return collective_comm(inp, mode, group, device)
  215. def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
  216. """
  217. Send a Tensor to a remote process.
  218. :param inp: tensor to send.
  219. :param dest_rank: destination process rank.
  220. """
  221. op = RemoteSend()
  222. op.key = "{}->{}".format(get_rank(), dest_rank)
  223. op.addr, op.port = get_mm_server_addr()
  224. op.rank_to = dest_rank
  225. return apply(op, inp)[0]
  226. def remote_recv(
  227. src_rank: int,
  228. shape: Tuple[int],
  229. dtype: type,
  230. device: Optional[str] = None,
  231. inp=None,
  232. ) -> Tensor:
  233. """
  234. Receive a Tensor from a remote process.
  235. :param src_rank: source process rank.
  236. :param shape: the shape of the tensor to receive.
  237. :param dtype: the data type of the tensor to receive.
  238. :param device: the device to place the received tensor.
  239. :param inp: dummy input to determine recved tensor type
  240. """
  241. key = "{}->{}".format(src_rank, get_rank())
  242. if device is None:
  243. device = get_default_device()
  244. # dummy input
  245. if inp == None:
  246. inp = tensor([0], device=device)
  247. tracer_set = get_client().check_remote_tracer(key)
  248. for grad_manager in get_grad_managers():
  249. if grad_manager.name in tracer_set:
  250. grad_manager.wrt(inp)
  251. op = RemoteRecv()
  252. op.key = key
  253. op.cn = device
  254. op.shape = shape
  255. op.dtype = dtype
  256. op.addr, op.port = get_mm_server_addr()
  257. op.rank_from = src_rank
  258. return apply(op, inp)[0]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台