GitOrigin-RevId: 73d53eeba1
tags/v1.0.0-rc1
@@ -1,19 +1,39 @@ | |||
from collections import defaultdict | |||
from contextlib import contextmanager | |||
from ..core.autodiff.grad import Grad | |||
from ..distributed.util import Future | |||
from ..tensor import tensor | |||
backwarding_grad_manager = None | |||
def get_backwarding_grad_manager(): | |||
return backwarding_grad_manager | |||
class GradManager: | |||
def __init__(self): | |||
self._call_back_pair = [] | |||
self._call_back_dict = defaultdict(list) | |||
self._param_dict = dict() | |||
self._recording = False | |||
self._grad = None | |||
self._after_backward_callback = [] | |||
self._gradients = dict() | |||
def register(self, params, callbacks=[]): | |||
for p in params: | |||
self._param_dict[id(p)] = p | |||
for cb in callbacks: | |||
self._call_back_dict[id(p)].append(cb) | |||
def register(self, params, callbacks=None): | |||
self._call_back_pair.append([list(params), callbacks or []]) | |||
def register_after_backward_callback(self, callback): | |||
self._after_backward_callback.append(callback) | |||
def backward(self, ys, dys=None): | |||
global backwarding_grad_manager | |||
cache = backwarding_grad_manager | |||
backwarding_grad_manager = self | |||
if not self._recording: | |||
raise RuntimeError( | |||
"no computation history. " | |||
@@ -29,8 +49,20 @@ class GradManager: | |||
dys = [dys] | |||
try: | |||
self._grad(ys, dys) | |||
for callback in self._after_backward_callback: | |||
callback() | |||
for p, grad in self._gradients.items(): | |||
if isinstance(grad, Future): | |||
grad = grad.get() | |||
param = self._param_dict[p] | |||
if getattr(param, "grad", None) is None: | |||
param.grad = grad | |||
else: | |||
param.grad += grad | |||
finally: | |||
self._grad = None | |||
self._gradients = dict() | |||
backwarding_grad_manager = cache | |||
def record(self): | |||
@contextmanager | |||
@@ -41,20 +73,24 @@ class GradManager: | |||
try: | |||
self._recording = True | |||
self._grad = grad | |||
for params, callbacks in self._call_back_pair: | |||
for p in params: | |||
for param_id in self._param_dict.keys(): | |||
param_wrapper = self._param_dict[param_id] | |||
callbacks = self._call_back_dict[param_id] | |||
def callback(param, grad, callbacks=callbacks, p=p): | |||
ret = grad | |||
for cb in callbacks: | |||
ret = cb(param, ret) | |||
p.grad = ret | |||
def callback( | |||
param, grad, callbacks=callbacks, p=param_wrapper, gm=self | |||
): | |||
ret = grad | |||
for cb in callbacks: | |||
ret = cb(param, ret) | |||
gm._gradients[id(p)] = ret | |||
grad.wrt(p, callback=callback) | |||
grad.wrt(param_wrapper, callback=callback) | |||
with grad: | |||
yield | |||
finally: | |||
self._recording = False | |||
self._grad = None | |||
self._gradients = dict() | |||
return recorder() |
@@ -8,12 +8,33 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
import multiprocessing as mp | |||
from collections import defaultdict | |||
from typing import Callable | |||
from megengine.device import get_device_count | |||
import numpy as np | |||
from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager | |||
from megengine.device import get_default_device, get_device_count | |||
from ..functional.param_pack import get_offsets, pack_allreduce_split | |||
from ..functional.utils import copy | |||
from .functional import all_reduce_sum, broadcast | |||
from .group import WORLD, group_barrier, is_distributed | |||
from .util import Future | |||
class FakeTensor(Future): | |||
def device(self): | |||
raise "Sorry, this tensor is not ready" | |||
def numpy(self): | |||
raise "Sorry, this tensor is not ready" | |||
def shape(self): | |||
raise "Sorry, this tensor is not ready" | |||
def dtype(self): | |||
raise "Sorry, this tensor is not ready" | |||
def synchronized(func: Callable): | |||
@@ -52,14 +73,58 @@ def bcast_params_(params, group): | |||
class AllreduceCallback: | |||
def __init__(self, reduce_method, group=WORLD): | |||
reduce_method = reduce_method.lower() | |||
assert reduce_method in ["sum", "mean"] | |||
self._reduce_method = reduce_method | |||
self._group = group | |||
self._gm_set = set() | |||
self._param_pack_thd = 10 * 1024 * 1024 | |||
self._reset() | |||
def _reset(self): | |||
self._params = [] | |||
self._gradients_dict = dict() | |||
self._futures_dict = dict() | |||
self._packing_list = defaultdict(list) | |||
self._packing_size = defaultdict(int) | |||
def _pack(self, dtype): | |||
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] | |||
shapes = [p.shape for p in self._packing_list[dtype]] | |||
reduced_grads = pack_allreduce_split( | |||
grad_list, shapes, self._group, self._reduce_method | |||
) | |||
for param, grad in zip(self._packing_list[dtype], reduced_grads): | |||
self._gradients_dict[param] = grad | |||
self._packing_list[dtype] = [] | |||
self._packing_size[dtype] = 0 | |||
def __call__(self, param, grad): | |||
ret = all_reduce_sum(grad, self._group) | |||
if self._reduce_method == "MEAN": | |||
ret = ret / self._group.size | |||
return ret | |||
gm = get_backwarding_grad_manager() | |||
assert isinstance(gm, GradManager) | |||
if gm not in self._gm_set: | |||
gm.register_after_backward_callback(self._flush) | |||
self._gm_set.add(gm) | |||
self._params.append(param) | |||
self._futures_dict[param] = FakeTensor(ack=False) | |||
self._gradients_dict[param] = grad | |||
self._packing_list[param.dtype].append(param) | |||
self._packing_size[param.dtype] += ( | |||
int(np.prod(list(param.shape))) * np.dtype(param.dtype).itemsize | |||
) | |||
if self._packing_size[param.dtype] > self._param_pack_thd: | |||
self._pack(param.dtype) | |||
return self._futures_dict[param] | |||
def _flush(self): | |||
for dtype in self._packing_list.keys(): | |||
self._pack(dtype) | |||
for param in self._params: | |||
grad = self._gradients_dict[param] | |||
grad = copy(grad, get_default_device()) | |||
self._futures_dict[param].set(grad) | |||
self._reset() | |||
make_allreduce_cb = AllreduceCallback |
@@ -16,25 +16,7 @@ from xmlrpc.client import ServerProxy | |||
from xmlrpc.server import SimpleXMLRPCServer | |||
from ..core._imperative_rt.utils import create_mm_server | |||
from .util import get_free_ports | |||
class Future: | |||
def __init__(self, ack=True): | |||
self.ready = threading.Event() | |||
self.ack = threading.Event() if ack else None | |||
def set(self, value): | |||
self.value = value | |||
self.ready.set() | |||
if self.ack: | |||
self.ack.wait() | |||
def get(self): | |||
self.ready.wait() | |||
if self.ack: | |||
self.ack.set() | |||
return self.value | |||
from .util import Future, get_free_ports | |||
class Methods: | |||
@@ -8,9 +8,28 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import functools | |||
import socket | |||
import threading | |||
from typing import List | |||
class Future: | |||
def __init__(self, ack=True): | |||
self.ready = threading.Event() | |||
self.ack = threading.Event() if ack else None | |||
def set(self, value): | |||
self.value = value | |||
self.ready.set() | |||
if self.ack: | |||
self.ack.wait() | |||
def get(self): | |||
self.ready.wait() | |||
if self.ack: | |||
self.ack.set() | |||
return self.value | |||
def get_free_ports(num: int) -> List[int]: | |||
"""Get one or more free ports. | |||
""" | |||
@@ -7,7 +7,6 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# pylint: disable=redefined-builtin | |||
from . import distributed | |||
from .elemwise import * | |||
from .graph import add_update | |||
from .loss import ( | |||
@@ -27,6 +26,8 @@ from .quantized import conv_bias_activation | |||
from .tensor import * | |||
from .utils import accuracy, copy, zero_grad | |||
from . import distributed # isort:skip | |||
# delete namespace | |||
# pylint: disable=undefined-variable | |||
# del elemwise, graph, loss, math, nn, tensor # type: ignore[name-defined] |
@@ -0,0 +1,34 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from ..functional.distributed import all_reduce_sum | |||
from ..tensor import Tensor | |||
from .tensor import param_pack_concat, param_pack_split | |||
def get_offsets(shapes): | |||
offsets = [] | |||
offset = 0 | |||
for shape in shapes: | |||
offsets.append(offset) | |||
offset += int(np.prod(shape)) | |||
offsets.append(offset) | |||
return offsets | |||
def pack_allreduce_split(pack_list, shapes, group, reduce_method): | |||
offsets_val = get_offsets(shapes) | |||
offsets = Tensor(offsets_val) | |||
packed_grads = param_pack_concat(pack_list, offsets, offsets_val) | |||
packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) | |||
if reduce_method == "mean": | |||
packed_grads /= dist_group.size | |||
grads = param_pack_split(packed_grads, offsets_val, shapes) | |||
return grads |
@@ -1,79 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from ..functional import param_pack_concat, param_pack_split | |||
from ..functional.distributed import all_reduce_sum | |||
from ..tensor import Tensor | |||
def get_offsets(shapes): | |||
offsets = [] | |||
offset = 0 | |||
for shape in shapes: | |||
offsets.append(offset) | |||
offset += int(np.prod(shape)) | |||
offsets.append(offset) | |||
return offsets | |||
def get_pack_list(param_group, param_pack_thd): | |||
pack_list = dict() | |||
shape_list = dict() | |||
pack_sum = dict() | |||
pack_ret, shape_ret = [], [] | |||
ignore_first = 8 | |||
ignore_last = 0 | |||
orders_len = len(param_group["orders"]) | |||
for i, idx in enumerate(param_group["orders"]): | |||
param = param_group["params"][idx] | |||
dtype = str(np.dtype(param.dtype)) | |||
dtype_size = np.dtype(param.dtype).itemsize | |||
shape = param.shape | |||
if ignore_first > 0: | |||
ignore_first -= 1 | |||
pack_ret.append([idx]) | |||
shape_ret.append([shape]) | |||
continue | |||
if dtype in pack_list.keys(): | |||
pack_list[dtype].append(idx) | |||
shape_list[dtype].append(shape) | |||
pack_sum[dtype] += int(np.prod(shape)) | |||
else: | |||
pack_list[dtype] = [idx] | |||
shape_list[dtype] = [shape] | |||
pack_sum[dtype] = int(np.prod(shape)) | |||
if ( | |||
pack_sum[dtype] * dtype_size > param_pack_thd | |||
or i + ignore_last > orders_len | |||
): | |||
pack_ret.append(pack_list[dtype]) | |||
shape_ret.append(shape_list[dtype]) | |||
pack_list[dtype] = [] | |||
shape_list[dtype] = [] | |||
pack_sum[dtype] = 0 | |||
for key in sorted(pack_list.keys()): | |||
if len(pack_list[key]) > 0: | |||
pack_ret.append(pack_list[key]) | |||
shape_ret.append(shape_list[key]) | |||
return pack_ret, shape_ret | |||
def pack_allreduce_split(group, pack, shapes, reduce_method): | |||
dist_group = group["dist_group"] | |||
grads = [group["grads"][idx] for idx in pack] | |||
offsets_val = get_offsets(shapes) | |||
offsets = Tensor(offsets_val) | |||
packed_grads = param_pack_concat(grads, offsets, offsets_val) | |||
packed_grads = all_reduce_sum(packed_grads, dist_group, dist_group.comp_node) | |||
if reduce_method == "mean": | |||
packed_grads /= dist_group.size | |||
grads = param_pack_split(packed_grads, offsets_val, shapes) | |||
for i, grad in enumerate(grads): | |||
group["grads"][pack[i]] = grad |