You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Tuple
  10. import numpy as np
  11. from ..core._imperative_rt.core2 import apply
  12. from ..core.autodiff.grad import Function, _grad_manager_dict
  13. from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
  14. from ..core.tensor.utils import isscalar, setscalar
  15. from ..device import get_default_device, what_is_xpu
  16. from ..tensor import Tensor
  17. from . import group
  18. from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank
  19. __all__ = [
  20. "reduce_sum",
  21. "broadcast",
  22. "all_gather",
  23. "reduce_scatter_sum",
  24. "all_reduce_sum",
  25. "all_reduce_max",
  26. "all_reduce_min",
  27. "gather",
  28. "scatter",
  29. "all_to_all",
  30. "remote_send",
  31. "remote_recv",
  32. ]
  33. _device2backend = {
  34. "gpu": "nccl",
  35. "cuda": "nccl",
  36. "rocm": "rccl",
  37. }
  38. def _backend():
  39. if group._sd.backend == "auto":
  40. return _device2backend[what_is_xpu()]
  41. else:
  42. return group._sd.backend
  43. def collective_comm(inp, mode, group, device):
  44. r"""Helper function for applying collective communication functions."""
  45. assert isinstance(group, Group)
  46. if group is None:
  47. return inp
  48. if device is None:
  49. device = ""
  50. addr, port = get_mm_server_addr()
  51. op = CollectiveComm(
  52. key=group.key + _backend(),
  53. nr_devices=group.size,
  54. rank=group.rank,
  55. is_root=(group.rank == 0),
  56. local_grad=False,
  57. addr=addr,
  58. port=port,
  59. mode=mode,
  60. dtype=inp.dtype,
  61. backend=_backend(),
  62. comp_node=device,
  63. )
  64. (result,) = apply(op, inp)
  65. # assume all workers have homogeneous shape
  66. if mode in (
  67. CollectiveComm.Mode.REDUCE_SUM,
  68. CollectiveComm.Mode.BROADCAST,
  69. CollectiveComm.Mode.ALL_REDUCE_SUM,
  70. CollectiveComm.Mode.ALL_REDUCE_MAX,
  71. CollectiveComm.Mode.ALL_REDUCE_MIN,
  72. ):
  73. if isscalar(inp):
  74. setscalar(result)
  75. return result
  76. def _save_output_for_autodiff(inp, out):
  77. for g in _grad_manager_dict.values():
  78. if g._is_attached_to(inp):
  79. g._refkeeper.append(out)
  80. def _bcast_has_grad(group, grad):
  81. if group.rank == 0:
  82. has_grad = grad is not None
  83. get_client().bcast_val(has_grad, group.key, group.size)
  84. else:
  85. has_grad = get_client().bcast_val(None, group.key, group.size)
  86. return has_grad
  87. def _bcast_shape_dtype(group, inp):
  88. if group.rank == 0:
  89. # FIXME in some cases, shape is not available(output of condtake)
  90. shape = inp._tuple_shape
  91. dtype = np.dtype(inp.dtype).name
  92. get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size)
  93. else:
  94. val = get_client().bcast_val(None, group.key, group.size)
  95. shape = val["shape"]
  96. dtype = val["dtype"]
  97. return shape, dtype
  98. def _bcast_tracer_state(group, inp):
  99. if group.rank == 0:
  100. tracer_keys = []
  101. for n, g in _grad_manager_dict.items():
  102. if g._is_attached_to(inp):
  103. tracer_keys.append(n)
  104. get_client().bcast_val(tracer_keys, group.key, group.size)
  105. else:
  106. tracer_keys = get_client().bcast_val(None, group.key, group.size)
  107. for n in tracer_keys:
  108. g = _grad_manager_dict.get(n)
  109. if g is not None:
  110. g.wrt(inp)
  111. g._refkeeper.append(inp)
  112. def _dummy_input(shape, dtype, device=None):
  113. if device is None:
  114. device = get_default_device()
  115. inp = Tensor(0, dtype=dtype, device=device)
  116. if len(shape) > 0:
  117. inp = inp._broadcast(shape)
  118. return inp
  119. class _ReduceSum(Function):
  120. def __init__(self, group=WORLD, device=None):
  121. self.group = group
  122. self.out_device = device
  123. def forward(self, data):
  124. self.in_device = str(data.device)
  125. return collective_comm(
  126. data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device,
  127. )
  128. def backward(self, grad):
  129. has_grad = _bcast_has_grad(self.group, grad)
  130. if has_grad:
  131. return broadcast(grad, self.group, self.in_device)
  132. def reduce_sum(
  133. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  134. ) -> Tensor:
  135. r"""Reduce tensor data across the specified group by sum.
  136. Only root process will receive the final result.
  137. Args:
  138. inp: Input tensor.
  139. group: The process group to work on.
  140. The default group is WORLD which means all processes available.
  141. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  142. device: The specific device to execute this operator.
  143. None default device means the device of inp will be used.
  144. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  145. 1 is stream id, and default stream id is 0.
  146. Returns:
  147. Reduced tensor if in root process, None in other processes.
  148. Examples:
  149. .. code-block::
  150. input = Tensor([rank])
  151. # Rank 0 # input: Tensor([0])
  152. # Rank 1 # input: Tensor([1])
  153. output = reduce_sum(input)
  154. # Rank 0 # output: Tensor([1])
  155. # Rank 1 # output: None
  156. input = Tensor([rank])
  157. group = Group([1, 0]) # first rank is root
  158. output = reduce_sum(input, group)
  159. # Rank 0 # output: None
  160. # Rank 1 # output: Tensor([1])
  161. """
  162. op = _ReduceSum(group, device)
  163. (out,) = apply(op, inp)
  164. if group.rank == 0:
  165. return out
  166. else:
  167. _save_output_for_autodiff(inp, out)
  168. class _Broadcast(Function):
  169. def __init__(self, group=WORLD, device=None):
  170. self.group = group
  171. self.out_device = device
  172. def forward(self, data):
  173. self.in_device = str(data.device)
  174. return collective_comm(
  175. data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device,
  176. )
  177. def backward(self, grad):
  178. # TODO backward with a part of grad
  179. if grad is not None:
  180. return reduce_sum(grad, self.group, self.in_device)
  181. def broadcast(
  182. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  183. ) -> Tensor:
  184. r"""Broadcast tensor data from root process to others.
  185. Args:
  186. inp: Input tensor.
  187. group: The process group to work on.
  188. The default group is WORLD which means all processes available.
  189. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  190. device: The specific device to execute this operator.
  191. None default device means the device of inp will be used.
  192. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  193. 1 is stream id, and default stream id is 0.
  194. Returns:
  195. Result tensor.
  196. Examples:
  197. .. code-block::
  198. input = Tensor([rank])
  199. # Rank 0 # input: Tensor([0])
  200. # Rank 1 # input: Tensor([1])
  201. output = broadcast(input)
  202. # Rank 0 # output: Tensor([0])
  203. # Rank 1 # output: Tensor([0])
  204. input = Tensor([rank])
  205. group = Group([1, 0]) # first rank is root
  206. output = broadcast(input, group)
  207. # Rank 0 # output: Tensor([1])
  208. # Rank 1 # output: Tensor([1])
  209. """
  210. shape, dtype = _bcast_shape_dtype(group, inp)
  211. if group.rank != 0:
  212. # dummy input to infer shape
  213. inp = _dummy_input(shape, dtype, device)
  214. _bcast_tracer_state(group, inp)
  215. op = _Broadcast(group, device)
  216. (out,) = apply(op, inp)
  217. return out
  218. def _bcast_param(
  219. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None
  220. ) -> Tensor:
  221. mode = CollectiveComm.Mode.BROADCAST
  222. return collective_comm(inp, mode, group, device)
  223. def all_gather(
  224. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  225. ) -> Tensor:
  226. r"""Gather tensors across the specified group and concat them at first dimension.
  227. Args:
  228. inp: Input tensor.
  229. group: The process group to work on.
  230. The default group is WORLD which means all processes available.
  231. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  232. device: The specific device to execute this operator.
  233. None default device means the device of inp will be used.
  234. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  235. 1 is stream id, and default stream id is 0.
  236. axis: The concat axis for collective_comm result
  237. The default axis is 0
  238. Returns:
  239. Result tensor.
  240. Examples:
  241. .. code-block::
  242. input = Tensor([rank])
  243. # Rank 0 # input: Tensor([0])
  244. # Rank 1 # input: Tensor([1])
  245. output = all_gather(input)
  246. # Rank 0 # output: Tensor([0 1])
  247. # Rank 1 # output: Tensor([0 1])
  248. input = Tensor([rank])
  249. group = Group([1, 0])
  250. output = all_gather(input, group)
  251. # Rank 0 # output: Tensor([1 0])
  252. # Rank 1 # output: Tensor([1 0])
  253. """
  254. mode = CollectiveComm.Mode.ALL_GATHER
  255. out = collective_comm(inp, mode, group, device)
  256. if axis == 0:
  257. return out
  258. else:
  259. group_size = group.size if group is not None else 1
  260. transformed_shape = list(inp._tuple_shape)
  261. transformed_shape[axis] *= group_size
  262. n, *shp = out._tuple_shape
  263. index = (
  264. [_ for _ in range(1, axis)]
  265. + [axis, 0]
  266. + [_ for _ in range(axis + 1, out.ndim + 1)]
  267. )
  268. return (
  269. out.reshape(group_size, n // group_size, *shp)
  270. .transpose(index)
  271. .reshape(transformed_shape)
  272. )
  273. def reduce_scatter_sum(
  274. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0
  275. ) -> Tensor:
  276. r"""Reduce tensors across the specified group by sum and split them at first dimension.
  277. Args:
  278. inp: Input tensor.
  279. group: The process group to work on.
  280. The default group is WORLD which means all processes available.
  281. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  282. device: The specific device to execute this operator.
  283. None default device means the device of inp will be used.
  284. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  285. 1 is stream id, and default stream id is 0.
  286. axis: The split axis for collective_comm result
  287. The default axis is 0, the data will split in the 0 axis
  288. Returns:
  289. Split tensor.
  290. Examples:
  291. .. code-block::
  292. input = Tensor([0 1])
  293. # Rank 0 # input: Tensor([0 1])
  294. # Rank 1 # input: Tensor([0 1])
  295. output = reduce_scatter_sum(input)
  296. # Rank 0 # output: Tensor([0])
  297. # Rank 1 # output: Tensor([2])
  298. input = Tensor([0 1])
  299. group = Group([1, 0])
  300. output = reduce_scatter_sum(input, group)
  301. # Rank 0 # output: Tensor([2])
  302. # Rank 1 # output: Tensor([0])
  303. """
  304. group_size = group.size if group is not None else 1
  305. assert (
  306. list(inp._tuple_shape)[axis] % group_size == 0
  307. ), "current axis: {} can't devided by group size".format(axis)
  308. if axis != 0:
  309. k_new_shape = list(inp._tuple_shape)
  310. k_new_shape[axis] //= group_size
  311. k_new_shape[0] *= group_size
  312. new_shape = list(inp._tuple_shape)
  313. new_shape[axis] //= group_size
  314. new_shape.insert(axis, group_size)
  315. index = (
  316. [axis]
  317. + [_ for _ in range(0, axis)]
  318. + [_ for _ in range(axis + 1, inp.ndim + 1)]
  319. )
  320. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  321. mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
  322. return collective_comm(inp, mode, group, device)
  323. def all_reduce_sum(
  324. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  325. ) -> Tensor:
  326. r"""Reduce tensors with sum operation on each value across the specified group.
  327. Note:
  328. ``inp`` tensor must have identical shape in all processes across the group.
  329. Args:
  330. inp (Tensor): tensor to be reduced.
  331. Keyword args:
  332. group (Group or sequence of ints): the process group to work on. Default: ``WORLD``.
  333. ``WORLD`` group selects all processes available.
  334. list of process rank as parameter will create a new group to work on.
  335. device (:attr:`.Tensor.device`): the specific device to execute this operator. Default: ``None``
  336. ``None`` will select the device of ``inp`` to execute.
  337. Specially, ``GPU`` device can assign a different stream to execute
  338. by adding a number right after a colon following the device name while
  339. ``:0`` denotes default stream of GPU, otherwise will use default stream.
  340. Returns:
  341. A tensor with sum operation on each value across the group.
  342. The shape of the output tensor must be the same as ``inp``, and the output
  343. tensor is going to be bitwise identical in all processes across the group.
  344. Examples:
  345. >>> # We execute all_reduce_sum on rank 0 and rank 1
  346. >>> input = F.arange(2) + 1 + 2 * rank
  347. >>> input
  348. Tensor([1. 2.], device=xpux:0) # Rank 0
  349. Tensor([3. 4.], device=xpux:0) # Rank 1
  350. >>> F.distributed.all_reduce_sum(input, group=[0, 1])
  351. Tensor([4. 6.], device=xpux:0) # Rank 0
  352. Tensor([4. 6.], device=xpux:0) # Rank 1
  353. >>> # We execute all_reduce_sum with on gpu0 with cuda stream 1
  354. >>> megengine.set_default_device("gpu0")
  355. >>> input = F.arange(2) + 1 + 2 * rank
  356. >>> input
  357. Tensor([1. 2.], device=gpu0:0) # Rank 0
  358. Tensor([3. 4.], device=gpu0:0) # Rank 1
  359. >>> F.distributed.all_reduce_sum(input, device="gpu0:1")
  360. Tensor([4. 6.], device=gpu0:0) # Rank 0
  361. Tensor([4. 6.], device=gpu0:0) # Rank 1
  362. """
  363. mode = CollectiveComm.Mode.ALL_REDUCE_SUM
  364. return collective_comm(inp, mode, group, device)
  365. def all_reduce_max(
  366. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  367. ) -> Tensor:
  368. r"""Reduce tensors with max operation on each value across the specified group.
  369. Note:
  370. ``inp`` tensor must have identical shape in all processes across the group.
  371. Args:
  372. inp (Tensor): tensor to be reduced.
  373. Keyword args:
  374. group (Group or sequence of ints): the process group to work on. Default: ``WORLD``.
  375. ``WORLD`` group selects all processes available.
  376. list of process rank as parameter will create a new group to work on.
  377. device (:attr:`.Tensor.device`): the specific device to execute this operator. Default: ``None``
  378. ``None`` will select the device of ``inp`` to execute.
  379. Specially, ``GPU`` device can assign a different stream to execute
  380. by adding a number right after a colon following the device name while
  381. ``:0`` denotes default stream of GPU, otherwise will use default stream.
  382. Returns:
  383. A tensor with max operation on each value across the group.
  384. The shape of the output tensor must be the same as ``inp``, and the output
  385. tensor is going to be bitwise identical in all processes across the group.
  386. Examples:
  387. >>> # We execute all_reduce_max on rank 0 and rank 1
  388. >>> input = F.arange(2) + 1 + 2 * rank
  389. >>> input
  390. Tensor([1. 2.], device=xpux:0) # Rank 0
  391. Tensor([3. 4.], device=xpux:0) # Rank 1
  392. >>> F.distributed.all_reduce_max(input, group=[0, 1])
  393. Tensor([3. 4.], device=xpux:0) # Rank 0
  394. Tensor([3. 4.], device=xpux:0) # Rank 1
  395. >>> # We execute all_reduce_max with on gpu0 with cuda stream 1
  396. >>> megengine.set_default_device("gpu0")
  397. >>> input = F.arange(2) + 1 + 2 * rank
  398. >>> input
  399. Tensor([1. 2.], device=gpu0:0) # Rank 0
  400. Tensor([3. 4.], device=gpu0:0) # Rank 1
  401. >>> F.distributed.all_reduce_max(input, device="gpu0:1")
  402. Tensor([3. 4.], device=xpux:0) # Rank 0
  403. Tensor([3. 4.], device=xpux:0) # Rank 1
  404. """
  405. mode = CollectiveComm.Mode.ALL_REDUCE_MAX
  406. return collective_comm(inp, mode, group, device)
  407. def all_reduce_min(
  408. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  409. ) -> Tensor:
  410. r"""Reduce tensors with min operation on each value across the specified group.
  411. Note:
  412. ``inp`` tensor must have identical shape in all processes across the group.
  413. Args:
  414. inp (Tensor): tensor to be reduced.
  415. Keyword args:
  416. group (Group or sequence of ints): the process group to work on. Default: ``WORLD``.
  417. ``WORLD`` group selects all processes available.
  418. list of process rank as parameter will create a new group to work on.
  419. device (:attr:`.Tensor.device`): the specific device to execute this operator. Default: ``None``
  420. ``None`` will select the device of ``inp`` to execute.
  421. Specially, ``GPU`` device can assign a different stream to execute
  422. by adding a number right after a colon following the device name while
  423. ``:0`` denotes default stream of GPU, otherwise will use default stream.
  424. Returns:
  425. A tensor with min operation on each value across the group.
  426. The shape of the output tensor must be the same as ``inp``, and the output
  427. tensor is going to be bitwise identical in all processes across the group.
  428. Examples:
  429. >>> # We execute all_reduce_min on rank 0 and rank 1
  430. >>> input = F.arange(2) + 1 + 2 * rank
  431. >>> input
  432. Tensor([1. 2.], device=xpux:0) # Rank 0
  433. Tensor([3. 4.], device=xpux:0) # Rank 1
  434. >>> F.distributed.all_reduce_min(input, group=[0, 1])
  435. Tensor([1. 2.], device=xpux:0) # Rank 0
  436. Tensor([1. 2.], device=xpux:0) # Rank 1
  437. >>> # We execute all_reduce_min with on gpu0 with cuda stream 1
  438. >>> megengine.set_default_device("gpu0")
  439. >>> input = F.arange(2) + 1 + 2 * rank
  440. >>> input
  441. Tensor([1. 2.], device=gpu0:0) # Rank 0
  442. Tensor([3. 4.], device=gpu0:0) # Rank 1
  443. >>> F.distributed.all_reduce_min(input, device="gpu0:1")
  444. Tensor([1. 2.], device=xpux:0) # Rank 0
  445. Tensor([1. 2.], device=xpux:0) # Rank 1
  446. """
  447. mode = CollectiveComm.Mode.ALL_REDUCE_MIN
  448. return collective_comm(inp, mode, group, device)
  449. class _Gather(Function):
  450. def __init__(self, group=WORLD, device=None):
  451. self.group = group
  452. self.out_device = device
  453. def forward(self, data):
  454. self.in_device = str(data.device)
  455. return collective_comm(
  456. data, CollectiveComm.Mode.GATHER, self.group, self.out_device
  457. )
  458. def backward(self, grad):
  459. has_grad = _bcast_has_grad(self.group, grad)
  460. if has_grad:
  461. return scatter(grad, self.group, self.in_device)
  462. def gather(
  463. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  464. ) -> Tensor:
  465. r"""Gather tensors across the specified group.
  466. Only root process will receive the final result.
  467. Args:
  468. inp: Input tensor.
  469. group: The process group to work on.
  470. The default group is WORLD which means all processes available.
  471. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  472. device: The specific device to execute this operator.
  473. None default device means the device of inp will be used.
  474. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  475. 1 is stream id, and default stream id is 0.
  476. axis: The concat axis for collective_comm result
  477. Examples:
  478. .. code-block::
  479. input = Tensor([rank])
  480. # Rank 0 # input: Tensor([0])
  481. # Rank 1 # input: Tensor([1])
  482. output = gather(input)
  483. # Rank 0 # output: Tensor([0 1])
  484. # Rank 1 # output: None
  485. input = Tensor([rank])
  486. group = Group([1, 0]) # first rank is root
  487. output = gather(input, group)
  488. # Rank 0 # output: None
  489. # Rank 1 # output: Tensor([1 0])
  490. """
  491. assert (
  492. axis < inp.ndim
  493. ), "your concat_axis exceeds the dim of the tensor, the tensor shape is {}".format(
  494. inp.shape
  495. )
  496. op = _Gather(group, device)
  497. (out,) = apply(op, inp)
  498. if group.rank == 0:
  499. if axis == 0:
  500. return out
  501. else:
  502. group_size = group.size
  503. transformed_shape = list(inp._tuple_shape)
  504. transformed_shape[axis] *= group_size
  505. n, *shp = out._tuple_shape
  506. index = (
  507. [_ for _ in range(1, axis)]
  508. + [axis, 0]
  509. + [_ for _ in range(axis + 1, out.ndim + 1)]
  510. )
  511. return (
  512. out.reshape(group_size, n // group_size, *shp)
  513. .transpose(index)
  514. .reshape(transformed_shape)
  515. )
  516. else:
  517. _save_output_for_autodiff(inp, out)
  518. class _Scatter(Function):
  519. def __init__(self, group=WORLD, device=None):
  520. self.group = group
  521. self.out_device = device
  522. def forward(self, data):
  523. self.in_device = str(data.device)
  524. return collective_comm(
  525. data, CollectiveComm.Mode.SCATTER, self.group, self.out_device
  526. )
  527. def backward(self, grad):
  528. # TODO backward with a part of grad
  529. if grad is not None:
  530. return gather(grad, self.group, self.in_device)
  531. def scatter(
  532. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  533. ) -> Tensor:
  534. r"""Split tensor in root process at first dimension.
  535. Args:
  536. inp: Input tensor.
  537. group: The process group to work on.
  538. The default group is WORLD which means all processes available.
  539. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  540. device: The specific device to execute this operator.
  541. None default device means the device of inp will be used.
  542. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  543. 1 is stream id, and default stream id is 0.
  544. axis: The concat axis for collective_comm result
  545. The default axis is 0
  546. Returns:
  547. Split tensor.
  548. Examples:
  549. .. code-block::
  550. input = Tensor([0 1]) + rank*2
  551. # Rank 0 # input: Tensor([0 1])
  552. # Rank 1 # input: Tensor([2 3])
  553. output = scatter(input)
  554. # Rank 0 # output: Tensor([0])
  555. # Rank 1 # output: Tensor([1])
  556. input = Tensor([0 1]) + rank*2
  557. group = Group([1, 0]) # first rank is root
  558. output = scatter(input, group)
  559. # Rank 0 # output: Tensor([3])
  560. # Rank 1 # output: Tensor([2])
  561. """
  562. shape, dtype = _bcast_shape_dtype(group, inp)
  563. if group.rank != 0:
  564. # dummy input to infer shape
  565. inp = _dummy_input(shape, dtype, device)
  566. _bcast_tracer_state(group, inp)
  567. assert (
  568. list(inp._tuple_shape)[axis] % group.size == 0
  569. ), "current axis: {} can't devided by group size".format(axis)
  570. if axis != 0:
  571. group_size = group.size
  572. k_new_shape = list(inp._tuple_shape)
  573. k_new_shape[axis] //= group_size
  574. k_new_shape[0] *= group_size
  575. new_shape = list(inp._tuple_shape)
  576. new_shape[axis] //= group_size
  577. new_shape.insert(axis, group_size)
  578. index = (
  579. [axis]
  580. + [_ for _ in range(0, axis)]
  581. + [_ for _ in range(axis + 1, inp.ndim + 1)]
  582. )
  583. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  584. op = _Scatter(group, device)
  585. (out,) = apply(op, inp)
  586. return out
  587. def all_to_all(
  588. inp: Tensor,
  589. group: Optional[Group] = WORLD,
  590. device: Optional[str] = None,
  591. split_axis: int = 0,
  592. concat_axis: int = 0,
  593. ) -> Tensor:
  594. r"""Each process scatter input tensor to all processes and return gathered tensor.
  595. Args:
  596. inp: Input tensor.
  597. group: The process group to work on.
  598. The default group is WORLD which means all processes available.
  599. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  600. device: The specific device to execute this operator.
  601. None default device means the device of inp will be used.
  602. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  603. 1 is stream id, and default stream id is 0.
  604. split_axis: The axis that collectivecomm will split data
  605. the default axis is 0
  606. Returns:
  607. Result tensor.
  608. Examples:
  609. .. code-block::
  610. input = Tensor([0 1]) + rank*2
  611. # Rank 0 # input: Tensor([0 1])
  612. # Rank 1 # input: Tensor([2 3])
  613. output = all_to_all(input)
  614. # Rank 0 # output: Tensor([0 2])
  615. # Rank 1 # output: Tensor([1 3])
  616. input = Tensor([0 1]) + rank*2
  617. group = Group([1, 0])
  618. output = all_to_all(input, group)
  619. # Rank 0 # output: Tensor([0 3])
  620. # Rank 1 # output: Tensor([2 1])
  621. """
  622. group_size = group.size if group is not None else 1
  623. assert (
  624. list(inp._tuple_shape)[split_axis] % group_size == 0
  625. ), "current axis: {} can't devided by group size".format(split_axis)
  626. origin_shape = inp._tuple_shape
  627. if split_axis != 0:
  628. k_new_shape = list(inp._tuple_shape)
  629. k_new_shape[split_axis] //= group_size
  630. k_new_shape[0] *= group_size
  631. new_shape = list(inp._tuple_shape)
  632. new_shape[split_axis] //= group_size
  633. new_shape.insert(split_axis, group_size)
  634. index = (
  635. [split_axis]
  636. + [_ for _ in range(0, split_axis)]
  637. + [_ for _ in range(split_axis + 1, inp.ndim + 1)]
  638. )
  639. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  640. mode = CollectiveComm.Mode.ALL_TO_ALL
  641. out = collective_comm(inp, mode, group, device)
  642. if concat_axis == 0:
  643. return out
  644. transformed_shape = list(origin_shape)
  645. transformed_shape[concat_axis] *= group_size
  646. transformed_shape[split_axis] //= group_size
  647. n, *shp = out._tuple_shape
  648. index = (
  649. [_ for _ in range(1, concat_axis)]
  650. + [concat_axis, 0]
  651. + [_ for _ in range(concat_axis + 1, out.ndim + 1)]
  652. )
  653. return (
  654. out.reshape(group_size, n // group_size, *shp)
  655. .transpose(index)
  656. .reshape(transformed_shape)
  657. )
  658. class _SendRecvGroup:
  659. def __init__(self, rank_from, rank_to):
  660. self.key = "{}->{}".format(rank_from, rank_to)
  661. self.rank_from = rank_from
  662. self.rank_to = rank_to
  663. self.size = 2
  664. @property
  665. def rank(self):
  666. if get_rank() == self.rank_from:
  667. return 0
  668. else:
  669. return 1
  670. class _RemoteSend(Function):
  671. def __init__(self, op: RemoteSend):
  672. self.op = op
  673. def forward(self, data):
  674. self.device = str(data.device)
  675. (self.dummy,) = apply(self.op, data)
  676. return self.dummy
  677. def backward(self, grad):
  678. assert grad is None
  679. has_grad = get_client().bcast_val(None, self.op.key, 2)
  680. if has_grad:
  681. return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,)
  682. class _RemoteRecv(Function):
  683. def __init__(self, op: RemoteRecv):
  684. self.op = op
  685. def forward(self, dummy):
  686. return apply(self.op, dummy)
  687. def backward(self, grad):
  688. get_client().bcast_val(grad is not None, self.op.key, 2)
  689. if grad is not None:
  690. remote_send(grad, self.op.rank_from)
  691. def remote_send(inp: Tensor, dest_rank: int):
  692. r"""Send tensor to another process.
  693. Args:
  694. inp: Tensor to send.
  695. dest_rank: Rank of destination process.
  696. Returns:
  697. None.
  698. Examples:
  699. .. code-block::
  700. if rank == 0:
  701. data = mge.tensor(1)
  702. # Tensor(1)
  703. F.distributed.remote_send(data, 1) # return None
  704. else:
  705. data = F.distributed.remote_recv(0)
  706. # Tensor(1)
  707. """
  708. group = _SendRecvGroup(get_rank(), dest_rank)
  709. _bcast_shape_dtype(group, inp)
  710. _bcast_tracer_state(group, inp)
  711. op = RemoteSend()
  712. op.key = group.key
  713. op.addr, op.port = get_mm_server_addr()
  714. op.rank_to = dest_rank
  715. op.backend = _backend()
  716. (out,) = apply(_RemoteSend(op), inp)
  717. _save_output_for_autodiff(inp, out)
  718. def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor:
  719. r"""Receive a tensor from another process.
  720. Args:
  721. src_rank: Rank of source process.
  722. device: The specific device to execute this operator.
  723. None default device means the device of inp will be used.
  724. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  725. 1 is stream id, and default stream id is 0.
  726. inp: Dummy input to determine received tensor type.
  727. Returns:
  728. Received tensor.
  729. Examples:
  730. .. code-block::
  731. if rank == 0:
  732. data = mge.tensor(1)
  733. # Tensor(1)
  734. F.distributed.remote_send(data, 1) # return None
  735. else:
  736. data = F.distributed.remote_recv(0)
  737. # Tensor(1)
  738. """
  739. group = _SendRecvGroup(src_rank, get_rank())
  740. shape, dtype = _bcast_shape_dtype(group, None)
  741. if device is None:
  742. device = get_default_device()
  743. # dummy input
  744. if inp is None:
  745. inp = Tensor(0, device=device)
  746. _bcast_tracer_state(group, inp)
  747. _isscalar = False
  748. if len(shape) == 0:
  749. shape = (1,)
  750. _isscalar = True
  751. op = RemoteRecv()
  752. op.key = group.key
  753. op.cn = device
  754. op.shape = shape
  755. op.dtype = dtype
  756. op.addr, op.port = get_mm_server_addr()
  757. op.rank_from = src_rank
  758. op.backend = _backend()
  759. (ret,) = apply(_RemoteRecv(op), inp)
  760. if _isscalar:
  761. setscalar(ret)
  762. return ret