GitOrigin-RevId: 673e11c5b6
tags/v1.0.0-rc1
@@ -2,8 +2,8 @@ from collections import defaultdict | |||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
from ..core.autodiff.grad import Grad | from ..core.autodiff.grad import Grad | ||||
from ..distributed.util import Future | |||||
from ..tensor import tensor | from ..tensor import tensor | ||||
from ..utils.future import Future | |||||
backwarding_grad_manager = None | backwarding_grad_manager = None | ||||
@@ -26,6 +26,7 @@ class GradManager: | |||||
self._param_dict[id(p)] = p | self._param_dict[id(p)] = p | ||||
for cb in callbacks: | for cb in callbacks: | ||||
self._call_back_dict[id(p)].append(cb) | self._call_back_dict[id(p)].append(cb) | ||||
return self | |||||
def register_after_backward_callback(self, callback): | def register_after_backward_callback(self, callback): | ||||
self._after_backward_callback.append(callback) | self._after_backward_callback.append(callback) | ||||
@@ -45,7 +46,7 @@ class GradManager: | |||||
if not isinstance(ys, (tuple, list)): | if not isinstance(ys, (tuple, list)): | ||||
ys = [ys] | ys = [ys] | ||||
if dys is None: | if dys is None: | ||||
dys = [tensor(1.0) for y in ys] | |||||
dys = [tensor(1.0).broadcast(y.shape) for y in ys] | |||||
if not isinstance(dys, (tuple, list)): | if not isinstance(dys, (tuple, list)): | ||||
dys = [dys] | dys = [dys] | ||||
try: | try: | ||||
@@ -178,8 +178,6 @@ class Grad: | |||||
assert len(ys) == len(dys) | assert len(ys) == len(dys) | ||||
ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()] | ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()] | ||||
if len(ids) == 0: | |||||
return | |||||
ys = [y for i, y in enumerate(ys) if i in ids] | ys = [y for i, y in enumerate(ys) if i in ids] | ||||
dys = [dy for i, dy in enumerate(dys) if i in ids] | dys = [dy for i, dy in enumerate(dys) if i in ids] | ||||
@@ -18,9 +18,9 @@ from megengine.device import get_default_device, get_device_count | |||||
from ..functional.param_pack import get_offsets, pack_allreduce_split | from ..functional.param_pack import get_offsets, pack_allreduce_split | ||||
from ..functional.utils import copy | from ..functional.utils import copy | ||||
from ..utils.future import Future | |||||
from .functional import all_reduce_sum, broadcast | from .functional import all_reduce_sum, broadcast | ||||
from .group import WORLD, group_barrier, is_distributed | from .group import WORLD, group_barrier, is_distributed | ||||
from .util import Future | |||||
class FakeTensor(Future): | class FakeTensor(Future): | ||||
@@ -77,7 +77,7 @@ class AllreduceCallback: | |||||
assert reduce_method in ["sum", "mean"] | assert reduce_method in ["sum", "mean"] | ||||
self._reduce_method = reduce_method | self._reduce_method = reduce_method | ||||
self._group = group | self._group = group | ||||
self._gm_set = set() | |||||
self._marked_gm = set() | |||||
self._param_pack_thd = 10 * 1024 * 1024 | self._param_pack_thd = 10 * 1024 * 1024 | ||||
self._reset() | self._reset() | ||||
@@ -87,6 +87,7 @@ class AllreduceCallback: | |||||
self._futures_dict = dict() | self._futures_dict = dict() | ||||
self._packing_list = defaultdict(list) | self._packing_list = defaultdict(list) | ||||
self._packing_size = defaultdict(int) | self._packing_size = defaultdict(int) | ||||
self._grad_origin_device = dict() | |||||
def _pack(self, dtype): | def _pack(self, dtype): | ||||
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] | grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] | ||||
@@ -102,27 +103,28 @@ class AllreduceCallback: | |||||
def __call__(self, param, grad): | def __call__(self, param, grad): | ||||
gm = get_backwarding_grad_manager() | gm = get_backwarding_grad_manager() | ||||
assert isinstance(gm, GradManager) | assert isinstance(gm, GradManager) | ||||
if gm not in self._gm_set: | |||||
if gm not in self._marked_gm: | |||||
gm.register_after_backward_callback(self._flush) | gm.register_after_backward_callback(self._flush) | ||||
self._gm_set.add(gm) | |||||
self._marked_gm.add(gm) | |||||
self._params.append(param) | self._params.append(param) | ||||
self._futures_dict[param] = FakeTensor(ack=False) | self._futures_dict[param] = FakeTensor(ack=False) | ||||
self._gradients_dict[param] = grad | self._gradients_dict[param] = grad | ||||
self._packing_list[param.dtype].append(param) | |||||
self._packing_size[param.dtype] += ( | |||||
int(np.prod(list(param.shape))) * np.dtype(param.dtype).itemsize | |||||
) | |||||
if self._packing_size[param.dtype] > self._param_pack_thd: | |||||
self._pack(param.dtype) | |||||
self._grad_origin_device[param] = str(grad.device) | |||||
dtype_str = str(np.dtype(param.dtype)) | |||||
dtype_size = np.dtype(param.dtype).itemsize | |||||
self._packing_list[dtype_str].append(param) | |||||
self._packing_size[dtype_str] += int(np.prod(param.shape)) * dtype_size | |||||
if self._packing_size[dtype_str] > self._param_pack_thd: | |||||
self._pack(dtype_str) | |||||
return self._futures_dict[param] | return self._futures_dict[param] | ||||
def _flush(self): | def _flush(self): | ||||
for dtype in self._packing_list.keys(): | |||||
for dtype in sorted(self._packing_list.keys()): | |||||
self._pack(dtype) | self._pack(dtype) | ||||
for param in self._params: | for param in self._params: | ||||
grad = self._gradients_dict[param] | grad = self._gradients_dict[param] | ||||
grad = copy(grad, get_default_device()) | |||||
grad = copy(grad, self._grad_origin_device[param]) | |||||
self._futures_dict[param].set(grad) | self._futures_dict[param].set(grad) | ||||
self._reset() | self._reset() | ||||
@@ -16,7 +16,8 @@ from xmlrpc.client import ServerProxy | |||||
from xmlrpc.server import SimpleXMLRPCServer | from xmlrpc.server import SimpleXMLRPCServer | ||||
from ..core._imperative_rt.utils import create_mm_server | from ..core._imperative_rt.utils import create_mm_server | ||||
from .util import Future, get_free_ports | |||||
from ..utils.future import Future | |||||
from .util import get_free_ports | |||||
class Methods: | class Methods: | ||||
@@ -8,28 +8,9 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import functools | import functools | ||||
import socket | import socket | ||||
import threading | |||||
from typing import List | from typing import List | ||||
class Future: | |||||
def __init__(self, ack=True): | |||||
self.ready = threading.Event() | |||||
self.ack = threading.Event() if ack else None | |||||
def set(self, value): | |||||
self.value = value | |||||
self.ready.set() | |||||
if self.ack: | |||||
self.ack.wait() | |||||
def get(self): | |||||
self.ready.wait() | |||||
if self.ack: | |||||
self.ack.set() | |||||
return self.value | |||||
def get_free_ports(num: int) -> List[int]: | def get_free_ports(num: int) -> List[int]: | ||||
"""Get one or more free ports. | """Get one or more free ports. | ||||
""" | """ | ||||
@@ -8,8 +8,8 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import numpy as np | import numpy as np | ||||
from ..functional.distributed import all_reduce_sum | |||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from .distributed import all_reduce_sum | |||||
from .tensor import param_pack_concat, param_pack_split | from .tensor import param_pack_concat, param_pack_split | ||||
@@ -29,6 +29,6 @@ def pack_allreduce_split(pack_list, shapes, group, reduce_method): | |||||
packed_grads = param_pack_concat(pack_list, offsets, offsets_val) | packed_grads = param_pack_concat(pack_list, offsets, offsets_val) | ||||
packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) | packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) | ||||
if reduce_method == "mean": | if reduce_method == "mean": | ||||
packed_grads /= dist_group.size | |||||
packed_grads /= group.size | |||||
grads = param_pack_split(packed_grads, offsets_val, shapes) | grads = param_pack_split(packed_grads, offsets_val, shapes) | ||||
return grads | return grads |
@@ -0,0 +1,26 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import threading | |||||
class Future: | |||||
def __init__(self, ack=True): | |||||
self.ready = threading.Event() | |||||
self.ack = threading.Event() if ack else None | |||||
def set(self, value): | |||||
self.value = value | |||||
self.ready.set() | |||||
if self.ack: | |||||
self.ack.wait() | |||||
def get(self): | |||||
self.ready.wait() | |||||
if self.ack: | |||||
self.ack.set() | |||||
return self.value |