You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 9.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Union
  10. import megengine._internal as mgb
  11. from megengine._internal.opr_param_defs import CollectiveComm as Param
  12. from ..core import Buffer, Parameter, Tensor, wrap_io_tensor
  13. from ..functional import add_update
  14. from .helper import collective_comm_symvar
  15. from .util import get_rank, is_distributed
  16. @wrap_io_tensor
  17. def _collective_comm(*args, **kargs):
  18. return collective_comm_symvar(*args, **kargs)
  19. def _group_check(*args):
  20. """Return True when arguments are all None or all not None
  21. """
  22. l = [val is None for val in args]
  23. return len(set(l)) <= 1
  24. def reduce_sum(
  25. tensor: Tensor,
  26. key: Optional[str] = None,
  27. nr_ranks: Optional[int] = None,
  28. is_root: Optional[bool] = None,
  29. ) -> Tensor:
  30. """Create reduce_sum operator for collective communication
  31. :param tensor: input tensor
  32. :param key: unique identifier for collective communication
  33. :param nr_ranks: number of ranks, use util.get_world_size() as default
  34. :param is_root: whether this is a root node
  35. """
  36. assert _group_check(
  37. key, nr_ranks, is_root
  38. ), "key, nr_ranks, is_root should be set at the same time"
  39. return _collective_comm(
  40. tensor, key, Param.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device,
  41. )
  42. def gather(
  43. tensor: Tensor,
  44. key: Optional[str] = None,
  45. nr_ranks: Optional[int] = None,
  46. is_root: Optional[bool] = None,
  47. rank: Optional[int] = None,
  48. ) -> Tensor:
  49. """Create gather operator for collective communication
  50. :param tensor: input tensor
  51. :param key: unique identifier for collective communication
  52. :param nr_ranks: number of ranks, use util.get_world_size() as default
  53. :param is_root: whether this is a root node
  54. :param rank: rank of this node
  55. """
  56. assert _group_check(
  57. key, nr_ranks, is_root, rank
  58. ), "key, nr_ranks, is_root, rank should be set at the same time"
  59. return _collective_comm(
  60. tensor, key, Param.Mode.GATHER, nr_ranks, is_root, rank, device=tensor.device,
  61. )
  62. def broadcast(
  63. tensor: Tensor,
  64. key: Optional[str] = None,
  65. nr_ranks: Optional[int] = None,
  66. is_root: Optional[bool] = None,
  67. ) -> Tensor:
  68. """Create broadcast operator for collective communication
  69. :param tensor: input tensor
  70. :param key: unique identifier for collective communication
  71. :param nr_ranks: number of ranks, use util.get_world_size() as default
  72. :param is_root: whether this is a root node
  73. """
  74. assert _group_check(
  75. key, nr_ranks, is_root
  76. ), "key, nr_ranks, is_root should be set at the same time"
  77. if is_root is None:
  78. is_root = get_rank() == 0
  79. if is_root:
  80. inp = tensor
  81. else:
  82. inp = tensor._symvar.owner_graph
  83. return _collective_comm(
  84. inp,
  85. key,
  86. Param.Mode.BROADCAST,
  87. nr_ranks,
  88. is_root,
  89. dtype=tensor.dtype,
  90. device=tensor.device,
  91. )
  92. def scatter(
  93. tensor: Tensor,
  94. key: Optional[str] = None,
  95. nr_ranks: Optional[int] = None,
  96. is_root: Optional[bool] = None,
  97. rank: Optional[int] = None,
  98. ) -> Tensor:
  99. """Create scatter operator for collective communication
  100. :param tensor: input tensor
  101. :param key: unique identifier for collective communication
  102. :param nr_ranks: number of ranks, use util.get_world_size() as default
  103. :param is_root: whether this is a root node
  104. :param rank: rank of this node
  105. """
  106. assert _group_check(
  107. key, nr_ranks, is_root, rank
  108. ), "key, nr_ranks, is_root, rank should be set at the same time"
  109. if key is None:
  110. key = tensor._symvar.name
  111. if is_root is None:
  112. is_root = get_rank() == 0
  113. if is_root:
  114. inp = tensor
  115. else:
  116. inp = tensor._symvar.owner_graph
  117. return _collective_comm(
  118. inp,
  119. key,
  120. Param.Mode.SCATTER,
  121. nr_ranks,
  122. is_root,
  123. rank,
  124. dtype=tensor.dtype,
  125. device=tensor.device,
  126. )
  127. def all_to_all(
  128. tensor: Tensor,
  129. key: Optional[str] = None,
  130. nr_ranks: Optional[int] = None,
  131. rank: Optional[int] = None,
  132. local_grad: Optional[bool] = False,
  133. ) -> Tensor:
  134. """Create all_to_all operator for collective communication
  135. :param tensor: input tensor
  136. :param key: unique identifier for collective communication
  137. :param nr_ranks: number of ranks, use util.get_world_size() as default
  138. :param rank: rank of this node
  139. :param local_grad: whether use local grad
  140. """
  141. assert _group_check(
  142. key, nr_ranks, rank
  143. ), "key, nr_ranks, rank should be set at the same time"
  144. return _collective_comm(
  145. tensor, key, Param.Mode.ALL_TO_ALL, nr_ranks, rank=rank, local_grad=local_grad,
  146. )
  147. def all_gather(
  148. tensor: Tensor,
  149. key: Optional[str] = None,
  150. nr_ranks: Optional[int] = None,
  151. rank: Optional[int] = None,
  152. local_grad: Optional[bool] = False,
  153. ) -> Tensor:
  154. """Create all_gather operator for collective communication
  155. :param tensor: input tensor
  156. :param key: unique identifier for collective communication
  157. :param nr_ranks: number of ranks, use util.get_world_size() as default
  158. :param rank: rank of this node
  159. :param local_grad: whether use local grad
  160. """
  161. assert _group_check(
  162. key, nr_ranks, rank
  163. ), "key, nr_ranks, rank should be set at the same time"
  164. return _collective_comm(
  165. tensor, key, Param.Mode.ALL_GATHER, nr_ranks, rank=rank, local_grad=local_grad
  166. )
  167. def reduce_scatter_sum(
  168. tensor: Tensor,
  169. key: Optional[str] = None,
  170. nr_ranks: Optional[int] = None,
  171. rank: Optional[int] = None,
  172. local_grad: Optional[bool] = False,
  173. ) -> Tensor:
  174. """Create reduce_scatter_sum operator for collective communication
  175. :param tensor: input tensor
  176. :param key: unique identifier for collective communication
  177. :param nr_ranks: number of ranks, use util.get_world_size() as default
  178. :param rank: rank of this node
  179. :param local_grad: whether use local grad
  180. """
  181. assert _group_check(
  182. key, nr_ranks, rank
  183. ), "key, nr_ranks, rank should be set at the same time"
  184. return _collective_comm(
  185. tensor,
  186. key,
  187. Param.Mode.REDUCE_SCATTER_SUM,
  188. nr_ranks,
  189. rank=rank,
  190. local_grad=local_grad,
  191. )
  192. def all_reduce_sum(
  193. tensor: Tensor,
  194. key: Optional[str] = None,
  195. nr_ranks: Optional[int] = None,
  196. local_grad: Optional[bool] = False,
  197. ) -> Tensor:
  198. """Create all_reduce_sum operator for collective communication
  199. :param tensor: input tensor
  200. :param key: unique identifier for collective communication
  201. :param nr_ranks: number of ranks, use util.get_world_size() as default
  202. :param local_grad: whether use local grad
  203. """
  204. assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time"
  205. return _collective_comm(
  206. tensor, key, Param.Mode.ALL_REDUCE_SUM, nr_ranks, local_grad=local_grad
  207. )
  208. def all_reduce_max(
  209. tensor: Tensor,
  210. key: Optional[str] = None,
  211. nr_ranks: Optional[int] = None,
  212. local_grad: Optional[bool] = False,
  213. ) -> Tensor:
  214. """Create all_reduce_max operator for collective communication
  215. :param tensor: input tensor
  216. :param key: unique identifier for collective communication
  217. :param nr_ranks: number of ranks, use util.get_world_size() as default
  218. :param local_grad: whether use local grad
  219. """
  220. assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time"
  221. return _collective_comm(
  222. tensor, key, Param.Mode.ALL_REDUCE_MAX, nr_ranks, local_grad=local_grad
  223. )
  224. def all_reduce_min(
  225. tensor: Tensor,
  226. key: Optional[str] = None,
  227. nr_ranks: Optional[int] = None,
  228. local_grad: Optional[bool] = False,
  229. ) -> Tensor:
  230. """Create all_reduce_min operator for collective communication
  231. :param tensor: input tensor
  232. :param key: unique identifier for collective communication
  233. :param nr_ranks: number of ranks, use util.get_world_size() as default
  234. :param local_grad: whether use local grad
  235. """
  236. assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time"
  237. return _collective_comm(
  238. tensor, key, Param.Mode.ALL_REDUCE_MIN, nr_ranks, local_grad=local_grad
  239. )
  240. def bcast_param(
  241. inp: Union[Buffer, Parameter],
  242. key: Optional[str] = None,
  243. nr_ranks: Optional[int] = None,
  244. is_root: Optional[bool] = None,
  245. ) -> None:
  246. """Broadcast parameters among devices
  247. :param inp: input Buffer or Parameter to be synchronized
  248. :param key: unique identifier for collective communication
  249. :param nr_ranks: number of ranks, use util.get_world_size() as default
  250. :param is_root: whether this is a root node
  251. """
  252. if not is_distributed():
  253. return
  254. assert _group_check(
  255. key, nr_ranks, is_root
  256. ), "key, nr_ranks, is_root should be set at the same time"
  257. assert isinstance(inp, (Buffer, Parameter))
  258. bcast_res = broadcast(inp, key, nr_ranks, is_root)
  259. add_update(inp, bcast_res, alpha=0)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台