@@ -9,6 +9,7 @@
import collections
import copy
import functools
import weakref
from inspect import getmembers, isclass, ismethod
from typing import Callable, Dict, Iterable, List, Sequence, Type
@@ -51,7 +52,9 @@ def _leaf_type(node):
def _is_leaf(node):
assert isinstance(node, RawTensor), type(node)
assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
type(node)
)
return isinstance(node, RawTensor)
@@ -107,6 +110,32 @@ class InternalGraph:
def add_output(self, o):
self._outputs.append(o)
def _replace_inputs_outputs(self, repl_dict):
for node, repl_node in repl_dict.items():
assert node in self._inputs or node in self._outputs
for i in node.users:
if i not in repl_node.users:
repl_node.users.append(i)
for idx, i in enumerate(self._inputs):
if i in repl_dict:
self._inputs[idx] = repl_dict[i]
for idx, o in enumerate(self._outputs):
if o in repl_dict:
self._outputs[idx] = repl_dict[o]
self._outputs[idx].expr = node.expr
for expr in self._exprs:
for idx, i in enumerate(expr.inputs):
if i in repl_dict:
expr.inputs[idx] = repl_dict[i]
for idx, o in enumerate(expr.outputs):
if o in repl_dict:
expr.outputs[idx] = repl_dict[o]
def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
if not isinstance(nodes, Sequence):
nodes = (nodes,)
@@ -117,6 +146,7 @@ class InternalGraph:
expr = node.expr
if expr not in ret:
ret.append(expr)
for i in expr.inputs:
if i not in queue:
queue.append(i)
@@ -287,10 +317,7 @@ def _wrapped_function(orig_func):
call_node.arg_def = tree_def
outputs = orig_func(*args, **kwargs)
if meth_name == "__new__":
call_node.add_outputs(outputs, False)
else:
call_node.add_outputs(outputs)
call_node.add_outputs(outputs)
set_module_tracing()
return outputs
return orig_func(*args, **kwargs)
@@ -303,12 +330,19 @@ class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module
_body = None # type: InternalGraph
_is_builtin = None # type: bool
_argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
_argdef_outdef_map = None # type: Dict[Treedef, Treedef]
nodes = None
__builder_attributes__ = [
"_mod",
"_body",
"_NodeMixin__node",
"_is_builtin",
"build",
"_argdef_graph_map",
"_argdef_outdef_map",
"nodes",
]
def __init__(self, mod, is_top_module=False):
@@ -316,23 +350,36 @@ class TracedModuleBuilder(NodeMixin):
self._mod = mod
self._body = None
self._is_builtin = module_tracer.is_builtin(mod)
self._argdef_graph_map = {}
self._argdef_outdef_map = {}
self.nodes = set()
def build(self):
if self._is_builtin:
node = NodeMixin.get(self)
node.module_type = type(self._mod)
for node in self.nodes:
node.module_type = type(self._mod)
# node._owner = weakref.ref(self._mod)
return self._mod
else:
node = NodeMixin.get(self)
traced_module = TracedModule(node)
traced_module = TracedModule(
self._argdef_graph_map, self._argdef_outdef_map
)
for _, g in self._argdef_graph_map.items():
g.compile()
# for node in self.nodes:
# node._owner = weakref.ref(traced_module)
for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__:
if isinstance(v, TracedModuleBuilder):
v = v.build()
setattr(traced_module, k, v)
traced_module.m_node.attr_type_map[k] = type(v)
return traced_module
def _record_wrapped_nodes(self, node):
self.nodes.add(node)
def __call__(self, *args, **kwargs):
assert isinstance(self._mod, Module)
# prepare args and kwargs for inner graph
@@ -360,19 +407,30 @@ class TracedModuleBuilder(NodeMixin):
if self._is_builtin:
self._body = None
else:
self_node = None
if self._body:
self_node = self._body.inputs[0]
self._body = InternalGraph()
active_module_tracer().push_scope(self._body)
# rebind self to new input node
orig_self = NodeMixin.get(self)
NodeMixin.wrap_safe(
self, Input.make("self", NodeMixin.get_wrapped_type(self))
)
if self_node:
NodeMixin.wrap_safe(self, self_node)
active_module_tracer().current_scope().add_input(self_node)
else:
NodeMixin.wrap_safe(
self,
self_node
if self_node
else Input.make("self", NodeMixin.get_wrapped_type(self)),
)
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
# prepare args and kwargs for inner graph
def wrap(x):
NodeMixin.wrap(
x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
)
if isinstance(x, (RawTensor, NodeMixin)):
NodeMixin.wrap(
x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
)
return x
args = [self]
@@ -397,9 +455,8 @@ class TracedModuleBuilder(NodeMixin):
# rebind output to outer graph
callnode.add_outputs(outputs)
self_node = NodeMixin.get(self)
self_node.argdef_graph_map[callnode.arg_def] = self._body
self_node.argdef_outdef_map[callnode.arg_def] = out_def
self._argdef_graph_map[callnode.arg_def] = self._body
self._argdef_outdef_map[callnode.arg_def] = out_def
return rst
def __getattr__(self, name):
@@ -424,8 +481,8 @@ class TracedModuleBuilder(NodeMixin):
else:
wrapped = super().__getattribute__(name)
if name in self._mod.__dict__:
if not NodeMixin.get(wrapped, None):
assert not self._is_builtin
assert not self._is_builtin
if isinstance(wrapped, (NodeMixin, RawTensor)):
NodeMixin.wrap(
wrapped,
lambda: GetAttr.make(
@@ -434,14 +491,15 @@ class TracedModuleBuilder(NodeMixin):
type=NodeMixin.get_wrapped_type(wrapped),
),
)
"""
else:
node = NodeMixin.get(wrapped)
expr = GetAttr.make(
NodeMixin.get(self),
name,
type=NodeMixin.get_wrapped_type(wrapped),
).expr
expr.outputs[0] = node
expr = node.expr
assert isinstance(expr, GetAttr)
if expr not in active_module_tracer().current_scope()._exprs:
active_module_tracer().current_scope().insert(expr)
"""
return wrapped
@@ -514,33 +572,51 @@ class ExprFilterCallMethod(ExprFilter):
class TracedModule(Module):
"""
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be
interpreted by CallMethod Expr.
`TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it.
"""
m_node = None # type: ModuleNode
# m_node = None # type: ModuleNode
argdef_graph_map = None
argdef_outdef_map = None
def __init__(self, node):
def __init__(self, argdef_graph_map, argdef_outdef_map ):
super(TracedModule, self).__init__()
self.m_node = node
self.argdef_graph_map = argdef_graph_map
self.argdef_outdef_map = argdef_outdef_map
def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten(
((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
)
assert treedef in self.m_node. argdef_graph_map
assert treedef in self.argdef_graph_map
inputs = filter(
lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
) # allow TracedModuleBuilder for retrace.
outputs = self.m_node. argdef_graph_map[treedef].interpret(*inputs)
out_def = self.m_node. argdef_outdef_map[treedef]
outputs = self.argdef_graph_map[treedef].interpret(*inputs)
out_def = self.argdef_outdef_map[treedef]
outputs = out_def.unflatten(outputs)
return outputs
@property
def graph(self):
assert len(self.m_node.argdef_graph_map) == 1
return list(self.m_node.argdef_graph_map.values())[0]
self._update_modulenode_ref()
assert len(self.argdef_graph_map) == 1
return list(self.argdef_graph_map.values())[0]
def _update_modulenode_ref(self):
for _, graph in self.argdef_graph_map.items():
graph._inputs[0]._owner = weakref.ref(self)
node2obj = {}
node2obj[graph._inputs[0]] = self
for expr in graph._exprs:
if isinstance(expr, GetAttr) and isinstance(
expr.outputs[0], ModuleNode
):
obj = getattr(node2obj[expr.inputs[0]], expr.name)
expr.outputs[0]._owner = weakref.ref(obj)
node2obj[expr.outputs[0]] = obj
if isinstance(obj, TracedModule):
obj._update_modulenode_ref()
@property
def exprs(self):
@@ -561,39 +637,49 @@ class TracedModule(Module):
const.outputs[0] = call.inputs[0]
const.outputs[0].expr = const
return [const, call]
if call is not None:
graph = copy.deepcopy(graph)
exprs = []
node2obj = {}
node2obj[graph._inputs[0]] = module
if call:
node2obj[call.inputs[0]] = module
for expr in graph._exprs:
# replace inputs for submodule's expr
for idx, inp in enumerate(expr.inputs):
if call and inp in graph._inputs:
inp_idx = graph._inputs.index(inp)
expr.inputs[idx] = call.inputs[inp_idx]
call.inputs[inp_idx].users.append(expr)
# replace outputs for submodule's expr
for idx, outp in enumerate(expr.outputs):
if call and outp in graph._outputs:
oup_idx = graph._outputs.index(outp)
expr.outputs[idx] = call.outputs[oup_idx]
call.outputs[oup_idx].expr = expr
# replace inputs for submodule's exprx
if call:
repl_dict = dict(
zip(graph._inputs + graph._outputs, call.inputs + call.outputs)
)
graph._replace_inputs_outputs(repl_dict)
if isinstance(expr, GetAttr):
# replace GetAttr with Constant
if isinstance(expr.outputs[0], TensorNode):
const = Constant(getattr(module , expr.name))
const = Constant(getattr(node2obj[expr.inputs[0]], expr.name))
const.outputs = expr.outputs
const.outputs[0].expr = const
exprs.append(const)
elif isinstance(expr.outputs[0], ModuleNode):
node2obj[expr.outputs[0]] = getattr(
node2obj[expr.inputs[0]], expr.name
)
elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode):
pre_expr = expr.inputs[0].expr
if isinstance(pre_expr, GetAttr):
(obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_subgraph(expr.graph, obj, expr))
(obj,) = pre_expr.interpret(node2obj[pre_expr.inputs[0]])
expr_graph = (
obj.argdef_graph_map[expr.arg_def]
if hasattr(obj, "argdef_graph_map")
else None
)
exprs.extend(_flatten_subgraph(expr_graph, obj, expr))
else:
# module has been replaced.
assert isinstance(pre_expr, Constant)
exprs.append(expr)
else:
exprs.append(expr)
else: