@@ -26,6 +26,12 @@ from ...core._imperative_rt.core2 import (
from ...core._trace_option import set_symbolic_shape
from ...core._trace_option import set_symbolic_shape
from ...core.tensor.array_method import ArrayMethodMixin
from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module
from ...module import Module
from ...quantization.fake_quant import LSQ, TQT, FakeQuantize
from ...quantization.observer import (
ExponentialMovingAverageObserver,
MinMaxObserver,
SyncMinMaxObserver,
)
from ...tensor import Tensor
from ...tensor import Tensor
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
from .module_tracer import (
from .module_tracer import (
@@ -40,15 +46,6 @@ from .pytree import tree_flatten
logger = get_logger(__name__)
logger = get_logger(__name__)
def _leaf_type(node):
if isinstance(node, (RawTensor, TensorNode)):
return (Tensor, TensorNode)
elif isinstance(node, (NodeMixin, Module, ModuleNode)):
return (Module, ModuleNode, NodeMixin)
else:
return type(node)
def _is_leaf(node):
def _is_leaf(node):
assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
assert isinstance(node, RawTensor), "doesn't support {} in return values".format(
type(node)
type(node)
@@ -56,20 +53,10 @@ def _is_leaf(node):
return isinstance(node, RawTensor)
return isinstance(node, RawTensor)
def _is_const_leaf(node):
if isinstance(node, (RawTensor, NodeMixin, Module)):
return False
return True
def wrap_tensors(tensors: Tensor, nodes: TensorNode):
def wrap_tensors(tensors: Tensor, nodes: TensorNode):
inp_tensors = copy.deepcopy(tensors)
inp_tensors = copy.deepcopy(tensors)
inp_tensors, inp_def_v = tree_flatten(
inp_tensors, leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
inp_nodes, inp_def_n = tree_flatten(
nodes, leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
inp_tensors, inp_def_v = tree_flatten(inp_tensors)
inp_nodes, inp_def_n = tree_flatten(nodes)
for v, n in zip(inp_tensors, inp_nodes):
for v, n in zip(inp_tensors, inp_nodes):
if isinstance(n, TensorNode) and isinstance(v, Tensor):
if isinstance(n, TensorNode) and isinstance(v, Tensor):
NodeMixin.wrap_safe(v, n)
NodeMixin.wrap_safe(v, n)
@@ -124,6 +111,9 @@ class InternalGraph:
self._exprs = []
self._exprs = []
self._inputs = []
self._inputs = []
self._outputs = []
self._outputs = []
self._watch_point = []
self._end_point = []
self._rst = collections.defaultdict(list)
def insert(self, expr):
def insert(self, expr):
self._exprs.append(expr)
self._exprs.append(expr)
@@ -177,6 +167,7 @@ class InternalGraph:
for idx, i in enumerate(self._inputs):
for idx, i in enumerate(self._inputs):
if i in repl_dict:
if i in repl_dict:
self._inputs[idx] = repl_dict[i]
self._inputs[idx] = repl_dict[i]
for idx, o in enumerate(self._outputs):
for idx, o in enumerate(self._outputs):
if o in repl_dict:
if o in repl_dict:
self._outputs[idx] = repl_dict[o]
self._outputs[idx] = repl_dict[o]
@@ -224,11 +215,7 @@ class InternalGraph:
moudle = forma_mnode.owner
moudle = forma_mnode.owner
assert moudle._is_top, "reset_inputs only support the top-level graph"
assert moudle._is_top, "reset_inputs only support the top-level graph"
inputs, tree_def = tree_flatten(
((moudle, *args), kwargs),
leaf_type=_leaf_type,
is_const_leaf=_is_const_leaf,
)
inputs, tree_def = tree_flatten(((moudle, *args), kwargs))
def create_node(val: Tensor):
def create_node(val: Tensor):
node = Input(type=TensorNode).outputs[0]
node = Input(type=TensorNode).outputs[0]
@@ -302,7 +289,6 @@ class InternalGraph:
formal_inp_node = create_node(True)
formal_inp_node = create_node(True)
inputs, tree_def = tree_flatten(
inputs, tree_def = tree_flatten(
((*args, formal_inp_node), kwargs),
((*args, formal_inp_node), kwargs),
leaf_type=_leaf_type,
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
)
)
self._inputs[:] = inputs[:]
self._inputs[:] = inputs[:]
@@ -313,7 +299,6 @@ class InternalGraph:
args = args + (create_node(False),)
args = args + (create_node(False),)
inputs, tree_def = tree_flatten(
inputs, tree_def = tree_flatten(
(args, kwargs),
(args, kwargs),
leaf_type=_leaf_type,
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)),
)
)
e.inputs[:] = inputs[:]
e.inputs[:] = inputs[:]
@@ -328,7 +313,7 @@ class InternalGraph:
def reset_outputs(self, outputs):
def reset_outputs(self, outputs):
outputs, out_def = tree_flatten(
outputs, out_def = tree_flatten(
outputs, leaf_type=_leaf_type, is_leaf=lambda x: isinstance(x, TensorNode),
outputs, is_leaf=lambda x: isinstance(x, TensorNode),
)
)
forma_mnode = self.inputs[0]
forma_mnode = self.inputs[0]
@@ -393,9 +378,7 @@ class InternalGraph:
org_out_def = moudle.argdef_outdef_map[tree_def]
org_out_def = moudle.argdef_outdef_map[tree_def]
org_outs = org_out_def.unflatten(self._outputs)
org_outs = org_out_def.unflatten(self._outputs)
outputs, out_def = tree_flatten(
outputs, out_def = tree_flatten(
(org_outs, node),
leaf_type=_leaf_type,
is_leaf=lambda x: isinstance(x, TensorNode),
(org_outs, node), is_leaf=lambda x: isinstance(x, TensorNode),
)
)
self._outputs[:] = outputs
self._outputs[:] = outputs
@@ -404,9 +387,7 @@ class InternalGraph:
actual_node = create_node(node, e)
actual_node = create_node(node, e)
org_outs = org_out_def.unflatten(e.outputs)
org_outs = org_out_def.unflatten(e.outputs)
outputs, out_def = tree_flatten(
outputs, out_def = tree_flatten(
(org_outs, actual_node),
leaf_type=_leaf_type,
is_leaf=lambda x: isinstance(x, TensorNode),
(org_outs, actual_node), is_leaf=lambda x: isinstance(x, TensorNode),
)
)
e.outputs[:] = outputs
e.outputs[:] = outputs
e.out_def = out_def
e.out_def = out_def
@@ -419,9 +400,7 @@ class InternalGraph:
def insert_function(self, func: Callable, *args, **kwargs):
def insert_function(self, func: Callable, *args, **kwargs):
assert isinstance(func, Callable)
assert isinstance(func, Callable)
inp_nodes, inp_def = tree_flatten(
(args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
inp_nodes, inp_def = tree_flatten((args, kwargs))
insert_idx = -1
insert_idx = -1
for i in inp_nodes:
for i in inp_nodes:
@@ -449,7 +428,7 @@ class InternalGraph:
if rst is None:
if rst is None:
return None
return None
outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
node_outputs = []
node_outputs = []
for out in outputs:
for out in outputs:
assert isinstance(out, RawTensor)
assert isinstance(out, RawTensor)
@@ -510,15 +489,40 @@ class InternalGraph:
def interpret(self, *inputs):
def interpret(self, *inputs):
node2value = {}
node2value = {}
end_nodes_set = set(self._end_point)
endnode2value = {}
def get_all_endnode_val(n, v):
if n in end_nodes_set:
endnode2value[n] = v
end_nodes_set.remove(n)
return not end_nodes_set
return False
for n, v in zip(self._inputs, inputs):
for n, v in zip(self._inputs, inputs):
node2value[n] = v
node2value[n] = v
if n in self._watch_point:
self._rst[n].append(v)
if n in self._end_point and get_all_endnode_val(n, v):
return list(endnode2value[i] for i in self._end_point)
for expr in self._exprs:
for expr in self._exprs:
values = expr.interpret(*list(node2value[i] for i in expr.inputs))
values = expr.interpret(*list(node2value[i] for i in expr.inputs))
if values is not None:
if values is not None:
for n, v in zip(expr.outputs, values):
for n, v in zip(expr.outputs, values):
node2value[n] = v
node2value[n] = v
if n in self._watch_point:
self._rst[n] = v
if self._end_point and get_all_endnode_val(n, v):
return list(endnode2value[i] for i in self._end_point)
return list(node2value[i] for i in self._outputs)
return list(node2value[i] for i in self._outputs)
def eval(self, *inputs):
assert len(inputs) == len(self._inputs) - 1
inp = [self._inputs[0].owner] + list(inputs)
return self.interpret(*inp)
def __repr__(self):
def __repr__(self):
return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
", ".join(str(i) for i in self._inputs),
", ".join(str(i) for i in self._inputs),
@@ -541,9 +545,7 @@ def _wrapped_function(orig_func):
def wrapped_fn(*args, **kwargs):
def wrapped_fn(*args, **kwargs):
if is_tracing_module():
if is_tracing_module():
unset_module_tracing()
unset_module_tracing()
inputs, tree_def = tree_flatten(
(args, kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
inputs, tree_def = tree_flatten((args, kwargs))
for i in inputs:
for i in inputs:
if not NodeMixin.get(i, None):
if not NodeMixin.get(i, None):
if isinstance(i, (RawTensor, NodeMixin)):
if isinstance(i, (RawTensor, NodeMixin)):
@@ -575,9 +577,7 @@ def _wrapped_function(orig_func):
if meth_name == "__setitem__":
if meth_name == "__setitem__":
rst = self
rst = self
if rst is not None:
if rst is not None:
outputs, out_def = tree_flatten(
rst, leaf_type=_leaf_type, is_leaf=_is_leaf
)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
call_node.out_def = out_def
call_node.out_def = out_def
else:
else:
outputs = None
outputs = None
@@ -604,13 +604,17 @@ class TracedModuleBuilder(NodeMixin):
"_NodeMixin__node",
"_NodeMixin__node",
"_is_builtin",
"_is_builtin",
"build",
"build",
"_record_wrapped_nodes",
"_argdef_graph_map",
"_argdef_graph_map",
"_argdef_outdef_map",
"_argdef_outdef_map",
"nodes",
"nodes",
"__class__",
"__dict__",
]
]
def __init__(self, mod, is_top_module=False):
def __init__(self, mod, is_top_module=False):
super(TracedModuleBuilder, self).__init__()
super(TracedModuleBuilder, self).__init__()
assert isinstance(mod, Module)
self._mod = mod
self._mod = mod
self._body = None
self._body = None
self._is_top = is_top_module
self._is_top = is_top_module
@@ -618,6 +622,13 @@ class TracedModuleBuilder(NodeMixin):
self._argdef_graph_map = {}
self._argdef_graph_map = {}
self._argdef_outdef_map = {}
self._argdef_outdef_map = {}
self.nodes = set()
self.nodes = set()
# The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__.
# modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__.
self.__class__ = type(
"TracedModuleBuilder",
(TracedModuleBuilder, mod.__class__),
dict(TracedModuleBuilder.__dict__),
)
def build(self):
def build(self):
if self._is_builtin:
if self._is_builtin:
@@ -631,8 +642,6 @@ class TracedModuleBuilder(NodeMixin):
)
)
for _, g in self._argdef_graph_map.items():
for _, g in self._argdef_graph_map.items():
g.compile()
g.compile()
# for node in self.nodes:
# node._owner = weakref.ref(traced_module)
for k, v in self.__dict__.items():
for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__:
if k not in TracedModuleBuilder.__builder_attributes__:
@@ -653,9 +662,7 @@ class TracedModuleBuilder(NodeMixin):
if node is None: # capture as constant
if node is None: # capture as constant
NodeMixin.wrap(x, lambda: Constant.make(x))
NodeMixin.wrap(x, lambda: Constant.make(x))
inputs, tree_def = tree_flatten(
((self, *args), kwargs), leaf_type=_leaf_type, is_const_leaf=_is_const_leaf
)
inputs, tree_def = tree_flatten(((self, *args), kwargs))
for i in inputs:
for i in inputs:
mark_constant(i)
mark_constant(i)
callnode = CallMethod.make(NodeMixin.get(self))
callnode = CallMethod.make(NodeMixin.get(self))
@@ -667,7 +674,7 @@ class TracedModuleBuilder(NodeMixin):
if self._is_builtin:
if self._is_builtin:
unset_module_tracing()
unset_module_tracing()
rst = self._mod(*args, **kwargs)
rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
set_module_tracing()
set_module_tracing()
if self._is_builtin:
if self._is_builtin:
self._body = None
self._body = None
@@ -706,7 +713,7 @@ class TracedModuleBuilder(NodeMixin):
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {})
)
)
rst = type(self._mod).forward(*args, **kwargs)
rst = type(self._mod).forward(*args, **kwargs)
outputs, out_def = tree_flatten(rst, leaf_type=_leaf_type, is_leaf=_is_leaf)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
for i in (
for i in (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
):
):
@@ -725,6 +732,12 @@ class TracedModuleBuilder(NodeMixin):
self._argdef_outdef_map[callnode.arg_def] = out_def
self._argdef_outdef_map[callnode.arg_def] = out_def
return rst
return rst
def __setattr__(self, name, value):
object.__setattr__(self, name, value)
def __repr__(self):
return repr(self._mod)
def __getattr__(self, name):
def __getattr__(self, name):
if name not in self._mod.__dict__:
if name not in self._mod.__dict__:
attr = getattr(type(self._mod), name).__get__(self, type(self))
attr = getattr(type(self._mod), name).__get__(self, type(self))
@@ -743,11 +756,22 @@ class TracedModuleBuilder(NodeMixin):
def __getattribute__(self, name):
def __getattribute__(self, name):
if name in TracedModuleBuilder.__builder_attributes__:
if name in TracedModuleBuilder.__builder_attributes__:
return super().__getattribute__( name)
return object.__getattribute__(self, name)
else:
else:
wrapped = super().__getattribute__( name)
wrapped = object.__getattribute__(self, name)
if name in self._mod.__dict__:
if name in self._mod.__dict__:
assert not self._is_builtin
mod_attr = getattr(self._mod, name)
if not isinstance(mod_attr, Module) and wrapped is not mod_attr:
wrapped = mod_attr
setattr(self, name, wrapped)
if isinstance(mod_attr, Module):
assert mod_attr is wrapped._mod
else:
assert mod_attr is wrapped
# assert not self._is_builtin
if isinstance(wrapped, (NodeMixin, RawTensor)):
if isinstance(wrapped, (NodeMixin, RawTensor)):
NodeMixin.wrap(
NodeMixin.wrap(
wrapped,
wrapped,
@@ -757,14 +781,6 @@ class TracedModuleBuilder(NodeMixin):
type=NodeMixin.get_wrapped_type(wrapped),
type=NodeMixin.get_wrapped_type(wrapped),
),
),
)
)
"""
else:
node = NodeMixin.get(wrapped)
expr = node.expr
assert isinstance(expr, GetAttr)
if expr not in active_module_tracer().current_scope()._exprs:
active_module_tracer().current_scope().insert(expr)
"""
return wrapped
return wrapped
@@ -924,20 +940,57 @@ class TracedModule(Module):
self.argdef_graph_map = argdef_graph_map
self.argdef_graph_map = argdef_graph_map
self.argdef_outdef_map = argdef_outdef_map
self.argdef_outdef_map = argdef_outdef_map
self._is_top = is_top
self._is_top = is_top
self.watch_points = []
self.watch_node_value = {}
self.end_points = []
def forward(self, *args, **kwargs):
def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten(
((self, *args), kwargs), _leaf_type, is_const_leaf=_is_const_leaf
)
inputs, treedef = tree_flatten(((self, *args), kwargs))
assert treedef in self.argdef_graph_map
assert treedef in self.argdef_graph_map
inputs = filter(
inputs = filter(
lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs
) # allow TracedModuleBuilder for retrace.
) # allow TracedModuleBuilder for retrace.
outputs = self.argdef_graph_map[treedef].interpret(*inputs)
outputs = self.argdef_graph_map[treedef].interpret(*inputs)
if self.watch_points:
self.watch_node_value = {}
for n in self.watch_points:
self.watch_node_value[n] = n.top_graph._rst.pop(n)
if self.end_points:
return outputs
out_def = self.argdef_outdef_map[treedef]
out_def = self.argdef_outdef_map[treedef]
outputs = out_def.unflatten(outputs)
outputs = out_def.unflatten(outputs)
return outputs
return outputs
def set_watch_points(self, nodes):
if not isinstance(nodes, Sequence):
nodes = [nodes]
self.watch_points = nodes
for n in nodes:
n.top_graph._watch_point.append(n)
def clear_watch_points(self):
for n in self.watch_points:
n.top_graph._watch_point = []
self.watch_points = []
self.watch_node_value = {}
def set_end_points(self, nodes):
if not isinstance(nodes, Sequence):
nodes = [nodes]
self.end_points = nodes
graphs = list(self.argdef_graph_map.values())
for n in nodes:
assert n.top_graph in graphs
n.top_graph._end_point.append(n)
def clear_end_points(self):
for n in self.end_points:
n.top_graph._end_point = []
self.end_points = []
@property
@property
def graph(self) -> InternalGraph:
def graph(self) -> InternalGraph:
if self._is_top:
if self._is_top:
@@ -1014,6 +1067,9 @@ class TracedModule(Module):
node2obj[graph._inputs[0]] = module
node2obj[graph._inputs[0]] = module
if call:
if call:
node2obj[call.inputs[0]] = module
node2obj[call.inputs[0]] = module
# replace inputs for submodule's exprx
if call:
repl_dict = dict(zip(graph._inputs, call.inputs))
repl_dict = dict(zip(graph._inputs, call.inputs))
for ind, out in enumerate(graph.outputs):
for ind, out in enumerate(graph.outputs):
if isinstance(out.expr, Input):
if isinstance(out.expr, Input):
@@ -1028,8 +1084,8 @@ class TracedModule(Module):
repl_dict[out] = call.outputs[ind]
repl_dict[out] = call.outputs[ind]
graph._replace_inputs_outputs(repl_dict)
graph._replace_inputs_outputs(repl_dict)
for expr in graph._exprs:
for expr in graph._exprs:
if isinstance(expr, GetAttr):
if isinstance(expr, GetAttr):
# replace GetAttr with Constant
# replace GetAttr with Constant
if isinstance(expr.outputs[0], TensorNode):
if isinstance(expr.outputs[0], TensorNode):
@@ -1129,6 +1185,7 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
param kwargs: the keyword arguments passed to forward method of ``mod``
param kwargs: the keyword arguments passed to forward method of ``mod``
"""
"""
assert active_module_tracer() is None
assert active_module_tracer() is None
assert isinstance(mod, Module)
try:
try:
use_sym_shape = set_symbolic_shape(True)
use_sym_shape = set_symbolic_shape(True)
set_module_tracing()
set_module_tracing()
@@ -1140,9 +1197,9 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
builder = TracedModuleBuilder(mod, True)
builder = TracedModuleBuilder(mod, True)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))
inputs, _ = tree_flatten((args, kwargs), is_const_leaf=_is_const_leaf )
inputs, _ = tree_flatten((args, kwargs))
for _, i in enumerate(inputs):
for _, i in enumerate(inputs):
assert isinstance(i, Tensor), "not support "
# assert isinstance(i, Tensor), "not support "
if isinstance(i, RawTensor):
if isinstance(i, RawTensor):
NodeMixin.wrap_safe(
NodeMixin.wrap_safe(
i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))
i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i))