Browse Source

fix(mge/traced_module): let graph record total_id

GitOrigin-RevId: f99178f3ac
release-1.6
Megvii Engine Team 3 years ago
parent
commit
c7a8d945c7
3 changed files with 64 additions and 26 deletions
  1. +9
    -0
      imperative/python/megengine/traced_module/expr.py
  2. +9
    -4
      imperative/python/megengine/traced_module/node.py
  3. +46
    -22
      imperative/python/megengine/traced_module/traced_module.py

+ 9
- 0
imperative/python/megengine/traced_module/expr.py View File

@@ -167,6 +167,15 @@ class Expr:
state.pop("_top_graph")
return state

@classmethod
def get_total_id(cls):
return cls.__total_id

@classmethod
def set_total_id(cls, id: int = 0):
assert isinstance(id, int)
cls.__total_id = id


# expr: None (i.e. fake expression which is used to mark input)
class Input(Expr):


+ 9
- 4
imperative/python/megengine/traced_module/node.py View File

@@ -42,10 +42,6 @@ class Node:
self._orig_name = orig_name
self.actual_node = [] # type: List[Node]

def __setstate__(self, d):
self.__dict__ = d
Node.__total_id = max(Node.__total_id, self._id) + 1

def __repr__(self):
format_spec = Node._format_spec
return self.__format__(format_spec)
@@ -89,6 +85,15 @@ class Node:
cls._format_spec = str
return old_format_spec

@classmethod
def get_total_id(cls):
return cls.__total_id

@classmethod
def set_total_id(cls, id: int = 0):
assert isinstance(id, int)
cls.__total_id = id


class ModuleNode(Node):
r"""``ModuleNode`` represents the Module objects."""


+ 46
- 22
imperative/python/megengine/traced_module/traced_module.py View File

@@ -247,6 +247,10 @@ def _init_id2name(mod: Module, prefix: str = ""):
class _InsertExprs:
def __init__(self, graph, expr: Optional[Expr] = None):
self.graph = graph
while graph.top_graph is not None:
graph = graph.top_graph
assert graph.inputs[0].owner._is_top
self.root_graph = graph
self.global_scope = InternalGraph(
graph._name, graph._prefix_name, graph._module_name
)
@@ -256,6 +260,9 @@ class _InsertExprs:

def __enter__(self):
self.use_sym_shape = set_symbolic_shape(True)
node_id, expr_id = self.root_graph._total_ids
Node.set_total_id(node_id)
Expr.set_total_id(expr_id)
set_module_tracing()
_set_convert_node_flag(True)
assert active_module_tracer() is None
@@ -334,10 +341,8 @@ class _InsertExprs:
insert_index += 1

self.graph._used_names.update(self.global_scope._used_names)
graph = self.graph
while graph.top_graph is not None:
graph = graph.top_graph
graph.inputs[0].owner._update_ref()
self.root_graph._total_ids = (Node.get_total_id(), Expr.get_total_id())
self.root_graph.inputs[0].owner._update_ref()
return True


@@ -353,7 +358,8 @@ class InternalGraph:
_exprs = None # type: List[Expr]
_inputs = None # type: List[Node]
_outputs = None # type: List[Node]
_top_graph = None
_top_graph = None # type: InternalGraph
_total_ids = None # type: List[int]

def __init__(self, name: str = None, prefix_name: str = "", module_name: str = ""):
self._exprs = []
@@ -704,8 +710,12 @@ class InternalGraph:
def replace_node(self, repl_dict: Dict[Node, Node]):
while repl_dict:
node, repl_node = repl_dict.popitem()
assert type(node) == type(
repl_node
), "The type of {}({}) and {}({}) are not the same".format(
node, type(node).__name__, repl_node, type(repl_node).__name__
)
# check graph inputs and outputs
# assert node not in self.inputs, "Cannot replace inputs"
for i, n in enumerate(self.outputs):
if n is node:
self.outputs[i] = repl_node
@@ -713,7 +723,10 @@ class InternalGraph:
# update inputs of expr in node.users
graph = repl_node.top_graph
assert graph is not None
index = graph._exprs.index(repl_node.expr)
assert graph is self
index = -1
if not isinstance(repl_node.expr, Input):
index = graph._exprs.index(repl_node.expr)
dep_exprs = self.get_dep_exprs(repl_node)
i = 0
while i < len(node.users):
@@ -745,6 +758,13 @@ class InternalGraph:
n.users.remove(expr)
self._exprs.remove(expr)

