@@ -71,7 +71,7 @@ if sys.platform == "win32": | |||
kernel32.SetErrorMode(old_error_mode) | |||
from .core._imperative_rt.core2 import sync, release_trace_apply_func | |||
from .core._imperative_rt.core2 import release_trace_apply_func, sync | |||
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func | |||
from .device import * | |||
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level | |||
@@ -46,9 +46,31 @@ def get_grad_managers(): | |||
return [_grad_manager_dict[key] for key in _grad_manager_dict] | |||
class GradKey(core2.GradKey): | |||
def __init__(self, name=None): | |||
if name: | |||
self.name = name | |||
def backward(self, ys, dys): | |||
return core2.backward(self, ys, dys) | |||
class Grad: | |||
def __init__(self): | |||
self._impl = core2.GradKey() | |||
def __init__(self, name=None): | |||
global _grad_count | |||
if name is None: | |||
name = "grad_%d" % _grad_count | |||
_grad_count += 1 | |||
self._refkeeper = [] | |||
self._impl = GradKey(name) | |||
_grad_manager_dict[self._name] = self | |||
@property | |||
def _name(self): | |||
return self._impl.name | |||
def _is_attached_to(self, tensor): | |||
return self._impl.is_attached_to(tensor) | |||
def wrt(self, *tensors, callback=None): | |||
for x in tensors: | |||
@@ -62,12 +84,16 @@ class Grad: | |||
ys = [ys] | |||
if not isinstance(dys, Sequence): | |||
dys = [dys] | |||
core2.backward(self._impl, ys, dys) | |||
self._impl.backward(ys, dys) | |||
self._refkeeper = None | |||
def __enter__(self): | |||
return self | |||
def __exit__(self, _1, _2, _3): | |||
self._refkeeper = None | |||
del self._impl | |||
@@ -9,8 +9,8 @@ | |||
from typing import Optional, Tuple | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core.autodiff.grad import get_grad_managers | |||
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
from ..core.autodiff.grad import _grad_manager_dict | |||
from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend | |||
from ..device import get_default_device | |||
from ..tensor import Tensor | |||
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank | |||
@@ -193,6 +193,48 @@ def all_to_all( | |||
return collective_comm(inp, mode, group, device) | |||
class _RemoteSend(PyOpBase): | |||
def __init__(self, op: RemoteSend): | |||
self.op = op | |||
def _default_rule(self, data): | |||
return apply(self.op, data) | |||
def _grad_rule(self, data): | |||
self.dtype = data.dtype | |||
self.shape = data.shape | |||
self.device = data.device | |||
(self.dummy,) = self._default_rule(data) | |||
return self.dummy, self.backward | |||
def backward(self, grad): | |||
assert grad is None | |||
if get_client().check_is_grad(self.op.key): | |||
return remote_recv( | |||
self.op.rank_to, | |||
self.shape, | |||
self.dtype, | |||
device=str(self.device), | |||
inp=self.dummy, | |||
) | |||
class _RemoteRecv(PyOpBase): | |||
def __init__(self, op: RemoteRecv): | |||
self.op = op | |||
def _default_rule(self, dummy): | |||
return apply(self.op, dummy) | |||
def _grad_rule(self, dummy): | |||
return self._default_rule(dummy), self.backward | |||
def backward(self, grad): | |||
get_client().set_is_grad(self.op.key, grad is not None) | |||
if grad is not None: | |||
remote_send(grad, self.op.rank_from) | |||
def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||
""" | |||
Send a Tensor to a remote process. | |||
@@ -200,11 +242,21 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||
:param inp: tensor to send. | |||
:param dest_rank: destination process rank. | |||
""" | |||
key = "{}->{}".format(get_rank(), dest_rank) | |||
grad_keys = {} | |||
for n, g in _grad_manager_dict.items(): | |||
if g._is_attached_to(inp): | |||
grad_keys[n] = g | |||
get_client().set_remote_tracer(key, grad_keys) | |||
op = RemoteSend() | |||
op.key = "{}->{}".format(get_rank(), dest_rank) | |||
op.key = key | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_to = dest_rank | |||
return apply(op, inp)[0] | |||
(dummy,) = apply(_RemoteSend(op), inp) | |||
for g in grad_keys.values(): | |||
g._refkeeper.append(dummy) | |||
def remote_recv( | |||
@@ -228,12 +280,14 @@ def remote_recv( | |||
if device is None: | |||
device = get_default_device() | |||
# dummy input | |||
if inp == None: | |||
if inp is None: | |||
inp = Tensor([0], device=device) | |||
tracer_set = get_client().check_remote_tracer(key) | |||
for grad_manager in get_grad_managers(): | |||
if grad_manager.name in tracer_set: | |||
grad_manager.wrt(inp) | |||
for n in tracer_set: | |||
g = _grad_manager_dict.get(n) | |||
if g is not None: | |||
g.wrt(inp) | |||
g._refkeeper.append(inp) | |||
op = RemoteRecv() | |||
op.key = key | |||
@@ -243,4 +297,5 @@ def remote_recv( | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_from = src_rank | |||
return apply(op, inp)[0] | |||
(ret,) = apply(_RemoteRecv(op), inp) | |||
return ret |
@@ -193,11 +193,15 @@ struct PythonBackward { | |||
args[i] = g ? ctx.wrap_tensor(g) : py::none(); | |||
} | |||
auto input_grads = py::reinterpret_steal<py::object>(PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr)); | |||
if (!input_grads) throw py::error_already_set(); | |||
if (input_grads.is_none()) return; | |||
if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) { | |||
if (input_size != 1) { | |||
throw py::value_error("custom grad rule returned wrong number of grads"); | |||
} | |||
if (!ctx.pytype) { | |||
ctx.pytype = Py_TYPE(input_grads.ptr()); | |||
} | |||
receiver(0, tw->m_tensor); | |||
return; | |||
} | |||
@@ -210,6 +214,9 @@ struct PythonBackward { | |||
if (!tw) { | |||
throw py::type_error("custom grad rule returned non-tensor"); | |||
} | |||
if (!ctx.pytype) { | |||
ctx.pytype = Py_TYPE(g.ptr()); | |||
} | |||
receiver(i, tw->m_tensor); | |||
} | |||
} | |||
@@ -321,6 +328,7 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | |||
} | |||
auto grad_rule = py::getattr(op->obj, "_grad_rule"); | |||
auto pyret = py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr)); | |||
if (!pyret) throw py::error_already_set(); | |||
auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret); | |||
ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs); | |||
if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) { | |||
@@ -507,8 +515,12 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||
~CleanupGuard() {owner->cleanup();} | |||
} _cleanup_guard(this); | |||
if (tape.empty() || grads.empty()) return; | |||
PyTypeObject* pytype = Py_TYPE(grads[0]->self().ptr()); | |||
if (tape.empty()) return; | |||
BackwardContext bctx; | |||
if (!grads.empty()) { | |||
bctx.pytype = Py_TYPE(grads[0]->self().ptr()); | |||
} | |||
for (size_t i = 0; i < tensors.size(); ++i) { | |||
auto& grad_info = tensors[i]->m_tensor->m_grad_info; | |||
@@ -517,7 +529,6 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||
} | |||
} | |||
BackwardContext bctx{pytype}; | |||
std::vector<std::shared_ptr<GradFn>> ref_keeper; | |||
ref_keeper.reserve(tape.size()); | |||
// back-propagation in reverse order | |||
@@ -548,7 +559,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||
} | |||
if (!dst.producer_record.next && dst->callback && dst->grad) { | |||
// I'm the last grad producer, invoke callback | |||
dst->callback(TensorWrapper::make(pytype, dst->grad)); | |||
dst->callback(bctx.wrap_tensor(dst->grad)); | |||
} | |||
} | |||
grad_fn->clear(); | |||
@@ -568,6 +579,31 @@ void GradKeyWrapper::backward(std::vector<TensorWrapper*> tensors, std::vector<T | |||
m_key->backward(std::move(tensors), std::move(grads)); | |||
} | |||
PyObject* GradKeyWrapper::get_name() { | |||
return py::cast(m_key->name).release().ptr(); | |||
} | |||
void GradKeyWrapper::set_name(py::handle name) { | |||
m_key->name = py::cast<std::string>(name); | |||
} | |||
PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) { | |||
if (nargs != 1) { | |||
PyErr_SetString(PyExc_TypeError, "expect 1 argument"); | |||
return nullptr; | |||
} | |||
auto* tw = TensorWrapper::try_cast(args[0]); | |||
if (!tw) { | |||
PyErr_SetString(PyExc_TypeError, "expect Tensor"); | |||
return nullptr; | |||
} | |||
auto&& grad_fn = tw->m_tensor->m_grad_info.grad_fn; | |||
if (grad_fn && grad_fn->key.lock() == m_key) { | |||
Py_RETURN_TRUE; | |||
} | |||
Py_RETURN_FALSE; | |||
} | |||
GradKey::~GradKey() { | |||
cleanup(); | |||
} | |||
@@ -41,8 +41,11 @@ struct GradKeyWrapper { | |||
inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {} | |||
PyObject* get_name(); | |||
void set_name(pybind11::handle name); | |||
void attach(PyObject*const* args, size_t nargs); | |||
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | |||
PyObject* is_attached_to(PyObject*const* args, size_t nargs); | |||
}; | |||
struct BackwardContext { | |||
@@ -733,15 +733,18 @@ void init_tensor(py::module m) { | |||
py_task_q.wait_all_task_finish(); | |||
}, | |||
py::call_guard<py::gil_scoped_release>()); | |||
m.def("release_trace_apply_func", &release_trace_apply_func); | |||
py::handle grad_key_type = GradKeyWrapper::wrap_t::type() | |||
.def<&GradKeyWrapper::attach>("attach") | |||
.def<&GradKeyWrapper::is_attached_to>("is_attached_to") | |||
.def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>("name") | |||
.finalize(); | |||
if (!grad_key_type) throw py::error_already_set(); | |||
py::setattr(m, "GradKey", grad_key_type); | |||
py::setattr(m, "backward", py::cpp_function(&GradKeyWrapper::backward)); | |||
m.def("backward", &GradKeyWrapper::backward); | |||
m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing); | |||
m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing); | |||
m.def("set_cpp_apply_compiled_mode", &set_cpp_apply_compiled_mode); | |||
@@ -141,6 +141,7 @@ def test_regression_1762(): | |||
) | |||
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||
@pytest.mark.isolated_distributed | |||
@pytest.mark.skip(reason="FIXME: remote_send/recv") | |||
def test_remote_grad(): | |||
@dist.launcher | |||
def worker(): | |||
@@ -16,9 +16,8 @@ import pytest | |||
import megengine as mge | |||
import megengine.distributed as dist | |||
import megengine.functional as F | |||
from megengine.core._imperative_rt import TensorAttr, core2, imperative | |||
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply | |||
from megengine.core._imperative_rt.imperative import sync | |||
from megengine.core._imperative_rt import CompNode, TensorAttr, core2, imperative | |||
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync | |||
from megengine.core.autodiff.grad import Grad | |||
from megengine.core.ops.builtin import Elemwise | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
@@ -73,7 +72,7 @@ def test_dist_grad(): | |||
x = as_tensor(x_np) | |||
grad.wrt(x, callback=save_to(x)) | |||
# need a placeholder to trace operator | |||
send_x = remote_send(x, 1) | |||
remote_send(x, 1) | |||
recv_x = remote_recv(1, x_np.shape, x_np.dtype) | |||
y = recv_x * recv_x | |||
@@ -83,13 +82,12 @@ def test_dist_grad(): | |||
grad = Grad() | |||
recv_x = remote_recv(0, x_np.shape, x_np.dtype) | |||
send_x = remote_send(recv_x, 0) | |||
remote_send(recv_x, 0) | |||
grad([], []) | |||
worker() | |||
def test_grad(): | |||
x_np = np.random.rand(10).astype("float32") | |||
x = as_tensor(x_np) | |||
@@ -14,6 +14,7 @@ import pytest | |||
import megengine as mge | |||
import megengine.distributed as dist | |||
from megengine import Parameter, Tensor, tensor | |||
from megengine.core._imperative_rt.core2 import sync | |||
from megengine.device import get_default_device, set_default_device | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.functional.distributed import ( | |||
@@ -333,8 +334,8 @@ def test_io_remote(): | |||
rank = dist.get_rank() | |||
if rank == 0: # remote send | |||
x = Tensor(val, device="gpu0") | |||
y = remote_send(x, 1) | |||
assert y.numpy()[0] == 0 | |||
remote_send(x, 1) | |||
sync() | |||
else: # remote recv | |||
y = remote_recv(0, val.shape, val.dtype) | |||
assert y.device == "gpu1" | |||