diff --git a/imperative/python/megengine/traced_module/node.py b/imperative/python/megengine/traced_module/node.py index d3c9fcca..e6786406 100644 --- a/imperative/python/megengine/traced_module/node.py +++ b/imperative/python/megengine/traced_module/node.py @@ -74,7 +74,7 @@ class Node: r"""Set a new name to this Node.""" graph = self.top_graph assert graph is not None, "The parent graph of this Node cannot be None." - assert new_name not in graph._namespace.used_names, ( + assert graph._namespace.used_names.get(new_name, None) is None, ( "The name(%s) is already in use. Please try a different one again." % (new_name) ) diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 7f6bcb26..b4a7f9b1 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -544,11 +544,11 @@ class InternalGraph: graph = self.top_graph assert graph is not None or mod._is_top, "The parent graph cannot be None." if graph is not None: - assert new_name not in self._namespace.used_names, ( + assert graph._namespace.used_names.get(new_name, None) is None, ( "The name(%s) is already in use. Please try a different one again." % (new_name) ) - new_name = self._namespace.create_unique_name(new_name, self) + new_name = graph._namespace.create_unique_name(new_name, self) self._name = new_name @property @@ -1032,21 +1032,33 @@ class InternalGraph: n.inputs[idx] = repl_node def _merge_getattr_expr(self): - getattr_nodes_map = dict() - for expr in self._exprs: - if not isinstance(expr, GetAttr): - continue - attr_name = get_suffix_name(self.qualname, expr.outputs[0].qualname) - assert attr_name, '"{}" is not a prefix of "{}"'.format( - self.qualname, expr.outputs[0].qualname - ) - if attr_name in getattr_nodes_map: - base_node = getattr_nodes_map[attr_name] + getattr_nodes_map = dict() # Dcit[(Node, str), Node] + node_to_attrname = dict() # Dict[Node, (Node, Str)] + for expr in filter(lambda x: isinstance(x, GetAttr), self._exprs): + base_node, attr_name = expr.inputs[0], expr.name + if expr.inputs[0] in node_to_attrname: + base_node, base_name = node_to_attrname[expr.inputs[0]] + attr_name = "{}.{}".format(base_name, expr.name) + + if get_suffix_name(self.qualname, expr.outputs[0].qualname) != attr_name: + expected_qualname = base_node.qualname + "." + attr_name + logger.warning( + "{}.qualname expects {}, got {} actually. You can re-trace this " + "TracedModel to make the name correct.".format( + expr.outputs[0], expected_qualname, expr.outputs[0].qualname + ) + ) + expr.outputs[0]._qualname = expected_qualname + + key = (base_node, attr_name) + node_to_attrname[expr.outputs[0]] = key + if key in getattr_nodes_map: + existed_node = getattr_nodes_map[key] repl_node = expr.outputs[0] for expr in repl_node.users: - base_node.users.append(expr) + existed_node.users.append(expr) idx = expr.inputs.index(repl_node) - expr.inputs[idx] = base_node + expr.inputs[idx] = existed_node repl_node.users = [] else: if attr_name != expr.name: @@ -1054,7 +1066,7 @@ class InternalGraph: expr.inputs[0].users.remove(expr) self.inputs[0].users.append(expr) expr.inputs[0] = self.inputs[0] - getattr_nodes_map[attr_name] = expr.outputs[0] + getattr_nodes_map[key] = expr.outputs[0] def compile(self): r"""Delete unused expr."""