You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Tuple
  10. import numpy as np
  11. from ..core._imperative_rt.core2 import apply
  12. from ..core.autodiff.grad import Function, _grad_manager_dict
  13. from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
  14. from ..core.tensor.utils import isscalar, setscalar
  15. from ..device import get_default_device
  16. from ..tensor import Tensor
  17. from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank
  18. __all__ = [
  19. "reduce_sum",
  20. "broadcast",
  21. "all_gather",
  22. "reduce_scatter_sum",
  23. "all_reduce_sum",
  24. "all_reduce_max",
  25. "all_reduce_min",
  26. "gather",
  27. "scatter",
  28. "all_to_all",
  29. "remote_send",
  30. "remote_recv",
  31. ]
  32. def collective_comm(inp, mode, group, device):
  33. """Helper function for applying collective communication functions."""
  34. assert isinstance(group, Group)
  35. if group is None:
  36. return inp
  37. addr, port = get_mm_server_addr()
  38. op = CollectiveComm(
  39. key=group.key,
  40. nr_devices=group.size,
  41. rank=group.rank,
  42. is_root=(group.rank == 0),
  43. local_grad=False,
  44. addr=addr,
  45. port=port,
  46. mode=mode,
  47. dtype=inp.dtype,
  48. backend=get_backend(),
  49. comp_node=device,
  50. )
  51. (result,) = apply(op, inp)
  52. # assume all workers have homogeneous shape
  53. if mode in (
  54. CollectiveComm.Mode.REDUCE_SUM,
  55. CollectiveComm.Mode.BROADCAST,
  56. CollectiveComm.Mode.ALL_REDUCE_SUM,
  57. CollectiveComm.Mode.ALL_REDUCE_MAX,
  58. CollectiveComm.Mode.ALL_REDUCE_MIN,
  59. ):
  60. if isscalar(inp):
  61. setscalar(result)
  62. return result
  63. def _save_output_for_autodiff(inp, out):
  64. for g in _grad_manager_dict.values():
  65. if g._is_attached_to(inp):
  66. g._refkeeper.append(out)
  67. def _bcast_has_grad(group, grad):
  68. if group.rank == 0:
  69. has_grad = grad is not None
  70. get_client().bcast_val(has_grad, group.key, group.size)
  71. else:
  72. has_grad = get_client().bcast_val(None, group.key, group.size)
  73. return has_grad
  74. def _bcast_shape_dtype(group, inp):
  75. if group.rank == 0:
  76. # FIXME in some cases, shape is not available(output of condtake)
  77. shape = inp._tuple_shape
  78. dtype = np.dtype(inp.dtype).name
  79. get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size)
  80. else:
  81. val = get_client().bcast_val(None, group.key, group.size)
  82. shape = val["shape"]
  83. dtype = val["dtype"]
  84. return shape, dtype
  85. def _bcast_tracer_state(group, inp):
  86. if group.rank == 0:
  87. tracer_keys = []
  88. for n, g in _grad_manager_dict.items():
  89. if g._is_attached_to(inp):
  90. tracer_keys.append(n)
  91. get_client().bcast_val(tracer_keys, group.key, group.size)
  92. else:
  93. tracer_keys = get_client().bcast_val(None, group.key, group.size)
  94. for n in tracer_keys:
  95. g = _grad_manager_dict.get(n)
  96. if g is not None:
  97. g.wrt(inp)
  98. g._refkeeper.append(inp)
  99. def _dummy_input(shape, dtype, device=""):
  100. if device == "":
  101. device = get_default_device()
  102. inp = Tensor(0, dtype=dtype, device=device)
  103. if len(shape) > 0:
  104. inp = inp._broadcast(shape)
  105. return inp
  106. class _ReduceSum(Function):
  107. def __init__(self, group=WORLD, device=""):
  108. self.group = group
  109. self.out_device = device
  110. def forward(self, data):
  111. self.in_device = str(data.device)
  112. return collective_comm(
  113. data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device
  114. )
  115. def backward(self, grad):
  116. has_grad = _bcast_has_grad(self.group, grad)
  117. if has_grad:
  118. return broadcast(grad, self.group, self.in_device)
  119. def reduce_sum(
  120. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  121. ) -> Tensor:
  122. """
  123. Create reduce_sum operator for collective communication.
  124. :param inp: input tensor.
  125. :param group: communication group.
  126. :param device: execution device.
  127. """
  128. op = _ReduceSum(group, device)
  129. (out,) = apply(op, inp)
  130. if group.rank == 0:
  131. return out
  132. else:
  133. _save_output_for_autodiff(inp, out)
  134. class _Broadcast(Function):
  135. def __init__(self, group=WORLD, device=""):
  136. self.group = group
  137. self.out_device = device
  138. def forward(self, data):
  139. self.in_device = str(data.device)
  140. return collective_comm(
  141. data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device
  142. )
  143. def backward(self, grad):
  144. # TODO backward with a part of grad
  145. if grad is not None:
  146. return reduce_sum(grad, self.group, self.in_device)
  147. def broadcast(
  148. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  149. ) -> Tensor:
  150. """
  151. Create broadcast operator for collective communication.
  152. :param inp: input tensor.
  153. :param group: communication group.
  154. :param device: execution device.
  155. """
  156. shape, dtype = _bcast_shape_dtype(group, inp)
  157. if group.rank != 0:
  158. # dummy input to infer shape
  159. inp = _dummy_input(shape, dtype, device)
  160. _bcast_tracer_state(group, inp)
  161. op = _Broadcast(group, device)
  162. (out,) = apply(op, inp)
  163. return out
  164. def all_gather(
  165. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  166. ) -> Tensor:
  167. """
  168. Create all_gather operator for collective communication.
  169. :param inp: input tensor.
  170. :param group: communication group.
  171. :param device: execution device.
  172. """
  173. mode = CollectiveComm.Mode.ALL_GATHER
  174. return collective_comm(inp, mode, group, device)
  175. def reduce_scatter_sum(
  176. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  177. ) -> Tensor:
  178. """
  179. Create reduce_scatter_sum operator for collective communication.
  180. :param inp: input tensor.
  181. :param group: communication group.
  182. :param device: execution device.
  183. """
  184. mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
  185. return collective_comm(inp, mode, group, device)
  186. def all_reduce_sum(
  187. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  188. ) -> Tensor:
  189. """
  190. Create all_reduce_sum operator for collective communication.
  191. :param inp: input tensor.
  192. :param group: communication group.
  193. :param device: execution device.
  194. """
  195. mode = CollectiveComm.Mode.ALL_REDUCE_SUM
  196. return collective_comm(inp, mode, group, device)
  197. def all_reduce_max(
  198. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  199. ) -> Tensor:
  200. """
  201. Create all_reduce_max operator for collective communication.
  202. :param inp: input tensor.
  203. :param group: communication group.
  204. :param device: execution device.
  205. """
  206. mode = CollectiveComm.Mode.ALL_REDUCE_MAX
  207. return collective_comm(inp, mode, group, device)
  208. def all_reduce_min(
  209. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  210. ) -> Tensor:
  211. """
  212. Create all_reduce_min operator for collective communication.
  213. :param inp: input tensor.
  214. :param group: communication group.
  215. :param device: execution device.
  216. """
  217. mode = CollectiveComm.Mode.ALL_REDUCE_MIN
  218. return collective_comm(inp, mode, group, device)
  219. class _Gather(Function):
  220. def __init__(self, group=WORLD, device=""):
  221. self.group = group
  222. self.out_device = device
  223. def forward(self, data):
  224. self.in_device = str(data.device)
  225. return collective_comm(
  226. data, CollectiveComm.Mode.GATHER, self.group, self.out_device
  227. )
  228. def backward(self, grad):
  229. has_grad = _bcast_has_grad(self.group, grad)
  230. if has_grad:
  231. return scatter(grad, self.group, self.in_device)
  232. def gather(
  233. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  234. ) -> Tensor:
  235. """
  236. Create gather operator for collective communication.
  237. :param inp: input tensor.
  238. :param group: communication group.
  239. :param device: execution device.
  240. """
  241. op = _Gather(group, device)
  242. (out,) = apply(op, inp)
  243. if group.rank == 0:
  244. return out
  245. else:
  246. _save_output_for_autodiff(inp, out)
  247. class _Scatter(Function):
  248. def __init__(self, group=WORLD, device=""):
  249. self.group = group
  250. self.out_device = device
  251. def forward(self, data):
  252. self.in_device = str(data.device)
  253. return collective_comm(
  254. data, CollectiveComm.Mode.SCATTER, self.group, self.out_device
  255. )
  256. def backward(self, grad):
  257. # TODO backward with a part of grad
  258. if grad is not None:
  259. return gather(grad, self.group, self.in_device)
  260. def scatter(
  261. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  262. ) -> Tensor:
  263. """
  264. Create scatter operator for collective communication.
  265. :param inp: input tensor.
  266. :param group: communication group.
  267. :param device: execution device.
  268. """
  269. shape, dtype = _bcast_shape_dtype(group, inp)
  270. if group.rank != 0:
  271. # dummy input to infer shape
  272. inp = _dummy_input(shape, dtype, device)
  273. _bcast_tracer_state(group, inp)
  274. op = _Scatter(group, device)
  275. (out,) = apply(op, inp)
  276. return out
  277. def all_to_all(
  278. inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
  279. ) -> Tensor:
  280. """
  281. Create all_to_all operator for collective communication.
  282. :param inp: input tensor.
  283. :param group: communication group.
  284. :param device: execution device.
  285. """
  286. mode = CollectiveComm.Mode.ALL_TO_ALL
  287. return collective_comm(inp, mode, group, device)
  288. class _SendRecvGroup:
  289. def __init__(self, rank_from, rank_to):
  290. self.key = "{}->{}".format(rank_from, rank_to)
  291. self.rank_from = rank_from
  292. self.rank_to = rank_to
  293. self.size = 2
  294. @property
  295. def rank(self):
  296. if get_rank() == self.rank_from:
  297. return 0
  298. else:
  299. return 1
  300. class _RemoteSend(Function):
  301. def __init__(self, op: RemoteSend):
  302. self.op = op
  303. def forward(self, data):
  304. self.device = str(data.device)
  305. (self.dummy,) = apply(self.op, data)
  306. return self.dummy
  307. def backward(self, grad):
  308. assert grad is None
  309. has_grad = get_client().bcast_val(None, self.op.key, 2)
  310. if has_grad:
  311. return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,)
  312. class _RemoteRecv(Function):
  313. def __init__(self, op: RemoteRecv):
  314. self.op = op
  315. def forward(self, dummy):
  316. return apply(self.op, dummy)
  317. def backward(self, grad):
  318. get_client().bcast_val(grad is not None, self.op.key, 2)
  319. if grad is not None:
  320. remote_send(grad, self.op.rank_from)
  321. def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
  322. """
  323. Send a Tensor to a remote process.
  324. :param inp: tensor to send.
  325. :param dest_rank: destination process rank.
  326. """
  327. group = _SendRecvGroup(get_rank(), dest_rank)
  328. _bcast_shape_dtype(group, inp)
  329. _bcast_tracer_state(group, inp)
  330. op = RemoteSend()
  331. op.key = group.key
  332. op.addr, op.port = get_mm_server_addr()
  333. op.rank_to = dest_rank
  334. op.backend = get_backend()
  335. (out,) = apply(_RemoteSend(op), inp)
  336. _save_output_for_autodiff(inp, out)
  337. def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tensor:
  338. """
  339. Receive a Tensor from a remote process.
  340. :param src_rank: source process rank.
  341. :param device: the device to place the received tensor.
  342. :param inp: dummy input to determine recved tensor type
  343. """
  344. group = _SendRecvGroup(src_rank, get_rank())
  345. shape, dtype = _bcast_shape_dtype(group, None)
  346. if device is None:
  347. device = get_default_device()
  348. # dummy input
  349. if inp is None:
  350. inp = Tensor(0, device=device)
  351. _bcast_tracer_state(group, inp)
  352. _isscalar = False
  353. if len(shape) == 0:
  354. shape = (1,)
  355. _isscalar = True
  356. op = RemoteRecv()
  357. op.key = group.key
  358. op.cn = device
  359. op.shape = shape
  360. op.dtype = dtype
  361. op.addr, op.port = get_mm_server_addr()
  362. op.rank_from = src_rank
  363. op.backend = get_backend()
  364. (ret,) = apply(_RemoteRecv(op), inp)
  365. if _isscalar:
  366. setscalar(ret)
  367. return ret

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台