|
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-
- import collections
- from collections import OrderedDict, defaultdict
- from functools import partial
- from inspect import FullArgSpec
- from typing import Any, Callable, Dict, List, NamedTuple, Tuple
-
- import numpy as np
-
- from ..core._imperative_rt import OpDef
- from ..core._imperative_rt.common import CompNode
- from ..core._imperative_rt.core2 import Tensor as RawTensor
- from ..core._wrap import Device
- from ..core.tensor.dtype import QuantDtypeMeta
- from ..distributed import Group
- from ..module import Module
- from ..quantization.utils import LSQParams, QParams, QuantMode
- from ..tensor import Parameter, Tensor
- from .node import ModuleNode, Node, NodeMixin, TensorNode
-
-
- class ArgsIndex:
- def __init__(self, index=0, name="") -> None:
- self.index = index
- self.name = name
-
- def __repr__(self) -> str:
- return self.name
-
-
- SUPPORTED_TYPE = {}
-
- # if type(object) or obj in SUPPORTED_LEAF_TYPE, the object could be treated as leaf node of pytree
- SUPPORTED_LEAF_TYPE = {
- RawTensor,
- Tensor,
- Parameter,
- str,
- int,
- float,
- bool,
- bytes,
- bytearray,
- QuantDtypeMeta,
- CompNode,
- Device,
- type(None),
- type(Ellipsis),
- QuantMode,
- ArgsIndex,
- Group,
- FullArgSpec,
- }
-
- USER_REGISTERED_LEAF_TYPE = []
- USER_REGISTERED_CONTAINER_TYPE = []
- # if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree
- SUPPORTED_LEAF_CLS = [
- Module,
- Node,
- NodeMixin,
- np.dtype,
- np.ndarray,
- np.number,
- np.bool_,
- OpDef,
- ]
-
- NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])
-
-
- def register_supported_type(
- type,
- flatten_fn: Callable[[Any], Tuple[List, Any]] = None,
- unflatten_fn: Callable[[List, Any], Any] = None,
- ):
- r"""Call this function to register the ``type`` as a built-in type. The registered ``type``
- can be used and serialized correctly in :py:class:`TracedModule`.
-
- Examples:
- .. code-block::
-
- def dict_flatten(obj: Dict):
- context, values = [], []
- # obj.keys() needs to be sortable
- keys = sorted(obj.keys())
- for key in keys:
- values.append(obj[key])
- context.append(key)
- return values, tuple(context)
-
- def dict_unflatten(values: List, context: Any):
- return dict(zip(context, values))
-
- register_supported_type(dict, dict_flatten, dict_unflatten)
-
- Args:
- type: the type that needs to be registered.
- flatten_fn: a function that should take an object created from ``type`` and return a
- flat list of values. It can also return some context that is used in reconstructing
- the object. Default: None
- unflatten_fn: a function that should take a flat list of values and some context
- (returned by flatten_fn). It returns the object by reconstructing
- it from the list and the context. Default: None
- """
- tp_info = (type.__module__, type.__qualname__)
- if flatten_fn and unflatten_fn:
- USER_REGISTERED_CONTAINER_TYPE.append(tp_info)
- else:
- USER_REGISTERED_LEAF_TYPE.append(tp_info)
- _register_supported_type(type, flatten_fn, unflatten_fn)
-
-
- def _register_supported_type(type, flatten_fn=None, unflatten_fn=None):
- if flatten_fn and unflatten_fn:
- SUPPORTED_TYPE[type] = NodeType(flatten_fn, unflatten_fn)
- else:
- SUPPORTED_LEAF_CLS.append(type)
-
-
- def _dict_flatten(ordered, inp):
- aux_data = []
- results = []
- dict_items = inp.items() if ordered else sorted(inp.items())
- for key, value in dict_items:
- results.append(value)
- aux_data.append(key)
- return results, tuple(aux_data)
-
-
- def _dict_unflatten(dict_type, inps, aux_data):
- return dict_type(zip(aux_data, inps))
-
-
- def qparams_flatten(inp):
- aux_data = []
- results = []
- for key in inp.__slots__:
- aux_data.append(key)
- results.append(getattr(inp, key, None))
- return results, tuple(aux_data)
-
-
- def qparams_unflatten(qparam_type, inp, aux_data):
- obj = qparam_type.__new__(qparam_type)
- for k, v in zip(aux_data, inp):
- setattr(obj, k, v)
- return obj
-
-
- _register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
- _register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x))
- _register_supported_type(
- dict, partial(_dict_flatten, False), partial(_dict_unflatten, dict)
- )
- _register_supported_type(
- defaultdict, partial(_dict_flatten, False), partial(_dict_unflatten, defaultdict)
- )
- _register_supported_type(
- OrderedDict, partial(_dict_flatten, True), partial(_dict_unflatten, OrderedDict)
- )
-
- _register_supported_type(
- slice,
- lambda x: ([x.start, x.stop, x.step], None),
- lambda x, aux_data: slice(x[0], x[1], x[2]),
- )
-
- _register_supported_type(QParams, qparams_flatten, partial(qparams_unflatten, QParams))
- _register_supported_type(
- LSQParams, qparams_flatten, partial(qparams_unflatten, LSQParams)
- )
-
-
- def _is_leaf(obj):
- obj_type = obj if isinstance(obj, type) else type(obj)
- return (
- issubclass(obj_type, tuple(SUPPORTED_LEAF_CLS))
- or obj_type in SUPPORTED_LEAF_TYPE
- )
-
-
- def _leaf_type(node):
- if isinstance(node, (RawTensor, TensorNode)):
- return (Tensor, TensorNode, ArgsIndex)
- elif isinstance(node, (NodeMixin, Module, ModuleNode)):
- return (Module, ModuleNode, NodeMixin, ArgsIndex)
- else:
- return (type(node), ArgsIndex)
-
-
- def _is_const_leaf(node):
- if isinstance(node, (RawTensor, NodeMixin, Module)):
- return False
- return True
-
-
- def tree_flatten(
- values,
- leaf_type: Callable = _leaf_type,
- is_leaf: Callable = _is_leaf,
- is_const_leaf: Callable = _is_const_leaf,
- ):
- r"""Flattens a pytree into a list of values and a :class:`TreeDef` that can be used
- to reconstruct the pytree.
- """
- if type(values) not in SUPPORTED_TYPE:
- assert is_leaf(
- values
- ), 'doesn\'t support {} type, MUST use "register_supported_type" method to register self-defined type'.format(
- values
- )
- node = LeafDef(leaf_type(values))
- if is_const_leaf(values):
- node.const_val = values
- return [values,], node
-
- rst = []
- children_defs = []
- children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values)
- for v in children_values:
- v_list, treedef = tree_flatten(v, leaf_type, is_leaf, is_const_leaf)
- rst.extend(v_list)
- children_defs.append(treedef)
-
- return rst, TreeDef(type(values), aux_data, children_defs)
-
-
- class TreeDef:
- r"""A ``TreeDef`` represents the structure of a pytree.
-
- Args:
- type: the type of root Node of the pytree.
- aux_data: some const data that is useful in unflattening the pytree.
- children_defs: ``TreeDef`` for each child of the root Node.
- num_leaves: the number of leaves.
- """
-
- def __init__(self, type, aux_data, children_defs):
- self.type = type
- self.aux_data = aux_data
- self.children_defs = children_defs
- self.num_leaves = sum(ch.num_leaves for ch in children_defs)
-
- def unflatten(self, leaves):
- r"""Given a list of values and a ``TreeDef``, builds a pytree.
- This is the inverse operation of ``tree_flatten``.
- """
- assert len(leaves) == self.num_leaves
- start = 0
- children = []
- for ch in self.children_defs:
- children.append(ch.unflatten(leaves[start : start + ch.num_leaves]))
- start += ch.num_leaves
- return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data)
-
- def __hash__(self):
- return hash(
- tuple(
- [
- self.type,
- self.aux_data,
- self.num_leaves,
- tuple([hash(x) for x in self.children_defs]),
- ]
- )
- )
-
- def __ne__(self, other) -> bool:
- return not self.__eq__(other)
-
- def __eq__(self, other) -> bool:
- return (
- self.type == other.type
- and self.aux_data == other.aux_data
- and self.num_leaves == other.num_leaves
- and self.children_defs == other.children_defs
- )
-
- def _args_kwargs_repr(self):
- if (
- len(self.children_defs) == 2
- and issubclass(self.children_defs[0].type, (List, Tuple))
- and issubclass(self.children_defs[1].type, Dict)
- ):
- args_def = self.children_defs[0]
- content = ", ".join(repr(i) for i in args_def.children_defs)
- kwargs_def = self.children_defs[1]
- if kwargs_def.aux_data:
- content += ", "
- content += ", ".join(
- str(i) + "=" + repr(j)
- for i, j in zip(kwargs_def.aux_data, kwargs_def.children_defs)
- )
- return content
- else:
- return repr(self)
-
- def __repr__(self):
- format_str = self.type.__name__ + "({})"
- aux_data_delimiter = "="
- if issubclass(self.type, List):
- format_str = "[{}]"
- if issubclass(self.type, Tuple):
- format_str = "({})"
- if issubclass(self.type, Dict):
- format_str = "{{{}}}"
- aux_data_delimiter = ":"
- if self.aux_data:
- content = ", ".join(
- repr(i) + aux_data_delimiter + repr(j)
- for i, j in zip(self.aux_data, self.children_defs)
- )
- else:
- content = ", ".join(repr(i) for i in self.children_defs)
- return format_str.format(content)
-
-
- class LeafDef(TreeDef):
- def __init__(self, type):
- if not isinstance(type, collections.abc.Sequence):
- type = (type,)
- super().__init__(type, None, [])
- self.num_leaves = 1
- self.const_val = None
-
- def unflatten(self, leaves):
- assert len(leaves) == 1
- assert isinstance(leaves[0], self.type), self.type
- return leaves[0]
-
- def __ne__(self, other) -> bool:
- return not self.__eq__(other)
-
- def __eq__(self, other):
- if isinstance(self.const_val, np.ndarray):
- return self.type == other.type and (self.const_val == other.const_val).all()
- return self.type == other.type and self.const_val == other.const_val
-
- def __hash__(self):
- if isinstance(self.const_val, np.ndarray):
- return hash(tuple([self.type, str(self.const_val)]))
- return hash(tuple([self.type, self.const_val]))
-
- def __repr__(self):
-
- return "{}".format(
- self.const_val
- if self.const_val is not None or type(None) in self.type
- else self.type[0].__name__
- )
|