import numpy as np import megengine.module as M from megengine import Tensor from megengine.traced_module import TracedModule, trace_module class MyModule1(M.Module): def forward(self, x): y = Tensor(x) y += 1 x = x + 2 return x, y class MyModule2(M.Module): def forward(self, x): y = Tensor([1, x, 1]) y += 1 x = x + 2 return x, y class MyModule3(M.Module): def __init__(self): super().__init__() self.modules = [ M.Elemwise("ADD"), M.Elemwise("ADD"), {"a": M.Elemwise("ADD"), "b": M.Elemwise("ADD")}, ] def forward(self, a, b): x = self.modules[0](a, b) y = self.modules[1](a, b) y = self.modules[2]["a"](x, y) y = self.modules[2]["b"](x, y) return y def test_trace_module(): x = Tensor(1) m1 = MyModule1() tm1 = trace_module(m1, x) m2 = MyModule2() tm2 = trace_module(m2, x) inp = Tensor(2) gt = m1(inp) output = tm1(inp) for a, b in zip(output, gt): np.testing.assert_equal(a.numpy(), b.numpy()) gt1 = m2(inp) output1 = tm2(inp) for a, b in zip(output1, gt1): np.testing.assert_equal(a.numpy(), b.numpy()) a, b = Tensor(1), Tensor(2) m3 = MyModule3() gt = m3(a, b) tm3 = trace_module(m3, a, b) out = tm3(a, b) np.testing.assert_equal(out.numpy(), gt.numpy()) assert isinstance(tm3.modules.__dict__["0"], M.Elemwise) assert isinstance(tm3.modules.__dict__["2"], TracedModule) assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise)