# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import builtins import collections import copy import ctypes import fnmatch import functools import inspect import keyword import re import weakref from inspect import getcallargs, getmembers, isclass, ismethod from itertools import chain from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union from megengine import tensor from ... import functional as F from ... import get_logger from ... import module as M from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._imperative_rt.core2 import ( is_tracing_module, set_module_tracing, unset_module_tracing, ) from ...core._trace_option import set_symbolic_shape from ...core.tensor.array_method import ArrayMethodMixin from ...module import Module from ...module.qat import QATModule from ...quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize from ...quantization.observer import ( ExponentialMovingAverageObserver, HistogramObserver, MinMaxObserver, Observer, PassiveObserver, SyncExponentialMovingAverageObserver, SyncMinMaxObserver, ) from ...tensor import Tensor from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input from .fake_quant import FakeQuantize as TM_FakeQuant from .module_tracer import ( PatchedFn, Patcher, active_module_tracer, get_tensor_wrapable_method, module_tracer, set_active_module_tracer, ) from .node import ModuleNode, Node, NodeMixin, TensorNode from .pytree import ArgsIndex, tree_flatten logger = get_logger(__name__) def _is_builtin_name(name: str) -> bool: return ( name in builtins.__dict__ or name in keyword.kwlist or name in {"inf", "nan", "NoneType"} ) def _is_leaf(node): assert isinstance(node, RawTensor), "doesn't support {} in return values".format( type(node) ) return isinstance(node, RawTensor) _enable_node_to_tensor = False def _convert_node_flag(): return _enable_node_to_tensor def _set_convert_node_flag(flag: bool = False): global _enable_node_to_tensor pre_flag = _enable_node_to_tensor _enable_node_to_tensor = flag return pre_flag def _node_to_tensor(*args, **kwargs): tensors = [] nodes, tree_def = tree_flatten((args, kwargs)) for n in nodes: if isinstance(n, TensorNode): if n.top_graph is not None: active_module_tracer().current_scope()._add_input(n) value = n.value if value is None: flag = _set_convert_node_flag(False) unset_module_tracing() value = F.zeros(shape=n._shape, dtype=n._dtype) set_module_tracing() _set_convert_node_flag(flag) orig_n = NodeMixin.get(value, None) if orig_n is None or "setitem" not in orig_n._name: NodeMixin.wrap_safe(value, n) tensors.append(value) else: tensors.append(n) tensors = tree_def.unflatten(tensors) return tensors def _tensor_to_node(tensors): if tensors is None: return None nodes = [] tensors, out_def = tree_flatten(tensors) for t in tensors: if isinstance(t, Tensor): n = NodeMixin.get(t, None) if isinstance(n, TensorNode): n.value = t nodes.append(n) else: nodes.append(t) else: nodes.append(t) nodes = out_def.unflatten(nodes) return nodes def _wrap_method_to_tensor_node(): def _any_method(name): def _any(*args, **kwargs): args, kwargs = _node_to_tensor(*args, **kwargs) attr = getattr(args[0], name) outs = attr if callable(attr): outs = attr(*(args[1:]), **kwargs) if name == "__setitem__": _node_to_tensor(outs) return None outs = _tensor_to_node(outs) return outs return _any tensor_method_patch = [] for method in get_tensor_wrapable_method(): patch = PatchedFn(TensorNode, method) if type(getattr(Tensor, method)) == property: patch.set_func(property(_any_method(method))) else: patch.set_func(_any_method(method)) tensor_method_patch.append(patch) return tensor_method_patch def _convert_node_and_tensor(orig_func): @functools.wraps(orig_func) def _convert(*args, **kwargs): if _convert_node_flag() and is_tracing_module(): args, kwargs = _node_to_tensor(*args, **kwargs) rst = orig_func(*args, **kwargs, method_func=_convert) rst = _tensor_to_node(rst) return rst else: rst = orig_func(*args, **kwargs) return rst return _convert def _wrap_mnode_getattr(orig_getattr): @functools.wraps(orig_getattr) def wraped_fn(self, name): obj = self.owner if self.top_graph is not None: active_module_tracer().current_scope()._add_input(self) attr = getattr(obj, name) node = attr full_name = None if id(attr) in active_module_tracer().id2name: full_name = active_module_tracer().id2name[id(attr)] if not isinstance(attr, TracedModuleBuilder): if isinstance(attr, Module): attr = TracedModuleBuilder(attr) setattr(obj, name, attr) active_module_tracer().id2name[id(attr)] = full_name if isinstance(attr, (NodeMixin, RawTensor)): if full_name: scope_name = active_module_tracer().current_scope()._module_name if scope_name: full_name = full_name[len(scope_name) + 1 :] else: full_name = name else: full_name = name NodeMixin.wrap( attr, lambda: GetAttr.make( self, name, type=NodeMixin.get_wrapped_type(attr), orig_name=full_name, ), ) if isinstance(attr, (NodeMixin, RawTensor)): node = NodeMixin.get(attr) if isinstance(node, ModuleNode): node._owner = weakref.ref(attr) return node return wraped_fn def _wrap_mnode_call(orig_call): @functools.wraps(orig_call) def wraped_fn(self, *args, **kwargs): obj = self.owner if self.top_graph is not None: active_module_tracer().current_scope()._add_input(self) rst = obj(*args, **kwargs) return rst return wraped_fn def _init_id2name(mod: Module, prefix: str = ""): id2name = { id(m): "%s.%s" % (prefix, key) for key, m in chain( mod.named_modules(), mod.named_parameters(), mod.named_buffers() ) } return id2name class _InsertExprs: def __init__(self, graph, expr: Optional[Expr] = None): self.graph = graph self.global_scope = InternalGraph( graph._name, graph._prefix_name, graph._module_name ) self.global_scope._used_names.update(graph._used_names) self.expr = expr self._tensor_method_patch = None def __enter__(self): self.use_sym_shape = set_symbolic_shape(True) set_module_tracing() _set_convert_node_flag(True) assert active_module_tracer() is None module = self.graph.inputs[0].owner _wrap_func = lambda x: _convert_node_and_tensor(_wrapped_function(x)) set_active_module_tracer( module_tracer(_wrap_func, _init_id2name(module, self.graph._module_name)) ) active_module_tracer().patcher.__enter__() for cls, name, func in [ [ModuleNode, "__getattr__", _wrap_mnode_getattr], [ModuleNode, "__call__", _wrap_mnode_call], [TracedModuleBuilder, "__call__", _convert_node_and_tensor], ]: active_module_tracer().patcher.patch_function(cls, name, func) self._tensor_method_patch = _wrap_method_to_tensor_node() active_module_tracer().push_scope(self.global_scope) def __exit__(self, ty, va, tr): if va is not None: return False set_symbolic_shape(self.use_sym_shape) unset_module_tracing() active_module_tracer().patcher.__exit__(ty, va, tr) _set_convert_node_flag(False) while self._tensor_method_patch: pf = self._tensor_method_patch.pop() pf.set_func(pf.origin_fn) module = self.graph.inputs[0].owner for mod, parent in module.modules(with_parent=True): name = mod._name if isinstance(mod, TracedModuleBuilder): mod = mod.build() if hasattr(mod, "graph"): for node in mod.graph.nodes(): node.value = None setattr(parent, name, mod) set_active_module_tracer(None) for node in self.global_scope.nodes(): node.value = None extra_inp_nodes = set(self.global_scope.inputs) max_inp_expr_idx = -1 for node in extra_inp_nodes: assert ( node.top_graph == self.graph ), "The input node ({}) is not in the graph ({})".format(node, self.graph) if isinstance(node, TensorNode) and node.expr in self.graph._exprs: max_inp_expr_idx = max( max_inp_expr_idx, self.graph._exprs.index(node.expr) ) max_inp_expr_idx += 1 insert_index = -1 if self.expr is not None: insert_index = self.graph._exprs.index(self.expr) insert_index += 1 if insert_index < max_inp_expr_idx: insert_index = max_inp_expr_idx anchor_index = insert_index - 1 if anchor_index >= 0: logger.info( "The new expr will be inserted after ( {} )".format( self.graph._exprs[anchor_index] ) ) for expr in self.global_scope._exprs: self.graph._exprs.insert(insert_index, expr) insert_index += 1 self.graph._used_names.update(self.global_scope._used_names) graph = self.graph while graph.top_graph is not None: graph = graph.top_graph graph.inputs[0].owner._update_ref() return True class InternalGraph: """ ``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method. Attributes: _exprs: List of Exprs in order of execution _inputs: Input Nodes of InternalGraph _outputs: Output Nodes of InternalGraph """ _exprs = None # type: List[Expr] _inputs = None # type: List[Node] _outputs = None # type: List[Node] _top_graph = None def __init__(self, name: str = None, prefix_name: str = "", module_name: str = ""): self._exprs = [] self._inputs = [] self._outputs = [] self._watch_point = [] self._end_point = [] self._used_names = {} self._rst = collections.defaultdict(list) self._name = name self._prefix_name = prefix_name self._module_name = module_name def _insert(self, expr): self._exprs.append(expr) def _create_unique_name(self, name: str) -> str: assert isinstance(name, str), "The name must be a str" name = re.sub("[^0-9a-zA-Z_]+", "_", name) if name[0].isdigit(): name = "_{}".format(name) while name in self._used_names or _is_builtin_name(name): match = re.match(r"(.*)_(\d+)$", name) if match is None: name = name + "_1" else: base, num = match.group(1, 2) name = "{}_{}".format(base, int(num) + 1) self._used_names.setdefault(name) return name @property def inputs(self): return self._inputs @property def outputs(self): return self._outputs @property def top_graph(self): if self._top_graph: return self._top_graph() return None def exprs(self, recursive=True): return ExprFilter(_expr_iter(self, recursive)) def nodes(self, recursive=True): return NodeFilter(_node_iter(self, recursive)) def get_function_by_type(self, func: Callable = None, recursive=True): return self.exprs(recursive).call_function(func) def get_method_by_type(self, method: str = None, recursive=True): return self.exprs(recursive).call_method(method) def get_expr_by_id(self, expr_id: List[int] = None, recursive=True): return self.exprs(recursive).expr_id(expr_id) def get_module_by_type(self, module_cls: Module, recursive=True): assert issubclass(module_cls, Module) return self.nodes(recursive).type(module_cls, ModuleNode) def get_node_by_id(self, node_id: List[int] = None, recursive=True): return self.nodes(recursive).node_id(node_id) def get_node_by_name( self, name: str = None, ignorecase: bool = True, recursive=True ): return self.nodes(recursive).name(name, ignorecase) def _add_input(self, i): self._inputs.append(i) def _add_output(self, o): self._outputs.append(o) def _replace_inputs_outputs(self, repl_dict, prefix_name="", module_name=""): for node, repl_node in repl_dict.items(): assert node in self._inputs or node in self._outputs for i in node.users: if i not in repl_node.users: repl_node.users.append(i) for idx, i in enumerate(self._inputs): if i in repl_dict: self._inputs[idx] = repl_dict[i] for idx, o in enumerate(self._outputs): if o in repl_dict: repl_dict[o]._orig_name = "{}{}".format(module_name, o._orig_name) self._outputs[idx] = repl_dict[o] for expr in self._exprs: for idx, i in enumerate(expr.inputs): assert isinstance( i._name, str ), "The node ({}) name must be a str".format(i) if i in repl_dict: expr.inputs[idx] = repl_dict[i] elif isinstance(i, TensorNode) and prefix_name not in i._name: if i.top_graph != active_module_tracer().current_scope(): i._name = ( active_module_tracer() .current_scope() ._create_unique_name(prefix_name + i._name.lstrip("_")) ) i._orig_name = "{}{}".format(module_name, i._orig_name) for idx, o in enumerate(expr.outputs): assert isinstance( o._name, str ), "The node ({}) name must be a str".format(i) if o in repl_dict: expr.outputs[idx] = repl_dict[o] expr.outputs[idx].expr = expr elif isinstance(o, TensorNode) and prefix_name not in i._name: if o.top_graph != active_module_tracer().current_scope(): o._name = ( active_module_tracer() .current_scope() ._create_unique_name(prefix_name + o._name.lstrip("_")) ) o._orig_name = "{}{}".format(module_name, o._orig_name) def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: if not isinstance(nodes, Sequence): nodes = (nodes,) ret = list() queue = list(nodes) visited_queue = list() while queue: node = queue.pop() visited_queue.append(node) expr = node.expr if expr not in ret: ret.append(expr) for i in expr.inputs: if i not in queue and i not in visited_queue: queue.append(i) return ret def reset_inputs(self, *args, **kwargs): forma_mnode = self.inputs[0] actual_mnodes = forma_mnode.actual_node call_nodes = [] for n in actual_mnodes: for c_expr in n.users: if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": call_nodes.append((c_expr, n)) moudle = forma_mnode.owner assert moudle._is_top, "reset_inputs only support the top-level graph" inputs, tree_def = tree_flatten(((moudle, *args), kwargs)) def create_node(val: Tensor): node = Input(type=TensorNode).outputs[0] node.shape = val.shape node.dtype = val.dtype return node formal_node_inputs = [ forma_mnode, ] org_argdef = list(moudle.argdef_graph_map.keys())[0] if call_nodes: org_argdef = call_nodes[0][0].arg_def for v in inputs[1:]: assert isinstance(v, RawTensor) formal_node_inputs.append(create_node(v)) actual_nodes = [] for e, n in call_nodes: e.arg_def = tree_def actual_node_inputs = [ n, ] for v in inputs[1:]: actual_node_inputs.append(create_node(v)) for org_n in e.inputs: org_n.users.pop(e) e.inputs[:] = actual_node_inputs e.const_val = [] actual_nodes.append(actual_node_inputs[1:]) self._inputs[:] = formal_node_inputs moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef) moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef) # return formal_node_inputs[1:], actual_nodes return formal_node_inputs[1:] def add_input_node(self, shape, dtype="float32", name="args"): forma_mnode = self.inputs[0] actual_mnodes = forma_mnode.actual_node moudle = forma_mnode.owner assert moudle._is_top, "add_input_node only support the top-level graph" call_nodes = [] for n in actual_mnodes: for c_expr in n.users: if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": call_nodes.append(c_expr) def create_node(name=None, is_input: bool = True): if is_input: node = Input(type=TensorNode, name=name).outputs[0] else: node = TensorNode(expr=None, name=None) node.shape = shape node.dtype = dtype return node org_argdef = list(moudle.argdef_graph_map.keys())[0] if call_nodes: org_argdef = call_nodes[0].arg_def args, kwargs = org_argdef.unflatten(self._inputs) formal_inp_node = create_node(self._create_unique_name(name), True) inputs, tree_def = tree_flatten( ((*args, formal_inp_node), kwargs), is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), ) self._inputs[:] = inputs[:] actual_inp_nodes = [] for e in call_nodes: args, kwargs = e.unflatten_args(e.inputs) args = args + (create_node(False),) inputs, tree_def = tree_flatten( (args, kwargs), is_const_leaf=lambda x: not isinstance(x, (TensorNode, ModuleNode)), ) e.inputs[:] = inputs[:] e.arg_def = tree_def actual_inp_nodes.append(args[-1]) moudle.argdef_graph_map[tree_def] = moudle.argdef_graph_map.pop(org_argdef) moudle.argdef_outdef_map[tree_def] = moudle.argdef_outdef_map.pop(org_argdef) # return formal_inp_node, actual_inp_nodes return formal_inp_node def reset_outputs(self, outputs): outputs, out_def = tree_flatten( outputs, is_leaf=lambda x: isinstance(x, TensorNode), ) forma_mnode = self.inputs[0] moudle = forma_mnode.owner assert moudle._is_top, "reset_outputs only support the top-level graph" actual_mnodes = forma_mnode.actual_node call_nodes = [] for n in actual_mnodes: for c_expr in n.users: if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": call_nodes.append((c_expr)) def create_node(val: TensorNode, expr: Expr): node = TensorNode(expr) node.shape = val.shape node.dtype = val.dtype return node tree_def = list(moudle.argdef_graph_map.keys())[0] if call_nodes: tree_def = call_nodes[0].arg_def actual_nodes = [] for e in call_nodes: actual_node_outputs = [] for v in outputs: actual_node_outputs.append(create_node(v, e)) e.outputs[:] = actual_node_outputs e.out_def = out_def actual_nodes.append(actual_node_outputs) self._outputs[:] = outputs moudle.argdef_outdef_map[tree_def] = out_def return actual_nodes def add_output_node(self, node: TensorNode): forma_mnode = self.inputs[0] moudle = forma_mnode.owner assert moudle._is_top, "add_output_node only support the top-level graph" actual_mnodes = forma_mnode.actual_node call_nodes = [] for n in actual_mnodes: for c_expr in n.users: if isinstance(c_expr, CallMethod) and c_expr.method == "__call__": call_nodes.append((c_expr)) def create_node(val: TensorNode, expr: Expr): node = TensorNode(expr) node.shape = val.shape node.dtype = val.dtype return node tree_def = list(moudle.argdef_graph_map.keys())[0] if call_nodes: tree_def = call_nodes[0].arg_def org_out_def = moudle.argdef_outdef_map[tree_def] org_outs = org_out_def.unflatten(self._outputs) outputs, out_def = tree_flatten( (org_outs, node), is_leaf=lambda x: isinstance(x, TensorNode), ) self._outputs[:] = outputs actual_out_nodes = [] for e in call_nodes: actual_node = create_node(node, e) org_outs = org_out_def.unflatten(e.outputs) outputs, out_def = tree_flatten( (org_outs, actual_node), is_leaf=lambda x: isinstance(x, TensorNode), ) e.outputs[:] = outputs e.out_def = out_def actual_out_nodes.append(actual_node) moudle.argdef_outdef_map[tree_def] = out_def return actual_out_nodes def insert_exprs(self, expr: Optional[Expr] = None): if expr is not None: assert expr.top_graph == self, "Expr to insert after is not in graph." return _InsertExprs(self, expr) def replace_node(self, repl_dict: Dict[Node, Node]): while repl_dict: node, repl_node = repl_dict.popitem() # check graph inputs and outputs # assert node not in self.inputs, "Cannot replace inputs" for i, n in enumerate(self.outputs): if n is node: self.outputs[i] = repl_node # update users of node and repl_node # update inputs of expr in node.users graph = repl_node.top_graph assert graph is not None index = graph._exprs.index(repl_node.expr) dep_exprs = self.get_dep_exprs(repl_node) i = 0 while i < len(node.users): n = node.users[i] if n in graph._exprs and index >= graph._exprs.index(n): i += 1 continue if n in dep_exprs: logger.info("Find a loop: ignore this replacement once") logger.info("node: %s" % node.__repr__()) logger.info("expr: %s" % n.__repr__()) i += 1 continue repl_node.users.append(n) node.users.pop(i) idx = n.inputs.index(node) n.inputs[idx] = repl_node def compile(self): """ Delete unused expr. """ dep_exprs = self.get_dep_exprs(self.outputs) i = 0 while i < len(self._exprs): expr = self._exprs[i] if expr in dep_exprs or expr._disable_remove: i += 1 continue for n in expr.inputs: n.users.remove(expr) self._exprs.remove(expr) def interpret(self, *inputs): node2value = {} end_nodes_set = set(self._end_point) endnode2value = {} def get_all_endnode_val(n, v): if n in end_nodes_set: endnode2value[n] = v end_nodes_set.remove(n) return not end_nodes_set return False for n, v in zip(self._inputs, inputs): node2value[n] = v if n in self._watch_point: self._rst[n].append(v) if n in self._end_point and get_all_endnode_val(n, v): return list(endnode2value[i] for i in self._end_point) for expr in self._exprs: values = expr.interpret(*list(node2value[i] for i in expr.inputs)) if values is not None: for n, v in zip(expr.outputs, values): node2value[n] = v if n in self._watch_point: self._rst[n] = v if self._end_point and get_all_endnode_val(n, v): return list(endnode2value[i] for i in self._end_point) return list(node2value[i] for i in self._outputs) def eval(self, *inputs): assert len(inputs) == len(self._inputs) - 1 inp = [self._inputs[0].owner] + list(inputs) return self.interpret(*inp) def __repr__(self): return self.__format__() def __format__(self, format_spec: str = "") -> str: saved_format_spec = Node.set_format_spec(format_spec) name = "" if self._name: name = "%s.Graph" % self._name res = "{} ({}) {{\n\t{}\n\treturn {}\n}}".format( name, ", ".join(str(i) for i in self._inputs), "\n\t".join("{}".format(str(i)) for i in self._exprs), ", ".join(str(i) for i in self._outputs), ) Node.set_format_spec(saved_format_spec) return res def __getstate__(self): state = self.__dict__.copy() if "_top_graph" in state: state.pop("_top_graph") return state def _get_meth_name(obj, func): tp = obj if isinstance(obj, type) else type(obj) for cls in tp.mro(): for k, v in cls.__dict__.items(): if v == func: return k return None def _wrapped_function(orig_func): @functools.wraps(orig_func) def wrapped_fn(*args, **kwargs): method_func = wrapped_fn if "method_func" in kwargs: method_func = kwargs.pop("method_func") if is_tracing_module(): unset_module_tracing() inputs, tree_def = tree_flatten((args, kwargs)) for i in inputs: if not NodeMixin.get(i, None): if isinstance(i, (RawTensor, NodeMixin)): NodeMixin.wrap_safe(i, Constant.make(i)) meth_name, arg_type = None, None if args: meth_name = _get_meth_name(args[0], method_func) arg_type = args[0] if isinstance(args[0], type) else type(args[0]) if meth_name and arg_type and issubclass(arg_type, RawTensor): self = inputs[0] if meth_name == "__new__": if all([not isinstance(i, RawTensor) for i in inputs]): # only trace Tensor.__new__() when there are tensors in args set_module_tracing() return orig_func(*args, **kwargs) if isinstance(args[1], RawTensor): node = NodeMixin.get(inputs[1]) inputs[1] = copy.copy(inputs[1]) # copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing. NodeMixin.wrap_safe(inputs[1], node) args, kwargs = tree_def.unflatten(inputs) call_node = CallMethod.make(self, meth_name) else: call_node = CallMethod.make(NodeMixin.get(self), meth_name) call_node.add_inputs(inputs[1:]) else: call_node = CallFunction.make(orig_func) call_node.add_inputs(inputs) call_node.arg_def = tree_def rst = orig_func(*args, **kwargs) if meth_name == "__setitem__": rst = self if rst is not None: outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) call_node.out_def = out_def else: outputs = None call_node.add_outputs(outputs) set_module_tracing() return rst return orig_func(*args, **kwargs) return wrapped_fn class TracedModuleBuilder(NodeMixin): _mod = None # type: Module _body = None # type: InternalGraph _is_builtin = None # type: bool _argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"] _argdef_outdef_map = None # type: Dict[Treedef, Treedef] nodes = None __builder_attributes__ = [ "_mod", "_body", "_NodeMixin__node", "_is_builtin", "build", "_record_wrapped_nodes", "_argdef_graph_map", "_argdef_outdef_map", "nodes", "__class__", "__dict__", ] def __init__(self, mod, is_top_module=False): super(TracedModuleBuilder, self).__init__() assert isinstance(mod, Module) self._mod = mod self._body = None self._is_top = is_top_module self._is_builtin = ( True if isinstance(mod, (Observer, _FakeQuantize)) else module_tracer.is_builtin(mod) ) if isinstance(self._mod, QATModule): unset_module_tracing() self._check_qat_module(self._mod) set_module_tracing() self._argdef_graph_map = {} self._argdef_outdef_map = {} self.nodes = set() # The builder will be passed to self._mod.forward as 'self' argument. If the 'forward' uses super().xxx to call method of its base classes, the trace procedure will throw exceprion, because the builder doesn't inherit from self._mod.__bases__. # modify self.__class__ and let the builder inherit from TracedModuleBuilder and mod.__class__. self.__class__ = type( "TracedModuleBuilder", (TracedModuleBuilder, mod.__class__), dict(TracedModuleBuilder.__dict__), ) def _check_qat_module(self, qat_module): def isbuiltin(m): return m is None or module_tracer.is_builtin(m) if qat_module.with_act: act_observer = qat_module.act_observer act_fakequant = qat_module.act_fake_quant if not isbuiltin(act_observer) or not isbuiltin(act_fakequant): qparams = ( act_observer.get_qparams() if hasattr(act_observer, "get_qparams") else act_fakequant.get_qparams() ) dtype = ( act_observer.dtype if hasattr(act_observer, "dtype") else act_fakequant.dtype ) qat_module.act_observer = None qat_module.act_fake_quant = TM_FakeQuant(dtype) qat_module.act_fake_quant.set_qparams(qparams) if qat_module.with_weight: weight_observer = qat_module.weight_observer weight_fakequant = qat_module.weight_fake_quant if not isbuiltin(weight_observer) or not isbuiltin(weight_fakequant): qparams = ( weight_observer.get_qparams() if hasattr(weight_observer, "get_qparams") else weight_fakequant.get_qparams() ) dtype = ( weight_observer.dtype if hasattr(weight_observer, "dtype") else weight_fakequant.dtype ) qat_module.weight_observer = None qat_module.weight_fake_quant = TM_FakeQuant(dtype) qat_module.weight_fake_quant.set_qparams(qparams) def build(self): if self._is_builtin or isinstance(self._mod, TracedModule): if module_tracer.is_builtin(self._mod) or isinstance( self._mod, TracedModule ): mod_type = type(self._mod) else: assert isinstance(self._mod, (Observer, _FakeQuantize)) mod_type = ( Observer if isinstance(self._mod, Observer) else _FakeQuantize ) for node in self.nodes: node.module_type = mod_type return self._mod else: is_qat = isinstance(self._mod, QATModule) traced_module = TracedModule( self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat ) for _, g in self._argdef_graph_map.items(): g.compile() for k, v in self.__dict__.items(): if k not in TracedModuleBuilder.__builder_attributes__: if isinstance(v, TracedModuleBuilder): v = v.build() setattr(traced_module, k, v) if isinstance(self._mod, QATModule): unset_module_tracing() traced_module.with_act = self._mod.with_act traced_module.with_weight = self._mod.with_weight if not hasattr(traced_module, "act_fake_quant"): traced_module.act_fakequant = None if not hasattr(traced_module, "act_observer"): traced_module.act_observer = None if not hasattr(traced_module, "weight_fake_quant"): traced_module.weight_fakequant = None if not hasattr(traced_module, "weight_observer"): traced_module.weight_observer = None set_module_tracing() return traced_module def _record_wrapped_nodes(self, node): self.nodes.add(node) def __call__(self, *args, **kwargs): assert isinstance(self._mod, Module) # prepare args and kwargs for inner graph if "method_func" in kwargs: kwargs.pop("method_func") def mark_constant(x): node = NodeMixin.get(x, None) if node is None: # capture as constant NodeMixin.wrap(x, lambda: Constant.make(x)) inputs, tree_def = tree_flatten(((self, *args), kwargs)) for i in inputs: mark_constant(i) callnode = CallMethod.make(NodeMixin.get(self)) callnode.add_inputs(inputs[1:]) callnode.arg_def = tree_def if ( self._is_builtin or tree_def in self._argdef_graph_map or isinstance(self._mod, TracedModule) ): unset_module_tracing() rst = self._mod(*args, **kwargs) outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) set_module_tracing() if self._is_builtin: self._body = None elif tree_def in self._argdef_graph_map: self._body = self._argdef_graph_map[tree_def] else: self._mod._is_top = False self._body = self._mod.graph else: self_node = None orig_self = NodeMixin.get(self) top_graph = active_module_tracer().current_scope() graph_prefix_name = top_graph._name if top_graph._prefix_name: graph_prefix_name = "{}_{}".format( top_graph._prefix_name, graph_prefix_name.lstrip("_") ) module_name = orig_self._orig_name if top_graph._module_name: module_name = "{}.{}".format(top_graph._module_name, module_name) self._body = InternalGraph( orig_self._name, prefix_name=graph_prefix_name, module_name=module_name ) active_module_tracer().push_scope(self._body) # rebind self to new input node if self_node: NodeMixin.wrap_safe(self, self_node) active_module_tracer().current_scope()._add_input(self_node) else: NodeMixin.wrap_safe( self, self_node if self_node else Input.make("self", NodeMixin.get_wrapped_type(self), ""), ) origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] # prepare args and kwargs for inner graph index_args, index_kwargs = tree_def.unflatten( [ ArgsIndex(0), *list(ArgsIndex(i + 1) for i in range(len(origin_inp_node))), ] ) key2idx = getcallargs(type(self._mod).forward, *index_args, **index_kwargs) idx2key = {} for k, v in key2idx.items(): if isinstance(v, ArgsIndex): idx2key[v.index] = k else: flatten_argidx, _ = tree_flatten(v) for _i, v in enumerate(flatten_argidx): if isinstance(v, ArgsIndex): idx2key[v.index] = k + "_%d" % _i def wrap(x, name): if isinstance(x, (RawTensor, NodeMixin)): NodeMixin.wrap( x, lambda: Input.make( type=NodeMixin.get_wrapped_type(x), name=name ), ) return x args = [self] for i, v in enumerate(inputs[1:]): args.append(wrap(v, idx2key[i + 1])) args, kwargs = tree_def.unflatten(args) active_module_tracer().patcher.auto_patch( getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) ) rst = type(self._mod).forward(*args, **kwargs) if _convert_node_flag(): rst = _node_to_tensor(rst)[0][0] outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) for i in ( outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) ): active_module_tracer().current_scope()._add_output(NodeMixin.get(i)) NodeMixin.wrap_safe(self, orig_self) for arg, node in zip(inputs[1:], origin_inp_node): if node: NodeMixin.wrap_safe(arg, node) active_module_tracer().pop_scope() # rebind output to outer graph callnode.out_def = out_def callnode.add_outputs(outputs) self._argdef_graph_map[callnode.arg_def] = self._body self._argdef_outdef_map[callnode.arg_def] = out_def return rst def __setattr__(self, name, value): object.__setattr__(self, name, value) def __repr__(self): return repr(self._mod) def __getattr__(self, name): if name not in self._mod.__dict__: attr = getattr(type(self._mod), name).__get__(self, type(self)) else: attr = getattr(self._mod, name) full_name = None if id(attr) in active_module_tracer().id2name: full_name = active_module_tracer().id2name[id(attr)] if isinstance(attr, Module): attr = TracedModuleBuilder(attr) if isinstance(attr, (Module, RawTensor)): setattr(self, name, attr) active_module_tracer().id2name[id(attr)] = full_name if full_name: scope_name = active_module_tracer().current_scope()._module_name if scope_name: full_name = full_name[len(scope_name) + 1 :] else: full_name = name else: full_name = name NodeMixin.wrap( attr, lambda: GetAttr.make( NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr), orig_name=full_name, ), ) return attr def __getattribute__(self, name): if name in TracedModuleBuilder.__builder_attributes__: return object.__getattribute__(self, name) else: wrapped = object.__getattribute__(self, name) if name in self._mod.__dict__: mod_attr = getattr(self._mod, name) if not isinstance(mod_attr, Module) and wrapped is not mod_attr: wrapped = mod_attr setattr(self, name, wrapped) if isinstance(mod_attr, Module): assert mod_attr is wrapped._mod else: assert mod_attr is wrapped full_name = None if id(mod_attr) in active_module_tracer().id2name: full_name = active_module_tracer().id2name[id(mod_attr)] scope_name = active_module_tracer().current_scope()._module_name if full_name and scope_name: full_name = full_name[len(scope_name) + 1 :] else: full_name = name else: full_name = name # assert not self._is_builtin if isinstance(wrapped, (NodeMixin, RawTensor)): NodeMixin.wrap( wrapped, lambda: GetAttr.make( NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(wrapped), orig_name=full_name, ), ) return wrapped class _expr_iter: def __init__(self, graph: InternalGraph, recursive: bool = True): self.graph = graph self.recursive = recursive def __iter__(self): for expr in self.graph._exprs: if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): yield expr if self.recursive and expr.graph is not None: yield from expr.graph.exprs(self.recursive) else: yield expr class _node_iter: def __init__(self, graph: InternalGraph, recursive: bool = True) -> None: nodes = [] node_ids = set() for expr in graph.exprs(recursive): for n in expr.inputs + expr.outputs: if n._id in node_ids: continue nodes.append(n) node_ids.add(n._id) self.nodes = list(sorted(nodes, key=lambda x: x._id)) def __iter__(self): for node in self.nodes: yield node class BaseFilter: def __init__(self, expr_iter: Iterable): self._iter = expr_iter def __iter__(self): return iter(self._iter) def as_list(self): return list(self) def as_dict(self): return collections.OrderedDict((i._id, i) for i in self) def as_unique(self): rst = self.as_list() assert len(rst) == 1, "{} elements found".format(len(rst)) (expr,) = self return expr def as_count(self): return sum(1 for _ in self) class ExprFilter(BaseFilter): def call_function(self, func): return ExprFilterCallFunction(self, func) def call_method(self, method): return ExprFilterCallMethod(self, method) def expr_id(self, expr_id: List[int]): return ExprFilterExprId(self, expr_id) class NodeFilter(BaseFilter): def type(self, owner_type, node_type): return NodeFilterType(self, owner_type, node_type) def node_id(self, node_id: List[int]): return NodeFilterNodeId(self, node_id) def name(self, name: str, ignorecase: bool = True): return NodeFilterName(self, name, ignorecase) class NodeFilterType(NodeFilter): def __init__(self, expr_iter, owner_type, node_type): super().__init__(expr_iter) self.owner_type = owner_type self.node_type = node_type def __iter__(self): for node in self._iter: if not isinstance(node, self.node_type): continue if not hasattr(node, "owner"): continue if isinstance(node.owner, self.owner_type): yield node class NodeFilterNodeId(NodeFilter): def __init__(self, expr_iter, node_id: List[int]): super().__init__(expr_iter) if not isinstance(node_id, Sequence): node_id = [node_id] self.node_id = node_id def __iter__(self): for node in self._iter: if node._id in self.node_id: yield node class NodeFilterName(NodeFilter): _re = None def __init__(self, node_iter, pattern, ignorecase): super().__init__(node_iter) self.pattern = pattern self._re = self.make_re(pattern, ignorecase) @classmethod def make_re(cls, pattern, ignorecase=True): assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern) assert isinstance(ignorecase, bool) flags = 0 if ignorecase: flags |= re.IGNORECASE return re.compile(fnmatch.translate(pattern), flags=flags) def __iter__(self): for i in self._iter: graph = i.top_graph name = "{}_{}".format(graph._name, i._name.lstrip("_")) if graph._prefix_name: name = "{}_{}".format(graph._prefix_name, name.lstrip("_")) if self.pattern == name or self._re.match(name): yield i class ExprFilterCallFunction(ExprFilter): def __init__(self, expr_iter, func: Callable = None): super().__init__(expr_iter) self.func = func def __iter__(self): for expr in self._iter: if not isinstance(expr, CallFunction): continue if self.func is None or expr.func == self.func: yield expr class ExprFilterCallMethod(ExprFilter): def __init__(self, expr_iter, method: str = None): super().__init__(expr_iter) self.method = method def __iter__(self): for expr in self._iter: if not isinstance(expr, CallMethod): continue if self.method is None or expr.method == self.method: yield expr class ExprFilterExprId(ExprFilter): def __init__(self, expr_iter, expr_id: List[int]): super().__init__(expr_iter) if not isinstance(expr_id, Sequence): expr_id = [expr_id] self.expr_id = expr_id def __iter__(self): for expr in self._iter: if expr._id in self.expr_id: yield expr class TracedModule(Module): """ `TracedModule` is the Module created by tracing normal module. It owns an argdef to graph(InternalGraph) map. The forward method of `TracedModule` will get a graph from `argdef_graph_map` according to the argdef of input args/kwargs and interpret it. """ # m_node = None # type: ModuleNode argdef_graph_map = None argdef_outdef_map = None def __init__(self, is_top, argdef_graph_map, argdef_outdef_map, is_qat=False): super(TracedModule, self).__init__() self.argdef_graph_map = argdef_graph_map self.argdef_outdef_map = argdef_outdef_map self._is_top = is_top self.watch_points = [] self.watch_node_value = {} self.end_points = [] self.is_qat = is_qat def forward(self, *args, **kwargs): inputs, treedef = tree_flatten(((self, *args), kwargs)) assert treedef in self.argdef_graph_map inputs = filter( lambda i: isinstance(i, (Module, TracedModuleBuilder, RawTensor)), inputs ) # allow TracedModuleBuilder for retrace. outputs = self.argdef_graph_map[treedef].interpret(*inputs) if self.watch_points: self.watch_node_value = {} for n in self.watch_points: self.watch_node_value[n] = n.top_graph._rst.pop(n) if self.end_points: return outputs out_def = self.argdef_outdef_map[treedef] outputs = out_def.unflatten(outputs) return outputs def set_watch_points(self, nodes): if not isinstance(nodes, Sequence): nodes = [nodes] self.watch_points = nodes for n in nodes: n.top_graph._watch_point.append(n) def clear_watch_points(self): for n in self.watch_points: n.top_graph._watch_point = [] self.watch_points = [] self.watch_node_value = {} def set_end_points(self, nodes): if not isinstance(nodes, Sequence): nodes = [nodes] self.end_points = nodes graphs = list(self.argdef_graph_map.values()) for n in nodes: assert n.top_graph in graphs n.top_graph._end_point.append(n) def clear_end_points(self): for n in self.end_points: n.top_graph._end_point = [] self.end_points = [] @property def graph(self) -> InternalGraph: if self._is_top: self._update_ref() assert len(self.argdef_graph_map) == 1 return list(self.argdef_graph_map.values())[0] def _update_ref(self, actual_node_map: Union[Dict] = None, top_graph=None): for inp_def, graph in self.argdef_graph_map.items(): if top_graph is not None: graph._top_graph = weakref.ref(top_graph) for n in graph._inputs + graph.outputs: n._top_graph = weakref.ref(graph) graph._inputs[0]._owner = weakref.ref(self) for i, n in enumerate(graph._inputs): n.actual_node = [] if actual_node_map is not None and inp_def in actual_node_map.keys(): n.actual_node = list(list(zip(*(actual_node_map[inp_def])))[i]) node2obj = {} next_actual_node_map = collections.defaultdict( lambda: collections.defaultdict(list) ) node2obj[graph._inputs[0]] = self for expr in graph._exprs: for n in expr.inputs + expr.outputs: n._top_graph = weakref.ref(graph) expr._top_graph = weakref.ref(graph) if isinstance(expr, GetAttr) and isinstance( expr.outputs[0], ModuleNode ): obj = getattr(node2obj[expr.inputs[0]], expr.name) expr.outputs[0]._owner = weakref.ref(obj) node2obj[expr.outputs[0]] = obj if isinstance(expr, Constant) and isinstance( expr.outputs[0], ModuleNode ): obj = expr.value expr.outputs[0]._owner = weakref.ref(obj) node2obj[expr.outputs[0]] = obj if ( isinstance(expr, CallMethod) and expr.method == "__call__" and isinstance(expr.inputs[0], ModuleNode) ): obj = node2obj[expr.inputs[0]] if expr.arg_def is not None: next_actual_node_map[obj][expr.arg_def].append(expr.inputs) for obj in node2obj.values(): if obj is self: continue mnode_map = None if obj in next_actual_node_map.keys(): mnode_map = next_actual_node_map[obj] if isinstance(obj, TracedModule): obj._update_ref(mnode_map, graph) def flatten(self): """ Get a new module, which eliminates ``GetAttr`` and has no hierarchy. :return: :class:`TracedModule` """ new_module = copy.deepcopy(self) assert active_module_tracer() is None id2name = _init_id2name(new_module, "self") set_active_module_tracer(module_tracer(lambda x: x, {})) active_module_tracer().push_scope(new_module.graph) def _flatten_subgraph( graph: InternalGraph, module: Module, call=None, prefix_name="", module_name="", ): if isinstance(prefix_name, str) and prefix_name and prefix_name[-1] != "_": prefix_name += "_" if isinstance(module_name, str) and module_name: module_name += "." if graph is None or module.is_qat: assert not isinstance(module, TracedModule) or module.is_qat const = Constant(module, id2name[id(module)]) m_node = call.inputs[0] if m_node.top_graph != active_module_tracer().current_scope(): m_node._name = ( active_module_tracer() .current_scope() ._create_unique_name(prefix_name) ) m_node._orig_name = id2name[id(module)][5:] const.outputs[0] = m_node const.outputs[0].expr = const return [const, call] if call is not None: graph = copy.deepcopy(graph) exprs = [] node2obj = {} node2obj[graph._inputs[0]] = module if call: node2obj[call.inputs[0]] = module # replace inputs for submodule's exprx if call: repl_dict = dict(zip(graph._inputs, call.inputs)) for ind, out in enumerate(graph.outputs): if isinstance(out.expr, Input): assert out in repl_dict call_out = call.outputs[ind] for expr in call.outputs[ind].users: for index, inp in enumerate(expr.inputs): if inp is call_out: expr.inputs[index] = repl_dict[out] continue repl_dict[out] = call.outputs[ind] graph._replace_inputs_outputs(repl_dict, prefix_name, module_name) for expr in graph._exprs: if isinstance(expr, GetAttr): # replace GetAttr with Constant if isinstance(expr.outputs[0], TensorNode): const = Constant(getattr(node2obj[expr.inputs[0]], expr.name)) const.outputs = expr.outputs const.outputs[0].expr = const exprs.append(const) elif isinstance(expr.outputs[0], ModuleNode): node2obj[expr.outputs[0]] = getattr( node2obj[expr.inputs[0]], expr.name ) elif isinstance(expr, CallMethod): obj_node = expr.inputs[0] if isinstance(obj_node, ModuleNode): pre_expr = expr.inputs[0].expr if isinstance(pre_expr, GetAttr): (obj,) = pre_expr.interpret(node2obj[pre_expr.inputs[0]]) expr_graph = ( obj.argdef_graph_map[expr.arg_def] if hasattr(obj, "argdef_graph_map") else None ) exprs.extend( _flatten_subgraph( expr_graph, obj, expr, prefix_name + obj_node._name.lstrip("_"), module_name + obj_node._orig_name, ) ) else: # module has been replaced. assert isinstance(pre_expr, Constant) exprs.append(expr) else: exprs.append(expr) else: exprs.append(expr) if call is not None: for i in call.inputs: i.users.remove(call) return exprs new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module) new_module.graph.compile() set_active_module_tracer(None) for _id, expr in enumerate(new_module.graph._exprs): expr._id = _id total_node_id = 0 for i in new_module.graph._inputs: i._id = total_node_id total_node_id += 1 for expr in new_module.graph._exprs: for o in expr.outputs: o._id = total_node_id total_node_id += 1 return new_module def __getstate__(self): d = self.__dict__ for k in Module.__dict__: d.pop(k, None) return d def cpp_apply_module_trace(opdef, *args): return Apply.apply_module_trace_hook(opdef, *args) def register_as_builtin(mod_cls: Type[Module]) -> None: """ Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module. param mod_cls: the Module class which will be threated as builtin module in tracing """ module_tracer.register_as_builtin(mod_cls) def wrap(func: Callable): """ Call this function to register func as a builtin function. """ assert callable(func), "func must be a callable" assert hasattr(func, "__code__") fn_name = func.__code__.co_name currentframe = inspect.currentframe() assert currentframe is not None f = currentframe.f_back assert f is not None assert ( f.f_code.co_name == "" ), "wrap must be called at the top level of a module" Patcher._builtin_functions.append((f.f_globals, fn_name)) return func def _register_all_builtin_module(): for sub_mod in [M, M.qat, M.quantized]: for m in getmembers(sub_mod): if ( isclass(m[1]) and issubclass(m[1], M.Module) and m[1] is not M.Sequential and m[1] is not M.ModuleList ): module_tracer.register_as_builtin(m[1]) module_tracer.register_as_builtin(Observer) module_tracer.register_as_builtin(MinMaxObserver) module_tracer.register_as_builtin(SyncMinMaxObserver) module_tracer.register_as_builtin(ExponentialMovingAverageObserver) module_tracer.register_as_builtin(SyncExponentialMovingAverageObserver) module_tracer.register_as_builtin(HistogramObserver) module_tracer.register_as_builtin(PassiveObserver) module_tracer.register_as_builtin(LSQ) module_tracer.register_as_builtin(TQT) module_tracer.register_as_builtin(FakeQuantize) module_tracer.register_as_builtin(TM_FakeQuant) def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: """ Traces module ``mod`` and returns corresponding TracedModule. param mod: the module will be converted to TracedModule param input: the positional arguments passed to forward method of ``mod`` param kwargs: the keyword arguments passed to forward method of ``mod`` """ assert active_module_tracer() is None assert isinstance(mod, Module) try: use_sym_shape = set_symbolic_shape(True) set_module_tracing() set_active_module_tracer( module_tracer(_wrapped_function, _init_id2name(mod, "self")) ) with active_module_tracer().patcher: global_scope = InternalGraph(name="") active_module_tracer().push_scope(global_scope) builder = TracedModuleBuilder(mod, True) name = mod._name if mod._name else mod.__class__.__name__ NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode, orig_name="self")) inputs, _ = tree_flatten((args, kwargs)) for _, i in enumerate(inputs): # assert isinstance(i, Tensor), "not support " if isinstance(i, RawTensor): NodeMixin.wrap_safe( i, Input.make("arg_{}".format(_), NodeMixin.get_wrapped_type(i)) ) builder(*args, **kwargs) active_module_tracer().pop_scope() return builder.build() finally: set_symbolic_shape(use_sym_shape) set_active_module_tracer(None) unset_module_tracing()