Browse Source

feat(mge): restore remote send/recv

GitOrigin-RevId: 8b78fd5591
release-1.2
Megvii Engine Team 4 years ago
parent
commit
4d75f691a0
9 changed files with 150 additions and 27 deletions
  1. +1
    -1
      imperative/python/megengine/__init__.py
  2. +29
    -3
      imperative/python/megengine/core/autodiff/grad.py
  3. +64
    -9
      imperative/python/megengine/distributed/functional.py
  4. +40
    -4
      imperative/python/src/grad.cpp
  5. +3
    -0
      imperative/python/src/grad.h
  6. +5
    -2
      imperative/python/src/tensor.cpp
  7. +1
    -0
      imperative/python/test/unit/autodiff/test_grad_manger.py
  8. +4
    -6
      imperative/python/test/unit/core/test_autodiff.py
  9. +3
    -2
      imperative/python/test/unit/functional/test_functional_distributed.py

+ 1
- 1
imperative/python/megengine/__init__.py View File

@@ -71,7 +71,7 @@ if sys.platform == "win32":

kernel32.SetErrorMode(old_error_mode)

from .core._imperative_rt.core2 import sync, release_trace_apply_func
from .core._imperative_rt.core2 import release_trace_apply_func, sync
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .device import *
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level


+ 29
- 3
imperative/python/megengine/core/autodiff/grad.py View File

@@ -46,9 +46,31 @@ def get_grad_managers():
return [_grad_manager_dict[key] for key in _grad_manager_dict]


class GradKey(core2.GradKey):
def __init__(self, name=None):
if name:
self.name = name

def backward(self, ys, dys):
return core2.backward(self, ys, dys)


class Grad:
def __init__(self):
self._impl = core2.GradKey()
def __init__(self, name=None):
global _grad_count
if name is None:
name = "grad_%d" % _grad_count
_grad_count += 1
self._refkeeper = []
self._impl = GradKey(name)
_grad_manager_dict[self._name] = self

@property
def _name(self):
return self._impl.name

def _is_attached_to(self, tensor):
return self._impl.is_attached_to(tensor)

def wrt(self, *tensors, callback=None):
for x in tensors:
@@ -62,12 +84,16 @@ class Grad:
ys = [ys]
if not isinstance(dys, Sequence):
dys = [dys]
core2.backward(self._impl, ys, dys)

self._impl.backward(ys, dys)

self._refkeeper = None

def __enter__(self):
return self

def __exit__(self, _1, _2, _3):
self._refkeeper = None
del self._impl




+ 64
- 9
imperative/python/megengine/distributed/functional.py View File

@@ -9,8 +9,8 @@
from typing import Optional, Tuple

