You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.py 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import multiprocessing as mp
  11. from collections import defaultdict
  12. from typing import Callable
  13. from weakref import WeakSet
  14. import numpy as np
  15. from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager
  16. from ..core._imperative_rt.core2 import apply
  17. from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
  18. from ..functional.tensor import copy
  19. from ..tensor import Tensor
  20. from ..utils.deprecation import deprecated_func
  21. from ..utils.future import Future
  22. from . import group as _group
  23. from .functional import _bcast_param, all_reduce_sum, broadcast
  24. from .group import WORLD, Group, group_barrier, is_distributed, override_backend
  25. def param_pack_split(inp: Tensor, offsets: list, shapes: list):
  26. r"""Returns split tensor to list of tensors as offsets and shapes described,
  27. only used for ``parampack``.
  28. Args:
  29. inp: input tensor.
  30. offsets: offsets of outputs, length of ``2 * n``,
  31. where ``n`` is the number of tensor you want to split,
  32. format ``[begin0, end0, begin1, end1]``.
  33. shapes: tensor shapes of outputs.
  34. Returns:
  35. splitted tensors.
  36. Examples:
  37. >>> a = F.ones(10)
  38. >>> b, c = dist.helper.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
  39. >>> b
  40. Tensor([1.], device=xpux:0)
  41. >>> c
  42. Tensor([[1. 1. 1.]
  43. [1. 1. 1.]
  44. [1. 1. 1.]], device=xpux:0)
  45. """
  46. op = ParamPackSplit()
  47. op.offsets = offsets
  48. op.shapes = [s or (1,) for s in shapes]
  49. outputs = apply(op, inp)
  50. return outputs
  51. def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list):
  52. r"""Returns concated tensor, only used for ``parampack``.
  53. Args:
  54. inps: list of input tensors.
  55. offsets: device value of offsets.
  56. offsets_val: offsets of inputs, length of ``2 * n``,
  57. format ``[begin0, end0, begin1, end1]``.
  58. Returns:
  59. concated tensor.
  60. Examples:
  61. >>> a = F.ones(1)
  62. >>> b = F.ones((3, 3))
  63. >>> offsets_val = [0, 1, 1, 10]
  64. >>> offsets = Tensor(offsets_val)
  65. >>> c = dist.helper.param_pack_concat([a, b], offsets, offsets_val) # doctest: +SKIP
  66. Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], device=xpux:0)
  67. """
  68. op = ParamPackConcat()
  69. op.offsets = offsets_val
  70. return apply(op, *inps, offsets)[0]
  71. def get_offsets(shapes):
  72. offsets = []
  73. offset = 0
  74. for shape in shapes:
  75. offsets.append(offset)
  76. offset += int(np.prod(shape))
  77. offsets.append(offset)
  78. return offsets
  79. _enable_p2p_cache = None
  80. def _check_enable_p2p():
  81. global _enable_p2p_cache
  82. if _enable_p2p_cache is not None:
  83. return _enable_p2p_cache
  84. cmd = ["nvidia-smi", "topo", "-p2p", "w"]
  85. import subprocess
  86. output = subprocess.run(cmd, stdout=subprocess.PIPE).stdout
  87. if output.count(b"OK") > 1:
  88. _enable_p2p_cache = True
  89. return True
  90. else:
  91. _enable_p2p_cache = False
  92. return False
  93. def pack_allreduce_split(pack_list, shapes, group, reduce_method):
  94. offsets_val = get_offsets(shapes)
  95. offsets = Tensor(offsets_val)
  96. packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
  97. packed_grads = all_reduce_sum(packed_grads, group, group.comp_node)
  98. if reduce_method == "mean":
  99. packed_grads /= group.size
  100. grads = param_pack_split(packed_grads, offsets_val, shapes)
  101. return grads
  102. class TensorFuture(Future):
  103. def device(self):
  104. raise "Sorry, this tensor is not ready"
  105. def numpy(self):
  106. raise "Sorry, this tensor is not ready"
  107. def shape(self):
  108. raise "Sorry, this tensor is not ready"
  109. def dtype(self):
  110. raise "Sorry, this tensor is not ready"
  111. def synchronized(func: Callable):
  112. r"""Decorator. Decorated function will synchronize when finished.
  113. Specifically, we use this to prevent data race during hub.load
  114. """
  115. @functools.wraps(func)
  116. def wrapper(*args, **kwargs):
  117. if not is_distributed():
  118. return func(*args, **kwargs)
  119. ret = func(*args, **kwargs)
  120. group_barrier()
  121. return ret
  122. return wrapper
  123. def _check_device_initialized(device_type: str, rank: int):
  124. try:
  125. test = Tensor(1, device=(device_type + str(rank)))
  126. inited = False
  127. del test
  128. except:
  129. inited = True
  130. errmsg = "The cuda env is set before the forked thread starts. Please do not use any cuda function or variable before forking."
  131. if inited:
  132. raise RuntimeError(errmsg)
  133. get_device_count_by_fork = deprecated_func(
  134. "1.5", "megengine.device", "get_device_count", False
  135. )
  136. def bcast_list_(inps: list, group: Group = WORLD):
  137. r"""Broadcast tensors between given group.
  138. Args:
  139. inps: input tensors.
  140. group: communication group.
  141. """
  142. for inp in inps:
  143. inp._reset(_bcast_param(inp, group))
  144. class AllreduceCallback:
  145. r"""Allreduce Callback with tensor fusion optimization.
  146. Args:
  147. reduce_method: the method to reduce gradiants.
  148. group: communication group.
  149. backend: override distributed backend in allreduce
  150. """
  151. def __init__(self, reduce_method: str, group: Group = WORLD, backend: str = None):
  152. reduce_method = reduce_method.lower()
  153. assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean"
  154. self._reduce_method = reduce_method
  155. self._group = group
  156. self._marked_gm = WeakSet()
  157. self._param_pack_thd = 10 * 1024 * 1024
  158. self._reset()
  159. if backend is None:
  160. assert _group._sd, "please call init_process_group first"
  161. backend = _group._sd.backend
  162. if backend == "auto":
  163. if group.is_single_machine and not _check_enable_p2p():
  164. backend = "shm"
  165. else:
  166. backend = "nccl"
  167. self._backend = backend
  168. def _reset(self):
  169. self._params = []
  170. self._gradients_dict = dict()
  171. self._futures_dict = dict()
  172. self._packing_list = defaultdict(list)
  173. self._packing_size = defaultdict(int)
  174. self._grad_origin_device = dict()
  175. def _pack(self, dtype):
  176. if len(self._packing_list[dtype]) == 0:
  177. return
  178. grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
  179. shapes = [p._tuple_shape for p in self._packing_list[dtype]]
  180. with override_backend(self._backend):
  181. reduced_grads = pack_allreduce_split(
  182. grad_list, shapes, self._group, self._reduce_method
  183. )
  184. for param, grad in zip(self._packing_list[dtype], reduced_grads):
  185. self._gradients_dict[param] = grad
  186. self._packing_list[dtype] = []
  187. self._packing_size[dtype] = 0
  188. def __call__(self, param, grad):
  189. gm = get_backwarding_grad_manager()
  190. assert isinstance(gm, GradManager)
  191. if gm not in self._marked_gm:
  192. gm._register_after_backward_callback(self._flush)
  193. self._marked_gm.add(gm)
  194. self._params.append(param)
  195. self._futures_dict[param] = TensorFuture(ack=False)
  196. self._gradients_dict[param] = grad
  197. self._grad_origin_device[param] = str(grad.device)
  198. dtype_str = str(np.dtype(param.dtype))
  199. dtype_size = np.dtype(param.dtype).itemsize
  200. self._packing_list[dtype_str].append(param)
  201. self._packing_size[dtype_str] += int(np.prod(param._tuple_shape)) * dtype_size
  202. if self._packing_size[dtype_str] > self._param_pack_thd:
  203. self._pack(dtype_str)
  204. return self._futures_dict[param]
  205. def _flush(self):
  206. for dtype in sorted(self._packing_list.keys()):
  207. self._pack(dtype)
  208. for param in self._params:
  209. grad = self._gradients_dict[param]
  210. grad = copy(grad, self._grad_origin_device[param])
  211. self._futures_dict[param].set(grad)
  212. self._reset()
  213. make_allreduce_cb = AllreduceCallback