Browse Source

fix(autodiff): fix inplace operation on autodiff.Function

GitOrigin-RevId: a658680f35
release-1.10
Megvii Engine Team 3 years ago
parent
commit
a49d5bf973
2 changed files with 41 additions and 3 deletions
  1. +7
    -3
      imperative/python/megengine/core/autodiff/grad.py
  2. +34
    -0
      imperative/python/test/unit/core/test_function.py

+ 7
- 3
imperative/python/megengine/core/autodiff/grad.py View File

@@ -178,8 +178,10 @@ class Function:
return self._default_rule(*args), self.backward

def __call__(self, *args):
from ...tensor import Tensor

for arg in args:
if not isinstance(arg, core2.Tensor):
if not isinstance(arg, Tensor):
raise TypeError(
"op Function expect type Tensor as inputs, got {}".format(type(arg))
)
@@ -191,6 +193,8 @@ class Function:
grad = Grad.key2grad[grad_key]
group = [ref() for ref in grad._group]

origin_args = [Tensor(arg) for arg in args]

for grad in group:
grad.suppress()
outputs, backward = self._grad_rule(*args)
@@ -199,13 +203,13 @@ class Function:

def normalized_backward(*output_grads):
input_grads = backward(*output_grads)
if isinstance(input_grads, core2.Tensor) or input_grads is None:
if isinstance(input_grads, Tensor) or input_grads is None:
input_grads = (input_grads,)
return input_grads

if self.__single_output:
outputs = (outputs,)
outputs = core2.set_grad(normalized_backward, args, outputs)
outputs = core2.set_grad(normalized_backward, origin_args, outputs)
if self.__single_output:
(outputs,) = outputs
return outputs


+ 34
- 0
imperative/python/test/unit/core/test_function.py View File

@@ -347,3 +347,37 @@ def test_multiple_grad():

np.testing.assert_almost_equal(loss.numpy(), (av * 10))
np.testing.assert_almost_equal(net.a.numpy(), (av - 20))


def test_inplace_forward():
data_shape = (9, 2, 6)
av = np.random.random(data_shape).astype(np.float32)

class MulFunc(Function):
def forward(self, a):
self.a = a
a *= 10
return a

def backward(self, grad_o):
return grad_o * 10

class Simple(Module):
def __init__(self, a):
super().__init__()
self.a = Parameter(a, dtype=np.float32)
self.layer1 = MulFunc()

def forward(self):
x = self.layer1(self.a)
return x

net = Simple(av)
gm = ad.GradManager().attach(net.parameters())
opt = optimizer.SGD(net.parameters(), lr=1.0)

opt.clear_grad()
with gm:
loss = net()
gm.backward(loss.sum())
opt.step()

Loading…
Cancel
Save