|
@@ -16,7 +16,7 @@ import megengine.module as M |
|
|
from megengine.module.identity import Identity |
|
|
from megengine.module.identity import Identity |
|
|
from megengine.traced_module import trace_module |
|
|
from megengine.traced_module import trace_module |
|
|
from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input |
|
|
from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input |
|
|
from megengine.traced_module.node import ModuleNode, Node |
|
|
|
|
|
|
|
|
from megengine.traced_module.node import ModuleNode, Node, TensorNode |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IdentityMod(M.Module): |
|
|
class IdentityMod(M.Module): |
|
@@ -159,21 +159,44 @@ def test_insert(): |
|
|
|
|
|
|
|
|
def test_insert_module(): |
|
|
def test_insert_module(): |
|
|
class Neg(M.Module): |
|
|
class Neg(M.Module): |
|
|
|
|
|
def __init__(self, name): |
|
|
|
|
|
super().__init__(name) |
|
|
|
|
|
self.identity = M.Identity() |
|
|
|
|
|
self.identity_list = [M.Identity(), M.Identity()] |
|
|
|
|
|
self.identity_dict = {"0": M.Identity(), "1": M.Identity()} |
|
|
|
|
|
self.param = F.zeros((1,)) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
def forward(self, x): |
|
|
return F.neg(x) |
|
|
|
|
|
|
|
|
x = self.identity(x) |
|
|
|
|
|
for m in self.identity_dict: |
|
|
|
|
|
x = self.identity_dict[m](x) |
|
|
|
|
|
for m in self.identity_list: |
|
|
|
|
|
x = m(x) |
|
|
|
|
|
return F.neg(x) + self.param |
|
|
|
|
|
|
|
|
traced_module, x, expect = _init_block() |
|
|
traced_module, x, expect = _init_block() |
|
|
graph = traced_module.graph |
|
|
graph = traced_module.graph |
|
|
relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0] |
|
|
relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0] |
|
|
self = graph.inputs[0] |
|
|
self = graph.inputs[0] |
|
|
setattr(traced_module, "neg", Neg()) |
|
|
|
|
|
|
|
|
setattr(traced_module, "neg", Neg(name="neg")) |
|
|
|
|
|
setattr(traced_module, "neg2", Neg(name="neg")) |
|
|
|
|
|
setattr(traced_module, "param", F.zeros((1,))) |
|
|
|
|
|
|
|
|
with graph.insert_exprs(): |
|
|
with graph.insert_exprs(): |
|
|
neg_out = self.neg(relu_out) |
|
|
neg_out = self.neg(relu_out) |
|
|
|
|
|
neg_out = self.neg2(relu_out) |
|
|
|
|
|
neg_out = neg_out + self.param |
|
|
graph.replace_node({relu_out: neg_out}) |
|
|
graph.replace_node({relu_out: neg_out}) |
|
|
graph.compile() |
|
|
graph.compile() |
|
|
|
|
|
|
|
|
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) |
|
|
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) |
|
|
assert traced_module.neg.graph is not None |
|
|
assert traced_module.neg.graph is not None |
|
|
assert len(traced_module.neg.graph._exprs) == 1 |
|
|
|
|
|
|
|
|
assert traced_module.neg2.graph is not None |
|
|
|
|
|
assert traced_module.neg2.param is not None |
|
|
|
|
|
assert len(traced_module.neg.graph._exprs) == 13 |
|
|
|
|
|
for n in traced_module.graph.nodes(): |
|
|
|
|
|
if isinstance(n, TensorNode): |
|
|
|
|
|
assert n.value is None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_add_input_and_output(): |
|
|
def test_add_input_and_output(): |
|
|