From 3a21920955c6f760383950cc00f37dfde73ce8be Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 27 Dec 2021 18:48:13 +0800 Subject: [PATCH] fix(mge/traced_module): fix some bugs GitOrigin-RevId: 88f98829cec88df120a05318209f78f065699bf8 --- .../python/megengine/traced_module/__init__.py | 2 + .../megengine/traced_module/_passes/const_pass.py | 16 ++++--- .../python/megengine/traced_module/pytree.py | 50 +++++++++++++++++++--- .../megengine/traced_module/traced_module.py | 11 +++-- imperative/python/megengine/traced_module/utils.py | 10 +++-- 5 files changed, 69 insertions(+), 20 deletions(-) diff --git a/imperative/python/megengine/traced_module/__init__.py b/imperative/python/megengine/traced_module/__init__.py index c906b879..848c968e 100644 --- a/imperative/python/megengine/traced_module/__init__.py +++ b/imperative/python/megengine/traced_module/__init__.py @@ -9,6 +9,7 @@ from ..core._imperative_rt.core2 import set_cpp_apply_module_trace from . import compat from ._passes import optimize +from .pytree import register_supported_type from .traced_module import ( TracedModule, _register_all_builtin_module, @@ -23,6 +24,7 @@ set_cpp_apply_module_trace(cpp_apply_module_trace) __all__ = [ "register_as_builtin", + "register_supported_type", "trace_module", "wrap", "TracedModule", diff --git a/imperative/python/megengine/traced_module/_passes/const_pass.py b/imperative/python/megengine/traced_module/_passes/const_pass.py index 143a704c..0ff3571b 100644 --- a/imperative/python/megengine/traced_module/_passes/const_pass.py +++ b/imperative/python/megengine/traced_module/_passes/const_pass.py @@ -12,7 +12,7 @@ from ...core.ops.builtin import GetVarShape from ...logger import get_logger from ...tensor import Tensor from ..expr import Constant, Expr, is_apply_def, is_constant, is_getattr -from ..node import Node, TensorNode +from ..node import Node, NodeMixin, TensorNode from .matcher import PatternMatcher from .pass_base import BackwardPass, ForwardPass, register_pass from .pattern import is_op @@ -21,6 +21,12 @@ from .utils import get_const_value logger = get_logger(__name__) +def _as_const_node(x): + node = Constant.make(x) + NodeMixin.wrap(x, node) + return node + + @register_pass("AttrToConstant") class AttrToConstant(BackwardPass): r"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr.""" @@ -35,10 +41,10 @@ class AttrToConstant(BackwardPass): orig_node = expr.outputs[0] name = orig_node.name with graph.insert_exprs(expr): - const_node = Constant.make(value, name=name) + const_node = _as_const_node(value) graph.replace_node({orig_node: const_node}) graph.compile() - name = orig_node.name + const_node.name = name return const_node.expr @@ -53,7 +59,7 @@ class FixInputShape(BackwardPass): shape = Tensor(expr.inputs[0].shape, dtype="int32") graph = expr.top_graph with graph.insert_exprs(expr): - const_shape = Constant.make(shape) + const_shape = _as_const_node(shape) graph.replace_node({expr.outputs[0]: const_shape}) graph.compile() const_shape.name = expr.outputs[0].name @@ -73,7 +79,7 @@ class FlodConstant(ForwardPass): const_var = expr.interpret(*[get_const_value(n.expr) for n in expr.inputs])[0] graph = expr.top_graph with graph.insert_exprs(expr): - const_node = Constant.make(const_var) + const_node = _as_const_node(const_var) graph.replace_node({expr.outputs[0]: const_node}) graph.compile() const_node.name = expr.outputs[0].name diff --git a/imperative/python/megengine/traced_module/pytree.py b/imperative/python/megengine/traced_module/pytree.py index 98d19f1e..4a9f1e0b 100644 --- a/imperative/python/megengine/traced_module/pytree.py +++ b/imperative/python/megengine/traced_module/pytree.py @@ -10,7 +10,7 @@ import collections from collections import OrderedDict, defaultdict from functools import partial from inspect import FullArgSpec -from typing import Callable, NamedTuple +from typing import Any, Callable, List, NamedTuple, Tuple import numpy as np @@ -46,6 +46,8 @@ SUPPORTED_LEAF_TYPE = { int, float, bool, + bytes, + bytearray, QuantDtypeMeta, CompNode, Device, @@ -74,18 +76,51 @@ SUPPORTED_LEAF_CLS = [ NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) -def register_supported_type(type, flatten=None, unflatten=None): +def register_supported_type( + type, + flatten_fn: Callable[[Any], Tuple[List, Any]] = None, + unflatten_fn: Callable[[List, Any], Any] = None, +): + r"""Call this function to register the ``type`` as a built-in type. The registered ``type`` + can be used and serialized correctly in :py:class:`TracedModule`. + + Examples: + .. code-block:: + + def dict_flatten(obj: Dict): + context, values = [], [] + # obj.keys() needs to be sortable + keys = sorted(obj.keys()) + for key in keys: + values.append(obj[key]) + context.append(key) + return values, tuple(context) + + def dict_unflatten(values: List, context: Any): + return dict(zip(context, values)) + + register_supported_type(dict, dict_flatten, dict_unflatten) + + Args: + type: the type that needs to be registered. + flatten_fn: a function that should take an object created from ``type`` and return a + flat list of values. It can also return some context that is used in reconstructing + the object. Default: None + unflatten_fn: a function that should take a flat list of values and some context + (returned by flatten_fn). It returns the object by reconstructing + it from the list and the context. Default: None + """ tp_info = (type.__module__, type.__qualname__) - if flatten and unflatten: + if flatten_fn and unflatten_fn: USER_REGISTERED_CONTAINER_TYPE.append(tp_info) else: USER_REGISTERED_LEAF_TYPE.append(tp_info) - _register_supported_type(type, flatten, unflatten) + _register_supported_type(type, flatten_fn, unflatten_fn) -def _register_supported_type(type, flatten=None, unflatten=None): - if flatten and unflatten: - SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) +def _register_supported_type(type, flatten_fn=None, unflatten_fn=None): + if flatten_fn and unflatten_fn: + SUPPORTED_TYPE[type] = NodeType(flatten_fn, unflatten_fn) else: SUPPORTED_LEAF_CLS.append(type) @@ -131,6 +166,7 @@ _register_supported_type( _register_supported_type( OrderedDict, partial(_dict_flatten, True), partial(_dict_unflatten, OrderedDict) ) + _register_supported_type( slice, lambda x: ([x.start, x.stop, x.step], None), diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 8092f6c9..058fb4a6 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -42,6 +42,7 @@ from ..core._imperative_rt.core2 import ( ) from ..core._trace_option import set_symbolic_shape from ..module import Module +from ..module import external as MExternal from ..module.qat import QATModule from ..quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize from ..quantization.observer import ( @@ -207,6 +208,7 @@ def _wrap_method_to_tensor_node(): for method in get_tensor_wrapable_method(): patch = PatchedFn(TensorNode, method) if type(getattr(Tensor, method)) == property: + # Only support property.getter patch.set_func(property(_any_method(method, patch.origin_fn))) else: patch.set_func(_any_method(method, patch.origin_fn)) @@ -351,14 +353,14 @@ class _InsertExprs: assert ( node.top_graph == self.graph ), "The input node ({}) is not in the graph ({})".format(node, self.graph) - if isinstance(node, TensorNode) and node.expr in self.graph._exprs: + if node.expr in self.graph._exprs: max_inp_expr_idx = max( max_inp_expr_idx, self.graph._exprs.index(node.expr) ) max_inp_expr_idx += 1 insert_index = -1 - if self.expr is not None: + if self.expr in self.graph._exprs: insert_index = self.graph._exprs.index(self.expr) insert_index += 1 @@ -2070,7 +2072,8 @@ class TracedModule(Module): for inp_def, graph in self.argdef_graph_map.items(): if top_graph is not None: graph._top_graph = weakref.ref(top_graph) - for n in graph._inputs + graph.outputs: + for n in graph._inputs + graph._outputs: + n.expr._top_graph = weakref.ref(graph) n._top_graph = weakref.ref(graph) graph._inputs[0]._owner = weakref.ref(self) for i, n in enumerate(graph._inputs): @@ -2375,7 +2378,7 @@ def wrap(func: Callable): def _register_all_builtin_module(): - for sub_mod in [M, M.qat, M.quantized]: + for sub_mod in [M, M.qat, M.quantized, MExternal]: for m in getmembers(sub_mod): if ( isclass(m[1]) diff --git a/imperative/python/megengine/traced_module/utils.py b/imperative/python/megengine/traced_module/utils.py index 21ccb35c..b722fc30 100644 --- a/imperative/python/megengine/traced_module/utils.py +++ b/imperative/python/megengine/traced_module/utils.py @@ -126,10 +126,12 @@ def _check_obj_attr(obj): for _, v in obj.items(): leafs, _ = tree_flatten(v, is_leaf=lambda _: True) for leaf in leafs: - assert _check_leaf_type( - leaf - ), "Type {} is not supported by traced module".format( - leaf if isinstance(leaf, type) else type(leaf) + assert _check_leaf_type(leaf), ( + "Type {} is not supported in TracedModule serialization by default. " + "If you want to save this object to file, please call tm.register_supported_type({}) " + "before saving.".format( + leaf if isinstance(leaf, type) else type(leaf), type(leaf).__name__ + ) )