@@ -9,6 +9,7 @@
import collections
import collections
import copy
import copy
import functools
import functools
import weakref
from inspect import getmembers, isclass, ismethod
from inspect import getmembers, isclass, ismethod
from typing import Callable, Dict, Iterable, List, Sequence, Type
from typing import Callable, Dict, Iterable, List, Sequence, Type
@@ -51,7 +52,9 @@ def _leaf_type(node):
def _is_leaf(node):
def _is_leaf(node):
assert isinstance(node, RawTensor), type(node)
assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
type(node)
)
return isinstance(node, RawTensor)
return isinstance(node, RawTensor)
@@ -107,6 +110,32 @@ class InternalGraph:
def add_output(self, o):
def add_output(self, o):
self._outputs.append(o)
self._outputs.append(o)
def _replace_inputs_outputs(self, repl_dict):
for node, repl_node in repl_dict.items():
assert node in self._inputs or node in self._outputs
for i in node.users:
if i not in repl_node.users:
repl_node.users.append(i)
for idx, i in enumerate(self._inputs):
if i in repl_dict:
self._inputs[idx] = repl_dict[i]
for idx, o in enumerate(self._outputs):
if o in repl_dict:
self._outputs[idx] = repl_dict[o]
self._outputs[idx].expr = node.expr
for expr in self._exprs:
for idx, i in enumerate(expr.inputs):
if i in repl_dict:
expr.inputs[idx] = repl_dict[i]
for idx, o in enumerate(expr.outputs):
if o in repl_dict:
expr.outputs[idx] = repl_dict[o]
def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]:
if not isinstance(nodes, Sequence):
if not isinstance(nodes, Sequence):
nodes = (nodes,)
nodes = (nodes,)
@@ -117,6 +146,7 @@ class InternalGraph:
expr = node.expr
expr = node.expr
if expr not in ret:
if expr not in ret:
ret.append(expr)
ret.append(expr)
for i in expr.inputs:
for i in expr.inputs:
if i not in queue:
if i not in queue:
queue.append(i)
queue.append(i)
@@ -287,10 +317,7 @@ def _wrapped_function(orig_func):
call_node.arg_def = tree_def
call_node.arg_def = tree_def
outputs = orig_func(*args, **kwargs)
outputs = orig_func(*args, **kwargs)
if meth_name == "__new__":
call_node.add_outputs(outputs, False)
else:
call_node.add_outputs(outputs)
call_node.add_outputs(outputs)
set_module_tracing()
set_module_tracing()
return outputs
return outputs
return orig_func(*args, **kwargs)
return orig_func(*args, **kwargs)
@@ -303,12 +330,19 @@ class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module
_mod = None # type: Module
_body = None # type: InternalGraph
_body = None # type: InternalGraph
_is_builtin = None # type: bool
_is_builtin = None # type: bool
_argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
_argdef_outdef_map = None # type: Dict[Treedef, Treedef]
nodes = None
__builder_attributes__ = [
__builder_attributes__ = [
"_mod",
"_mod",
"_body",
"_body",
"_NodeMixin__node",
"_NodeMixin__node",
"_is_builtin",
"_is_builtin",
"build",
"build",
"_argdef_graph_map",
"_argdef_outdef_map",
"nodes",
]
]
def __init__(self, mod, is_top_module=False):
def __init__(self, mod, is_top_module=False):
@@ -316,23 +350,36 @@ class TracedModuleBuilder(NodeMixin):
self._mod = mod
self._mod = mod
self._body = None
self._body = None
self._is_builtin = module_tracer.is_builtin(mod)
self._is_builtin = module_tracer.is_builtin(mod)
self._argdef_graph_map = {}
self._argdef_outdef_map = {}
self.nodes = set()
def build(self):
def build(self):
if self._is_builtin:
if self._is_builtin:
node = NodeMixin.get(self)
node.module_type = type(self._mod)
for node in self.nodes:
node.module_type = type(self._mod)
# node._owner = weakref.ref(self._mod)
return self._mod
return self._mod
else:
else:
node = NodeMixin.get(self)
traced_module = TracedModule(node)
traced_module = TracedModule(
self._argdef_graph_map, self._argdef_outdef_map
)
for _, g in self._argdef_graph_map.items():
g.compile()
# for node in self.nodes:
# node._owner = weakref.ref(traced_module)
for k, v in self.__dict__.items():
for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__:
if k not in TracedModuleBuilder.__builder_attributes__:
if isinstance(v, TracedModuleBuilder):
if isinstance(v, TracedModuleBuilder):
v = v.build()
v = v.build()
setattr(traced_module, k, v)
setattr(traced_module, k, v)
traced_module.m_node.attr_type_map[k] = type(v)
return traced_module
return traced_module
def _record_wrapped_nodes(self, node):
self.nodes.add(node)
def __call__(self, *args, **kwargs):
def __call__(self, *args, **kwargs):
assert isinstance(self._mod, Module)
assert isinstance(self._mod, Module)
# prepare args and kwargs for inner graph
# prepare args and kwargs for inner graph
@@ -360,19 +407,30 @@ class TracedModuleBuilder(NodeMixin):
if self._is_builtin:
if self._is_builtin:
self._body = None
self._body = None
else:
else:
self_node = None
if self._body:
self_node = self._body.inputs[0]
self._body = InternalGraph()
self._body = InternalGraph()
active_module_tracer().push_scope(self._body)
active_module_tracer().push_scope(self._body)
# rebind self to new input node
# rebind self to new input node
orig_self = NodeMixin.get(self)
orig_self = NodeMixin.get(self)
NodeMixin.wrap_safe(
self, Input.make("self", NodeMixin.get_wrapped_type(self))
)
if self_node:
NodeMixin.wrap_safe(self, self_node)
active_module_tracer().current_scope().add_input(self_node)
else:
NodeMixin.wrap_safe(
self,
self_node
if self_node
else Input.make("self", NodeMixin.get_wrapped_type(self)),
)
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]]
# prepare args and kwargs for inner graph
# prepare args and kwargs for inner graph
def wrap(x):
def wrap(x):
NodeMixin.wrap(
x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
)
if isinstance(x, (RawTensor, NodeMixin)):
NodeMixin.wrap(
x, lambda: Input.make(type=NodeMixin.get_wrapped_type(x)),
)
return x
return x
args = [self]
args = [self]
@@ -397,9 +455,8 @@ class TracedModuleBuilder(NodeMixin):
# rebind output to outer graph
# rebind output to outer graph
callnode.add_outputs(outputs)
callnode.add_outputs(outputs)
self_node = NodeMixin.get(self)
self_node.argdef_graph_map[callnode.arg_def] = self._body
self_node.argdef_outdef_map[callnode.arg_def] = out_def
self._argdef_graph_map[callnode.arg_def] = self._body
self._argdef_outdef_map[callnode.arg_def] = out_def
return rst
return rst
def __getattr__(self, name):
def __getattr__(self, name):
@@ -424,8 +481,8 @@ class TracedModuleBuilder(NodeMixin):
else:
else:
wrapped = super().__getattribute__(name)
wrapped = super().__getattribute__(name)
if name in self._mod.__dict__:
if name in self._mod.__dict__:
if not NodeMixin.get(wrapped, None):
assert not self._is_builtin
assert not self._is_builtin
if isinstance(wrapped, (NodeMixin, RawTensor)):
NodeMixin.wrap(
NodeMixin.wrap(
wrapped,
wrapped,
lambda: GetAttr.make(
lambda: GetAttr.make(
@@ -434,14 +491,15 @@ class TracedModuleBuilder(NodeMixin):
type=NodeMixin.get_wrapped_type(wrapped),
type=NodeMixin.get_wrapped_type(wrapped),
),
),
)
)
"""
else:
else:
node = NodeMixin.get(wrapped)
node = NodeMixin.get(wrapped)
expr = GetAttr.make(
NodeMixin.get(self),
name,
type=NodeMixin.get_wrapped_type(wrapped),
).expr
expr.outputs[0] = node
expr = node.expr
assert isinstance(expr, GetAttr)
if expr not in active_module_tracer().current_scope()._exprs:
active_module_tracer().current_scope().insert(expr)
"""
return wrapped
return wrapped
@@ -514,33 +572,51 @@ class ExprFilterCallMethod(ExprFilter):
class TracedModule(Module):
class TracedModule(Module):
"""
"""
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be
interpreted by CallMethod Expr.
`TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it.
"""
"""
m_node = None # type: ModuleNode
# m_node = None # type: ModuleNode
argdef_graph_map = None
argdef_outdef_map = None
def __init__(self, node):
def __init__(self, argdef_graph_map, argdef_outdef_map ):
super(TracedModule, self).__init__()
super(TracedModule, self).__init__()
self.m_node = node
self.argdef_graph_map = argdef_graph_map
self.argdef_outdef_map = argdef_outdef_map
def forward(self, *args, **kwargs):
def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten(
inputs, treedef = tree_flatten(
((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
)
)
assert treedef in self.m_node. argdef_graph_map
assert treedef in self.argdef_graph_map
inputs = filter(
inputs = filter(
lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
) # allow TracedModuleBuilder for retrace.
) # allow TracedModuleBuilder for retrace.
outputs = self.m_node. argdef_graph_map[treedef].interpret(*inputs)
out_def = self.m_node. argdef_outdef_map[treedef]
outputs = self.argdef_graph_map[treedef].interpret(*inputs)
out_def = self.argdef_outdef_map[treedef]
outputs = out_def.unflatten(outputs)
outputs = out_def.unflatten(outputs)
return outputs
return outputs
@property
@property
def graph(self):
def graph(self):
assert len(self.m_node.argdef_graph_map) == 1
return list(self.m_node.argdef_graph_map.values())[0]
self._update_modulenode_ref()
assert len(self.argdef_graph_map) == 1
return list(self.argdef_graph_map.values())[0]
def _update_modulenode_ref(self):
for _, graph in self.argdef_graph_map.items():
graph._inputs[0]._owner = weakref.ref(self)
node2obj = {}
node2obj[graph._inputs[0]] = self
for expr in graph._exprs:
if isinstance(expr, GetAttr) and isinstance(
expr.outputs[0], ModuleNode
):
obj = getattr(node2obj[expr.inputs[0]], expr.name)
expr.outputs[0]._owner = weakref.ref(obj)
node2obj[expr.outputs[0]] = obj
if isinstance(obj, TracedModule):
obj._update_modulenode_ref()
@property
@property
def exprs(self):
def exprs(self):
@@ -561,39 +637,49 @@ class TracedModule(Module):
const.outputs[0] = call.inputs[0]
const.outputs[0] = call.inputs[0]
const.outputs[0].expr = const
const.outputs[0].expr = const
return [const, call]
return [const, call]
if call is not None:
graph = copy.deepcopy(graph)
exprs = []
exprs = []
node2obj = {}
node2obj[graph._inputs[0]] = module
if call:
node2obj[call.inputs[0]] = module
for expr in graph._exprs:
for expr in graph._exprs:
# replace inputs for submodule's expr
for idx, inp in enumerate(expr.inputs):
if call and inp in graph._inputs:
inp_idx = graph._inputs.index(inp)
expr.inputs[idx] = call.inputs[inp_idx]
call.inputs[inp_idx].users.append(expr)
# replace outputs for submodule's expr
for idx, outp in enumerate(expr.outputs):
if call and outp in graph._outputs:
oup_idx = graph._outputs.index(outp)
expr.outputs[idx] = call.outputs[oup_idx]
call.outputs[oup_idx].expr = expr
# replace inputs for submodule's exprx
if call:
repl_dict = dict(
zip(graph._inputs + graph._outputs, call.inputs + call.outputs)
)
graph._replace_inputs_outputs(repl_dict)
if isinstance(expr, GetAttr):
if isinstance(expr, GetAttr):
# replace GetAttr with Constant
# replace GetAttr with Constant
if isinstance(expr.outputs[0], TensorNode):
if isinstance(expr.outputs[0], TensorNode):
const = Constant(getattr(module , expr.name))
const = Constant(getattr(node2obj[expr.inputs[0]], expr.name))
const.outputs = expr.outputs
const.outputs = expr.outputs
const.outputs[0].expr = const
const.outputs[0].expr = const
exprs.append(const)
exprs.append(const)
elif isinstance(expr.outputs[0], ModuleNode):
node2obj[expr.outputs[0]] = getattr(
node2obj[expr.inputs[0]], expr.name
)
elif isinstance(expr, CallMethod):
elif isinstance(expr, CallMethod):
obj_node = expr.inputs[0]
obj_node = expr.inputs[0]
if isinstance(obj_node, ModuleNode):
if isinstance(obj_node, ModuleNode):
pre_expr = expr.inputs[0].expr
pre_expr = expr.inputs[0].expr
if isinstance(pre_expr, GetAttr):
if isinstance(pre_expr, GetAttr):
(obj,) = expr.inputs[0].expr.interpret(module)
exprs.extend(_flatten_subgraph(expr.graph, obj, expr))
(obj,) = pre_expr.interpret(node2obj[pre_expr.inputs[0]])
expr_graph = (
obj.argdef_graph_map[expr.arg_def]
if hasattr(obj, "argdef_graph_map")
else None
)
exprs.extend(_flatten_subgraph(expr_graph, obj, expr))
else:
else:
# module has been replaced.
# module has been replaced.
assert isinstance(pre_expr, Constant)
assert isinstance(pre_expr, Constant)
exprs.append(expr)
else:
else:
exprs.append(expr)
exprs.append(expr)
else:
else: