Browse Source

fix(mge/io_remote): fix remote send/recv gradient at trace

GitOrigin-RevId: 7886efd0c1
release-1.2
Megvii Engine Team 4 years ago
parent
commit
2ad8c5e1e9
5 changed files with 37 additions and 16 deletions
  1. +7
    -1
      imperative/python/megengine/core/autodiff/grad.py
  2. +1
    -1
      imperative/python/megengine/jit/tracing.py
  3. +12
    -9
      imperative/python/test/unit/autodiff/test_grad_manger.py
  4. +14
    -5
      src/opr-mm/impl/io_remote.cpp
  5. +3
    -0
      src/opr-mm/include/megbrain/opr/io_remote.h

+ 7
- 1
imperative/python/megengine/core/autodiff/grad.py View File

@@ -16,7 +16,7 @@ import numpy as np


import megengine as mge import megengine as mge


from ..ops.builtin import Elemwise, OpDef
from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const from ..ops.special import Const
from ..tensor.core import TensorBase, TensorWrapperBase, apply from ..tensor.core import TensorBase, TensorWrapperBase, apply
from ..tensor.function import Function from ..tensor.function import Function
@@ -84,6 +84,9 @@ class Grad:
# ops forms the computational graph # ops forms the computational graph
self.ops = [] self.ops = []


# save remote_send output for backward
self.remote_send_cache = []

self._attached_tensors = weakref.WeakSet() self._attached_tensors = weakref.WeakSet()
self._enabled = True self._enabled = True


@@ -144,6 +147,7 @@ class Grad:
o.clear() o.clear()
for i in self._attached_tensors: for i in self._attached_tensors:
i._extra_data.pop(self, None) i._extra_data.pop(self, None)
self.remote_send_cache = []


def __exit__(self, *_): def __exit__(self, *_):
self._exit() self._exit()
@@ -398,6 +402,8 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
return return


opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs)
if isinstance(op, RemoteSend):
manager.remote_send_cache.append(opnode)
opnode.backward = backward opnode.backward = backward


outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)]


+ 1
- 1
imperative/python/megengine/jit/tracing.py View File

@@ -588,7 +588,7 @@ class trace:
graph.options.graph_opt_level = self._graph_opt_level graph.options.graph_opt_level = self._graph_opt_level
else: else:
graph.options.graph_opt_level = 2 graph.options.graph_opt_level = 2
graph.compile(*readers)
graph.compile(*readers, *links)


def _reset_exec_env(self): def _reset_exec_env(self):
for opnode in self._need_reset_nodes: for opnode in self._need_reset_nodes:


+ 12
- 9
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -111,7 +111,6 @@ def test_remote_grad():
gm = GradManager().attach(m.parameters()) gm = GradManager().attach(m.parameters())
opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9) opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9)


@trace(symbolic=True)
def train_func(x): def train_func(x):
with gm: with gm:
if rank != 0: if rank != 0:
@@ -120,18 +119,22 @@ def test_remote_grad():
) )
y = m(x) y = m(x)
if rank != size - 1: if rank != size - 1:
y = dist.functional.remote_send(y, dest_rank=rank + 1)
if rank == size - 1:
dist.functional.remote_send(y, dest_rank=rank + 1)
gm.backward()
else:
y = y.mean() y = y.mean()
gm.backward(y) gm.backward(y)
else:
gm.backward()
opt.step().clear_grad() opt.step().clear_grad()


for i in range(3):
train_func(x)
train_funcs = [
train_func,
trace(symbolic=False)(train_func),
trace(symbolic=True)(train_func),
]


for param in m.parameters():
param.numpy()
for func in train_funcs:
for i in range(3):
func(x)
sync()


worker() worker()

+ 14
- 5
src/opr-mm/impl/io_remote.cpp View File

@@ -266,11 +266,20 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv(
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<RemoteRecv>(); auto&& opr = opr_.cast_final_safe<RemoteRecv>();
return RemoteRecv::make(opr.key(), *opr.owner_graph(),
opr.group_client(), config, inputs[0]->shape(),
inputs[0]->dtype())
.node()
->owner_opr();
if (inputs.size() == 1) {
return RemoteRecv::make(opr.key(), inputs[0], *opr.owner_graph(),
opr.group_client(), config, opr.shape(),
opr.dtype())
.node()
->owner_opr();
} else {
mgb_assert(inputs.size() == 0, "recv should have 1 or 0 input");
return RemoteRecv::make(opr.key(), *opr.owner_graph(),
opr.group_client(), config, opr.shape(),
opr.dtype())
.node()
->owner_opr();
}
} }
MGB_REG_OPR_SHALLOW_COPY(RemoteRecv, opr_shallow_copy_remote_recv); MGB_REG_OPR_SHALLOW_COPY(RemoteRecv, opr_shallow_copy_remote_recv);




+ 3
- 0
src/opr-mm/include/megbrain/opr/io_remote.h View File

@@ -94,6 +94,9 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // {
const OperatorNodeConfig& config, const TensorShape& shape, const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype); DType dtype);


const TensorShape& shape() const { return m_shape; }
const DType& dtype() const { return m_dtype; }

private: private:
const TensorShape m_shape; const TensorShape m_shape;
const DType m_dtype; const DType m_dtype;


Loading…
Cancel
Save