Browse Source

fix(mge/io_remote): fix remote send/recv gradient at trace

GitOrigin-RevId: 7886efd0c1
release-1.2
Megvii Engine Team 4 years ago
parent
commit
2ad8c5e1e9
5 changed files with 37 additions and 16 deletions
  1. +7
    -1
      imperative/python/megengine/core/autodiff/grad.py
  2. +1
    -1
      imperative/python/megengine/jit/tracing.py
  3. +12
    -9
      imperative/python/test/unit/autodiff/test_grad_manger.py
  4. +14
    -5
      src/opr-mm/impl/io_remote.cpp
  5. +3
    -0
      src/opr-mm/include/megbrain/opr/io_remote.h

+ 7
- 1
imperative/python/megengine/core/autodiff/grad.py View File

@@ -16,7 +16,7 @@ import numpy as np

import megengine as mge

from ..ops.builtin import Elemwise, OpDef
from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const
from ..tensor.core import TensorBase, TensorWrapperBase, apply
from ..tensor.function import Function
@@ -84,6 +84,9 @@ class Grad:
# ops forms the computational graph
self.ops = []

# save remote_send output for backward
self.remote_send_cache = []

self._attached_tensors = weakref.WeakSet()
self._enabled = True

@@ -144,6 +147,7 @@ class Grad:
o.clear()
for i in self._attached_tensors:
i._extra_data.pop(self, None)
self.remote_send_cache = []

def __exit__(self, *_):
self._exit()
@@ -398,6 +402,8 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
return

opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs)
if isinstance(op, RemoteSend):
manager.remote_send_cache.append(opnode)
opnode.backward = backward

outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)]


+ 1
- 1
imperative/python/megengine/jit/tracing.py View File

@@ -588,7 +588,7 @@ class trace:
graph.options.graph_opt_level = self._graph_opt_level
else:
graph.options.graph_opt_level = 2
graph.compile(*readers)
graph.compile(*readers, *links)

def _reset_exec_env(self):
for opnode in self._need_reset_nodes:


+ 12
- 9
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -111,7 +111,6 @@ def test_remote_grad():
gm = GradManager().attach(m.parameters())
opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9)

@trace(symbolic=True)
def train_func(x):
with gm:
if rank != 0:
@@ -120,18 +119,22 @@ def test_remote_grad():
)
y = m(x)
if rank != size - 1:
y = dist.functional.remote_send(y, dest_rank=rank + 1)
if rank == size - 1:
dist.functional.remote_send(y, dest_rank=rank + 1)
gm.backward()
else:
y = y.mean()
gm.backward(y)
else:
gm.backward()
opt.step().clear_grad()

for i in range(3):
train_func(x)
train_funcs = [
train_func,
trace(symbolic=False)(train_func),
trace(symbolic=True)(train_func),
]

for param in m.parameters():
param.numpy()
for func in train_funcs:
for i in range(3):
func(x)
sync()

worker()

+ 14
- 5
src/opr-mm/impl/io_remote.cpp View File

@@ -266,11 +266,20 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv(
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<RemoteRecv>();
return RemoteRecv::make(opr.key(), *opr.owner_graph(),
opr.group_client(), config, inputs[0]->shape(),
inputs[0]->dtype())
.node()
->owner_opr();
if (inputs.size() == 1) {
return RemoteRecv::make(opr.key(), inputs[0], *opr.owner_graph(),
opr.group_client(), config, opr.shape(),
opr.dtype())
.node()
->owner_opr();
} else {
mgb_assert(inputs.size() == 0, "recv should have 1 or 0 input");
return RemoteRecv::make(opr.key(), *opr.owner_graph(),
opr.group_client(), config, opr.shape(),
opr.dtype())
.node()
->owner_opr();
}
}
MGB_REG_OPR_SHALLOW_COPY(RemoteRecv, opr_shallow_copy_remote_recv);



+ 3
- 0
src/opr-mm/include/megbrain/opr/io_remote.h View File

@@ -94,6 +94,9 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // {
const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype);

const TensorShape& shape() const { return m_shape; }
const DType& dtype() const { return m_dtype; }

private:
const TensorShape m_shape;
const DType m_dtype;


Loading…
Cancel
Save