You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Union
  10. import megengine._internal as mgb
  11. from megengine._internal.opr_param_defs import CollectiveComm as CollParam
  12. from ..core import Buffer, Parameter, Tensor, wrap_io_tensor
  13. from ..functional import add_update
  14. from .helper import collective_comm_symvar
  15. from .util import get_rank, is_distributed
  16. @wrap_io_tensor
  17. def _collective_comm(*args, **kargs):
  18. return collective_comm_symvar(*args, **kargs)
  19. def reduce_sum(
  20. tensor: Tensor,
  21. key: str,
  22. nr_ranks: Optional[int] = None,
  23. rank: Optional[int] = None,
  24. root: Optional[int] = 0,
  25. ) -> Tensor:
  26. """Create reduce_sum operator for collective communication
  27. :param tensor: input tensor
  28. :param key: unique identifier for collective communication
  29. :param nr_ranks: number of ranks, use util.get_world_size() as default
  30. :param rank: rank of the current process, use util.get_rank() as default
  31. :param root: rank of root node, use 0 as default
  32. """
  33. return _collective_comm(
  34. tensor,
  35. key,
  36. CollParam.Mode.REDUCE_SUM,
  37. nr_ranks,
  38. rank,
  39. root,
  40. device=tensor.device,
  41. )
  42. def broadcast(
  43. tensor: Tensor,
  44. key: str,
  45. nr_ranks: Optional[int] = None,
  46. rank: Optional[int] = None,
  47. root: Optional[int] = 0,
  48. ) -> Tensor:
  49. """Create broadcast operator for collective communication
  50. :param tensor: input tensor
  51. :param key: unique identifier for collective communication
  52. :param nr_ranks: number of ranks, use util.get_world_size() as default
  53. :param rank: rank of the current process, use util.get_rank() as default
  54. :param root: rank of root node, use 0 as default
  55. """
  56. if key is None:
  57. key = tensor._symvar.name
  58. if rank is None:
  59. rank = get_rank()
  60. if rank == root:
  61. inp = tensor
  62. else:
  63. inp = tensor._symvar.owner_graph
  64. return _collective_comm(
  65. inp,
  66. key,
  67. CollParam.Mode.BROADCAST,
  68. nr_ranks,
  69. rank,
  70. root,
  71. dtype=tensor.dtype,
  72. device=tensor.device,
  73. )
  74. def all_gather(
  75. tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
  76. ) -> Tensor:
  77. """Create all_gather operator for collective communication
  78. :param tensor: input tensor
  79. :param key: unique identifier for collective communication
  80. :param nr_ranks: number of ranks, use util.get_world_size() as default
  81. :param rank: rank of the current process, use util.get_rank() as default
  82. """
  83. return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank, 0)
  84. def reduce_scatter_sum(
  85. tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
  86. ) -> Tensor:
  87. """Create reduce_scatter_sum operator for collective communication
  88. :param tensor: input tensor
  89. :param key: unique identifier for collective communication
  90. :param nr_ranks: number of ranks, use util.get_world_size() as default
  91. :param rank: rank of the current process, use util.get_rank() as default
  92. """
  93. return _collective_comm(
  94. tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank
  95. )
  96. def all_reduce_sum(
  97. tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
  98. ) -> Tensor:
  99. """Create all_reduce_sum operator for collective communication
  100. :param tensor: input tensor
  101. :param key: unique identifier for collective communication
  102. :param nr_ranks: number of ranks, use util.get_world_size() as default
  103. :param rank: rank of the current process, use util.get_rank() as default
  104. """
  105. return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks, rank)
  106. def all_reduce_max(
  107. tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
  108. ) -> Tensor:
  109. """Create all_reduce_max operator for collective communication
  110. :param tensor: input tensor
  111. :param key: unique identifier for collective communication
  112. :param nr_ranks: number of ranks, use util.get_world_size() as default
  113. :param rank: rank of the current process, use util.get_rank() as default
  114. """
  115. return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks, rank)
  116. def all_reduce_min(
  117. tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
  118. ) -> Tensor:
  119. """Create all_reduce_min operator for collective communication
  120. :param tensor: input tensor
  121. :param key: unique identifier for collective communication
  122. :param nr_ranks: number of ranks, use util.get_world_size() as default
  123. :param rank: rank of the current process, use util.get_rank() as default
  124. """
  125. return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks, rank)
  126. def bcast_param(
  127. inp: Union[Buffer, Parameter],
  128. key: str,
  129. nr_ranks: Optional[int] = None,
  130. rank: Optional[int] = None,
  131. root: Optional[int] = 0,
  132. ) -> None:
  133. """Broadcast parameters among devices
  134. :param inp: input Buffer or Parameter to be synchronized
  135. :param key: unique identifier for collective communication
  136. :param nr_ranks: number of ranks, use util.get_world_size() as default
  137. :param rank: rank of the current process, use util.get_rank() as default
  138. :param root: rank of root node, use 0 as default
  139. """
  140. if not is_distributed():
  141. return
  142. assert isinstance(inp, (Buffer, Parameter))
  143. bcast_res = broadcast(inp, key, nr_ranks, rank, root)
  144. add_update(inp, bcast_res, alpha=0)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台