You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Tuple
  10. import numpy as np
  11. from ..core._imperative_rt.core2 import apply
  12. from ..core.autodiff.grad import Function, _grad_manager_dict
  13. from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
  14. from ..core.tensor.utils import isscalar, setscalar
  15. from ..device import get_default_device, what_is_xpu
  16. from ..tensor import Tensor
  17. from . import group
  18. from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank
  19. __all__ = [
  20. "reduce_sum",
  21. "broadcast",
  22. "all_gather",
  23. "reduce_scatter_sum",
  24. "all_reduce_sum",
  25. "all_reduce_max",
  26. "all_reduce_min",
  27. "gather",
  28. "scatter",
  29. "all_to_all",
  30. "remote_send",
  31. "remote_recv",
  32. ]
  33. _device2backend = {
  34. "gpu": "nccl",
  35. "cuda": "nccl",
  36. "rocm": "rccl",
  37. }
  38. def _backend():
  39. if group._sd.backend == "auto":
  40. return _device2backend[what_is_xpu()]
  41. else:
  42. return group._sd.backend
  43. def collective_comm(inp, mode, group, device):
  44. """Helper function for applying collective communication functions."""
  45. assert isinstance(group, Group)
  46. if group is None:
  47. return inp
  48. if device is None:
  49. device = ""
  50. addr, port = get_mm_server_addr()
  51. op = CollectiveComm(
  52. key=group.key + _backend(),
  53. nr_devices=group.size,
  54. rank=group.rank,
  55. is_root=(group.rank == 0),
  56. local_grad=False,
  57. addr=addr,
  58. port=port,
  59. mode=mode,
  60. dtype=inp.dtype,
  61. backend=_backend(),
  62. comp_node=device,
  63. )
  64. (result,) = apply(op, inp)
  65. # assume all workers have homogeneous shape
  66. if mode in (
  67. CollectiveComm.Mode.REDUCE_SUM,
  68. CollectiveComm.Mode.BROADCAST,
  69. CollectiveComm.Mode.ALL_REDUCE_SUM,
  70. CollectiveComm.Mode.ALL_REDUCE_MAX,
  71. CollectiveComm.Mode.ALL_REDUCE_MIN,
  72. ):
  73. if isscalar(inp):
  74. setscalar(result)
  75. return result
  76. def _save_output_for_autodiff(inp, out):
  77. for g in _grad_manager_dict.values():
  78. if g._is_attached_to(inp):
  79. g._refkeeper.append(out)
  80. def _bcast_has_grad(group, grad):
  81. if group.rank == 0:
  82. has_grad = grad is not None
  83. get_client().bcast_val(has_grad, group.key, group.size)
  84. else:
  85. has_grad = get_client().bcast_val(None, group.key, group.size)
  86. return has_grad
  87. def _bcast_shape_dtype(group, inp):
  88. if group.rank == 0:
  89. # FIXME in some cases, shape is not available(output of condtake)
  90. shape = inp._tuple_shape
  91. dtype = np.dtype(inp.dtype).name
  92. get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size)
  93. else:
  94. val = get_client().bcast_val(None, group.key, group.size)
  95. shape = val["shape"]
  96. dtype = val["dtype"]
  97. return shape, dtype
  98. def _bcast_tracer_state(group, inp):
  99. if group.rank == 0:
  100. tracer_keys = []
  101. for n, g in _grad_manager_dict.items():
  102. if g._is_attached_to(inp):
  103. tracer_keys.append(n)
  104. get_client().bcast_val(tracer_keys, group.key, group.size)
  105. else:
  106. tracer_keys = get_client().bcast_val(None, group.key, group.size)
  107. for n in tracer_keys:
  108. g = _grad_manager_dict.get(n)
  109. if g is not None:
  110. g.wrt(inp)
  111. g._refkeeper.append(inp)
  112. def _dummy_input(shape, dtype, device=None):
  113. if device is None:
  114. device = get_default_device()
  115. inp = Tensor(0, dtype=dtype, device=device)
  116. if len(shape) > 0:
  117. inp = inp._broadcast(shape)
  118. return inp
  119. class _ReduceSum(Function):
  120. def __init__(self, group=WORLD, device=None):
  121. self.group = group
  122. self.out_device = device
  123. def forward(self, data):
  124. self.in_device = str(data.device)
  125. return collective_comm(
  126. data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device,
  127. )
  128. def backward(self, grad):
  129. has_grad = _bcast_has_grad(self.group, grad)
  130. if has_grad:
  131. return broadcast(grad, self.group, self.in_device)
  132. def reduce_sum(
  133. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  134. ) -> Tensor:
  135. """
  136. Create reduce_sum operator for collective communication.
  137. :param inp: input tensor.
  138. :param group: communication group.
  139. :param device: execution device.
  140. """
  141. op = _ReduceSum(group, device)
  142. (out,) = apply(op, inp)
  143. if group.rank == 0:
  144. return out
  145. else:
  146. _save_output_for_autodiff(inp, out)
  147. class _Broadcast(Function):
  148. def __init__(self, group=WORLD, device=None):
  149. self.group = group
  150. self.out_device = device
  151. def forward(self, data):
  152. self.in_device = str(data.device)
  153. return collective_comm(
  154. data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device,
  155. )
  156. def backward(self, grad):
  157. # TODO backward with a part of grad
  158. if grad is not None:
  159. return reduce_sum(grad, self.group, self.in_device)
  160. def broadcast(
  161. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  162. ) -> Tensor:
  163. """
  164. Create broadcast operator for collective communication.
  165. :param inp: input tensor.
  166. :param group: communication group.
  167. :param device: execution device.
  168. """
  169. shape, dtype = _bcast_shape_dtype(group, inp)
  170. if group.rank != 0:
  171. # dummy input to infer shape
  172. inp = _dummy_input(shape, dtype, device)
  173. _bcast_tracer_state(group, inp)
  174. op = _Broadcast(group, device)
  175. (out,) = apply(op, inp)
  176. return out
  177. def _bcast_param(
  178. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None
  179. ) -> Tensor:
  180. mode = CollectiveComm.Mode.BROADCAST
  181. return collective_comm(inp, mode, group, device)
  182. def all_gather(
  183. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  184. ) -> Tensor:
  185. """
  186. Create all_gather operator for collective communication.
  187. :param inp: input tensor.
  188. :param group: communication group.
  189. :param device: execution device.
  190. """
  191. mode = CollectiveComm.Mode.ALL_GATHER
  192. return collective_comm(inp, mode, group, device)
  193. def reduce_scatter_sum(
  194. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  195. ) -> Tensor:
  196. """
  197. Create reduce_scatter_sum operator for collective communication.
  198. :param inp: input tensor.
  199. :param group: communication group.
  200. :param device: execution device.
  201. """
  202. mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
  203. return collective_comm(inp, mode, group, device)
  204. def all_reduce_sum(
  205. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  206. ) -> Tensor:
  207. r"""
  208. Create all_reduce_sum operator for collective communication.
  209. This operator sums the tensor data by coordinates across the specified group and returns a tensor with the shape of the input tensor.
  210. Args:
  211. inp: The tensor data to apply this operator on.
  212. group: The communication node list instance of :class:'Group' to apply this operator across. The default group is WORLD which means all processes available.
  213. Specify a list of process ranks to apply this operator on specific processes, e.g. [1, 3, 5].
  214. device: The specific device type of :class:'str' to execute this operator. The default device is None which mean the device of inp will be used.
  215. Specify "cpu" or "gpu" to execute this operator on specific devices.
  216. Returns:
  217. opt: The reduce sum tensor of the input tensor data across the specified group.
  218. Examples:
  219. .. code-block::
  220. import megengine as mge
  221. import megengine.distributed as dist
  222. import numpy as np
  223. from warnings import warn
  224. def func(sum_value):
  225. # get the rank of this process, the ranks shold be 0, 1, 2, 3 for a 4 gpu task
  226. rank = dist.get_rank()
  227. data = mge.tensor(rank)
  228. # the result should be n * (n - 1) / 2 for all processes
  229. result = mge.functional.distributed.all_reduce_sum(data).item()
  230. assert result == sum_value
  231. def main():
  232. p_num = mge.device.get_device_count("gpu")
  233. if p_num < 2:
  234. warn('This opr only works on group with more than one gpu')
  235. return
  236. method = dist.launcher(func)
  237. method(p_num * (p_num - 1) // 2)
  238. if __name__ == '__main__':
  239. main()
  240. """
  241. mode = CollectiveComm.Mode.ALL_REDUCE_SUM
  242. return collective_comm(inp, mode, group, device)
  243. def all_reduce_max(
  244. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  245. ) -> Tensor:
  246. """
  247. Create all_reduce_max operator for collective communication.
  248. :param inp: input tensor.
  249. :param group: communication group.
  250. :param device: execution device.
  251. :returns: reduced tensor.
  252. """
  253. mode = CollectiveComm.Mode.ALL_REDUCE_MAX
  254. return collective_comm(inp, mode, group, device)
  255. def all_reduce_min(
  256. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  257. ) -> Tensor:
  258. r"""
  259. Create all_reduce_min operator for collective communication.
  260. This operator calculates the minimum value of the tensor data by coordinates across the specified group and returns a tensor with the shape of the input tensor.
  261. Args:
  262. inp: The tensor data to apply this operator on.
  263. group: The communication node list instance of :class:'Group' to apply this operator across. The default group is WORLD which means all processes available.
  264. Specify a list of process ranks to apply this operator on specific processes, e.g. [1, 3, 5].
  265. device: The specific device type of :class:'str' to execute this operator. The default device is None which mean the device of inp will be used.
  266. Specify "cpu" or "gpu" to execute this operator on specific devices.
  267. Returns:
  268. opt: The reduce min tensor of the input tensor data across the specified group.
  269. Examples:
  270. .. code-block::
  271. import megengine as mge
  272. import megengine.distributed as dist
  273. import numpy as np
  274. from warnings import warn
  275. def func(min_value):
  276. # get the rank of this process, the ranks shold be 0, 1, 2, 3 for a 4 gpu task
  277. rank = dist.get_rank()
  278. data = mge.Tensor(rank)
  279. # the result should be 0 for all processes
  280. result = mge.functional.distributed.all_reduce_min(data).item()
  281. assert result == min_value
  282. def main():
  283. p_num = dist.helper.get_device_count("gpu")
  284. if p_num < 2:
  285. warn('This opr only works on group with more than one gpu')
  286. return
  287. method = dist.launcher(func)
  288. method(0)
  289. if __name__ == '__main__':
  290. main()
  291. """
  292. mode = CollectiveComm.Mode.ALL_REDUCE_MIN
  293. return collective_comm(inp, mode, group, device)
  294. class _Gather(Function):
  295. def __init__(self, group=WORLD, device=None):
  296. self.group = group
  297. self.out_device = device
  298. def forward(self, data):
  299. self.in_device = str(data.device)
  300. return collective_comm(
  301. data, CollectiveComm.Mode.GATHER, self.group, self.out_device
  302. )
  303. def backward(self, grad):
  304. has_grad = _bcast_has_grad(self.group, grad)
  305. if has_grad:
  306. return scatter(grad, self.group, self.in_device)
  307. def gather(
  308. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  309. ) -> Tensor:
  310. """
  311. Create gather operator for collective communication.
  312. :param inp: input tensor.
  313. :param group: communication group.
  314. :param device: execution device.
  315. """
  316. op = _Gather(group, device)
  317. (out,) = apply(op, inp)
  318. if group.rank == 0:
  319. return out
  320. else:
  321. _save_output_for_autodiff(inp, out)
  322. class _Scatter(Function):
  323. def __init__(self, group=WORLD, device=None):
  324. self.group = group
  325. self.out_device = device
  326. def forward(self, data):
  327. self.in_device = str(data.device)
  328. return collective_comm(
  329. data, CollectiveComm.Mode.SCATTER, self.group, self.out_device
  330. )
  331. def backward(self, grad):
  332. # TODO backward with a part of grad
  333. if grad is not None:
  334. return gather(grad, self.group, self.in_device)
  335. def scatter(
  336. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  337. ) -> Tensor:
  338. """
  339. Create scatter operator for collective communication.
  340. :param inp: input tensor.
  341. :param group: communication group.
  342. :param device: execution device.
  343. """
  344. shape, dtype = _bcast_shape_dtype(group, inp)
  345. if group.rank != 0:
  346. # dummy input to infer shape
  347. inp = _dummy_input(shape, dtype, device)
  348. _bcast_tracer_state(group, inp)
  349. op = _Scatter(group, device)
  350. (out,) = apply(op, inp)
  351. return out
  352. def all_to_all(
  353. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
  354. ) -> Tensor:
  355. """
  356. Create all_to_all operator for collective communication.
  357. :param inp: input tensor.
  358. :param group: communication group.
  359. :param device: execution device.
  360. """
  361. mode = CollectiveComm.Mode.ALL_TO_ALL
  362. return collective_comm(inp, mode, group, device)
  363. class _SendRecvGroup:
  364. def __init__(self, rank_from, rank_to):
  365. self.key = "{}->{}".format(rank_from, rank_to)
  366. self.rank_from = rank_from
  367. self.rank_to = rank_to
  368. self.size = 2
  369. @property
  370. def rank(self):
  371. if get_rank() == self.rank_from:
  372. return 0
  373. else:
  374. return 1
  375. class _RemoteSend(Function):
  376. def __init__(self, op: RemoteSend):
  377. self.op = op
  378. def forward(self, data):
  379. self.device = str(data.device)
  380. (self.dummy,) = apply(self.op, data)
  381. return self.dummy
  382. def backward(self, grad):
  383. assert grad is None
  384. has_grad = get_client().bcast_val(None, self.op.key, 2)
  385. if has_grad:
  386. return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,)
  387. class _RemoteRecv(Function):
  388. def __init__(self, op: RemoteRecv):
  389. self.op = op
  390. def forward(self, dummy):
  391. return apply(self.op, dummy)
  392. def backward(self, grad):
  393. get_client().bcast_val(grad is not None, self.op.key, 2)
  394. if grad is not None:
  395. remote_send(grad, self.op.rank_from)
  396. def remote_send(inp: Tensor, dest_rank: int):
  397. """
  398. Send a Tensor to a remote process.
  399. :param inp: tensor to send.
  400. :param dest_rank: destination process rank.
  401. """
  402. group = _SendRecvGroup(get_rank(), dest_rank)
  403. _bcast_shape_dtype(group, inp)
  404. _bcast_tracer_state(group, inp)
  405. op = RemoteSend()
  406. op.key = group.key
  407. op.addr, op.port = get_mm_server_addr()
  408. op.rank_to = dest_rank
  409. op.backend = _backend()
  410. (out,) = apply(_RemoteSend(op), inp)
  411. _save_output_for_autodiff(inp, out)
  412. def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor:
  413. """
  414. Receive a Tensor from a remote process.
  415. :param src_rank: source process rank.
  416. :param device: the device to place the received tensor.
  417. :param inp: dummy input to determine recved tensor type
  418. """
  419. group = _SendRecvGroup(src_rank, get_rank())
  420. shape, dtype = _bcast_shape_dtype(group, None)
  421. if device is None:
  422. device = get_default_device()
  423. # dummy input
  424. if inp is None:
  425. inp = Tensor(0, device=device)
  426. _bcast_tracer_state(group, inp)
  427. _isscalar = False
  428. if len(shape) == 0:
  429. shape = (1,)
  430. _isscalar = True
  431. op = RemoteRecv()
  432. op.key = group.key
  433. op.cn = device
  434. op.shape = shape
  435. op.dtype = dtype
  436. op.addr, op.port = get_mm_server_addr()
  437. op.rank_from = src_rank
  438. op.backend = _backend()
  439. (ret,) = apply(_RemoteRecv(op), inp)
  440. if _isscalar:
  441. setscalar(ret)
  442. return ret

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台