You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Tuple
  10. import numpy as np
  11. from ..core._imperative_rt.core2 import apply
  12. from ..core.autodiff.grad import Function, _grad_manager_dict
  13. from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
  14. from ..core.tensor.utils import isscalar
  15. from ..device import get_default_device, what_is_xpu
  16. from ..tensor import Tensor
  17. from . import group
  18. from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank
  19. __all__ = [
  20. "reduce_sum",
  21. "broadcast",
  22. "all_gather",
  23. "reduce_scatter_sum",
  24. "all_reduce_sum",
  25. "all_reduce_max",
  26. "all_reduce_min",
  27. "gather",
  28. "scatter",
  29. "all_to_all",
  30. "remote_send",
  31. "remote_recv",
  32. ]
  33. _device2backend = {
  34. "gpu": "nccl",
  35. "cuda": "nccl",
  36. "rocm": "rccl",
  37. }
  38. def _backend():
  39. if group._sd.backend == "auto":
  40. return _device2backend[what_is_xpu()]
  41. else:
  42. return group._sd.backend
  43. def collective_comm(inp, mode, group, device):
  44. r"""Helper function for applying collective communication functions."""
  45. assert isinstance(group, Group)
  46. if group is None:
  47. return inp
  48. if device is None:
  49. device = ""
  50. addr, port = get_mm_server_addr()
  51. op = CollectiveComm(
  52. key=group.key + _backend(),
  53. nr_devices=group.size,
  54. rank=group.rank,
  55. is_root=(group.rank == 0),
  56. local_grad=False,
  57. addr=addr,
  58. port=port,
  59. mode=mode,
  60. dtype=inp.dtype,
  61. backend=_backend(),
  62. comp_node=device,
  63. )
  64. (result,) = apply(op, inp)
  65. # assume all workers have homogeneous shape
  66. return result
  67. def _save_output_for_autodiff(inp, out):
  68. for g in _grad_manager_dict.values():
  69. if g._is_attached_to(inp):
  70. g._refkeeper.append(out)
  71. def _bcast_has_grad(group, grad):
  72. if group.rank == 0:
  73. has_grad = grad is not None
  74. get_client().bcast_val(has_grad, group.key, group.size)
  75. else:
  76. has_grad = get_client().bcast_val(None, group.key, group.size)
  77. return has_grad
  78. def _bcast_shape_dtype(group, inp):
  79. if group.rank == 0:
  80. # FIXME in some cases, shape is not available(output of condtake)
  81. shape = inp._tuple_shape
  82. dtype = np.dtype(inp.dtype).name
  83. get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size)
  84. else:
  85. val = get_client().bcast_val(None, group.key, group.size)
  86. shape = val["shape"]
  87. dtype = val["dtype"]
  88. return shape, dtype
  89. def _bcast_tracer_state(group, inp):
  90. if group.rank == 0:
  91. tracer_keys = []
  92. for n, g in _grad_manager_dict.items():
  93. if g._is_attached_to(inp):
  94. tracer_keys.append(n)
  95. get_client().bcast_val(tracer_keys, group.key, group.size)
  96. else:
  97. tracer_keys = get_client().bcast_val(None, group.key, group.size)
  98. for n in tracer_keys:
  99. g = _grad_manager_dict.get(n)
  100. if g is not None:
  101. g.wrt(inp)
  102. g._refkeeper.append(inp)
  103. def _dummy_input(shape, dtype, device=None):
  104. if device is None:
  105. device = get_default_device()
  106. inp = Tensor(0, dtype=dtype, device=device)
  107. if len(shape) > 0:
  108. inp = inp._broadcast(shape)
  109. return inp
  110. class _ReduceSum(Function):
  111. def __init__(self, group=WORLD, device=None):
  112. self.group = group
  113. self.out_device = device
  114. def forward(self, data):
  115. self.in_device = str(data.device)
  116. return collective_comm(
  117. data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device,
  118. )
  119. def backward(self, grad):
  120. has_grad = _bcast_has_grad(self.group, grad)
  121. if has_grad:
  122. return broadcast(grad, self.group, self.in_device)
  123. def reduce_sum(
  124. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  125. ) -> Tensor:
  126. r"""Reduce tensor data across the specified group by sum.
  127. Only root process will receive the final result.
  128. Args:
  129. inp: Input tensor.
  130. group: The process group to work on.
  131. The default group is WORLD which means all processes available.
  132. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  133. device: The specific device to execute this operator.
  134. None default device means the device of inp will be used.
  135. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  136. 1 is stream id, and default stream id is 0.
  137. Returns:
  138. Reduced tensor if in root process, None in other processes.
  139. Examples:
  140. .. code-block::
  141. input = Tensor([rank])
  142. # Rank 0 # input: Tensor([0])
  143. # Rank 1 # input: Tensor([1])
  144. output = reduce_sum(input)
  145. # Rank 0 # output: Tensor([1])
  146. # Rank 1 # output: None
  147. input = Tensor([rank])
  148. group = Group([1, 0]) # first rank is root
  149. output = reduce_sum(input, group)
  150. # Rank 0 # output: None
  151. # Rank 1 # output: Tensor([1])
  152. """
  153. out = _ReduceSum(group, device)(inp)
  154. if group.rank == 0:
  155. return out
  156. else:
  157. _save_output_for_autodiff(inp, out)
  158. class _Broadcast(Function):
  159. def __init__(self, group=WORLD, device=None):
  160. self.group = group
  161. self.out_device = device
  162. def forward(self, data):
  163. self.in_device = str(data.device)
  164. return collective_comm(
  165. data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device,
  166. )
  167. def backward(self, grad):
  168. # TODO backward with a part of grad
  169. if grad is not None:
  170. return reduce_sum(grad, self.group, self.in_device)
  171. def broadcast(
  172. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  173. ) -> Tensor:
  174. r"""Broadcast tensor data from root process to others.
  175. Args:
  176. inp: Input tensor.
  177. group: The process group to work on.
  178. The default group is WORLD which means all processes available.
  179. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  180. device: The specific device to execute this operator.
  181. None default device means the device of inp will be used.
  182. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  183. 1 is stream id, and default stream id is 0.
  184. Returns:
  185. Result tensor.
  186. Examples:
  187. .. code-block::
  188. input = Tensor([rank])
  189. # Rank 0 # input: Tensor([0])
  190. # Rank 1 # input: Tensor([1])
  191. output = broadcast(input)
  192. # Rank 0 # output: Tensor([0])
  193. # Rank 1 # output: Tensor([0])
  194. input = Tensor([rank])
  195. group = Group([1, 0]) # first rank is root
  196. output = broadcast(input, group)
  197. # Rank 0 # output: Tensor([1])
  198. # Rank 1 # output: Tensor([1])
  199. """
  200. shape, dtype = _bcast_shape_dtype(group, inp)
  201. if group.rank != 0:
  202. # dummy input to infer shape
  203. inp = _dummy_input(shape, dtype, device)
  204. _bcast_tracer_state(group, inp)
  205. out = _Broadcast(group, device)(inp)
  206. return out
  207. def _bcast_param(
  208. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None
  209. ) -> Tensor:
  210. mode = CollectiveComm.Mode.BROADCAST
  211. return collective_comm(inp, mode, group, device)
  212. def all_gather(
  213. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  214. ) -> Tensor:
  215. r"""Gather tensors across the specified group and concat them at first dimension.
  216. Args:
  217. inp: Input tensor.
  218. group: The process group to work on.
  219. The default group is WORLD which means all processes available.
  220. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  221. device: The specific device to execute this operator.
  222. None default device means the device of inp will be used.
  223. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  224. 1 is stream id, and default stream id is 0.
  225. axis: The concat axis for collective_comm result
  226. The default axis is 0
  227. Returns:
  228. Result tensor.
  229. Examples:
  230. .. code-block::
  231. input = Tensor([rank])
  232. # Rank 0 # input: Tensor([0])
  233. # Rank 1 # input: Tensor([1])
  234. output = all_gather(input)
  235. # Rank 0 # output: Tensor([0 1])
  236. # Rank 1 # output: Tensor([0 1])
  237. input = Tensor([rank])
  238. group = Group([1, 0])
  239. output = all_gather(input, group)
  240. # Rank 0 # output: Tensor([1 0])
  241. # Rank 1 # output: Tensor([1 0])
  242. """
  243. mode = CollectiveComm.Mode.ALL_GATHER
  244. out = collective_comm(inp, mode, group, device)
  245. if axis == 0:
  246. return out
  247. else:
  248. group_size = group.size if group is not None else 1
  249. transformed_shape = list(inp._tuple_shape)
  250. transformed_shape[axis] *= group_size
  251. n, *shp = out._tuple_shape
  252. index = (
  253. [_ for _ in range(1, axis)]
  254. + [axis, 0]
  255. + [_ for _ in range(axis + 1, out.ndim + 1)]
  256. )
  257. return (
  258. out.reshape(group_size, n // group_size, *shp)
  259. .transpose(index)
  260. .reshape(transformed_shape)
  261. )
  262. def reduce_scatter_sum(
  263. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0
  264. ) -> Tensor:
  265. r"""Reduce tensors across the specified group by sum and split them at first dimension.
  266. Args:
  267. inp: Input tensor.
  268. group: The process group to work on.
  269. The default group is WORLD which means all processes available.
  270. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  271. device: The specific device to execute this operator.
  272. None default device means the device of inp will be used.
  273. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  274. 1 is stream id, and default stream id is 0.
  275. axis: The split axis for collective_comm result
  276. The default axis is 0, the data will split in the 0 axis
  277. Returns:
  278. Split tensor.
  279. Examples:
  280. .. code-block::
  281. input = Tensor([0 1])
  282. # Rank 0 # input: Tensor([0 1])
  283. # Rank 1 # input: Tensor([0 1])
  284. output = reduce_scatter_sum(input)
  285. # Rank 0 # output: Tensor([0])
  286. # Rank 1 # output: Tensor([2])
  287. input = Tensor([0 1])
  288. group = Group([1, 0])
  289. output = reduce_scatter_sum(input, group)
  290. # Rank 0 # output: Tensor([2])
  291. # Rank 1 # output: Tensor([0])
  292. """
  293. group_size = group.size if group is not None else 1
  294. assert (
  295. list(inp._tuple_shape)[axis] % group_size == 0
  296. ), "current axis: {} can't devided by group size".format(axis)
  297. if axis != 0:
  298. k_new_shape = list(inp._tuple_shape)
  299. k_new_shape[axis] //= group_size
  300. k_new_shape[0] *= group_size
  301. new_shape = list(inp._tuple_shape)
  302. new_shape[axis] //= group_size
  303. new_shape.insert(axis, group_size)
  304. index = (
  305. [axis]
  306. + [_ for _ in range(0, axis)]
  307. + [_ for _ in range(axis + 1, inp.ndim + 1)]
  308. )
  309. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  310. mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
  311. return collective_comm(inp, mode, group, device)
  312. def all_reduce_sum(
  313. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  314. ) -> Tensor:
  315. r"""Reduce tensors with sum operation on each value across the specified group.
  316. Note:
  317. ``inp`` tensor must have identical shape in all processes across the group.
  318. Args:
  319. inp (Tensor): tensor to be reduced.
  320. Keyword args:
  321. group (Group or sequence of ints): the process group to work on. Default: ``WORLD``.
  322. ``WORLD`` group selects all processes available.
  323. list of process rank as parameter will create a new group to work on.
  324. device (:attr:`.Tensor.device`): the specific device to execute this operator. Default: ``None``
  325. ``None`` will select the device of ``inp`` to execute.
  326. Specially, ``GPU`` device can assign a different stream to execute
  327. by adding a number right after a colon following the device name while
  328. ``:0`` denotes default stream of GPU, otherwise will use default stream.
  329. Returns:
  330. A tensor with sum operation on each value across the group.
  331. The shape of the output tensor must be the same as ``inp``, and the output
  332. tensor is going to be bitwise identical in all processes across the group.
  333. Examples:
  334. >>> # We execute all_reduce_sum on rank 0 and rank 1
  335. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  336. >>> input # doctest: +SKIP
  337. Tensor([1. 2.], device=xpux:0) # Rank 0
  338. Tensor([3. 4.], device=xpux:0) # Rank 1
  339. >>> F.distributed.all_reduce_sum(input, group=[0, 1]) # doctest: +SKIP
  340. Tensor([4. 6.], device=xpux:0) # Rank 0
  341. Tensor([4. 6.], device=xpux:0) # Rank 1
  342. >>> # We execute all_reduce_sum with on gpu0 with cuda stream 1
  343. >>> megengine.set_default_device("gpu0") # doctest: +SKIP
  344. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  345. >>> input # doctest: +SKIP
  346. Tensor([1. 2.], device=gpu0:0) # Rank 0
  347. Tensor([3. 4.], device=gpu0:0) # Rank 1
  348. >>> F.distributed.all_reduce_sum(input, device="gpu0:1") # doctest: +SKIP
  349. Tensor([4. 6.], device=gpu0:0) # Rank 0
  350. Tensor([4. 6.], device=gpu0:0) # Rank 1
  351. """
  352. mode = CollectiveComm.Mode.ALL_REDUCE_SUM
  353. return collective_comm(inp, mode, group, device)
  354. def all_reduce_max(
  355. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  356. ) -> Tensor:
  357. r"""Reduce tensors with max operation on each value across the specified group.
  358. Note:
  359. ``inp`` tensor must have identical shape in all processes across the group.
  360. Args:
  361. inp (Tensor): tensor to be reduced.
  362. Keyword args:
  363. group (Group or sequence of ints): the process group to work on. Default: ``WORLD``.
  364. ``WORLD`` group selects all processes available.
  365. list of process rank as parameter will create a new group to work on.
  366. device (:attr:`.Tensor.device`): the specific device to execute this operator. Default: ``None``
  367. ``None`` will select the device of ``inp`` to execute.
  368. Specially, ``GPU`` device can assign a different stream to execute
  369. by adding a number right after a colon following the device name while
  370. ``:0`` denotes default stream of GPU, otherwise will use default stream.
  371. Returns:
  372. A tensor with max operation on each value across the group.
  373. The shape of the output tensor must be the same as ``inp``, and the output
  374. tensor is going to be bitwise identical in all processes across the group.
  375. Examples:
  376. >>> # We execute all_reduce_max on rank 0 and rank 1
  377. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  378. >>> input # doctest: +SKIP
  379. Tensor([1. 2.], device=xpux:0) # Rank 0
  380. Tensor([3. 4.], device=xpux:0) # Rank 1
  381. >>> F.distributed.all_reduce_max(input, group=[0, 1]) # doctest: +SKIP
  382. Tensor([3. 4.], device=xpux:0) # Rank 0
  383. Tensor([3. 4.], device=xpux:0) # Rank 1
  384. >>> # We execute all_reduce_max with on gpu0 with cuda stream 1
  385. >>> megengine.set_default_device("gpu0") # doctest: +SKIP
  386. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  387. >>> input # doctest: +SKIP
  388. Tensor([1. 2.], device=gpu0:0) # Rank 0
  389. Tensor([3. 4.], device=gpu0:0) # Rank 1
  390. >>> F.distributed.all_reduce_max(input, device="gpu0:1") # doctest: +SKIP
  391. Tensor([3. 4.], device=xpux:0) # Rank 0
  392. Tensor([3. 4.], device=xpux:0) # Rank 1
  393. """
  394. mode = CollectiveComm.Mode.ALL_REDUCE_MAX
  395. return collective_comm(inp, mode, group, device)
  396. def all_reduce_min(
  397. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  398. ) -> Tensor:
  399. r"""Reduce tensors with min operation on each value across the specified group.
  400. Note:
  401. ``inp`` tensor must have identical shape in all processes across the group.
  402. Args:
  403. inp (Tensor): tensor to be reduced.
  404. Keyword args:
  405. group (Group or sequence of ints): the process group to work on. Default: ``WORLD``.
  406. ``WORLD`` group selects all processes available.
  407. list of process rank as parameter will create a new group to work on.
  408. device (:attr:`.Tensor.device`): the specific device to execute this operator. Default: ``None``
  409. ``None`` will select the device of ``inp`` to execute.
  410. Specially, ``GPU`` device can assign a different stream to execute
  411. by adding a number right after a colon following the device name while
  412. ``:0`` denotes default stream of GPU, otherwise will use default stream.
  413. Returns:
  414. A tensor with min operation on each value across the group.
  415. The shape of the output tensor must be the same as ``inp``, and the output
  416. tensor is going to be bitwise identical in all processes across the group.
  417. Examples:
  418. >>> # We execute all_reduce_min on rank 0 and rank 1
  419. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  420. >>> input # doctest: +SKIP
  421. Tensor([1. 2.], device=xpux:0) # Rank 0
  422. Tensor([3. 4.], device=xpux:0) # Rank 1
  423. >>> F.distributed.all_reduce_min(input, group=[0, 1]) # doctest: +SKIP
  424. Tensor([1. 2.], device=xpux:0) # Rank 0
  425. Tensor([1. 2.], device=xpux:0) # Rank 1
  426. >>> # We execute all_reduce_min with on gpu0 with cuda stream 1
  427. >>> megengine.set_default_device("gpu0") # doctest: +SKIP
  428. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  429. >>> input # doctest: +SKIP
  430. Tensor([1. 2.], device=gpu0:0) # Rank 0
  431. Tensor([3. 4.], device=gpu0:0) # Rank 1
  432. >>> F.distributed.all_reduce_min(input, device="gpu0:1") # doctest: +SKIP
  433. Tensor([1. 2.], device=xpux:0) # Rank 0
  434. Tensor([1. 2.], device=xpux:0) # Rank 1
  435. """
  436. mode = CollectiveComm.Mode.ALL_REDUCE_MIN
  437. return collective_comm(inp, mode, group, device)
  438. class _Gather(Function):
  439. def __init__(self, group=WORLD, device=None):
  440. self.group = group
  441. self.out_device = device
  442. def forward(self, data):
  443. self.in_device = str(data.device)
  444. return collective_comm(
  445. data, CollectiveComm.Mode.GATHER, self.group, self.out_device
  446. )
  447. def backward(self, grad):
  448. has_grad = _bcast_has_grad(self.group, grad)
  449. if has_grad:
  450. return scatter(grad, self.group, self.in_device)
  451. def gather(
  452. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  453. ) -> Tensor:
  454. r"""Gather tensors across the specified group.
  455. Only root process will receive the final result.
  456. Args:
  457. inp: Input tensor.
  458. group: The process group to work on.
  459. The default group is WORLD which means all processes available.
  460. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  461. device: The specific device to execute this operator.
  462. None default device means the device of inp will be used.
  463. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  464. 1 is stream id, and default stream id is 0.
  465. axis: The concat axis for collective_comm result
  466. Examples:
  467. .. code-block::
  468. input = Tensor([rank])
  469. # Rank 0 # input: Tensor([0])
  470. # Rank 1 # input: Tensor([1])
  471. output = gather(input)
  472. # Rank 0 # output: Tensor([0 1])
  473. # Rank 1 # output: None
  474. input = Tensor([rank])
  475. group = Group([1, 0]) # first rank is root
  476. output = gather(input, group)
  477. # Rank 0 # output: None
  478. # Rank 1 # output: Tensor([1 0])
  479. """
  480. assert (
  481. axis < inp.ndim
  482. ), "your concat_axis exceeds the dim of the tensor, the tensor shape is {}".format(
  483. inp.shape
  484. )
  485. out = _Gather(group, device)(inp)
  486. if group.rank == 0:
  487. if axis == 0:
  488. return out
  489. else:
  490. group_size = group.size
  491. transformed_shape = list(inp._tuple_shape)
  492. transformed_shape[axis] *= group_size
  493. n, *shp = out._tuple_shape
  494. index = (
  495. [_ for _ in range(1, axis)]
  496. + [axis, 0]
  497. + [_ for _ in range(axis + 1, out.ndim + 1)]
  498. )
  499. return (
  500. out.reshape(group_size, n // group_size, *shp)
  501. .transpose(index)
  502. .reshape(transformed_shape)
  503. )
  504. else:
  505. _save_output_for_autodiff(inp, out)
  506. class _Scatter(Function):
  507. def __init__(self, group=WORLD, device=None):
  508. self.group = group
  509. self.out_device = device
  510. def forward(self, data):
  511. self.in_device = str(data.device)
  512. return collective_comm(
  513. data, CollectiveComm.Mode.SCATTER, self.group, self.out_device
  514. )
  515. def backward(self, grad):
  516. # TODO backward with a part of grad
  517. if grad is not None:
  518. return gather(grad, self.group, self.in_device)
  519. def scatter(
  520. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  521. ) -> Tensor:
  522. r"""Split tensor in root process at first dimension.
  523. Args:
  524. inp: Input tensor.
  525. group: The process group to work on.
  526. The default group is WORLD which means all processes available.
  527. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  528. device: The specific device to execute this operator.
  529. None default device means the device of inp will be used.
  530. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  531. 1 is stream id, and default stream id is 0.
  532. axis: The concat axis for collective_comm result
  533. The default axis is 0
  534. Returns:
  535. Split tensor.
  536. Examples:
  537. .. code-block::
  538. input = Tensor([0 1]) + rank*2
  539. # Rank 0 # input: Tensor([0 1])
  540. # Rank 1 # input: Tensor([2 3])
  541. output = scatter(input)
  542. # Rank 0 # output: Tensor([0])
  543. # Rank 1 # output: Tensor([1])
  544. input = Tensor([0 1]) + rank*2
  545. group = Group([1, 0]) # first rank is root
  546. output = scatter(input, group)
  547. # Rank 0 # output: Tensor([3])
  548. # Rank 1 # output: Tensor([2])
  549. """
  550. shape, dtype = _bcast_shape_dtype(group, inp)
  551. if group.rank != 0:
  552. # dummy input to infer shape
  553. inp = _dummy_input(shape, dtype, device)
  554. _bcast_tracer_state(group, inp)
  555. assert (
  556. list(inp._tuple_shape)[axis] % group.size == 0
  557. ), "current axis: {} can't devided by group size".format(axis)
  558. if axis != 0:
  559. group_size = group.size
  560. k_new_shape = list(inp._tuple_shape)
  561. k_new_shape[axis] //= group_size
  562. k_new_shape[0] *= group_size
  563. new_shape = list(inp._tuple_shape)
  564. new_shape[axis] //= group_size
  565. new_shape.insert(axis, group_size)
  566. index = (
  567. [axis]
  568. + [_ for _ in range(0, axis)]
  569. + [_ for _ in range(axis + 1, inp.ndim + 1)]
  570. )
  571. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  572. out = _Scatter(group, device)(inp)
  573. return out
  574. def all_to_all(
  575. inp: Tensor,
  576. group: Optional[Group] = WORLD,
  577. device: Optional[str] = None,
  578. split_axis: int = 0,
  579. concat_axis: int = 0,
  580. ) -> Tensor:
  581. r"""Each process scatter input tensor to all processes and return gathered tensor.
  582. Args:
  583. inp: Input tensor.
  584. group: The process group to work on.
  585. The default group is WORLD which means all processes available.
  586. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  587. device: The specific device to execute this operator.
  588. None default device means the device of inp will be used.
  589. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  590. 1 is stream id, and default stream id is 0.
  591. split_axis: The axis that collectivecomm will split data
  592. the default axis is 0
  593. Returns:
  594. Result tensor.
  595. Examples:
  596. .. code-block::
  597. input = Tensor([0 1]) + rank*2
  598. # Rank 0 # input: Tensor([0 1])
  599. # Rank 1 # input: Tensor([2 3])
  600. output = all_to_all(input)
  601. # Rank 0 # output: Tensor([0 2])
  602. # Rank 1 # output: Tensor([1 3])
  603. input = Tensor([0 1]) + rank*2
  604. group = Group([1, 0])
  605. output = all_to_all(input, group)
  606. # Rank 0 # output: Tensor([0 3])
  607. # Rank 1 # output: Tensor([2 1])
  608. """
  609. group_size = group.size if group is not None else 1
  610. assert (
  611. list(inp._tuple_shape)[split_axis] % group_size == 0
  612. ), "current axis: {} can't devided by group size".format(split_axis)
  613. origin_shape = inp._tuple_shape
  614. if split_axis != 0:
  615. k_new_shape = list(inp._tuple_shape)
  616. k_new_shape[split_axis] //= group_size
  617. k_new_shape[0] *= group_size
  618. new_shape = list(inp._tuple_shape)
  619. new_shape[split_axis] //= group_size
  620. new_shape.insert(split_axis, group_size)
  621. index = (
  622. [split_axis]
  623. + [_ for _ in range(0, split_axis)]
  624. + [_ for _ in range(split_axis + 1, inp.ndim + 1)]
  625. )
  626. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  627. mode = CollectiveComm.Mode.ALL_TO_ALL
  628. out = collective_comm(inp, mode, group, device)
  629. if concat_axis == 0:
  630. return out
  631. transformed_shape = list(origin_shape)
  632. transformed_shape[concat_axis] *= group_size
  633. transformed_shape[split_axis] //= group_size
  634. n, *shp = out._tuple_shape
  635. index = (
  636. [_ for _ in range(1, concat_axis)]
  637. + [concat_axis, 0]
  638. + [_ for _ in range(concat_axis + 1, out.ndim + 1)]
  639. )
  640. return (
  641. out.reshape(group_size, n // group_size, *shp)
  642. .transpose(index)
  643. .reshape(transformed_shape)
  644. )
  645. class _SendRecvGroup:
  646. def __init__(self, rank_from, rank_to):
  647. self.key = "{}->{}".format(rank_from, rank_to)
  648. self.rank_from = rank_from
  649. self.rank_to = rank_to
  650. self.size = 2
  651. @property
  652. def rank(self):
  653. if get_rank() == self.rank_from:
  654. return 0
  655. else:
  656. return 1
  657. class _RemoteSend(Function):
  658. def __init__(self, op: RemoteSend):
  659. self.op = op
  660. def forward(self, data):
  661. self.device = str(data.device)
  662. (self.dummy,) = apply(self.op, data)
  663. return self.dummy
  664. def backward(self, grad):
  665. assert grad is None
  666. has_grad = get_client().bcast_val(None, self.op.key, 2)
  667. if has_grad:
  668. return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,)
  669. class _RemoteRecv(Function):
  670. def __init__(self, op: RemoteRecv):
  671. self.op = op
  672. def forward(self, dummy):
  673. return apply(self.op, dummy)[0]
  674. def backward(self, grad):
  675. get_client().bcast_val(grad is not None, self.op.key, 2)
  676. if grad is not None:
  677. remote_send(grad, self.op.rank_from)
  678. def remote_send(inp: Tensor, dest_rank: int):
  679. r"""Send tensor to another process.
  680. Args:
  681. inp: Tensor to send.
  682. dest_rank: Rank of destination process.
  683. Returns:
  684. None.
  685. Examples:
  686. .. code-block::
  687. if rank == 0:
  688. data = mge.tensor(1)
  689. # Tensor(1)
  690. F.distributed.remote_send(data, 1) # return None
  691. else:
  692. data = F.distributed.remote_recv(0)
  693. # Tensor(1)
  694. """
  695. group = _SendRecvGroup(get_rank(), dest_rank)
  696. _bcast_shape_dtype(group, inp)
  697. _bcast_tracer_state(group, inp)
  698. op = RemoteSend()
  699. op.key = group.key
  700. op.addr, op.port = get_mm_server_addr()
  701. op.rank_to = dest_rank
  702. op.backend = _backend()
  703. out = _RemoteSend(op)(inp)
  704. _save_output_for_autodiff(inp, out)
  705. def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor:
  706. r"""Receive a tensor from another process.
  707. Args:
  708. src_rank: Rank of source process.
  709. device: The specific device to execute this operator.
  710. None default device means the device of inp will be used.
  711. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  712. 1 is stream id, and default stream id is 0.
  713. inp: Dummy input to determine received tensor type.
  714. Returns:
  715. Received tensor.
  716. Examples:
  717. .. code-block::
  718. if rank == 0:
  719. data = mge.tensor(1)
  720. # Tensor(1)
  721. F.distributed.remote_send(data, 1) # return None
  722. else:
  723. data = F.distributed.remote_recv(0)
  724. # Tensor(1)
  725. """
  726. group = _SendRecvGroup(src_rank, get_rank())
  727. shape, dtype = _bcast_shape_dtype(group, None)
  728. if device is None:
  729. device = get_default_device()
  730. # dummy input
  731. if inp is None:
  732. inp = Tensor(0, device=device)
  733. _bcast_tracer_state(group, inp)
  734. op = RemoteRecv()
  735. op.key = group.key
  736. op.cn = device
  737. op.shape = shape
  738. op.dtype = dtype
  739. op.addr, op.port = get_mm_server_addr()
  740. op.rank_from = src_rank
  741. op.backend = _backend()
  742. ret = _RemoteRecv(op)(inp)
  743. return ret