GitOrigin-RevId: 7886efd0c1
release-1.2
@@ -16,7 +16,7 @@ import numpy as np | |||||
import megengine as mge | import megengine as mge | ||||
from ..ops.builtin import Elemwise, OpDef | |||||
from ..ops.builtin import Elemwise, OpDef, RemoteSend | |||||
from ..ops.special import Const | from ..ops.special import Const | ||||
from ..tensor.core import TensorBase, TensorWrapperBase, apply | from ..tensor.core import TensorBase, TensorWrapperBase, apply | ||||
from ..tensor.function import Function | from ..tensor.function import Function | ||||
@@ -84,6 +84,9 @@ class Grad: | |||||
# ops forms the computational graph | # ops forms the computational graph | ||||
self.ops = [] | self.ops = [] | ||||
# save remote_send output for backward | |||||
self.remote_send_cache = [] | |||||
self._attached_tensors = weakref.WeakSet() | self._attached_tensors = weakref.WeakSet() | ||||
self._enabled = True | self._enabled = True | ||||
@@ -144,6 +147,7 @@ class Grad: | |||||
o.clear() | o.clear() | ||||
for i in self._attached_tensors: | for i in self._attached_tensors: | ||||
i._extra_data.pop(self, None) | i._extra_data.pop(self, None) | ||||
self.remote_send_cache = [] | |||||
def __exit__(self, *_): | def __exit__(self, *_): | ||||
self._exit() | self._exit() | ||||
@@ -398,6 +402,8 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): | |||||
return | return | ||||
opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) | opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) | ||||
if isinstance(op, RemoteSend): | |||||
manager.remote_send_cache.append(opnode) | |||||
opnode.backward = backward | opnode.backward = backward | ||||
outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] | outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] | ||||
@@ -588,7 +588,7 @@ class trace: | |||||
graph.options.graph_opt_level = self._graph_opt_level | graph.options.graph_opt_level = self._graph_opt_level | ||||
else: | else: | ||||
graph.options.graph_opt_level = 2 | graph.options.graph_opt_level = 2 | ||||
graph.compile(*readers) | |||||
graph.compile(*readers, *links) | |||||
def _reset_exec_env(self): | def _reset_exec_env(self): | ||||
for opnode in self._need_reset_nodes: | for opnode in self._need_reset_nodes: | ||||
@@ -111,7 +111,6 @@ def test_remote_grad(): | |||||
gm = GradManager().attach(m.parameters()) | gm = GradManager().attach(m.parameters()) | ||||
opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9) | opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9) | ||||
@trace(symbolic=True) | |||||
def train_func(x): | def train_func(x): | ||||
with gm: | with gm: | ||||
if rank != 0: | if rank != 0: | ||||
@@ -120,18 +119,22 @@ def test_remote_grad(): | |||||
) | ) | ||||
y = m(x) | y = m(x) | ||||
if rank != size - 1: | if rank != size - 1: | ||||
y = dist.functional.remote_send(y, dest_rank=rank + 1) | |||||
if rank == size - 1: | |||||
dist.functional.remote_send(y, dest_rank=rank + 1) | |||||
gm.backward() | |||||
else: | |||||
y = y.mean() | y = y.mean() | ||||
gm.backward(y) | gm.backward(y) | ||||
else: | |||||
gm.backward() | |||||
opt.step().clear_grad() | opt.step().clear_grad() | ||||
for i in range(3): | |||||
train_func(x) | |||||
train_funcs = [ | |||||
train_func, | |||||
trace(symbolic=False)(train_func), | |||||
trace(symbolic=True)(train_func), | |||||
] | |||||
for param in m.parameters(): | |||||
param.numpy() | |||||
for func in train_funcs: | |||||
for i in range(3): | |||||
func(x) | |||||
sync() | |||||
worker() | worker() |
@@ -266,11 +266,20 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv( | |||||
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | ||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
auto&& opr = opr_.cast_final_safe<RemoteRecv>(); | auto&& opr = opr_.cast_final_safe<RemoteRecv>(); | ||||
return RemoteRecv::make(opr.key(), *opr.owner_graph(), | |||||
opr.group_client(), config, inputs[0]->shape(), | |||||
inputs[0]->dtype()) | |||||
.node() | |||||
->owner_opr(); | |||||
if (inputs.size() == 1) { | |||||
return RemoteRecv::make(opr.key(), inputs[0], *opr.owner_graph(), | |||||
opr.group_client(), config, opr.shape(), | |||||
opr.dtype()) | |||||
.node() | |||||
->owner_opr(); | |||||
} else { | |||||
mgb_assert(inputs.size() == 0, "recv should have 1 or 0 input"); | |||||
return RemoteRecv::make(opr.key(), *opr.owner_graph(), | |||||
opr.group_client(), config, opr.shape(), | |||||
opr.dtype()) | |||||
.node() | |||||
->owner_opr(); | |||||
} | |||||
} | } | ||||
MGB_REG_OPR_SHALLOW_COPY(RemoteRecv, opr_shallow_copy_remote_recv); | MGB_REG_OPR_SHALLOW_COPY(RemoteRecv, opr_shallow_copy_remote_recv); | ||||
@@ -94,6 +94,9 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { | |||||
const OperatorNodeConfig& config, const TensorShape& shape, | const OperatorNodeConfig& config, const TensorShape& shape, | ||||
DType dtype); | DType dtype); | ||||
const TensorShape& shape() const { return m_shape; } | |||||
const DType& dtype() const { return m_dtype; } | |||||
private: | private: | ||||
const TensorShape m_shape; | const TensorShape m_shape; | ||||
const DType m_dtype; | const DType m_dtype; | ||||