You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Tuple
  10. import numpy as np
  11. from ..core._imperative_rt.core2 import apply
  12. from ..core.autodiff.grad import Function, _grad_manager_dict
  13. from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
  14. from ..core.tensor.utils import isscalar, setscalar
  15. from ..device import get_default_device, what_is_xpu
  16. from ..tensor import Tensor
  17. from . import group
  18. from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank
  19. __all__ = [
  20. "reduce_sum",
  21. "broadcast",
  22. "all_gather",
  23. "reduce_scatter_sum",
  24. "all_reduce_sum",
  25. "all_reduce_max",
  26. "all_reduce_min",
  27. "gather",
  28. "scatter",
  29. "all_to_all",
  30. "remote_send",
  31. "remote_recv",
  32. ]
  33. _device2backend = {
  34. "gpu": "nccl",
  35. "cuda": "nccl",
  36. "rocm": "rccl",
  37. }
  38. def _backend():
  39. if group._sd.backend == "auto":
  40. return _device2backend[what_is_xpu()]
  41. else:
  42. return group._sd.backend
  43. def collective_comm(inp, mode, group, device):
  44. """Helper function for applying collective communication functions."""
  45. assert isinstance(group, Group)
  46. if group is None:
  47. return inp
  48. if device is None:
  49. device = ""
  50. addr, port = get_mm_server_addr()
  51. op = CollectiveComm(
  52. key=group.key + _backend(),
  53. nr_devices=group.size,
  54. rank=group.rank,
  55. is_root=(group.rank == 0),
  56. local_grad=False,
  57. addr=addr,
  58. port=port,
  59. mode=mode,
  60. dtype=inp.dtype,
  61. backend=_backend(),
  62. comp_node=device,
  63. )
  64. (result,) = apply(op, inp)
  65. # assume all workers have homogeneous shape
  66. if mode in (
  67. CollectiveComm.Mode.REDUCE_SUM,
  68. CollectiveComm.Mode.BROADCAST,
  69. CollectiveComm.Mode.ALL_REDUCE_SUM,
  70. CollectiveComm.Mode.ALL_REDUCE_MAX,
  71. CollectiveComm.Mode.ALL_REDUCE_MIN,
  72. ):
  73. if isscalar(inp):
  74. setscalar(result)
  75. return result
  76. def _save_output_for_autodiff(inp, out):
  77. for g in _grad_manager_dict.values():
  78. if g._is_attached_to(inp):
  79. g._refkeeper.append(out)
  80. def _bcast_has_grad(group, grad):
  81. if group.rank == 0:
  82. has_grad = grad is not None
  83. get_client().bcast_val(has_grad, group.key, group.size)
  84. else:
  85. has_grad = get_client().bcast_val(None, group.key, group.size)
  86. return has_grad
  87. def _bcast_shape_dtype(group, inp):
  88. if group.rank == 0:
  89. # FIXME in some cases, shape is not available(output of condtake)
  90. shape = inp._tuple_shape
  91. dtype = np.dtype(inp.dtype).name
  92. get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size)
  93. else:
  94. val = get_client().bcast_val(None, group.key, group.size)
  95. shape = val["shape"]
  96. dtype = val["dtype"]
  97. return shape, dtype
  98. def _bcast_tracer_state(group, inp):
  99. if group.rank == 0:
  100. tracer_keys = []
  101. for n, g in _grad_manager_dict.items():
  102. if g._is_attached_to(inp):
  103. tracer_keys.append(n)
  104. get_client().bcast_val(tracer_keys, group.key, group.size)
  105. else:
  106. tracer_keys = get_client().bcast_val(None, group.key, group.size)
  107. for n in tracer_keys:
  108. g = _grad_manager_dict.get(n)
  109. if g is not None:
  110. g.wrt(inp)
  111. g._refkeeper.append(inp)
  112. def _dummy_input(shape, dtype, device=None):
  113. if device is None:
  114. device = get_default_device()
  115. inp = Tensor(0, dtype=dtype, device=device)
  116. if len(shape) > 0:
  117. inp = inp._broadcast(shape)
  118. return inp
  119. class _ReduceSum(Function):
  120. def __init__(self, group=WORLD, device=None):
  121. self.group = group
  122. self.out_device = device
  123. def forward(self, data):
  124. self.in_device = str(data.device)
  125. return collective_comm(
  126. data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device,
  127. )
  128. def backward(self, grad):
  129. has_grad = _bcast_has_grad(self.group, grad)
  130. if has_grad:
  131. return broadcast(grad, self.group, self.in_device)
  132. def reduce_sum(
  133. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  134. ) -> Tensor:
  135. r"""
  136. Reduce tensor data across the specified group by sum.
  137. Only root process will receive the final result.
  138. Args:
  139. inp: Input tensor.
  140. group: The process group to work on.
  141. The default group is WORLD which means all processes available.
  142. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  143. device: The specific device to execute this operator.
  144. None default device means the device of inp will be used.
  145. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  146. 1 is stream id, and default stream id is 0.
  147. Returns:
  148. Reduced tensor if in root process, None in other processes.
  149. Examples:
  150. .. code-block::
  151. input = Tensor([rank])
  152. # Rank 0 # input: Tensor([0])
  153. # Rank 1 # input: Tensor([1])
  154. output = reduce_sum(input)
  155. # Rank 0 # output: Tensor([1])
  156. # Rank 1 # output: None
  157. input = Tensor([rank])
  158. group = Group([1, 0]) # first rank is root
  159. output = reduce_sum(input, group)
  160. # Rank 0 # output: None
  161. # Rank 1 # output: Tensor([1])
  162. """
  163. op = _ReduceSum(group, device)
  164. (out,) = apply(op, inp)
  165. if group.rank == 0:
  166. return out
  167. else:
  168. _save_output_for_autodiff(inp, out)
  169. class _Broadcast(Function):
  170. def __init__(self, group=WORLD, device=None):
  171. self.group = group
  172. self.out_device = device
  173. def forward(self, data):
  174. self.in_device = str(data.device)
  175. return collective_comm(
  176. data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device,
  177. )
  178. def backward(self, grad):
  179. # TODO backward with a part of grad
  180. if grad is not None:
  181. return reduce_sum(grad, self.group, self.in_device)
  182. def broadcast(
  183. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  184. ) -> Tensor:
  185. r"""
  186. Broadcast tensor data from root process to others.
  187. Args:
  188. inp: Input tensor.
  189. group: The process group to work on.
  190. The default group is WORLD which means all processes available.
  191. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  192. device: The specific device to execute this operator.
  193. None default device means the device of inp will be used.
  194. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  195. 1 is stream id, and default stream id is 0.
  196. Returns:
  197. Result tensor.
  198. Examples:
  199. .. code-block::
  200. input = Tensor([rank])
  201. # Rank 0 # input: Tensor([0])
  202. # Rank 1 # input: Tensor([1])
  203. output = broadcast(input)
  204. # Rank 0 # output: Tensor([0])
  205. # Rank 1 # output: Tensor([0])
  206. input = Tensor([rank])
  207. group = Group([1, 0]) # first rank is root
  208. output = broadcast(input, group)
  209. # Rank 0 # output: Tensor([1])
  210. # Rank 1 # output: Tensor([1])
  211. """
  212. shape, dtype = _bcast_shape_dtype(group, inp)
  213. if group.rank != 0:
  214. # dummy input to infer shape
  215. inp = _dummy_input(shape, dtype, device)
  216. _bcast_tracer_state(group, inp)
  217. op = _Broadcast(group, device)
  218. (out,) = apply(op, inp)
  219. return out
  220. def _bcast_param(
  221. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None
  222. ) -> Tensor:
  223. mode = CollectiveComm.Mode.BROADCAST
  224. return collective_comm(inp, mode, group, device)
  225. def all_gather(
  226. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  227. ) -> Tensor:
  228. r"""
  229. Gather tensors across the specified group and concat them at first dimension.
  230. Args:
  231. inp: Input tensor.
  232. group: The process group to work on.
  233. The default group is WORLD which means all processes available.
  234. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  235. device: The specific device to execute this operator.
  236. None default device means the device of inp will be used.
  237. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  238. 1 is stream id, and default stream id is 0.
  239. Returns:
  240. Result tensor.
  241. Examples:
  242. .. code-block::
  243. input = Tensor([rank])
  244. # Rank 0 # input: Tensor([0])
  245. # Rank 1 # input: Tensor([1])
  246. output = all_gather(input)
  247. # Rank 0 # output: Tensor([0 1])
  248. # Rank 1 # output: Tensor([0 1])
  249. input = Tensor([rank])
  250. group = Group([1, 0])
  251. output = all_gather(input, group)
  252. # Rank 0 # output: Tensor([1 0])
  253. # Rank 1 # output: Tensor([1 0])
  254. """
  255. mode = CollectiveComm.Mode.ALL_GATHER
  256. return collective_comm(inp, mode, group, device)
  257. def reduce_scatter_sum(
  258. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  259. ) -> Tensor:
  260. r"""
  261. Reduce tensors across the specified group by sum and split them at first dimension.
  262. Args:
  263. inp: Input tensor.
  264. group: The process group to work on.
  265. The default group is WORLD which means all processes available.
  266. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  267. device: The specific device to execute this operator.
  268. None default device means the device of inp will be used.
  269. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  270. 1 is stream id, and default stream id is 0.
  271. Returns:
  272. Split tensor.
  273. Examples:
  274. .. code-block::
  275. input = Tensor([0 1])
  276. # Rank 0 # input: Tensor([0 1])
  277. # Rank 1 # input: Tensor([0 1])
  278. output = reduce_scatter_sum(input)
  279. # Rank 0 # output: Tensor([0])
  280. # Rank 1 # output: Tensor([2])
  281. input = Tensor([0 1])
  282. group = Group([1, 0])
  283. output = reduce_scatter_sum(input, group)
  284. # Rank 0 # output: Tensor([2])
  285. # Rank 1 # output: Tensor([0])
  286. """
  287. mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
  288. return collective_comm(inp, mode, group, device)
  289. def all_reduce_sum(
  290. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  291. ) -> Tensor:
  292. r"""
  293. Reduce tensors across the specified group by sum.
  294. Args:
  295. inp: Input tensor.
  296. group: The process group to work on.
  297. The default group is WORLD which means all processes available.
  298. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  299. device: The specific device to execute this operator.
  300. None default device means the device of inp will be used.
  301. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  302. 1 is stream id, and default stream id is 0.
  303. Returns:
  304. Result tensor.
  305. Examples:
  306. .. code-block::
  307. input = Tensor(rank)
  308. # Rank 0 # input: Tensor(0)
  309. # Rank 1 # input: Tensor(1)
  310. output = all_reduce_sum(input)
  311. # Rank 0 # output: Tensor(1)
  312. # Rank 1 # output: Tensor(1)
  313. """
  314. mode = CollectiveComm.Mode.ALL_REDUCE_SUM
  315. return collective_comm(inp, mode, group, device)
  316. def all_reduce_max(
  317. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  318. ) -> Tensor:
  319. r"""
  320. Reduce tensors across the specified group by max.
  321. Args:
  322. inp: Input tensor.
  323. group: The process group to work on.
  324. The default group is WORLD which means all processes available.
  325. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  326. device: The specific device to execute this operator.
  327. None default device means the device of inp will be used.
  328. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  329. 1 is stream id, and default stream id is 0.
  330. Returns:
  331. Result tensor.
  332. Examples:
  333. .. code-block::
  334. input = Tensor(rank)
  335. # Rank 0 # input: Tensor(0)
  336. # Rank 1 # input: Tensor(1)
  337. output = all_reduce_max(input)
  338. # Rank 0 # output: Tensor(1)
  339. # Rank 1 # output: Tensor(1)
  340. """
  341. mode = CollectiveComm.Mode.ALL_REDUCE_MAX
  342. return collective_comm(inp, mode, group, device)
  343. def all_reduce_min(
  344. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  345. ) -> Tensor:
  346. r"""
  347. Reduce tensors across the specified group by min.
  348. Args:
  349. inp: Input tensor.
  350. group: The process group to work on.
  351. The default group is WORLD which means all processes available.
  352. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  353. device: The specific device to execute this operator.
  354. None default device means the device of inp will be used.
  355. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  356. 1 is stream id, and default stream id is 0.
  357. Returns:
  358. Result tensor.
  359. Examples:
  360. .. code-block::
  361. input = Tensor(rank)
  362. # Rank 0 # input: Tensor(0)
  363. # Rank 1 # input: Tensor(1)
  364. output = all_reduce_min(input)
  365. # Rank 0 # output: Tensor(0)
  366. # Rank 1 # output: Tensor(0)
  367. """
  368. mode = CollectiveComm.Mode.ALL_REDUCE_MIN
  369. return collective_comm(inp, mode, group, device)
  370. class _Gather(Function):
  371. def __init__(self, group=WORLD, device=None):
  372. self.group = group
  373. self.out_device = device
  374. def forward(self, data):
  375. self.in_device = str(data.device)
  376. return collective_comm(
  377. data, CollectiveComm.Mode.GATHER, self.group, self.out_device
  378. )
  379. def backward(self, grad):
  380. has_grad = _bcast_has_grad(self.group, grad)
  381. if has_grad:
  382. return scatter(grad, self.group, self.in_device)
  383. def gather(
  384. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  385. ) -> Tensor:
  386. r"""
  387. Gather tensors across the specified group.
  388. Only root process will receive the final result.
  389. Args:
  390. inp: Input tensor.
  391. group: The process group to work on.
  392. The default group is WORLD which means all processes available.
  393. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  394. device: The specific device to execute this operator.
  395. None default device means the device of inp will be used.
  396. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  397. 1 is stream id, and default stream id is 0.
  398. Returns:
  399. Result tensor if in root process, None if in other process
  400. Examples:
  401. .. code-block::
  402. input = Tensor([rank])
  403. # Rank 0 # input: Tensor([0])
  404. # Rank 1 # input: Tensor([1])
  405. output = gather(input)
  406. # Rank 0 # output: Tensor([0 1])
  407. # Rank 1 # output: None
  408. input = Tensor([rank])
  409. group = Group([1, 0]) # first rank is root
  410. output = gather(input, group)
  411. # Rank 0 # output: None
  412. # Rank 1 # output: Tensor([1 0])
  413. """
  414. op = _Gather(group, device)
  415. (out,) = apply(op, inp)
  416. if group.rank == 0:
  417. return out
  418. else:
  419. _save_output_for_autodiff(inp, out)
  420. class _Scatter(Function):
  421. def __init__(self, group=WORLD, device=None):
  422. self.group = group
  423. self.out_device = device
  424. def forward(self, data):
  425. self.in_device = str(data.device)
  426. return collective_comm(
  427. data, CollectiveComm.Mode.SCATTER, self.group, self.out_device
  428. )
  429. def backward(self, grad):
  430. # TODO backward with a part of grad
  431. if grad is not None:
  432. return gather(grad, self.group, self.in_device)
  433. def scatter(
  434. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  435. ) -> Tensor:
  436. r"""
  437. Split tensor in root process at first dimension.
  438. Args:
  439. inp: Input tensor.
  440. group: The process group to work on.
  441. The default group is WORLD which means all processes available.
  442. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  443. device: The specific device to execute this operator.
  444. None default device means the device of inp will be used.
  445. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  446. 1 is stream id, and default stream id is 0.
  447. Returns:
  448. Split tensor.
  449. Examples:
  450. .. code-block::
  451. input = Tensor([0 1]) + rank*2
  452. # Rank 0 # input: Tensor([0 1])
  453. # Rank 1 # input: Tensor([2 3])
  454. output = scatter(input)
  455. # Rank 0 # output: Tensor([0])
  456. # Rank 1 # output: Tensor([1])
  457. input = Tensor([0 1]) + rank*2
  458. group = Group([1, 0]) # first rank is root
  459. output = scatter(input, group)
  460. # Rank 0 # output: Tensor([3])
  461. # Rank 1 # output: Tensor([2])
  462. """
  463. shape, dtype = _bcast_shape_dtype(group, inp)
  464. if group.rank != 0:
  465. # dummy input to infer shape
  466. inp = _dummy_input(shape, dtype, device)
  467. _bcast_tracer_state(group, inp)
  468. op = _Scatter(group, device)
  469. (out,) = apply(op, inp)
  470. return out
  471. def all_to_all(
  472. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  473. ) -> Tensor:
  474. r"""
  475. Each process scatter input tensor to all processes and return gathered tensor.
  476. Args:
  477. inp: Input tensor.
  478. group: The process group to work on.
  479. The default group is WORLD which means all processes available.
  480. You can use a list of process ranks to create new group to work on it, e.g. [1, 3, 5].
  481. device: The specific device to execute this operator.
  482. None default device means the device of inp will be used.
  483. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  484. 1 is stream id, and default stream id is 0.
  485. Returns:
  486. Result tensor.
  487. Examples:
  488. .. code-block::
  489. input = Tensor([0 1]) + rank*2
  490. # Rank 0 # input: Tensor([0 1])
  491. # Rank 1 # input: Tensor([2 3])
  492. output = all_to_all(input)
  493. # Rank 0 # output: Tensor([0 2])
  494. # Rank 1 # output: Tensor([1 3])
  495. input = Tensor([0 1]) + rank*2
  496. group = Group([1, 0])
  497. output = all_to_all(input, group)
  498. # Rank 0 # output: Tensor([0 3])
  499. # Rank 1 # output: Tensor([2 1])
  500. """
  501. mode = CollectiveComm.Mode.ALL_TO_ALL
  502. return collective_comm(inp, mode, group, device)
  503. class _SendRecvGroup:
  504. def __init__(self, rank_from, rank_to):
  505. self.key = "{}->{}".format(rank_from, rank_to)
  506. self.rank_from = rank_from
  507. self.rank_to = rank_to
  508. self.size = 2
  509. @property
  510. def rank(self):
  511. if get_rank() == self.rank_from:
  512. return 0
  513. else:
  514. return 1
  515. class _RemoteSend(Function):
  516. def __init__(self, op: RemoteSend):
  517. self.op = op
  518. def forward(self, data):
  519. self.device = str(data.device)
  520. (self.dummy,) = apply(self.op, data)
  521. return self.dummy
  522. def backward(self, grad):
  523. assert grad is None
  524. has_grad = get_client().bcast_val(None, self.op.key, 2)
  525. if has_grad:
  526. return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,)
  527. class _RemoteRecv(Function):
  528. def __init__(self, op: RemoteRecv):
  529. self.op = op
  530. def forward(self, dummy):
  531. return apply(self.op, dummy)
  532. def backward(self, grad):
  533. get_client().bcast_val(grad is not None, self.op.key, 2)
  534. if grad is not None:
  535. remote_send(grad, self.op.rank_from)
  536. def remote_send(inp: Tensor, dest_rank: int):
  537. r"""
  538. Send tensor to another process.
  539. Args:
  540. inp: Tensor to send.
  541. dest_rank: Rank of destination process.
  542. Returns:
  543. None.
  544. Examples:
  545. .. code-block::
  546. if rank == 0:
  547. data = mge.tensor(1)
  548. # Tensor(1)
  549. F.distributed.remote_send(data, 1) # return None
  550. else:
  551. data = F.distributed.remote_recv(0)
  552. # Tensor(1)
  553. """
  554. group = _SendRecvGroup(get_rank(), dest_rank)
  555. _bcast_shape_dtype(group, inp)
  556. _bcast_tracer_state(group, inp)
  557. op = RemoteSend()
  558. op.key = group.key
  559. op.addr, op.port = get_mm_server_addr()
  560. op.rank_to = dest_rank
  561. op.backend = _backend()
  562. (out,) = apply(_RemoteSend(op), inp)
  563. _save_output_for_autodiff(inp, out)
  564. def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor:
  565. r"""
  566. Receive a tensor from another process.
  567. Args:
  568. src_rank: Rank of source process.
  569. device: The specific device to execute this operator.
  570. None default device means the device of inp will be used.
  571. Specify "gpu0:1" to execute this operator on diffrent cuda stream,
  572. 1 is stream id, and default stream id is 0.
  573. inp: Dummy input to determine received tensor type.
  574. Returns:
  575. Received tensor.
  576. Examples:
  577. .. code-block::
  578. if rank == 0:
  579. data = mge.tensor(1)
  580. # Tensor(1)
  581. F.distributed.remote_send(data, 1) # return None
  582. else:
  583. data = F.distributed.remote_recv(0)
  584. # Tensor(1)
  585. """
  586. group = _SendRecvGroup(src_rank, get_rank())
  587. shape, dtype = _bcast_shape_dtype(group, None)
  588. if device is None:
  589. device = get_default_device()
  590. # dummy input
  591. if inp is None:
  592. inp = Tensor(0, device=device)
  593. _bcast_tracer_state(group, inp)
  594. _isscalar = False
  595. if len(shape) == 0:
  596. shape = (1,)
  597. _isscalar = True
  598. op = RemoteRecv()
  599. op.key = group.key
  600. op.cn = device
  601. op.shape = shape
  602. op.dtype = dtype
  603. op.addr, op.port = get_mm_server_addr()
  604. op.rank_from = src_rank
  605. op.backend = _backend()
  606. (ret,) = apply(_RemoteRecv(op), inp)
  607. if _isscalar:
  608. setscalar(ret)
  609. return ret

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台