From a49d5bf973ee0e9552f6d75bbb8433216b8fe6ed Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 20 Jun 2022 16:46:03 +0800 Subject: [PATCH] fix(autodiff): fix inplace operation on autodiff.Function GitOrigin-RevId: a658680f35dead712faab6f0b094cda9caa56e76 --- imperative/python/megengine/core/autodiff/grad.py | 10 +++++-- imperative/python/test/unit/core/test_function.py | 34 +++++++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index f9e182ef..06a40159 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -178,8 +178,10 @@ class Function: return self._default_rule(*args), self.backward def __call__(self, *args): + from ...tensor import Tensor + for arg in args: - if not isinstance(arg, core2.Tensor): + if not isinstance(arg, Tensor): raise TypeError( "op Function expect type Tensor as inputs, got {}".format(type(arg)) ) @@ -191,6 +193,8 @@ class Function: grad = Grad.key2grad[grad_key] group = [ref() for ref in grad._group] + origin_args = [Tensor(arg) for arg in args] + for grad in group: grad.suppress() outputs, backward = self._grad_rule(*args) @@ -199,13 +203,13 @@ class Function: def normalized_backward(*output_grads): input_grads = backward(*output_grads) - if isinstance(input_grads, core2.Tensor) or input_grads is None: + if isinstance(input_grads, Tensor) or input_grads is None: input_grads = (input_grads,) return input_grads if self.__single_output: outputs = (outputs,) - outputs = core2.set_grad(normalized_backward, args, outputs) + outputs = core2.set_grad(normalized_backward, origin_args, outputs) if self.__single_output: (outputs,) = outputs return outputs diff --git a/imperative/python/test/unit/core/test_function.py b/imperative/python/test/unit/core/test_function.py index 8a5e9e8e..dd02efee 100644 --- a/imperative/python/test/unit/core/test_function.py +++ b/imperative/python/test/unit/core/test_function.py @@ -347,3 +347,37 @@ def test_multiple_grad(): np.testing.assert_almost_equal(loss.numpy(), (av * 10)) np.testing.assert_almost_equal(net.a.numpy(), (av - 20)) + + +def test_inplace_forward(): + data_shape = (9, 2, 6) + av = np.random.random(data_shape).astype(np.float32) + + class MulFunc(Function): + def forward(self, a): + self.a = a + a *= 10 + return a + + def backward(self, grad_o): + return grad_o * 10 + + class Simple(Module): + def __init__(self, a): + super().__init__() + self.a = Parameter(a, dtype=np.float32) + self.layer1 = MulFunc() + + def forward(self): + x = self.layer1(self.a) + return x + + net = Simple(av) + gm = ad.GradManager().attach(net.parameters()) + opt = optimizer.SGD(net.parameters(), lr=1.0) + + opt.clear_grad() + with gm: + loss = net() + gm.backward(loss.sum()) + opt.step()