GitOrigin-RevId: a890c206a5
release-1.1
@@ -127,7 +127,7 @@ class GradManager: | |||||
self._after_backward_callback.append(callback) | self._after_backward_callback.append(callback) | ||||
return self | return self | ||||
def backward(self, ys, dys=None): | |||||
def backward(self, ys=None, dys=None): | |||||
r""" | r""" | ||||
Performs back-propagation and computes gradients. | Performs back-propagation and computes gradients. | ||||
@@ -146,6 +146,8 @@ class GradManager: | |||||
"call a method that clears the history?" | "call a method that clears the history?" | ||||
) | ) | ||||
assert self._grad is not None | assert self._grad is not None | ||||
if ys is None: | |||||
ys = [] | |||||
if not isinstance(ys, (tuple, list)): | if not isinstance(ys, (tuple, list)): | ||||
ys = [ys] | ys = [ys] | ||||
if dys is None: | if dys is None: | ||||
@@ -14,6 +14,8 @@ import weakref | |||||
import numpy as np | import numpy as np | ||||
import megengine as mge | |||||
from ..ops.builtin import Elemwise, OpDef | from ..ops.builtin import Elemwise, OpDef | ||||
from ..ops.special import Const | from ..ops.special import Const | ||||
from ..tensor.core import TensorBase, TensorWrapperBase, apply | from ..tensor.core import TensorBase, TensorWrapperBase, apply | ||||
@@ -167,6 +169,8 @@ class Grad: | |||||
for i in dys: | for i in dys: | ||||
if isinstance(i, TensorWrapperBase): | if isinstance(i, TensorWrapperBase): | ||||
return type(i) | return type(i) | ||||
# use Tensor as defualt wrapper | |||||
return mge.Tensor | |||||
Wrapper = check_wrapper() | Wrapper = check_wrapper() | ||||
@@ -59,7 +59,11 @@ def _(op: RemoteSend, inputs, outputs, input_requires_grad): | |||||
def backward(*args): | def backward(*args): | ||||
return [ | return [ | ||||
remote_recv( | remote_recv( | ||||
op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device) | |||||
op.rank_to, | |||||
inputs[0].shape, | |||||
inputs[0].dtype, | |||||
device=str(inputs[0].device), | |||||
inp=inputs[0], | |||||
) | ) | ||||
] | ] | ||||
@@ -275,7 +279,11 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||||
def remote_recv( | def remote_recv( | ||||
src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None | |||||
src_rank: int, | |||||
shape: Tuple[int], | |||||
dtype: type, | |||||
device: Optional[str] = None, | |||||
inp=None, | |||||
) -> Tensor: | ) -> Tensor: | ||||
""" | """ | ||||
Receive a Tensor from a remote process. | Receive a Tensor from a remote process. | ||||
@@ -284,13 +292,15 @@ def remote_recv( | |||||
:param shape: the shape of the tensor to receive. | :param shape: the shape of the tensor to receive. | ||||
:param dtype: the data type of the tensor to receive. | :param dtype: the data type of the tensor to receive. | ||||
:param device: the device to place the received tensor. | :param device: the device to place the received tensor. | ||||
:param inp: dummy input to determine recved tensor type | |||||
""" | """ | ||||
key = "{}->{}".format(src_rank, get_rank()) | key = "{}->{}".format(src_rank, get_rank()) | ||||
if device is None: | if device is None: | ||||
device = get_default_device() | device = get_default_device() | ||||
# dummpy input | |||||
inp = tensor([0]) | |||||
# dummy input | |||||
if inp == None: | |||||
inp = tensor([0]) | |||||
tracer_set = get_client().check_remote_tracer(key) | tracer_set = get_client().check_remote_tracer(key) | ||||
for grad_manager in get_grad_managers(): | for grad_manager in get_grad_managers(): | ||||
if grad_manager.name in tracer_set: | if grad_manager.name in tracer_set: | ||||
@@ -5,12 +5,19 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import platform | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
import megengine as mge | import megengine as mge | ||||
import megengine.distributed as dist | |||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.module as M | |||||
import megengine.optimizer as optim | |||||
from megengine.autodiff import GradManager | from megengine.autodiff import GradManager | ||||
from megengine.core._imperative_rt.imperative import sync | |||||
from megengine.distributed.helper import get_device_count_by_fork | |||||
def test_basic(): | def test_basic(): | ||||
@@ -48,3 +55,47 @@ def test_attach_in_with_block(): | |||||
c = b + 1 | c = b + 1 | ||||
gm.backward(c) | gm.backward(c) | ||||
assert int(b.grad.numpy()) == 1 | assert int(b.grad.numpy()) == 1 | ||||
@pytest.mark.skipif( | |||||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
) | |||||
@pytest.mark.skipif( | |||||
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" | |||||
) | |||||
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||||
@pytest.mark.isolated_distributed | |||||
def test_remote_grad(): | |||||
@dist.launcher | |||||
def worker(): | |||||
rank = dist.get_rank() | |||||
size = dist.get_world_size() | |||||
x = mge.tensor(np.random.randn(1, rank * 2 + 2), dtype=np.float32) | |||||
m = M.Linear(rank * 2 + 2, rank * 2 + 4) | |||||
gm = GradManager().attach(m.parameters()) | |||||
opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9) | |||||
def train_func(x): | |||||
if rank != 0: | |||||
x = dist.functional.remote_recv( | |||||
rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32 | |||||
) | |||||
print(rank, "x", x) | |||||
y = m(x) | |||||
print(rank, "y", y) | |||||
if rank != size - 1: | |||||
y = dist.functional.remote_send(y, dest_rank=rank + 1) | |||||
return y | |||||
with gm: | |||||
y = train_func(x) | |||||
if rank == size - 1: | |||||
y = y.mean() | |||||
gm.backward(y) | |||||
else: | |||||
gm.backward() | |||||
opt.step().clear_grad() | |||||
# sync because send is the last job | |||||
sync() | |||||
worker() |