From 65432d3b935b616833978babd8de129f9bb21027 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 25 Mar 2020 20:07:05 +0800 Subject: [PATCH] fix(mge/module): fix torch subgraph under jit.trace with symbolic=False GitOrigin-RevId: a208ba79d964baf78bdd9d10264dcb9166bb8506 --- python_module/megengine/module/pytorch/pytorch.py | 2 + python_module/test/unit/module/test_pytorch.py | 68 +++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/python_module/megengine/module/pytorch/pytorch.py b/python_module/megengine/module/pytorch/pytorch.py index 35b103e0..81548a50 100644 --- a/python_module/megengine/module/pytorch/pytorch.py +++ b/python_module/megengine/module/pytorch/pytorch.py @@ -305,6 +305,8 @@ class PyTorchSubgraphImplOpr(mgb.craniotome.CraniotomeBase): ret.__dict__["_last_forward_inputs"] = d0.pop("_last_forward_inputs") ret.__dict__["_last_forward_outputs"] = d0.pop("_last_forward_outputs") + ret.__dict__["_last_forward_params"] = d0.pop("_last_forward_params") + ret.__dict__["_func"] = d0.pop("_func") d0.pop("_grad_opr") later_copy = self._grad_opr in _copy_dict diff --git a/python_module/test/unit/module/test_pytorch.py b/python_module/test/unit/module/test_pytorch.py index 9b9456dc..105bb62f 100644 --- a/python_module/test/unit/module/test_pytorch.py +++ b/python_module/test/unit/module/test_pytorch.py @@ -13,8 +13,11 @@ from helpers import randomTorch import megengine as mge import megengine._internal as mgb import megengine.functional +import megengine.optimizer as optimizer from megengine import get_default_device, set_default_device from megengine.core import Parameter, tensor +from megengine.jit import trace +from megengine.module import Module as MGEModule from megengine.module.pytorch import PyTorchModule from megengine.test import assertTensorClose @@ -72,3 +75,68 @@ def test_pytorch_backward(): return mge.functional.grad(mge_e, mge_a, use_virtual_grad=False) assertTensorClose(get_pytorch_backward().numpy(), get_mge_backward().numpy()) + + +def test_pytorch_mixed(): + + init_param = (np.array([2.0], dtype=np.float32), np.array([3.0], dtype=np.float32)) + lr = 1.0 + + class Mixed(MGEModule): + class SubModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.multiplier = torch.nn.Parameter(torch.tensor(init_param[0])) + + def forward(self, inp): + return inp * self.multiplier + + def __init__(self): + super().__init__() + self.torch_module = PyTorchModule(self.SubModule()) + a = list(self.SubModule().named_parameters(recurse=True)) + a = list(self.SubModule().parameters()) + self.multiplier = Parameter(np.array(init_param[1]), dtype=np.float32) + + def forward(self, inp): + return self.torch_module(inp) * self.multiplier + + def run(step, enable_trace, use_symbolic): + def train_func(data, net=None, opt=None): + pred = net(data) + opt.backward(pred) + return pred + + if enable_trace: + train_func = trace(train_func, symbolic=use_symbolic) + + net = Mixed() + data = tensor() + opt = optimizer.SGD(net.parameters(), lr=lr) + + saved_param = init_param + for i in range(step): + opt.zero_grad() + data.set_value([i + 1.0]) + output = train_func(data, net=net, opt=opt) + opt.step() + + expect_param = ( + saved_param[0] - lr * saved_param[1] * data.numpy(), + saved_param[1] - lr * saved_param[0] * data.numpy(), + ) + assertTensorClose( + output.numpy(), saved_param[0] * saved_param[1] * data.numpy() + ) + torch_param = net.torch_module._torch_params[0].detach().cpu() + assertTensorClose(torch_param.numpy(), expect_param[0]) + assertTensorClose(net.multiplier.numpy(), expect_param[1]) + saved_param = expect_param + + run(1, False, False) + run(1, True, True) + run(1, True, False) + + run(2, False, False) + run(2, True, True) + run(2, True, False)