You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper.py 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import multiprocessing as mp
  11. from collections import defaultdict
  12. from typing import Callable
  13. from weakref import WeakSet
  14. import numpy as np
  15. from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager
  16. from megengine.device import get_default_device, get_device_count
  17. from ..core._imperative_rt.core2 import apply
  18. from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
  19. from ..functional.tensor import copy
  20. from ..tensor import Tensor
  21. from ..utils.future import Future
  22. from .functional import _bcast_param, all_reduce_sum, broadcast
  23. from .group import WORLD, Group, group_barrier, is_distributed
  24. def param_pack_split(inp: Tensor, offsets: list, shapes: list):
  25. r"""
  26. Returns split tensor to tensor list as offsets and shapes described,
  27. only used for ``parampack``.
  28. :param inp: input tensor.
  29. :param offsets: offsets of outputs, length of `2 * n`,
  30. while n is tensor nums you want to split,
  31. format `[begin0, end0, begin1, end1]`.
  32. :param shapes: tensor shapes of outputs.
  33. :return: splitted tensors.
  34. Examples:
  35. .. testcode::
  36. import numpy as np
  37. from megengine import tensor
  38. from megengine.distributed.helper import param_pack_split
  39. a = tensor(np.ones((10,), np.int32))
  40. b, c = param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
  41. print(b.numpy())
  42. print(c.numpy())
  43. Outputs:
  44. .. testoutput::
  45. [1]
  46. [[1 1 1]
  47. [1 1 1]
  48. [1 1 1]]
  49. """
  50. op = ParamPackSplit()
  51. op.offsets = offsets
  52. op.shapes = [s or (1,) for s in shapes]
  53. outputs = apply(op, inp)
  54. for s, x in zip(shapes, outputs):
  55. if not s:
  56. x._setscalar()
  57. return outputs
  58. def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list):
  59. r"""
  60. Returns concated tensor, only used for ``parampack``.
  61. :param inps: input tensors.
  62. :param offsets: device value of offsets.
  63. :param offsets_val: offsets of inputs, length of `2 * n`,
  64. format `[begin0, end0, begin1, end1]`.
  65. :return: concated tensor.
  66. Examples:
  67. .. testcode::
  68. import numpy as np
  69. from megengine import tensor
  70. from megengine.distributed.helper import param_pack_concat
  71. a = tensor(np.ones((1,), np.int32))
  72. b = tensor(np.ones((3, 3), np.int32))
  73. offsets_val = [0, 1, 1, 10]
  74. offsets = tensor(offsets_val, np.int32)
  75. c = param_pack_concat([a, b], offsets, offsets_val)
  76. print(c.numpy())
  77. Outputs:
  78. .. testoutput::
  79. [1 1 1 1 1 1 1 1 1 1]
  80. """
  81. op = ParamPackConcat()
  82. op.offsets = offsets_val
  83. return apply(op, *inps, offsets)[0]
  84. def get_offsets(shapes):
  85. offsets = []
  86. offset = 0
  87. for shape in shapes:
  88. offsets.append(offset)
  89. offset += int(np.prod(shape))
  90. offsets.append(offset)
  91. return offsets
  92. def pack_allreduce_split(pack_list, shapes, group, reduce_method):
  93. offsets_val = get_offsets(shapes)
  94. offsets = Tensor(offsets_val)
  95. packed_grads = param_pack_concat(pack_list, offsets, offsets_val)
  96. packed_grads = all_reduce_sum(packed_grads, group, group.comp_node)
  97. if reduce_method == "mean":
  98. packed_grads /= group.size
  99. grads = param_pack_split(packed_grads, offsets_val, shapes)
  100. return grads
  101. class TensorFuture(Future):
  102. def device(self):
  103. raise "Sorry, this tensor is not ready"
  104. def numpy(self):
  105. raise "Sorry, this tensor is not ready"
  106. def shape(self):
  107. raise "Sorry, this tensor is not ready"
  108. def dtype(self):
  109. raise "Sorry, this tensor is not ready"
  110. def synchronized(func: Callable):
  111. """
  112. Decorator. Decorated function will synchronize when finished.
  113. Specifically, we use this to prevent data race during hub.load"""
  114. @functools.wraps(func)
  115. def wrapper(*args, **kwargs):
  116. if not is_distributed():
  117. return func(*args, **kwargs)
  118. ret = func(*args, **kwargs)
  119. group_barrier()
  120. return ret
  121. return wrapper
  122. def _get_device_count_worker(queue, device_type):
  123. num = get_device_count(device_type)
  124. queue.put(num)
  125. def get_device_count_by_fork(device_type: str):
  126. """
  127. Get device count in fork thread.
  128. See https://stackoverflow.com/questions/22950047/cuda-initialization-error-after-fork
  129. for more information.
  130. """
  131. q = mp.Queue()
  132. p = mp.Process(target=_get_device_count_worker, args=(q, device_type))
  133. p.start()
  134. p.join()
  135. return q.get()
  136. def bcast_list_(inps: list, group: Group = WORLD):
  137. """
  138. Broadcast tensors between given group.
  139. :param inps: input tensors.
  140. :param group: communication group.
  141. """
  142. for inp in inps:
  143. inp._reset(_bcast_param(inp, group))
  144. class AllreduceCallback:
  145. """
  146. Allreduce Callback with tensor fusion optimization.
  147. :param reduce_method: the method to reduce gradiants.
  148. :param group: communication group.
  149. """
  150. def __init__(self, reduce_method: str, group: Group = WORLD):
  151. reduce_method = reduce_method.lower()
  152. assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean"
  153. self._reduce_method = reduce_method
  154. self._group = group
  155. self._marked_gm = WeakSet()
  156. self._param_pack_thd = 10 * 1024 * 1024
  157. self._reset()
  158. def _reset(self):
  159. self._params = []
  160. self._gradients_dict = dict()
  161. self._futures_dict = dict()
  162. self._packing_list = defaultdict(list)
  163. self._packing_size = defaultdict(int)
  164. self._grad_origin_device = dict()
  165. def _pack(self, dtype):
  166. if len(self._packing_list[dtype]) == 0:
  167. return
  168. grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]]
  169. shapes = [p._tuple_shape for p in self._packing_list[dtype]]
  170. reduced_grads = pack_allreduce_split(
  171. grad_list, shapes, self._group, self._reduce_method
  172. )
  173. for param, grad in zip(self._packing_list[dtype], reduced_grads):
  174. self._gradients_dict[param] = grad
  175. self._packing_list[dtype] = []
  176. self._packing_size[dtype] = 0
  177. def __call__(self, param, grad):
  178. gm = get_backwarding_grad_manager()
  179. assert isinstance(gm, GradManager)
  180. if gm not in self._marked_gm:
  181. gm._register_after_backward_callback(self._flush)
  182. self._marked_gm.add(gm)
  183. self._params.append(param)
  184. self._futures_dict[param] = TensorFuture(ack=False)
  185. self._gradients_dict[param] = grad
  186. self._grad_origin_device[param] = str(grad.device)
  187. dtype_str = str(np.dtype(param.dtype))
  188. dtype_size = np.dtype(param.dtype).itemsize
  189. self._packing_list[dtype_str].append(param)
  190. self._packing_size[dtype_str] += int(np.prod(param._tuple_shape)) * dtype_size
  191. if self._packing_size[dtype_str] > self._param_pack_thd:
  192. self._pack(dtype_str)
  193. return self._futures_dict[param]
  194. def _flush(self):
  195. for dtype in sorted(self._packing_list.keys()):
  196. self._pack(dtype)
  197. for param in self._params:
  198. grad = self._gradients_dict[param]
  199. grad = copy(grad, self._grad_origin_device[param])
  200. self._futures_dict[param].set(grad)
  201. self._reset()
  202. make_allreduce_cb = AllreduceCallback

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台