You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Tuple
  10. import numpy as np
  11. from ..core._imperative_rt.core2 import apply
  12. from ..core.autodiff.grad import Function, _grad_manager_dict
  13. from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
  14. from ..core.tensor.utils import isscalar, setscalar
  15. from ..device import get_default_device, what_is_xpu
  16. from ..tensor import Tensor
  17. from . import group
  18. from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank
  19. __all__ = [
  20. "reduce_sum",
  21. "broadcast",
  22. "all_gather",
  23. "reduce_scatter_sum",
  24. "all_reduce_sum",
  25. "all_reduce_max",
  26. "all_reduce_min",
  27. "gather",
  28. "scatter",
  29. "all_to_all",
  30. "remote_send",
  31. "remote_recv",
  32. ]
  33. _device2backend = {
  34. "gpu": "nccl",
  35. "cuda": "nccl",
  36. "rocm": "rccl",
  37. }
  38. def _backend():
  39. if group._sd.backend == "auto":
  40. return _device2backend[what_is_xpu()]
  41. else:
  42. return group._sd.backend
  43. def collective_comm(inp, mode, group, device):
  44. r"""Helper function for applying collective communication functions."""
  45. assert isinstance(group, Group)
  46. if group is None:
  47. return inp
  48. if device is None:
  49. device = ""
  50. addr, port = get_mm_server_addr()
  51. op = CollectiveComm(
  52. key=group.key + _backend(),
  53. nr_devices=group.size,
  54. rank=group.rank,
  55. is_root=(group.rank == 0),
  56. local_grad=False,
  57. addr=addr,
  58. port=port,
  59. mode=mode,
  60. dtype=inp.dtype,
  61. backend=_backend(),
  62. comp_node=device,
  63. )
  64. (result,) = apply(op, inp)
  65. # assume all workers have homogeneous shape
  66. if mode in (
  67. CollectiveComm.Mode.REDUCE_SUM,
  68. CollectiveComm.Mode.BROADCAST,
  69. CollectiveComm.Mode.ALL_REDUCE_SUM,
  70. CollectiveComm.Mode.ALL_REDUCE_MAX,
  71. CollectiveComm.Mode.ALL_REDUCE_MIN,
  72. ):
  73. if isscalar(inp):
  74. setscalar(result)
  75. return result
  76. def _save_output_for_autodiff(inp, out):
  77. for g in _grad_manager_dict.values():
  78. if g._is_attached_to(inp):
  79. g._refkeeper.append(out)
  80. def _bcast_has_grad(group, grad):
  81. if group.rank == 0:
  82. has_grad = grad is not None
  83. get_client().bcast_val(has_grad, group.key, group.size)
  84. else:
  85. has_grad = get_client().bcast_val(None, group.key, group.size)
  86. return has_grad
  87. def _bcast_shape_dtype(group, inp):
  88. if group.rank == 0:
  89. # FIXME in some cases, shape is not available(output of condtake)
  90. shape = inp._tuple_shape
  91. dtype = np.dtype(inp.dtype).name
  92. get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size)
  93. else:
  94. val = get_client().bcast_val(None, group.key, group.size)
  95. shape = val["shape"]
  96. dtype = val["dtype"]
  97. return shape, dtype
  98. def _bcast_tracer_state(group, inp):
  99. if group.rank == 0:
  100. tracer_keys = []
  101. for n, g in _grad_manager_dict.items():
  102. if g._is_attached_to(inp):
  103. tracer_keys.append(n)
  104. get_client().bcast_val(tracer_keys, group.key, group.size)
  105. else:
  106. tracer_keys = get_client().bcast_val(None, group.key, group.size)
  107. for n in tracer_keys:
  108. g = _grad_manager_dict.get(n)
  109. if g is not None:
  110. g.wrt(inp)
  111. g._refkeeper.append(inp)
  112. def _dummy_input(shape, dtype, device=None):
  113. if device is None:
  114. device = get_default_device()
  115. inp = Tensor(0, dtype=dtype, device=device)
  116. if len(shape) > 0:
  117. inp = inp._broadcast(shape)
  118. return inp
  119. class _ReduceSum(Function):
  120. def __init__(self, group=WORLD, device=None):
  121. self.group = group
  122. self.out_device = device
  123. def forward(self, data):
  124. self.in_device = str(data.device)
  125. return collective_comm(
  126. data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device,
  127. )
  128. def backward(self, grad):
  129. has_grad = _bcast_has_grad(self.group, grad)
  130. if has_grad:
  131. return broadcast(grad, self.group, self.in_device)
  132. def reduce_sum(
  133. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  134. ) -> Tensor:
  135. r"""Reduce tensor data across the specified group by sum.
  136. Only root process will receive the final result.
  137. Args:
  138. inp: Input tensor.
  139. group: The process group to work on.
  140. The default group is WORLD which means all processes available.
  141. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  142. device: The specific device to execute this operator.
  143. None default device means the device of inp will be used.
  144. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  145. 1 is stream id, and default stream id is 0.
  146. Returns:
  147. Reduced tensor if in root process, None in other processes.
  148. Examples:
  149. .. code-block::
  150. input = Tensor([rank])
  151. # Rank 0 # input: Tensor([0])
  152. # Rank 1 # input: Tensor([1])
  153. output = reduce_sum(input)
  154. # Rank 0 # output: Tensor([1])
  155. # Rank 1 # output: None
  156. input = Tensor([rank])
  157. group = Group([1, 0]) # first rank is root
  158. output = reduce_sum(input, group)
  159. # Rank 0 # output: None
  160. # Rank 1 # output: Tensor([1])
  161. """
  162. op = _ReduceSum(group, device)
  163. (out,) = apply(op, inp)
  164. if group.rank == 0:
  165. return out
  166. else:
  167. _save_output_for_autodiff(inp, out)
  168. class _Broadcast(Function):
  169. def __init__(self, group=WORLD, device=None):
  170. self.group = group
  171. self.out_device = device
  172. def forward(self, data):
  173. self.in_device = str(data.device)
  174. return collective_comm(
  175. data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device,
  176. )
  177. def backward(self, grad):
  178. # TODO backward with a part of grad
  179. if grad is not None:
  180. return reduce_sum(grad, self.group, self.in_device)
  181. def broadcast(
  182. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  183. ) -> Tensor:
  184. r"""Broadcast tensor data from root process to others.
  185. Args:
  186. inp: Input tensor.
  187. group: The process group to work on.
  188. The default group is WORLD which means all processes available.
  189. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  190. device: The specific device to execute this operator.
  191. None default device means the device of inp will be used.
  192. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  193. 1 is stream id, and default stream id is 0.
  194. Returns:
  195. Result tensor.
  196. Examples:
  197. .. code-block::
  198. input = Tensor([rank])
  199. # Rank 0 # input: Tensor([0])
  200. # Rank 1 # input: Tensor([1])
  201. output = broadcast(input)
  202. # Rank 0 # output: Tensor([0])
  203. # Rank 1 # output: Tensor([0])
  204. input = Tensor([rank])
  205. group = Group([1, 0]) # first rank is root
  206. output = broadcast(input, group)
  207. # Rank 0 # output: Tensor([1])
  208. # Rank 1 # output: Tensor([1])
  209. """
  210. shape, dtype = _bcast_shape_dtype(group, inp)
  211. if group.rank != 0:
  212. # dummy input to infer shape
  213. inp = _dummy_input(shape, dtype, device)
  214. _bcast_tracer_state(group, inp)
  215. op = _Broadcast(group, device)
  216. (out,) = apply(op, inp)
  217. return out
  218. def _bcast_param(
  219. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None
  220. ) -> Tensor:
  221. mode = CollectiveComm.Mode.BROADCAST
  222. return collective_comm(inp, mode, group, device)
  223. def all_gather(
  224. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  225. ) -> Tensor:
  226. r"""Gather tensors across the specified group and concat them at first dimension.
  227. Args:
  228. inp: Input tensor.
  229. group: The process group to work on.
  230. The default group is WORLD which means all processes available.
  231. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  232. device: The specific device to execute this operator.
  233. None default device means the device of inp will be used.
  234. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  235. 1 is stream id, and default stream id is 0.
  236. axis: The concat axis for collective_comm result
  237. The default axis is 0
  238. Returns:
  239. Result tensor.
  240. Examples:
  241. .. code-block::
  242. input = Tensor([rank])
  243. # Rank 0 # input: Tensor([0])
  244. # Rank 1 # input: Tensor([1])
  245. output = all_gather(input)
  246. # Rank 0 # output: Tensor([0 1])
  247. # Rank 1 # output: Tensor([0 1])
  248. input = Tensor([rank])
  249. group = Group([1, 0])
  250. output = all_gather(input, group)
  251. # Rank 0 # output: Tensor([1 0])
  252. # Rank 1 # output: Tensor([1 0])
  253. """
  254. mode = CollectiveComm.Mode.ALL_GATHER
  255. out = collective_comm(inp, mode, group, device)
  256. if axis == 0:
  257. return out
  258. else:
  259. group_size = group.size if group is not None else 1
  260. transformed_shape = list(inp._tuple_shape)
  261. transformed_shape[axis] *= group_size
  262. n, *shp = out._tuple_shape
  263. index = (
  264. [_ for _ in range(1, axis)]
  265. + [axis, 0]
  266. + [_ for _ in range(axis + 1, out.ndim + 1)]
  267. )
  268. return (
  269. out.reshape(group_size, n // group_size, *shp)
  270. .transpose(index)
  271. .reshape(transformed_shape)
  272. )
  273. def reduce_scatter_sum(
  274. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0
  275. ) -> Tensor:
  276. r"""Reduce tensors across the specified group by sum and split them at first dimension.
  277. Args:
  278. inp: Input tensor.
  279. group: The process group to work on.
  280. The default group is WORLD which means all processes available.
  281. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  282. device: The specific device to execute this operator.
  283. None default device means the device of inp will be used.
  284. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  285. 1 is stream id, and default stream id is 0.
  286. axis: The split axis for collective_comm result
  287. The default axis is 0, the data will split in the 0 axis
  288. Returns:
  289. Split tensor.
  290. Examples:
  291. .. code-block::
  292. input = Tensor([0 1])
  293. # Rank 0 # input: Tensor([0 1])
  294. # Rank 1 # input: Tensor([0 1])
  295. output = reduce_scatter_sum(input)
  296. # Rank 0 # output: Tensor([0])
  297. # Rank 1 # output: Tensor([2])
  298. input = Tensor([0 1])
  299. group = Group([1, 0])
  300. output = reduce_scatter_sum(input, group)
  301. # Rank 0 # output: Tensor([2])
  302. # Rank 1 # output: Tensor([0])
  303. """
  304. group_size = group.size if group is not None else 1
  305. assert (
  306. list(inp._tuple_shape)[axis] % group_size == 0
  307. ), "current axis: {} can't devided by group size".format(axis)
  308. if axis != 0:
  309. k_new_shape = list(inp._tuple_shape)
  310. k_new_shape[axis] //= group_size
  311. k_new_shape[0] *= group_size
  312. new_shape = list(inp._tuple_shape)
  313. new_shape[axis] //= group_size
  314. new_shape.insert(axis, group_size)
  315. index = (
  316. [axis]
  317. + [_ for _ in range(0, axis)]
  318. + [_ for _ in range(axis + 1, inp.ndim + 1)]
  319. )
  320. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  321. mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
  322. return collective_comm(inp, mode, group, device)
  323. def all_reduce_sum(
  324. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  325. ) -> Tensor:
  326. r"""Reduce tensors across the specified group by sum.
  327. Args:
  328. inp: Input tensor.
  329. group: The process group to work on.
  330. The default group is WORLD which means all processes available.
  331. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  332. device: The specific device to execute this operator.
  333. None default device means the device of inp will be used.
  334. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  335. 1 is stream id, and default stream id is 0.
  336. Returns:
  337. Result tensor.
  338. Examples:
  339. .. code-block::
  340. input = Tensor(rank)
  341. # Rank 0 # input: Tensor(0)
  342. # Rank 1 # input: Tensor(1)
  343. output = all_reduce_sum(input)
  344. # Rank 0 # output: Tensor(1)
  345. # Rank 1 # output: Tensor(1)
  346. """
  347. mode = CollectiveComm.Mode.ALL_REDUCE_SUM
  348. return collective_comm(inp, mode, group, device)
  349. def all_reduce_max(
  350. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  351. ) -> Tensor:
  352. r"""Reduce tensors across the specified group by max.
  353. Args:
  354. inp: Input tensor.
  355. group: The process group to work on.
  356. The default group is WORLD which means all processes available.
  357. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  358. device: The specific device to execute this operator.
  359. None default device means the device of inp will be used.
  360. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  361. 1 is stream id, and default stream id is 0.
  362. Returns:
  363. Result tensor.
  364. Examples:
  365. .. code-block::
  366. input = Tensor(rank)
  367. # Rank 0 # input: Tensor(0)
  368. # Rank 1 # input: Tensor(1)
  369. output = all_reduce_max(input)
  370. # Rank 0 # output: Tensor(1)
  371. # Rank 1 # output: Tensor(1)
  372. """
  373. mode = CollectiveComm.Mode.ALL_REDUCE_MAX
  374. return collective_comm(inp, mode, group, device)
  375. def all_reduce_min(
  376. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  377. ) -> Tensor:
  378. r"""Reduce tensors across the specified group by min.
  379. Args:
  380. inp: Input tensor.
  381. group: The process group to work on.
  382. The default group is WORLD which means all processes available.
  383. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  384. device: The specific device to execute this operator.
  385. None default device means the device of inp will be used.
  386. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  387. 1 is stream id, and default stream id is 0.
  388. Returns:
  389. Result tensor.
  390. Examples:
  391. .. code-block::
  392. input = Tensor(rank)
  393. # Rank 0 # input: Tensor(0)
  394. # Rank 1 # input: Tensor(1)
  395. output = all_reduce_min(input)
  396. # Rank 0 # output: Tensor(0)
  397. # Rank 1 # output: Tensor(0)
  398. """
  399. mode = CollectiveComm.Mode.ALL_REDUCE_MIN
  400. return collective_comm(inp, mode, group, device)
  401. class _Gather(Function):
  402. def __init__(self, group=WORLD, device=None):
  403. self.group = group
  404. self.out_device = device
  405. def forward(self, data):
  406. self.in_device = str(data.device)
  407. return collective_comm(
  408. data, CollectiveComm.Mode.GATHER, self.group, self.out_device
  409. )
  410. def backward(self, grad):
  411. has_grad = _bcast_has_grad(self.group, grad)
  412. if has_grad:
  413. return scatter(grad, self.group, self.in_device)
  414. def gather(
  415. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  416. ) -> Tensor:
  417. r"""Gather tensors across the specified group.
  418. Only root process will receive the final result.
  419. Args:
  420. inp: Input tensor.
  421. group: The process group to work on.
  422. The default group is WORLD which means all processes available.
  423. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  424. device: The specific device to execute this operator.
  425. None default device means the device of inp will be used.
  426. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  427. 1 is stream id, and default stream id is 0.
  428. axis: The concat axis for collective_comm result
  429. Examples:
  430. .. code-block::
  431. input = Tensor([rank])
  432. # Rank 0 # input: Tensor([0])
  433. # Rank 1 # input: Tensor([1])
  434. output = gather(input)
  435. # Rank 0 # output: Tensor([0 1])
  436. # Rank 1 # output: None
  437. input = Tensor([rank])
  438. group = Group([1, 0]) # first rank is root
  439. output = gather(input, group)
  440. # Rank 0 # output: None
  441. # Rank 1 # output: Tensor([1 0])
  442. """
  443. assert (
  444. axis < inp.ndim
  445. ), "your concat_axis exceeds the dim of the tensor, the tensor shape is {}".format(
  446. inp.shape
  447. )
  448. op = _Gather(group, device)
  449. (out,) = apply(op, inp)
  450. if group.rank == 0:
  451. if axis == 0:
  452. return out
  453. else:
  454. group_size = group.size
  455. transformed_shape = list(inp._tuple_shape)
  456. transformed_shape[axis] *= group_size
  457. n, *shp = out._tuple_shape
  458. index = (
  459. [_ for _ in range(1, axis)]
  460. + [axis, 0]
  461. + [_ for _ in range(axis + 1, out.ndim + 1)]
  462. )
  463. return (
  464. out.reshape(group_size, n // group_size, *shp)
  465. .transpose(index)
  466. .reshape(transformed_shape)
  467. )
  468. else:
  469. _save_output_for_autodiff(inp, out)
  470. class _Scatter(Function):
  471. def __init__(self, group=WORLD, device=None):
  472. self.group = group
  473. self.out_device = device
  474. def forward(self, data):
  475. self.in_device = str(data.device)
  476. return collective_comm(
  477. data, CollectiveComm.Mode.SCATTER, self.group, self.out_device
  478. )
  479. def backward(self, grad):
  480. # TODO backward with a part of grad
  481. if grad is not None:
  482. return gather(grad, self.group, self.in_device)
  483. def scatter(
  484. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  485. ) -> Tensor:
  486. r"""Split tensor in root process at first dimension.
  487. Args:
  488. inp: Input tensor.
  489. group: The process group to work on.
  490. The default group is WORLD which means all processes available.
  491. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  492. device: The specific device to execute this operator.
  493. None default device means the device of inp will be used.
  494. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  495. 1 is stream id, and default stream id is 0.
  496. axis: The concat axis for collective_comm result
  497. The default axis is 0
  498. Returns:
  499. Split tensor.
  500. Examples:
  501. .. code-block::
  502. input = Tensor([0 1]) + rank*2
  503. # Rank 0 # input: Tensor([0 1])
  504. # Rank 1 # input: Tensor([2 3])
  505. output = scatter(input)
  506. # Rank 0 # output: Tensor([0])
  507. # Rank 1 # output: Tensor([1])
  508. input = Tensor([0 1]) + rank*2
  509. group = Group([1, 0]) # first rank is root
  510. output = scatter(input, group)
  511. # Rank 0 # output: Tensor([3])
  512. # Rank 1 # output: Tensor([2])
  513. """
  514. shape, dtype = _bcast_shape_dtype(group, inp)
  515. if group.rank != 0:
  516. # dummy input to infer shape
  517. inp = _dummy_input(shape, dtype, device)
  518. _bcast_tracer_state(group, inp)
  519. assert (
  520. list(inp._tuple_shape)[axis] % group.size == 0
  521. ), "current axis: {} can't devided by group size".format(axis)
  522. if axis != 0:
  523. group_size = group.size
  524. k_new_shape = list(inp._tuple_shape)
  525. k_new_shape[axis] //= group_size
  526. k_new_shape[0] *= group_size
  527. new_shape = list(inp._tuple_shape)
  528. new_shape[axis] //= group_size
  529. new_shape.insert(axis, group_size)
  530. index = (
  531. [axis]
  532. + [_ for _ in range(0, axis)]
  533. + [_ for _ in range(axis + 1, inp.ndim + 1)]
  534. )
  535. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  536. op = _Scatter(group, device)
  537. (out,) = apply(op, inp)
  538. return out
  539. def all_to_all(
  540. inp: Tensor,
  541. group: Optional[Group] = WORLD,
  542. device: Optional[str] = None,
  543. split_axis: int = 0,
  544. concat_axis: int = 0,
  545. ) -> Tensor:
  546. r"""Each process scatter input tensor to all processes and return gathered tensor.
  547. Args:
  548. inp: Input tensor.
  549. group: The process group to work on.
  550. The default group is WORLD which means all processes available.
  551. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  552. device: The specific device to execute this operator.
  553. None default device means the device of inp will be used.
  554. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  555. 1 is stream id, and default stream id is 0.
  556. split_axis: The axis that collectivecomm will split data
  557. the default axis is 0
  558. Returns:
  559. Result tensor.
  560. Examples:
  561. .. code-block::
  562. input = Tensor([0 1]) + rank*2
  563. # Rank 0 # input: Tensor([0 1])
  564. # Rank 1 # input: Tensor([2 3])
  565. output = all_to_all(input)
  566. # Rank 0 # output: Tensor([0 2])
  567. # Rank 1 # output: Tensor([1 3])
  568. input = Tensor([0 1]) + rank*2
  569. group = Group([1, 0])
  570. output = all_to_all(input, group)
  571. # Rank 0 # output: Tensor([0 3])
  572. # Rank 1 # output: Tensor([2 1])
  573. """
  574. group_size = group.size if group is not None else 1
  575. assert (
  576. list(inp._tuple_shape)[split_axis] % group_size == 0
  577. ), "current axis: {} can't devided by group size".format(split_axis)
  578. origin_shape = inp._tuple_shape
  579. if split_axis != 0:
  580. k_new_shape = list(inp._tuple_shape)
  581. k_new_shape[split_axis] //= group_size
  582. k_new_shape[0] *= group_size
  583. new_shape = list(inp._tuple_shape)
  584. new_shape[split_axis] //= group_size
  585. new_shape.insert(split_axis, group_size)
  586. index = (
  587. [split_axis]
  588. + [_ for _ in range(0, split_axis)]
  589. + [_ for _ in range(split_axis + 1, inp.ndim + 1)]
  590. )
  591. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  592. mode = CollectiveComm.Mode.ALL_TO_ALL
  593. out = collective_comm(inp, mode, group, device)
  594. if concat_axis == 0:
  595. return out
  596. transformed_shape = list(origin_shape)
  597. transformed_shape[concat_axis] *= group_size
  598. transformed_shape[split_axis] //= group_size
  599. n, *shp = out._tuple_shape
  600. index = (
  601. [_ for _ in range(1, concat_axis)]
  602. + [concat_axis, 0]
  603. + [_ for _ in range(concat_axis + 1, out.ndim + 1)]
  604. )
  605. return (
  606. out.reshape(group_size, n // group_size, *shp)
  607. .transpose(index)
  608. .reshape(transformed_shape)
  609. )
  610. class _SendRecvGroup:
  611. def __init__(self, rank_from, rank_to):
  612. self.key = "{}->{}".format(rank_from, rank_to)
  613. self.rank_from = rank_from
  614. self.rank_to = rank_to
  615. self.size = 2
  616. @property
  617. def rank(self):
  618. if get_rank() == self.rank_from:
  619. return 0
  620. else:
  621. return 1
  622. class _RemoteSend(Function):
  623. def __init__(self, op: RemoteSend):
  624. self.op = op
  625. def forward(self, data):
  626. self.device = str(data.device)
  627. (self.dummy,) = apply(self.op, data)
  628. return self.dummy
  629. def backward(self, grad):
  630. assert grad is None
  631. has_grad = get_client().bcast_val(None, self.op.key, 2)
  632. if has_grad:
  633. return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,)
  634. class _RemoteRecv(Function):
  635. def __init__(self, op: RemoteRecv):
  636. self.op = op
  637. def forward(self, dummy):
  638. return apply(self.op, dummy)
  639. def backward(self, grad):
  640. get_client().bcast_val(grad is not None, self.op.key, 2)
  641. if grad is not None:
  642. remote_send(grad, self.op.rank_from)
  643. def remote_send(inp: Tensor, dest_rank: int):
  644. r"""Send tensor to another process.
  645. Args:
  646. inp: Tensor to send.
  647. dest_rank: Rank of destination process.
  648. Returns:
  649. None.
  650. Examples:
  651. .. code-block::
  652. if rank == 0:
  653. data = mge.tensor(1)
  654. # Tensor(1)
  655. F.distributed.remote_send(data, 1) # return None
  656. else:
  657. data = F.distributed.remote_recv(0)
  658. # Tensor(1)
  659. """
  660. group = _SendRecvGroup(get_rank(), dest_rank)
  661. _bcast_shape_dtype(group, inp)
  662. _bcast_tracer_state(group, inp)
  663. op = RemoteSend()
  664. op.key = group.key
  665. op.addr, op.port = get_mm_server_addr()
  666. op.rank_to = dest_rank
  667. op.backend = _backend()
  668. (out,) = apply(_RemoteSend(op), inp)
  669. _save_output_for_autodiff(inp, out)
  670. def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor:
  671. r"""Receive a tensor from another process.
  672. Args:
  673. src_rank: Rank of source process.
  674. device: The specific device to execute this operator.
  675. None default device means the device of inp will be used.
  676. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  677. 1 is stream id, and default stream id is 0.
  678. inp: Dummy input to determine received tensor type.
  679. Returns:
  680. Received tensor.
  681. Examples:
  682. .. code-block::
  683. if rank == 0:
  684. data = mge.tensor(1)
  685. # Tensor(1)
  686. F.distributed.remote_send(data, 1) # return None
  687. else:
  688. data = F.distributed.remote_recv(0)
  689. # Tensor(1)
  690. """
  691. group = _SendRecvGroup(src_rank, get_rank())
  692. shape, dtype = _bcast_shape_dtype(group, None)
  693. if device is None:
  694. device = get_default_device()
  695. # dummy input
  696. if inp is None:
  697. inp = Tensor(0, device=device)
  698. _bcast_tracer_state(group, inp)
  699. _isscalar = False
  700. if len(shape) == 0:
  701. shape = (1,)
  702. _isscalar = True
  703. op = RemoteRecv()
  704. op.key = group.key
  705. op.cn = device
  706. op.shape = shape
  707. op.dtype = dtype
  708. op.addr, op.port = get_mm_server_addr()
  709. op.rank_from = src_rank
  710. op.backend = _backend()
  711. (ret,) = apply(_RemoteRecv(op), inp)
  712. if _isscalar:
  713. setscalar(ret)
  714. return ret

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台