|
|
@@ -42,6 +42,12 @@ def _leaf_type(node): |
|
|
|
return type(node) |
|
|
|
|
|
|
|
|
|
|
|
def _is_const_leaf(node): |
|
|
|
if isinstance(node, (RawTensor, NodeMixin, Module)): |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
class InternalGraph: |
|
|
|
""" |
|
|
|
``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method. |
|
|
@@ -72,6 +78,10 @@ class InternalGraph: |
|
|
|
def outputs(self): |
|
|
|
return self._outputs |
|
|
|
|
|
|
|
@property |
|
|
|
def exprs(self): |
|
|
|
return _expr_list(self) |
|
|
|
|
|
|
|
def add_input(self, i): |
|
|
|
self._inputs.append(i) |
|
|
|
|
|
|
@@ -111,7 +121,9 @@ def _wrapped_function(orig_func): |
|
|
|
def wrapped_fn(*args, **kwargs): |
|
|
|
if is_tracing_module(): |
|
|
|
unset_module_tracing() |
|
|
|
inputs, tree_def = tree_flatten((args, kwargs), leaf_type=_leaf_type) |
|
|
|
inputs, tree_def = tree_flatten( |
|
|
|
(args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf |
|
|
|
) |
|
|
|
for i in inputs: |
|
|
|
if not NodeMixin.get(i, None): |
|
|
|
if isinstance(i, (RawTensor, NodeMixin)): |
|
|
@@ -140,21 +152,18 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
_mod = None # type: Module |
|
|
|
_body = None # type: InternalGraph |
|
|
|
_is_builtin = None # type: bool |
|
|
|
_arg_def = None # type: TreeDef |
|
|
|
__builder_attributes__ = [ |
|
|
|
"_mod", |
|
|
|
"_body", |
|
|
|
"_NodeMixin__node", |
|
|
|
"_is_builtin", |
|
|
|
"_is_traced", |
|
|
|
"_arg_def" "build", |
|
|
|
"build", |
|
|
|
] |
|
|
|
|
|
|
|
def __init__(self, mod): |
|
|
|
def __init__(self, mod, is_top_module=False): |
|
|
|
super(TracedModuleBuilder, self).__init__() |
|
|
|
self._mod = mod |
|
|
|
self._body = InternalGraph() |
|
|
|
self._is_traced = False |
|
|
|
self._body = None |
|
|
|
self._is_builtin = module_tracer.is_builtin(mod) |
|
|
|
|
|
|
|
def build(self): |
|
|
@@ -164,9 +173,6 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
return self._mod |
|
|
|
else: |
|
|
|
node = NodeMixin.get(self) |
|
|
|
node.graph = self._body |
|
|
|
node.attr_type_map = {} |
|
|
|
node.arg_def = self._arg_def |
|
|
|
traced_module = TracedModule(node) |
|
|
|
for k, v in self.__dict__.items(): |
|
|
|
if k not in TracedModuleBuilder.__builder_attributes__: |
|
|
@@ -178,21 +184,15 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
assert isinstance(self._mod, Module) |
|
|
|
for arg in args: |
|
|
|
assert isinstance(arg, RawTensor) |
|
|
|
|
|
|
|
for k, v in kwargs.items(): |
|
|
|
assert isinstance(v, RawTensor) |
|
|
|
# prepare args and kwargs for inner graph |
|
|
|
def mark_constant(x): |
|
|
|
node = NodeMixin.get(x, None) |
|
|
|
if node is None: # capture as constant |
|
|
|
NodeMixin.wrap(x, lambda: Constant.make(x)) |
|
|
|
|
|
|
|
inputs, tree_def = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type) |
|
|
|
if self._arg_def is None: |
|
|
|
self._arg_def = tree_def |
|
|
|
assert self._arg_def == tree_def |
|
|
|
inputs, tree_def = tree_flatten( |
|
|
|
((self, *args), kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf |
|
|
|
) |
|
|
|
for i in inputs: |
|
|
|
mark_constant(i) |
|
|
|
callnode = CallMethod.make(NodeMixin.get(self)) |
|
|
@@ -201,13 +201,14 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
|
|
|
|
callnode.arg_def = tree_def |
|
|
|
|
|
|
|
if self._is_builtin or self._is_traced: |
|
|
|
if self._is_builtin: |
|
|
|
unset_module_tracing() |
|
|
|
outputs = self._mod(*args, **kwargs) |
|
|
|
set_module_tracing() |
|
|
|
if self._is_builtin: |
|
|
|
self._body = None |
|
|
|
else: |
|
|
|
self._body = InternalGraph() |
|
|
|
active_module_tracer().push_scope(self._body) |
|
|
|
# rebind self to new input node |
|
|
|
orig_self = NodeMixin.get(self) |
|
|
@@ -238,11 +239,12 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
active_module_tracer().current_scope().add_output(NodeMixin.get(i)) |
|
|
|
|
|
|
|
NodeMixin.wrap_safe(self, orig_self) |
|
|
|
self._is_traced = True |
|
|
|
active_module_tracer().pop_scope() |
|
|
|
|
|
|
|
# rebind output to outer graph |
|
|
|
callnode.add_outputs(outputs) |
|
|
|
self_node = NodeMixin.get(self) |
|
|
|
self_node.argdef_graph_map[callnode.arg_def] = self._body |
|
|
|
return outputs |
|
|
|
|
|
|
|
def __getattr__(self, name): |
|
|
@@ -280,24 +282,23 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
|
|
|
|
|
|
|
|
class _expr_list: |
|
|
|
def __init__(self, module: "TracedModule"): |
|
|
|
self.module = module |
|
|
|
def __init__(self, graph: InternalGraph): |
|
|
|
self.graph = graph |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
graph = self.module.m_node.graph |
|
|
|
for expr in graph._exprs: |
|
|
|
for expr in self.graph._exprs: |
|
|
|
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): |
|
|
|
yield expr |
|
|
|
assert isinstance(expr.inputs[0].expr, GetAttr) |
|
|
|
(obj,) = expr.inputs[0].expr.interpret(self.module) |
|
|
|
if isinstance(obj, TracedModule): |
|
|
|
yield from obj.exprs |
|
|
|
yield expr |
|
|
|
if expr.graph is not None: |
|
|
|
yield from expr.graph.exprs |
|
|
|
else: |
|
|
|
yield expr |
|
|
|
|
|
|
|
|
|
|
|
class TracedModule(Module): |
|
|
|
""" |
|
|
|
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called. |
|
|
|
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be |
|
|
|
interpreted by CallMethod Expr. |
|
|
|
""" |
|
|
|
|
|
|
|
m_node = None # type: ModuleNode |
|
|
@@ -307,21 +308,24 @@ class TracedModule(Module): |
|
|
|
self.m_node = node |
|
|
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
inputs, treedef = tree_flatten(((self, *args), kwargs), leaf_type=_leaf_type) |
|
|
|
assert treedef == self.m_node.arg_def |
|
|
|
rst = self.m_node.graph.interpret(*inputs) |
|
|
|
if len(rst) == 1: |
|
|
|
rst = rst[0] |
|
|
|
return rst |
|
|
|
inputs, treedef = tree_flatten( |
|
|
|
((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf |
|
|
|
) |
|
|
|
assert treedef in self.m_node.argdef_graph_map |
|
|
|
inputs = [i for i in inputs if isinstance(i, (Module, RawTensor))] |
|
|
|
outputs = self.m_node.argdef_graph_map[treedef].interpret(*inputs) |
|
|
|
if len(outputs) == 1: |
|
|
|
return outputs[0] |
|
|
|
return outputs |
|
|
|
|
|
|
|
@property |
|
|
|
def exprs(self): |
|
|
|
""" |
|
|
|
Get all ``Expr`` s recursively. |
|
|
|
def graph(self): |
|
|
|
assert len(self.m_node.argdef_graph_map) == 1 |
|
|
|
return list(self.m_node.argdef_graph_map.values())[0] |
|
|
|
|
|
|
|
:return: Iterator[Expr] |
|
|
|
""" |
|
|
|
return _expr_list(self) |
|
|
|
@property |
|
|
|
def exprs(self): |
|
|
|
return self.graph.exprs |
|
|
|
|
|
|
|
def flatten(self): |
|
|
|
""" |
|
|
@@ -331,24 +335,26 @@ class TracedModule(Module): |
|
|
|
""" |
|
|
|
new_module = copy.deepcopy(self) |
|
|
|
|
|
|
|
def _flatten_submodule(module, call=None): |
|
|
|
if not isinstance(module, TracedModule): |
|
|
|
call.inputs[0] = module |
|
|
|
return (call,) |
|
|
|
|
|
|
|
def _flatten_subgraph(graph, module, call=None): |
|
|
|
if graph is None: |
|
|
|
assert not isinstance(module, TracedModule) |
|
|
|
const = Constant(module) |
|
|
|
modulenode = const.outputs[0] |
|
|
|
modulenode.module_type = type(module) |
|
|
|
call.inputs[0] = modulenode |
|
|
|
return [const, call] |
|
|
|
exprs = [] |
|
|
|
|
|
|
|
graph = module.m_node.graph |
|
|
|
for expr in graph._exprs: |
|
|
|
|
|
|
|
# replace inputs for submodule's expr |
|
|
|
for idx, inp in enumerate(expr.inputs): |
|
|
|
if call and inp in graph._inputs: |
|
|
|
expr.inputs[idx] = call.inputs[idx] |
|
|
|
inp_idx = graph._inputs.index(inp) |
|
|
|
expr.inputs[idx] = call.inputs[inp_idx] |
|
|
|
# replace outputs for submodule's expr |
|
|
|
for idx, outp in enumerate(expr.outputs): |
|
|
|
if call and outp in graph._outputs: |
|
|
|
expr.outputs[idx] = call.outputs[idx] |
|
|
|
oup_idx = graph._outputs.index(outp) |
|
|
|
expr.outputs[idx] = call.outputs[oup_idx] |
|
|
|
|
|
|
|
if isinstance(expr, GetAttr): |
|
|
|
# replace GetAttr with Constant |
|
|
@@ -356,12 +362,13 @@ class TracedModule(Module): |
|
|
|
const = Constant(getattr(module, expr.name)) |
|
|
|
const.outputs = expr.outputs |
|
|
|
exprs.append(const) |
|
|
|
|
|
|
|
elif isinstance(expr, CallMethod): |
|
|
|
obj_node = expr.inputs[0] |
|
|
|
if isinstance(obj_node, ModuleNode): |
|
|
|
assert isinstance(expr.inputs[0].expr, GetAttr) |
|
|
|
(obj,) = expr.inputs[0].expr.interpret(module) |
|
|
|
exprs.extend(_flatten_submodule(obj, expr)) |
|
|
|
exprs.extend(_flatten_subgraph(expr.graph, obj, expr)) |
|
|
|
else: |
|
|
|
exprs.append(expr) |
|
|
|
else: |
|
|
@@ -369,7 +376,7 @@ class TracedModule(Module): |
|
|
|
|
|
|
|
return exprs |
|
|
|
|
|
|
|
new_module.m_node.graph._exprs = _flatten_submodule(new_module) |
|
|
|
new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module) |
|
|
|
|
|
|
|
return new_module |
|
|
|
|
|
|
@@ -421,7 +428,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: |
|
|
|
global_scope = InternalGraph() |
|
|
|
active_module_tracer().push_scope(global_scope) |
|
|
|
|
|
|
|
builder = TracedModuleBuilder(mod) |
|
|
|
builder = TracedModuleBuilder(mod, True) |
|
|
|
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) |
|
|
|
inputs, _ = tree_flatten((args, kwargs)) |
|
|
|
for _, i in enumerate(inputs): |
|
|
|