@@ -9,6 +9,7 @@ | |||
from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | |||
from . import compat | |||
from ._passes import optimize | |||
from .pytree import register_supported_type | |||
from .traced_module import ( | |||
TracedModule, | |||
_register_all_builtin_module, | |||
@@ -23,6 +24,7 @@ set_cpp_apply_module_trace(cpp_apply_module_trace) | |||
__all__ = [ | |||
"register_as_builtin", | |||
"register_supported_type", | |||
"trace_module", | |||
"wrap", | |||
"TracedModule", | |||
@@ -12,7 +12,7 @@ from ...core.ops.builtin import GetVarShape | |||
from ...logger import get_logger | |||
from ...tensor import Tensor | |||
from ..expr import Constant, Expr, is_apply_def, is_constant, is_getattr | |||
from ..node import Node, TensorNode | |||
from ..node import Node, NodeMixin, TensorNode | |||
from .matcher import PatternMatcher | |||
from .pass_base import BackwardPass, ForwardPass, register_pass | |||
from .pattern import is_op | |||
@@ -21,6 +21,12 @@ from .utils import get_const_value | |||
logger = get_logger(__name__) | |||
def _as_const_node(x): | |||
node = Constant.make(x) | |||
NodeMixin.wrap(x, node) | |||
return node | |||
@register_pass("AttrToConstant") | |||
class AttrToConstant(BackwardPass): | |||
r"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr.""" | |||
@@ -35,10 +41,10 @@ class AttrToConstant(BackwardPass): | |||
orig_node = expr.outputs[0] | |||
name = orig_node.name | |||
with graph.insert_exprs(expr): | |||
const_node = Constant.make(value, name=name) | |||
const_node = _as_const_node(value) | |||
graph.replace_node({orig_node: const_node}) | |||
graph.compile() | |||
name = orig_node.name | |||
const_node.name = name | |||
return const_node.expr | |||
@@ -53,7 +59,7 @@ class FixInputShape(BackwardPass): | |||
shape = Tensor(expr.inputs[0].shape, dtype="int32") | |||
graph = expr.top_graph | |||
with graph.insert_exprs(expr): | |||
const_shape = Constant.make(shape) | |||
const_shape = _as_const_node(shape) | |||
graph.replace_node({expr.outputs[0]: const_shape}) | |||
graph.compile() | |||
const_shape.name = expr.outputs[0].name | |||
@@ -73,7 +79,7 @@ class FlodConstant(ForwardPass): | |||
const_var = expr.interpret(*[get_const_value(n.expr) for n in expr.inputs])[0] | |||
graph = expr.top_graph | |||
with graph.insert_exprs(expr): | |||
const_node = Constant.make(const_var) | |||
const_node = _as_const_node(const_var) | |||
graph.replace_node({expr.outputs[0]: const_node}) | |||
graph.compile() | |||
const_node.name = expr.outputs[0].name | |||
@@ -10,7 +10,7 @@ import collections | |||
from collections import OrderedDict, defaultdict | |||
from functools import partial | |||
from inspect import FullArgSpec | |||
from typing import Callable, NamedTuple | |||
from typing import Any, Callable, List, NamedTuple, Tuple | |||
import numpy as np | |||
@@ -46,6 +46,8 @@ SUPPORTED_LEAF_TYPE = { | |||
int, | |||
float, | |||
bool, | |||
bytes, | |||
bytearray, | |||
QuantDtypeMeta, | |||
CompNode, | |||
Device, | |||
@@ -74,18 +76,51 @@ SUPPORTED_LEAF_CLS = [ | |||
NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) | |||
def register_supported_type(type, flatten=None, unflatten=None): | |||
def register_supported_type( | |||
type, | |||
flatten_fn: Callable[[Any], Tuple[List, Any]] = None, | |||
unflatten_fn: Callable[[List, Any], Any] = None, | |||
): | |||
r"""Call this function to register the ``type`` as a built-in type. The registered ``type`` | |||
can be used and serialized correctly in :py:class:`TracedModule`. | |||
Examples: | |||
.. code-block:: | |||
def dict_flatten(obj: Dict): | |||
context, values = [], [] | |||
# obj.keys() needs to be sortable | |||
keys = sorted(obj.keys()) | |||
for key in keys: | |||
values.append(obj[key]) | |||
context.append(key) | |||
return values, tuple(context) | |||
def dict_unflatten(values: List, context: Any): | |||
return dict(zip(context, values)) | |||
register_supported_type(dict, dict_flatten, dict_unflatten) | |||
Args: | |||
type: the type that needs to be registered. | |||
flatten_fn: a function that should take an object created from ``type`` and return a | |||
flat list of values. It can also return some context that is used in reconstructing | |||
the object. Default: None | |||
unflatten_fn: a function that should take a flat list of values and some context | |||
(returned by flatten_fn). It returns the object by reconstructing | |||
it from the list and the context. Default: None | |||
""" | |||
tp_info = (type.__module__, type.__qualname__) | |||
if flatten and unflatten: | |||
if flatten_fn and unflatten_fn: | |||
USER_REGISTERED_CONTAINER_TYPE.append(tp_info) | |||
else: | |||
USER_REGISTERED_LEAF_TYPE.append(tp_info) | |||
_register_supported_type(type, flatten, unflatten) | |||
_register_supported_type(type, flatten_fn, unflatten_fn) | |||
def _register_supported_type(type, flatten=None, unflatten=None): | |||
if flatten and unflatten: | |||
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | |||
def _register_supported_type(type, flatten_fn=None, unflatten_fn=None): | |||
if flatten_fn and unflatten_fn: | |||
SUPPORTED_TYPE[type] = NodeType(flatten_fn, unflatten_fn) | |||
else: | |||
SUPPORTED_LEAF_CLS.append(type) | |||
@@ -131,6 +166,7 @@ _register_supported_type( | |||
_register_supported_type( | |||
OrderedDict, partial(_dict_flatten, True), partial(_dict_unflatten, OrderedDict) | |||
) | |||
_register_supported_type( | |||
slice, | |||
lambda x: ([x.start, x.stop, x.step], None), | |||
@@ -42,6 +42,7 @@ from ..core._imperative_rt.core2 import ( | |||
) | |||
from ..core._trace_option import set_symbolic_shape | |||
from ..module import Module | |||
from ..module import external as MExternal | |||
from ..module.qat import QATModule | |||
from ..quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize | |||
from ..quantization.observer import ( | |||
@@ -207,6 +208,7 @@ def _wrap_method_to_tensor_node(): | |||
for method in get_tensor_wrapable_method(): | |||
patch = PatchedFn(TensorNode, method) | |||
if type(getattr(Tensor, method)) == property: | |||
# Only support property.getter | |||
patch.set_func(property(_any_method(method, patch.origin_fn))) | |||
else: | |||
patch.set_func(_any_method(method, patch.origin_fn)) | |||
@@ -351,14 +353,14 @@ class _InsertExprs: | |||
assert ( | |||
node.top_graph == self.graph | |||
), "The input node ({}) is not in the graph ({})".format(node, self.graph) | |||
if isinstance(node, TensorNode) and node.expr in self.graph._exprs: | |||
if node.expr in self.graph._exprs: | |||
max_inp_expr_idx = max( | |||
max_inp_expr_idx, self.graph._exprs.index(node.expr) | |||
) | |||
max_inp_expr_idx += 1 | |||
insert_index = -1 | |||
if self.expr is not None: | |||
if self.expr in self.graph._exprs: | |||
insert_index = self.graph._exprs.index(self.expr) | |||
insert_index += 1 | |||
@@ -2070,7 +2072,8 @@ class TracedModule(Module): | |||
for inp_def, graph in self.argdef_graph_map.items(): | |||
if top_graph is not None: | |||
graph._top_graph = weakref.ref(top_graph) | |||
for n in graph._inputs + graph.outputs: | |||
for n in graph._inputs + graph._outputs: | |||
n.expr._top_graph = weakref.ref(graph) | |||
n._top_graph = weakref.ref(graph) | |||
graph._inputs[0]._owner = weakref.ref(self) | |||
for i, n in enumerate(graph._inputs): | |||
@@ -2375,7 +2378,7 @@ def wrap(func: Callable): | |||
def _register_all_builtin_module(): | |||
for sub_mod in [M, M.qat, M.quantized]: | |||
for sub_mod in [M, M.qat, M.quantized, MExternal]: | |||
for m in getmembers(sub_mod): | |||
if ( | |||
isclass(m[1]) | |||
@@ -126,10 +126,12 @@ def _check_obj_attr(obj): | |||
for _, v in obj.items(): | |||
leafs, _ = tree_flatten(v, is_leaf=lambda _: True) | |||
for leaf in leafs: | |||
assert _check_leaf_type( | |||
leaf | |||
), "Type {} is not supported by traced module".format( | |||
leaf if isinstance(leaf, type) else type(leaf) | |||
assert _check_leaf_type(leaf), ( | |||
"Type {} is not supported in TracedModule serialization by default. " | |||
"If you want to save this object to file, please call tm.register_supported_type({}) " | |||
"before saving.".format( | |||
leaf if isinstance(leaf, type) else type(leaf), type(leaf).__name__ | |||
) | |||
) | |||