You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Tuple
  10. import numpy as np
  11. from ..core._imperative_rt.core2 import apply
  12. from ..core.autodiff.grad import Function, _grad_manager_dict
  13. from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
  14. from ..core.tensor.utils import isscalar, setscalar
  15. from ..device import get_default_device, what_is_xpu
  16. from ..tensor import Tensor
  17. from . import group
  18. from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank
  19. __all__ = [
  20. "reduce_sum",
  21. "broadcast",
  22. "all_gather",
  23. "reduce_scatter_sum",
  24. "all_reduce_sum",
  25. "all_reduce_max",
  26. "all_reduce_min",
  27. "gather",
  28. "scatter",
  29. "all_to_all",
  30. "remote_send",
  31. "remote_recv",
  32. ]
  33. _device2backend = {
  34. "gpu": "nccl",
  35. "cuda": "nccl",
  36. "rocm": "rccl",
  37. }
  38. def _backend():
  39. if group._sd.backend == "auto":
  40. return _device2backend[what_is_xpu()]
  41. else:
  42. return group._sd.backend
  43. def collective_comm(inp, mode, group, device):
  44. """Helper function for applying collective communication functions."""
  45. assert isinstance(group, Group)
  46. if group is None:
  47. return inp
  48. if device is None:
  49. device = ""
  50. addr, port = get_mm_server_addr()
  51. op = CollectiveComm(
  52. key=group.key + _backend(),
  53. nr_devices=group.size,
  54. rank=group.rank,
  55. is_root=(group.rank == 0),
  56. local_grad=False,
  57. addr=addr,
  58. port=port,
  59. mode=mode,
  60. dtype=inp.dtype,
  61. backend=_backend(),
  62. comp_node=device,
  63. )
  64. (result,) = apply(op, inp)
  65. # assume all workers have homogeneous shape
  66. if mode in (
  67. CollectiveComm.Mode.REDUCE_SUM,
  68. CollectiveComm.Mode.BROADCAST,
  69. CollectiveComm.Mode.ALL_REDUCE_SUM,
  70. CollectiveComm.Mode.ALL_REDUCE_MAX,
  71. CollectiveComm.Mode.ALL_REDUCE_MIN,
  72. ):
  73. if isscalar(inp):
  74. setscalar(result)
  75. return result
  76. def _save_output_for_autodiff(inp, out):
  77. for g in _grad_manager_dict.values():
  78. if g._is_attached_to(inp):
  79. g._refkeeper.append(out)
  80. def _bcast_has_grad(group, grad):
  81. if group.rank == 0:
  82. has_grad = grad is not None
  83. get_client().bcast_val(has_grad, group.key, group.size)
  84. else:
  85. has_grad = get_client().bcast_val(None, group.key, group.size)
  86. return has_grad
  87. def _bcast_shape_dtype(group, inp):
  88. if group.rank == 0:
  89. # FIXME in some cases, shape is not available(output of condtake)
  90. shape = inp._tuple_shape
  91. dtype = np.dtype(inp.dtype).name
  92. get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size)
  93. else:
  94. val = get_client().bcast_val(None, group.key, group.size)
  95. shape = val["shape"]
  96. dtype = val["dtype"]
  97. return shape, dtype
  98. def _bcast_tracer_state(group, inp):
  99. if group.rank == 0:
  100. tracer_keys = []
  101. for n, g in _grad_manager_dict.items():
  102. if g._is_attached_to(inp):
  103. tracer_keys.append(n)
  104. get_client().bcast_val(tracer_keys, group.key, group.size)
  105. else:
  106. tracer_keys = get_client().bcast_val(None, group.key, group.size)
  107. for n in tracer_keys:
  108. g = _grad_manager_dict.get(n)
  109. if g is not None:
  110. g.wrt(inp)
  111. g._refkeeper.append(inp)
  112. def _dummy_input(shape, dtype, device=None):
  113. if device is None:
  114. device = get_default_device()
  115. inp = Tensor(0, dtype=dtype, device=device)
  116. if len(shape) > 0:
  117. inp = inp._broadcast(shape)
  118. return inp
  119. class _ReduceSum(Function):
  120. def __init__(self, group=WORLD, device=None):
  121. self.group = group
  122. self.out_device = device
  123. def forward(self, data):
  124. self.in_device = str(data.device)
  125. return collective_comm(
  126. data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device,
  127. )
  128. def backward(self, grad):
  129. has_grad = _bcast_has_grad(self.group, grad)
  130. if has_grad:
  131. return broadcast(grad, self.group, self.in_device)
  132. def reduce_sum(
  133. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  134. ) -> Tensor:
  135. """
  136. Create reduce_sum operator for collective communication.
  137. :param inp: input tensor.
  138. :param group: communication group.
  139. :param device: execution device.
  140. """
  141. op = _ReduceSum(group, device)
  142. (out,) = apply(op, inp)
  143. if group.rank == 0:
  144. return out
  145. else:
  146. _save_output_for_autodiff(inp, out)
  147. class _Broadcast(Function):
  148. def __init__(self, group=WORLD, device=None):
  149. self.group = group
  150. self.out_device = device
  151. def forward(self, data):
  152. self.in_device = str(data.device)
  153. return collective_comm(
  154. data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device,
  155. )
  156. def backward(self, grad):
  157. # TODO backward with a part of grad
  158. if grad is not None:
  159. return reduce_sum(grad, self.group, self.in_device)
  160. def broadcast(
  161. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  162. ) -> Tensor:
  163. """
  164. Create broadcast operator for collective communication.
  165. :param inp: input tensor.
  166. :param group: communication group.
  167. :param device: execution device.
  168. """
  169. shape, dtype = _bcast_shape_dtype(group, inp)
  170. if group.rank != 0:
  171. # dummy input to infer shape
  172. inp = _dummy_input(shape, dtype, device)
  173. _bcast_tracer_state(group, inp)
  174. op = _Broadcast(group, device)
  175. (out,) = apply(op, inp)
  176. return out
  177. def _bcast_param(
  178. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None
  179. ) -> Tensor:
  180. mode = CollectiveComm.Mode.BROADCAST
  181. return collective_comm(inp, mode, group, device)
  182. def all_gather(
  183. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  184. ) -> Tensor:
  185. """
  186. Create all_gather operator for collective communication.
  187. :param inp: input tensor.
  188. :param group: communication group.
  189. :param device: execution device.
  190. """
  191. mode = CollectiveComm.Mode.ALL_GATHER
  192. return collective_comm(inp, mode, group, device)
  193. def reduce_scatter_sum(
  194. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  195. ) -> Tensor:
  196. """
  197. Create reduce_scatter_sum operator for collective communication.
  198. :param inp: input tensor.
  199. :param group: communication group.
  200. :param device: execution device.
  201. """
  202. mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
  203. return collective_comm(inp, mode, group, device)
  204. def all_reduce_sum(
  205. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  206. ) -> Tensor:
  207. r"""
  208. Create all_reduce_sum operator for collective communication.
  209. This operator sums the tensor data by coordinates across the specified group and returns a tensor with the shape of the input tensor.
  210. Args:
  211. inp: The tensor data to apply this operator on.
  212. group: The communication node list instance of :class:'Group' to apply this operator across. The default group is WORLD which means all processes available.
  213. Specify a list of process ranks to apply this operator on specific processes, e.g. [1, 3, 5].
  214. device: The specific device type of :class:'str' to execute this operator. The default device is None which mean the device of inp will be used.
  215. Specify "cpu" or "gpu" to execute this operator on specific devices.
  216. Returns:
  217. opt: The reduce sum tensor of the input tensor data across the specified group.
  218. Examples:
  219. .. code-block::
  220. import megengine as mge
  221. import megengine.distributed as dist
  222. import numpy as np
  223. from warnings import warn
  224. def func(sum_value):
  225. # get the rank of this process, the ranks shold be 0, 1, 2, 3 for a 4 gpu task
  226. rank = dist.get_rank()
  227. data = mge.tensor(rank)
  228. # the result should be n * (n - 1) / 2 for all processes
  229. result = mge.functional.distributed.all_reduce_sum(data).item()
  230. assert result == sum_value
  231. def main():
  232. p_num = mge.device.get_device_count("gpu")
  233. if p_num < 2:
  234. warn('This opr only works on group with more than one gpu')
  235. return
  236. method = dist.launcher(func)
  237. method(p_num * (p_num - 1) // 2)
  238. if __name__ == '__main__':
  239. main()
  240. """
  241. mode = CollectiveComm.Mode.ALL_REDUCE_SUM
  242. return collective_comm(inp, mode, group, device)
  243. def all_reduce_max(
  244. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  245. ) -> Tensor:
  246. """
  247. Create all_reduce_max operator for collective communication.
  248. :param inp: input tensor.
  249. :param group: communication group.
  250. :param device: execution device.
  251. :returns: reduced tensor.
  252. """
  253. mode = CollectiveComm.Mode.ALL_REDUCE_MAX
  254. return collective_comm(inp, mode, group, device)
  255. def all_reduce_min(
  256. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  257. ) -> Tensor:
  258. """
  259. Create all_reduce_min operator for collective communication.
  260. :param inp: input tensor.
  261. :param group: communication group.
  262. :param device: execution device.
  263. """
  264. mode = CollectiveComm.Mode.ALL_REDUCE_MIN
  265. return collective_comm(inp, mode, group, device)
  266. class _Gather(Function):
  267. def __init__(self, group=WORLD, device=None):
  268. self.group = group
  269. self.out_device = device
  270. def forward(self, data):
  271. self.in_device = str(data.device)
  272. return collective_comm(
  273. data, CollectiveComm.Mode.GATHER, self.group, self.out_device
  274. )
  275. def backward(self, grad):
  276. has_grad = _bcast_has_grad(self.group, grad)
  277. if has_grad:
  278. return scatter(grad, self.group, self.in_device)
  279. def gather(
  280. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  281. ) -> Tensor:
  282. """
  283. Create gather operator for collective communication.
  284. :param inp: input tensor.
  285. :param group: communication group.
  286. :param device: execution device.
  287. """
  288. op = _Gather(group, device)
  289. (out,) = apply(op, inp)
  290. if group.rank == 0:
  291. return out
  292. else:
  293. _save_output_for_autodiff(inp, out)
  294. class _Scatter(Function):
  295. def __init__(self, group=WORLD, device=None):
  296. self.group = group
  297. self.out_device = device
  298. def forward(self, data):
  299. self.in_device = str(data.device)
  300. return collective_comm(
  301. data, CollectiveComm.Mode.SCATTER, self.group, self.out_device
  302. )
  303. def backward(self, grad):
  304. # TODO backward with a part of grad
  305. if grad is not None:
  306. return gather(grad, self.group, self.in_device)
  307. def scatter(
  308. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  309. ) -> Tensor:
  310. """
  311. Create scatter operator for collective communication.
  312. :param inp: input tensor.
  313. :param group: communication group.
  314. :param device: execution device.
  315. """
  316. shape, dtype = _bcast_shape_dtype(group, inp)
  317. if group.rank != 0:
  318. # dummy input to infer shape
  319. inp = _dummy_input(shape, dtype, device)
  320. _bcast_tracer_state(group, inp)
  321. op = _Scatter(group, device)
  322. (out,) = apply(op, inp)
  323. return out
  324. def all_to_all(
  325. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  326. ) -> Tensor:
  327. """
  328. Create all_to_all operator for collective communication.
  329. :param inp: input tensor.
  330. :param group: communication group.
  331. :param device: execution device.
  332. """
  333. mode = CollectiveComm.Mode.ALL_TO_ALL
  334. return collective_comm(inp, mode, group, device)
  335. class _SendRecvGroup:
  336. def __init__(self, rank_from, rank_to):
  337. self.key = "{}->{}".format(rank_from, rank_to)
  338. self.rank_from = rank_from
  339. self.rank_to = rank_to
  340. self.size = 2
  341. @property
  342. def rank(self):
  343. if get_rank() == self.rank_from:
  344. return 0
  345. else:
  346. return 1
  347. class _RemoteSend(Function):
  348. def __init__(self, op: RemoteSend):
  349. self.op = op
  350. def forward(self, data):
  351. self.device = str(data.device)
  352. (self.dummy,) = apply(self.op, data)
  353. return self.dummy
  354. def backward(self, grad):
  355. assert grad is None
  356. has_grad = get_client().bcast_val(None, self.op.key, 2)
  357. if has_grad:
  358. return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,)
  359. class _RemoteRecv(Function):
  360. def __init__(self, op: RemoteRecv):
  361. self.op = op
  362. def forward(self, dummy):
  363. return apply(self.op, dummy)
  364. def backward(self, grad):
  365. get_client().bcast_val(grad is not None, self.op.key, 2)
  366. if grad is not None:
  367. remote_send(grad, self.op.rank_from)
  368. def remote_send(inp: Tensor, dest_rank: int):
  369. """
  370. Send a Tensor to a remote process.
  371. :param inp: tensor to send.
  372. :param dest_rank: destination process rank.
  373. """
  374. group = _SendRecvGroup(get_rank(), dest_rank)
  375. _bcast_shape_dtype(group, inp)
  376. _bcast_tracer_state(group, inp)
  377. op = RemoteSend()
  378. op.key = group.key
  379. op.addr, op.port = get_mm_server_addr()
  380. op.rank_to = dest_rank
  381. op.backend = _backend()
  382. (out,) = apply(_RemoteSend(op), inp)
  383. _save_output_for_autodiff(inp, out)
  384. def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor:
  385. """
  386. Receive a Tensor from a remote process.
  387. :param src_rank: source process rank.
  388. :param device: the device to place the received tensor.
  389. :param inp: dummy input to determine recved tensor type
  390. """
  391. group = _SendRecvGroup(src_rank, get_rank())
  392. shape, dtype = _bcast_shape_dtype(group, None)
  393. if device is None:
  394. device = get_default_device()
  395. # dummy input
  396. if inp is None:
  397. inp = Tensor(0, device=device)
  398. _bcast_tracer_state(group, inp)
  399. _isscalar = False
  400. if len(shape) == 0:
  401. shape = (1,)
  402. _isscalar = True
  403. op = RemoteRecv()
  404. op.key = group.key
  405. op.cn = device
  406. op.shape = shape
  407. op.dtype = dtype
  408. op.addr, op.port = get_mm_server_addr()
  409. op.rank_from = src_rank
  410. op.backend = _backend()
  411. (ret,) = apply(_RemoteRecv(op), inp)
  412. if _isscalar:
  413. setscalar(ret)
  414. return ret

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台