From a3f9073c2c0d982c4bdcd4b58f6a510293009103 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 10 Aug 2021 16:30:04 +0800 Subject: [PATCH] feat(traced_module): update graph transform and add _module_name GitOrigin-RevId: ef63ae0fd0dcdd69c3566e19f8a34d85422a1e1e --- .../experimental/traced_module/__init__.py | 1 - .../megengine/experimental/traced_module/expr.py | 106 ++-- .../experimental/traced_module/module_tracer.py | 42 +- .../megengine/experimental/traced_module/node.py | 81 ++- .../megengine/experimental/traced_module/pytree.py | 2 +- .../experimental/traced_module/traced_module.py | 543 ++++++++++++++++----- .../test/unit/traced_module/test_modification.py | 7 +- 7 files changed, 582 insertions(+), 200 deletions(-) diff --git a/imperative/python/megengine/experimental/traced_module/__init__.py b/imperative/python/megengine/experimental/traced_module/__init__.py index bda9fe92..1a34651b 100644 --- a/imperative/python/megengine/experimental/traced_module/__init__.py +++ b/imperative/python/megengine/experimental/traced_module/__init__.py @@ -14,7 +14,6 @@ from .traced_module import ( register_as_builtin, trace_module, wrap, - wrap_tensors, ) _register_all_builtin_module() diff --git a/imperative/python/megengine/experimental/traced_module/expr.py b/imperative/python/megengine/experimental/traced_module/expr.py index 24bbee79..a5d119c5 100644 --- a/imperative/python/megengine/experimental/traced_module/expr.py +++ b/imperative/python/megengine/experimental/traced_module/expr.py @@ -33,17 +33,6 @@ def rstrip(s: str, __chars: str): return s -def lstrip(s: str, __chars: str): - __chars = re.escape(__chars) - s = re.sub(r"^(?:%s)+(?P.*)$" % __chars, "\g", s) - return s - - -def strip(s: str, __chars: str): - s = lstrip(rstrip(s, __chars), __chars) - return s - - class Expr: """ ``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``. @@ -89,27 +78,40 @@ class Expr: outputs = (outputs,) name = None + orig_name = None if isinstance(self, CallMethod): name = self.inputs[0]._name - assert name is not None + orig_name = self.inputs[0]._orig_name + assert isinstance(name, str), "The name of ({}) must be a str".format( + self.inputs[0] + ) + assert isinstance( + orig_name, str + ), "The orig_name of ({}) must be a str".format(self.inputs[0]) name = rstrip(name, "_out") if self.method == "__call__": name += "_out" + orig_name += "_out" else: - strip_method = strip(self.method, "_") + strip_method = self.method.strip("_") name = "%s_out" % strip_method + orig_name = name elif isinstance(self, CallFunction): name = self.func.__name__ + "_out" elif isinstance(self, Apply): name = str(self.opdef).lower() + "_out" for i in outputs: - assert isinstance(i, RawTensor) + assert isinstance(i, RawTensor), "The output must be a Tensor" o_name = ( active_module_tracer().current_scope()._create_unique_name(name) ) self.outputs.append( - NodeMixin.get_wrapped_type(i)(expr=self, name=o_name) + NodeMixin.get_wrapped_type(i)( + expr=self, + name=o_name, + orig_name=orig_name if orig_name else o_name, + ) ) for i, node in zip(outputs, self.outputs,): @@ -125,21 +127,26 @@ class Expr: else: return inputs, {} - def _replace_nodes(self, repl_dict: Dict[Node, Node], nodes: List[Node]): + def replace_inputs(self, repl_dict: Dict[Node, Node]): while repl_dict: node, repl_node = repl_dict.popitem() assert type(node) == type(repl_node) - assert node in nodes - index = nodes.index(node) - nodes[index] = repl_node + assert node in self.inputs, "({}) is not in the ({})".format(node, self) + assert ( + repl_node.top_graph == node.top_graph + ), "({}) and ({}) are not in the same graph".format(node, repl_node) + graph = self.top_graph + repl_expr_idx = graph._exprs.index(repl_node.expr) + self_idx = graph._exprs.index(self) + assert ( + repl_expr_idx < self_idx + ), "({}) must be generated before ({})".format(repl_node, self) + idx = self.inputs.index(node) + self.inputs[idx] = repl_node + user_idx = node.users.index(self) + assert user_idx >= 0 + node.users.pop(user_idx) repl_node.users.append(self) - node.users.pop(self) - - def replace_inputs(self, repl_dict: Dict[Node, Node]): - self._replace_nodes(repl_dict, self.inputs) - - def replace_outputs(self, repl_dict: Dict[Node, Node]): - self._replace_nodes(repl_dict, self.outputs) @property def kwargs(self): @@ -159,7 +166,8 @@ class Expr: def __getstate__(self): state = self.__dict__.copy() - state.pop("_top_graph", None) + if "_top_graph" in state: + state.pop("_top_graph") return state @@ -167,12 +175,14 @@ class Expr: class Input(Expr): name = None - def __init__(self, name=None, type=None): + def __init__(self, name=None, type=None, orig_name=None): super().__init__() self.inputs = [] node_cls = type if type else Node + if orig_name is None: + orig_name = name self.outputs = [ - node_cls(self, name=name), + node_cls(self, name=name, orig_name=orig_name), ] self.name = name @@ -184,7 +194,7 @@ class Input(Expr): active_module_tracer().current_scope()._create_unique_name(oup_node._name) ) oup_node._name = name - active_module_tracer().current_scope().add_input(oup_node) + active_module_tracer().current_scope()._add_input(oup_node) return expr.outputs[0] def __repr__(self): @@ -195,7 +205,7 @@ class Input(Expr): class GetAttr(Expr): name = None - def __init__(self, module, name, type=None): + def __init__(self, module, name, type=None, orig_name=None): super().__init__() assert isinstance(module, ModuleNode) self.inputs = [ @@ -205,7 +215,7 @@ class GetAttr(Expr): self.name = name node_cls = type if type else Node self.outputs = [ - node_cls(self, name=name), + node_cls(self, name=name, orig_name=orig_name), ] @classmethod @@ -218,7 +228,7 @@ class GetAttr(Expr): module = module.expr.inputs[0] oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name) expr.outputs[0]._name = oup_name - active_module_tracer().current_scope().insert(expr) + active_module_tracer().current_scope()._insert(expr) return expr.outputs[0] def interpret(self, *inputs): @@ -255,7 +265,7 @@ class CallMethod(Expr): @classmethod def make(cls, *args, **kwargs): expr = cls(*args, **kwargs) - active_module_tracer().current_scope().insert(expr) + active_module_tracer().current_scope()._insert(expr) return expr @property @@ -315,7 +325,7 @@ class Apply(Expr): @classmethod def make(cls, *args, **kwargs): expr = cls(*args, **kwargs) - active_module_tracer().current_scope().insert(expr) + active_module_tracer().current_scope()._insert(expr) return expr def interpret(self, *inputs): @@ -382,7 +392,7 @@ class CallFunction(Expr): @classmethod def make(cls, *args, **kwargs): expr = cls(*args, **kwargs) - active_module_tracer().current_scope().insert(expr) + active_module_tracer().current_scope()._insert(expr) return expr def interpret(self, *inputs): @@ -423,7 +433,7 @@ class Constant(Expr): self.inputs = [] node_cls = NodeMixin.get_wrapped_type(c) self.outputs = [ - node_cls(self, name=name), + node_cls(self, name=name, orig_name=name), ] self.outputs[0]._name = name if name else "const_" + str(self._id) @@ -431,9 +441,23 @@ class Constant(Expr): def make(cls, *args, **kwargs): expr = cls(*args, **kwargs) name = "const_module" if isinstance(expr.value, Module) else "const_tensor" - name = active_module_tracer().current_scope()._create_unique_name(name) + full_name = name + if ( + isinstance(expr.value, RawTensor) + and id(expr.value) in active_module_tracer().id2name + ): + full_name = active_module_tracer().id2name[id(expr.value)] + scope_name = active_module_tracer().current_scope()._module_name + if full_name and scope_name: + full_name = ("self." + full_name)[len(scope_name) + 1 :] + else: + full_name = name + else: + full_name = name + name = active_module_tracer().current_scope()._create_unique_name(full_name) expr.outputs[0]._name = name - active_module_tracer().current_scope().insert(expr) + expr.outputs[0]._orig_name = full_name + active_module_tracer().current_scope()._insert(expr) return expr.outputs[0] def interpret(self, *inputs): @@ -453,7 +477,9 @@ class Constant(Expr): ) def __getstate__(self): - state = super().__getstate__() + state = self.__dict__.copy() + if "_top_graph" in state: + state.pop("_top_graph") if isinstance(self.value, RawTensor): state["value"] = Tensor(self.value) return state diff --git a/imperative/python/megengine/experimental/traced_module/module_tracer.py b/imperative/python/megengine/experimental/traced_module/module_tracer.py index 7ef693b5..49632d29 100644 --- a/imperative/python/megengine/experimental/traced_module/module_tracer.py +++ b/imperative/python/megengine/experimental/traced_module/module_tracer.py @@ -84,6 +84,34 @@ BUILTIN_ARRAY_METHOD = [ "__setitem__", ] +BUILTIN_TENSOR_WRAP_METHOD = [ + "T", + "to", + "size", + "shape", + "detach", + "device", + "dtype", + "grad", + "item", + "name", + "ndim", + "numpy", + "qparams", + "set_value", + "reset_zero", + "requires_grad", + "_reset", + "_isscalar", + "_setscalar", + "_tuple_shape", + "_unsetscalar", +] + + +def get_tensor_wrapable_method(): + return BUILTIN_TENSOR_WRAP_METHOD + BUILTIN_ARRAY_METHOD + def active_module_tracer(): return _active_module_tracer @@ -101,9 +129,10 @@ class module_tracer: _active_scopes = None - def __init__(self, wrap_fn): + def __init__(self, wrap_fn, id2name): self._active_scopes = [] self.patcher = Patcher(wrap_fn) + self.id2name = id2name @classmethod def register_as_builtin(cls, mod): @@ -127,6 +156,10 @@ class module_tracer: return None +class NotExist: + pass + + class PatchedFn: frame_dict = None name = None @@ -138,14 +171,17 @@ class PatchedFn: self.origin_fn = ( self.frame_dict[name] if isinstance(frame_dict, collections.abc.Mapping) - else getattr(frame_dict, name) + else getattr(frame_dict, name, NotExist) ) def set_func(self, func): if isinstance(self.frame_dict, collections.abc.Mapping): self.frame_dict[self.name] = func else: - setattr(self.frame_dict, self.name, func) + if func is not NotExist: + setattr(self.frame_dict, self.name, func) + else: + delattr(self.frame_dict, self.name) class Patcher: diff --git a/imperative/python/megengine/experimental/traced_module/node.py b/imperative/python/megengine/experimental/traced_module/node.py index fb64a8bf..15c89e01 100644 --- a/imperative/python/megengine/experimental/traced_module/node.py +++ b/imperative/python/megengine/experimental/traced_module/node.py @@ -30,14 +30,17 @@ class Node: _id = None _top_graph = None # type: weakref.ReferenceType _name = None + _orig_name = None _format_spec = "" - def __init__(self, expr: "Expr", name: str = None): + def __init__(self, expr: "Expr", name: str = None, orig_name: str = None): self.expr = expr self.users = [] # List[Expr] self._id = Node.__total_id Node.__total_id += 1 self._name = name + self._orig_name = orig_name + self.actual_node = [] # type: List[Node] def __setstate__(self, d): self.__dict__ = d @@ -48,7 +51,7 @@ class Node: return self.__format__(format_spec) def __format__(self, format_spec: str) -> str: - if format_spec == "" or format_spec is None: + if not format_spec: format_spec = Node._format_spec name = self._name if name is None: @@ -100,9 +103,8 @@ class ModuleNode(Node): module_type = Module # type: Type[Module] _owner = None # type: weakref.ReferenceType - def __init__(self, expr: "Expr", name: str = None): - super().__init__(expr, name) - self.actual_mnode = [] + def __init__(self, expr: "Expr", name: str = None, orig_name: str = None): + super().__init__(expr, name, orig_name) def __getstate__(self): return { @@ -110,6 +112,7 @@ class ModuleNode(Node): "users": self.users, "_id": self._id, "_name": self._name, + "_orig_name": self._orig_name, "module_type": self.module_type, } @@ -125,23 +128,67 @@ class TensorNode(Node): ``TensorNode`` represents the Tensor objects. """ - shape = None # type: Tuple[int] - dtype = None # type: numpy.dtype - qparams = None - device = None + _shape = None # type: Tuple[int] + _dtype = None # type: numpy.dtype + _qparams = None + _device = None + _value = None # type: Tensor def __getstate__(self): return { "expr": self.expr, "users": self.users, "_id": self._id, - "qparams": self.qparams, - "shape": self.shape, - "dtype": self.dtype, - "device": self.device, + "_qparams": self._qparams, + "_shape": self._shape, + "_dtype": self._dtype, + "_device": self._device, "_name": self._name, + "_orig_name": self._orig_name, } + @property + def shape(self): + return self._shape + + @shape.setter + def shape(self, shape): + self._shape = shape + + @property + def dtype(self): + return self._dtype + + @dtype.setter + def dtype(self, dtype): + self._dtype = dtype + + @property + def device(self): + return self._device + + @device.setter + def device(self, device): + self._device = device + + @property + def qparams(self): + return self._qparams + + @qparams.setter + def qparams(self, qparams): + self._qparams = qparams + + @property + def value(self): + return self._value + + @value.setter + def value(self, value): + if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None: + setattr(value, "_NodeMixin__node", None) + self._value = value + class NodeMixin(abc.ABC): __node = None @@ -156,13 +203,13 @@ class NodeMixin(abc.ABC): assert isinstance(node, TensorNode) assert isinstance(value, RawTensor) if isinstance(value, RawTensor): - node.dtype = value.dtype - node.shape = ( + node._dtype = value.dtype + node._shape = ( value._tuple_shape if isinstance(value, Tensor) else value.shape ) - node.device = value.device + node._device = value.device if hasattr(value, "_qparams") and value._qparams is not None: - node.qparams = value.qparams + node._qparams = value.qparams @classmethod def wrap(cls, value, node): diff --git a/imperative/python/megengine/experimental/traced_module/pytree.py b/imperative/python/megengine/experimental/traced_module/pytree.py index 8382adc8..686b651e 100644 --- a/imperative/python/megengine/experimental/traced_module/pytree.py +++ b/imperative/python/megengine/experimental/traced_module/pytree.py @@ -133,7 +133,7 @@ def _is_leaf(obj): def _leaf_type(node): if isinstance(node, (RawTensor, TensorNode)): return (Tensor, TensorNode, ArgsIndex) - elif isinstance(node, (NodeMixin, Module)): + elif isinstance(node, (NodeMixin, Module, ModuleNode)): return (Module, ModuleNode, NodeMixin, ArgsIndex) else: return (type(node), ArgsIndex) diff --git a/imperative/python/megengine/experimental/traced_module/traced_module.py b/imperative/python/megengine/experimental/traced_module/traced_module.py index bb5cce04..66ee1fa3 100644 --- a/imperative/python/megengine/experimental/traced_module/traced_module.py +++ b/imperative/python/megengine/experimental/traced_module/traced_module.py @@ -9,14 +9,19 @@ import builtins import collections import copy +import ctypes import fnmatch import functools +import inspect import keyword import re import weakref from inspect import getcallargs, getmembers, isclass, ismethod +from itertools import chain from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union +from megengine import tensor + from ... import functional as F from ... import get_logger from ... import module as M @@ -44,8 +49,10 @@ from ...tensor import Tensor from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input from .fake_quant import FakeQuantize as TM_FakeQuant from .module_tracer import ( + PatchedFn, Patcher, active_module_tracer, + get_tensor_wrapable_method, module_tracer, set_active_module_tracer, ) @@ -70,46 +77,267 @@ def _is_leaf(node): return isinstance(node, RawTensor) -def wrap_tensors(tensors: Tensor, nodes: TensorNode): - inp_tensors = copy.deepcopy(tensors) - inp_tensors, inp_def_v = tree_flatten(inp_tensors) - inp_nodes, inp_def_n = tree_flatten(nodes) - for v, n in zip(inp_tensors, inp_nodes): - if isinstance(n, TensorNode) and isinstance(v, Tensor): - NodeMixin.wrap_safe(v, n) - return inp_def_v.unflatten(inp_tensors) +_enable_node_to_tensor = False + + +def _convert_node_flag(): + return _enable_node_to_tensor + + +def _set_convert_node_flag(flag: bool = False): + global _enable_node_to_tensor + pre_flag = _enable_node_to_tensor + _enable_node_to_tensor = flag + return pre_flag + + +def _node_to_tensor(*args, **kwargs): + tensors = [] + nodes, tree_def = tree_flatten((args, kwargs)) + for n in nodes: + if isinstance(n, TensorNode): + if n.top_graph is not None: + active_module_tracer().current_scope()._add_input(n) + value = n.value + if value is None: + flag = _set_convert_node_flag(False) + unset_module_tracing() + value = F.zeros(shape=n._shape, dtype=n._dtype) + set_module_tracing() + _set_convert_node_flag(flag) + orig_n = NodeMixin.get(value, None) + if orig_n is None or "setitem" not in orig_n._name: + NodeMixin.wrap_safe(value, n) + tensors.append(value) + else: + tensors.append(n) + tensors = tree_def.unflatten(tensors) + return tensors + + +def _tensor_to_node(tensors): + if tensors is None: + return None + nodes = [] + tensors, out_def = tree_flatten(tensors) + for t in tensors: + if isinstance(t, Tensor): + n = NodeMixin.get(t, None) + if isinstance(n, TensorNode): + n.value = t + nodes.append(n) + else: + nodes.append(t) + else: + nodes.append(t) + nodes = out_def.unflatten(nodes) + return nodes + + +def _wrap_method_to_tensor_node(): + def _any_method(name): + def _any(*args, **kwargs): + args, kwargs = _node_to_tensor(*args, **kwargs) + attr = getattr(args[0], name) + outs = attr + if callable(attr): + outs = attr(*(args[1:]), **kwargs) + if name == "__setitem__": + _node_to_tensor(outs) + return None + outs = _tensor_to_node(outs) + return outs + + return _any + + tensor_method_patch = [] + for method in get_tensor_wrapable_method(): + patch = PatchedFn(TensorNode, method) + if type(getattr(Tensor, method)) == property: + patch.set_func(property(_any_method(method))) + else: + patch.set_func(_any_method(method)) + tensor_method_patch.append(patch) + return tensor_method_patch + + +def _convert_node_and_tensor(orig_func): + @functools.wraps(orig_func) + def _convert(*args, **kwargs): + if _convert_node_flag() and is_tracing_module(): + args, kwargs = _node_to_tensor(*args, **kwargs) + rst = orig_func(*args, **kwargs, method_func=_convert) + rst = _tensor_to_node(rst) + return rst + else: + rst = orig_func(*args, **kwargs) + return rst + + return _convert + + +def _wrap_mnode_getattr(orig_getattr): + @functools.wraps(orig_getattr) + def wraped_fn(self, name): + obj = self.owner + if self.top_graph is not None: + active_module_tracer().current_scope()._add_input(self) + attr = getattr(obj, name) + node = attr + full_name = None + if id(attr) in active_module_tracer().id2name: + full_name = active_module_tracer().id2name[id(attr)] + + if not isinstance(attr, TracedModuleBuilder): + if isinstance(attr, Module): + attr = TracedModuleBuilder(attr) + setattr(obj, name, attr) + active_module_tracer().id2name[id(attr)] = full_name + + if isinstance(attr, (NodeMixin, RawTensor)): + if full_name: + scope_name = active_module_tracer().current_scope()._module_name + if scope_name: + full_name = full_name[len(scope_name) + 1 :] + else: + full_name = name + else: + full_name = name + NodeMixin.wrap( + attr, + lambda: GetAttr.make( + self, + name, + type=NodeMixin.get_wrapped_type(attr), + orig_name=full_name, + ), + ) + if isinstance(attr, (NodeMixin, RawTensor)): + node = NodeMixin.get(attr) + if isinstance(node, ModuleNode): + node._owner = weakref.ref(attr) + return node + + return wraped_fn + + +def _wrap_mnode_call(orig_call): + @functools.wraps(orig_call) + def wraped_fn(self, *args, **kwargs): + obj = self.owner + if self.top_graph is not None: + active_module_tracer().current_scope()._add_input(self) + rst = obj(*args, **kwargs) + return rst + + return wraped_fn + + +def _init_id2name(mod: Module, prefix: str = ""): + id2name = { + id(m): "%s.%s" % (prefix, key) + for key, m in chain( + mod.named_modules(), mod.named_parameters(), mod.named_buffers() + ) + } + return id2name class _InsertExprs: - def __init__(self, graph, expr: Optional[Expr] = None, after: bool = True): + def __init__(self, graph, expr: Optional[Expr] = None): self.graph = graph - self.global_scope = InternalGraph() + self.global_scope = InternalGraph( + graph._name, graph._prefix_name, graph._module_name + ) self.global_scope._used_names.update(graph._used_names) self.expr = expr - self.after = after + self._tensor_method_patch = None def __enter__(self): self.use_sym_shape = set_symbolic_shape(True) set_module_tracing() + _set_convert_node_flag(True) assert active_module_tracer() is None - set_active_module_tracer(module_tracer(_wrapped_function)) + module = self.graph.inputs[0].owner + _wrap_func = lambda x: _convert_node_and_tensor(_wrapped_function(x)) + set_active_module_tracer( + module_tracer(_wrap_func, _init_id2name(module, self.graph._module_name)) + ) active_module_tracer().patcher.__enter__() + for cls, name, func in [ + [ModuleNode, "__getattr__", _wrap_mnode_getattr], + [ModuleNode, "__call__", _wrap_mnode_call], + [TracedModuleBuilder, "__call__", _convert_node_and_tensor], + ]: + active_module_tracer().patcher.patch_function(cls, name, func) + self._tensor_method_patch = _wrap_method_to_tensor_node() active_module_tracer().push_scope(self.global_scope) def __exit__(self, ty, va, tr): + if va is not None: + return False set_symbolic_shape(self.use_sym_shape) unset_module_tracing() active_module_tracer().patcher.__exit__(ty, va, tr) + _set_convert_node_flag(False) + + while self._tensor_method_patch: + pf = self._tensor_method_patch.pop() + pf.set_func(pf.origin_fn) + + module = self.graph.inputs[0].owner + + for mod, parent in module.modules(with_parent=True): + name = mod._name + if isinstance(mod, TracedModuleBuilder): + mod = mod.build() + if hasattr(mod, "graph"): + for node in mod.graph.nodes(): + node.value = None + setattr(parent, name, mod) set_active_module_tracer(None) - index = len(self.graph._exprs) if self.after else 0 + + for node in self.global_scope.nodes(): + node.value = None + + extra_inp_nodes = set(self.global_scope.inputs) + max_inp_expr_idx = -1 + for node in extra_inp_nodes: + assert ( + node.top_graph == self.graph + ), "The input node ({}) is not in the graph ({})".format(node, self.graph) + if isinstance(node, TensorNode) and node.expr in self.graph._exprs: + max_inp_expr_idx = max( + max_inp_expr_idx, self.graph._exprs.index(node.expr) + ) + max_inp_expr_idx += 1 + + insert_index = -1 if self.expr is not None: - index = self.graph._exprs.index(self.expr) - if self.after: - index += 1 + insert_index = self.graph._exprs.index(self.expr) + insert_index += 1 + + if insert_index < max_inp_expr_idx: + insert_index = max_inp_expr_idx + + anchor_index = insert_index - 1 + if anchor_index >= 0: + logger.info( + "The new expr will be inserted after ( {} )".format( + self.graph._exprs[anchor_index] + ) + ) + for expr in self.global_scope._exprs: - self.graph._exprs.insert(index, expr) - index += 1 + self.graph._exprs.insert(insert_index, expr) + insert_index += 1 + self.graph._used_names.update(self.global_scope._used_names) + graph = self.graph + while graph.top_graph is not None: + graph = graph.top_graph + graph.inputs[0].owner._update_ref() + return True class InternalGraph: @@ -125,8 +353,9 @@ class InternalGraph: _exprs = None # type: List[Expr] _inputs = None # type: List[Node] _outputs = None # type: List[Node] + _top_graph = None - def __init__(self, name: str = None, prefix_name: str = ""): + def __init__(self, name: str = None, prefix_name: str = "", module_name: str = ""): self._exprs = [] self._inputs = [] self._outputs = [] @@ -136,12 +365,13 @@ class InternalGraph: self._rst = collections.defaultdict(list) self._name = name self._prefix_name = prefix_name + self._module_name = module_name - def insert(self, expr): + def _insert(self, expr): self._exprs.append(expr) def _create_unique_name(self, name: str) -> str: - assert isinstance(name, str) + assert isinstance(name, str), "The name must be a str" name = re.sub("[^0-9a-zA-Z_]+", "_", name) if name[0].isdigit(): name = "_{}".format(name) @@ -166,40 +396,45 @@ class InternalGraph: return self._outputs @property - def expr_filter(self): - return ExprFilter(_expr_iter(self)) + def top_graph(self): + if self._top_graph: + return self._top_graph() + return None - @property - def node_filter(self): - return NodeFilter(_node_iter(self)) + def exprs(self, recursive=True): + return ExprFilter(_expr_iter(self, recursive)) + + def nodes(self, recursive=True): + return NodeFilter(_node_iter(self, recursive)) - def get_function_by_type(self, func: Callable = None): - return self.expr_filter.call_function(func) + def get_function_by_type(self, func: Callable = None, recursive=True): + return self.exprs(recursive).call_function(func) - def get_method_by_type(self, method: str = None): - return self.expr_filter.call_method(method) + def get_method_by_type(self, method: str = None, recursive=True): + return self.exprs(recursive).call_method(method) - def get_expr_by_id(self, expr_id: List[int] = None): - return self.expr_filter.expr_id(expr_id) + def get_expr_by_id(self, expr_id: List[int] = None, recursive=True): + return self.exprs(recursive).expr_id(expr_id) - def get_module_by_type(self, module_cls: Module): + def get_module_by_type(self, module_cls: Module, recursive=True): assert issubclass(module_cls, Module) - return self.node_filter.type(module_cls, ModuleNode) + return self.nodes(recursive).type(module_cls, ModuleNode) - def get_node_by_id(self, node_id: List[int] = None): - return self.node_filter.node_id(node_id) + def get_node_by_id(self, node_id: List[int] = None, recursive=True): + return self.nodes(recursive).node_id(node_id) - def get_node_by_name(self, name: str = None, ignorecase: bool = True): - return self.node_filter.name(name, ignorecase) + def get_node_by_name( + self, name: str = None, ignorecase: bool = True, recursive=True + ): + return self.nodes(recursive).name(name, ignorecase) - def add_input(self, i): + def _add_input(self, i): self._inputs.append(i) - def add_output(self, o): + def _add_output(self, o): self._outputs.append(o) - def _replace_inputs_outputs_and_add_prefixname(self, repl_dict, prefix_name=""): - + def _replace_inputs_outputs(self, repl_dict, prefix_name="", module_name=""): for node, repl_node in repl_dict.items(): assert node in self._inputs or node in self._outputs for i in node.users: @@ -212,12 +447,15 @@ class InternalGraph: for idx, o in enumerate(self._outputs): if o in repl_dict: + repl_dict[o]._orig_name = "{}{}".format(module_name, o._orig_name) self._outputs[idx] = repl_dict[o] for expr in self._exprs: for idx, i in enumerate(expr.inputs): - assert i._name is not None + assert isinstance( + i._name, str + ), "The node ({}) name must be a str".format(i) if i in repl_dict: expr.inputs[idx] = repl_dict[i] elif isinstance(i, TensorNode) and prefix_name not in i._name: @@ -227,9 +465,12 @@ class InternalGraph: .current_scope() ._create_unique_name(prefix_name + i._name.lstrip("_")) ) + i._orig_name = "{}{}".format(module_name, i._orig_name) for idx, o in enumerate(expr.outputs): - assert o._name is not None + assert isinstance( + o._name, str + ), "The node ({}) name must be a str".format(i) if o in repl_dict: expr.outputs[idx] = repl_dict[o] expr.outputs[idx].expr = expr @@ -240,6 +481,7 @@ class InternalGraph: .current_scope() ._create_unique_name(prefix_name + o._name.lstrip("_")) ) + o._orig_name = "{}{}".format(module_name, o._orig_name) def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: if not isinstance(nodes, Sequence): @@ -263,7 +505,7 @@ class InternalGraph: def reset_inputs(self, *args, **kwargs): forma_mnode = self.inputs[0] - actual_mnodes = forma_mnode.actual_mnode + actual_mnodes = forma_mnode.actual_node call_nodes = [] for n in actual_mnodes: for c_expr in n.users: @@ -318,7 +560,7 @@ class InternalGraph: def add_input_node(self, shape, dtype="float32", name="args"): forma_mnode = self.inputs[0] - actual_mnodes = forma_mnode.actual_mnode + actual_mnodes = forma_mnode.actual_node moudle = forma_mnode.owner assert moudle._is_top, "add_input_node only support the top-level graph" @@ -378,7 +620,7 @@ class InternalGraph: moudle = forma_mnode.owner assert moudle._is_top, "reset_outputs only support the top-level graph" - actual_mnodes = forma_mnode.actual_mnode + actual_mnodes = forma_mnode.actual_node call_nodes = [] for n in actual_mnodes: for c_expr in n.users: @@ -406,7 +648,6 @@ class InternalGraph: self._outputs[:] = outputs moudle.argdef_outdef_map[tree_def] = out_def - return actual_nodes def add_output_node(self, node: TensorNode): @@ -415,7 +656,7 @@ class InternalGraph: moudle = forma_mnode.owner assert moudle._is_top, "add_output_node only support the top-level graph" - actual_mnodes = forma_mnode.actual_mnode + actual_mnodes = forma_mnode.actual_node call_nodes = [] for n in actual_mnodes: @@ -455,74 +696,35 @@ class InternalGraph: return actual_out_nodes - def insert_function(self, func: Callable, *args, **kwargs): - assert isinstance(func, Callable) - - inp_nodes, inp_def = tree_flatten((args, kwargs)) - - insert_idx = -1 - for i in inp_nodes: - if isinstance(i, TensorNode) and i.expr in self._exprs: - insert_idx = max(insert_idx, self._exprs.index(i.expr)) - - fake_inp_val = list( - F.zeros(shape=i.shape, dtype=i.dtype) if isinstance(i, TensorNode) else i - for i in inp_nodes - ) - - for v, n in zip(fake_inp_val, inp_nodes): - if isinstance(n, TensorNode): - NodeMixin.wrap_safe(v, n) - - fake_args, fake_kwargs = inp_def.unflatten(fake_inp_val) - - insert_point = self.insert_exprs_before() - if insert_idx != -1: - insert_point = self.insert_exprs_after(self._exprs[insert_idx]) - - with insert_point: - rst = func(*fake_args, **fake_kwargs) - - if rst is None: - return None - - outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) - node_outputs = [] - for out in outputs: - assert isinstance(out, RawTensor) - node_outputs.append(NodeMixin.get(out, None)) - - node_outputs = out_def.unflatten(node_outputs) - return node_outputs - - def insert_exprs_after(self, expr: Optional[Expr] = None): + def insert_exprs(self, expr: Optional[Expr] = None): if expr is not None: assert expr.top_graph == self, "Expr to insert after is not in graph." - return _InsertExprs(self, expr, after=True) - - def insert_exprs_before(self, expr: Optional[Expr] = None): - if expr is not None: - assert expr.top_graph == self, "Expr to insert before is not in graph." - return _InsertExprs(self, expr, after=False) + return _InsertExprs(self, expr) def replace_node(self, repl_dict: Dict[Node, Node]): while repl_dict: node, repl_node = repl_dict.popitem() # check graph inputs and outputs - assert node not in self.inputs, "Cannot replace inputs" + # assert node not in self.inputs, "Cannot replace inputs" for i, n in enumerate(self.outputs): if n is node: self.outputs[i] = repl_node # update users of node and repl_node # update inputs of expr in node.users + graph = repl_node.top_graph + assert graph is not None + index = graph._exprs.index(repl_node.expr) dep_exprs = self.get_dep_exprs(repl_node) i = 0 while i < len(node.users): n = node.users[i] + if n in graph._exprs and index >= graph._exprs.index(n): + i += 1 + continue if n in dep_exprs: logger.info("Find a loop: ignore this replacement once") logger.info("node: %s" % node.__repr__()) - logger.info("repl_node: %s" % repl_node.__repr__()) + logger.info("expr: %s" % n.__repr__()) i += 1 continue repl_node.users.append(n) @@ -598,6 +800,12 @@ class InternalGraph: Node.set_format_spec(saved_format_spec) return res + def __getstate__(self): + state = self.__dict__.copy() + if "_top_graph" in state: + state.pop("_top_graph") + return state + def _get_meth_name(obj, func): tp = obj if isinstance(obj, type) else type(obj) @@ -611,6 +819,9 @@ def _get_meth_name(obj, func): def _wrapped_function(orig_func): @functools.wraps(orig_func) def wrapped_fn(*args, **kwargs): + method_func = wrapped_fn + if "method_func" in kwargs: + method_func = kwargs.pop("method_func") if is_tracing_module(): unset_module_tracing() inputs, tree_def = tree_flatten((args, kwargs)) @@ -618,9 +829,11 @@ def _wrapped_function(orig_func): if not NodeMixin.get(i, None): if isinstance(i, (RawTensor, NodeMixin)): NodeMixin.wrap_safe(i, Constant.make(i)) - meth_name = _get_meth_name(args[0], wrapped_fn) if args else None - arg_type = args[0] if isinstance(args[0], type) else type(args[0]) - if meth_name and issubclass(arg_type, RawTensor): + meth_name, arg_type = None, None + if args: + meth_name = _get_meth_name(args[0], method_func) + arg_type = args[0] if isinstance(args[0], type) else type(args[0]) + if meth_name and arg_type and issubclass(arg_type, RawTensor): self = inputs[0] if meth_name == "__new__": if all([not isinstance(i, RawTensor) for i in inputs]): @@ -799,6 +1012,9 @@ class TracedModuleBuilder(NodeMixin): def __call__(self, *args, **kwargs): assert isinstance(self._mod, Module) # prepare args and kwargs for inner graph + if "method_func" in kwargs: + kwargs.pop("method_func") + def mark_constant(x): node = NodeMixin.get(x, None) if node is None: # capture as constant @@ -829,9 +1045,6 @@ class TracedModuleBuilder(NodeMixin): else: self._mod._is_top = False self._body = self._mod.graph - name = NodeMixin.get(self)._name - if name: - self._body._name = name else: self_node = None orig_self = NodeMixin.get(self) @@ -841,19 +1054,24 @@ class TracedModuleBuilder(NodeMixin): graph_prefix_name = "{}_{}".format( top_graph._prefix_name, graph_prefix_name.lstrip("_") ) - self._body = InternalGraph(orig_self._name, prefix_name=graph_prefix_name) + module_name = orig_self._orig_name + if top_graph._module_name: + module_name = "{}.{}".format(top_graph._module_name, module_name) + self._body = InternalGraph( + orig_self._name, prefix_name=graph_prefix_name, module_name=module_name + ) active_module_tracer().push_scope(self._body) # rebind self to new input node if self_node: NodeMixin.wrap_safe(self, self_node) - active_module_tracer().current_scope().add_input(self_node) + active_module_tracer().current_scope()._add_input(self_node) else: NodeMixin.wrap_safe( self, self_node if self_node - else Input.make("self", NodeMixin.get_wrapped_type(self)), + else Input.make("self", NodeMixin.get_wrapped_type(self), ""), ) origin_inp_node = [NodeMixin.get(i, None) for i in inputs[1:]] # prepare args and kwargs for inner graph @@ -893,12 +1111,13 @@ class TracedModuleBuilder(NodeMixin): getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) ) rst = type(self._mod).forward(*args, **kwargs) + if _convert_node_flag(): + rst = _node_to_tensor(rst)[0][0] outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) for i in ( outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) ): - active_module_tracer().current_scope().add_output(NodeMixin.get(i)) - NodeMixin.get(self, None).actual_mnode.append(orig_self) + active_module_tracer().current_scope()._add_output(NodeMixin.get(i)) NodeMixin.wrap_safe(self, orig_self) for arg, node in zip(inputs[1:], origin_inp_node): if node: @@ -923,14 +1142,33 @@ class TracedModuleBuilder(NodeMixin): attr = getattr(type(self._mod), name).__get__(self, type(self)) else: attr = getattr(self._mod, name) + full_name = None + + if id(attr) in active_module_tracer().id2name: + full_name = active_module_tracer().id2name[id(attr)] + if isinstance(attr, Module): attr = TracedModuleBuilder(attr) + if isinstance(attr, (Module, RawTensor)): setattr(self, name, attr) + active_module_tracer().id2name[id(attr)] = full_name + + if full_name: + scope_name = active_module_tracer().current_scope()._module_name + if scope_name: + full_name = full_name[len(scope_name) + 1 :] + else: + full_name = name + else: + full_name = name NodeMixin.wrap( attr, lambda: GetAttr.make( - NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr) + NodeMixin.get(self), + name, + type=NodeMixin.get_wrapped_type(attr), + orig_name=full_name, ), ) return attr @@ -951,7 +1189,16 @@ class TracedModuleBuilder(NodeMixin): assert mod_attr is wrapped._mod else: assert mod_attr is wrapped - + full_name = None + if id(mod_attr) in active_module_tracer().id2name: + full_name = active_module_tracer().id2name[id(mod_attr)] + scope_name = active_module_tracer().current_scope()._module_name + if full_name and scope_name: + full_name = full_name[len(scope_name) + 1 :] + else: + full_name = name + else: + full_name = name # assert not self._is_builtin if isinstance(wrapped, (NodeMixin, RawTensor)): NodeMixin.wrap( @@ -960,6 +1207,7 @@ class TracedModuleBuilder(NodeMixin): NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(wrapped), + orig_name=full_name, ), ) @@ -967,24 +1215,25 @@ class TracedModuleBuilder(NodeMixin): class _expr_iter: - def __init__(self, graph: InternalGraph): + def __init__(self, graph: InternalGraph, recursive: bool = True): self.graph = graph + self.recursive = recursive def __iter__(self): for expr in self.graph._exprs: if isinstance(expr, CallMethod) and isinstance(expr.inputs[0], ModuleNode): yield expr - if expr.graph is not None: - yield from expr.graph.expr_filter + if self.recursive and expr.graph is not None: + yield from expr.graph.exprs(self.recursive) else: yield expr class _node_iter: - def __init__(self, graph: InternalGraph) -> None: + def __init__(self, graph: InternalGraph, recursive: bool = True) -> None: nodes = [] node_ids = set() - for expr in graph.expr_filter: + for expr in graph.exprs(recursive): for n in expr.inputs + expr.outputs: if n._id in node_ids: continue @@ -1210,14 +1459,17 @@ class TracedModule(Module): assert len(self.argdef_graph_map) == 1 return list(self.argdef_graph_map.values())[0] - def _update_ref(self, actual_node_map: Union[Dict] = None): + def _update_ref(self, actual_node_map: Union[Dict] = None, top_graph=None): for inp_def, graph in self.argdef_graph_map.items(): + if top_graph is not None: + graph._top_graph = weakref.ref(top_graph) for n in graph._inputs + graph.outputs: n._top_graph = weakref.ref(graph) graph._inputs[0]._owner = weakref.ref(self) - graph._inputs[0].actual_mnode = [] - if actual_node_map is not None and inp_def in actual_node_map.keys(): - graph._inputs[0].actual_mnode = actual_node_map[inp_def] + for i, n in enumerate(graph._inputs): + n.actual_node = [] + if actual_node_map is not None and inp_def in actual_node_map.keys(): + n.actual_node = list(list(zip(*(actual_node_map[inp_def])))[i]) node2obj = {} next_actual_node_map = collections.defaultdict( lambda: collections.defaultdict(list) @@ -1246,7 +1498,7 @@ class TracedModule(Module): ): obj = node2obj[expr.inputs[0]] if expr.arg_def is not None: - next_actual_node_map[obj][expr.arg_def].append(expr.inputs[0]) + next_actual_node_map[obj][expr.arg_def].append(expr.inputs) for obj in node2obj.values(): if obj is self: @@ -1255,7 +1507,7 @@ class TracedModule(Module): if obj in next_actual_node_map.keys(): mnode_map = next_actual_node_map[obj] if isinstance(obj, TracedModule): - obj._update_ref(mnode_map) + obj._update_ref(mnode_map, graph) def flatten(self): """ @@ -1264,21 +1516,25 @@ class TracedModule(Module): :return: :class:`TracedModule` """ new_module = copy.deepcopy(self) - module2name = {} assert active_module_tracer() is None - set_active_module_tracer(module_tracer(lambda x: x)) + id2name = _init_id2name(new_module, "self") + set_active_module_tracer(module_tracer(lambda x: x, {})) active_module_tracer().push_scope(new_module.graph) - for n, m in new_module.named_modules(): - module2name[id(m)] = n def _flatten_subgraph( - graph: InternalGraph, module: Module, call=None, prefix_name="" + graph: InternalGraph, + module: Module, + call=None, + prefix_name="", + module_name="", ): - if graph is not None and prefix_name and prefix_name[-1] != "_": + if isinstance(prefix_name, str) and prefix_name and prefix_name[-1] != "_": prefix_name += "_" + if isinstance(module_name, str) and module_name: + module_name += "." if graph is None or module.is_qat: assert not isinstance(module, TracedModule) or module.is_qat - const = Constant(module, "self.%s" % module2name[id(module)]) + const = Constant(module, id2name[id(module)]) m_node = call.inputs[0] if m_node.top_graph != active_module_tracer().current_scope(): m_node._name = ( @@ -1286,6 +1542,7 @@ class TracedModule(Module): .current_scope() ._create_unique_name(prefix_name) ) + m_node._orig_name = id2name[id(module)][5:] const.outputs[0] = m_node const.outputs[0].expr = const return [const, call] @@ -1312,7 +1569,7 @@ class TracedModule(Module): continue repl_dict[out] = call.outputs[ind] - graph._replace_inputs_outputs_and_add_prefixname(repl_dict, prefix_name) + graph._replace_inputs_outputs(repl_dict, prefix_name, module_name) for expr in graph._exprs: if isinstance(expr, GetAttr): @@ -1344,6 +1601,7 @@ class TracedModule(Module): obj, expr, prefix_name + obj_node._name.lstrip("_"), + module_name + obj_node._orig_name, ) ) else: @@ -1358,7 +1616,6 @@ class TracedModule(Module): if call is not None: for i in call.inputs: i.users.remove(call) - return exprs new_module.graph._exprs = _flatten_subgraph(new_module.graph, new_module) @@ -1396,7 +1653,22 @@ def register_as_builtin(mod_cls: Type[Module]) -> None: module_tracer.register_as_builtin(mod_cls) -wrap = _wrapped_function +def wrap(func: Callable): + """ + Call this function to register func as a builtin function. + """ + assert callable(func), "func must be a callable" + assert hasattr(func, "__code__") + fn_name = func.__code__.co_name + currentframe = inspect.currentframe() + assert currentframe is not None + f = currentframe.f_back + assert f is not None + assert ( + f.f_code.co_name == "" + ), "wrap must be called at the top level of a module" + Patcher._builtin_functions.append((f.f_globals, fn_name)) + return func def _register_all_builtin_module(): @@ -1438,14 +1710,15 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: try: use_sym_shape = set_symbolic_shape(True) set_module_tracing() - set_active_module_tracer(module_tracer(_wrapped_function)) - + set_active_module_tracer( + module_tracer(_wrapped_function, _init_id2name(mod, "self")) + ) with active_module_tracer().patcher: global_scope = InternalGraph(name="") active_module_tracer().push_scope(global_scope) builder = TracedModuleBuilder(mod, True) name = mod._name if mod._name else mod.__class__.__name__ - NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode)) + NodeMixin.wrap_safe(builder, Input.make(name, ModuleNode, orig_name="self")) inputs, _ = tree_flatten((args, kwargs)) for _, i in enumerate(inputs): # assert isinstance(i, Tensor), "not support " diff --git a/imperative/python/test/unit/traced_module/test_modification.py b/imperative/python/test/unit/traced_module/test_modification.py index 5ccdcfa9..593a5448 100644 --- a/imperative/python/test/unit/traced_module/test_modification.py +++ b/imperative/python/test/unit/traced_module/test_modification.py @@ -64,9 +64,10 @@ def test_search(): def test_insert(): traced_module, x, expect = _init_block() graph = traced_module.graph - relu_node = graph.get_function_by_type(F.relu).as_unique().outputs - neg_node = graph.insert_function(lambda x: F.neg(x), *relu_node) - graph.replace_node({relu_node[0]: neg_node}) + relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0] + with graph.insert_exprs(): + neg_out = F.neg(relu_out) + graph.replace_node({relu_out: neg_out}) graph.compile() np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)