Browse Source

fix(mge/module): fix torch subgraph under jit.trace with symbolic=False

GitOrigin-RevId: a208ba79d9
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
65432d3b93
2 changed files with 70 additions and 0 deletions
  1. +2
    -0
      python_module/megengine/module/pytorch/pytorch.py
  2. +68
    -0
      python_module/test/unit/module/test_pytorch.py

+ 2
- 0
python_module/megengine/module/pytorch/pytorch.py View File

@@ -305,6 +305,8 @@ class PyTorchSubgraphImplOpr(mgb.craniotome.CraniotomeBase):


ret.__dict__["_last_forward_inputs"] = d0.pop("_last_forward_inputs") ret.__dict__["_last_forward_inputs"] = d0.pop("_last_forward_inputs")
ret.__dict__["_last_forward_outputs"] = d0.pop("_last_forward_outputs") ret.__dict__["_last_forward_outputs"] = d0.pop("_last_forward_outputs")
ret.__dict__["_last_forward_params"] = d0.pop("_last_forward_params")
ret.__dict__["_func"] = d0.pop("_func")


d0.pop("_grad_opr") d0.pop("_grad_opr")
later_copy = self._grad_opr in _copy_dict later_copy = self._grad_opr in _copy_dict


+ 68
- 0
python_module/test/unit/module/test_pytorch.py View File

@@ -13,8 +13,11 @@ from helpers import randomTorch
import megengine as mge import megengine as mge
import megengine._internal as mgb import megengine._internal as mgb
import megengine.functional import megengine.functional
import megengine.optimizer as optimizer
from megengine import get_default_device, set_default_device from megengine import get_default_device, set_default_device
from megengine.core import Parameter, tensor from megengine.core import Parameter, tensor
from megengine.jit import trace
from megengine.module import Module as MGEModule
from megengine.module.pytorch import PyTorchModule from megengine.module.pytorch import PyTorchModule
from megengine.test import assertTensorClose from megengine.test import assertTensorClose


@@ -72,3 +75,68 @@ def test_pytorch_backward():
return mge.functional.grad(mge_e, mge_a, use_virtual_grad=False) return mge.functional.grad(mge_e, mge_a, use_virtual_grad=False)


assertTensorClose(get_pytorch_backward().numpy(), get_mge_backward().numpy()) assertTensorClose(get_pytorch_backward().numpy(), get_mge_backward().numpy())


def test_pytorch_mixed():

init_param = (np.array([2.0], dtype=np.float32), np.array([3.0], dtype=np.float32))
lr = 1.0

class Mixed(MGEModule):
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.multiplier = torch.nn.Parameter(torch.tensor(init_param[0]))

def forward(self, inp):
return inp * self.multiplier

def __init__(self):
super().__init__()
self.torch_module = PyTorchModule(self.SubModule())
a = list(self.SubModule().named_parameters(recurse=True))
a = list(self.SubModule().parameters())
self.multiplier = Parameter(np.array(init_param[1]), dtype=np.float32)

def forward(self, inp):
return self.torch_module(inp) * self.multiplier

def run(step, enable_trace, use_symbolic):
def train_func(data, net=None, opt=None):
pred = net(data)
opt.backward(pred)
return pred

if enable_trace:
train_func = trace(train_func, symbolic=use_symbolic)

net = Mixed()
data = tensor()
opt = optimizer.SGD(net.parameters(), lr=lr)

saved_param = init_param
for i in range(step):
opt.zero_grad()
data.set_value([i + 1.0])
output = train_func(data, net=net, opt=opt)
opt.step()

expect_param = (
saved_param[0] - lr * saved_param[1] * data.numpy(),
saved_param[1] - lr * saved_param[0] * data.numpy(),
)
assertTensorClose(
output.numpy(), saved_param[0] * saved_param[1] * data.numpy()
)
torch_param = net.torch_module._torch_params[0].detach().cpu()
assertTensorClose(torch_param.numpy(), expect_param[0])
assertTensorClose(net.multiplier.numpy(), expect_param[1])
saved_param = expect_param

run(1, False, False)
run(1, True, True)
run(1, True, False)

run(2, False, False)
run(2, True, True)
run(2, True, False)

Loading…
Cancel
Save