from ..core._imperative_rt.core2 import apply
from ..core.autodiff.grad import get_grad_managers
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.autodiff.grad import _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend
from ..device import get_default_device
from ..tensor import Tensor
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank
@@ -193,6 +193,48 @@ def all_to_all(
return collective_comm(inp, mode, group, device)


class _RemoteSend(PyOpBase):
def __init__(self, op: RemoteSend):
self.op = op

def _default_rule(self, data):
return apply(self.op, data)

def _grad_rule(self, data):
self.dtype = data.dtype
self.shape = data.shape
self.device = data.device
(self.dummy,) = self._default_rule(data)
return self.dummy, self.backward

def backward(self, grad):
assert grad is None
if get_client().check_is_grad(self.op.key):
return remote_recv(
self.op.rank_to,
self.shape,
self.dtype,
device=str(self.device),
inp=self.dummy,
)


class _RemoteRecv(PyOpBase):
def __init__(self, op: RemoteRecv):
self.op = op

def _default_rule(self, dummy):
return apply(self.op, dummy)

def _grad_rule(self, dummy):
return self._default_rule(dummy), self.backward

def backward(self, grad):
get_client().set_is_grad(self.op.key, grad is not None)
if grad is not None:
remote_send(grad, self.op.rank_from)


def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
"""
Send a Tensor to a remote process.
@@ -200,11 +242,21 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
:param inp: tensor to send.
:param dest_rank: destination process rank.
"""
key = "{}->{}".format(get_rank(), dest_rank)
grad_keys = {}
for n, g in _grad_manager_dict.items():
if g._is_attached_to(inp):
grad_keys[n] = g
get_client().set_remote_tracer(key, grad_keys)

op = RemoteSend()
op.key = "{}->{}".format(get_rank(), dest_rank)
op.key = key
op.addr, op.port = get_mm_server_addr()
op.rank_to = dest_rank
return apply(op, inp)[0]
(dummy,) = apply(_RemoteSend(op), inp)

for g in grad_keys.values():
g._refkeeper.append(dummy)


def remote_recv(
@@ -228,12 +280,14 @@ def remote_recv(
if device is None:
device = get_default_device()
# dummy input
if inp == None:
if inp is None:
inp = Tensor([0], device=device)
tracer_set = get_client().check_remote_tracer(key)
for grad_manager in get_grad_managers():
if grad_manager.name in tracer_set:
grad_manager.wrt(inp)
for n in tracer_set:
g = _grad_manager_dict.get(n)
if g is not None:
g.wrt(inp)
g._refkeeper.append(inp)

op = RemoteRecv()
op.key = key
@@ -243,4 +297,5 @@ def remote_recv(
op.addr, op.port = get_mm_server_addr()
op.rank_from = src_rank

return apply(op, inp)[0]
(ret,) = apply(_RemoteRecv(op), inp)
return ret

+ 40
- 4
imperative/python/src/grad.cpp View File

@@ -193,11 +193,15 @@ struct PythonBackward {
args[i] = g ? ctx.wrap_tensor(g) : py::none();
}
auto input_grads = py::reinterpret_steal<py::object>(PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr));
if (!input_grads) throw py::error_already_set();
if (input_grads.is_none()) return;
if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) {
if (input_size != 1) {
throw py::value_error("custom grad rule returned wrong number of grads");
}
if (!ctx.pytype) {
ctx.pytype = Py_TYPE(input_grads.ptr());
}
receiver(0, tw->m_tensor);
return;
}
@@ -210,6 +214,9 @@ struct PythonBackward {
if (!tw) {
throw py::type_error("custom grad rule returned non-tensor");
}
if (!ctx.pytype) {
ctx.pytype = Py_TYPE(g.ptr());
}
receiver(i, tw->m_tensor);
}
}
@@ -321,6 +328,7 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
}
auto grad_rule = py::getattr(op->obj, "_grad_rule");
auto pyret = py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr));
if (!pyret) throw py::error_already_set();
auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret);
ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs);
if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) {
@@ -507,8 +515,12 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
~CleanupGuard() {owner->cleanup();}
} _cleanup_guard(this);

if (tape.empty() || grads.empty()) return;
PyTypeObject* pytype = Py_TYPE(grads[0]->self().ptr());
if (tape.empty()) return;

BackwardContext bctx;
if (!grads.empty()) {
bctx.pytype = Py_TYPE(grads[0]->self().ptr());
}

for (size_t i = 0; i < tensors.size(); ++i) {
auto& grad_info = tensors[i]->m_tensor->m_grad_info;
@@ -517,7 +529,6 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
}
}

BackwardContext bctx{pytype};
std::vector<std::shared_ptr<GradFn>> ref_keeper;
ref_keeper.reserve(tape.size());
// back-propagation in reverse order
@@ -548,7 +559,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
}
if (!dst.producer_record.next && dst->callback && dst->grad) {
// I'm the last grad producer, invoke callback
dst->callback(TensorWrapper::make(pytype, dst->grad));
dst->callback(bctx.wrap_tensor(dst->grad));
}
}
grad_fn->clear();
@@ -568,6 +579,31 @@ void GradKeyWrapper::backward(std::vector<TensorWrapper*> tensors, std::vector<T
m_key->backward(std::move(tensors), std::move(grads));
}

PyObject* GradKeyWrapper::get_name() {
return py::cast(m_key->name).release().ptr();
}

