|
@@ -10,7 +10,7 @@ import collections |
|
|
import copy |
|
|
import copy |
|
|
import functools |
|
|
import functools |
|
|
from inspect import getmembers, isclass, ismethod |
|
|
from inspect import getmembers, isclass, ismethod |
|
|
from typing import List, Type |
|
|
|
|
|
|
|
|
from typing import Dict, List, Type |
|
|
|
|
|
|
|
|
from ... import module as M |
|
|
from ... import module as M |
|
|
from ...core._imperative_rt.core2 import Tensor as RawTensor |
|
|
from ...core._imperative_rt.core2 import Tensor as RawTensor |
|
@@ -64,6 +64,14 @@ class InternalGraph: |
|
|
def insert(self, expr): |
|
|
def insert(self, expr): |
|
|
self._exprs.append(expr) |
|
|
self._exprs.append(expr) |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
|
def inputs(self): |
|
|
|
|
|
return self._inputs |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
|
def outputs(self): |
|
|
|
|
|
return self._outputs |
|
|
|
|
|
|
|
|
def add_input(self, i): |
|
|
def add_input(self, i): |
|
|
self._inputs.append(i) |
|
|
self._inputs.append(i) |
|
|
|
|
|
|
|
@@ -271,6 +279,22 @@ class TracedModuleBuilder(NodeMixin): |
|
|
return wrapped |
|
|
return wrapped |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _expr_list: |
|
|
|
|
|
def __init__(self, module: "TracedModule"): |
|
|
|
|
|
self.module = module |
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
|
|
graph = self.module.m_node.graph |
|
|
|
|
|
for expr in graph._exprs: |
|
|
|
|
|
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): |
|
|
|
|
|
yield expr |
|
|
|
|
|
assert isinstance(expr.inputs[0].expr, GetAttr) |
|
|
|
|
|
(obj,) = expr.inputs[0].expr.interpret(self.module) |
|
|
|
|
|
if isinstance(obj, TracedModule): |
|
|
|
|
|
yield from obj.exprs |
|
|
|
|
|
yield expr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TracedModule(Module): |
|
|
class TracedModule(Module): |
|
|
""" |
|
|
""" |
|
|
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called. |
|
|
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called. |
|
@@ -291,14 +315,21 @@ class TracedModule(Module): |
|
|
return rst |
|
|
return rst |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def all_exprs(self): |
|
|
|
|
|
|
|
|
def exprs(self): |
|
|
""" |
|
|
""" |
|
|
Visit all ``Expr``s in the graph recursively. |
|
|
|
|
|
|
|
|
Get all ``Expr`` s recursively. |
|
|
|
|
|
|
|
|
:return: List[Expr] |
|
|
|
|
|
|
|
|
:return: Iterator[Expr] |
|
|
""" |
|
|
""" |
|
|
|
|
|
return _expr_list(self) |
|
|
|
|
|
|
|
|
in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self] |
|
|
|
|
|
|
|
|
def flatten(self): |
|
|
|
|
|
""" |
|
|
|
|
|
Get a new module, which eliminates ``GetAttr`` and has no hierarchy. |
|
|
|
|
|
|
|
|
|
|
|
:return: :class:`TracedModule` |
|
|
|
|
|
""" |
|
|
|
|
|
new_module = copy.deepcopy(self) |
|
|
|
|
|
|
|
|
def _flatten_submodule(module, call=None): |
|
|
def _flatten_submodule(module, call=None): |
|
|
if not isinstance(module, TracedModule): |
|
|
if not isinstance(module, TracedModule): |
|
@@ -328,6 +359,7 @@ class TracedModule(Module): |
|
|
elif isinstance(expr, CallMethod): |
|
|
elif isinstance(expr, CallMethod): |
|
|
obj_node = expr.inputs[0] |
|
|
obj_node = expr.inputs[0] |
|
|
if isinstance(obj_node, ModuleNode): |
|
|
if isinstance(obj_node, ModuleNode): |
|
|
|
|
|
assert isinstance(expr.inputs[0].expr, GetAttr) |
|
|
(obj,) = expr.inputs[0].expr.interpret(module) |
|
|
(obj,) = expr.inputs[0].expr.interpret(module) |
|
|
exprs.extend(_flatten_submodule(obj, expr)) |
|
|
exprs.extend(_flatten_submodule(obj, expr)) |
|
|
else: |
|
|
else: |
|
@@ -337,7 +369,9 @@ class TracedModule(Module): |
|
|
|
|
|
|
|
|
return exprs |
|
|
return exprs |
|
|
|
|
|
|
|
|
return in_nodes + _flatten_submodule(self) |
|
|
|
|
|
|
|
|
new_module.m_node.graph._exprs = _flatten_submodule(new_module) |
|
|
|
|
|
|
|
|
|
|
|
return new_module |
|
|
|
|
|
|
|
|
def __getstate__(self): |
|
|
def __getstate__(self): |
|
|
d = self.__dict__ |
|
|
d = self.__dict__ |
|
|