GitOrigin-RevId: a890c206a5
release-1.1
@@ -127,7 +127,7 @@ class GradManager: | |||
self._after_backward_callback.append(callback) | |||
return self | |||
def backward(self, ys, dys=None): | |||
def backward(self, ys=None, dys=None): | |||
r""" | |||
Performs back-propagation and computes gradients. | |||
@@ -146,6 +146,8 @@ class GradManager: | |||
"call a method that clears the history?" | |||
) | |||
assert self._grad is not None | |||
if ys is None: | |||
ys = [] | |||
if not isinstance(ys, (tuple, list)): | |||
ys = [ys] | |||
if dys is None: | |||
@@ -14,6 +14,8 @@ import weakref | |||
import numpy as np | |||
import megengine as mge | |||
from ..ops.builtin import Elemwise, OpDef | |||
from ..ops.special import Const | |||
from ..tensor.core import TensorBase, TensorWrapperBase, apply | |||
@@ -167,6 +169,8 @@ class Grad: | |||
for i in dys: | |||
if isinstance(i, TensorWrapperBase): | |||
return type(i) | |||
# use Tensor as defualt wrapper | |||
return mge.Tensor | |||
Wrapper = check_wrapper() | |||
@@ -59,7 +59,11 @@ def _(op: RemoteSend, inputs, outputs, input_requires_grad): | |||
def backward(*args): | |||
return [ | |||
remote_recv( | |||
op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device) | |||
op.rank_to, | |||
inputs[0].shape, | |||
inputs[0].dtype, | |||
device=str(inputs[0].device), | |||
inp=inputs[0], | |||
) | |||
] | |||
@@ -275,7 +279,11 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||
def remote_recv( | |||
src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None | |||
src_rank: int, | |||
shape: Tuple[int], | |||
dtype: type, | |||
device: Optional[str] = None, | |||
inp=None, | |||
) -> Tensor: | |||
""" | |||
Receive a Tensor from a remote process. | |||
@@ -284,13 +292,15 @@ def remote_recv( | |||
:param shape: the shape of the tensor to receive. | |||
:param dtype: the data type of the tensor to receive. | |||
:param device: the device to place the received tensor. | |||
:param inp: dummy input to determine recved tensor type | |||
""" | |||
key = "{}->{}".format(src_rank, get_rank()) | |||
if device is None: | |||
device = get_default_device() | |||
# dummpy input | |||
inp = tensor([0]) | |||
# dummy input | |||
if inp == None: | |||
inp = tensor([0]) | |||
tracer_set = get_client().check_remote_tracer(key) | |||
for grad_manager in get_grad_managers(): | |||
if grad_manager.name in tracer_set: | |||
@@ -5,12 +5,19 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import platform | |||
import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine.distributed as dist | |||
import megengine.functional as F | |||
import megengine.module as M | |||
import megengine.optimizer as optim | |||
from megengine.autodiff import GradManager | |||
from megengine.core._imperative_rt.imperative import sync | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
def test_basic(): | |||
@@ -48,3 +55,47 @@ def test_attach_in_with_block(): | |||
c = b + 1 | |||
gm.backward(c) | |||
assert int(b.grad.numpy()) == 1 | |||
@pytest.mark.skipif( | |||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
) | |||
@pytest.mark.skipif( | |||
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" | |||
) | |||
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||
@pytest.mark.isolated_distributed | |||
def test_remote_grad(): | |||
@dist.launcher | |||
def worker(): | |||
rank = dist.get_rank() | |||
size = dist.get_world_size() | |||
x = mge.tensor(np.random.randn(1, rank * 2 + 2), dtype=np.float32) | |||
m = M.Linear(rank * 2 + 2, rank * 2 + 4) | |||
gm = GradManager().attach(m.parameters()) | |||
opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9) | |||
def train_func(x): | |||
if rank != 0: | |||
x = dist.functional.remote_recv( | |||
rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32 | |||
) | |||
print(rank, "x", x) | |||
y = m(x) | |||
print(rank, "y", y) | |||
if rank != size - 1: | |||
y = dist.functional.remote_send(y, dest_rank=rank + 1) | |||
return y | |||
with gm: | |||
y = train_func(x) | |||
if rank == size - 1: | |||
y = y.mean() | |||
gm.backward(y) | |||
else: | |||
gm.backward() | |||
opt.step().clear_grad() | |||
# sync because send is the last job | |||
sync() | |||
worker() |