diff --git a/imperative/python/megengine/traced_module/expr.py b/imperative/python/megengine/traced_module/expr.py index df0b1616..5a465fd8 100644 --- a/imperative/python/megengine/traced_module/expr.py +++ b/imperative/python/megengine/traced_module/expr.py @@ -36,7 +36,7 @@ class Expr: r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``, ``GetAttr``, ``Input``, ``Constant``) on ``Node``. """ - + inputs = None # type: List[Node] r"""The input Nodes of this Expr.""" outputs = None # type: List[Node] @@ -229,6 +229,7 @@ class GetAttr(Expr): name = None r"""name: the qualified name of the attribute to be retrieved.""" + def __init__(self, module, name, type=None, orig_name=None): super().__init__() assert isinstance(module, ModuleNode) @@ -276,6 +277,7 @@ class CallMethod(Expr): method: the method name. Default: "__call__" """ + def __init__(self, node, method="__call__"): super().__init__() if isinstance(node, type): @@ -351,6 +353,7 @@ class Apply(Expr): opdef: the applied :class:`OpDef`. """ opdef = None + def __init__(self, opdef): super().__init__() assert isinstance(opdef, OpDef) @@ -422,6 +425,7 @@ class CallFunction(Expr): Args: func: a built-in function. """ + def __init__(self, func): super().__init__() assert isinstance(func, Callable) diff --git a/imperative/python/megengine/traced_module/node.py b/imperative/python/megengine/traced_module/node.py index 056043c7..cc084a49 100644 --- a/imperative/python/megengine/traced_module/node.py +++ b/imperative/python/megengine/traced_module/node.py @@ -115,7 +115,7 @@ class Node: class ModuleNode(Node): r"""``ModuleNode`` represents the Module objects.""" - + module_type = Module # type: Type[Module] r"""The type of the Module correspending to the ModuleNode.""" _owner = None # type: weakref.ReferenceType diff --git a/imperative/python/test/unit/traced_module/test_modification.py b/imperative/python/test/unit/traced_module/test_modification.py index 5b331285..edf48348 100644 --- a/imperative/python/test/unit/traced_module/test_modification.py +++ b/imperative/python/test/unit/traced_module/test_modification.py @@ -5,12 +5,21 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import pickle + import numpy as np import megengine.functional as F import megengine.module as M +from megengine.module.identity import Identity from megengine.traced_module import trace_module -from megengine.traced_module.expr import CallFunction, GetAttr +from megengine.traced_module.expr import CallFunction, Expr, GetAttr +from megengine.traced_module.node import Node + + +class IdentityMod(M.Module): + def forward(self, x): + return x class MyBlock(M.Module): @@ -18,11 +27,13 @@ class MyBlock(M.Module): super(MyBlock, self).__init__() self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False) self.bn1 = M.BatchNorm2d(channels) + self.nothing = IdentityMod() def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = F.relu(x) + 1 + x = self.nothing(x) return x @@ -31,10 +42,24 @@ class MyModule(M.Module): super(MyModule, self).__init__() self.block0 = MyBlock() self.block1 = MyBlock() + self.nothing = IdentityMod() def forward(self, x): x = self.block0(x) x = self.block1(x) + x = self.nothing(x) + return x + + +class NewModule(M.Module): + def __init__(self, traced_module): + super(NewModule, self).__init__() + self.module = traced_module + + def forward(self, x): + x = x - 1 + x = self.module(x) + x = x + 1 return x @@ -82,6 +107,12 @@ def test_delete(): graph.compile() np.testing.assert_allclose(expect - 1, F.relu(traced_module(x) - 1), atol=1e-6) + # clear graph + graph.replace_node({graph.outputs[0]: graph.inputs[1]}) + graph.compile() + np.testing.assert_equal(len(list(graph._exprs)), 0) + np.testing.assert_equal(traced_module(x).numpy(), x.numpy()) + def test_flatten(): traced_module, x, expect = _init_module() @@ -89,6 +120,74 @@ def test_flatten(): traced_module.graph.compile() assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs) assert len(traced_module.graph._exprs) == 12 + np.testing.assert_equal(expect.numpy(), traced_module(x).numpy()) + + +def test_id_and_name(): + def _check_id(traced_module): + _total_ids = traced_module.graph._total_ids + node_ids = [n._id for n in traced_module.graph.nodes().as_list()] + assert len(set(node_ids)) == len(node_ids) + assert max(node_ids) + 1 == len(node_ids) + + expr_ids = [n._id for n in traced_module.graph.exprs().as_list()] + assert len(set(expr_ids)) == len(expr_ids) + assert max(expr_ids) + 1 == _total_ids[1] + + def _check_name(flatened_module): + node_names = [n._name for n in flatened_module.graph.nodes().as_list()] + assert len(set(node_names)) == len(node_names) + + traced_module, x, expect = _init_module() + _check_id(traced_module) + + flattened_module = traced_module.flatten() + _check_id(flattened_module) + _check_name(flattened_module) + + # pickle check + obj = pickle.dumps(traced_module) + traced_module = pickle.loads(obj) + Node._set_next_id(159) + Expr._set_next_id(1024) + + graph = traced_module.graph + for expr in graph.get_function_by_type(F.relu).as_list(): + relu_out = expr.outputs[0] + cur_graph = expr.top_graph + with cur_graph.insert_exprs(): + neg_out = F.neg(relu_out) + cur_graph.replace_node({relu_out: neg_out}) + cur_graph.compile() + _check_id(traced_module) + + flattened_module = traced_module.flatten() + _check_id(flattened_module) + _check_name(flattened_module) + + # check trace TracedModule + obj = pickle.dumps(traced_module) + traced_module = pickle.loads(obj) + module = NewModule(traced_module) + traced_module = trace_module(module, x) + _check_id(traced_module) + + flattened_module = traced_module.flatten() + _check_id(flattened_module) + _check_name(flattened_module) + + +def test_set_name(): + traced_module, x, expect = _init_module() + graph = traced_module.graph + output_node = graph.outputs[0] + + def rename(name): + output_node.name = name + + np.testing.assert_raises(AssertionError, rename, "block1_out") + rename("output") + np.testing.assert_equal(str(graph.outputs[0]), "output") def test_extra_block():