GitOrigin-RevId: 73d53eeba1
tags/v1.0.0-rc1
@@ -1,19 +1,39 @@ | |||||
from collections import defaultdict | |||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
from ..core.autodiff.grad import Grad | from ..core.autodiff.grad import Grad | ||||
from ..distributed.util import Future | |||||
from ..tensor import tensor | from ..tensor import tensor | ||||
backwarding_grad_manager = None | |||||
def get_backwarding_grad_manager(): | |||||
return backwarding_grad_manager | |||||
class GradManager: | class GradManager: | ||||
def __init__(self): | def __init__(self): | ||||
self._call_back_pair = [] | |||||
self._call_back_dict = defaultdict(list) | |||||
self._param_dict = dict() | |||||
self._recording = False | self._recording = False | ||||
self._grad = None | self._grad = None | ||||
self._after_backward_callback = [] | |||||
self._gradients = dict() | |||||
def register(self, params, callbacks=[]): | |||||
for p in params: | |||||
self._param_dict[id(p)] = p | |||||
for cb in callbacks: | |||||
self._call_back_dict[id(p)].append(cb) | |||||
def register(self, params, callbacks=None): | |||||
self._call_back_pair.append([list(params), callbacks or []]) | |||||
def register_after_backward_callback(self, callback): | |||||
self._after_backward_callback.append(callback) | |||||
def backward(self, ys, dys=None): | def backward(self, ys, dys=None): | ||||
global backwarding_grad_manager | |||||
cache = backwarding_grad_manager | |||||
backwarding_grad_manager = self | |||||
if not self._recording: | if not self._recording: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
"no computation history. " | "no computation history. " | ||||
@@ -29,8 +49,20 @@ class GradManager: | |||||
dys = [dys] | dys = [dys] | ||||
try: | try: | ||||
self._grad(ys, dys) | self._grad(ys, dys) | ||||
for callback in self._after_backward_callback: | |||||
callback() | |||||
for p, grad in self._gradients.items(): | |||||
if isinstance(grad, Future): | |||||
grad = grad.get() | |||||
param = self._param_dict[p] | |||||
if getattr(param, "grad", None) is None: | |||||
param.grad = grad | |||||
else: | |||||
param.grad += grad | |||||
finally: | finally: | ||||
self._grad = None | self._grad = None | ||||
self._gradients = dict() | |||||
backwarding_grad_manager = cache | |||||
def record(self): | def record(self): | ||||
@contextmanager | @contextmanager | ||||
@@ -41,20 +73,24 @@ class GradManager: | |||||
try: | try: | ||||
self._recording = True | self._recording = True | ||||
self._grad = grad | self._grad = grad | ||||
for params, callbacks in self._call_back_pair: | |||||
for p in params: | |||||
for param_id in self._param_dict.keys(): | |||||
param_wrapper = self._param_dict[param_id] | |||||
callbacks = self._call_back_dict[param_id] | |||||
def callback(param, grad, callbacks=callbacks, p=p): | |||||
ret = grad | |||||
for cb in callbacks: | |||||
ret = cb(param, ret) | |||||
p.grad = ret | |||||
def callback( | |||||
param, grad, callbacks=callbacks, p=param_wrapper, gm=self | |||||
): | |||||
ret = grad | |||||
for cb in callbacks: | |||||
ret = cb(param, ret) | |||||
gm._gradients[id(p)] = ret | |||||
grad.wrt(p, callback=callback) | |||||
grad.wrt(param_wrapper, callback=callback) | |||||
with grad: | with grad: | ||||
yield | yield | ||||
finally: | finally: | ||||
self._recording = False | self._recording = False | ||||
self._grad = None | self._grad = None | ||||
self._gradients = dict() | |||||
return recorder() | return recorder() |
@@ -8,12 +8,33 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import functools | import functools | ||||
import multiprocessing as mp | import multiprocessing as mp | ||||
from collections import defaultdict | |||||
from typing import Callable | from typing import Callable | ||||
from megengine.device import get_device_count | |||||
import numpy as np | |||||
from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager | |||||
from megengine.device import get_default_device, get_device_count | |||||
from ..functional.param_pack import get_offsets, pack_allreduce_split | |||||
from ..functional.utils import copy | |||||
from .functional import all_reduce_sum, broadcast | from .functional import all_reduce_sum, broadcast | ||||
from .group import WORLD, group_barrier, is_distributed | from .group import WORLD, group_barrier, is_distributed | ||||
from .util import Future | |||||
class FakeTensor(Future): | |||||
def device(self): | |||||
raise "Sorry, this tensor is not ready" | |||||
def numpy(self): | |||||
raise "Sorry, this tensor is not ready" | |||||
def shape(self): | |||||
raise "Sorry, this tensor is not ready" | |||||
def dtype(self): | |||||
raise "Sorry, this tensor is not ready" | |||||
def synchronized(func: Callable): | def synchronized(func: Callable): | ||||
@@ -52,14 +73,58 @@ def bcast_params_(params, group): | |||||
class AllreduceCallback: | class AllreduceCallback: | ||||
def __init__(self, reduce_method, group=WORLD): | def __init__(self, reduce_method, group=WORLD): | ||||
reduce_method = reduce_method.lower() | |||||
assert reduce_method in ["sum", "mean"] | |||||
self._reduce_method = reduce_method | self._reduce_method = reduce_method | ||||
self._group = group | self._group = group | ||||
self._gm_set = set() | |||||
self._param_pack_thd = 10 * 1024 * 1024 | |||||
self._reset() | |||||
def _reset(self): | |||||
self._params = [] | |||||
self._gradients_dict = dict() | |||||
self._futures_dict = dict() | |||||
self._packing_list = defaultdict(list) | |||||
self._packing_size = defaultdict(int) | |||||
def _pack(self, dtype): | |||||
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] | |||||
shapes = [p.shape for p in self._packing_list[dtype]] | |||||
reduced_grads = pack_allreduce_split( | |||||
grad_list, shapes, self._group, self._reduce_method | |||||
) | |||||
for param, grad in zip(self._packing_list[dtype], reduced_grads): | |||||
self._gradients_dict[param] = grad | |||||
self._packing_list[dtype] = [] | |||||
self._packing_size[dtype] = 0 | |||||
def __call__(self, param, grad): | def __call__(self, param, grad): | ||||
ret = all_reduce_sum(grad, self._group) | |||||
if self._reduce_method == "MEAN": | |||||
ret = ret / self._group.size | |||||
return ret | |||||
gm = get_backwarding_grad_manager() | |||||
assert isinstance(gm, GradManager) | |||||
if gm not in self._gm_set: | |||||
gm.register_after_backward_callback(self._flush) | |||||
self._gm_set.add(gm) | |||||
self._params.append(param) | |||||
self._futures_dict[param] = FakeTensor(ack=False) | |||||
self._gradients_dict[param] = grad | |||||
self._packing_list[param.dtype].append(param) | |||||
self._packing_size[param.dtype] += ( | |||||
int(np.prod(list(param.shape))) * np.dtype(param.dtype).itemsize | |||||
) | |||||
if self._packing_size[param.dtype] > self._param_pack_thd: | |||||
self._pack(param.dtype) | |||||
return self._futures_dict[param] | |||||
def _flush(self): | |||||
for dtype in self._packing_list.keys(): | |||||
self._pack(dtype) | |||||
for param in self._params: | |||||
grad = self._gradients_dict[param] | |||||
grad = copy(grad, get_default_device()) | |||||
self._futures_dict[param].set(grad) | |||||
self._reset() | |||||
make_allreduce_cb = AllreduceCallback | make_allreduce_cb = AllreduceCallback |
@@ -16,25 +16,7 @@ from xmlrpc.client import ServerProxy | |||||
from xmlrpc.server import SimpleXMLRPCServer | from xmlrpc.server import SimpleXMLRPCServer | ||||
from ..core._imperative_rt.utils import create_mm_server | from ..core._imperative_rt.utils import create_mm_server | ||||
from .util import get_free_ports | |||||
class Future: | |||||
def __init__(self, ack=True): | |||||
self.ready = threading.Event() | |||||
self.ack = threading.Event() if ack else None | |||||
def set(self, value): | |||||
self.value = value | |||||
self.ready.set() | |||||
if self.ack: | |||||
self.ack.wait() | |||||
def get(self): | |||||
self.ready.wait() | |||||
if self.ack: | |||||
self.ack.set() | |||||
return self.value | |||||
from .util import Future, get_free_ports | |||||
class Methods: | class Methods: | ||||
@@ -8,9 +8,28 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import functools | import functools | ||||
import socket | import socket | ||||
import threading | |||||
from typing import List | from typing import List | ||||
class Future: | |||||
def __init__(self, ack=True): | |||||
self.ready = threading.Event() | |||||
self.ack = threading.Event() if ack else None | |||||
def set(self, value): | |||||
self.value = value | |||||
self.ready.set() | |||||
if self.ack: | |||||
self.ack.wait() | |||||
def get(self): | |||||
self.ready.wait() | |||||
if self.ack: | |||||
self.ack.set() | |||||
return self.value | |||||
def get_free_ports(num: int) -> List[int]: | def get_free_ports(num: int) -> List[int]: | ||||
"""Get one or more free ports. | """Get one or more free ports. | ||||
""" | """ | ||||
@@ -7,7 +7,6 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
# pylint: disable=redefined-builtin | # pylint: disable=redefined-builtin | ||||
from . import distributed | |||||
from .elemwise import * | from .elemwise import * | ||||
from .graph import add_update | from .graph import add_update | ||||
from .loss import ( | from .loss import ( | ||||
@@ -27,6 +26,8 @@ from .quantized import conv_bias_activation | |||||
from .tensor import * | from .tensor import * | ||||
from .utils import accuracy, copy, zero_grad | from .utils import accuracy, copy, zero_grad | ||||
from . import distributed # isort:skip | |||||
# delete namespace | # delete namespace | ||||
# pylint: disable=undefined-variable | # pylint: disable=undefined-variable | ||||
# del elemwise, graph, loss, math, nn, tensor # type: ignore[name-defined] | # del elemwise, graph, loss, math, nn, tensor # type: ignore[name-defined] |
@@ -0,0 +1,34 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import numpy as np | |||||
from ..functional.distributed import all_reduce_sum | |||||
from ..tensor import Tensor | |||||
from .tensor import param_pack_concat, param_pack_split | |||||
def get_offsets(shapes): | |||||
offsets = [] | |||||
offset = 0 | |||||
for shape in shapes: | |||||
offsets.append(offset) | |||||
offset += int(np.prod(shape)) | |||||
offsets.append(offset) | |||||
return offsets | |||||
def pack_allreduce_split(pack_list, shapes, group, reduce_method): | |||||
offsets_val = get_offsets(shapes) | |||||
offsets = Tensor(offsets_val) | |||||
packed_grads = param_pack_concat(pack_list, offsets, offsets_val) | |||||
packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) | |||||
if reduce_method == "mean": | |||||
packed_grads /= dist_group.size | |||||
grads = param_pack_split(packed_grads, offsets_val, shapes) | |||||
return grads |
@@ -1,79 +0,0 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import numpy as np | |||||
from ..functional import param_pack_concat, param_pack_split | |||||
from ..functional.distributed import all_reduce_sum | |||||
from ..tensor import Tensor | |||||
def get_offsets(shapes): | |||||
offsets = [] | |||||
offset = 0 | |||||
for shape in shapes: | |||||
offsets.append(offset) | |||||
offset += int(np.prod(shape)) | |||||
offsets.append(offset) | |||||
return offsets | |||||
def get_pack_list(param_group, param_pack_thd): | |||||
pack_list = dict() | |||||
shape_list = dict() | |||||
pack_sum = dict() | |||||
pack_ret, shape_ret = [], [] | |||||
ignore_first = 8 | |||||
ignore_last = 0 | |||||
orders_len = len(param_group["orders"]) | |||||
for i, idx in enumerate(param_group["orders"]): | |||||
param = param_group["params"][idx] | |||||
dtype = str(np.dtype(param.dtype)) | |||||
dtype_size = np.dtype(param.dtype).itemsize | |||||
shape = param.shape | |||||
if ignore_first > 0: | |||||
ignore_first -= 1 | |||||
pack_ret.append([idx]) | |||||
shape_ret.append([shape]) | |||||
continue | |||||
if dtype in pack_list.keys(): | |||||
pack_list[dtype].append(idx) | |||||
shape_list[dtype].append(shape) | |||||
pack_sum[dtype] += int(np.prod(shape)) | |||||
else: | |||||
pack_list[dtype] = [idx] | |||||
shape_list[dtype] = [shape] | |||||
pack_sum[dtype] = int(np.prod(shape)) | |||||
if ( | |||||
pack_sum[dtype] * dtype_size > param_pack_thd | |||||
or i + ignore_last > orders_len | |||||
): | |||||
pack_ret.append(pack_list[dtype]) | |||||
shape_ret.append(shape_list[dtype]) | |||||
pack_list[dtype] = [] | |||||
shape_list[dtype] = [] | |||||
pack_sum[dtype] = 0 | |||||
for key in sorted(pack_list.keys()): | |||||
if len(pack_list[key]) > 0: | |||||
pack_ret.append(pack_list[key]) | |||||
shape_ret.append(shape_list[key]) | |||||
return pack_ret, shape_ret | |||||
def pack_allreduce_split(group, pack, shapes, reduce_method): | |||||
dist_group = group["dist_group"] | |||||
grads = [group["grads"][idx] for idx in pack] | |||||
offsets_val = get_offsets(shapes) | |||||
offsets = Tensor(offsets_val) | |||||
packed_grads = param_pack_concat(grads, offsets, offsets_val) | |||||
packed_grads = all_reduce_sum(packed_grads, dist_group, dist_group.comp_node) | |||||
if reduce_method == "mean": | |||||
packed_grads /= dist_group.size | |||||
grads = param_pack_split(packed_grads, offsets_val, shapes) | |||||
for i, grad in enumerate(grads): | |||||
group["grads"][pack[i]] = grad |