diff --git a/imperative/python/megengine/traced_module/expr.py b/imperative/python/megengine/traced_module/expr.py index 836f765c..0247b861 100644 --- a/imperative/python/megengine/traced_module/expr.py +++ b/imperative/python/megengine/traced_module/expr.py @@ -167,6 +167,15 @@ class Expr: state.pop("_top_graph") return state + @classmethod + def get_total_id(cls): + return cls.__total_id + + @classmethod + def set_total_id(cls, id: int = 0): + assert isinstance(id, int) + cls.__total_id = id + # expr: None (i.e. fake expression which is used to mark input) class Input(Expr): diff --git a/imperative/python/megengine/traced_module/node.py b/imperative/python/megengine/traced_module/node.py index db1326dd..2ae32e70 100644 --- a/imperative/python/megengine/traced_module/node.py +++ b/imperative/python/megengine/traced_module/node.py @@ -42,10 +42,6 @@ class Node: self._orig_name = orig_name self.actual_node = [] # type: List[Node] - def __setstate__(self, d): - self.__dict__ = d - Node.__total_id = max(Node.__total_id, self._id) + 1 - def __repr__(self): format_spec = Node._format_spec return self.__format__(format_spec) @@ -89,6 +85,15 @@ class Node: cls._format_spec = str return old_format_spec + @classmethod + def get_total_id(cls): + return cls.__total_id + + @classmethod + def set_total_id(cls, id: int = 0): + assert isinstance(id, int) + cls.__total_id = id + class ModuleNode(Node): r"""``ModuleNode`` represents the Module objects.""" diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 83911593..56b88e87 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -247,6 +247,10 @@ def _init_id2name(mod: Module, prefix: str = ""): class _InsertExprs: def __init__(self, graph, expr: Optional[Expr] = None): self.graph = graph + while graph.top_graph is not None: + graph = graph.top_graph + assert graph.inputs[0].owner._is_top + self.root_graph = graph self.global_scope = InternalGraph( graph._name, graph._prefix_name, graph._module_name ) @@ -256,6 +260,9 @@ class _InsertExprs: def __enter__(self): self.use_sym_shape = set_symbolic_shape(True) + node_id, expr_id = self.root_graph._total_ids + Node.set_total_id(node_id) + Expr.set_total_id(expr_id) set_module_tracing() _set_convert_node_flag(True) assert active_module_tracer() is None @@ -334,10 +341,8 @@ class _InsertExprs: insert_index += 1 self.graph._used_names.update(self.global_scope._used_names) - graph = self.graph - while graph.top_graph is not None: - graph = graph.top_graph - graph.inputs[0].owner._update_ref() + self.root_graph._total_ids = (Node.get_total_id(), Expr.get_total_id()) + self.root_graph.inputs[0].owner._update_ref() return True @@ -353,7 +358,8 @@ class InternalGraph: _exprs = None # type: List[Expr] _inputs = None # type: List[Node] _outputs = None # type: List[Node] - _top_graph = None + _top_graph = None # type: InternalGraph + _total_ids = None # type: List[int] def __init__(self, name: str = None, prefix_name: str = "", module_name: str = ""): self._exprs = [] @@ -704,8 +710,12 @@ class InternalGraph: def replace_node(self, repl_dict: Dict[Node, Node]): while repl_dict: node, repl_node = repl_dict.popitem() + assert type(node) == type( + repl_node + ), "The type of {}({}) and {}({}) are not the same".format( + node, type(node).__name__, repl_node, type(repl_node).__name__ + ) # check graph inputs and outputs - # assert node not in self.inputs, "Cannot replace inputs" for i, n in enumerate(self.outputs): if n is node: self.outputs[i] = repl_node @@ -713,7 +723,10 @@ class InternalGraph: # update inputs of expr in node.users graph = repl_node.top_graph assert graph is not None - index = graph._exprs.index(repl_node.expr) + assert graph is self + index = -1 + if not isinstance(repl_node.expr, Input): + index = graph._exprs.index(repl_node.expr) dep_exprs = self.get_dep_exprs(repl_node) i = 0 while i < len(node.users): @@ -745,6 +758,13 @@ class InternalGraph: n.users.remove(expr) self._exprs.remove(expr) + def _reset_ids(self): + for total_expr_id, expr in enumerate(self.exprs()): + expr._id = total_expr_id + for total_node_id, node in enumerate(self.nodes()): + node._id = total_node_id + self._total_ids = (total_node_id + 1, total_expr_id + 1) + def interpret(self, *inputs): node2value = {} end_nodes_set = set(self._end_point) @@ -989,6 +1009,8 @@ class TracedModuleBuilder(NodeMixin): ) for _, g in self._argdef_graph_map.items(): g.compile() + if self._is_top: + g._total_ids = (Node.get_total_id(), Expr.get_total_id()) for k, v in self.__dict__.items(): if k not in TracedModuleBuilder.__builder_attributes__: @@ -1247,6 +1269,8 @@ class _expr_iter: self.recursive = recursive def __iter__(self): + for inp_node in self.graph.inputs: + yield inp_node.expr for expr in self.graph._exprs: if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): yield expr @@ -1262,10 +1286,10 @@ class _node_iter: node_ids = set() for expr in graph.exprs(recursive): for n in expr.inputs + expr.outputs: - if n._id in node_ids: + if id(n) in node_ids: continue nodes.append(n) - node_ids.add(n._id) + node_ids.add(id(n)) self.nodes = list(sorted(nodes, key=lambda x: x._id)) def __iter__(self): @@ -1546,6 +1570,7 @@ class TracedModule(Module): active_module_tracer().push_scope(new_module.graph) def _flatten_subgraph( + parent_graph: InternalGraph, graph: InternalGraph, module: Module, call=None, @@ -1590,7 +1615,10 @@ class TracedModule(Module): if inp is call_out: expr.inputs[index] = repl_dict[out] repl_dict[out].users.append(expr) - + if parent_graph is not None: + for index, parent_out in enumerate(parent_graph._outputs): + if parent_out is call_out: + parent_graph._outputs[index] = repl_dict[out] continue repl_dict[out] = call.outputs[ind] @@ -1622,6 +1650,7 @@ class TracedModule(Module): ) exprs.extend( _flatten_subgraph( + graph, expr_graph, obj, expr, @@ -1643,19 +1672,10 @@ class TracedModule(Module): i.users.remove(call) return exprs - new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module) + new_module.graph._exprs = _flatten_subgraph(None, new_module.graph, new_module) new_module.graph.compile() set_active_module_tracer(None) - for _id, expr in enumerate(new_module.graph._exprs): - expr._id = _id - total_node_id = 0 - for i in new_module.graph._inputs: - i._id = total_node_id - total_node_id += 1 - for expr in new_module.graph._exprs: - for o in expr.outputs: - o._id = total_node_id - total_node_id += 1 + new_module.graph._reset_ids() return new_module def __getstate__(self): @@ -1735,6 +1755,8 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: set_active_module_tracer( module_tracer(_wrapped_function, _init_id2name(mod, "self")) ) + for cls in [Expr, Node]: + cls.set_total_id(0) with active_module_tracer().patcher: global_scope = InternalGraph(name="") active_module_tracer().push_scope(global_scope) @@ -1750,7 +1772,9 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: ) builder(*args, **kwargs) active_module_tracer().pop_scope() - return builder.build() + traced_mod = builder.build() + traced_mod.graph._reset_ids() + return traced_mod finally: set_symbolic_shape(use_sym_shape) set_active_module_tracer(None)