You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import Optional, Union
  10. import megengine._internal as mgb
  11. from megengine._internal.opr_param_defs import CollectiveComm as CollParam
  12. from ..core import Buffer, Parameter, Tensor, wrap_io_tensor
  13. from ..functional import add_update
  14. from .helper import collective_comm_symvar
  15. from .util import get_rank, is_distributed
  16. @wrap_io_tensor
  17. def _collective_comm(*args, **kargs):
  18. return collective_comm_symvar(*args, **kargs)
  19. def reduce_sum(
  20. tensor: Tensor,
  21. key: str,
  22. nr_ranks: Optional[int] = None,
  23. is_root: Optional[bool] = None,
  24. ) -> Tensor:
  25. """Create reduce_sum operator for collective communication
  26. :param tensor: input tensor
  27. :param key: unique identifier for collective communication
  28. :param nr_ranks: number of ranks, use util.get_world_size() as default
  29. :param is_root: whether this is a root node
  30. """
  31. return _collective_comm(
  32. tensor, key, CollParam.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device,
  33. )
  34. def broadcast(
  35. tensor: Tensor,
  36. key: str,
  37. nr_ranks: Optional[int] = None,
  38. is_root: Optional[bool] = None,
  39. ) -> Tensor:
  40. """Create broadcast operator for collective communication
  41. :param tensor: input tensor
  42. :param key: unique identifier for collective communication
  43. :param nr_ranks: number of ranks, use util.get_world_size() as default
  44. :param is_root: whether this is a root node
  45. """
  46. if key is None:
  47. key = tensor._symvar.name
  48. if is_root is None:
  49. is_root = get_rank() == 0
  50. if is_root:
  51. inp = tensor
  52. else:
  53. inp = tensor._symvar.owner_graph
  54. return _collective_comm(
  55. inp,
  56. key,
  57. CollParam.Mode.BROADCAST,
  58. nr_ranks,
  59. is_root,
  60. dtype=tensor.dtype,
  61. device=tensor.device,
  62. )
  63. def all_gather(
  64. tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
  65. ) -> Tensor:
  66. """Create all_gather operator for collective communication
  67. :param tensor: input tensor
  68. :param key: unique identifier for collective communication
  69. :param nr_ranks: number of ranks, use util.get_world_size() as default
  70. :param rank: rank of this node
  71. """
  72. return _collective_comm(tensor, key, CollParam.Mode.ALL_GATHER, nr_ranks, rank=rank)
  73. def reduce_scatter_sum(
  74. tensor: Tensor, key: str, nr_ranks: Optional[int] = None, rank: Optional[int] = None
  75. ) -> Tensor:
  76. """Create reduce_scatter_sum operator for collective communication
  77. :param tensor: input tensor
  78. :param key: unique identifier for collective communication
  79. :param nr_ranks: number of ranks, use util.get_world_size() as default
  80. :param rank: rank of this node
  81. """
  82. return _collective_comm(
  83. tensor, key, CollParam.Mode.REDUCE_SCATTER_SUM, nr_ranks, rank=rank,
  84. )
  85. def all_reduce_sum(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor:
  86. """Create all_reduce_sum operator for collective communication
  87. :param tensor: input tensor
  88. :param key: unique identifier for collective communication
  89. :param nr_ranks: number of ranks, use util.get_world_size() as default
  90. """
  91. return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_SUM, nr_ranks)
  92. def all_reduce_max(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor:
  93. """Create all_reduce_max operator for collective communication
  94. :param tensor: input tensor
  95. :param key: unique identifier for collective communication
  96. :param nr_ranks: number of ranks, use util.get_world_size() as default
  97. """
  98. return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MAX, nr_ranks)
  99. def all_reduce_min(tensor: Tensor, key: str, nr_ranks: Optional[int] = None) -> Tensor:
  100. """Create all_reduce_min operator for collective communication
  101. :param tensor: input tensor
  102. :param key: unique identifier for collective communication
  103. :param nr_ranks: number of ranks, use util.get_world_size() as default
  104. """
  105. return _collective_comm(tensor, key, CollParam.Mode.ALL_REDUCE_MIN, nr_ranks)
  106. def bcast_param(
  107. inp: Union[Buffer, Parameter],
  108. key: str,
  109. nr_ranks: Optional[int] = None,
  110. is_root: Optional[bool] = None,
  111. ) -> None:
  112. """Broadcast parameters among devices
  113. :param inp: input Buffer or Parameter to be synchronized
  114. :param key: unique identifier for collective communication
  115. :param nr_ranks: number of ranks, use util.get_world_size() as default
  116. :param is_root: whether this is a root node
  117. """
  118. if not is_distributed():
  119. return
  120. assert isinstance(inp, (Buffer, Parameter))
  121. bcast_res = broadcast(inp, key, nr_ranks, is_root)
  122. add_update(inp, bcast_res, alpha=0)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台