Browse Source

fix(mge/traced_module): fix some bugs

GitOrigin-RevId: 88f98829ce
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
3a21920955
5 changed files with 69 additions and 20 deletions
  1. +2
    -0
      imperative/python/megengine/traced_module/__init__.py
  2. +11
    -5
      imperative/python/megengine/traced_module/_passes/const_pass.py
  3. +43
    -7
      imperative/python/megengine/traced_module/pytree.py
  4. +7
    -4
      imperative/python/megengine/traced_module/traced_module.py
  5. +6
    -4
      imperative/python/megengine/traced_module/utils.py

+ 2
- 0
imperative/python/megengine/traced_module/__init__.py View File

@@ -9,6 +9,7 @@
from ..core._imperative_rt.core2 import set_cpp_apply_module_trace
from . import compat
from ._passes import optimize
from .pytree import register_supported_type
from .traced_module import (
TracedModule,
_register_all_builtin_module,
@@ -23,6 +24,7 @@ set_cpp_apply_module_trace(cpp_apply_module_trace)

__all__ = [
"register_as_builtin",
"register_supported_type",
"trace_module",
"wrap",
"TracedModule",


+ 11
- 5
imperative/python/megengine/traced_module/_passes/const_pass.py View File

@@ -12,7 +12,7 @@ from ...core.ops.builtin import GetVarShape
from ...logger import get_logger
from ...tensor import Tensor
from ..expr import Constant, Expr, is_apply_def, is_constant, is_getattr
from ..node import Node, TensorNode
from ..node import Node, NodeMixin, TensorNode
from .matcher import PatternMatcher
from .pass_base import BackwardPass, ForwardPass, register_pass
from .pattern import is_op
@@ -21,6 +21,12 @@ from .utils import get_const_value
logger = get_logger(__name__)


def _as_const_node(x):
node = Constant.make(x)
NodeMixin.wrap(x, node)
return node


@register_pass("AttrToConstant")
class AttrToConstant(BackwardPass):
r"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr."""
@@ -35,10 +41,10 @@ class AttrToConstant(BackwardPass):
orig_node = expr.outputs[0]
name = orig_node.name
with graph.insert_exprs(expr):
const_node = Constant.make(value, name=name)
const_node = _as_const_node(value)
graph.replace_node({orig_node: const_node})
graph.compile()
name = orig_node.name
const_node.name = name
return const_node.expr


@@ -53,7 +59,7 @@ class FixInputShape(BackwardPass):
shape = Tensor(expr.inputs[0].shape, dtype="int32")
graph = expr.top_graph
with graph.insert_exprs(expr):
const_shape = Constant.make(shape)
const_shape = _as_const_node(shape)
graph.replace_node({expr.outputs[0]: const_shape})
graph.compile()
const_shape.name = expr.outputs[0].name
@@ -73,7 +79,7 @@ class FlodConstant(ForwardPass):
const_var = expr.interpret(*[get_const_value(n.expr) for n in expr.inputs])[0]
graph = expr.top_graph
with graph.insert_exprs(expr):
const_node = Constant.make(const_var)
const_node = _as_const_node(const_var)
graph.replace_node({expr.outputs[0]: const_node})
graph.compile()
const_node.name = expr.outputs[0].name


+ 43
- 7
imperative/python/megengine/traced_module/pytree.py View File

@@ -10,7 +10,7 @@ import collections
from collections import OrderedDict, defaultdict
from functools import partial
from inspect import FullArgSpec
from typing import Callable, NamedTuple
from typing import Any, Callable, List, NamedTuple, Tuple

import numpy as np

@@ -46,6 +46,8 @@ SUPPORTED_LEAF_TYPE = {
int,
float,
bool,
bytes,
bytearray,
QuantDtypeMeta,
CompNode,
Device,
@@ -74,18 +76,51 @@ SUPPORTED_LEAF_CLS = [
NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])


def register_supported_type(type, flatten=None, unflatten=None):
def register_supported_type(
type,
flatten_fn: Callable[[Any], Tuple[List, Any]] = None,
unflatten_fn: Callable[[List, Any], Any] = None,
):
r"""Call this function to register the ``type`` as a built-in type. The registered ``type``
can be used and serialized correctly in :py:class:`TracedModule`.

Examples:
.. code-block::

def dict_flatten(obj: Dict):
context, values = [], []
# obj.keys() needs to be sortable
keys = sorted(obj.keys())
for key in keys:
values.append(obj[key])
context.append(key)
return values, tuple(context)
def dict_unflatten(values: List, context: Any):
return dict(zip(context, values))
register_supported_type(dict, dict_flatten, dict_unflatten)

Args:
type: the type that needs to be registered.
flatten_fn: a function that should take an object created from ``type`` and return a
flat list of values. It can also return some context that is used in reconstructing
the object. Default: None
unflatten_fn: a function that should take a flat list of values and some context
(returned by flatten_fn). It returns the object by reconstructing
it from the list and the context. Default: None
"""
tp_info = (type.__module__, type.__qualname__)
if flatten and unflatten:
if flatten_fn and unflatten_fn:
USER_REGISTERED_CONTAINER_TYPE.append(tp_info)
else:
USER_REGISTERED_LEAF_TYPE.append(tp_info)
_register_supported_type(type, flatten, unflatten)
_register_supported_type(type, flatten_fn, unflatten_fn)


def _register_supported_type(type, flatten=None, unflatten=None):
if flatten and unflatten:
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)
def _register_supported_type(type, flatten_fn=None, unflatten_fn=None):
if flatten_fn and unflatten_fn:
SUPPORTED_TYPE[type] = NodeType(flatten_fn, unflatten_fn)
else:
SUPPORTED_LEAF_CLS.append(type)

@@ -131,6 +166,7 @@ _register_supported_type(
_register_supported_type(
OrderedDict, partial(_dict_flatten, True), partial(_dict_unflatten, OrderedDict)
)

_register_supported_type(
slice,
lambda x: ([x.start, x.stop, x.step], None),


+ 7
- 4
imperative/python/megengine/traced_module/traced_module.py View File

@@ -42,6 +42,7 @@ from ..core._imperative_rt.core2 import (
)
from ..core._trace_option import set_symbolic_shape
from ..module import Module
from ..module import external as MExternal
from ..module.qat import QATModule
from ..quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize
from ..quantization.observer import (
@@ -207,6 +208,7 @@ def _wrap_method_to_tensor_node():
for method in get_tensor_wrapable_method():
patch = PatchedFn(TensorNode, method)
if type(getattr(Tensor, method)) == property:
# Only support property.getter
patch.set_func(property(_any_method(method, patch.origin_fn)))
else:
patch.set_func(_any_method(method, patch.origin_fn))
@@ -351,14 +353,14 @@ class _InsertExprs:
assert (
node.top_graph == self.graph
), "The input node ({}) is not in the graph ({})".format(node, self.graph)
if isinstance(node, TensorNode) and node.expr in self.graph._exprs:
if node.expr in self.graph._exprs:
max_inp_expr_idx = max(
max_inp_expr_idx, self.graph._exprs.index(node.expr)
)
max_inp_expr_idx += 1

insert_index = -1
if self.expr is not None:
if self.expr in self.graph._exprs:
insert_index = self.graph._exprs.index(self.expr)
insert_index += 1

@@ -2070,7 +2072,8 @@ class TracedModule(Module):
for inp_def, graph in self.argdef_graph_map.items():
if top_graph is not None:
graph._top_graph = weakref.ref(top_graph)
for n in graph._inputs + graph.outputs:
for n in graph._inputs + graph._outputs:
n.expr._top_graph = weakref.ref(graph)
n._top_graph = weakref.ref(graph)
graph._inputs[0]._owner = weakref.ref(self)
for i, n in enumerate(graph._inputs):
@@ -2375,7 +2378,7 @@ def wrap(func: Callable):

def _register_all_builtin_module():

for sub_mod in [M, M.qat, M.quantized]:
for sub_mod in [M, M.qat, M.quantized, MExternal]:
for m in getmembers(sub_mod):
if (
isclass(m[1])


+ 6
- 4
imperative/python/megengine/traced_module/utils.py View File

@@ -126,10 +126,12 @@ def _check_obj_attr(obj):
for _, v in obj.items():
leafs, _ = tree_flatten(v, is_leaf=lambda _: True)
for leaf in leafs:
assert _check_leaf_type(
leaf
), "Type {} is not supported by traced module".format(
leaf if isinstance(leaf, type) else type(leaf)
assert _check_leaf_type(leaf), (
"Type {} is not supported in TracedModule serialization by default. "
"If you want to save this object to file, please call tm.register_supported_type({}) "
"before saving.".format(
leaf if isinstance(leaf, type) else type(leaf), type(leaf).__name__
)
)




Loading…
Cancel
Save