You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Tuple
  10. import numpy as np
  11. from ..core._imperative_rt.core2 import apply
  12. from ..core.autodiff.grad import Function, _grad_manager_dict
  13. from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
  14. from ..core.tensor.utils import isscalar, setscalar
  15. from ..device import get_default_device, what_is_xpu
  16. from ..tensor import Tensor
  17. from . import group
  18. from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank
  19. __all__ = [
  20. "reduce_sum",
  21. "broadcast",
  22. "all_gather",
  23. "reduce_scatter_sum",
  24. "all_reduce_sum",
  25. "all_reduce_max",
  26. "all_reduce_min",
  27. "gather",
  28. "scatter",
  29. "all_to_all",
  30. "remote_send",
  31. "remote_recv",
  32. ]
  33. _device2backend = {
  34. "gpu": "nccl",
  35. "cuda": "nccl",
  36. "rocm": "rccl",
  37. }
  38. def _backend():
  39. if group._sd.backend == "auto":
  40. return _device2backend[what_is_xpu()]
  41. else:
  42. return group._sd.backend
  43. def collective_comm(inp, mode, group, device):
  44. """Helper function for applying collective communication functions."""
  45. assert isinstance(group, Group)
  46. if group is None:
  47. return inp
  48. if device is None:
  49. device = ""
  50. addr, port = get_mm_server_addr()
  51. op = CollectiveComm(
  52. key=group.key + _backend(),
  53. nr_devices=group.size,
  54. rank=group.rank,
  55. is_root=(group.rank == 0),
  56. local_grad=False,
  57. addr=addr,
  58. port=port,
  59. mode=mode,
  60. dtype=inp.dtype,
  61. backend=_backend(),
  62. comp_node=device,
  63. )
  64. (result,) = apply(op, inp)
  65. # assume all workers have homogeneous shape
  66. if mode in (
  67. CollectiveComm.Mode.REDUCE_SUM,
  68. CollectiveComm.Mode.BROADCAST,
  69. CollectiveComm.Mode.ALL_REDUCE_SUM,
  70. CollectiveComm.Mode.ALL_REDUCE_MAX,
  71. CollectiveComm.Mode.ALL_REDUCE_MIN,
  72. ):
  73. if isscalar(inp):
  74. setscalar(result)
  75. return result
  76. def _save_output_for_autodiff(inp, out):
  77. for g in _grad_manager_dict.values():
  78. if g._is_attached_to(inp):
  79. g._refkeeper.append(out)
  80. def _bcast_has_grad(group, grad):
  81. if group.rank == 0:
  82. has_grad = grad is not None
  83. get_client().bcast_val(has_grad, group.key, group.size)
  84. else:
  85. has_grad = get_client().bcast_val(None, group.key, group.size)
  86. return has_grad
  87. def _bcast_shape_dtype(group, inp):
  88. if group.rank == 0:
  89. # FIXME in some cases, shape is not available(output of condtake)
  90. shape = inp._tuple_shape
  91. dtype = np.dtype(inp.dtype).name
  92. get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size)
  93. else:
  94. val = get_client().bcast_val(None, group.key, group.size)
  95. shape = val["shape"]
  96. dtype = val["dtype"]
  97. return shape, dtype
  98. def _bcast_tracer_state(group, inp):
  99. if group.rank == 0:
  100. tracer_keys = []
  101. for n, g in _grad_manager_dict.items():
  102. if g._is_attached_to(inp):
  103. tracer_keys.append(n)
  104. get_client().bcast_val(tracer_keys, group.key, group.size)
  105. else:
  106. tracer_keys = get_client().bcast_val(None, group.key, group.size)
  107. for n in tracer_keys:
  108. g = _grad_manager_dict.get(n)
  109. if g is not None:
  110. g.wrt(inp)
  111. g._refkeeper.append(inp)
  112. def _dummy_input(shape, dtype, device=None):
  113. if device is None:
  114. device = get_default_device()
  115. inp = Tensor(0, dtype=dtype, device=device)
  116. if len(shape) > 0:
  117. inp = inp._broadcast(shape)
  118. return inp
  119. class _ReduceSum(Function):
  120. def __init__(self, group=WORLD, device=None):
  121. self.group = group
  122. self.out_device = device
  123. def forward(self, data):
  124. self.in_device = str(data.device)
  125. return collective_comm(
  126. data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device,
  127. )
  128. def backward(self, grad):
  129. has_grad = _bcast_has_grad(self.group, grad)
  130. if has_grad:
  131. return broadcast(grad, self.group, self.in_device)
  132. def reduce_sum(
  133. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  134. ) -> Tensor:
  135. r"""
  136. Reduce tensor data across the specified group by sum.
  137. Only root process will receive the final result.
  138. Args:
  139. inp: Input tensor.
  140. group: The process group to work on.
  141. The default group is WORLD which means all processes available.
  142. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  143. device: The specific device to execute this operator.
  144. None default device means the device of inp will be used.
  145. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  146. 1 is stream id, and default stream id is 0.
  147. Returns:
  148. Reduced tensor if in root process, None in other processes.
  149. Examples:
  150. .. code-block::
  151. input = Tensor([rank])
  152. # Rank 0 # input: Tensor([0])
  153. # Rank 1 # input: Tensor([1])
  154. output = reduce_sum(input)
  155. # Rank 0 # output: Tensor([1])
  156. # Rank 1 # output: None
  157. input = Tensor([rank])
  158. group = Group([1, 0]) # first rank is root
  159. output = reduce_sum(input, group)
  160. # Rank 0 # output: None
  161. # Rank 1 # output: Tensor([1])
  162. """
  163. op = _ReduceSum(group, device)
  164. (out,) = apply(op, inp)
  165. if group.rank == 0:
  166. return out
  167. else:
  168. _save_output_for_autodiff(inp, out)
  169. class _Broadcast(Function):
  170. def __init__(self, group=WORLD, device=None):
  171. self.group = group
  172. self.out_device = device
  173. def forward(self, data):
  174. self.in_device = str(data.device)
  175. return collective_comm(
  176. data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device,
  177. )
  178. def backward(self, grad):
  179. # TODO backward with a part of grad
  180. if grad is not None:
  181. return reduce_sum(grad, self.group, self.in_device)
  182. def broadcast(
  183. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  184. ) -> Tensor:
  185. r"""
  186. Broadcast tensor data from root process to others.
  187. Args:
  188. inp: Input tensor.
  189. group: The process group to work on.
  190. The default group is WORLD which means all processes available.
  191. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  192. device: The specific device to execute this operator.
  193. None default device means the device of inp will be used.
  194. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  195. 1 is stream id, and default stream id is 0.
  196. Returns:
  197. Result tensor.
  198. Examples:
  199. .. code-block::
  200. input = Tensor([rank])
  201. # Rank 0 # input: Tensor([0])
  202. # Rank 1 # input: Tensor([1])
  203. output = broadcast(input)
  204. # Rank 0 # output: Tensor([0])
  205. # Rank 1 # output: Tensor([0])
  206. input = Tensor([rank])
  207. group = Group([1, 0]) # first rank is root
  208. output = broadcast(input, group)
  209. # Rank 0 # output: Tensor([1])
  210. # Rank 1 # output: Tensor([1])
  211. """
  212. shape, dtype = _bcast_shape_dtype(group, inp)
  213. if group.rank != 0:
  214. # dummy input to infer shape
  215. inp = _dummy_input(shape, dtype, device)
  216. _bcast_tracer_state(group, inp)
  217. op = _Broadcast(group, device)
  218. (out,) = apply(op, inp)
  219. return out
  220. def _bcast_param(
  221. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None
  222. ) -> Tensor:
  223. mode = CollectiveComm.Mode.BROADCAST
  224. return collective_comm(inp, mode, group, device)
  225. def all_gather(
  226. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  227. ) -> Tensor:
  228. r"""
  229. Gather tensors across the specified group and concat them at first dimension.
  230. Args:
  231. inp: Input tensor.
  232. group: The process group to work on.
  233. The default group is WORLD which means all processes available.
  234. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  235. device: The specific device to execute this operator.
  236. None default device means the device of inp will be used.
  237. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  238. 1 is stream id, and default stream id is 0.
  239. axis: The concat axis for collective_comm result
  240. The default axis is 0
  241. Returns:
  242. Result tensor.
  243. Examples:
  244. .. code-block::
  245. input = Tensor([rank])
  246. # Rank 0 # input: Tensor([0])
  247. # Rank 1 # input: Tensor([1])
  248. output = all_gather(input)
  249. # Rank 0 # output: Tensor([0 1])
  250. # Rank 1 # output: Tensor([0 1])
  251. input = Tensor([rank])
  252. group = Group([1, 0])
  253. output = all_gather(input, group)
  254. # Rank 0 # output: Tensor([1 0])
  255. # Rank 1 # output: Tensor([1 0])
  256. """
  257. mode = CollectiveComm.Mode.ALL_GATHER
  258. out = collective_comm(inp, mode, group, device)
  259. if axis == 0:
  260. return out
  261. else:
  262. group_size = group.size if group is not None else 1
  263. transformed_shape = list(inp._tuple_shape)
  264. transformed_shape[axis] *= group_size
  265. n, *shp = out._tuple_shape
  266. index = (
  267. [_ for _ in range(1, axis)]
  268. + [axis, 0]
  269. + [_ for _ in range(axis + 1, out.ndim + 1)]
  270. )
  271. return (
  272. out.reshape(group_size, n // group_size, *shp)
  273. .transpose(index)
  274. .reshape(transformed_shape)
  275. )
  276. def reduce_scatter_sum(
  277. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0
  278. ) -> Tensor:
  279. r"""
  280. Reduce tensors across the specified group by sum and split them at first dimension.
  281. Args:
  282. inp: Input tensor.
  283. group: The process group to work on.
  284. The default group is WORLD which means all processes available.
  285. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  286. device: The specific device to execute this operator.
  287. None default device means the device of inp will be used.
  288. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  289. 1 is stream id, and default stream id is 0.
  290. axis: The split axis for collective_comm result
  291. The default axis is 0, the data will split in the 0 axis
  292. Returns:
  293. Split tensor.
  294. Examples:
  295. .. code-block::
  296. input = Tensor([0 1])
  297. # Rank 0 # input: Tensor([0 1])
  298. # Rank 1 # input: Tensor([0 1])
  299. output = reduce_scatter_sum(input)
  300. # Rank 0 # output: Tensor([0])
  301. # Rank 1 # output: Tensor([2])
  302. input = Tensor([0 1])
  303. group = Group([1, 0])
  304. output = reduce_scatter_sum(input, group)
  305. # Rank 0 # output: Tensor([2])
  306. # Rank 1 # output: Tensor([0])
  307. """
  308. group_size = group.size if group is not None else 1
  309. assert (
  310. list(inp._tuple_shape)[axis] % group_size == 0
  311. ), "current axis: {} can't devided by group size".format(axis)
  312. if axis != 0:
  313. k_new_shape = list(inp._tuple_shape)
  314. k_new_shape[axis] //= group_size
  315. k_new_shape[0] *= group_size
  316. new_shape = list(inp._tuple_shape)
  317. new_shape[axis] //= group_size
  318. new_shape.insert(axis, group_size)
  319. index = (
  320. [axis]
  321. + [_ for _ in range(0, axis)]
  322. + [_ for _ in range(axis + 1, inp.ndim + 1)]
  323. )
  324. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  325. mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
  326. return collective_comm(inp, mode, group, device)
  327. def all_reduce_sum(
  328. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  329. ) -> Tensor:
  330. r"""
  331. Reduce tensors across the specified group by sum.
  332. Args:
  333. inp: Input tensor.
  334. group: The process group to work on.
  335. The default group is WORLD which means all processes available.
  336. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  337. device: The specific device to execute this operator.
  338. None default device means the device of inp will be used.
  339. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  340. 1 is stream id, and default stream id is 0.
  341. Returns:
  342. Result tensor.
  343. Examples:
  344. .. code-block::
  345. input = Tensor(rank)
  346. # Rank 0 # input: Tensor(0)
  347. # Rank 1 # input: Tensor(1)
  348. output = all_reduce_sum(input)
  349. # Rank 0 # output: Tensor(1)
  350. # Rank 1 # output: Tensor(1)
  351. """
  352. mode = CollectiveComm.Mode.ALL_REDUCE_SUM
  353. return collective_comm(inp, mode, group, device)
  354. def all_reduce_max(
  355. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  356. ) -> Tensor:
  357. r"""
  358. Reduce tensors across the specified group by max.
  359. Args:
  360. inp: Input tensor.
  361. group: The process group to work on.
  362. The default group is WORLD which means all processes available.
  363. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  364. device: The specific device to execute this operator.
  365. None default device means the device of inp will be used.
  366. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  367. 1 is stream id, and default stream id is 0.
  368. Returns:
  369. Result tensor.
  370. Examples:
  371. .. code-block::
  372. input = Tensor(rank)
  373. # Rank 0 # input: Tensor(0)
  374. # Rank 1 # input: Tensor(1)
  375. output = all_reduce_max(input)
  376. # Rank 0 # output: Tensor(1)
  377. # Rank 1 # output: Tensor(1)
  378. """
  379. mode = CollectiveComm.Mode.ALL_REDUCE_MAX
  380. return collective_comm(inp, mode, group, device)
  381. def all_reduce_min(
  382. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  383. ) -> Tensor:
  384. r"""
  385. Reduce tensors across the specified group by min.
  386. Args:
  387. inp: Input tensor.
  388. group: The process group to work on.
  389. The default group is WORLD which means all processes available.
  390. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  391. device: The specific device to execute this operator.
  392. None default device means the device of inp will be used.
  393. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  394. 1 is stream id, and default stream id is 0.
  395. Returns:
  396. Result tensor.
  397. Examples:
  398. .. code-block::
  399. input = Tensor(rank)
  400. # Rank 0 # input: Tensor(0)
  401. # Rank 1 # input: Tensor(1)
  402. output = all_reduce_min(input)
  403. # Rank 0 # output: Tensor(0)
  404. # Rank 1 # output: Tensor(0)
  405. """
  406. mode = CollectiveComm.Mode.ALL_REDUCE_MIN
  407. return collective_comm(inp, mode, group, device)
  408. class _Gather(Function):
  409. def __init__(self, group=WORLD, device=None):
  410. self.group = group
  411. self.out_device = device
  412. def forward(self, data):
  413. self.in_device = str(data.device)
  414. return collective_comm(
  415. data, CollectiveComm.Mode.GATHER, self.group, self.out_device
  416. )
  417. def backward(self, grad):
  418. has_grad = _bcast_has_grad(self.group, grad)
  419. if has_grad:
  420. return scatter(grad, self.group, self.in_device)
  421. def gather(
  422. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  423. ) -> Tensor:
  424. r"""
  425. Gather tensors across the specified group.
  426. Only root process will receive the final result.
  427. Args:
  428. inp: Input tensor.
  429. group: The process group to work on.
  430. The default group is WORLD which means all processes available.
  431. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  432. device: The specific device to execute this operator.
  433. None default device means the device of inp will be used.
  434. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  435. 1 is stream id, and default stream id is 0.
  436. axis: The concat axis for collective_comm result
  437. The default axis is 0
  438. Returns:
  439. Result tensor if in root process, None if in other process
  440. Examples:
  441. .. code-block::
  442. input = Tensor([rank])
  443. # Rank 0 # input: Tensor([0])
  444. # Rank 1 # input: Tensor([1])
  445. output = gather(input)
  446. # Rank 0 # output: Tensor([0 1])
  447. # Rank 1 # output: None
  448. input = Tensor([rank])
  449. group = Group([1, 0]) # first rank is root
  450. output = gather(input, group)
  451. # Rank 0 # output: None
  452. # Rank 1 # output: Tensor([1 0])
  453. """
  454. assert (
  455. axis < inp.ndim
  456. ), "your concat_axis exceeds the dim of the tensor, the tensor shape is {}".format(
  457. inp.shape
  458. )
  459. op = _Gather(group, device)
  460. (out,) = apply(op, inp)
  461. if group.rank == 0:
  462. if axis == 0:
  463. return out
  464. else:
  465. group_size = group.size
  466. transformed_shape = list(inp._tuple_shape)
  467. transformed_shape[axis] *= group_size
  468. n, *shp = out._tuple_shape
  469. index = (
  470. [_ for _ in range(1, axis)]
  471. + [axis, 0]
  472. + [_ for _ in range(axis + 1, out.ndim + 1)]
  473. )
  474. return (
  475. out.reshape(group_size, n // group_size, *shp)
  476. .transpose(index)
  477. .reshape(transformed_shape)
  478. )
  479. else:
  480. _save_output_for_autodiff(inp, out)
  481. class _Scatter(Function):
  482. def __init__(self, group=WORLD, device=None):
  483. self.group = group
  484. self.out_device = device
  485. def forward(self, data):
  486. self.in_device = str(data.device)
  487. return collective_comm(
  488. data, CollectiveComm.Mode.SCATTER, self.group, self.out_device
  489. )
  490. def backward(self, grad):
  491. # TODO backward with a part of grad
  492. if grad is not None:
  493. return gather(grad, self.group, self.in_device)
  494. def scatter(
  495. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
  496. ) -> Tensor:
  497. r"""
  498. Split tensor in root process at first dimension.
  499. Args:
  500. inp: Input tensor.
  501. group: The process group to work on.
  502. The default group is WORLD which means all processes available.
  503. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  504. device: The specific device to execute this operator.
  505. None default device means the device of inp will be used.
  506. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  507. 1 is stream id, and default stream id is 0.
  508. axis: The concat axis for collective_comm result
  509. The default axis is 0
  510. Returns:
  511. Split tensor.
  512. Examples:
  513. .. code-block::
  514. input = Tensor([0 1]) + rank*2
  515. # Rank 0 # input: Tensor([0 1])
  516. # Rank 1 # input: Tensor([2 3])
  517. output = scatter(input)
  518. # Rank 0 # output: Tensor([0])
  519. # Rank 1 # output: Tensor([1])
  520. input = Tensor([0 1]) + rank*2
  521. group = Group([1, 0]) # first rank is root
  522. output = scatter(input, group)
  523. # Rank 0 # output: Tensor([3])
  524. # Rank 1 # output: Tensor([2])
  525. """
  526. shape, dtype = _bcast_shape_dtype(group, inp)
  527. if group.rank != 0:
  528. # dummy input to infer shape
  529. inp = _dummy_input(shape, dtype, device)
  530. _bcast_tracer_state(group, inp)
  531. assert (
  532. list(inp._tuple_shape)[axis] % group.size == 0
  533. ), "current axis: {} can't devided by group size".format(axis)
  534. if axis != 0:
  535. group_size = group.size
  536. k_new_shape = list(inp._tuple_shape)
  537. k_new_shape[axis] //= group_size
  538. k_new_shape[0] *= group_size
  539. new_shape = list(inp._tuple_shape)
  540. new_shape[axis] //= group_size
  541. new_shape.insert(axis, group_size)
  542. index = (
  543. [axis]
  544. + [_ for _ in range(0, axis)]
  545. + [_ for _ in range(axis + 1, inp.ndim + 1)]
  546. )
  547. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  548. op = _Scatter(group, device)
  549. (out,) = apply(op, inp)
  550. return out
  551. def all_to_all(
  552. inp: Tensor,
  553. group: Optional[Group] = WORLD,
  554. device: Optional[str] = None,
  555. split_axis: int = 0,
  556. concat_axis: int = 0,
  557. ) -> Tensor:
  558. r"""
  559. Each process scatter input tensor to all processes and return gathered tensor.
  560. Args:
  561. inp: Input tensor.
  562. group: The process group to work on.
  563. The default group is WORLD which means all processes available.
  564. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  565. device: The specific device to execute this operator.
  566. None default device means the device of inp will be used.
  567. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  568. 1 is stream id, and default stream id is 0.
  569. split_axis: The axis that collectivecomm will split data
  570. the default axis is 0
  571. split_axis: The axis that collectivecomm will concat data
  572. the default axis is 0
  573. Returns:
  574. Result tensor.
  575. Examples:
  576. .. code-block::
  577. input = Tensor([0 1]) + rank*2
  578. # Rank 0 # input: Tensor([0 1])
  579. # Rank 1 # input: Tensor([2 3])
  580. output = all_to_all(input)
  581. # Rank 0 # output: Tensor([0 2])
  582. # Rank 1 # output: Tensor([1 3])
  583. input = Tensor([0 1]) + rank*2
  584. group = Group([1, 0])
  585. output = all_to_all(input, group)
  586. # Rank 0 # output: Tensor([0 3])
  587. # Rank 1 # output: Tensor([2 1])
  588. """
  589. group_size = group.size if group is not None else 1
  590. assert (
  591. list(inp._tuple_shape)[split_axis] % group_size == 0
  592. ), "current axis: {} can't devided by group size".format(split_axis)
  593. origin_shape = inp._tuple_shape
  594. if split_axis != 0:
  595. k_new_shape = list(inp._tuple_shape)
  596. k_new_shape[split_axis] //= group_size
  597. k_new_shape[0] *= group_size
  598. new_shape = list(inp._tuple_shape)
  599. new_shape[split_axis] //= group_size
  600. new_shape.insert(split_axis, group_size)
  601. index = (
  602. [split_axis]
  603. + [_ for _ in range(0, split_axis)]
  604. + [_ for _ in range(split_axis + 1, inp.ndim + 1)]
  605. )
  606. inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
  607. mode = CollectiveComm.Mode.ALL_TO_ALL
  608. out = collective_comm(inp, mode, group, device)
  609. if concat_axis == 0:
  610. return out
  611. transformed_shape = list(origin_shape)
  612. transformed_shape[concat_axis] *= group_size
  613. transformed_shape[split_axis] //= group_size
  614. n, *shp = out._tuple_shape
  615. index = (
  616. [_ for _ in range(1, concat_axis)]
  617. + [concat_axis, 0]
  618. + [_ for _ in range(concat_axis + 1, out.ndim + 1)]
  619. )
  620. return (
  621. out.reshape(group_size, n // group_size, *shp)
  622. .transpose(index)
  623. .reshape(transformed_shape)
  624. )
  625. class _SendRecvGroup:
  626. def __init__(self, rank_from, rank_to):
  627. self.key = "{}->{}".format(rank_from, rank_to)
  628. self.rank_from = rank_from
  629. self.rank_to = rank_to
  630. self.size = 2
  631. @property
  632. def rank(self):
  633. if get_rank() == self.rank_from:
  634. return 0
  635. else:
  636. return 1
  637. class _RemoteSend(Function):
  638. def __init__(self, op: RemoteSend):
  639. self.op = op
  640. def forward(self, data):
  641. self.device = str(data.device)
  642. (self.dummy,) = apply(self.op, data)
  643. return self.dummy
  644. def backward(self, grad):
  645. assert grad is None
  646. has_grad = get_client().bcast_val(None, self.op.key, 2)
  647. if has_grad:
  648. return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,)
  649. class _RemoteRecv(Function):
  650. def __init__(self, op: RemoteRecv):
  651. self.op = op
  652. def forward(self, dummy):
  653. return apply(self.op, dummy)
  654. def backward(self, grad):
  655. get_client().bcast_val(grad is not None, self.op.key, 2)
  656. if grad is not None:
  657. remote_send(grad, self.op.rank_from)
  658. def remote_send(inp: Tensor, dest_rank: int):
  659. r"""
  660. Send tensor to another process.
  661. Args:
  662. inp: Tensor to send.
  663. dest_rank: Rank of destination process.
  664. Returns:
  665. None.
  666. Examples:
  667. .. code-block::
  668. if rank == 0:
  669. data = mge.tensor(1)
  670. # Tensor(1)
  671. F.distributed.remote_send(data, 1) # return None
  672. else:
  673. data = F.distributed.remote_recv(0)
  674. # Tensor(1)
  675. """
  676. group = _SendRecvGroup(get_rank(), dest_rank)
  677. _bcast_shape_dtype(group, inp)
  678. _bcast_tracer_state(group, inp)
  679. op = RemoteSend()
  680. op.key = group.key
  681. op.addr, op.port = get_mm_server_addr()
  682. op.rank_to = dest_rank
  683. op.backend = _backend()
  684. (out,) = apply(_RemoteSend(op), inp)
  685. _save_output_for_autodiff(inp, out)
  686. def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor:
  687. r"""
  688. Receive a tensor from another process.
  689. Args:
  690. src_rank: Rank of source process.
  691. device: The specific device to execute this operator.
  692. None default device means the device of inp will be used.
  693. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  694. 1 is stream id, and default stream id is 0.
  695. inp: Dummy input to determine received tensor type.
  696. Returns:
  697. Received tensor.
  698. Examples:
  699. .. code-block::
  700. if rank == 0:
  701. data = mge.tensor(1)
  702. # Tensor(1)
  703. F.distributed.remote_send(data, 1) # return None
  704. else:
  705. data = F.distributed.remote_recv(0)
  706. # Tensor(1)
  707. """
  708. group = _SendRecvGroup(src_rank, get_rank())
  709. shape, dtype = _bcast_shape_dtype(group, None)
  710. if device is None:
  711. device = get_default_device()
  712. # dummy input
  713. if inp is None:
  714. inp = Tensor(0, device=device)
  715. _bcast_tracer_state(group, inp)
  716. _isscalar = False
  717. if len(shape) == 0:
  718. shape = (1,)
  719. _isscalar = True
  720. op = RemoteRecv()
  721. op.key = group.key
  722. op.cn = device
  723. op.shape = shape
  724. op.dtype = dtype
  725. op.addr, op.port = get_mm_server_addr()
  726. op.rank_from = src_rank
  727. op.backend = _backend()
  728. (ret,) = apply(_RemoteRecv(op), inp)
  729. if _isscalar:
  730. setscalar(ret)
  731. return ret

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台