Browse Source

feat(traced_module): add visit method

GitOrigin-RevId: 251ecebf87
release-1.6
Megvii Engine Team 4 years ago
parent
commit
9a6a379346
1 changed files with 40 additions and 6 deletions
  1. +40
    -6
      imperative/python/megengine/experimental/traced_module/traced_module.py

+ 40
- 6
imperative/python/megengine/experimental/traced_module/traced_module.py View File

@@ -10,7 +10,7 @@ import collections
import copy
import functools
from inspect import getmembers, isclass, ismethod
from typing import List, Type
from typing import Dict, List, Type

from ... import module as M
from ...core._imperative_rt.core2 import Tensor as RawTensor
@@ -64,6 +64,14 @@ class InternalGraph:
def insert(self, expr):
self._exprs.append(expr)

@property
def inputs(self):
return self._inputs

@property
def outputs(self):
return self._outputs

def add_input(self, i):
self._inputs.append(i)

@@ -271,6 +279,22 @@ class TracedModuleBuilder(NodeMixin):
return wrapped


class _expr_list:
def __init__(self, module: "TracedModule"):
self.module = module

def __iter__(self):
graph = self.module.m_node.graph
for expr in graph._exprs:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr
assert isinstance(expr.inputs[0].expr, GetAttr)
(obj,) = expr.inputs[0].expr.interpret(self.module)
if isinstance(obj, TracedModule):
yield from obj.exprs
yield expr


class TracedModule(Module):
"""
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called.
@@ -291,14 +315,21 @@ class TracedModule(Module):
return rst

@property
def all_exprs(self):
def exprs(self):
"""
Visit all ``Expr``s in the graph recursively.
Get all ``Expr`` s recursively.

:return: List[Expr]
:return: Iterator[Expr]
"""
return _expr_list(self)

in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self]
def flatten(self):
"""
Get a new module, which eliminates ``GetAttr`` and has no hierarchy.

:return: :class:`TracedModule`
"""
new_module = copy.deepcopy(self)

def _flatten_submodule(module, call=None):
if not isinstance(module, TracedModule):
@@ -328,6 +359,7 @@ class TracedModule(Module):
elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode):
assert isinstance(expr.inputs[0].expr, GetAttr)
(obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_submodule(obj, expr))
else:
@@ -337,7 +369,9 @@ class TracedModule(Module):

return exprs

return in_nodes + _flatten_submodule(self)
new_module.m_node.graph._exprs = _flatten_submodule(new_module)

return new_module

def __getstate__(self):
d = self.__dict__


Loading…
Cancel
Save