Browse Source

fix(mge/traced_module): fix node naming in the flattened graph

GitOrigin-RevId: aa7c516725
release-1.7
Megvii Engine Team 3 years ago
parent
commit
edfd38befd
1 changed files with 19 additions and 11 deletions
  1. +19
    -11
      imperative/python/megengine/traced_module/traced_module.py

+ 19
- 11
imperative/python/megengine/traced_module/traced_module.py View File

@@ -1122,16 +1122,25 @@ class InternalGraph:
self.__dict__.update(state)

if old_version:
self.inputs[0]._qualname = self._qualname
for e in self.exprs(False):
if isinstance(e, GetAttr):
e.outputs[0]._qualname = "{}.{}".format(
e.inputs[0]._qualname, e.name
)

for n in self.nodes(False):
qualname = self._qualname
if isinstance(n.expr, CallMethod) and isinstance(
n.expr.inputs[0], ModuleNode
):
n._qualname = n.expr.inputs[0]._qualname + ".[out]"
continue
if n._qualname:
qualname = "{}.{}".format(qualname, n._qualname)
n._qualname = qualname
if (
not isinstance(n.expr, GetAttr)
and isinstance(n, TensorNode)
and n._qualname
):
n._qualname = "{}.{}".format(self._qualname, n._qualname)
self._namespace = NameSpace(self._name, self._qualname)
self._re_associate_name()

@@ -2080,8 +2089,10 @@ class TracedModule(Module):

node2obj[graph._inputs[0]] = module
prefix_name = call.inputs[0]._name if call else ""
exprs = []
flattened_exprs = []

for expr in graph._exprs:
exprs = [expr]

if call is not None:
_replace_inputs_and_outputs(expr, repl_dict)
@@ -2102,10 +2113,7 @@ class TracedModule(Module):
else None
)
if expr_graph is not None:
exprs.extend(
_flatten_subgraph(graph, expr_graph, expr, obj)
)
continue
exprs = _flatten_subgraph(graph, expr_graph, expr, obj)

if parent_graph is not None:
for node in expr.outputs:
@@ -2116,13 +2124,13 @@ class TracedModule(Module):
name, node
)

exprs.append(expr)
flattened_exprs.extend(exprs)

if call is not None:
for i in call.inputs:
i.users.remove(call)

return exprs
return flattened_exprs

new_module.graph._exprs = _flatten_subgraph(
None, new_module.graph, None, new_module


Loading…
Cancel
Save