@@ -9,6 +9,7 @@ | |||||
from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | from ..core._imperative_rt.core2 import set_cpp_apply_module_trace | ||||
from . import compat | from . import compat | ||||
from ._passes import optimize | from ._passes import optimize | ||||
from .pytree import register_supported_type | |||||
from .traced_module import ( | from .traced_module import ( | ||||
TracedModule, | TracedModule, | ||||
_register_all_builtin_module, | _register_all_builtin_module, | ||||
@@ -23,6 +24,7 @@ set_cpp_apply_module_trace(cpp_apply_module_trace) | |||||
__all__ = [ | __all__ = [ | ||||
"register_as_builtin", | "register_as_builtin", | ||||
"register_supported_type", | |||||
"trace_module", | "trace_module", | ||||
"wrap", | "wrap", | ||||
"TracedModule", | "TracedModule", | ||||
@@ -12,7 +12,7 @@ from ...core.ops.builtin import GetVarShape | |||||
from ...logger import get_logger | from ...logger import get_logger | ||||
from ...tensor import Tensor | from ...tensor import Tensor | ||||
from ..expr import Constant, Expr, is_apply_def, is_constant, is_getattr | from ..expr import Constant, Expr, is_apply_def, is_constant, is_getattr | ||||
from ..node import Node, TensorNode | |||||
from ..node import Node, NodeMixin, TensorNode | |||||
from .matcher import PatternMatcher | from .matcher import PatternMatcher | ||||
from .pass_base import BackwardPass, ForwardPass, register_pass | from .pass_base import BackwardPass, ForwardPass, register_pass | ||||
from .pattern import is_op | from .pattern import is_op | ||||
@@ -21,6 +21,12 @@ from .utils import get_const_value | |||||
logger = get_logger(__name__) | logger = get_logger(__name__) | ||||
def _as_const_node(x): | |||||
node = Constant.make(x) | |||||
NodeMixin.wrap(x, node) | |||||
return node | |||||
@register_pass("AttrToConstant") | @register_pass("AttrToConstant") | ||||
class AttrToConstant(BackwardPass): | class AttrToConstant(BackwardPass): | ||||
r"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr.""" | r"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr.""" | ||||
@@ -35,10 +41,10 @@ class AttrToConstant(BackwardPass): | |||||
orig_node = expr.outputs[0] | orig_node = expr.outputs[0] | ||||
name = orig_node.name | name = orig_node.name | ||||
with graph.insert_exprs(expr): | with graph.insert_exprs(expr): | ||||
const_node = Constant.make(value, name=name) | |||||
const_node = _as_const_node(value) | |||||
graph.replace_node({orig_node: const_node}) | graph.replace_node({orig_node: const_node}) | ||||
graph.compile() | graph.compile() | ||||
name = orig_node.name | |||||
const_node.name = name | |||||
return const_node.expr | return const_node.expr | ||||
@@ -53,7 +59,7 @@ class FixInputShape(BackwardPass): | |||||
shape = Tensor(expr.inputs[0].shape, dtype="int32") | shape = Tensor(expr.inputs[0].shape, dtype="int32") | ||||
graph = expr.top_graph | graph = expr.top_graph | ||||
with graph.insert_exprs(expr): | with graph.insert_exprs(expr): | ||||
const_shape = Constant.make(shape) | |||||
const_shape = _as_const_node(shape) | |||||
graph.replace_node({expr.outputs[0]: const_shape}) | graph.replace_node({expr.outputs[0]: const_shape}) | ||||
graph.compile() | graph.compile() | ||||
const_shape.name = expr.outputs[0].name | const_shape.name = expr.outputs[0].name | ||||
@@ -73,7 +79,7 @@ class FlodConstant(ForwardPass): | |||||
const_var = expr.interpret(*[get_const_value(n.expr) for n in expr.inputs])[0] | const_var = expr.interpret(*[get_const_value(n.expr) for n in expr.inputs])[0] | ||||
graph = expr.top_graph | graph = expr.top_graph | ||||
with graph.insert_exprs(expr): | with graph.insert_exprs(expr): | ||||
const_node = Constant.make(const_var) | |||||
const_node = _as_const_node(const_var) | |||||
graph.replace_node({expr.outputs[0]: const_node}) | graph.replace_node({expr.outputs[0]: const_node}) | ||||
graph.compile() | graph.compile() | ||||
const_node.name = expr.outputs[0].name | const_node.name = expr.outputs[0].name | ||||
@@ -10,7 +10,7 @@ import collections | |||||
from collections import OrderedDict, defaultdict | from collections import OrderedDict, defaultdict | ||||
from functools import partial | from functools import partial | ||||
from inspect import FullArgSpec | from inspect import FullArgSpec | ||||
from typing import Callable, NamedTuple | |||||
from typing import Any, Callable, List, NamedTuple, Tuple | |||||
import numpy as np | import numpy as np | ||||
@@ -46,6 +46,8 @@ SUPPORTED_LEAF_TYPE = { | |||||
int, | int, | ||||
float, | float, | ||||
bool, | bool, | ||||
bytes, | |||||
bytearray, | |||||
QuantDtypeMeta, | QuantDtypeMeta, | ||||
CompNode, | CompNode, | ||||
Device, | Device, | ||||
@@ -74,18 +76,51 @@ SUPPORTED_LEAF_CLS = [ | |||||
NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) | NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) | ||||
def register_supported_type(type, flatten=None, unflatten=None): | |||||
def register_supported_type( | |||||
type, | |||||
flatten_fn: Callable[[Any], Tuple[List, Any]] = None, | |||||
unflatten_fn: Callable[[List, Any], Any] = None, | |||||
): | |||||
r"""Call this function to register the ``type`` as a built-in type. The registered ``type`` | |||||
can be used and serialized correctly in :py:class:`TracedModule`. | |||||
Examples: | |||||
.. code-block:: | |||||
def dict_flatten(obj: Dict): | |||||
context, values = [], [] | |||||
# obj.keys() needs to be sortable | |||||
keys = sorted(obj.keys()) | |||||
for key in keys: | |||||
values.append(obj[key]) | |||||
context.append(key) | |||||
return values, tuple(context) | |||||
def dict_unflatten(values: List, context: Any): | |||||
return dict(zip(context, values)) | |||||
register_supported_type(dict, dict_flatten, dict_unflatten) | |||||
Args: | |||||
type: the type that needs to be registered. | |||||
flatten_fn: a function that should take an object created from ``type`` and return a | |||||
flat list of values. It can also return some context that is used in reconstructing | |||||
the object. Default: None | |||||
unflatten_fn: a function that should take a flat list of values and some context | |||||
(returned by flatten_fn). It returns the object by reconstructing | |||||
it from the list and the context. Default: None | |||||
""" | |||||
tp_info = (type.__module__, type.__qualname__) | tp_info = (type.__module__, type.__qualname__) | ||||
if flatten and unflatten: | |||||
if flatten_fn and unflatten_fn: | |||||
USER_REGISTERED_CONTAINER_TYPE.append(tp_info) | USER_REGISTERED_CONTAINER_TYPE.append(tp_info) | ||||
else: | else: | ||||
USER_REGISTERED_LEAF_TYPE.append(tp_info) | USER_REGISTERED_LEAF_TYPE.append(tp_info) | ||||
_register_supported_type(type, flatten, unflatten) | |||||
_register_supported_type(type, flatten_fn, unflatten_fn) | |||||
def _register_supported_type(type, flatten=None, unflatten=None): | |||||
if flatten and unflatten: | |||||
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) | |||||
def _register_supported_type(type, flatten_fn=None, unflatten_fn=None): | |||||
if flatten_fn and unflatten_fn: | |||||
SUPPORTED_TYPE[type] = NodeType(flatten_fn, unflatten_fn) | |||||
else: | else: | ||||
SUPPORTED_LEAF_CLS.append(type) | SUPPORTED_LEAF_CLS.append(type) | ||||
@@ -131,6 +166,7 @@ _register_supported_type( | |||||
_register_supported_type( | _register_supported_type( | ||||
OrderedDict, partial(_dict_flatten, True), partial(_dict_unflatten, OrderedDict) | OrderedDict, partial(_dict_flatten, True), partial(_dict_unflatten, OrderedDict) | ||||
) | ) | ||||
_register_supported_type( | _register_supported_type( | ||||
slice, | slice, | ||||
lambda x: ([x.start, x.stop, x.step], None), | lambda x: ([x.start, x.stop, x.step], None), | ||||
@@ -42,6 +42,7 @@ from ..core._imperative_rt.core2 import ( | |||||
) | ) | ||||
from ..core._trace_option import set_symbolic_shape | from ..core._trace_option import set_symbolic_shape | ||||
from ..module import Module | from ..module import Module | ||||
from ..module import external as MExternal | |||||
from ..module.qat import QATModule | from ..module.qat import QATModule | ||||
from ..quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize | from ..quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize | ||||
from ..quantization.observer import ( | from ..quantization.observer import ( | ||||
@@ -207,6 +208,7 @@ def _wrap_method_to_tensor_node(): | |||||
for method in get_tensor_wrapable_method(): | for method in get_tensor_wrapable_method(): | ||||
patch = PatchedFn(TensorNode, method) | patch = PatchedFn(TensorNode, method) | ||||
if type(getattr(Tensor, method)) == property: | if type(getattr(Tensor, method)) == property: | ||||
# Only support property.getter | |||||
patch.set_func(property(_any_method(method, patch.origin_fn))) | patch.set_func(property(_any_method(method, patch.origin_fn))) | ||||
else: | else: | ||||
patch.set_func(_any_method(method, patch.origin_fn)) | patch.set_func(_any_method(method, patch.origin_fn)) | ||||
@@ -351,14 +353,14 @@ class _InsertExprs: | |||||
assert ( | assert ( | ||||
node.top_graph == self.graph | node.top_graph == self.graph | ||||
), "The input node ({}) is not in the graph ({})".format(node, self.graph) | ), "The input node ({}) is not in the graph ({})".format(node, self.graph) | ||||
if isinstance(node, TensorNode) and node.expr in self.graph._exprs: | |||||
if node.expr in self.graph._exprs: | |||||
max_inp_expr_idx = max( | max_inp_expr_idx = max( | ||||
max_inp_expr_idx, self.graph._exprs.index(node.expr) | max_inp_expr_idx, self.graph._exprs.index(node.expr) | ||||
) | ) | ||||
max_inp_expr_idx += 1 | max_inp_expr_idx += 1 | ||||
insert_index = -1 | insert_index = -1 | ||||
if self.expr is not None: | |||||
if self.expr in self.graph._exprs: | |||||
insert_index = self.graph._exprs.index(self.expr) | insert_index = self.graph._exprs.index(self.expr) | ||||
insert_index += 1 | insert_index += 1 | ||||
@@ -2070,7 +2072,8 @@ class TracedModule(Module): | |||||
for inp_def, graph in self.argdef_graph_map.items(): | for inp_def, graph in self.argdef_graph_map.items(): | ||||
if top_graph is not None: | if top_graph is not None: | ||||
graph._top_graph = weakref.ref(top_graph) | graph._top_graph = weakref.ref(top_graph) | ||||
for n in graph._inputs + graph.outputs: | |||||
for n in graph._inputs + graph._outputs: | |||||
n.expr._top_graph = weakref.ref(graph) | |||||
n._top_graph = weakref.ref(graph) | n._top_graph = weakref.ref(graph) | ||||
graph._inputs[0]._owner = weakref.ref(self) | graph._inputs[0]._owner = weakref.ref(self) | ||||
for i, n in enumerate(graph._inputs): | for i, n in enumerate(graph._inputs): | ||||
@@ -2375,7 +2378,7 @@ def wrap(func: Callable): | |||||
def _register_all_builtin_module(): | def _register_all_builtin_module(): | ||||
for sub_mod in [M, M.qat, M.quantized]: | |||||
for sub_mod in [M, M.qat, M.quantized, MExternal]: | |||||
for m in getmembers(sub_mod): | for m in getmembers(sub_mod): | ||||
if ( | if ( | ||||
isclass(m[1]) | isclass(m[1]) | ||||
@@ -126,10 +126,12 @@ def _check_obj_attr(obj): | |||||
for _, v in obj.items(): | for _, v in obj.items(): | ||||
leafs, _ = tree_flatten(v, is_leaf=lambda _: True) | leafs, _ = tree_flatten(v, is_leaf=lambda _: True) | ||||
for leaf in leafs: | for leaf in leafs: | ||||
assert _check_leaf_type( | |||||
leaf | |||||
), "Type {} is not supported by traced module".format( | |||||
leaf if isinstance(leaf, type) else type(leaf) | |||||
assert _check_leaf_type(leaf), ( | |||||
"Type {} is not supported in TracedModule serialization by default. " | |||||
"If you want to save this object to file, please call tm.register_supported_type({}) " | |||||
"before saving.".format( | |||||
leaf if isinstance(leaf, type) else type(leaf), type(leaf).__name__ | |||||
) | |||||
) | ) | ||||