Browse Source

fix(mge/traced_module): fix merge GetAttr failure when qualname is incorrect

GitOrigin-RevId: 46241e5b4f
release-1.7
Megvii Engine Team 3 years ago
parent
commit
2d20f93777
2 changed files with 28 additions and 16 deletions
  1. +1
    -1
      imperative/python/megengine/traced_module/node.py
  2. +27
    -15
      imperative/python/megengine/traced_module/traced_module.py

+ 1
- 1
imperative/python/megengine/traced_module/node.py View File

@@ -74,7 +74,7 @@ class Node:
r"""Set a new name to this Node.""" r"""Set a new name to this Node."""
graph = self.top_graph graph = self.top_graph
assert graph is not None, "The parent graph of this Node cannot be None." assert graph is not None, "The parent graph of this Node cannot be None."
assert new_name not in graph._namespace.used_names, (
assert graph._namespace.used_names.get(new_name, None) is None, (
"The name(%s) is already in use. Please try a different one again." "The name(%s) is already in use. Please try a different one again."
% (new_name) % (new_name)
) )


+ 27
- 15
imperative/python/megengine/traced_module/traced_module.py View File

@@ -544,11 +544,11 @@ class InternalGraph:
graph = self.top_graph graph = self.top_graph
assert graph is not None or mod._is_top, "The parent graph cannot be None." assert graph is not None or mod._is_top, "The parent graph cannot be None."
if graph is not None: if graph is not None:
assert new_name not in self._namespace.used_names, (
assert graph._namespace.used_names.get(new_name, None) is None, (
"The name(%s) is already in use. Please try a different one again." "The name(%s) is already in use. Please try a different one again."
% (new_name) % (new_name)
) )
new_name = self._namespace.create_unique_name(new_name, self)
new_name = graph._namespace.create_unique_name(new_name, self)
self._name = new_name self._name = new_name


@property @property
@@ -1032,21 +1032,33 @@ class InternalGraph:
n.inputs[idx] = repl_node n.inputs[idx] = repl_node


def _merge_getattr_expr(self): def _merge_getattr_expr(self):
getattr_nodes_map = dict()
for expr in self._exprs:
if not isinstance(expr, GetAttr):
continue
attr_name = get_suffix_name(self.qualname, expr.outputs[0].qualname)
assert attr_name, '"{}" is not a prefix of "{}"'.format(
self.qualname, expr.outputs[0].qualname
)
if attr_name in getattr_nodes_map:
base_node = getattr_nodes_map[attr_name]
getattr_nodes_map = dict() # Dcit[(Node, str), Node]
node_to_attrname = dict() # Dict[Node, (Node, Str)]
for expr in filter(lambda x: isinstance(x, GetAttr), self._exprs):
base_node, attr_name = expr.inputs[0], expr.name
if expr.inputs[0] in node_to_attrname:
base_node, base_name = node_to_attrname[expr.inputs[0]]
attr_name = "{}.{}".format(base_name, expr.name)

if get_suffix_name(self.qualname, expr.outputs[0].qualname) != attr_name:
expected_qualname = base_node.qualname + "." + attr_name
logger.warning(
"{}.qualname expects {}, got {} actually. You can re-trace this "
"TracedModel to make the name correct.".format(
expr.outputs[0], expected_qualname, expr.outputs[0].qualname
)
)
expr.outputs[0]._qualname = expected_qualname

key = (base_node, attr_name)
node_to_attrname[expr.outputs[0]] = key
if key in getattr_nodes_map:
existed_node = getattr_nodes_map[key]
repl_node = expr.outputs[0] repl_node = expr.outputs[0]
for expr in repl_node.users: for expr in repl_node.users:
base_node.users.append(expr)
existed_node.users.append(expr)
idx = expr.inputs.index(repl_node) idx = expr.inputs.index(repl_node)
expr.inputs[idx] = base_node
expr.inputs[idx] = existed_node
repl_node.users = [] repl_node.users = []
else: else:
if attr_name != expr.name: if attr_name != expr.name:
@@ -1054,7 +1066,7 @@ class InternalGraph:
expr.inputs[0].users.remove(expr) expr.inputs[0].users.remove(expr)
self.inputs[0].users.append(expr) self.inputs[0].users.append(expr)
expr.inputs[0] = self.inputs[0] expr.inputs[0] = self.inputs[0]
getattr_nodes_map[attr_name] = expr.outputs[0]
getattr_nodes_map[key] = expr.outputs[0]


def compile(self): def compile(self):
r"""Delete unused expr.""" r"""Delete unused expr."""


Loading…
Cancel
Save