You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Union
  10. import megengine._internal as mgb
  11. from megengine._internal.opr_param_defs import CollectiveComm as CollParam
  12. from ..core import Buffer, Parameter, Tensor, wrap_io_tensor
  13. from ..core.graph import get_default_graph
  14. from ..functional import add_update
  15. from .util import (
  16. get_backend,
  17. get_master_ip,
  18. get_master_port,
  19. get_rank,
  20. get_world_size,
  21. is_distributed,
  22. )
  23. @wrap_io_tensor
  24. def _collective_comm(
  25. inp: Union[Tensor, mgb.CompGraph],
  26. key: str,
  27. op: CollParam.Mode,
  28. nr_ranks: Optional[int] = None,
  29. rank: Optional[int] = None,
  30. root: Optional[int] = 0,
  31. dtype: Optional[type] = None,
  32. device: Optional[mgb.CompNode] = None,
  33. comp_graph: Optional[mgb.CompGraph] = None,
  34. ) -> Tensor:
  35. """Helper function for creating collective_comm operators
  36. :param inp: tensor or comp_graph
  37. :param key: unique identifier for collective communication
  38. :param op: mode of collective communication
  39. :param nr_ranks: number of ranks, use util.get_world_size() as default
  40. :param rank: rank of the current process, use util.get_rank() as default
  41. :param root: rank of root node, use 0 as default
  42. :param dtype: output data type, use dtype of inp as default
  43. :param device: output comp node, use comp node of inp as default
  44. :param comp_graph: output comp graph, use comp graph of inp as default
  45. """
  46. return mgb.opr.collective_comm(
  47. inp,
  48. key=str(key),
  49. nr_devices=nr_ranks if nr_ranks is not None else get_world_size(),
  50. rank=rank if rank is not None else get_rank(),
  51. root=root,
  52. server_addr=get_master_ip(),
  53. port=get_master_port(),
  54. param=CollParam(mode=op),
  55. dtype=dtype,
  56. backend=get_backend(),
  57. comp_node=device,
  58. comp_graph=comp_graph,
  59. )
  60. def reduce_sum(
  61. tensor: Tensor,
  62. key: str,
  63. nr_ranks: Optional[int] = None,
  64. rank: Optional[int] = None,
  65. root: Optional[int] = 0,
  66. ) -> Tensor:
  67. """Create reduce_sum operator for collective communication
  68. :param tensor: input tensor
  69. :param key: unique identifier for collective communication
  70. :param nr_ranks: number of ranks, use util.get_world_size() as default
  71. :param rank: rank of the current process, use util.get_rank() as default
  72. :param root: rank of root node, use 0 as default
  73. """
  74. return _collective_comm(
  75. tensor,
  76. key,
  77. CollParam.Mode.REDUCE_SUM,
  78. nr_ranks,
  79. rank,
  80. root,
  81. device=tensor.device,
  82. )
  83. def broadcast(
  84. tensor: Tensor,
  85. key: str,
  86. nr_ranks: Optional[int] = None,
  87. rank: Optional[int] = None,
  88. root: Optional[int] = 0,
  89. ) -> Tensor:
  90. """Create broadcast operator for collective communication
  91. :param tensor: input tensor
  92. :param key: unique identifier for collective communication
  93. :param nr_ranks: number of ranks, use util.get_world_size() as default
  94. :param rank: rank of the current process, use util.get_rank() as default
  95. :param root: rank of root node, use 0 as default
  96. """
  97. if key is None:
  98. key = tensor._symvar.name
  99. if rank is None:
  100. rank = get_rank()
  101. if rank == root:
  102. return _collective_comm(
  103. tensor,
  104. key,
  105. CollParam.Mode.BROADCAST,
  106. nr_ranks,
  107. rank,
  108. root,
  109. device=tensor.device,
  110. )
  111. else:
  112. return _collective_comm(
  113. get_default_graph(),
  114. key,
  115. CollParam.Mode.BROADCAST,
  116. nr_ranks,
  117. rank,
  118. root,
  119. dtype=tensor._symvar.dtype,
  120. device=tensor.device,
  121. )
  122. def all_gather(
  123. tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
  124. ) -> Tensor:
  125. """Create all_gather operator for collective communication
  126. :param tensor: input tensor
  127. :param key: unique identifier for collective communication
  128. :param nr_ranks: number of ranks, use util.get_world_size() as default
  129. :param rank: rank of the current process, use util.get_rank() as default
  130. """
  131. return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank, 0)
  132. def reduce_scatter_sum(
  133. tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
  134. ) -> Tensor:
  135. """Create reduce_scatter_sum operator for collective communication
  136. :param tensor: input tensor
  137. :param key: unique identifier for collective communication
  138. :param nr_ranks: number of ranks, use util.get_world_size() as default
  139. :param rank: rank of the current process, use util.get_rank() as default
  140. """
  141. return _collective_comm(
  142. tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank
  143. )
  144. def all_reduce_sum(
  145. tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
  146. ) -> Tensor:
  147. """Create all_reduce_sum operator for collective communication
  148. :param tensor: input tensor
  149. :param key: unique identifier for collective communication
  150. :param nr_ranks: number of ranks, use util.get_world_size() as default
  151. :param rank: rank of the current process, use util.get_rank() as default
  152. """
  153. return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks, rank)
  154. def all_reduce_max(
  155. tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
  156. ) -> Tensor:
  157. """Create all_reduce_max operator for collective communication
  158. :param tensor: input tensor
  159. :param key: unique identifier for collective communication
  160. :param nr_ranks: number of ranks, use util.get_world_size() as default
  161. :param rank: rank of the current process, use util.get_rank() as default
  162. """
  163. return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks, rank)
  164. def all_reduce_min(
  165. tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
  166. ) -> Tensor:
  167. """Create all_reduce_min operator for collective communication
  168. :param tensor: input tensor
  169. :param key: unique identifier for collective communication
  170. :param nr_ranks: number of ranks, use util.get_world_size() as default
  171. :param rank: rank of the current process, use util.get_rank() as default
  172. """
  173. return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks, rank)
  174. def bcast_param(
  175. inp: Union[Buffer, Parameter],
  176. key: str,
  177. nr_ranks: Optional[int] = None,
  178. rank: Optional[int] = None,
  179. root: Optional[int] = 0,
  180. ) -> None:
  181. """Broadcast parameters among devices
  182. :param inp: input Buffer or Parameter to be synchronized
  183. :param key: unique identifier for collective communication
  184. :param nr_ranks: number of ranks, use util.get_world_size() as default
  185. :param rank: rank of the current process, use util.get_rank() as default
  186. :param root: rank of root node, use 0 as default
  187. """
  188. if not is_distributed():
  189. return
  190. assert isinstance(inp, (Buffer, Parameter))
  191. bcast_res = broadcast(inp, key, nr_ranks, rank, root)
  192. add_update(inp, bcast_res, alpha=0)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)