|
|
@@ -1,3 +1,5 @@ |
|
|
|
from collections import OrderedDict |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
import megengine.functional as F |
|
|
@@ -29,14 +31,19 @@ class MyModule3(M.Module): |
|
|
|
self.modules = [ |
|
|
|
M.Elemwise("ADD"), |
|
|
|
M.Elemwise("ADD"), |
|
|
|
{"a": M.Elemwise("ADD"), "b": M.Elemwise("ADD")}, |
|
|
|
OrderedDict([("a", M.Elemwise("ADD")), ("b", M.Elemwise("ADD"))]), |
|
|
|
M.Elemwise("RELU"), |
|
|
|
M.Elemwise("RELU"), |
|
|
|
] |
|
|
|
|
|
|
|
def forward(self, a, b): |
|
|
|
x = self.modules[0](a, b) |
|
|
|
y = self.modules[1](a, b) |
|
|
|
y = self.modules[2]["a"](x, y) |
|
|
|
y = self.modules[2]["b"](x, y) |
|
|
|
assert list(self.modules[2].keys()) == ["a", "b"] |
|
|
|
for _, m in self.modules[2].items(): |
|
|
|
y = m(x, y) |
|
|
|
for m in self.modules[3:]: |
|
|
|
y = m(y) |
|
|
|
return y |
|
|
|
|
|
|
|
|
|
|
@@ -78,6 +85,7 @@ def test_trace_module(): |
|
|
|
assert isinstance(tm3.modules.__dict__["0"], M.Elemwise) |
|
|
|
assert isinstance(tm3.modules.__dict__["2"], TracedModule) |
|
|
|
assert isinstance(tm3.modules.__dict__["2"].a, M.Elemwise) |
|
|
|
assert isinstance(tm3.modules.__dict__["3"], M.Elemwise) |
|
|
|
|
|
|
|
m4 = MyModule4() |
|
|
|
tm4 = trace_module(m4, a, b) |
|
|
|