You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905
  1. # -*- coding: utf-8 -*-
  2. from typing import Optional, Tuple
  3. import numpy as np
  4. from ..core._imperative_rt.core2 import apply
  5. from ..core.autodiff.grad import Function, _grad_manager_dict
  6. from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
  7. from ..core.tensor.utils import isscalar
  8. from ..device import get_default_device, what_is_xpu
  9. from ..tensor import Tensor
  10. from . import group
  11. from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank
  12. __all__ = [
  13. "reduce_sum",
  14. "broadcast",
  15. "all_gather",
  16. "reduce_scatter_sum",
  17. "all_reduce_sum",
  18. "all_reduce_max",
  19. "all_reduce_min",
  20. "gather",
  21. "scatter",
  22. "all_to_all",
  23. "remote_send",
  24. "remote_recv",
  25. ]
  26. _device2backend = {
  27. "gpu": "nccl",
  28. "cuda": "nccl",
  29. "rocm": "rccl",
  30. }
  31. def _backend():
  32. if group._sd.backend == "auto":
  33. return _device2backend[what_is_xpu()]
  34. else:
  35. return group._sd.backend
  36. def collective_comm(inp, mode, group, device):
  37. r"""Helper function for applying collective communication functions."""
  38. assert isinstance(group, Group)
  39. if group is None:
  40. return inp
  41. if device is None:
  42. device = ""
  43. addr, port = get_mm_server_addr()
  44. op = CollectiveComm(
  45. key=group.key + _backend(),
  46. nr_devices=group.size,
  47. rank=group.rank,
  48. is_root=(group.rank == 0),
  49. local_grad=False,
  50. addr=addr,
  51. port=port,
  52. mode=mode,
  53. dtype=inp.dtype,
  54. backend=_backend(),
  55. comp_node=device,
  56. )
  57. (result,) = apply(op, inp)
  58. # assume all workers have homogeneous shape
  59. return result
  60. def _save_output_for_autodiff(inp, out):
  61. for g in _grad_manager_dict.values():
  62. if g._is_attached_to(inp):
  63. g._refkeeper.append(out)
  64. def _bcast_has_grad(group, grad):
  65. if group.rank == 0:
  66. has_grad = grad is not None
  67. get_client().bcast_val(has_grad, group.key, group.size)
  68. else:
  69. has_grad = get_client().bcast_val(None, group.key, group.size)
  70. return has_grad
  71. def _bcast_shape_dtype(group, inp):
  72. if group.rank == 0:
  73. # FIXME in some cases, shape is not available(output of condtake)
  74. shape = inp._tuple_shape
  75. dtype = np.dtype(inp.dtype).name
  76. get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size)
  77. else:
  78. val = get_client().bcast_val(None, group.key, group.size)
  79. shape = val["shape"]
  80. dtype = val["dtype"]
  81. return shape, dtype
  82. def _bcast_tracer_state(group, inp):
  83. if group.rank == 0:
  84. tracer_keys = []
  85. for n, g in _grad_manager_dict.items():
  86. if g._is_attached_to(inp):
  87. tracer_keys.append(n)
  88. get_client().bcast_val(tracer_keys, group.key, group.size)
  89. else:
  90. tracer_keys = get_client().bcast_val(None, group.key, group.size)
  91. for n in tracer_keys:
  92. g = _grad_manager_dict.get(n)
  93. if g is not None:
  94. g.wrt(inp)
  95. g._refkeeper.append(inp)
  96. def _dummy_input(shape, dtype, device=None):
  97. if device is None:
  98. device = get_default_device()
  99. inp = Tensor(0, dtype=dtype, device=device)
  100. if len(shape) > 0:
  101. inp = inp._broadcast(shape)
  102. return inp
  103. class _ReduceSum(Function):
  104. def __init__(self, group=WORLD, device=None):
  105. self.group = group
  106. self.out_device = device
  107. def forward(self, data):
  108. self.in_device = str(data.device)
  109. return collective_comm(
  110. data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device,
  111. )
  112. def backward(self, grad):
  113. has_grad = _bcast_has_grad(self.group, grad)
  114. if has_grad:
  115. return broadcast(grad, self.group, self.in_device)
  116. def reduce_sum(
  117. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  118. ) -> Tensor:
  119. r"""Reduce tensor data across the specified group by sum.
  120. Only root process will receive the final result.
  121. Args:
  122. inp: Input tensor.
  123. group: The process group to work on.
  124. The default group is WORLD which means all processes available.
  125. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  126. device: The specific device to execute this operator.
  127. None default device means the device of inp will be used.
  128. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  129. 1 is stream id, and default stream id is 0.
  130. Returns:
  131. Reduced tensor if in root process, None in other processes.
  132. Examples:
  133. .. code-block::
  134. input = Tensor([rank])
  135. # Rank 0 # input: Tensor([0])
  136. # Rank 1 # input: Tensor([1])
  137. output = reduce_sum(input)
  138. # Rank 0 # output: Tensor([1])
  139. # Rank 1 # output: None
  140. input = Tensor([rank])
  141. group = Group([1, 0]) # first rank is root
  142. output = reduce_sum(input, group)
  143. # Rank 0 # output: None
  144. # Rank 1 # output: Tensor([1])
  145. """
  146. out = _ReduceSum(group, device)(inp)
  147. if group.rank == 0:
  148. return out
  149. else:
  150. _save_output_for_autodiff(inp, out)
  151. class _Broadcast(Function):
  152. def __init__(self, group=WORLD, device=None):
  153. self.group = group
  154. self.out_device = device
  155. def forward(self, data):
  156. self.in_device = str(data.device)
  157. return collective_comm(
  158. data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device,
  159. )
  160. def backward(self, grad):
  161. # TODO backward with a part of grad
  162. if grad is not None:
  163. return reduce_sum(grad, self.group, self.in_device)
  164. def broadcast(
  165. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  166. ) -> Tensor:
  167. r"""Broadcast tensor data from root process to others.
  168. Args:
  169. inp: Input tensor.
  170. group: The process group to work on.
  171. The default group is WORLD which means all processes available.
  172. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  173. device: The specific device to execute this operator.
  174. None default device means the device of inp will be used.
  175. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  176. 1 is stream id, and default stream id is 0.
  177. Returns:
  178. Result tensor.
  179. Examples:
  180. .. code-block::
  181. input = Tensor([rank])
  182. # Rank 0 # input: Tensor([0])
  183. # Rank 1 # input: Tensor([1])
  184. output = broadcast(input)
  185. # Rank 0 # output: Tensor([0])
  186. # Rank 1 # output: Tensor([0])
  187. input = Tensor([rank])
  188. group = Group([1, 0]) # first rank is root
  189. output = broadcast(input, group)
  190. # Rank 0 # output: Tensor([1])
  191. # Rank 1 # output: Tensor([1])
  192. """
  193. shape, dtype = _bcast_shape_dtype(group, inp)
  194. if group.rank != 0:
  195. # dummy input to infer shape
  196. inp = _dummy_input(shape, dtype, device)
  197. _bcast_tracer_state(group, inp)
  198. out = _Broadcast(group, device)(inp)
  199. return out
  200. def _bcast_param(
  201. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None
  202. ) -> Tensor:
  203. mode = CollectiveComm.Mode.BROADCAST
  204. return collective_comm(inp, mode, group, device)
  205. def all_gather(
  206. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  207. ) -> Tensor:
  208. r"""Gather tensors across the specified group and concat them at first dimension.
  209. Args:
  210. inp: Input tensor.
  211. group: The process group to work on.
  212. The default group is WORLD which means all processes available.
  213. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  214. device: The specific device to execute this operator.
  215. None default device means the device of inp will be used.
  216. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  217. 1 is stream id, and default stream id is 0.
  218. axis: The concat axis for collective_comm result
  219. The default axis is 0
  220. Returns:
  221. Result tensor.
  222. Examples:
  223. .. code-block::
  224. input = Tensor([rank])
  225. # Rank 0 # input: Tensor([0])
  226. # Rank 1 # input: Tensor([1])
  227. output = all_gather(input)
  228. # Rank 0 # output: Tensor([0 1])
  229. # Rank 1 # output: Tensor([0 1])
  230. input = Tensor([rank])
  231. group = Group([1, 0])
  232. output = all_gather(input, group)
  233. # Rank 0 # output: Tensor([1 0])
  234. # Rank 1 # output: Tensor([1 0])
  235. """
  236. mode = CollectiveComm.Mode.ALL_GATHER
  237. out = collective_comm(inp, mode, group, device)
  238. if axis == 0:
  239. return out
  240. else:
  241. group_size = group.size if group is not None else 1
  242. transformed_shape = list(inp._tuple_shape)
  243. transformed_shape[axis] *= group_size
  244. n, *shp = out._tuple_shape
  245. index = (
  246. [_ for _ in range(1, axis)]
  247. + [axis, 0]
  248. + [_ for _ in range(axis + 1, out.ndim + 1)]
  249. )
  250. return (
  251. out.reshape(group_size, n // group_size, *shp)
  252. .transpose(index)
  253. .reshape(transformed_shape)
  254. )
  255. def reduce_scatter_sum(
  256. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0
  257. ) -> Tensor:
  258. r"""Reduce tensors across the specified group by sum and split them at first dimension.
  259. Args:
  260. inp: Input tensor.
  261. group: The process group to work on.
  262. The default group is WORLD which means all processes available.
  263. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  264. device: The specific device to execute this operator.
  265. None default device means the device of inp will be used.
  266. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  267. 1 is stream id, and default stream id is 0.
  268. axis: The split axis for collective_comm result
  269. The default axis is 0, the data will split in the 0 axis
  270. Returns:
  271. Split tensor.
  272. Examples:
  273. .. code-block::
  274. input = Tensor([0 1])
  275. # Rank 0 # input: Tensor([0 1])
  276. # Rank 1 # input: Tensor([0 1])
  277. output = reduce_scatter_sum(input)
  278. # Rank 0 # output: Tensor([0])
  279. # Rank 1 # output: Tensor([2])
  280. input = Tensor([0 1])
  281. group = Group([1, 0])
  282. output = reduce_scatter_sum(input, group)
  283. # Rank 0 # output: Tensor([2])
  284. # Rank 1 # output: Tensor([0])
  285. """
  286. group_size = group.size if group is not None else 1
  287. assert (
  288. list(inp._tuple_shape)[axis] % group_size == 0
  289. ), "current axis: {} can't devided by group size".format(axis)
  290. if axis != 0:
  291. k_new_shape = list(inp._tuple_shape)
  292. k_new_shape[axis] //= group_size
  293. k_new_shape[0] *= group_size
  294. new_shape = list(inp._tuple_shape)
  295. new_shape[axis] //= group_size
  296. new_shape.insert(axis, group_size)
  297. index = (
  298. [axis]
  299. + [_ for _ in range(0, axis)]
  300. + [_ for _ in range(axis + 1, inp.ndim + 1)]
  301. )
  302. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  303. mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
  304. return collective_comm(inp, mode, group, device)
  305. def all_reduce_sum(
  306. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  307. ) -> Tensor:
  308. r"""Reduce tensors with sum operation on each value across the specified group.
  309. Note:
  310. ``inp`` tensor must have identical shape in all processes across the group.
  311. Args:
  312. inp (Tensor): tensor to be reduced.
  313. Keyword args:
  314. group (Group or sequence of ints): the process group to work on. Default: ``WORLD``.
  315. ``WORLD`` group selects all processes available.
  316. list of process rank as parameter will create a new group to work on.
  317. device (:attr:`.Tensor.device`): the specific device to execute this operator. Default: ``None``
  318. ``None`` will select the device of ``inp`` to execute.
  319. Specially, ``GPU`` device can assign a different stream to execute
  320. by adding a number right after a colon following the device name while
  321. ``:0`` denotes default stream of GPU, otherwise will use default stream.
  322. Returns:
  323. A tensor with sum operation on each value across the group.
  324. The shape of the output tensor must be the same as ``inp``, and the output
  325. tensor is going to be bitwise identical in all processes across the group.
  326. Examples:
  327. >>> # We execute all_reduce_sum on rank 0 and rank 1
  328. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  329. >>> input # doctest: +SKIP
  330. Tensor([1. 2.], device=xpux:0) # Rank 0
  331. Tensor([3. 4.], device=xpux:0) # Rank 1
  332. >>> F.distributed.all_reduce_sum(input, group=[0, 1]) # doctest: +SKIP
  333. Tensor([4. 6.], device=xpux:0) # Rank 0
  334. Tensor([4. 6.], device=xpux:0) # Rank 1
  335. >>> # We execute all_reduce_sum with on gpu0 with cuda stream 1
  336. >>> megengine.set_default_device("gpu0") # doctest: +SKIP
  337. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  338. >>> input # doctest: +SKIP
  339. Tensor([1. 2.], device=gpu0:0) # Rank 0
  340. Tensor([3. 4.], device=gpu0:0) # Rank 1
  341. >>> F.distributed.all_reduce_sum(input, device="gpu0:1") # doctest: +SKIP
  342. Tensor([4. 6.], device=gpu0:0) # Rank 0
  343. Tensor([4. 6.], device=gpu0:0) # Rank 1
  344. """
  345. mode = CollectiveComm.Mode.ALL_REDUCE_SUM
  346. return collective_comm(inp, mode, group, device)
  347. def all_reduce_max(
  348. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  349. ) -> Tensor:
  350. r"""Reduce tensors with max operation on each value across the specified group.
  351. Note:
  352. ``inp`` tensor must have identical shape in all processes across the group.
  353. Args:
  354. inp (Tensor): tensor to be reduced.
  355. Keyword args:
  356. group (Group or sequence of ints): the process group to work on. Default: ``WORLD``.
  357. ``WORLD`` group selects all processes available.
  358. list of process rank as parameter will create a new group to work on.
  359. device (:attr:`.Tensor.device`): the specific device to execute this operator. Default: ``None``
  360. ``None`` will select the device of ``inp`` to execute.
  361. Specially, ``GPU`` device can assign a different stream to execute
  362. by adding a number right after a colon following the device name while
  363. ``:0`` denotes default stream of GPU, otherwise will use default stream.
  364. Returns:
  365. A tensor with max operation on each value across the group.
  366. The shape of the output tensor must be the same as ``inp``, and the output
  367. tensor is going to be bitwise identical in all processes across the group.
  368. Examples:
  369. >>> # We execute all_reduce_max on rank 0 and rank 1
  370. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  371. >>> input # doctest: +SKIP
  372. Tensor([1. 2.], device=xpux:0) # Rank 0
  373. Tensor([3. 4.], device=xpux:0) # Rank 1
  374. >>> F.distributed.all_reduce_max(input, group=[0, 1]) # doctest: +SKIP
  375. Tensor([3. 4.], device=xpux:0) # Rank 0
  376. Tensor([3. 4.], device=xpux:0) # Rank 1
  377. >>> # We execute all_reduce_max with on gpu0 with cuda stream 1
  378. >>> megengine.set_default_device("gpu0") # doctest: +SKIP
  379. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  380. >>> input # doctest: +SKIP
  381. Tensor([1. 2.], device=gpu0:0) # Rank 0
  382. Tensor([3. 4.], device=gpu0:0) # Rank 1
  383. >>> F.distributed.all_reduce_max(input, device="gpu0:1") # doctest: +SKIP
  384. Tensor([3. 4.], device=xpux:0) # Rank 0
  385. Tensor([3. 4.], device=xpux:0) # Rank 1
  386. """
  387. mode = CollectiveComm.Mode.ALL_REDUCE_MAX
  388. return collective_comm(inp, mode, group, device)
  389. def all_reduce_min(
  390. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  391. ) -> Tensor:
  392. r"""Reduce tensors with min operation on each value across the specified group.
  393. Note:
  394. ``inp`` tensor must have identical shape in all processes across the group.
  395. Args:
  396. inp (Tensor): tensor to be reduced.
  397. Keyword args:
  398. group (Group or sequence of ints): the process group to work on. Default: ``WORLD``.
  399. ``WORLD`` group selects all processes available.
  400. list of process rank as parameter will create a new group to work on.
  401. device (:attr:`.Tensor.device`): the specific device to execute this operator. Default: ``None``
  402. ``None`` will select the device of ``inp`` to execute.
  403. Specially, ``GPU`` device can assign a different stream to execute
  404. by adding a number right after a colon following the device name while
  405. ``:0`` denotes default stream of GPU, otherwise will use default stream.
  406. Returns:
  407. A tensor with min operation on each value across the group.
  408. The shape of the output tensor must be the same as ``inp``, and the output
  409. tensor is going to be bitwise identical in all processes across the group.
  410. Examples:
  411. >>> # We execute all_reduce_min on rank 0 and rank 1
  412. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  413. >>> input # doctest: +SKIP
  414. Tensor([1. 2.], device=xpux:0) # Rank 0
  415. Tensor([3. 4.], device=xpux:0) # Rank 1
  416. >>> F.distributed.all_reduce_min(input, group=[0, 1]) # doctest: +SKIP
  417. Tensor([1. 2.], device=xpux:0) # Rank 0
  418. Tensor([1. 2.], device=xpux:0) # Rank 1
  419. >>> # We execute all_reduce_min with on gpu0 with cuda stream 1
  420. >>> megengine.set_default_device("gpu0") # doctest: +SKIP
  421. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  422. >>> input # doctest: +SKIP
  423. Tensor([1. 2.], device=gpu0:0) # Rank 0
  424. Tensor([3. 4.], device=gpu0:0) # Rank 1
  425. >>> F.distributed.all_reduce_min(input, device="gpu0:1") # doctest: +SKIP
  426. Tensor([1. 2.], device=xpux:0) # Rank 0
  427. Tensor([1. 2.], device=xpux:0) # Rank 1
  428. """
  429. mode = CollectiveComm.Mode.ALL_REDUCE_MIN
  430. return collective_comm(inp, mode, group, device)
  431. class _Gather(Function):
  432. def __init__(self, group=WORLD, device=None):
  433. self.group = group
  434. self.out_device = device
  435. def forward(self, data):
  436. self.in_device = str(data.device)
  437. return collective_comm(
  438. data, CollectiveComm.Mode.GATHER, self.group, self.out_device
  439. )
  440. def backward(self, grad):
  441. has_grad = _bcast_has_grad(self.group, grad)
  442. if has_grad:
  443. return scatter(grad, self.group, self.in_device)
  444. def gather(
  445. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  446. ) -> Tensor:
  447. r"""Gather tensors across the specified group.
  448. Only root process will receive the final result.
  449. Args:
  450. inp: Input tensor.
  451. group: The process group to work on.
  452. The default group is WORLD which means all processes available.
  453. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  454. device: The specific device to execute this operator.
  455. None default device means the device of inp will be used.
  456. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  457. 1 is stream id, and default stream id is 0.
  458. axis: The concat axis for collective_comm result
  459. Examples:
  460. .. code-block::
  461. input = Tensor([rank])
  462. # Rank 0 # input: Tensor([0])
  463. # Rank 1 # input: Tensor([1])
  464. output = gather(input)
  465. # Rank 0 # output: Tensor([0 1])
  466. # Rank 1 # output: None
  467. input = Tensor([rank])
  468. group = Group([1, 0]) # first rank is root
  469. output = gather(input, group)
  470. # Rank 0 # output: None
  471. # Rank 1 # output: Tensor([1 0])
  472. """
  473. assert (
  474. axis < inp.ndim
  475. ), "your concat_axis exceeds the dim of the tensor, the tensor shape is {}".format(
  476. inp.shape
  477. )
  478. out = _Gather(group, device)(inp)
  479. if group.rank == 0:
  480. if axis == 0:
  481. return out
  482. else:
  483. group_size = group.size
  484. transformed_shape = list(inp._tuple_shape)
  485. transformed_shape[axis] *= group_size
  486. n, *shp = out._tuple_shape
  487. index = (
  488. [_ for _ in range(1, axis)]
  489. + [axis, 0]
  490. + [_ for _ in range(axis + 1, out.ndim + 1)]
  491. )
  492. return (
  493. out.reshape(group_size, n // group_size, *shp)
  494. .transpose(index)
  495. .reshape(transformed_shape)
  496. )
  497. else:
  498. _save_output_for_autodiff(inp, out)
  499. class _Scatter(Function):
  500. def __init__(self, group=WORLD, device=None):
  501. self.group = group
  502. self.out_device = device
  503. def forward(self, data):
  504. self.in_device = str(data.device)
  505. return collective_comm(
  506. data, CollectiveComm.Mode.SCATTER, self.group, self.out_device
  507. )
  508. def backward(self, grad):
  509. # TODO backward with a part of grad
  510. if grad is not None:
  511. return gather(grad, self.group, self.in_device)
  512. def scatter(
  513. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  514. ) -> Tensor:
  515. r"""Split tensor in root process at first dimension.
  516. Args:
  517. inp: Input tensor.
  518. group: The process group to work on.
  519. The default group is WORLD which means all processes available.
  520. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  521. device: The specific device to execute this operator.
  522. None default device means the device of inp will be used.
  523. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  524. 1 is stream id, and default stream id is 0.
  525. axis: The concat axis for collective_comm result
  526. The default axis is 0
  527. Returns:
  528. Split tensor.
  529. Examples:
  530. .. code-block::
  531. input = Tensor([0 1]) + rank*2
  532. # Rank 0 # input: Tensor([0 1])
  533. # Rank 1 # input: Tensor([2 3])
  534. output = scatter(input)
  535. # Rank 0 # output: Tensor([0])
  536. # Rank 1 # output: Tensor([1])
  537. input = Tensor([0 1]) + rank*2
  538. group = Group([1, 0]) # first rank is root
  539. output = scatter(input, group)
  540. # Rank 0 # output: Tensor([3])
  541. # Rank 1 # output: Tensor([2])
  542. """
  543. shape, dtype = _bcast_shape_dtype(group, inp)
  544. if group.rank != 0:
  545. # dummy input to infer shape
  546. inp = _dummy_input(shape, dtype, device)
  547. _bcast_tracer_state(group, inp)
  548. assert (
  549. list(inp._tuple_shape)[axis] % group.size == 0
  550. ), "current axis: {} can't devided by group size".format(axis)
  551. if axis != 0:
  552. group_size = group.size
  553. k_new_shape = list(inp._tuple_shape)
  554. k_new_shape[axis] //= group_size
  555. k_new_shape[0] *= group_size
  556. new_shape = list(inp._tuple_shape)
  557. new_shape[axis] //= group_size
  558. new_shape.insert(axis, group_size)
  559. index = (
  560. [axis]
  561. + [_ for _ in range(0, axis)]
  562. + [_ for _ in range(axis + 1, inp.ndim + 1)]
  563. )
  564. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  565. out = _Scatter(group, device)(inp)
  566. return out
  567. def all_to_all(
  568. inp: Tensor,
  569. group: Optional[Group] = WORLD,
  570. device: Optional[str] = None,
  571. split_axis: int = 0,
  572. concat_axis: int = 0,
  573. ) -> Tensor:
  574. r"""Each process scatter input tensor to all processes and return gathered tensor.
  575. Args:
  576. inp: Input tensor.
  577. group: The process group to work on.
  578. The default group is WORLD which means all processes available.
  579. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  580. device: The specific device to execute this operator.
  581. None default device means the device of inp will be used.
  582. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  583. 1 is stream id, and default stream id is 0.
  584. split_axis: The axis that collectivecomm will split data
  585. the default axis is 0
  586. Returns:
  587. Result tensor.
  588. Examples:
  589. .. code-block::
  590. input = Tensor([0 1]) + rank*2
  591. # Rank 0 # input: Tensor([0 1])
  592. # Rank 1 # input: Tensor([2 3])
  593. output = all_to_all(input)
  594. # Rank 0 # output: Tensor([0 2])
  595. # Rank 1 # output: Tensor([1 3])
  596. input = Tensor([0 1]) + rank*2
  597. group = Group([1, 0])
  598. output = all_to_all(input, group)
  599. # Rank 0 # output: Tensor([0 3])
  600. # Rank 1 # output: Tensor([2 1])
  601. """
  602. group_size = group.size if group is not None else 1
  603. assert (
  604. list(inp._tuple_shape)[split_axis] % group_size == 0
  605. ), "current axis: {} can't devided by group size".format(split_axis)
  606. origin_shape = inp._tuple_shape
  607. if split_axis != 0:
  608. k_new_shape = list(inp._tuple_shape)
  609. k_new_shape[split_axis] //= group_size
  610. k_new_shape[0] *= group_size
  611. new_shape = list(inp._tuple_shape)
  612. new_shape[split_axis] //= group_size
  613. new_shape.insert(split_axis, group_size)
  614. index = (
  615. [split_axis]
  616. + [_ for _ in range(0, split_axis)]
  617. + [_ for _ in range(split_axis + 1, inp.ndim + 1)]
  618. )
  619. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  620. mode = CollectiveComm.Mode.ALL_TO_ALL
  621. out = collective_comm(inp, mode, group, device)
  622. if concat_axis == 0:
  623. return out
  624. transformed_shape = list(origin_shape)
  625. transformed_shape[concat_axis] *= group_size
  626. transformed_shape[split_axis] //= group_size
  627. n, *shp = out._tuple_shape
  628. index = (
  629. [_ for _ in range(1, concat_axis)]
  630. + [concat_axis, 0]
  631. + [_ for _ in range(concat_axis + 1, out.ndim + 1)]
  632. )
  633. return (
  634. out.reshape(group_size, n // group_size, *shp)
  635. .transpose(index)
  636. .reshape(transformed_shape)
  637. )
  638. class _SendRecvGroup:
  639. def __init__(self, rank_from, rank_to):
  640. self.key = "{}->{}".format(rank_from, rank_to)
  641. self.rank_from = rank_from
  642. self.rank_to = rank_to
  643. self.size = 2
  644. @property
  645. def rank(self):
  646. if get_rank() == self.rank_from:
  647. return 0
  648. else:
  649. return 1
  650. class _RemoteSend(Function):
  651. def __init__(self, op: RemoteSend):
  652. self.op = op
  653. def forward(self, data):
  654. self.device = str(data.device)
  655. (self.dummy,) = apply(self.op, data)
  656. return self.dummy
  657. def backward(self, grad):
  658. assert grad is None
  659. has_grad = get_client().bcast_val(None, self.op.key, 2)
  660. if has_grad:
  661. return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,)
  662. class _RemoteRecv(Function):
  663. def __init__(self, op: RemoteRecv):
  664. self.op = op
  665. def forward(self, dummy):
  666. return apply(self.op, dummy)[0]
  667. def backward(self, grad):
  668. get_client().bcast_val(grad is not None, self.op.key, 2)
  669. if grad is not None:
  670. remote_send(grad, self.op.rank_from)
  671. def remote_send(inp: Tensor, dest_rank: int):
  672. r"""Send tensor to another process.
  673. Args:
  674. inp: Tensor to send.
  675. dest_rank: Rank of destination process.
  676. Returns:
  677. None.
  678. Examples:
  679. .. code-block::
  680. if rank == 0:
  681. data = mge.tensor(1)
  682. # Tensor(1)
  683. F.distributed.remote_send(data, 1) # return None
  684. else:
  685. data = F.distributed.remote_recv(0)
  686. # Tensor(1)
  687. """
  688. group = _SendRecvGroup(get_rank(), dest_rank)
  689. _bcast_shape_dtype(group, inp)
  690. _bcast_tracer_state(group, inp)
  691. op = RemoteSend()
  692. op.key = group.key
  693. op.addr, op.port = get_mm_server_addr()
  694. op.rank_to = dest_rank
  695. op.backend = _backend()
  696. out = _RemoteSend(op)(inp)
  697. _save_output_for_autodiff(inp, out)
  698. def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor:
  699. r"""Receive a tensor from another process.
  700. Args:
  701. src_rank: Rank of source process.
  702. device: The specific device to execute this operator.
  703. None default device means the device of inp will be used.
  704. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  705. 1 is stream id, and default stream id is 0.
  706. inp: Dummy input to determine received tensor type.
  707. Returns:
  708. Received tensor.
  709. Examples:
  710. .. code-block::
  711. if rank == 0:
  712. data = mge.tensor(1)
  713. # Tensor(1)
  714. F.distributed.remote_send(data, 1) # return None
  715. else:
  716. data = F.distributed.remote_recv(0)
  717. # Tensor(1)
  718. """
  719. group = _SendRecvGroup(src_rank, get_rank())
  720. shape, dtype = _bcast_shape_dtype(group, None)
  721. if device is None:
  722. device = get_default_device()
  723. # dummy input
  724. if inp is None:
  725. inp = Tensor(0, device=device)
  726. _bcast_tracer_state(group, inp)
  727. op = RemoteRecv()
  728. op.key = group.key
  729. op.cn = device
  730. op.shape = shape
  731. op.dtype = dtype
  732. op.addr, op.port = get_mm_server_addr()
  733. op.rank_from = src_rank
  734. op.backend = _backend()
  735. ret = _RemoteRecv(op)(inp)
  736. return ret