You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929
  1. # -*- coding: utf-8 -*-
  2. from typing import Optional
  3. import numpy as np
  4. from ..core._imperative_rt.core2 import apply
  5. from ..core.autodiff.grad import Function, _grad_manager_dict
  6. from ..core.ops.builtin import CollectiveComm, RemoteRecv, RemoteSend
  7. from ..device import get_default_device, what_is_xpu
  8. from ..tensor import Tensor
  9. from . import group
  10. from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank
  11. __all__ = [
  12. "reduce_sum",
  13. "broadcast",
  14. "all_gather",
  15. "reduce_scatter_sum",
  16. "all_reduce_sum",
  17. "all_reduce_max",
  18. "all_reduce_min",
  19. "gather",
  20. "scatter",
  21. "all_to_all",
  22. "remote_send",
  23. "remote_recv",
  24. ]
  25. _device2backend = {
  26. "gpu": "nccl",
  27. "cuda": "nccl",
  28. "rocm": "rccl",
  29. }
  30. def _backend():
  31. if group._sd.backend == "auto":
  32. return _device2backend[what_is_xpu()]
  33. else:
  34. return group._sd.backend
  35. def collective_comm(inp, mode, group, device):
  36. r"""Helper function for applying collective communication functions."""
  37. assert isinstance(group, Group)
  38. if group is None:
  39. return inp
  40. if device is None:
  41. device = ""
  42. addr, port = get_mm_server_addr()
  43. op = CollectiveComm(
  44. key=group.key + _backend(),
  45. nr_devices=group.size,
  46. rank=group.rank,
  47. is_root=(group.rank == 0),
  48. local_grad=False,
  49. addr=addr,
  50. port=port,
  51. mode=mode,
  52. dtype=inp.dtype,
  53. backend=_backend(),
  54. comp_node=device,
  55. )
  56. (result,) = apply(op, inp)
  57. # assume all workers have homogeneous shape
  58. return result
  59. def _save_output_for_autodiff(inp, out):
  60. for g in _grad_manager_dict.values():
  61. if g._is_attached_to(inp):
  62. g._refkeeper.append(out)
  63. def _bcast_has_grad(group, grad):
  64. if group.rank == 0:
  65. has_grad = grad is not None
  66. get_client().bcast_val(has_grad, group.key, group.size)
  67. else:
  68. has_grad = get_client().bcast_val(None, group.key, group.size)
  69. return has_grad
  70. def _bcast_shape_dtype(group, inp):
  71. if group.rank == 0:
  72. # FIXME in some cases, shape is not available(output of condtake)
  73. shape = inp._tuple_shape
  74. dtype = np.dtype(inp.dtype).name
  75. get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size)
  76. else:
  77. val = get_client().bcast_val(None, group.key, group.size)
  78. shape = val["shape"]
  79. dtype = val["dtype"]
  80. return shape, dtype
  81. def _bcast_tracer_state(group, inp):
  82. if group.rank == 0:
  83. tracer_keys = []
  84. for n, g in _grad_manager_dict.items():
  85. if g._is_attached_to(inp):
  86. tracer_keys.append(n)
  87. get_client().bcast_val(tracer_keys, group.key, group.size)
  88. else:
  89. tracer_keys = get_client().bcast_val(None, group.key, group.size)
  90. for n in tracer_keys:
  91. g = _grad_manager_dict.get(n)
  92. if g is not None:
  93. g.wrt(inp)
  94. g._refkeeper.append(inp)
  95. def _dummy_input(shape, dtype, device=None):
  96. if device is None:
  97. device = get_default_device()
  98. inp = Tensor(0, dtype=dtype, device=device)
  99. if len(shape) > 0:
  100. inp = inp._broadcast(shape)
  101. return inp
  102. class _ReduceSum(Function):
  103. def __init__(self, group=WORLD, device=None):
  104. self.group = group
  105. self.out_device = device
  106. def forward(self, data):
  107. self.in_device = str(data.device)
  108. return collective_comm(
  109. data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device,
  110. )
  111. def backward(self, grad):
  112. has_grad = _bcast_has_grad(self.group, grad)
  113. if has_grad:
  114. return broadcast(grad, self.group, self.in_device)
  115. def reduce_sum(
  116. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  117. ) -> Tensor:
  118. r"""Reduce tensor data across the specified group by sum.
  119. Only root process will receive the final result.
  120. Args:
  121. inp: Input tensor.
  122. group: The process group to work on.
  123. The default group is WORLD which means all processes available.
  124. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  125. device: The specific device to execute this operator.
  126. None default device means the device of inp will be used.
  127. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  128. 1 is stream id, and default stream id is 0.
  129. Returns:
  130. Reduced tensor if in root process, None in other processes.
  131. Examples:
  132. .. code-block::
  133. input = Tensor([rank])
  134. # Rank 0 # input: Tensor([0])
  135. # Rank 1 # input: Tensor([1])
  136. output = reduce_sum(input)
  137. # Rank 0 # output: Tensor([1])
  138. # Rank 1 # output: None
  139. input = Tensor([rank])
  140. group = Group([1, 0]) # first rank is root
  141. output = reduce_sum(input, group)
  142. # Rank 0 # output: None
  143. # Rank 1 # output: Tensor([1])
  144. """
  145. out = _ReduceSum(group, device)(inp)
  146. if group.rank == 0:
  147. return out
  148. else:
  149. _save_output_for_autodiff(inp, out)
  150. class _Broadcast(Function):
  151. def __init__(self, group=WORLD, device=None):
  152. self.group = group
  153. self.out_device = device
  154. def forward(self, data):
  155. self.in_device = str(data.device)
  156. return collective_comm(
  157. data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device,
  158. )
  159. def backward(self, grad):
  160. # TODO backward with a part of grad
  161. if grad is not None:
  162. return reduce_sum(grad, self.group, self.in_device)
  163. def broadcast(
  164. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  165. ) -> Tensor:
  166. r"""Broadcast tensor data from root process to others.
  167. Args:
  168. inp: Input tensor.
  169. group: The process group to work on.
  170. The default group is WORLD which means all processes available.
  171. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  172. device: The specific device to execute this operator.
  173. None default device means the device of inp will be used.
  174. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  175. 1 is stream id, and default stream id is 0.
  176. Returns:
  177. Result tensor.
  178. Examples:
  179. .. code-block::
  180. input = Tensor([rank])
  181. # Rank 0 # input: Tensor([0])
  182. # Rank 1 # input: Tensor([1])
  183. output = broadcast(input)
  184. # Rank 0 # output: Tensor([0])
  185. # Rank 1 # output: Tensor([0])
  186. input = Tensor([rank])
  187. group = Group([1, 0]) # first rank is root
  188. output = broadcast(input, group)
  189. # Rank 0 # output: Tensor([1])
  190. # Rank 1 # output: Tensor([1])
  191. """
  192. shape, dtype = _bcast_shape_dtype(group, inp)
  193. if group.rank != 0:
  194. # dummy input to infer shape
  195. inp = _dummy_input(shape, dtype, device)
  196. _bcast_tracer_state(group, inp)
  197. out = _Broadcast(group, device)(inp)
  198. return out
  199. def _bcast_param(
  200. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None
  201. ) -> Tensor:
  202. mode = CollectiveComm.Mode.BROADCAST
  203. return collective_comm(inp, mode, group, device)
  204. def all_gather(
  205. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  206. ) -> Tensor:
  207. r"""Gather tensors across the specified group and concat them at first dimension.
  208. Args:
  209. inp: Input tensor.
  210. group: The process group to work on.
  211. The default group is WORLD which means all processes available.
  212. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  213. device: The specific device to execute this operator.
  214. None default device means the device of inp will be used.
  215. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  216. 1 is stream id, and default stream id is 0.
  217. axis: The concat axis for collective_comm result
  218. The default axis is 0
  219. Returns:
  220. Result tensor.
  221. Examples:
  222. .. code-block::
  223. input = Tensor([rank])
  224. # Rank 0 # input: Tensor([0])
  225. # Rank 1 # input: Tensor([1])
  226. output = all_gather(input)
  227. # Rank 0 # output: Tensor([0 1])
  228. # Rank 1 # output: Tensor([0 1])
  229. input = Tensor([rank])
  230. group = Group([1, 0])
  231. output = all_gather(input, group)
  232. # Rank 0 # output: Tensor([1 0])
  233. # Rank 1 # output: Tensor([1 0])
  234. """
  235. mode = CollectiveComm.Mode.ALL_GATHER
  236. out = collective_comm(inp, mode, group, device)
  237. if axis == 0:
  238. return out
  239. else:
  240. group_size = group.size if group is not None else 1
  241. transformed_shape = list(inp._tuple_shape)
  242. transformed_shape[axis] *= group_size
  243. n, *shp = out._tuple_shape
  244. index = (
  245. [_ for _ in range(1, axis)]
  246. + [axis, 0]
  247. + [_ for _ in range(axis + 1, out.ndim + 1)]
  248. )
  249. return (
  250. out.reshape(group_size, n // group_size, *shp)
  251. .transpose(index)
  252. .reshape(transformed_shape)
  253. )
  254. def reduce_scatter_sum(
  255. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0
  256. ) -> Tensor:
  257. r"""Reduce tensors across the specified group by sum and split them at first dimension.
  258. Args:
  259. inp: Input tensor.
  260. group: The process group to work on.
  261. The default group is WORLD which means all processes available.
  262. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  263. device: The specific device to execute this operator.
  264. None default device means the device of inp will be used.
  265. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  266. 1 is stream id, and default stream id is 0.
  267. axis: The split axis for collective_comm result
  268. The default axis is 0, the data will split in the 0 axis
  269. Returns:
  270. Split tensor.
  271. Examples:
  272. .. code-block::
  273. input = Tensor([0 1])
  274. # Rank 0 # input: Tensor([0 1])
  275. # Rank 1 # input: Tensor([0 1])
  276. output = reduce_scatter_sum(input)
  277. # Rank 0 # output: Tensor([0])
  278. # Rank 1 # output: Tensor([2])
  279. input = Tensor([0 1])
  280. group = Group([1, 0])
  281. output = reduce_scatter_sum(input, group)
  282. # Rank 0 # output: Tensor([2])
  283. # Rank 1 # output: Tensor([0])
  284. """
  285. group_size = group.size if group is not None else 1
  286. assert (
  287. list(inp._tuple_shape)[axis] % group_size == 0
  288. ), "current axis: {} can't devided by group size".format(axis)
  289. if axis != 0:
  290. k_new_shape = list(inp._tuple_shape)
  291. k_new_shape[axis] //= group_size
  292. k_new_shape[0] *= group_size
  293. new_shape = list(inp._tuple_shape)
  294. new_shape[axis] //= group_size
  295. new_shape.insert(axis, group_size)
  296. index = (
  297. [axis]
  298. + [_ for _ in range(0, axis)]
  299. + [_ for _ in range(axis + 1, inp.ndim + 1)]
  300. )
  301. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  302. mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
  303. return collective_comm(inp, mode, group, device)
  304. def all_reduce_sum(
  305. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  306. ) -> Tensor:
  307. r"""Reduce tensors with sum operation on each value across the specified group.
  308. Note:
  309. ``inp`` tensor must have identical shape in all processes across the group.
  310. Args:
  311. inp (Tensor): tensor to be reduced.
  312. Keyword args:
  313. group (Group or sequence of ints): the process group to work on. Default: ``WORLD``.
  314. ``WORLD`` group selects all processes available.
  315. list of process rank as parameter will create a new group to work on.
  316. device (:attr:`.Tensor.device`): the specific device to execute this operator. Default: ``None``
  317. ``None`` will select the device of ``inp`` to execute.
  318. Specially, ``GPU`` device can assign a different stream to execute
  319. by adding a number right after a colon following the device name while
  320. ``:0`` denotes default stream of GPU, otherwise will use default stream.
  321. Returns:
  322. A tensor with sum operation on each value across the group.
  323. The shape of the output tensor must be the same as ``inp``, and the output
  324. tensor is going to be bitwise identical in all processes across the group.
  325. Examples:
  326. >>> # We execute all_reduce_sum on rank 0 and rank 1
  327. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  328. >>> input # doctest: +SKIP
  329. Tensor([1. 2.], device=xpux:0) # Rank 0
  330. Tensor([3. 4.], device=xpux:0) # Rank 1
  331. >>> F.distributed.all_reduce_sum(input, group=[0, 1]) # doctest: +SKIP
  332. Tensor([4. 6.], device=xpux:0) # Rank 0
  333. Tensor([4. 6.], device=xpux:0) # Rank 1
  334. >>> # We execute all_reduce_sum with on gpu0 with cuda stream 1
  335. >>> megengine.set_default_device("gpu0") # doctest: +SKIP
  336. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  337. >>> input # doctest: +SKIP
  338. Tensor([1. 2.], device=gpu0:0) # Rank 0
  339. Tensor([3. 4.], device=gpu0:0) # Rank 1
  340. >>> F.distributed.all_reduce_sum(input, device="gpu0:1") # doctest: +SKIP
  341. Tensor([4. 6.], device=gpu0:0) # Rank 0
  342. Tensor([4. 6.], device=gpu0:0) # Rank 1
  343. """
  344. mode = CollectiveComm.Mode.ALL_REDUCE_SUM
  345. return collective_comm(inp, mode, group, device)
  346. def all_reduce_max(
  347. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  348. ) -> Tensor:
  349. r"""Reduce tensors with max operation on each value across the specified group.
  350. Note:
  351. ``inp`` tensor must have identical shape in all processes across the group.
  352. Args:
  353. inp (Tensor): tensor to be reduced.
  354. Keyword args:
  355. group (Group or sequence of ints): the process group to work on. Default: ``WORLD``.
  356. ``WORLD`` group selects all processes available.
  357. list of process rank as parameter will create a new group to work on.
  358. device (:attr:`.Tensor.device`): the specific device to execute this operator. Default: ``None``
  359. ``None`` will select the device of ``inp`` to execute.
  360. Specially, ``GPU`` device can assign a different stream to execute
  361. by adding a number right after a colon following the device name while
  362. ``:0`` denotes default stream of GPU, otherwise will use default stream.
  363. Returns:
  364. A tensor with max operation on each value across the group.
  365. The shape of the output tensor must be the same as ``inp``, and the output
  366. tensor is going to be bitwise identical in all processes across the group.
  367. Examples:
  368. >>> # We execute all_reduce_max on rank 0 and rank 1
  369. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  370. >>> input # doctest: +SKIP
  371. Tensor([1. 2.], device=xpux:0) # Rank 0
  372. Tensor([3. 4.], device=xpux:0) # Rank 1
  373. >>> F.distributed.all_reduce_max(input, group=[0, 1]) # doctest: +SKIP
  374. Tensor([3. 4.], device=xpux:0) # Rank 0
  375. Tensor([3. 4.], device=xpux:0) # Rank 1
  376. >>> # We execute all_reduce_max with on gpu0 with cuda stream 1
  377. >>> megengine.set_default_device("gpu0") # doctest: +SKIP
  378. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  379. >>> input # doctest: +SKIP
  380. Tensor([1. 2.], device=gpu0:0) # Rank 0
  381. Tensor([3. 4.], device=gpu0:0) # Rank 1
  382. >>> F.distributed.all_reduce_max(input, device="gpu0:1") # doctest: +SKIP
  383. Tensor([3. 4.], device=xpux:0) # Rank 0
  384. Tensor([3. 4.], device=xpux:0) # Rank 1
  385. """
  386. mode = CollectiveComm.Mode.ALL_REDUCE_MAX
  387. return collective_comm(inp, mode, group, device)
  388. def all_reduce_min(
  389. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  390. ) -> Tensor:
  391. r"""Reduce tensors with min operation on each value across the specified group.
  392. Note:
  393. ``inp`` tensor must have identical shape in all processes across the group.
  394. Args:
  395. inp (Tensor): tensor to be reduced.
  396. Keyword args:
  397. group (Group or sequence of ints): the process group to work on. Default: ``WORLD``.
  398. ``WORLD`` group selects all processes available.
  399. list of process rank as parameter will create a new group to work on.
  400. device (:attr:`.Tensor.device`): the specific device to execute this operator. Default: ``None``
  401. ``None`` will select the device of ``inp`` to execute.
  402. Specially, ``GPU`` device can assign a different stream to execute
  403. by adding a number right after a colon following the device name while
  404. ``:0`` denotes default stream of GPU, otherwise will use default stream.
  405. Returns:
  406. A tensor with min operation on each value across the group.
  407. The shape of the output tensor must be the same as ``inp``, and the output
  408. tensor is going to be bitwise identical in all processes across the group.
  409. Examples:
  410. >>> # We execute all_reduce_min on rank 0 and rank 1
  411. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  412. >>> input # doctest: +SKIP
  413. Tensor([1. 2.], device=xpux:0) # Rank 0
  414. Tensor([3. 4.], device=xpux:0) # Rank 1
  415. >>> F.distributed.all_reduce_min(input, group=[0, 1]) # doctest: +SKIP
  416. Tensor([1. 2.], device=xpux:0) # Rank 0
  417. Tensor([1. 2.], device=xpux:0) # Rank 1
  418. >>> # We execute all_reduce_min with on gpu0 with cuda stream 1
  419. >>> megengine.set_default_device("gpu0") # doctest: +SKIP
  420. >>> input = F.arange(2) + 1 + 2 * rank # doctest: +SKIP
  421. >>> input # doctest: +SKIP
  422. Tensor([1. 2.], device=gpu0:0) # Rank 0
  423. Tensor([3. 4.], device=gpu0:0) # Rank 1
  424. >>> F.distributed.all_reduce_min(input, device="gpu0:1") # doctest: +SKIP
  425. Tensor([1. 2.], device=xpux:0) # Rank 0
  426. Tensor([1. 2.], device=xpux:0) # Rank 1
  427. """
  428. mode = CollectiveComm.Mode.ALL_REDUCE_MIN
  429. return collective_comm(inp, mode, group, device)
  430. class _Gather(Function):
  431. def __init__(self, group=WORLD, device=None):
  432. self.group = group
  433. self.out_device = device
  434. def forward(self, data):
  435. self.in_device = str(data.device)
  436. return collective_comm(
  437. data, CollectiveComm.Mode.GATHER, self.group, self.out_device
  438. )
  439. def backward(self, grad):
  440. has_grad = _bcast_has_grad(self.group, grad)
  441. if has_grad:
  442. return scatter(grad, self.group, self.in_device)
  443. def gather(
  444. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  445. ) -> Tensor:
  446. r"""Gather tensors across the specified group.
  447. Only root process will receive the final result.
  448. Args:
  449. inp: Input tensor.
  450. group: The process group to work on.
  451. The default group is WORLD which means all processes available.
  452. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  453. device: The specific device to execute this operator.
  454. None default device means the device of inp will be used.
  455. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  456. 1 is stream id, and default stream id is 0.
  457. axis: The concat axis for collective_comm result
  458. Examples:
  459. .. code-block::
  460. input = Tensor([rank])
  461. # Rank 0 # input: Tensor([0])
  462. # Rank 1 # input: Tensor([1])
  463. output = gather(input)
  464. # Rank 0 # output: Tensor([0 1])
  465. # Rank 1 # output: None
  466. input = Tensor([rank])
  467. group = Group([1, 0]) # first rank is root
  468. output = gather(input, group)
  469. # Rank 0 # output: None
  470. # Rank 1 # output: Tensor([1 0])
  471. """
  472. assert (
  473. axis < inp.ndim
  474. ), "your concat_axis exceeds the dim of the tensor, the tensor shape is {}".format(
  475. inp.shape
  476. )
  477. out = _Gather(group, device)(inp)
  478. if group.rank == 0:
  479. if axis == 0:
  480. return out
  481. else:
  482. group_size = group.size
  483. transformed_shape = list(inp._tuple_shape)
  484. transformed_shape[axis] *= group_size
  485. n, *shp = out._tuple_shape
  486. index = (
  487. [_ for _ in range(1, axis)]
  488. + [axis, 0]
  489. + [_ for _ in range(axis + 1, out.ndim + 1)]
  490. )
  491. return (
  492. out.reshape(group_size, n // group_size, *shp)
  493. .transpose(index)
  494. .reshape(transformed_shape)
  495. )
  496. else:
  497. _save_output_for_autodiff(inp, out)
  498. class _Scatter(Function):
  499. def __init__(self, group=WORLD, device=None):
  500. self.group = group
  501. self.out_device = device
  502. def forward(self, data):
  503. self.in_device = str(data.device)
  504. return collective_comm(
  505. data, CollectiveComm.Mode.SCATTER, self.group, self.out_device
  506. )
  507. def backward(self, grad):
  508. # TODO backward with a part of grad
  509. if grad is not None:
  510. return gather(grad, self.group, self.in_device)
  511. def scatter(
  512. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  513. ) -> Tensor:
  514. r"""Split tensor in root process at first dimension.
  515. Args:
  516. inp: Input tensor.
  517. group: The process group to work on.
  518. The default group is WORLD which means all processes available.
  519. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  520. device: The specific device to execute this operator.
  521. None default device means the device of inp will be used.
  522. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  523. 1 is stream id, and default stream id is 0.
  524. axis: The concat axis for collective_comm result
  525. The default axis is 0
  526. Returns:
  527. Split tensor.
  528. Examples:
  529. .. code-block::
  530. input = Tensor([0 1]) + rank*2
  531. # Rank 0 # input: Tensor([0 1])
  532. # Rank 1 # input: Tensor([2 3])
  533. output = scatter(input)
  534. # Rank 0 # output: Tensor([0])
  535. # Rank 1 # output: Tensor([1])
  536. input = Tensor([0 1]) + rank*2
  537. group = Group([1, 0]) # first rank is root
  538. output = scatter(input, group)
  539. # Rank 0 # output: Tensor([3])
  540. # Rank 1 # output: Tensor([2])
  541. """
  542. shape, dtype = _bcast_shape_dtype(group, inp)
  543. if group.rank != 0:
  544. # dummy input to infer shape
  545. inp = _dummy_input(shape, dtype, device)
  546. _bcast_tracer_state(group, inp)
  547. assert (
  548. list(inp._tuple_shape)[axis] % group.size == 0
  549. ), "current axis: {} can't devided by group size".format(axis)
  550. if axis != 0:
  551. group_size = group.size
  552. k_new_shape = list(inp._tuple_shape)
  553. k_new_shape[axis] //= group_size
  554. k_new_shape[0] *= group_size
  555. new_shape = list(inp._tuple_shape)
  556. new_shape[axis] //= group_size
  557. new_shape.insert(axis, group_size)
  558. index = (
  559. [axis]
  560. + [_ for _ in range(0, axis)]
  561. + [_ for _ in range(axis + 1, inp.ndim + 1)]
  562. )
  563. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  564. out = _Scatter(group, device)(inp)
  565. return out
  566. def all_to_all(
  567. inp: Tensor,
  568. group: Optional[Group] = WORLD,
  569. device: Optional[str] = None,
  570. split_axis: int = 0,
  571. concat_axis: int = 0,
  572. ) -> Tensor:
  573. r"""Each process scatter input tensor to all processes and return gathered tensor.
  574. Args:
  575. inp: Input tensor.
  576. group: The process group to work on.
  577. The default group is WORLD which means all processes available.
  578. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  579. device: The specific device to execute this operator.
  580. None default device means the device of inp will be used.
  581. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  582. 1 is stream id, and default stream id is 0.
  583. split_axis: The axis that collectivecomm will split data
  584. the default axis is 0
  585. Returns:
  586. Result tensor.
  587. Examples:
  588. .. code-block::
  589. input = Tensor([0 1]) + rank*2
  590. # Rank 0 # input: Tensor([0 1])
  591. # Rank 1 # input: Tensor([2 3])
  592. output = all_to_all(input)
  593. # Rank 0 # output: Tensor([0 2])
  594. # Rank 1 # output: Tensor([1 3])
  595. input = Tensor([0 1]) + rank*2
  596. group = Group([1, 0])
  597. output = all_to_all(input, group)
  598. # Rank 0 # output: Tensor([0 3])
  599. # Rank 1 # output: Tensor([2 1])
  600. """
  601. group_size = group.size if group is not None else 1
  602. assert (
  603. list(inp._tuple_shape)[split_axis] % group_size == 0
  604. ), "current axis: {} can't devided by group size".format(split_axis)
  605. origin_shape = inp._tuple_shape
  606. if split_axis != 0:
  607. k_new_shape = list(inp._tuple_shape)
  608. k_new_shape[split_axis] //= group_size
  609. k_new_shape[0] *= group_size
  610. new_shape = list(inp._tuple_shape)
  611. new_shape[split_axis] //= group_size
  612. new_shape.insert(split_axis, group_size)
  613. index = (
  614. [split_axis]
  615. + [_ for _ in range(0, split_axis)]
  616. + [_ for _ in range(split_axis + 1, inp.ndim + 1)]
  617. )
  618. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  619. mode = CollectiveComm.Mode.ALL_TO_ALL
  620. out = collective_comm(inp, mode, group, device)
  621. if concat_axis == 0:
  622. return out
  623. transformed_shape = list(origin_shape)
  624. transformed_shape[concat_axis] *= group_size
  625. transformed_shape[split_axis] //= group_size
  626. n, *shp = out._tuple_shape
  627. index = (
  628. [_ for _ in range(1, concat_axis)]
  629. + [concat_axis, 0]
  630. + [_ for _ in range(concat_axis + 1, out.ndim + 1)]
  631. )
  632. return (
  633. out.reshape(group_size, n // group_size, *shp)
  634. .transpose(index)
  635. .reshape(transformed_shape)
  636. )
  637. class _SendRecvGroup:
  638. def __init__(self, rank_from, rank_to):
  639. self.key = "{}->{}".format(rank_from, rank_to)
  640. self.rank_from = rank_from
  641. self.rank_to = rank_to
  642. self.size = 2
  643. @property
  644. def rank(self):
  645. if get_rank() == self.rank_from:
  646. return 0
  647. else:
  648. return 1
  649. class _RemoteSend(Function):
  650. def __init__(self, op: RemoteSend):
  651. self.op = op
  652. def forward(self, data):
  653. self.device = str(data.device)
  654. (self.dummy,) = apply(self.op, data)
  655. return self.dummy
  656. def backward(self, grad):
  657. assert grad is None
  658. has_grad = get_client().bcast_val(None, self.op.key, 2)
  659. if has_grad:
  660. return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,)
  661. class _RemoteRecv(Function):
  662. def __init__(self, op: RemoteRecv):
  663. self.op = op
  664. def forward(self, dummy):
  665. return apply(self.op, dummy)[0]
  666. def backward(self, grad):
  667. get_client().bcast_val(grad is not None, self.op.key, 2)
  668. if grad is not None:
  669. remote_send(grad, self.op.rank_from)
  670. def remote_send(inp: Tensor, dest_rank: int):
  671. r"""Send tensor to another process.
  672. Args:
  673. inp: Tensor to send.
  674. dest_rank: Rank of destination process.
  675. Returns:
  676. None.
  677. Examples:
  678. .. code-block::
  679. if rank == 0:
  680. data = mge.tensor(1)
  681. # Tensor(1)
  682. F.distributed.remote_send(data, 1) # return None
  683. else:
  684. data = F.distributed.remote_recv(0)
  685. # Tensor(1)
  686. """
  687. group = _SendRecvGroup(get_rank(), dest_rank)
  688. _bcast_shape_dtype(group, inp)
  689. _bcast_tracer_state(group, inp)
  690. op = RemoteSend()
  691. op.key = group.key
  692. op.addr, op.port = get_mm_server_addr()
  693. op.rank_to = dest_rank
  694. op.backend = _backend()
  695. out = _RemoteSend(op)(inp)
  696. _save_output_for_autodiff(inp, out)
  697. def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor:
  698. r"""Receive a tensor from another process.
  699. Args:
  700. src_rank: Rank of source process.
  701. device: The specific device to execute this operator.
  702. None default device means the device of inp will be used.
  703. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  704. 1 is stream id, and default stream id is 0.
  705. inp: Dummy input to determine received tensor type.
  706. Returns:
  707. Received tensor.
  708. Examples:
  709. .. code-block::
  710. if rank == 0:
  711. data = mge.tensor(1)
  712. # Tensor(1)
  713. F.distributed.remote_send(data, 1) # return None
  714. else:
  715. data = F.distributed.remote_recv(0)
  716. # Tensor(1)
  717. """
  718. group = _SendRecvGroup(src_rank, get_rank())
  719. shape, dtype = _bcast_shape_dtype(group, None)
  720. if device is None:
  721. device = get_default_device()
  722. # dummy input
  723. if inp is None:
  724. inp = Tensor(0, device=device)
  725. _bcast_tracer_state(group, inp)
  726. op = RemoteRecv()
  727. op.key = group.key
  728. op.cn = device
  729. op.shape = shape
  730. op.dtype = dtype
  731. op.addr, op.port = get_mm_server_addr()
  732. op.rank_from = src_rank
  733. op.backend = _backend()
  734. ret = _RemoteRecv(op)(inp)
  735. return ret
  736. def _remote_send_nobackward(inp: Tensor, dest_rank: int):
  737. op = RemoteSend()
  738. op.key = "b{}->{}".format(get_rank(), dest_rank)
  739. op.addr, op.port = get_mm_server_addr()
  740. op.rank_to = dest_rank
  741. op.backend = _backend()
  742. apply(op, inp)
  743. def _remote_recv_nobackward(
  744. src_rank: int, device: Optional[str] = None, inp=None, shape=None, dtype=None,
  745. ):
  746. op = RemoteRecv()
  747. op.key = "b{}->{}".format(src_rank, get_rank())
  748. if device is None:
  749. device = get_default_device()
  750. op.cn = device
  751. if inp is None:
  752. inp = Tensor(0, device=device)
  753. assert shape is not None and dtype is not None
  754. op.shape = shape
  755. op.dtype = dtype
  756. op.addr, op.port = get_mm_server_addr()
  757. op.rank_from = src_rank
  758. op.backend = _backend()
  759. ret = apply(op, inp)[0]
  760. return ret