You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.py 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # -*- coding: utf-8 -*-
  2. import functools
  3. import multiprocessing as mp
  4. from collections import defaultdict
  5. from typing import Callable
  6. from weakref import WeakSet
  7. import numpy as np
  8. from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager
  9. from ..core._imperative_rt.core2 import apply
  10. from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
  11. from ..functional.tensor import copy
  12. from ..tensor import Tensor
  13. from ..utils.deprecation import deprecated_func
  14. from ..utils.future import Future
  15. from . import group as _group
  16. from .functional import _bcast_param, all_reduce_sum, broadcast
  17. from .group import WORLD, Group, group_barrier, is_distributed, override_backend
  18. def param_pack_split(inp: Tensor, offsets: list, shapes: list):
  19. r"""Returns split tensor to list of tensors as offsets and shapes described,
  20. only used for ``parampack``.
  21. Args:
  22. inp: input tensor.
  23. offsets: offsets of outputs, length of ``2 * n``,
  24. where ``n`` is the number of tensor you want to split,
  25. format ``[begin0, end0, begin1, end1]``.
  26. shapes: tensor shapes of outputs.
  27. Returns:
  28. splitted tensors.
  29. Examples:
  30. >>> a = F.ones(10)
  31. >>> b, c = dist.helper.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
  32. >>> b
  33. Tensor([1.], device=xpux:0)
  34. >>> c
  35. Tensor([[1. 1. 1.]
  36. [1. 1. 1.]
  37. [1. 1. 1.]], device=xpux:0)
  38. """
  39. op = ParamPackSplit()
  40. op.offsets = offsets
  41. op.shapes = [s or (1,) for s in shapes]
  42. outputs = apply(op, inp)
  43. return outputs
  44. def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list):
  45. r"""Returns concated tensor, only used for ``parampack``.
  46. Args:
  47. inps: list of input tensors.
  48. offsets: device value of offsets.
  49. offsets_val: offsets of inputs, length of ``2 * n``,
  50. format ``[begin0, end0, begin1, end1]``.
  51. Returns:
  52. concated tensor.
  53. Examples:
  54. >>> a = F.ones(1)
  55. >>> b = F.ones((3, 3))
  56. >>> offsets_val = [0, 1, 1, 10]
  57. >>> offsets = Tensor(offsets_val)
  58. >>> c = dist.helper.param_pack_concat([a, b], offsets, offsets_val) # doctest: +SKIP
  59. Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], device=xpux:0)
  60. """
  61. op = ParamPackConcat()
  62. op.offsets = offsets_val
  63. return apply(op, *inps, offsets)[0]
  64. def get_offsets(shapes):
  65. offsets = []
  66. offset = 0
  67. for shape in shapes:
  68. offsets.append(offset)
  69. offset += int(np.prod(shape))
  70. offsets.append(offset)
  71. return offsets
  72. _enable_p2p_cache = None
  73. def _check_enable_p2p():
  74. global _enable_p2p_cache
  75. if _enable_p2p_cache is not None:
  76. return _enable_p2p_cache
  77. cmd = ["nvidia-smi", "topo", "-p2p", "w"]
  78. import subprocess
  79. output = subprocess.run(cmd, stdout=subprocess.PIPE).stdout
  80. if output.count(b"OK") > 1:
  81. _enable_p2p_cache = True
  82. return True
  83. else:
  84. _enable_p2p_cache = False
  85. return False
  86. def pack_allreduce_split(pack_list, shapes, group, reduce_method):
  87. offsets_val = get_offsets(shapes)
  88. offsets = Tensor(offsets_val)
  89. packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
  90. packed_grads = all_reduce_sum(packed_grads, group, group.comp_node)
  91. if reduce_method == "mean":
  92. packed_grads /= group.size
  93. grads = param_pack_split(packed_grads, offsets_val, shapes)
  94. return grads
  95. class TensorFuture(Future):
  96. def device(self):
  97. raise "Sorry, this tensor is not ready"
  98. def numpy(self):
  99. raise "Sorry, this tensor is not ready"
  100. def shape(self):
  101. raise "Sorry, this tensor is not ready"
  102. def dtype(self):
  103. raise "Sorry, this tensor is not ready"
  104. def synchronized(func: Callable):
  105. r"""Decorator. Decorated function will synchronize when finished.
  106. Specifically, we use this to prevent data race during hub.load
  107. """
  108. @functools.wraps(func)
  109. def wrapper(*args, **kwargs):
  110. if not is_distributed():
  111. return func(*args, **kwargs)
  112. ret = func(*args, **kwargs)
  113. group_barrier()
  114. return ret
  115. return wrapper
  116. def _check_device_initialized(device_type: str, rank: int):
  117. try:
  118. test = Tensor(1, device=(device_type + str(rank)))
  119. inited = False
  120. del test
  121. except:
  122. inited = True
  123. errmsg = "The cuda env is set before the forked thread starts. Please do not use any cuda function or variable before forking."
  124. if inited:
  125. raise RuntimeError(errmsg)
  126. def _check_interpreter_status():
  127. from ..core._imperative_rt.core2 import get_option
  128. _ = get_option("async_level")
  129. get_device_count_by_fork = deprecated_func(
  130. "1.5", "megengine.device", "get_device_count", False
  131. )
  132. def bcast_list_(inps: list, group: Group = WORLD):
  133. r"""Broadcast tensors between given group.
  134. Args:
  135. inps: input tensors.
  136. group: communication group.
  137. """
  138. for inp in inps:
  139. inp._reset(_bcast_param(inp, group))
  140. class AllreduceCallback:
  141. r"""Allreduce Callback with tensor fusion optimization.
  142. Args:
  143. reduce_method: the method to reduce gradiants.
  144. group: communication group.
  145. backend: override distributed backend in allreduce
  146. """
  147. def __init__(self, reduce_method: str, group: Group = WORLD, backend: str = None):
  148. reduce_method = reduce_method.lower()
  149. assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean"
  150. self._reduce_method = reduce_method
  151. self._group = group
  152. self._marked_gm = WeakSet()
  153. self._param_pack_thd = 10 * 1024 * 1024
  154. self._reset()
  155. if backend is None:
  156. assert _group._sd, "please call init_process_group first"
  157. backend = _group._sd.backend
  158. if backend == "auto":
  159. if group.is_single_machine and not _check_enable_p2p():
  160. backend = "shm"
  161. else:
  162. backend = "nccl"
  163. self._backend = backend
  164. def _reset(self):
  165. self._params = []
  166. self._gradients_dict = dict()
  167. self._futures_dict = dict()
  168. self._packing_list = defaultdict(list)
  169. self._packing_size = defaultdict(int)
  170. self._grad_origin_device = dict()
  171. def _pack(self, dtype):
  172. if len(self._packing_list[dtype]) == 0:
  173. return
  174. grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
  175. shapes = [p._tuple_shape for p in self._packing_list[dtype]]
  176. with override_backend(self._backend):
  177. reduced_grads = pack_allreduce_split(
  178. grad_list, shapes, self._group, self._reduce_method
  179. )
  180. for param, grad in zip(self._packing_list[dtype], reduced_grads):
  181. self._gradients_dict[param] = grad
  182. self._packing_list[dtype] = []
  183. self._packing_size[dtype] = 0
  184. def __call__(self, param, grad):
  185. gm = get_backwarding_grad_manager()
  186. assert isinstance(gm, GradManager)
  187. if gm not in self._marked_gm:
  188. gm._register_after_backward_callback(self._flush)
  189. self._marked_gm.add(gm)
  190. self._params.append(param)
  191. self._futures_dict[param] = TensorFuture(ack=False)
  192. self._gradients_dict[param] = grad
  193. self._grad_origin_device[param] = str(grad.device)
  194. dtype_str = str(np.dtype(param.dtype))
  195. dtype_size = np.dtype(param.dtype).itemsize
  196. self._packing_list[dtype_str].append(param)
  197. self._packing_size[dtype_str] += int(np.prod(param._tuple_shape)) * dtype_size
  198. if self._packing_size[dtype_str] > self._param_pack_thd:
  199. self._pack(dtype_str)
  200. return self._futures_dict[param]
  201. def _flush(self):
  202. for dtype in sorted(self._packing_list.keys()):
  203. self._pack(dtype)
  204. for param in self._params:
  205. grad = self._gradients_dict[param]
  206. grad = copy(grad, self._grad_origin_device[param])
  207. self._futures_dict[param].set(grad)
  208. self._reset()
  209. make_allreduce_cb = AllreduceCallback