def _reset_ids(self):
for total_expr_id, expr in enumerate(self.exprs()):
expr._id = total_expr_id
for total_node_id, node in enumerate(self.nodes()):
node._id = total_node_id
self._total_ids = (total_node_id + 1, total_expr_id + 1)

def interpret(self, *inputs):
node2value = {}
end_nodes_set = set(self._end_point)
@@ -989,6 +1009,8 @@ class TracedModuleBuilder(NodeMixin):
)
for _, g in self._argdef_graph_map.items():
g.compile()
if self._is_top:
g._total_ids = (Node.get_total_id(), Expr.get_total_id())

for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__:
@@ -1247,6 +1269,8 @@ class _expr_iter:
self.recursive = recursive

def __iter__(self):
for inp_node in self.graph.inputs:
yield inp_node.expr
for expr in self.graph._exprs:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr
@@ -1262,10 +1286,10 @@ class _node_iter:
node_ids = set()
for expr in graph.exprs(recursive):
for n in expr.inputs + expr.outputs:
if n._id in node_ids:
if id(n) in node_ids:
continue
nodes.append(n)
node_ids.add(n._id)
node_ids.add(id(n))
self.nodes = list(sorted(nodes, key=lambda x: x._id))

def __iter__(self):
@@ -1546,6 +1570,7 @@ class TracedModule(Module):
active_module_tracer().push_scope(new_module.graph)

def _flatten_subgraph(
parent_graph: InternalGraph,
graph: InternalGraph,
module: Module,
call=None,
@@ -1590,7 +1615,10 @@ class TracedModule(Module):
if inp is call_out:
expr.inputs[index] = repl_dict[out]
repl_dict[out].users.append(expr)

if parent_graph is not None:
for index, parent_out in enumerate(parent_graph._outputs):
if parent_out is call_out:
parent_graph._outputs[index] = repl_dict[out]
continue
repl_dict[out] = call.outputs[ind]

@@ -1622,6 +1650,7 @@ class TracedModule(Module):
)
exprs.extend(
_flatten_subgraph(
graph,
expr_graph,
obj,
expr,
@@ -1643,19 +1672,10 @@ class TracedModule(Module):
i.users.remove(call)
return exprs

new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module)
new_module.graph._exprs = _flatten_subgraph(None, new_module.graph, new_module)
new_module.graph.compile()
set_active_module_tracer(None)
for _id, expr in enumerate(new_module.graph._exprs):
expr._id = _id
total_node_id = 0
for i in new_module.graph._inputs:
i._id = total_node_id
total_node_id += 1
for expr in new_module.graph._exprs:
for o in expr.outputs:
o._id = total_node_id
total_node_id += 1
new_module.graph._reset_ids()
return new_module

def __getstate__(self):
@@ -1735,6 +1755,8 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
set_active_module_tracer(
module_tracer(_wrapped_function, _init_id2name(mod, "self"))
)
for cls in [Expr, Node]:
cls.set_total_id(0)
with active_module_tracer().patcher:
global_scope = InternalGraph(name="")
active_module_tracer().push_scope(global_scope)
@@ -1750,7 +1772,9 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
)
builder(*args, **kwargs)
active_module_tracer().pop_scope()
return builder.build()
traced_mod = builder.build()
traced_mod.graph._reset_ids()
return traced_mod
finally:
set_symbolic_shape(use_sym_shape)
set_active_module_tracer(None)


Loading…
Cancel
Save