# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import collections from collections import OrderedDict from typing import Callable, NamedTuple import numpy as np from ...core._imperative_rt.common import CompNode from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._wrap import Device from ...core.tensor.dtype import QuantDtypeMeta from ...module import Module from ...quantization.utils import LSQParams, QParams, QuantMode from ...tensor import Parameter, Tensor from .node import ModuleNode, Node, NodeMixin, TensorNode SUPPORTED_TYPE = {} # if type(object) or obj in SUPPORTED_LEAF_TYPE, the object could be treated as leaf node of pytree SUPPORTED_LEAF_TYPE = { RawTensor, Tensor, Parameter, str, int, float, bool, QuantDtypeMeta, CompNode, Device, type(None), type(Ellipsis), QuantMode, } # if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree SUPPORTED_LEAF_CLS = [Module, Node, NodeMixin, np.dtype, np.ndarray, np.number] NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)]) def register_supported_type(type, flatten=None, unflatten=None): if flatten and unflatten: SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) else: SUPPORTED_LEAF_CLS.append(type) def _dict_flatten(inp): aux_data = [] results = [] for key, value in sorted(inp.items()): results.append(value) aux_data.append(key) return results, tuple(aux_data) def _dict_unflatten(inps, aux_data): return dict(zip(aux_data, inps)) def _ordereddict_flatten(inp): aux_data = [] results = [] for key, value in inp.items(): results.append(value) aux_data.append(key) return results, tuple(aux_data) def _ordereddict_unflatten(inps, aux_data): return OrderedDict(zip(aux_data, inps)) def qparams_flatten(inp): aux_data = [] results = [] for key in inp.__slots__: aux_data.append(key) results.append(getattr(inp, key, None)) return results, tuple(aux_data) def qparams_unflatten(inp, aux_data): obj = QParams.__new__(QParams) for k, v in zip(aux_data, inp): setattr(obj, k, v) return obj register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x)) register_supported_type(dict, _dict_flatten, _dict_unflatten) register_supported_type( collections.OrderedDict, _ordereddict_flatten, _ordereddict_unflatten ) register_supported_type( slice, lambda x: ([x.start, x.stop, x.step], None), lambda x, aux_data: slice(x[0], x[1], x[2]), ) register_supported_type(QParams, qparams_flatten, qparams_unflatten) def _is_leaf(obj): if isinstance(obj, type): return issubclass(obj, tuple(SUPPORTED_LEAF_CLS)) or obj in SUPPORTED_LEAF_TYPE return ( isinstance(obj, tuple(SUPPORTED_LEAF_CLS)) or type(obj) in SUPPORTED_LEAF_TYPE ) def _leaf_type(node): if isinstance(node, (RawTensor, TensorNode)): return (Tensor, TensorNode) elif isinstance(node, (NodeMixin, Module)): return (Module, ModuleNode, NodeMixin) else: return type(node) def _is_const_leaf(node): if isinstance(node, (RawTensor, NodeMixin, Module)): return False return True def tree_flatten( values, leaf_type: Callable = _leaf_type, is_leaf: Callable = _is_leaf, is_const_leaf: Callable = _is_const_leaf, ): if type(values) not in SUPPORTED_TYPE: assert is_leaf(values), values node = LeafDef(leaf_type(values)) if is_const_leaf(values): if isinstance(values, np.ndarray): node.const_val = str(values) else: node.const_val = values return [values,], node rst = [] children_defs = [] children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values) for v in children_values: v_list, treedef = tree_flatten(v, leaf_type, is_leaf, is_const_leaf) rst.extend(v_list) children_defs.append(treedef) return rst, TreeDef(type(values), aux_data, children_defs) class TreeDef: def __init__(self, type, aux_data, children_defs): self.type = type self.aux_data = aux_data self.children_defs = children_defs self.num_leaves = sum(ch.num_leaves for ch in children_defs) def unflatten(self, leaves): assert len(leaves) == self.num_leaves start = 0 children = [] for ch in self.children_defs: children.append(ch.unflatten(leaves[start : start + ch.num_leaves])) start += ch.num_leaves return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data) def __hash__(self): return hash( tuple( [ self.type, self.aux_data, self.num_leaves, tuple([hash(x) for x in self.children_defs]), ] ) ) def __lt__(self, other): return self.__hash__() < other.__hash__() def __gt__(self, other): return self.__hash__() > other.__hash__() def __eq__(self, other): return ( self.type == other.type and self.aux_data == other.aux_data and self.num_leaves == other.num_leaves and self.children_defs == other.children_defs ) def __repr__(self): return "{}[{}]".format(self.type.__name__, self.children_defs) class LeafDef(TreeDef): def __init__(self, type): if not isinstance(type, collections.abc.Sequence): type = (type,) super().__init__(type, None, []) self.num_leaves = 1 self.const_val = None def unflatten(self, leaves): assert len(leaves) == 1 assert isinstance(leaves[0], self.type), self.type return leaves[0] def __eq__(self, other): return self.type == other.type and self.const_val == other.const_val def __hash__(self): return hash(tuple([self.type, self.const_val])) def __repr__(self): return "Leaf({}[{}])".format( ", ".join(t.__name__ for t in self.type), self.const_val )