Browse Source

feat(traced_module): add visit method

GitOrigin-RevId: 251ecebf87
release-1.6
Megvii Engine Team 4 years ago
parent
commit
9a6a379346
1 changed files with 40 additions and 6 deletions
  1. +40
    -6
      imperative/python/megengine/experimental/traced_module/traced_module.py

+ 40
- 6
imperative/python/megengine/experimental/traced_module/traced_module.py View File

@@ -10,7 +10,7 @@ import collections
import copy import copy
import functools import functools
from inspect import getmembers, isclass, ismethod from inspect import getmembers, isclass, ismethod
from typing import List, Type
from typing import Dict, List, Type


from ... import module as M from ... import module as M
from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._imperative_rt.core2 import Tensor as RawTensor
@@ -64,6 +64,14 @@ class InternalGraph:
def insert(self, expr): def insert(self, expr):
self._exprs.append(expr) self._exprs.append(expr)


@property
def inputs(self):
return self._inputs

@property
def outputs(self):
return self._outputs

def add_input(self, i): def add_input(self, i):
self._inputs.append(i) self._inputs.append(i)


@@ -271,6 +279,22 @@ class TracedModuleBuilder(NodeMixin):
return wrapped return wrapped




class _expr_list:
def __init__(self, module: "TracedModule"):
self.module = module

def __iter__(self):
graph = self.module.m_node.graph
for expr in graph._exprs:
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode):
yield expr
assert isinstance(expr.inputs[0].expr, GetAttr)
(obj,) = expr.inputs[0].expr.interpret(self.module)
if isinstance(obj, TracedModule):
yield from obj.exprs
yield expr


class TracedModule(Module): class TracedModule(Module):
""" """
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called. `TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called.
@@ -291,14 +315,21 @@ class TracedModule(Module):
return rst return rst


@property @property
def all_exprs(self):
def exprs(self):
""" """
Visit all ``Expr``s in the graph recursively.
Get all ``Expr`` s recursively.


:return: List[Expr]
:return: Iterator[Expr]
""" """
return _expr_list(self)


in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self]
def flatten(self):
"""
Get a new module, which eliminates ``GetAttr`` and has no hierarchy.

:return: :class:`TracedModule`
"""
new_module = copy.deepcopy(self)


def _flatten_submodule(module, call=None): def _flatten_submodule(module, call=None):
if not isinstance(module, TracedModule): if not isinstance(module, TracedModule):
@@ -328,6 +359,7 @@ class TracedModule(Module):
elif isinstance(expr, CallMethod): elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0] obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode): if isinstance(obj_node, ModuleNode):
assert isinstance(expr.inputs[0].expr, GetAttr)
(obj,) = expr.inputs[0].expr.interpret(module) (obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_submodule(obj, expr)) exprs.extend(_flatten_submodule(obj, expr))
else: else:
@@ -337,7 +369,9 @@ class TracedModule(Module):


return exprs return exprs


return in_nodes + _flatten_submodule(self)
new_module.m_node.graph._exprs = _flatten_submodule(new_module)

return new_module


def __getstate__(self): def __getstate__(self):
d = self.__dict__ d = self.__dict__


Loading…
Cancel
Save