From b5016b9d29650896b57bdccb28c9c19779c6a026 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 7 Sep 2020 19:41:46 +0800 Subject: [PATCH] feat(mge/parampack): add parampack in allreduce callback GitOrigin-RevId: 73d53eeba16a792a427f57783f382267f3d33d43 --- .../python/megengine/autodiff/grad_manager.py | 58 +++++++++++++--- imperative/python/megengine/distributed/helper.py | 75 ++++++++++++++++++-- imperative/python/megengine/distributed/server.py | 20 +----- imperative/python/megengine/distributed/util.py | 19 ++++++ imperative/python/megengine/functional/__init__.py | 3 +- .../python/megengine/functional/param_pack.py | 34 ++++++++++ .../python/megengine/optimizer/param_pack.py | 79 ---------------------- 7 files changed, 173 insertions(+), 115 deletions(-) create mode 100644 imperative/python/megengine/functional/param_pack.py delete mode 100644 imperative/python/megengine/optimizer/param_pack.py diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 489067b5..07d716c1 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -1,19 +1,39 @@ +from collections import defaultdict from contextlib import contextmanager from ..core.autodiff.grad import Grad +from ..distributed.util import Future from ..tensor import tensor +backwarding_grad_manager = None + + +def get_backwarding_grad_manager(): + return backwarding_grad_manager + class GradManager: def __init__(self): - self._call_back_pair = [] + self._call_back_dict = defaultdict(list) + self._param_dict = dict() self._recording = False self._grad = None + self._after_backward_callback = [] + self._gradients = dict() + + def register(self, params, callbacks=[]): + for p in params: + self._param_dict[id(p)] = p + for cb in callbacks: + self._call_back_dict[id(p)].append(cb) - def register(self, params, callbacks=None): - self._call_back_pair.append([list(params), callbacks or []]) + def register_after_backward_callback(self, callback): + self._after_backward_callback.append(callback) def backward(self, ys, dys=None): + global backwarding_grad_manager + cache = backwarding_grad_manager + backwarding_grad_manager = self if not self._recording: raise RuntimeError( "no computation history. " @@ -29,8 +49,20 @@ class GradManager: dys = [dys] try: self._grad(ys, dys) + for callback in self._after_backward_callback: + callback() + for p, grad in self._gradients.items(): + if isinstance(grad, Future): + grad = grad.get() + param = self._param_dict[p] + if getattr(param, "grad", None) is None: + param.grad = grad + else: + param.grad += grad finally: self._grad = None + self._gradients = dict() + backwarding_grad_manager = cache def record(self): @contextmanager @@ -41,20 +73,24 @@ class GradManager: try: self._recording = True self._grad = grad - for params, callbacks in self._call_back_pair: - for p in params: + for param_id in self._param_dict.keys(): + param_wrapper = self._param_dict[param_id] + callbacks = self._call_back_dict[param_id] - def callback(param, grad, callbacks=callbacks, p=p): - ret = grad - for cb in callbacks: - ret = cb(param, ret) - p.grad = ret + def callback( + param, grad, callbacks=callbacks, p=param_wrapper, gm=self + ): + ret = grad + for cb in callbacks: + ret = cb(param, ret) + gm._gradients[id(p)] = ret - grad.wrt(p, callback=callback) + grad.wrt(param_wrapper, callback=callback) with grad: yield finally: self._recording = False self._grad = None + self._gradients = dict() return recorder() diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 8b787b7a..2d3b64c6 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -8,12 +8,33 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import functools import multiprocessing as mp +from collections import defaultdict from typing import Callable -from megengine.device import get_device_count +import numpy as np +from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager +from megengine.device import get_default_device, get_device_count + +from ..functional.param_pack import get_offsets, pack_allreduce_split +from ..functional.utils import copy from .functional import all_reduce_sum, broadcast from .group import WORLD, group_barrier, is_distributed +from .util import Future + + +class FakeTensor(Future): + def device(self): + raise "Sorry, this tensor is not ready" + + def numpy(self): + raise "Sorry, this tensor is not ready" + + def shape(self): + raise "Sorry, this tensor is not ready" + + def dtype(self): + raise "Sorry, this tensor is not ready" def synchronized(func: Callable): @@ -52,14 +73,58 @@ def bcast_params_(params, group): class AllreduceCallback: def __init__(self, reduce_method, group=WORLD): + reduce_method = reduce_method.lower() + assert reduce_method in ["sum", "mean"] self._reduce_method = reduce_method self._group = group + self._gm_set = set() + self._param_pack_thd = 10 * 1024 * 1024 + self._reset() + + def _reset(self): + self._params = [] + self._gradients_dict = dict() + self._futures_dict = dict() + self._packing_list = defaultdict(list) + self._packing_size = defaultdict(int) + + def _pack(self, dtype): + grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] + shapes = [p.shape for p in self._packing_list[dtype]] + reduced_grads = pack_allreduce_split( + grad_list, shapes, self._group, self._reduce_method + ) + for param, grad in zip(self._packing_list[dtype], reduced_grads): + self._gradients_dict[param] = grad + self._packing_list[dtype] = [] + self._packing_size[dtype] = 0 def __call__(self, param, grad): - ret = all_reduce_sum(grad, self._group) - if self._reduce_method == "MEAN": - ret = ret / self._group.size - return ret + gm = get_backwarding_grad_manager() + assert isinstance(gm, GradManager) + if gm not in self._gm_set: + gm.register_after_backward_callback(self._flush) + self._gm_set.add(gm) + self._params.append(param) + self._futures_dict[param] = FakeTensor(ack=False) + self._gradients_dict[param] = grad + + self._packing_list[param.dtype].append(param) + self._packing_size[param.dtype] += ( + int(np.prod(list(param.shape))) * np.dtype(param.dtype).itemsize + ) + if self._packing_size[param.dtype] > self._param_pack_thd: + self._pack(param.dtype) + return self._futures_dict[param] + + def _flush(self): + for dtype in self._packing_list.keys(): + self._pack(dtype) + for param in self._params: + grad = self._gradients_dict[param] + grad = copy(grad, get_default_device()) + self._futures_dict[param].set(grad) + self._reset() make_allreduce_cb = AllreduceCallback diff --git a/imperative/python/megengine/distributed/server.py b/imperative/python/megengine/distributed/server.py index d3d81120..d6e9ba44 100644 --- a/imperative/python/megengine/distributed/server.py +++ b/imperative/python/megengine/distributed/server.py @@ -16,25 +16,7 @@ from xmlrpc.client import ServerProxy from xmlrpc.server import SimpleXMLRPCServer from ..core._imperative_rt.utils import create_mm_server -from .util import get_free_ports - - -class Future: - def __init__(self, ack=True): - self.ready = threading.Event() - self.ack = threading.Event() if ack else None - - def set(self, value): - self.value = value - self.ready.set() - if self.ack: - self.ack.wait() - - def get(self): - self.ready.wait() - if self.ack: - self.ack.set() - return self.value +from .util import Future, get_free_ports class Methods: diff --git a/imperative/python/megengine/distributed/util.py b/imperative/python/megengine/distributed/util.py index b3a0a2aa..9f5be3fa 100644 --- a/imperative/python/megengine/distributed/util.py +++ b/imperative/python/megengine/distributed/util.py @@ -8,9 +8,28 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import functools import socket +import threading from typing import List +class Future: + def __init__(self, ack=True): + self.ready = threading.Event() + self.ack = threading.Event() if ack else None + + def set(self, value): + self.value = value + self.ready.set() + if self.ack: + self.ack.wait() + + def get(self): + self.ready.wait() + if self.ack: + self.ack.set() + return self.value + + def get_free_ports(num: int) -> List[int]: """Get one or more free ports. """ diff --git a/imperative/python/megengine/functional/__init__.py b/imperative/python/megengine/functional/__init__.py index cc999e2a..0a2b037d 100644 --- a/imperative/python/megengine/functional/__init__.py +++ b/imperative/python/megengine/functional/__init__.py @@ -7,7 +7,6 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # pylint: disable=redefined-builtin -from . import distributed from .elemwise import * from .graph import add_update from .loss import ( @@ -27,6 +26,8 @@ from .quantized import conv_bias_activation from .tensor import * from .utils import accuracy, copy, zero_grad +from . import distributed # isort:skip + # delete namespace # pylint: disable=undefined-variable # del elemwise, graph, loss, math, nn, tensor # type: ignore[name-defined] diff --git a/imperative/python/megengine/functional/param_pack.py b/imperative/python/megengine/functional/param_pack.py new file mode 100644 index 00000000..0b25e08b --- /dev/null +++ b/imperative/python/megengine/functional/param_pack.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np + +from ..functional.distributed import all_reduce_sum +from ..tensor import Tensor +from .tensor import param_pack_concat, param_pack_split + + +def get_offsets(shapes): + offsets = [] + offset = 0 + for shape in shapes: + offsets.append(offset) + offset += int(np.prod(shape)) + offsets.append(offset) + return offsets + + +def pack_allreduce_split(pack_list, shapes, group, reduce_method): + offsets_val = get_offsets(shapes) + offsets = Tensor(offsets_val) + packed_grads = param_pack_concat(pack_list, offsets, offsets_val) + packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) + if reduce_method == "mean": + packed_grads /= dist_group.size + grads = param_pack_split(packed_grads, offsets_val, shapes) + return grads diff --git a/imperative/python/megengine/optimizer/param_pack.py b/imperative/python/megengine/optimizer/param_pack.py deleted file mode 100644 index ea117aa8..00000000 --- a/imperative/python/megengine/optimizer/param_pack.py +++ /dev/null @@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import numpy as np - -from ..functional import param_pack_concat, param_pack_split -from ..functional.distributed import all_reduce_sum -from ..tensor import Tensor - - -def get_offsets(shapes): - offsets = [] - offset = 0 - for shape in shapes: - offsets.append(offset) - offset += int(np.prod(shape)) - offsets.append(offset) - return offsets - - -def get_pack_list(param_group, param_pack_thd): - pack_list = dict() - shape_list = dict() - pack_sum = dict() - pack_ret, shape_ret = [], [] - ignore_first = 8 - ignore_last = 0 - orders_len = len(param_group["orders"]) - for i, idx in enumerate(param_group["orders"]): - param = param_group["params"][idx] - dtype = str(np.dtype(param.dtype)) - dtype_size = np.dtype(param.dtype).itemsize - shape = param.shape - if ignore_first > 0: - ignore_first -= 1 - pack_ret.append([idx]) - shape_ret.append([shape]) - continue - if dtype in pack_list.keys(): - pack_list[dtype].append(idx) - shape_list[dtype].append(shape) - pack_sum[dtype] += int(np.prod(shape)) - else: - pack_list[dtype] = [idx] - shape_list[dtype] = [shape] - pack_sum[dtype] = int(np.prod(shape)) - if ( - pack_sum[dtype] * dtype_size > param_pack_thd - or i + ignore_last > orders_len - ): - pack_ret.append(pack_list[dtype]) - shape_ret.append(shape_list[dtype]) - pack_list[dtype] = [] - shape_list[dtype] = [] - pack_sum[dtype] = 0 - for key in sorted(pack_list.keys()): - if len(pack_list[key]) > 0: - pack_ret.append(pack_list[key]) - shape_ret.append(shape_list[key]) - return pack_ret, shape_ret - - -def pack_allreduce_split(group, pack, shapes, reduce_method): - dist_group = group["dist_group"] - grads = [group["grads"][idx] for idx in pack] - offsets_val = get_offsets(shapes) - offsets = Tensor(offsets_val) - packed_grads = param_pack_concat(grads, offsets, offsets_val) - packed_grads = all_reduce_sum(packed_grads, dist_group, dist_group.comp_node) - if reduce_method == "mean": - packed_grads /= dist_group.size - grads = param_pack_split(packed_grads, offsets_val, shapes) - for i, grad in enumerate(grads): - group["grads"][pack[i]] = grad