|
|
@@ -9,14 +9,19 @@ |
|
|
|
import builtins |
|
|
|
import collections |
|
|
|
import copy |
|
|
|
import ctypes |
|
|
|
import fnmatch |
|
|
|
import functools |
|
|
|
import inspect |
|
|
|
import keyword |
|
|
|
import re |
|
|
|
import weakref |
|
|
|
from inspect import getcallargs, getmembers, isclass, ismethod |
|
|
|
from itertools import chain |
|
|
|
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union |
|
|
|
|
|
|
|
from megengine import tensor |
|
|
|
|
|
|
|
from ... import functional as F |
|
|
|
from ... import get_logger |
|
|
|
from ... import module as M |
|
|
@@ -44,8 +49,10 @@ from ...tensor import Tensor |
|
|
|
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input |
|
|
|
from .fake_quant import FakeQuantize as TM_FakeQuant |
|
|
|
from .module_tracer import ( |
|
|
|
PatchedFn, |
|
|
|
Patcher, |
|
|
|
active_module_tracer, |
|
|
|
get_tensor_wrapable_method, |
|
|
|
module_tracer, |
|
|
|
set_active_module_tracer, |
|
|
|
) |
|
|
@@ -70,46 +77,267 @@ def _is_leaf(node): |
|
|
|
return isinstance(node, RawTensor) |
|
|
|
|
|
|
|
|
|
|
|
def wrap_tensors(tensors: Tensor, nodes: TensorNode): |
|
|
|
inp_tensors = copy.deepcopy(tensors) |
|
|
|
inp_tensors, inp_def_v = tree_flatten(inp_tensors) |
|
|
|
inp_nodes, inp_def_n = tree_flatten(nodes) |
|
|
|
for v, n in zip(inp_tensors, inp_nodes): |
|
|
|
if isinstance(n, TensorNode) and isinstance(v, Tensor): |
|
|
|
NodeMixin.wrap_safe(v, n) |
|
|
|
return inp_def_v.unflatten(inp_tensors) |
|
|
|
_enable_node_to_tensor = False |
|
|
|
|
|
|
|
|
|
|
|
def _convert_node_flag(): |
|
|
|
return _enable_node_to_tensor |
|
|
|
|
|
|
|
|
|
|
|
def _set_convert_node_flag(flag: bool = False): |
|
|
|
global _enable_node_to_tensor |
|
|
|
pre_flag = _enable_node_to_tensor |
|
|
|
_enable_node_to_tensor = flag |
|
|
|
return pre_flag |
|
|
|
|
|
|
|
|
|
|
|
def _node_to_tensor(*args, **kwargs): |
|
|
|
tensors = [] |
|
|
|
nodes, tree_def = tree_flatten((args, kwargs)) |
|
|
|
for n in nodes: |
|
|
|
if isinstance(n, TensorNode): |
|
|
|
if n.top_graph is not None: |
|
|
|
active_module_tracer().current_scope()._add_input(n) |
|
|
|
value = n.value |
|
|
|
if value is None: |
|
|
|
flag = _set_convert_node_flag(False) |
|
|
|
unset_module_tracing() |
|
|
|
value = F.zeros(shape=n._shape, dtype=n._dtype) |
|
|
|
set_module_tracing() |
|
|
|
_set_convert_node_flag(flag) |
|
|
|
orig_n = NodeMixin.get(value, None) |
|
|
|
if orig_n is None or "setitem" not in orig_n._name: |
|
|
|
NodeMixin.wrap_safe(value, n) |
|
|
|
tensors.append(value) |
|
|
|
else: |
|
|
|
tensors.append(n) |
|
|
|
tensors = tree_def.unflatten(tensors) |
|
|
|
return tensors |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_to_node(tensors): |
|
|
|
if tensors is None: |
|
|
|
return None |
|
|
|
nodes = [] |
|
|
|
tensors, out_def = tree_flatten(tensors) |
|
|
|
for t in tensors: |
|
|
|
if isinstance(t, Tensor): |
|
|
|
n = NodeMixin.get(t, None) |
|
|
|
if isinstance(n, TensorNode): |
|
|
|
n.value = t |
|
|
|
nodes.append(n) |
|
|
|
else: |
|
|
|
nodes.append(t) |
|
|
|
else: |
|
|
|
nodes.append(t) |
|
|
|
nodes = out_def.unflatten(nodes) |
|
|
|
return nodes |
|
|
|
|
|
|
|
|
|
|
|
def _wrap_method_to_tensor_node(): |
|
|
|
def _any_method(name): |
|
|
|
def _any(*args, **kwargs): |
|
|
|
args, kwargs = _node_to_tensor(*args, **kwargs) |
|
|
|
attr = getattr(args[0], name) |
|
|
|
outs = attr |
|
|
|
if callable(attr): |
|
|
|
outs = attr(*(args[1:]), **kwargs) |
|
|
|
if name == "__setitem__": |
|
|
|
_node_to_tensor(outs) |
|
|
|
return None |
|
|
|
outs = _tensor_to_node(outs) |
|
|
|
return outs |
|
|
|
|
|
|
|
return _any |
|
|
|
|
|
|
|
tensor_method_patch = [] |
|
|
|
for method in get_tensor_wrapable_method(): |
|
|
|
patch = PatchedFn(TensorNode, method) |
|
|
|
if type(getattr(Tensor, method)) == property: |
|
|
|
patch.set_func(property(_any_method(method))) |
|
|
|
else: |
|
|
|
patch.set_func(_any_method(method)) |
|
|
|
tensor_method_patch.append(patch) |
|
|
|
return tensor_method_patch |
|
|
|
|
|
|
|
|
|
|
|
def _convert_node_and_tensor(orig_func): |
|
|
|
@functools.wraps(orig_func) |
|
|
|
def _convert(*args, **kwargs): |
|
|
|
if _convert_node_flag() and is_tracing_module(): |
|
|
|
args, kwargs = _node_to_tensor(*args, **kwargs) |
|
|
|
rst = orig_func(*args, **kwargs, method_func=_convert) |
|
|
|
rst = _tensor_to_node(rst) |
|
|
|
return rst |
|
|
|
else: |
|
|
|
rst = orig_func(*args, **kwargs) |
|
|
|
return rst |
|
|
|
|
|
|
|
return _convert |
|
|
|
|
|
|
|
|
|
|
|
def _wrap_mnode_getattr(orig_getattr): |
|
|
|
@functools.wraps(orig_getattr) |
|
|
|
def wraped_fn(self, name): |
|
|
|
obj = self.owner |
|
|
|
if self.top_graph is not None: |
|
|
|
active_module_tracer().current_scope()._add_input(self) |
|
|
|
attr = getattr(obj, name) |
|
|
|
node = attr |
|
|
|
full_name = None |
|
|
|
if id(attr) in active_module_tracer().id2name: |
|
|
|
full_name = active_module_tracer().id2name[id(attr)] |
|
|
|
|
|
|
|
if not isinstance(attr, TracedModuleBuilder): |
|
|
|
if isinstance(attr, Module): |
|
|
|
attr = TracedModuleBuilder(attr) |
|
|
|
setattr(obj, name, attr) |
|
|
|
active_module_tracer().id2name[id(attr)] = full_name |
|
|
|
|
|
|
|
if isinstance(attr, (NodeMixin, RawTensor)): |
|
|
|
if full_name: |
|
|
|
scope_name = active_module_tracer().current_scope()._module_name |
|
|
|
if scope_name: |
|
|
|
full_name = full_name[len(scope_name) + 1 :] |
|
|
|
else: |
|
|
|
full_name = name |
|
|
|
else: |
|
|
|
full_name = name |
|
|
|
NodeMixin.wrap( |
|
|
|
attr, |
|
|
|
lambda: GetAttr.make( |
|
|
|
self, |
|
|
|
name, |
|
|
|
type=NodeMixin.get_wrapped_type(attr), |
|
|
|
orig_name=full_name, |
|
|
|
), |
|
|
|
) |
|
|
|
if isinstance(attr, (NodeMixin, RawTensor)): |
|
|
|
node = NodeMixin.get(attr) |
|
|
|
if isinstance(node, ModuleNode): |
|
|
|
node._owner = weakref.ref(attr) |
|
|
|
return node |
|
|
|
|
|
|
|
return wraped_fn |
|
|
|
|
|
|
|
|
|
|
|
def _wrap_mnode_call(orig_call): |
|
|
|
@functools.wraps(orig_call) |
|
|
|
def wraped_fn(self, *args, **kwargs): |
|
|
|
obj = self.owner |
|
|
|
if self.top_graph is not None: |
|
|
|
active_module_tracer().current_scope()._add_input(self) |
|
|
|
rst = obj(*args, **kwargs) |
|
|
|
return rst |
|
|
|
|
|
|
|
return wraped_fn |
|
|
|
|
|
|
|
|
|
|
|
def _init_id2name(mod: Module, prefix: str = ""): |
|
|
|
id2name = { |
|
|
|
id(m): "%s.%s" % (prefix, key) |
|
|
|
for key, m in chain( |
|
|
|
mod.named_modules(), mod.named_parameters(), mod.named_buffers() |
|
|
|
) |
|
|
|
} |
|
|
|
return id2name |
|
|
|
|
|
|
|
|
|
|
|
class _InsertExprs: |
|
|
|
def __init__(self, graph, expr: Optional[Expr] = None, after: bool = True): |
|
|
|
def __init__(self, graph, expr: Optional[Expr] = None): |
|
|
|
self.graph = graph |
|
|
|
self.global_scope = InternalGraph() |
|
|
|
self.global_scope = InternalGraph( |
|
|
|
graph._name, graph._prefix_name, graph._module_name |
|
|
|
) |
|
|
|
self.global_scope._used_names.update(graph._used_names) |
|
|
|
self.expr = expr |
|
|
|
self.after = after |
|
|
|
self._tensor_method_patch = None |
|
|
|
|
|
|
|
def __enter__(self): |
|
|
|
self.use_sym_shape = set_symbolic_shape(True) |
|
|
|
set_module_tracing() |
|
|
|
_set_convert_node_flag(True) |
|
|
|
assert active_module_tracer() is None |
|
|
|
set_active_module_tracer(module_tracer(_wrapped_function)) |
|
|
|
module = self.graph.inputs[0].owner |
|
|
|
_wrap_func = lambda x: _convert_node_and_tensor(_wrapped_function(x)) |
|
|
|
set_active_module_tracer( |
|
|
|
module_tracer(_wrap_func, _init_id2name(module, self.graph._module_name)) |
|
|
|
) |
|
|
|
active_module_tracer().patcher.__enter__() |
|
|
|
for cls, name, func in [ |
|
|
|
[ModuleNode, "__getattr__", _wrap_mnode_getattr], |
|
|
|
[ModuleNode, "__call__", _wrap_mnode_call], |
|
|
|
[TracedModuleBuilder, "__call__", _convert_node_and_tensor], |
|
|
|
]: |
|
|
|
active_module_tracer().patcher.patch_function(cls, name, func) |
|
|
|
self._tensor_method_patch = _wrap_method_to_tensor_node() |
|
|
|
active_module_tracer().push_scope(self.global_scope) |
|
|
|
|
|
|
|
def __exit__(self, ty, va, tr): |
|
|
|
if va is not None: |
|
|
|
return False |
|
|
|
set_symbolic_shape(self.use_sym_shape) |
|
|
|
unset_module_tracing() |
|
|
|
active_module_tracer().patcher.__exit__(ty, va, tr) |
|
|
|
_set_convert_node_flag(False) |
|
|
|
|
|
|
|
while self._tensor_method_patch: |
|
|
|
pf = self._tensor_method_patch.pop() |
|
|
|
pf.set_func(pf.origin_fn) |
|
|
|
|
|
|
|
module = self.graph.inputs[0].owner |
|
|
|
|
|
|
|
for mod, parent in module.modules(with_parent=True): |
|
|
|
name = mod._name |
|
|
|
if isinstance(mod, TracedModuleBuilder): |
|
|
|
mod = mod.build() |
|
|
|
if hasattr(mod, "graph"): |
|
|
|
for node in mod.graph.nodes(): |
|
|
|
node.value = None |
|
|
|
setattr(parent, name, mod) |
|
|
|
set_active_module_tracer(None) |
|
|
|
index = len(self.graph._exprs) if self.after else 0 |
|
|
|
|
|
|
|
for node in self.global_scope.nodes(): |
|
|
|
node.value = None |
|
|
|
|
|
|
|
extra_inp_nodes = set(self.global_scope.inputs) |
|
|
|
max_inp_expr_idx = -1 |
|
|
|
for node in extra_inp_nodes: |
|
|
|
assert ( |
|
|
|
node.top_graph == self.graph |
|
|
|
), "The input node ({}) is not in the graph ({})".format(node, self.graph) |
|
|
|
if isinstance(node, TensorNode) and node.expr in self.graph._exprs: |
|
|
|
max_inp_expr_idx = max( |
|
|
|
max_inp_expr_idx, self.graph._exprs.index(node.expr) |
|
|
|
) |
|
|
|
max_inp_expr_idx += 1 |
|
|
|
|
|
|
|
insert_index = -1 |
|
|
|
if self.expr is not None: |
|
|
|
index = self.graph._exprs.index(self.expr) |
|
|
|
if self.after: |
|
|
|
index += 1 |
|
|
|
insert_index = self.graph._exprs.index(self.expr) |
|
|
|
insert_index += 1 |
|
|
|
|
|
|
|
if insert_index < max_inp_expr_idx: |
|
|
|
insert_index = max_inp_expr_idx |
|
|
|
|
|
|
|
anchor_index = insert_index - 1 |
|
|
|
if anchor_index >= 0: |
|
|
|
logger.info( |
|
|
|
"The new expr will be inserted after ( {} )".format( |
|
|
|
self.graph._exprs[anchor_index] |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
for expr in self.global_scope._exprs: |
|
|
|
self.graph._exprs.insert(index, expr) |
|
|
|
index += 1 |
|
|
|
self.graph._exprs.insert(insert_index, expr) |
|
|
|
insert_index += 1 |
|
|
|
|
|
|
|
self.graph._used_names.update(self.global_scope._used_names) |
|
|
|
graph = self.graph |
|
|
|
while graph.top_graph is not None: |
|
|
|
graph = graph.top_graph |
|
|
|
graph.inputs[0].owner._update_ref() |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
class InternalGraph: |
|
|
@@ -125,8 +353,9 @@ class InternalGraph: |
|
|
|
_exprs = None # type: List[Expr] |
|
|
|
_inputs = None # type: List[Node] |
|
|
|
_outputs = None # type: List[Node] |
|
|
|
_top_graph = None |
|
|
|
|
|
|
|
def __init__(self, name: str = None, prefix_name: str = ""): |
|
|
|
def __init__(self, name: str = None, prefix_name: str = "", module_name: str = ""): |
|
|
|
self._exprs = [] |
|
|
|
self._inputs = [] |
|
|
|
self._outputs = [] |
|
|
@@ -136,12 +365,13 @@ class InternalGraph: |
|
|
|
self._rst = collections.defaultdict(list) |
|
|
|
self._name = name |
|
|
|
self._prefix_name = prefix_name |
|
|
|
self._module_name = module_name |
|
|
|
|
|
|
|
def insert(self, expr): |
|
|
|
def _insert(self, expr): |
|
|
|
self._exprs.append(expr) |
|
|
|
|
|
|
|
def _create_unique_name(self, name: str) -> str: |
|
|
|
assert isinstance(name, str) |
|
|
|
assert isinstance(name, str), "The name must be a str" |
|
|
|
name = re.sub("[^0-9a-zA-Z_]+", "_", name) |
|
|
|
if name[0].isdigit(): |
|
|
|
name = "_{}".format(name) |
|
|
@@ -166,40 +396,45 @@ class InternalGraph: |
|
|
|
return self._outputs |
|
|
|
|
|
|
|
@property |
|
|
|
def expr_filter(self): |
|
|
|
return ExprFilter(_expr_iter(self)) |
|
|
|
def top_graph(self): |
|
|
|
if self._top_graph: |
|
|
|
return self._top_graph() |
|
|
|
return None |
|
|
|
|
|
|
|
@property |
|
|
|
def node_filter(self): |
|
|
|
return NodeFilter(_node_iter(self)) |
|
|
|
def exprs(self, recursive=True): |
|
|
|
return ExprFilter(_expr_iter(self, recursive)) |
|
|
|
|
|
|
|
def nodes(self, recursive=True): |
|
|
|
return NodeFilter(_node_iter(self, recursive)) |
|
|
|
|
|
|
|
def get_function_by_type(self, func: Callable = None): |
|
|
|
return self.expr_filter.call_function(func) |
|
|
|
def get_function_by_type(self, func: Callable = None, recursive=True): |
|
|
|
return self.exprs(recursive).call_function(func) |
|
|
|
|
|
|
|
def get_method_by_type(self, method: str = None): |
|
|
|
return self.expr_filter.call_method(method) |
|
|
|
def get_method_by_type(self, method: str = None, recursive=True): |
|
|
|
return self.exprs(recursive).call_method(method) |
|
|
|
|
|
|
|
def get_expr_by_id(self, expr_id: List[int] = None): |
|
|
|
return self.expr_filter.expr_id(expr_id) |
|
|
|
def get_expr_by_id(self, expr_id: List[int] = None, recursive=True): |
|
|
|
return self.exprs(recursive).expr_id(expr_id) |
|
|
|
|
|
|
|
def get_module_by_type(self, module_cls: Module): |
|
|
|
def get_module_by_type(self, module_cls: Module, recursive=True): |
|
|
|
assert issubclass(module_cls, Module) |
|
|
|
return self.node_filter.type(module_cls, ModuleNode) |
|
|
|
return self.nodes(recursive).type(module_cls, ModuleNode) |
|
|
|
|
|
|
|
def get_node_by_id(self, node_id: List[int] = None): |
|
|
|
return self.node_filter.node_id(node_id) |
|
|
|
def get_node_by_id(self, node_id: List[int] = None, recursive=True): |
|
|
|
return self.nodes(recursive).node_id(node_id) |
|
|
|
|
|
|
|
def get_node_by_name(self, name: str = None, ignorecase: bool = True): |
|
|
|
return self.node_filter.name(name, ignorecase) |
|
|
|
def get_node_by_name( |
|
|
|
self, name: str = None, ignorecase: bool = True, recursive=True |
|
|
|
): |
|
|
|
return self.nodes(recursive).name(name, ignorecase) |
|
|
|
|
|
|
|
def add_input(self, i): |
|
|
|
def _add_input(self, i): |
|
|
|
self._inputs.append(i) |
|
|
|
|
|
|
|
def add_output(self, o): |
|
|
|
def _add_output(self, o): |
|
|
|
self._outputs.append(o) |
|
|
|
|
|
|
|
def _replace_inputs_outputs_and_add_prefixname(self, repl_dict, prefix_name=""): |
|
|
|
|
|
|
|
def _replace_inputs_outputs(self, repl_dict, prefix_name="", module_name=""): |
|
|
|
for node, repl_node in repl_dict.items(): |
|
|
|
assert node in self._inputs or node in self._outputs |
|
|
|
for i in node.users: |
|
|
@@ -212,12 +447,15 @@ class InternalGraph: |
|
|
|
|
|
|
|
for idx, o in enumerate(self._outputs): |
|
|
|
if o in repl_dict: |
|
|
|
repl_dict[o]._orig_name = "{}{}".format(module_name, o._orig_name) |
|
|
|
self._outputs[idx] = repl_dict[o] |
|
|
|
|
|
|
|
for expr in self._exprs: |
|
|
|
|
|
|
|
for idx, i in enumerate(expr.inputs): |
|
|
|
assert i._name is not None |
|
|
|
assert isinstance( |
|
|
|
i._name, str |
|
|
|
), "The node ({}) name must be a str".format(i) |
|
|
|
if i in repl_dict: |
|
|
|
expr.inputs[idx] = repl_dict[i] |
|
|
|
elif isinstance(i, TensorNode) and prefix_name not in i._name: |
|
|
@@ -227,9 +465,12 @@ class InternalGraph: |
|
|
|
.current_scope() |
|
|
|
._create_unique_name(prefix_name + i._name.lstrip("_")) |
|
|
|
) |
|
|
|
i._orig_name = "{}{}".format(module_name, i._orig_name) |
|
|
|
|
|
|
|
for idx, o in enumerate(expr.outputs): |
|
|
|
assert o._name is not None |
|
|
|
assert isinstance( |
|
|
|
o._name, str |
|
|
|
), "The node ({}) name must be a str".format(i) |
|
|
|
if o in repl_dict: |
|
|
|
expr.outputs[idx] = repl_dict[o] |
|
|
|
expr.outputs[idx].expr = expr |
|
|
@@ -240,6 +481,7 @@ class InternalGraph: |
|
|
|
.current_scope() |
|
|
|
._create_unique_name(prefix_name + o._name.lstrip("_")) |
|
|
|
) |
|
|
|
o._orig_name = "{}{}".format(module_name, o._orig_name) |
|
|
|
|
|
|
|
def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: |
|
|
|
if not isinstance(nodes, Sequence): |
|
|
@@ -263,7 +505,7 @@ class InternalGraph: |
|
|
|
|
|
|
|
def reset_inputs(self, *args, **kwargs): |
|
|
|
forma_mnode = self.inputs[0] |
|
|
|
actual_mnodes = forma_mnode.actual_mnode |
|
|
|
actual_mnodes = forma_mnode.actual_node |
|
|
|
call_nodes = [] |
|
|
|
for n in actual_mnodes: |
|
|
|
for c_expr in n.users: |
|
|
@@ -318,7 +560,7 @@ class InternalGraph: |
|
|
|
|
|
|
|
def add_input_node(self, shape, dtype="float32", name="args"): |
|
|
|
forma_mnode = self.inputs[0] |
|
|
|
actual_mnodes = forma_mnode.actual_mnode |
|
|
|
actual_mnodes = forma_mnode.actual_node |
|
|
|
|
|
|
|
moudle = forma_mnode.owner |
|
|
|
assert moudle._is_top, "add_input_node only support the top-level graph" |
|
|
@@ -378,7 +620,7 @@ class InternalGraph: |
|
|
|
moudle = forma_mnode.owner |
|
|
|
assert moudle._is_top, "reset_outputs only support the top-level graph" |
|
|
|
|
|
|
|
actual_mnodes = forma_mnode.actual_mnode |
|
|
|
actual_mnodes = forma_mnode.actual_node |
|
|
|
call_nodes = [] |
|
|
|
for n in actual_mnodes: |
|
|
|
for c_expr in n.users: |
|
|
@@ -406,7 +648,6 @@ class InternalGraph: |
|
|
|
|
|
|
|
self._outputs[:] = outputs |
|
|
|
moudle.argdef_outdef_map[tree_def] = out_def |
|
|
|
|
|
|
|
return actual_nodes |
|
|
|
|
|
|
|
def add_output_node(self, node: TensorNode): |
|
|
@@ -415,7 +656,7 @@ class InternalGraph: |
|
|
|
moudle = forma_mnode.owner |
|
|
|
assert moudle._is_top, "add_output_node only support the top-level graph" |
|
|
|
|
|
|
|
actual_mnodes = forma_mnode.actual_mnode |
|
|
|
actual_mnodes = forma_mnode.actual_node |
|
|
|
call_nodes = [] |
|
|
|
|
|
|
|
for n in actual_mnodes: |
|
|
@@ -455,74 +696,35 @@ class InternalGraph: |
|
|
|
|
|
|
|
return actual_out_nodes |
|
|
|
|
|
|
|
def insert_function(self, func: Callable, *args, **kwargs): |
|
|
|
assert isinstance(func, Callable) |
|
|
|
|
|
|
|
inp_nodes, inp_def = tree_flatten((args, kwargs)) |
|
|
|
|
|
|
|
insert_idx = -1 |
|
|
|
for i in inp_nodes: |
|
|
|
if isinstance(i, TensorNode) and i.expr in self._exprs: |
|
|
|
insert_idx = max(insert_idx, self._exprs.index(i.expr)) |
|
|
|
|
|
|
|
fake_inp_val = list( |
|
|
|
F.zeros(shape=i.shape, dtype=i.dtype) if isinstance(i, TensorNode) else i |
|
|
|
for i in inp_nodes |
|
|
|
) |
|
|
|
|
|
|
|
for v, n in zip(fake_inp_val, inp_nodes): |
|
|
|
if isinstance(n, TensorNode): |
|
|
|
NodeMixin.wrap_safe(v, n) |
|
|
|
|
|
|
|
fake_args, fake_kwargs = inp_def.unflatten(fake_inp_val) |
|
|
|
|
|
|
|
insert_point = self.insert_exprs_before() |
|
|
|
if insert_idx != -1: |
|
|
|
insert_point = self.insert_exprs_after(self._exprs[insert_idx]) |
|
|
|
|
|
|
|
with insert_point: |
|
|
|
rst = func(*fake_args, **fake_kwargs) |
|
|
|
|
|
|
|
if rst is None: |
|
|
|
return None |
|
|
|
|
|
|
|
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) |
|
|
|
node_outputs = [] |
|
|
|
for out in outputs: |
|
|
|
assert isinstance(out, RawTensor) |
|
|
|
node_outputs.append(NodeMixin.get(out, None)) |
|
|
|
|
|
|
|
node_outputs = out_def.unflatten(node_outputs) |
|
|
|
return node_outputs |
|
|
|
|
|
|
|
def insert_exprs_after(self, expr: Optional[Expr] = None): |
|
|
|
def insert_exprs(self, expr: Optional[Expr] = None): |
|
|
|
if expr is not None: |
|
|
|
assert expr.top_graph == self, "Expr to insert after is not in graph." |
|
|
|
return _InsertExprs(self, expr, after=True) |
|
|
|
|
|
|
|
def insert_exprs_before(self, expr: Optional[Expr] = None): |
|
|
|
if expr is not None: |
|
|
|
assert expr.top_graph == self, "Expr to insert before is not in graph." |
|
|
|
return _InsertExprs(self, expr, after=False) |
|
|
|
return _InsertExprs(self, expr) |
|
|
|
|
|
|
|
def replace_node(self, repl_dict: Dict[Node, Node]): |
|
|
|
while repl_dict: |
|
|
|
node, repl_node = repl_dict.popitem() |
|
|
|
# check graph inputs and outputs |
|
|
|
assert node not in self.inputs, "Cannot replace inputs" |
|
|
|
# assert node not in self.inputs, "Cannot replace inputs" |
|
|
|
for i, n in enumerate(self.outputs): |
|
|
|
if n is node: |
|
|
|
self.outputs[i] = repl_node |
|
|
|
# update users of node and repl_node |
|
|
|
# update inputs of expr in node.users |
|
|
|
graph = repl_node.top_graph |
|
|
|
assert graph is not None |
|
|
|
index = graph._exprs.index(repl_node.expr) |
|
|
|
dep_exprs = self.get_dep_exprs(repl_node) |
|
|
|
i = 0 |
|
|
|
while i < len(node.users): |
|
|
|
n = node.users[i] |
|
|
|
if n in graph._exprs and index >= graph._exprs.index(n): |
|
|
|
i += 1 |
|
|
|
continue |
|
|
|
if n in dep_exprs: |
|
|
|
logger.info("Find a loop: ignore this replacement once") |
|
|
|
logger.info("node: %s" % node.__repr__()) |
|
|
|
logger.info("repl_node: %s" % repl_node.__repr__()) |
|
|
|
logger.info("expr: %s" % n.__repr__()) |
|
|
|
i += 1 |
|
|
|
continue |
|
|
|
repl_node.users.append(n) |
|
|
@@ -598,6 +800,12 @@ class InternalGraph: |
|
|
|
Node.set_format_spec(saved_format_spec) |
|
|
|
return res |
|
|
|
|
|
|
|
def __getstate__(self): |
|
|
|
state = self.__dict__.copy() |
|
|
|
if "_top_graph" in state: |
|
|
|
state.pop("_top_graph") |
|
|
|
return state |
|
|
|
|
|
|
|
|
|
|
|
def _get_meth_name(obj, func): |
|
|
|
tp = obj if isinstance(obj, type) else type(obj) |
|
|
@@ -611,6 +819,9 @@ def _get_meth_name(obj, func): |
|
|
|
def _wrapped_function(orig_func): |
|
|
|
@functools.wraps(orig_func) |
|
|
|
def wrapped_fn(*args, **kwargs): |
|
|
|
method_func = wrapped_fn |
|
|
|
if "method_func" in kwargs: |
|
|
|
method_func = kwargs.pop("method_func") |
|
|
|
if is_tracing_module(): |
|
|
|
unset_module_tracing() |
|
|
|
inputs, tree_def = tree_flatten((args, kwargs)) |
|
|
@@ -618,9 +829,11 @@ def _wrapped_function(orig_func): |
|
|
|
if not NodeMixin.get(i, None): |
|
|
|
if isinstance(i, (RawTensor, NodeMixin)): |
|
|
|
NodeMixin.wrap_safe(i, Constant.make(i)) |
|
|
|
meth_name = _get_meth_name(args[0], wrapped_fn) if args else None |
|
|
|
arg_type = args[0] if isinstance(args[0], type) else type(args[0]) |
|
|
|
if meth_name and issubclass(arg_type, RawTensor): |
|
|
|
meth_name, arg_type = None, None |
|
|
|
if args: |
|
|
|
meth_name = _get_meth_name(args[0], method_func) |
|
|
|
arg_type = args[0] if isinstance(args[0], type) else type(args[0]) |
|
|
|
if meth_name and arg_type and issubclass(arg_type, RawTensor): |
|
|
|
self = inputs[0] |
|
|
|
if meth_name == "__new__": |
|
|
|
if all([not isinstance(i, RawTensor) for i in inputs]): |
|
|
@@ -799,6 +1012,9 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
assert isinstance(self._mod, Module) |
|
|
|
# prepare args and kwargs for inner graph |
|
|
|
if "method_func" in kwargs: |
|
|
|
kwargs.pop("method_func") |
|
|
|
|
|
|
|
def mark_constant(x): |
|
|
|
node = NodeMixin.get(x, None) |
|
|
|
if node is None: # capture as constant |
|
|
@@ -829,9 +1045,6 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
else: |
|
|
|
self._mod._is_top = False |
|
|
|
self._body = self._mod.graph |
|
|
|
name = NodeMixin.get(self)._name |
|
|
|
if name: |
|
|
|
self._body._name = name |
|
|
|
else: |
|
|
|
self_node = None |
|
|
|
orig_self = NodeMixin.get(self) |
|
|
@@ -841,19 +1054,24 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
graph_prefix_name = "{}_{}".format( |
|
|
|
top_graph._prefix_name, graph_prefix_name.lstrip("_") |
|
|
|
) |
|
|
|
self._body = InternalGraph(orig_self._name, prefix_name=graph_prefix_name) |
|
|
|
module_name = orig_self._orig_name |
|
|
|
if top_graph._module_name: |
|
|
|
module_name = "{}.{}".format(top_graph._module_name, module_name) |
|
|
|
self._body = InternalGraph( |
|
|
|
orig_self._name, prefix_name=graph_prefix_name, module_name=module_name |
|
|
|
) |
|
|
|
active_module_tracer().push_scope(self._body) |
|
|
|
# rebind self to new input node |
|
|
|
|
|
|
|
if self_node: |
|
|
|
NodeMixin.wrap_safe(self, self_node) |
|
|
|
active_module_tracer().current_scope().add_input(self_node) |
|
|
|
active_module_tracer().current_scope()._add_input(self_node) |
|
|
|
else: |
|
|
|
NodeMixin.wrap_safe( |
|
|
|
self, |
|
|
|
self_node |
|
|
|
if self_node |
|
|
|
else Input.make("self", NodeMixin.get_wrapped_type(self)), |
|
|
|
else Input.make("self", NodeMixin.get_wrapped_type(self), ""), |
|
|
|
) |
|
|
|
origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] |
|
|
|
# prepare args and kwargs for inner graph |
|
|
@@ -893,12 +1111,13 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) |
|
|
|
) |
|
|
|
rst = type(self._mod).forward(*args, **kwargs) |
|
|
|
if _convert_node_flag(): |
|
|
|
rst = _node_to_tensor(rst)[0][0] |
|
|
|
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) |
|
|
|
for i in ( |
|
|
|
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) |
|
|
|
): |
|
|
|
active_module_tracer().current_scope().add_output(NodeMixin.get(i)) |
|
|
|
NodeMixin.get(self, None).actual_mnode.append(orig_self) |
|
|
|
active_module_tracer().current_scope()._add_output(NodeMixin.get(i)) |
|
|
|
NodeMixin.wrap_safe(self, orig_self) |
|
|
|
for arg, node in zip(inputs[1:], origin_inp_node): |
|
|
|
if node: |
|
|
@@ -923,14 +1142,33 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
attr = getattr(type(self._mod), name).__get__(self, type(self)) |
|
|
|
else: |
|
|
|
attr = getattr(self._mod, name) |
|
|
|
full_name = None |
|
|
|
|
|
|
|
if id(attr) in active_module_tracer().id2name: |
|
|
|
full_name = active_module_tracer().id2name[id(attr)] |
|
|
|
|
|
|
|
if isinstance(attr, Module): |
|
|
|
attr = TracedModuleBuilder(attr) |
|
|
|
|
|
|
|
if isinstance(attr, (Module, RawTensor)): |
|
|
|
setattr(self, name, attr) |
|
|
|
active_module_tracer().id2name[id(attr)] = full_name |
|
|
|
|
|
|
|
if full_name: |
|
|
|
scope_name = active_module_tracer().current_scope()._module_name |
|
|
|
if scope_name: |
|
|
|
full_name = full_name[len(scope_name) + 1 :] |
|
|
|
else: |
|
|
|
full_name = name |
|
|
|
else: |
|
|
|
full_name = name |
|
|
|
NodeMixin.wrap( |
|
|
|
attr, |
|
|
|
lambda: GetAttr.make( |
|
|
|
NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr) |
|
|
|
NodeMixin.get(self), |
|
|
|
name, |
|
|
|
type=NodeMixin.get_wrapped_type(attr), |
|
|
|
orig_name=full_name, |
|
|
|
), |
|
|
|
) |
|
|
|
return attr |
|
|
@@ -951,7 +1189,16 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
assert mod_attr is wrapped._mod |
|
|
|
else: |
|
|
|
assert mod_attr is wrapped |
|
|
|
|
|
|
|
full_name = None |
|
|
|
if id(mod_attr) in active_module_tracer().id2name: |
|
|
|
full_name = active_module_tracer().id2name[id(mod_attr)] |
|
|
|
scope_name = active_module_tracer().current_scope()._module_name |
|
|
|
if full_name and scope_name: |
|
|
|
full_name = full_name[len(scope_name) + 1 :] |
|
|
|
else: |
|
|
|
full_name = name |
|
|
|
else: |
|
|
|
full_name = name |
|
|
|
# assert not self._is_builtin |
|
|
|
if isinstance(wrapped, (NodeMixin, RawTensor)): |
|
|
|
NodeMixin.wrap( |
|
|
@@ -960,6 +1207,7 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
NodeMixin.get(self), |
|
|
|
name, |
|
|
|
type=NodeMixin.get_wrapped_type(wrapped), |
|
|
|
orig_name=full_name, |
|
|
|
), |
|
|
|
) |
|
|
|
|
|
|
@@ -967,24 +1215,25 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
|
|
|
|
|
|
|
|
class _expr_iter: |
|
|
|
def __init__(self, graph: InternalGraph): |
|
|
|
def __init__(self, graph: InternalGraph, recursive: bool = True): |
|
|
|
self.graph = graph |
|
|
|
self.recursive = recursive |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
for expr in self.graph._exprs: |
|
|
|
if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): |
|
|
|
yield expr |
|
|
|
if expr.graph is not None: |
|
|
|
yield from expr.graph.expr_filter |
|
|
|
if self.recursive and expr.graph is not None: |
|
|
|
yield from expr.graph.exprs(self.recursive) |
|
|
|
else: |
|
|
|
yield expr |
|
|
|
|
|
|
|
|
|
|
|
class _node_iter: |
|
|
|
def __init__(self, graph: InternalGraph) -> None: |
|
|
|
def __init__(self, graph: InternalGraph, recursive: bool = True) -> None: |
|
|
|
nodes = [] |
|
|
|
node_ids = set() |
|
|
|
for expr in graph.expr_filter: |
|
|
|
for expr in graph.exprs(recursive): |
|
|
|
for n in expr.inputs + expr.outputs: |
|
|
|
if n._id in node_ids: |
|
|
|
continue |
|
|
@@ -1210,14 +1459,17 @@ class TracedModule(Module): |
|
|
|
assert len(self.argdef_graph_map) == 1 |
|
|
|
return list(self.argdef_graph_map.values())[0] |
|
|
|
|
|
|
|
def _update_ref(self, actual_node_map: Union[Dict] = None): |
|
|
|
def _update_ref(self, actual_node_map: Union[Dict] = None, top_graph=None): |
|
|
|
for inp_def, graph in self.argdef_graph_map.items(): |
|
|
|
if top_graph is not None: |
|
|
|
graph._top_graph = weakref.ref(top_graph) |
|
|
|
for n in graph._inputs + graph.outputs: |
|
|
|
n._top_graph = weakref.ref(graph) |
|
|
|
graph._inputs[0]._owner = weakref.ref(self) |
|
|
|
graph._inputs[0].actual_mnode = [] |
|
|
|
if actual_node_map is not None and inp_def in actual_node_map.keys(): |
|
|
|
graph._inputs[0].actual_mnode = actual_node_map[inp_def] |
|
|
|
for i, n in enumerate(graph._inputs): |
|
|
|
n.actual_node = [] |
|
|
|
if actual_node_map is not None and inp_def in actual_node_map.keys(): |
|
|
|
n.actual_node = list(list(zip(*(actual_node_map[inp_def])))[i]) |
|
|
|
node2obj = {} |
|
|
|
next_actual_node_map = collections.defaultdict( |
|
|
|
lambda: collections.defaultdict(list) |
|
|
@@ -1246,7 +1498,7 @@ class TracedModule(Module): |
|
|
|
): |
|
|
|
obj = node2obj[expr.inputs[0]] |
|
|
|
if expr.arg_def is not None: |
|
|
|
next_actual_node_map[obj][expr.arg_def].append(expr.inputs[0]) |
|
|
|
next_actual_node_map[obj][expr.arg_def].append(expr.inputs) |
|
|
|
|
|
|
|
for obj in node2obj.values(): |
|
|
|
if obj is self: |
|
|
@@ -1255,7 +1507,7 @@ class TracedModule(Module): |
|
|
|
if obj in next_actual_node_map.keys(): |
|
|
|
mnode_map = next_actual_node_map[obj] |
|
|
|
if isinstance(obj, TracedModule): |
|
|
|
obj._update_ref(mnode_map) |
|
|
|
obj._update_ref(mnode_map, graph) |
|
|
|
|
|
|
|
def flatten(self): |
|
|
|
""" |
|
|
@@ -1264,21 +1516,25 @@ class TracedModule(Module): |
|
|
|
:return: :class:`TracedModule` |
|
|
|
""" |
|
|
|
new_module = copy.deepcopy(self) |
|
|
|
module2name = {} |
|
|
|
assert active_module_tracer() is None |
|
|
|
set_active_module_tracer(module_tracer(lambda x: x)) |
|
|
|
id2name = _init_id2name(new_module, "self") |
|
|
|
set_active_module_tracer(module_tracer(lambda x: x, {})) |
|
|
|
active_module_tracer().push_scope(new_module.graph) |
|
|
|
for n, m in new_module.named_modules(): |
|
|
|
module2name[id(m)] = n |
|
|
|
|
|
|
|
def _flatten_subgraph( |
|
|
|
graph: InternalGraph, module: Module, call=None, prefix_name="" |
|
|
|
graph: InternalGraph, |
|
|
|
module: Module, |
|
|
|
call=None, |
|
|
|
prefix_name="", |
|
|
|
module_name="", |
|
|
|
): |
|
|
|
if graph is not None and prefix_name and prefix_name[-1] != "_": |
|
|
|
if isinstance(prefix_name, str) and prefix_name and prefix_name[-1] != "_": |
|
|
|
prefix_name += "_" |
|
|
|
if isinstance(module_name, str) and module_name: |
|
|
|
module_name += "." |
|
|
|
if graph is None or module.is_qat: |
|
|
|
assert not isinstance(module, TracedModule) or module.is_qat |
|
|
|
const = Constant(module, "self.%s" % module2name[id(module)]) |
|
|
|
const = Constant(module, id2name[id(module)]) |
|
|
|
m_node = call.inputs[0] |
|
|
|
if m_node.top_graph != active_module_tracer().current_scope(): |
|
|
|
m_node._name = ( |
|
|
@@ -1286,6 +1542,7 @@ class TracedModule(Module): |
|
|
|
.current_scope() |
|
|
|
._create_unique_name(prefix_name) |
|
|
|
) |
|
|
|
m_node._orig_name = id2name[id(module)][5:] |
|
|
|
const.outputs[0] = m_node |
|
|
|
const.outputs[0].expr = const |
|
|
|
return [const, call] |
|
|
@@ -1312,7 +1569,7 @@ class TracedModule(Module): |
|
|
|
continue |
|
|
|
repl_dict[out] = call.outputs[ind] |
|
|
|
|
|
|
|
graph._replace_inputs_outputs_and_add_prefixname(repl_dict, prefix_name) |
|
|
|
graph._replace_inputs_outputs(repl_dict, prefix_name, module_name) |
|
|
|
|
|
|
|
for expr in graph._exprs: |
|
|
|
if isinstance(expr, GetAttr): |
|
|
@@ -1344,6 +1601,7 @@ class TracedModule(Module): |
|
|
|
obj, |
|
|
|
expr, |
|
|
|
prefix_name + obj_node._name.lstrip("_"), |
|
|
|
module_name + obj_node._orig_name, |
|
|
|
) |
|
|
|
) |
|
|
|
else: |
|
|
@@ -1358,7 +1616,6 @@ class TracedModule(Module): |
|
|
|
if call is not None: |
|
|
|
for i in call.inputs: |
|
|
|
i.users.remove(call) |
|
|
|
|
|
|
|
return exprs |
|
|
|
|
|
|
|
new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module) |
|
|
@@ -1396,7 +1653,22 @@ def register_as_builtin(mod_cls: Type[Module]) -> None: |
|
|
|
module_tracer.register_as_builtin(mod_cls) |
|
|
|
|
|
|
|
|
|
|
|
wrap = _wrapped_function |
|
|
|
def wrap(func: Callable): |
|
|
|
""" |
|
|
|
Call this function to register func as a builtin function. |
|
|
|
""" |
|
|
|
assert callable(func), "func must be a callable" |
|
|
|
assert hasattr(func, "__code__") |
|
|
|
fn_name = func.__code__.co_name |
|
|
|
currentframe = inspect.currentframe() |
|
|
|
assert currentframe is not None |
|
|
|
f = currentframe.f_back |
|
|
|
assert f is not None |
|
|
|
assert ( |
|
|
|
f.f_code.co_name == "<module>" |
|
|
|
), "wrap must be called at the top level of a module" |
|
|
|
Patcher._builtin_functions.append((f.f_globals, fn_name)) |
|
|
|
return func |
|
|
|
|
|
|
|
|
|
|
|
def _register_all_builtin_module(): |
|
|
@@ -1438,14 +1710,15 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: |
|
|
|
try: |
|
|
|
use_sym_shape = set_symbolic_shape(True) |
|
|
|
set_module_tracing() |
|
|
|
set_active_module_tracer(module_tracer(_wrapped_function)) |
|
|
|
|
|
|
|
set_active_module_tracer( |
|
|
|
module_tracer(_wrapped_function, _init_id2name(mod, "self")) |
|
|
|
) |
|
|
|
with active_module_tracer().patcher: |
|
|
|
global_scope = InternalGraph(name="") |
|
|
|
active_module_tracer().push_scope(global_scope) |
|
|
|
builder = TracedModuleBuilder(mod, True) |
|
|
|
name = mod._name if mod._name else mod.__class__.__name__ |
|
|
|
NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode)) |
|
|
|
NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode, orig_name="self")) |
|
|
|
inputs, _ = tree_flatten((args, kwargs)) |
|
|
|
for _, i in enumerate(inputs): |
|
|
|
# assert isinstance(i, Tensor), "not support " |
|
|
|