You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Tuple
  10. import numpy as np
  11. from ..core._imperative_rt.core2 import apply
  12. from ..core.autodiff.grad import Function, _grad_manager_dict
  13. from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
  14. from ..core.tensor.utils import isscalar, setscalar
  15. from ..device import get_default_device, what_is_xpu
  16. from ..tensor import Tensor
  17. from . import group
  18. from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank
  19. __all__ = [
  20. "reduce_sum",
  21. "broadcast",
  22. "all_gather",
  23. "reduce_scatter_sum",
  24. "all_reduce_sum",
  25. "all_reduce_max",
  26. "all_reduce_min",
  27. "gather",
  28. "scatter",
  29. "all_to_all",
  30. "remote_send",
  31. "remote_recv",
  32. ]
  33. _device2backend = {
  34. "gpu": "nccl",
  35. "cuda": "nccl",
  36. "rocm": "rccl",
  37. }
  38. def _backend():
  39. if group._sd.backend == "auto":
  40. return _device2backend[what_is_xpu()]
  41. else:
  42. return group._sd.backend
  43. def collective_comm(inp, mode, group, device):
  44. """Helper function for applying collective communication functions."""
  45. assert isinstance(group, Group)
  46. if group is None:
  47. return inp
  48. if device is None:
  49. device = ""
  50. addr, port = get_mm_server_addr()
  51. op = CollectiveComm(
  52. key=group.key + _backend(),
  53. nr_devices=group.size,
  54. rank=group.rank,
  55. is_root=(group.rank == 0),
  56. local_grad=False,
  57. addr=addr,
  58. port=port,
  59. mode=mode,
  60. dtype=inp.dtype,
  61. backend=_backend(),
  62. comp_node=device,
  63. )
  64. (result,) = apply(op, inp)
  65. # assume all workers have homogeneous shape
  66. if mode in (
  67. CollectiveComm.Mode.REDUCE_SUM,
  68. CollectiveComm.Mode.BROADCAST,
  69. CollectiveComm.Mode.ALL_REDUCE_SUM,
  70. CollectiveComm.Mode.ALL_REDUCE_MAX,
  71. CollectiveComm.Mode.ALL_REDUCE_MIN,
  72. ):
  73. if isscalar(inp):
  74. setscalar(result)
  75. return result
  76. def _save_output_for_autodiff(inp, out):
  77. for g in _grad_manager_dict.values():
  78. if g._is_attached_to(inp):
  79. g._refkeeper.append(out)
  80. def _bcast_has_grad(group, grad):
  81. if group.rank == 0:
  82. has_grad = grad is not None
  83. get_client().bcast_val(has_grad, group.key, group.size)
  84. else:
  85. has_grad = get_client().bcast_val(None, group.key, group.size)
  86. return has_grad
  87. def _bcast_shape_dtype(group, inp):
  88. if group.rank == 0:
  89. # FIXME in some cases, shape is not available(output of condtake)
  90. shape = inp._tuple_shape
  91. dtype = np.dtype(inp.dtype).name
  92. get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size)
  93. else:
  94. val = get_client().bcast_val(None, group.key, group.size)
  95. shape = val["shape"]
  96. dtype = val["dtype"]
  97. return shape, dtype
  98. def _bcast_tracer_state(group, inp):
  99. if group.rank == 0:
  100. tracer_keys = []
  101. for n, g in _grad_manager_dict.items():
  102. if g._is_attached_to(inp):
  103. tracer_keys.append(n)
  104. get_client().bcast_val(tracer_keys, group.key, group.size)
  105. else:
  106. tracer_keys = get_client().bcast_val(None, group.key, group.size)
  107. for n in tracer_keys:
  108. g = _grad_manager_dict.get(n)
  109. if g is not None:
  110. g.wrt(inp)
  111. g._refkeeper.append(inp)
  112. def _dummy_input(shape, dtype, device=None):
  113. if device is None:
  114. device = get_default_device()
  115. inp = Tensor(0, dtype=dtype, device=device)
  116. if len(shape) > 0:
  117. inp = inp._broadcast(shape)
  118. return inp
  119. class _ReduceSum(Function):
  120. def __init__(self, group=WORLD, device=None):
  121. self.group = group
  122. self.out_device = device
  123. def forward(self, data):
  124. self.in_device = str(data.device)
  125. return collective_comm(
  126. data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device,
  127. )
  128. def backward(self, grad):
  129. has_grad = _bcast_has_grad(self.group, grad)
  130. if has_grad:
  131. return broadcast(grad, self.group, self.in_device)
  132. def reduce_sum(
  133. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  134. ) -> Tensor:
  135. """
  136. Create reduce_sum operator for collective communication.
  137. :param inp: input tensor.
  138. :param group: communication group.
  139. :param device: execution device.
  140. """
  141. op = _ReduceSum(group, device)
  142. (out,) = apply(op, inp)
  143. if group.rank == 0:
  144. return out
  145. else:
  146. _save_output_for_autodiff(inp, out)
  147. class _Broadcast(Function):
  148. def __init__(self, group=WORLD, device=None):
  149. self.group = group
  150. self.out_device = device
  151. def forward(self, data):
  152. self.in_device = str(data.device)
  153. return collective_comm(
  154. data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device,
  155. )
  156. def backward(self, grad):
  157. # TODO backward with a part of grad
  158. if grad is not None:
  159. return reduce_sum(grad, self.group, self.in_device)
  160. def broadcast(
  161. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  162. ) -> Tensor:
  163. """
  164. Create broadcast operator for collective communication.
  165. :param inp: input tensor.
  166. :param group: communication group.
  167. :param device: execution device.
  168. """
  169. shape, dtype = _bcast_shape_dtype(group, inp)
  170. if group.rank != 0:
  171. # dummy input to infer shape
  172. inp = _dummy_input(shape, dtype, device)
  173. _bcast_tracer_state(group, inp)
  174. op = _Broadcast(group, device)
  175. (out,) = apply(op, inp)
  176. return out
  177. def _bcast_param(
  178. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None
  179. ) -> Tensor:
  180. mode = CollectiveComm.Mode.BROADCAST
  181. return collective_comm(inp, mode, group, device)
  182. def all_gather(
  183. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  184. ) -> Tensor:
  185. """
  186. Create all_gather operator for collective communication.
  187. :param inp: input tensor.
  188. :param group: communication group.
  189. :param device: execution device.
  190. """
  191. mode = CollectiveComm.Mode.ALL_GATHER
  192. return collective_comm(inp, mode, group, device)
  193. def reduce_scatter_sum(
  194. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  195. ) -> Tensor:
  196. """
  197. Create reduce_scatter_sum operator for collective communication.
  198. :param inp: input tensor.
  199. :param group: communication group.
  200. :param device: execution device.
  201. """
  202. mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
  203. return collective_comm(inp, mode, group, device)
  204. def all_reduce_sum(
  205. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  206. ) -> Tensor:
  207. """
  208. Create all_reduce_sum operator for collective communication.
  209. :param inp: input tensor.
  210. :param group: communication group.
  211. :param device: execution device.
  212. """
  213. mode = CollectiveComm.Mode.ALL_REDUCE_SUM
  214. return collective_comm(inp, mode, group, device)
  215. def all_reduce_max(
  216. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  217. ) -> Tensor:
  218. """
  219. Create all_reduce_max operator for collective communication.
  220. :param inp: input tensor.
  221. :param group: communication group.
  222. :param device: execution device.
  223. """
  224. mode = CollectiveComm.Mode.ALL_REDUCE_MAX
  225. return collective_comm(inp, mode, group, device)
  226. def all_reduce_min(
  227. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  228. ) -> Tensor:
  229. """
  230. Create all_reduce_min operator for collective communication.
  231. :param inp: input tensor.
  232. :param group: communication group.
  233. :param device: execution device.
  234. """
  235. mode = CollectiveComm.Mode.ALL_REDUCE_MIN
  236. return collective_comm(inp, mode, group, device)
  237. class _Gather(Function):
  238. def __init__(self, group=WORLD, device=None):
  239. self.group = group
  240. self.out_device = device
  241. def forward(self, data):
  242. self.in_device = str(data.device)
  243. return collective_comm(
  244. data, CollectiveComm.Mode.GATHER, self.group, self.out_device
  245. )
  246. def backward(self, grad):
  247. has_grad = _bcast_has_grad(self.group, grad)
  248. if has_grad:
  249. return scatter(grad, self.group, self.in_device)
  250. def gather(
  251. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  252. ) -> Tensor:
  253. """
  254. Create gather operator for collective communication.
  255. :param inp: input tensor.
  256. :param group: communication group.
  257. :param device: execution device.
  258. """
  259. op = _Gather(group, device)
  260. (out,) = apply(op, inp)
  261. if group.rank == 0:
  262. return out
  263. else:
  264. _save_output_for_autodiff(inp, out)
  265. class _Scatter(Function):
  266. def __init__(self, group=WORLD, device=None):
  267. self.group = group
  268. self.out_device = device
  269. def forward(self, data):
  270. self.in_device = str(data.device)
  271. return collective_comm(
  272. data, CollectiveComm.Mode.SCATTER, self.group, self.out_device
  273. )
  274. def backward(self, grad):
  275. # TODO backward with a part of grad
  276. if grad is not None:
  277. return gather(grad, self.group, self.in_device)
  278. def scatter(
  279. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  280. ) -> Tensor:
  281. """
  282. Create scatter operator for collective communication.
  283. :param inp: input tensor.
  284. :param group: communication group.
  285. :param device: execution device.
  286. """
  287. shape, dtype = _bcast_shape_dtype(group, inp)
  288. if group.rank != 0:
  289. # dummy input to infer shape
  290. inp = _dummy_input(shape, dtype, device)
  291. _bcast_tracer_state(group, inp)
  292. op = _Scatter(group, device)
  293. (out,) = apply(op, inp)
  294. return out
  295. def all_to_all(
  296. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  297. ) -> Tensor:
  298. """
  299. Create all_to_all operator for collective communication.
  300. :param inp: input tensor.
  301. :param group: communication group.
  302. :param device: execution device.
  303. """
  304. mode = CollectiveComm.Mode.ALL_TO_ALL
  305. return collective_comm(inp, mode, group, device)
  306. class _SendRecvGroup:
  307. def __init__(self, rank_from, rank_to):
  308. self.key = "{}->{}".format(rank_from, rank_to)
  309. self.rank_from = rank_from
  310. self.rank_to = rank_to
  311. self.size = 2
  312. @property
  313. def rank(self):
  314. if get_rank() == self.rank_from:
  315. return 0
  316. else:
  317. return 1
  318. class _RemoteSend(Function):
  319. def __init__(self, op: RemoteSend):
  320. self.op = op
  321. def forward(self, data):
  322. self.device = str(data.device)
  323. (self.dummy,) = apply(self.op, data)
  324. return self.dummy
  325. def backward(self, grad):
  326. assert grad is None
  327. has_grad = get_client().bcast_val(None, self.op.key, 2)
  328. if has_grad:
  329. return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,)
  330. class _RemoteRecv(Function):
  331. def __init__(self, op: RemoteRecv):
  332. self.op = op
  333. def forward(self, dummy):
  334. return apply(self.op, dummy)
  335. def backward(self, grad):
  336. get_client().bcast_val(grad is not None, self.op.key, 2)
  337. if grad is not None:
  338. remote_send(grad, self.op.rank_from)
  339. def remote_send(inp: Tensor, dest_rank: int):
  340. """
  341. Send a Tensor to a remote process.
  342. :param inp: tensor to send.
  343. :param dest_rank: destination process rank.
  344. """
  345. group = _SendRecvGroup(get_rank(), dest_rank)
  346. _bcast_shape_dtype(group, inp)
  347. _bcast_tracer_state(group, inp)
  348. op = RemoteSend()
  349. op.key = group.key
  350. op.addr, op.port = get_mm_server_addr()
  351. op.rank_to = dest_rank
  352. op.backend = _backend()
  353. (out,) = apply(_RemoteSend(op), inp)
  354. _save_output_for_autodiff(inp, out)
  355. def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor:
  356. """
  357. Receive a Tensor from a remote process.
  358. :param src_rank: source process rank.
  359. :param device: the device to place the received tensor.
  360. :param inp: dummy input to determine recved tensor type
  361. """
  362. group = _SendRecvGroup(src_rank, get_rank())
  363. shape, dtype = _bcast_shape_dtype(group, None)
  364. if device is None:
  365. device = get_default_device()
  366. # dummy input
  367. if inp is None:
  368. inp = Tensor(0, device=device)
  369. _bcast_tracer_state(group, inp)
  370. _isscalar = False
  371. if len(shape) == 0:
  372. shape = (1,)
  373. _isscalar = True
  374. op = RemoteRecv()
  375. op.key = group.key
  376. op.cn = device
  377. op.shape = shape
  378. op.dtype = dtype
  379. op.addr, op.port = get_mm_server_addr()
  380. op.rank_from = src_rank
  381. op.backend = _backend()
  382. (ret,) = apply(_RemoteRecv(op), inp)
  383. if _isscalar:
  384. setscalar(ret)
  385. return ret

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台