You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.py 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. # -*- coding: utf-8 -*-
  2. import functools
  3. import multiprocessing as mp
  4. from collections import defaultdict
  5. from typing import Callable
  6. from weakref import WeakSet
  7. import numpy as np
  8. from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager
  9. from ..core._imperative_rt.core2 import apply
  10. from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
  11. from ..functional.tensor import copy
  12. from ..tensor import Tensor
  13. from ..utils.deprecation import deprecated_func
  14. from ..utils.future import Future
  15. from . import group as _group
  16. from .functional import _bcast_param, all_reduce_sum, broadcast
  17. from .group import WORLD, Group, group_barrier, is_distributed, override_backend
  18. def param_pack_split(inp: Tensor, offsets: list, shapes: list):
  19. r"""Returns split tensor to list of tensors as offsets and shapes described,
  20. only used for ``parampack``.
  21. Args:
  22. inp: input tensor.
  23. offsets: offsets of outputs, length of ``2 * n``,
  24. where ``n`` is the number of tensor you want to split,
  25. format ``[begin0, end0, begin1, end1]``.
  26. shapes: tensor shapes of outputs.
  27. Returns:
  28. splitted tensors.
  29. Examples:
  30. >>> a = F.ones(10)
  31. >>> b, c = dist.helper.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
  32. >>> b
  33. Tensor([1.], device=xpux:0)
  34. >>> c
  35. Tensor([[1. 1. 1.]
  36. [1. 1. 1.]
  37. [1. 1. 1.]], device=xpux:0)
  38. """
  39. op = ParamPackSplit()
  40. op.offsets = offsets
  41. op.shapes = [s or (1,) for s in shapes]
  42. outputs = apply(op, inp)
  43. return outputs
  44. def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list):
  45. r"""Returns concated tensor, only used for ``parampack``.
  46. Args:
  47. inps: list of input tensors.
  48. offsets: device value of offsets.
  49. offsets_val: offsets of inputs, length of ``2 * n``,
  50. format ``[begin0, end0, begin1, end1]``.
  51. Returns:
  52. concated tensor.
  53. Examples:
  54. >>> a = F.ones(1)
  55. >>> b = F.ones((3, 3))
  56. >>> offsets_val = [0, 1, 1, 10]
  57. >>> offsets = Tensor(offsets_val)
  58. >>> c = dist.helper.param_pack_concat([a, b], offsets, offsets_val) # doctest: +SKIP
  59. Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], device=xpux:0)
  60. """
  61. op = ParamPackConcat()
  62. op.offsets = offsets_val
  63. return apply(op, *inps, offsets)[0]
  64. def get_offsets(shapes):
  65. offsets = []
  66. offset = 0
  67. for shape in shapes:
  68. offsets.append(offset)
  69. offset += int(np.prod(shape))
  70. offsets.append(offset)
  71. return offsets
  72. _enable_p2p_cache = None
  73. def _check_enable_p2p():
  74. global _enable_p2p_cache
  75. if _enable_p2p_cache is not None:
  76. return _enable_p2p_cache
  77. cmd = ["nvidia-smi", "topo", "-p2p", "w"]
  78. import subprocess
  79. output = subprocess.run(cmd, stdout=subprocess.PIPE).stdout
  80. if output.count(b"OK") > 1:
  81. _enable_p2p_cache = True
  82. return True
  83. else:
  84. _enable_p2p_cache = False
  85. return False
  86. def pack_allreduce_split(pack_list, shapes, group, reduce_method):
  87. offsets_val = get_offsets(shapes)
  88. offsets = Tensor(offsets_val)
  89. packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
  90. packed_grads = all_reduce_sum(packed_grads, group, group.comp_node)
  91. if reduce_method == "mean":
  92. packed_grads /= group.size
  93. grads = param_pack_split(packed_grads, offsets_val, shapes)
  94. return grads
  95. class TensorFuture(Future):
  96. def device(self):
  97. raise "Sorry, this tensor is not ready"
  98. def numpy(self):
  99. raise "Sorry, this tensor is not ready"
  100. def shape(self):
  101. raise "Sorry, this tensor is not ready"
  102. def dtype(self):
  103. raise "Sorry, this tensor is not ready"
  104. def synchronized(func: Callable):
  105. r"""Decorator. Decorated function will synchronize when finished.
  106. Specifically, we use this to prevent data race during hub.load
  107. """
  108. @functools.wraps(func)
  109. def wrapper(*args, **kwargs):
  110. if not is_distributed():
  111. return func(*args, **kwargs)
  112. ret = func(*args, **kwargs)
  113. group_barrier()
  114. return ret
  115. return wrapper
  116. def _check_device_initialized(device_type: str, rank: int):
  117. try:
  118. test = Tensor(1, device=(device_type + str(rank)))
  119. inited = False
  120. del test
  121. except:
  122. inited = True
  123. errmsg = "The cuda env is set before the forked thread starts. Please do not use any cuda function or variable before forking."
  124. if inited:
  125. raise RuntimeError(errmsg)
  126. get_device_count_by_fork = deprecated_func(
  127. "1.5", "megengine.device", "get_device_count", False
  128. )
  129. def bcast_list_(inps: list, group: Group = WORLD):
  130. r"""Broadcast tensors between given group.
  131. Args:
  132. inps: input tensors.
  133. group: communication group.
  134. """
  135. for inp in inps:
  136. inp._reset(_bcast_param(inp, group))
  137. class AllreduceCallback:
  138. r"""Allreduce Callback with tensor fusion optimization.
  139. Args:
  140. reduce_method: the method to reduce gradiants.
  141. group: communication group.
  142. backend: override distributed backend in allreduce
  143. """
  144. def __init__(self, reduce_method: str, group: Group = WORLD, backend: str = None):
  145. reduce_method = reduce_method.lower()
  146. assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean"
  147. self._reduce_method = reduce_method
  148. self._group = group
  149. self._marked_gm = WeakSet()
  150. self._param_pack_thd = 10 * 1024 * 1024
  151. self._reset()
  152. if backend is None:
  153. assert _group._sd, "please call init_process_group first"
  154. backend = _group._sd.backend
  155. if backend == "auto":
  156. if group.is_single_machine and not _check_enable_p2p():
  157. backend = "shm"
  158. else:
  159. backend = "nccl"
  160. self._backend = backend
  161. def _reset(self):
  162. self._params = []
  163. self._gradients_dict = dict()
  164. self._futures_dict = dict()
  165. self._packing_list = defaultdict(list)
  166. self._packing_size = defaultdict(int)
  167. self._grad_origin_device = dict()
  168. def _pack(self, dtype):
  169. if len(self._packing_list[dtype]) == 0:
  170. return
  171. grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
  172. shapes = [p._tuple_shape for p in self._packing_list[dtype]]
  173. with override_backend(self._backend):
  174. reduced_grads = pack_allreduce_split(
  175. grad_list, shapes, self._group, self._reduce_method
  176. )
  177. for param, grad in zip(self._packing_list[dtype], reduced_grads):
  178. self._gradients_dict[param] = grad
  179. self._packing_list[dtype] = []
  180. self._packing_size[dtype] = 0
  181. def __call__(self, param, grad):
  182. gm = get_backwarding_grad_manager()
  183. assert isinstance(gm, GradManager)
  184. if gm not in self._marked_gm:
  185. gm._register_after_backward_callback(self._flush)
  186. self._marked_gm.add(gm)
  187. self._params.append(param)
  188. self._futures_dict[param] = TensorFuture(ack=False)
  189. self._gradients_dict[param] = grad
  190. self._grad_origin_device[param] = str(grad.device)
  191. dtype_str = str(np.dtype(param.dtype))
  192. dtype_size = np.dtype(param.dtype).itemsize
  193. self._packing_list[dtype_str].append(param)
  194. self._packing_size[dtype_str] += int(np.prod(param._tuple_shape)) * dtype_size
  195. if self._packing_size[dtype_str] > self._param_pack_thd:
  196. self._pack(dtype_str)
  197. return self._futures_dict[param]
  198. def _flush(self):
  199. for dtype in sorted(self._packing_list.keys()):
  200. self._pack(dtype)
  201. for param in self._params:
  202. grad = self._gradients_dict[param]
  203. grad = copy(grad, self._grad_origin_device[param])
  204. self._futures_dict[param].set(grad)
  205. self._reset()
  206. make_allreduce_cb = AllreduceCallback