|
|
@@ -13,8 +13,11 @@ from helpers import randomTorch |
|
|
|
import megengine as mge |
|
|
|
import megengine._internal as mgb |
|
|
|
import megengine.functional |
|
|
|
import megengine.optimizer as optimizer |
|
|
|
from megengine import get_default_device, set_default_device |
|
|
|
from megengine.core import Parameter, tensor |
|
|
|
from megengine.jit import trace |
|
|
|
from megengine.module import Module as MGEModule |
|
|
|
from megengine.module.pytorch import PyTorchModule |
|
|
|
from megengine.test import assertTensorClose |
|
|
|
|
|
|
@@ -72,3 +75,68 @@ def test_pytorch_backward(): |
|
|
|
return mge.functional.grad(mge_e, mge_a, use_virtual_grad=False) |
|
|
|
|
|
|
|
assertTensorClose(get_pytorch_backward().numpy(), get_mge_backward().numpy()) |
|
|
|
|
|
|
|
|
|
|
|
def test_pytorch_mixed(): |
|
|
|
|
|
|
|
init_param = (np.array([2.0], dtype=np.float32), np.array([3.0], dtype=np.float32)) |
|
|
|
lr = 1.0 |
|
|
|
|
|
|
|
class Mixed(MGEModule): |
|
|
|
class SubModule(torch.nn.Module): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.multiplier = torch.nn.Parameter(torch.tensor(init_param[0])) |
|
|
|
|
|
|
|
def forward(self, inp): |
|
|
|
return inp * self.multiplier |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.torch_module = PyTorchModule(self.SubModule()) |
|
|
|
a = list(self.SubModule().named_parameters(recurse=True)) |
|
|
|
a = list(self.SubModule().parameters()) |
|
|
|
self.multiplier = Parameter(np.array(init_param[1]), dtype=np.float32) |
|
|
|
|
|
|
|
def forward(self, inp): |
|
|
|
return self.torch_module(inp) * self.multiplier |
|
|
|
|
|
|
|
def run(step, enable_trace, use_symbolic): |
|
|
|
def train_func(data, net=None, opt=None): |
|
|
|
pred = net(data) |
|
|
|
opt.backward(pred) |
|
|
|
return pred |
|
|
|
|
|
|
|
if enable_trace: |
|
|
|
train_func = trace(train_func, symbolic=use_symbolic) |
|
|
|
|
|
|
|
net = Mixed() |
|
|
|
data = tensor() |
|
|
|
opt = optimizer.SGD(net.parameters(), lr=lr) |
|
|
|
|
|
|
|
saved_param = init_param |
|
|
|
for i in range(step): |
|
|
|
opt.zero_grad() |
|
|
|
data.set_value([i + 1.0]) |
|
|
|
output = train_func(data, net=net, opt=opt) |
|
|
|
opt.step() |
|
|
|
|
|
|
|
expect_param = ( |
|
|
|
saved_param[0] - lr * saved_param[1] * data.numpy(), |
|
|
|
saved_param[1] - lr * saved_param[0] * data.numpy(), |
|
|
|
) |
|
|
|
assertTensorClose( |
|
|
|
output.numpy(), saved_param[0] * saved_param[1] * data.numpy() |
|
|
|
) |
|
|
|
torch_param = net.torch_module._torch_params[0].detach().cpu() |
|
|
|
assertTensorClose(torch_param.numpy(), expect_param[0]) |
|
|
|
assertTensorClose(net.multiplier.numpy(), expect_param[1]) |
|
|
|
saved_param = expect_param |
|
|
|
|
|
|
|
run(1, False, False) |
|
|
|
run(1, True, True) |
|
|
|
run(1, True, False) |
|
|
|
|
|
|
|
run(2, False, False) |
|
|
|
run(2, True, True) |
|
|
|
run(2, True, False) |