You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pytree.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import collections
  9. from collections import OrderedDict, defaultdict
  10. from functools import partial
  11. from inspect import FullArgSpec
  12. from typing import Any, Callable, Dict, List, NamedTuple, Tuple
  13. import numpy as np
  14. from ..core._imperative_rt import OpDef
  15. from ..core._imperative_rt.common import CompNode
  16. from ..core._imperative_rt.core2 import Tensor as RawTensor
  17. from ..core._wrap import Device
  18. from ..core.tensor.dtype import QuantDtypeMeta
  19. from ..distributed import Group
  20. from ..module import Module
  21. from ..quantization.utils import LSQParams, QParams, QuantMode
  22. from ..tensor import Parameter, Tensor
  23. from .node import ModuleNode, Node, NodeMixin, TensorNode
  24. class ArgsIndex:
  25. def __init__(self, index=0, name="") -> None:
  26. self.index = index
  27. self.name = name
  28. def __repr__(self) -> str:
  29. return self.name
  30. SUPPORTED_TYPE = {}
  31. # if type(object) or obj in SUPPORTED_LEAF_TYPE, the object could be treated as leaf node of pytree
  32. SUPPORTED_LEAF_TYPE = {
  33. RawTensor,
  34. Tensor,
  35. Parameter,
  36. str,
  37. int,
  38. float,
  39. bool,
  40. bytes,
  41. bytearray,
  42. QuantDtypeMeta,
  43. CompNode,
  44. Device,
  45. type(None),
  46. type(Ellipsis),
  47. QuantMode,
  48. ArgsIndex,
  49. Group,
  50. FullArgSpec,
  51. }
  52. USER_REGISTERED_LEAF_TYPE = []
  53. USER_REGISTERED_CONTAINER_TYPE = []
  54. # if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree
  55. SUPPORTED_LEAF_CLS = [
  56. Module,
  57. Node,
  58. NodeMixin,
  59. np.dtype,
  60. np.ndarray,
  61. np.number,
  62. np.bool_,
  63. OpDef,
  64. ]
  65. NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])
  66. def register_supported_type(
  67. type,
  68. flatten_fn: Callable[[Any], Tuple[List, Any]] = None,
  69. unflatten_fn: Callable[[List, Any], Any] = None,
  70. ):
  71. r"""Call this function to register the ``type`` as a built-in type. The registered ``type``
  72. can be used and serialized correctly in :py:class:`TracedModule`.
  73. Examples:
  74. .. code-block::
  75. def dict_flatten(obj: Dict):
  76. context, values = [], []
  77. # obj.keys() needs to be sortable
  78. keys = sorted(obj.keys())
  79. for key in keys:
  80. values.append(obj[key])
  81. context.append(key)
  82. return values, tuple(context)
  83. def dict_unflatten(values: List, context: Any):
  84. return dict(zip(context, values))
  85. register_supported_type(dict, dict_flatten, dict_unflatten)
  86. Args:
  87. type: the type that needs to be registered.
  88. flatten_fn: a function that should take an object created from ``type`` and return a
  89. flat list of values. It can also return some context that is used in reconstructing
  90. the object. Default: None
  91. unflatten_fn: a function that should take a flat list of values and some context
  92. (returned by flatten_fn). It returns the object by reconstructing
  93. it from the list and the context. Default: None
  94. """
  95. tp_info = (type.__module__, type.__qualname__)
  96. if flatten_fn and unflatten_fn:
  97. USER_REGISTERED_CONTAINER_TYPE.append(tp_info)
  98. else:
  99. USER_REGISTERED_LEAF_TYPE.append(tp_info)
  100. _register_supported_type(type, flatten_fn, unflatten_fn)
  101. def _register_supported_type(type, flatten_fn=None, unflatten_fn=None):
  102. if flatten_fn and unflatten_fn:
  103. SUPPORTED_TYPE[type] = NodeType(flatten_fn, unflatten_fn)
  104. else:
  105. SUPPORTED_LEAF_CLS.append(type)
  106. def _dict_flatten(ordered, inp):
  107. aux_data = []
  108. results = []
  109. dict_items = inp.items() if ordered else sorted(inp.items())
  110. for key, value in dict_items:
  111. results.append(value)
  112. aux_data.append(key)
  113. return results, tuple(aux_data)
  114. def _dict_unflatten(dict_type, inps, aux_data):
  115. return dict_type(zip(aux_data, inps))
  116. def qparams_flatten(inp):
  117. aux_data = []
  118. results = []
  119. for key in inp.__slots__:
  120. aux_data.append(key)
  121. results.append(getattr(inp, key, None))
  122. return results, tuple(aux_data)
  123. def qparams_unflatten(qparam_type, inp, aux_data):
  124. obj = qparam_type.__new__(qparam_type)
  125. for k, v in zip(aux_data, inp):
  126. setattr(obj, k, v)
  127. return obj
  128. _register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
  129. _register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x))
  130. _register_supported_type(
  131. dict, partial(_dict_flatten, False), partial(_dict_unflatten, dict)
  132. )
  133. _register_supported_type(
  134. defaultdict, partial(_dict_flatten, False), partial(_dict_unflatten, defaultdict)
  135. )
  136. _register_supported_type(
  137. OrderedDict, partial(_dict_flatten, True), partial(_dict_unflatten, OrderedDict)
  138. )
  139. _register_supported_type(
  140. slice,
  141. lambda x: ([x.start, x.stop, x.step], None),
  142. lambda x, aux_data: slice(x[0], x[1], x[2]),
  143. )
  144. _register_supported_type(QParams, qparams_flatten, partial(qparams_unflatten, QParams))
  145. _register_supported_type(
  146. LSQParams, qparams_flatten, partial(qparams_unflatten, LSQParams)
  147. )
  148. def _is_leaf(obj):
  149. obj_type = obj if isinstance(obj, type) else type(obj)
  150. return (
  151. issubclass(obj_type, tuple(SUPPORTED_LEAF_CLS))
  152. or obj_type in SUPPORTED_LEAF_TYPE
  153. )
  154. def _leaf_type(node):
  155. if isinstance(node, (RawTensor, TensorNode)):
  156. return (Tensor, TensorNode, ArgsIndex)
  157. elif isinstance(node, (NodeMixin, Module, ModuleNode)):
  158. return (Module, ModuleNode, NodeMixin, ArgsIndex)
  159. else:
  160. return (type(node), ArgsIndex)
  161. def _is_const_leaf(node):
  162. if isinstance(node, (RawTensor, NodeMixin, Module)):
  163. return False
  164. return True
  165. def tree_flatten(
  166. values,
  167. leaf_type: Callable = _leaf_type,
  168. is_leaf: Callable = _is_leaf,
  169. is_const_leaf: Callable = _is_const_leaf,
  170. ):
  171. r"""Flattens a pytree into a list of values and a :class:`TreeDef` that can be used
  172. to reconstruct the pytree.
  173. """
  174. if type(values) not in SUPPORTED_TYPE:
  175. assert is_leaf(
  176. values
  177. ), 'doesn\'t support {} type, MUST use "register_supported_type" method to register self-defined type'.format(
  178. values
  179. )
  180. node = LeafDef(leaf_type(values))
  181. if is_const_leaf(values):
  182. node.const_val = values
  183. return [values,], node
  184. rst = []
  185. children_defs = []
  186. children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values)
  187. for v in children_values:
  188. v_list, treedef = tree_flatten(v, leaf_type, is_leaf, is_const_leaf)
  189. rst.extend(v_list)
  190. children_defs.append(treedef)
  191. return rst, TreeDef(type(values), aux_data, children_defs)
  192. class TreeDef:
  193. r"""A ``TreeDef`` represents the structure of a pytree.
  194. Args:
  195. type: the type of root Node of the pytree.
  196. aux_data: some const data that is useful in unflattening the pytree.
  197. children_defs: ``TreeDef`` for each child of the root Node.
  198. num_leaves: the number of leaves.
  199. """
  200. def __init__(self, type, aux_data, children_defs):
  201. self.type = type
  202. self.aux_data = aux_data
  203. self.children_defs = children_defs
  204. self.num_leaves = sum(ch.num_leaves for ch in children_defs)
  205. def unflatten(self, leaves):
  206. r"""Given a list of values and a ``TreeDef``, builds a pytree.
  207. This is the inverse operation of ``tree_flatten``.
  208. """
  209. assert len(leaves) == self.num_leaves
  210. start = 0
  211. children = []
  212. for ch in self.children_defs:
  213. children.append(ch.unflatten(leaves[start : start + ch.num_leaves]))
  214. start += ch.num_leaves
  215. return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data)
  216. def __hash__(self):
  217. return hash(
  218. tuple(
  219. [
  220. self.type,
  221. self.aux_data,
  222. self.num_leaves,
  223. tuple([hash(x) for x in self.children_defs]),
  224. ]
  225. )
  226. )
  227. def __ne__(self, other) -> bool:
  228. return not self.__eq__(other)
  229. def __eq__(self, other) -> bool:
  230. return (
  231. self.type == other.type
  232. and self.aux_data == other.aux_data
  233. and self.num_leaves == other.num_leaves
  234. and self.children_defs == other.children_defs
  235. )
  236. def _args_kwargs_repr(self):
  237. if (
  238. len(self.children_defs) == 2
  239. and issubclass(self.children_defs[0].type, (List, Tuple))
  240. and issubclass(self.children_defs[1].type, Dict)
  241. ):
  242. args_def = self.children_defs[0]
  243. content = ", ".join(repr(i) for i in args_def.children_defs)
  244. kwargs_def = self.children_defs[1]
  245. if kwargs_def.aux_data:
  246. content += ", "
  247. content += ", ".join(
  248. str(i) + "=" + repr(j)
  249. for i, j in zip(kwargs_def.aux_data, kwargs_def.children_defs)
  250. )
  251. return content
  252. else:
  253. return repr(self)
  254. def __repr__(self):
  255. format_str = self.type.__name__ + "({})"
  256. aux_data_delimiter = "="
  257. if issubclass(self.type, List):
  258. format_str = "[{}]"
  259. if issubclass(self.type, Tuple):
  260. format_str = "({})"
  261. if issubclass(self.type, Dict):
  262. format_str = "{{{}}}"
  263. aux_data_delimiter = ":"
  264. if self.aux_data:
  265. content = ", ".join(
  266. repr(i) + aux_data_delimiter + repr(j)
  267. for i, j in zip(self.aux_data, self.children_defs)
  268. )
  269. else:
  270. content = ", ".join(repr(i) for i in self.children_defs)
  271. return format_str.format(content)
  272. class LeafDef(TreeDef):
  273. def __init__(self, type):
  274. if not isinstance(type, collections.abc.Sequence):
  275. type = (type,)
  276. super().__init__(type, None, [])
  277. self.num_leaves = 1
  278. self.const_val = None
  279. def unflatten(self, leaves):
  280. assert len(leaves) == 1
  281. assert isinstance(leaves[0], self.type), self.type
  282. return leaves[0]
  283. def __ne__(self, other) -> bool:
  284. return not self.__eq__(other)
  285. def __eq__(self, other):
  286. if isinstance(self.const_val, np.ndarray):
  287. return self.type == other.type and (self.const_val == other.const_val).all()
  288. return self.type == other.type and self.const_val == other.const_val
  289. def __hash__(self):
  290. if isinstance(self.const_val, np.ndarray):
  291. return hash(tuple([self.type, str(self.const_val)]))
  292. return hash(tuple([self.type, self.const_val]))
  293. def __repr__(self):
  294. return "{}".format(
  295. self.const_val
  296. if self.const_val is not None or type(None) in self.type
  297. else self.type[0].__name__
  298. )