You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Union
  10. import megengine._internal as mgb
  11. from megengine._internal.opr_param_defs import CollectiveComm as CollParam
  12. from ..core import Buffer, Parameter, Tensor, wrap_io_tensor
  13. from ..functional import add_update
  14. from .helper import collective_comm_symvar
  15. from .util import get_rank, is_distributed
  16. @wrap_io_tensor
  17. def _collective_comm(*args, **kargs):
  18. return collective_comm_symvar(*args, **kargs)
  19. def reduce_sum(
  20. tensor: Tensor,
  21. key: str,
  22. nr_ranks: Optional[int] = None,
  23. is_root: Optional[bool] = None,
  24. ) -> Tensor:
  25. """Create reduce_sum operator for collective communication
  26. :param tensor: input tensor
  27. :param key: unique identifier for collective communication
  28. :param nr_ranks: number of ranks, use util.get_world_size() as default
  29. :param is_root: whether this is a root node
  30. """
  31. return _collective_comm(
  32. tensor, key, CollParam.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device,
  33. )
  34. def gather(
  35. tensor: Tensor,
  36. key: str,
  37. nr_ranks: Optional[int] = None,
  38. is_root: Optional[bool] = None,
  39. rank: Optional[int] = None,
  40. ) -> Tensor:
  41. """Create gather operator for collective communication
  42. :param tensor: input tensor
  43. :param key: unique identifier for collective communication
  44. :param nr_ranks: number of ranks, use util.get_world_size() as default
  45. :param is_root: whether this is a root node
  46. :param rank: rank of this node
  47. """
  48. return _collective_comm(
  49. tensor,
  50. key,
  51. CollParam.Mode.GATHER,
  52. nr_ranks,
  53. is_root,
  54. rank,
  55. device=tensor.device,
  56. )
  57. def broadcast(
  58. tensor: Tensor,
  59. key: str,
  60. nr_ranks: Optional[int] = None,
  61. is_root: Optional[bool] = None,
  62. ) -> Tensor:
  63. """Create broadcast operator for collective communication
  64. :param tensor: input tensor
  65. :param key: unique identifier for collective communication
  66. :param nr_ranks: number of ranks, use util.get_world_size() as default
  67. :param is_root: whether this is a root node
  68. """
  69. if key is None:
  70. key = tensor._symvar.name
  71. if is_root is None:
  72. is_root = get_rank() == 0
  73. if is_root:
  74. inp = tensor
  75. else:
  76. inp = tensor._symvar.owner_graph
  77. return _collective_comm(
  78. inp,
  79. key,
  80. CollParam.Mode.BROADCAST,
  81. nr_ranks,
  82. is_root,
  83. dtype=tensor.dtype,
  84. device=tensor.device,
  85. )
  86. def scatter(
  87. tensor: Tensor,
  88. key: str,
  89. nr_ranks: Optional[int] = None,
  90. is_root: Optional[bool] = None,
  91. rank: Optional[int] = None,
  92. ) -> Tensor:
  93. """Create scatter operator for collective communication
  94. :param tensor: input tensor
  95. :param key: unique identifier for collective communication
  96. :param nr_ranks: number of ranks, use util.get_world_size() as default
  97. :param is_root: whether this is a root node
  98. :param rank: rank of this node
  99. """
  100. if key is None:
  101. key = tensor._symvar.name
  102. if is_root is None:
  103. is_root = get_rank() == 0
  104. if is_root:
  105. inp = tensor
  106. else:
  107. inp = tensor._symvar.owner_graph
  108. return _collective_comm(
  109. inp,
  110. key,
  111. CollParam.Mode.SCATTER,
  112. nr_ranks,
  113. is_root,
  114. rank,
  115. dtype=tensor.dtype,
  116. device=tensor.device,
  117. )
  118. def all_to_all(
  119. tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
  120. ) -> Tensor:
  121. """Create all_to_all operator for collective communication
  122. :param tensor: input tensor
  123. :param key: unique identifier for collective communication
  124. :param nr_ranks: number of ranks, use util.get_world_size() as default
  125. :param rank: rank of this node
  126. """
  127. return _collective_comm(tensor, key, CollParam.Mode.ALL_TO_ALL, nr_ranks, rank=rank)
  128. def all_gather(
  129. tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
  130. ) -> Tensor:
  131. """Create all_gather operator for collective communication
  132. :param tensor: input tensor
  133. :param key: unique identifier for collective communication
  134. :param nr_ranks: number of ranks, use util.get_world_size() as default
  135. :param rank: rank of this node
  136. """
  137. return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank=rank)
  138. def reduce_scatter_sum(
  139. tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
  140. ) -> Tensor:
  141. """Create reduce_scatter_sum operator for collective communication
  142. :param tensor: input tensor
  143. :param key: unique identifier for collective communication
  144. :param nr_ranks: number of ranks, use util.get_world_size() as default
  145. :param rank: rank of this node
  146. """
  147. return _collective_comm(
  148. tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank=rank,
  149. )
  150. def all_reduce_sum(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor:
  151. """Create all_reduce_sum operator for collective communication
  152. :param tensor: input tensor
  153. :param key: unique identifier for collective communication
  154. :param nr_ranks: number of ranks, use util.get_world_size() as default
  155. """
  156. return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks)
  157. def all_reduce_max(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor:
  158. """Create all_reduce_max operator for collective communication
  159. :param tensor: input tensor
  160. :param key: unique identifier for collective communication
  161. :param nr_ranks: number of ranks, use util.get_world_size() as default
  162. """
  163. return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks)
  164. def all_reduce_min(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor:
  165. """Create all_reduce_min operator for collective communication
  166. :param tensor: input tensor
  167. :param key: unique identifier for collective communication
  168. :param nr_ranks: number of ranks, use util.get_world_size() as default
  169. """
  170. return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks)
  171. def bcast_param(
  172. inp: Union[Buffer, Parameter],
  173. key: str,
  174. nr_ranks: Optional[int] = None,
  175. is_root: Optional[bool] = None,
  176. ) -> None:
  177. """Broadcast parameters among devices
  178. :param inp: input Buffer or Parameter to be synchronized
  179. :param key: unique identifier for collective communication
  180. :param nr_ranks: number of ranks, use util.get_world_size() as default
  181. :param is_root: whether this is a root node
  182. """
  183. if not is_distributed():
  184. return
  185. assert isinstance(inp, (Buffer, Parameter))
  186. bcast_res = broadcast(inp, key, nr_ranks, is_root)
  187. add_update(inp, bcast_res, alpha=0)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台