|
|
@@ -18,9 +18,9 @@ from megengine.device import get_default_device, get_device_count |
|
|
|
|
|
|
|
from ..functional.param_pack import get_offsets, pack_allreduce_split |
|
|
|
from ..functional.utils import copy |
|
|
|
from ..utils.future import Future |
|
|
|
from .functional import all_reduce_sum, broadcast |
|
|
|
from .group import WORLD, group_barrier, is_distributed |
|
|
|
from .util import Future |
|
|
|
|
|
|
|
|
|
|
|
class FakeTensor(Future): |
|
|
@@ -77,7 +77,7 @@ class AllreduceCallback: |
|
|
|
assert reduce_method in ["sum", "mean"] |
|
|
|
self._reduce_method = reduce_method |
|
|
|
self._group = group |
|
|
|
self._gm_set = set() |
|
|
|
self._marked_gm = set() |
|
|
|
self._param_pack_thd = 10 * 1024 * 1024 |
|
|
|
self._reset() |
|
|
|
|
|
|
@@ -87,6 +87,7 @@ class AllreduceCallback: |
|
|
|
self._futures_dict = dict() |
|
|
|
self._packing_list = defaultdict(list) |
|
|
|
self._packing_size = defaultdict(int) |
|
|
|
self._grad_origin_device = dict() |
|
|
|
|
|
|
|
def _pack(self, dtype): |
|
|
|
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] |
|
|
@@ -102,27 +103,28 @@ class AllreduceCallback: |
|
|
|
def __call__(self, param, grad): |
|
|
|
gm = get_backwarding_grad_manager() |
|
|
|
assert isinstance(gm, GradManager) |
|
|
|
if gm not in self._gm_set: |
|
|
|
if gm not in self._marked_gm: |
|
|
|
gm.register_after_backward_callback(self._flush) |
|
|
|
self._gm_set.add(gm) |
|
|
|
self._marked_gm.add(gm) |
|
|
|
self._params.append(param) |
|
|
|
self._futures_dict[param] = FakeTensor(ack=False) |
|
|
|
self._gradients_dict[param] = grad |
|
|
|
|
|
|
|
self._packing_list[param.dtype].append(param) |
|
|
|
self._packing_size[param.dtype] += ( |
|
|
|
int(np.prod(list(param.shape))) * np.dtype(param.dtype).itemsize |
|
|
|
) |
|
|
|
if self._packing_size[param.dtype] > self._param_pack_thd: |
|
|
|
self._pack(param.dtype) |
|
|
|
self._grad_origin_device[param] = str(grad.device) |
|
|
|
|
|
|
|
dtype_str = str(np.dtype(param.dtype)) |
|
|
|
dtype_size = np.dtype(param.dtype).itemsize |
|
|
|
self._packing_list[dtype_str].append(param) |
|
|
|
self._packing_size[dtype_str] += int(np.prod(param.shape)) * dtype_size |
|
|
|
if self._packing_size[dtype_str] > self._param_pack_thd: |
|
|
|
self._pack(dtype_str) |
|
|
|
return self._futures_dict[param] |
|
|
|
|
|
|
|
def _flush(self): |
|
|
|
for dtype in self._packing_list.keys(): |
|
|
|
for dtype in sorted(self._packing_list.keys()): |
|
|
|
self._pack(dtype) |
|
|
|
for param in self._params: |
|
|
|
grad = self._gradients_dict[param] |
|
|
|
grad = copy(grad, get_default_device()) |
|
|
|
grad = copy(grad, self._grad_origin_device[param]) |
|
|
|
self._futures_dict[param].set(grad) |
|
|
|
self._reset() |
|
|
|
|
|
|
|