void GradKeyWrapper::set_name(py::handle name) {
m_key->name = py::cast<std::string>(name);
}

PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) {
if (nargs != 1) {
PyErr_SetString(PyExc_TypeError, "expect 1 argument");
return nullptr;
}
auto* tw = TensorWrapper::try_cast(args[0]);
if (!tw) {
PyErr_SetString(PyExc_TypeError, "expect Tensor");
return nullptr;
}
auto&& grad_fn = tw->m_tensor->m_grad_info.grad_fn;
if (grad_fn && grad_fn->key.lock() == m_key) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}

GradKey::~GradKey() {
cleanup();
}


+ 3
- 0
imperative/python/src/grad.h View File

@@ -41,8 +41,11 @@ struct GradKeyWrapper {

inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {}

PyObject* get_name();
void set_name(pybind11::handle name);
void attach(PyObject*const* args, size_t nargs);
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
PyObject* is_attached_to(PyObject*const* args, size_t nargs);
};

struct BackwardContext {


+ 5
- 2
imperative/python/src/tensor.cpp View File

@@ -733,15 +733,18 @@ void init_tensor(py::module m) {
py_task_q.wait_all_task_finish();
},
py::call_guard<py::gil_scoped_release>());
m.def("release_trace_apply_func", &release_trace_apply_func);

py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach")
.def<&GradKeyWrapper::is_attached_to>("is_attached_to")
.def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>("name")
.finalize();
if (!grad_key_type) throw py::error_already_set();
py::setattr(m, "GradKey", grad_key_type);
py::setattr(m, "backward", py::cpp_function(&GradKeyWrapper::backward));
m.def("backward", &GradKeyWrapper::backward);

m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing);
m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing);
m.def("set_cpp_apply_compiled_mode", &set_cpp_apply_compiled_mode);


+ 1
- 0
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -141,6 +141,7 @@ def test_regression_1762():
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
@pytest.mark.skip(reason="FIXME: remote_send/recv")
def test_remote_grad():
@dist.launcher
def worker():


+ 4
- 6
imperative/python/test/unit/core/test_autodiff.py View File

@@ -16,9 +16,8 @@ import pytest
import megengine as mge
import megengine.distributed as dist
import megengine.functional as F
from megengine.core._imperative_rt import TensorAttr, core2, imperative
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply
from megengine.core._imperative_rt.imperative import sync
from megengine.core._imperative_rt import CompNode, TensorAttr, core2, imperative
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync
from megengine.core.autodiff.grad import Grad
from megengine.core.ops.builtin import Elemwise
from megengine.distributed.helper import get_device_count_by_fork
@@ -73,7 +72,7 @@ def test_dist_grad():
x = as_tensor(x_np)
grad.wrt(x, callback=save_to(x))
# need a placeholder to trace operator
send_x = remote_send(x, 1)
remote_send(x, 1)
recv_x = remote_recv(1, x_np.shape, x_np.dtype)
y = recv_x * recv_x

@@ -83,13 +82,12 @@ def test_dist_grad():
grad = Grad()

recv_x = remote_recv(0, x_np.shape, x_np.dtype)
send_x = remote_send(recv_x, 0)
remote_send(recv_x, 0)

grad([], [])

worker()


def test_grad():
x_np = np.random.rand(10).astype("float32")
x = as_tensor(x_np)


+ 3
- 2
imperative/python/test/unit/functional/test_functional_distributed.py View File

@@ -14,6 +14,7 @@ import pytest
import megengine as mge
import megengine.distributed as dist
from megengine import Parameter, Tensor, tensor
from megengine.core._imperative_rt.core2 import sync
from megengine.device import get_default_device, set_default_device
from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.distributed import (
@@ -333,8 +334,8 @@ def test_io_remote():
rank = dist.get_rank()
if rank == 0: # remote send
x = Tensor(val, device="gpu0")
y = remote_send(x, 1)
assert y.numpy()[0] == 0
remote_send(x, 1)
sync()
else: # remote recv
y = remote_recv(0, val.shape, val.dtype)
assert y.device == "gpu1"


Loading…
Cancel
Save