Browse Source

feat(mge/distributed): allow remote grad by using grad manager

GitOrigin-RevId: a890c206a5
release-1.1
Megvii Engine Team 4 years ago
parent
commit
094601e834
4 changed files with 72 additions and 5 deletions
  1. +3
    -1
      imperative/python/megengine/autodiff/grad_manager.py
  2. +4
    -0
      imperative/python/megengine/core/autodiff/grad.py
  3. +14
    -4
      imperative/python/megengine/distributed/functional.py
  4. +51
    -0
      imperative/python/test/unit/autodiff/test_grad_manger.py

+ 3
- 1
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -127,7 +127,7 @@ class GradManager:
self._after_backward_callback.append(callback)
return self

def backward(self, ys, dys=None):
def backward(self, ys=None, dys=None):
r"""
Performs back-propagation and computes gradients.

@@ -146,6 +146,8 @@ class GradManager:
"call a method that clears the history?"
)
assert self._grad is not None
if ys is None:
ys = []
if not isinstance(ys, (tuple, list)):
ys = [ys]
if dys is None:


+ 4
- 0
imperative/python/megengine/core/autodiff/grad.py View File

@@ -14,6 +14,8 @@ import weakref

import numpy as np

import megengine as mge

from ..ops.builtin import Elemwise, OpDef
from ..ops.special import Const
from ..tensor.core import TensorBase, TensorWrapperBase, apply
@@ -167,6 +169,8 @@ class Grad:
for i in dys:
if isinstance(i, TensorWrapperBase):
return type(i)
# use Tensor as defualt wrapper
return mge.Tensor

Wrapper = check_wrapper()



+ 14
- 4
imperative/python/megengine/distributed/functional.py View File

@@ -59,7 +59,11 @@ def _(op: RemoteSend, inputs, outputs, input_requires_grad):
def backward(*args):
return [
remote_recv(
op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device)
op.rank_to,
inputs[0].shape,
inputs[0].dtype,
device=str(inputs[0].device),
inp=inputs[0],
)
]

@@ -275,7 +279,11 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:


def remote_recv(
src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None
src_rank: int,
shape: Tuple[int],
dtype: type,
device: Optional[str] = None,
inp=None,
) -> Tensor:
"""
Receive a Tensor from a remote process.
@@ -284,13 +292,15 @@ def remote_recv(
:param shape: the shape of the tensor to receive.
:param dtype: the data type of the tensor to receive.
:param device: the device to place the received tensor.
:param inp: dummy input to determine recved tensor type
"""
key = "{}->{}".format(src_rank, get_rank())

if device is None:
device = get_default_device()
# dummpy input
inp = tensor([0])
# dummy input
if inp == None:
inp = tensor([0])
tracer_set = get_client().check_remote_tracer(key)
for grad_manager in get_grad_managers():
if grad_manager.name in tracer_set:


+ 51
- 0
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -5,12 +5,19 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import platform

import numpy as np
import pytest

import megengine as mge
import megengine.distributed as dist
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
from megengine.autodiff import GradManager
from megengine.core._imperative_rt.imperative import sync
from megengine.distributed.helper import get_device_count_by_fork


def test_basic():
@@ -48,3 +55,47 @@ def test_attach_in_with_block():
c = b + 1
gm.backward(c)
assert int(b.grad.numpy()) == 1


@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_remote_grad():
@dist.launcher
def worker():
rank = dist.get_rank()
size = dist.get_world_size()
x = mge.tensor(np.random.randn(1, rank * 2 + 2), dtype=np.float32)
m = M.Linear(rank * 2 + 2, rank * 2 + 4)
gm = GradManager().attach(m.parameters())
opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9)

def train_func(x):
if rank != 0:
x = dist.functional.remote_recv(
rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32
)
print(rank, "x", x)
y = m(x)
print(rank, "y", y)
if rank != size - 1:
y = dist.functional.remote_send(y, dest_rank=rank + 1)
return y

with gm:
y = train_func(x)
if rank == size - 1:
y = y.mean()
gm.backward(y)
else:
gm.backward()
opt.step().clear_grad()
# sync because send is the last job
sync()

worker()

Loading…
Cancel
Save