You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. import collections
  2. import copy
  3. import inspect
  4. from collections.abc import MutableMapping, MutableSequence
  5. from inspect import FullArgSpec
  6. from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union
  7. from .. import get_logger
  8. from ..module import Module
  9. from ..tensor import Tensor
  10. logger = get_logger(__name__)
  11. def replace_container_with_module_container(container):
  12. has_module = False
  13. module_container = None
  14. if isinstance(container, Dict):
  15. m_dic = copy.copy(container)
  16. for key, value in container.items():
  17. if isinstance(value, Module):
  18. has_module = True
  19. elif isinstance(value, (List, Dict)):
  20. (
  21. _has_module,
  22. _module_container,
  23. ) = replace_container_with_module_container(value)
  24. m_dic[key] = _module_container
  25. if _has_module:
  26. has_module = True
  27. if not all(isinstance(v, Module) for v in m_dic.values()):
  28. return has_module, None
  29. else:
  30. return has_module, _ModuleDict(m_dic)
  31. elif isinstance(container, List):
  32. m_list = copy.copy(container)
  33. for ind, value in enumerate(container):
  34. if isinstance(value, Module):
  35. has_module = True
  36. elif isinstance(value, (List, Dict)):
  37. (
  38. _has_module,
  39. _module_container,
  40. ) = replace_container_with_module_container(value)
  41. m_list[ind] = _module_container
  42. if _has_module:
  43. has_module = True
  44. if not all(isinstance(v, Module) for v in m_list):
  45. return has_module, None
  46. else:
  47. return has_module, _ModuleList(m_list)
  48. return has_module, module_container
  49. def _convert_kwargs_to_args(
  50. argspecs: Union[Callable, FullArgSpec], args, kwargs, is_bounded=False
  51. ):
  52. # is_bounded = True when func is a method and provided args don't include 'self'
  53. arg_specs = (
  54. inspect.getfullargspec(argspecs) if isinstance(argspecs, Callable) else argspecs
  55. )
  56. func_name = argspecs.__qualname__ if isinstance(argspecs, Callable) else "function"
  57. assert isinstance(arg_specs, FullArgSpec)
  58. arg_specs_args = arg_specs.args
  59. arg_specs_defaults = arg_specs.defaults if arg_specs.defaults else []
  60. arg_specs_kwonlyargs = arg_specs.kwonlyargs
  61. arg_specs_kwonlydefaults = (
  62. arg_specs.kwonlydefaults if arg_specs.kwonlydefaults else dict()
  63. )
  64. if is_bounded:
  65. arg_specs_args = arg_specs.args[1:]
  66. new_args = []
  67. new_kwargs = {}
  68. new_args.extend(args)
  69. if set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys()):
  70. repeated_arg_name = set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys())
  71. raise TypeError(
  72. "{} got multiple values for argument {}".format(
  73. func_name, ", ".join(repeated_arg_name)
  74. )
  75. )
  76. if len(new_args) < len(arg_specs_args):
  77. for ind in range(len(new_args), len(arg_specs_args)):
  78. arg_name = arg_specs_args[ind]
  79. if arg_name in kwargs:
  80. new_args.append(kwargs[arg_name])
  81. else:
  82. index = ind - len(arg_specs_args) + len(arg_specs_defaults)
  83. if index >= len(arg_specs_defaults) or index < 0:
  84. raise TypeError(
  85. "{} missing required positional arguments: {}".format(
  86. func_name, arg_name
  87. )
  88. )
  89. new_args.append(arg_specs_defaults[index])
  90. for kwarg_name in arg_specs_kwonlyargs:
  91. if kwarg_name in kwargs:
  92. new_kwargs[kwarg_name] = kwargs[kwarg_name]
  93. else:
  94. if kwarg_name not in arg_specs_kwonlydefaults:
  95. raise TypeError(
  96. "{} missing required keyword-only argument: {}".format(
  97. func_name, kwarg_name
  98. )
  99. )
  100. new_kwargs[kwarg_name] = arg_specs_kwonlydefaults[kwarg_name]
  101. for k, v in kwargs.items():
  102. if k not in arg_specs_args and k not in arg_specs_kwonlyargs:
  103. if arg_specs.varkw is None:
  104. raise TypeError(
  105. "{} got an unexpected keyword argument {}".format(func_name, k)
  106. )
  107. new_kwargs[k] = v
  108. return tuple(new_args), new_kwargs
  109. def _check_obj_attr(obj):
  110. # check if all the attributes of a obj is serializable
  111. from .pytree import tree_flatten
  112. from .pytree import SUPPORTED_LEAF_CLS, SUPPORTED_LEAF_TYPE, TreeDef
  113. from .expr import Expr
  114. from .traced_module import TracedModule, InternalGraph, NameSpace
  115. def _check_leaf_type(leaf):
  116. leaf_type = leaf if isinstance(leaf, type) else type(leaf)
  117. traced_module_types = [Expr, TreeDef, TracedModule, InternalGraph, NameSpace]
  118. return (
  119. issubclass(leaf_type, tuple(SUPPORTED_LEAF_CLS + traced_module_types))
  120. or leaf_type in SUPPORTED_LEAF_TYPE
  121. )
  122. for _, v in obj.items():
  123. leafs, _ = tree_flatten(v, is_leaf=lambda _: True)
  124. for leaf in leafs:
  125. assert _check_leaf_type(leaf), (
  126. "Type {} is not supported in TracedModule serialization by default. "
  127. "If you want to save this object to file, please call tm.register_supported_type({}) "
  128. "before saving.".format(
  129. leaf if isinstance(leaf, type) else type(leaf), type(leaf).__name__
  130. )
  131. )
  132. def _check_builtin_module_attr(mod):
  133. from .pytree import _is_leaf as _check_leaf_type
  134. from .pytree import tree_flatten
  135. # check if all the attributes of a builtin module is serializable
  136. is_non_serializable_module = lambda m: isinstance(
  137. m, Module
  138. ) and not _check_builtin_module_attr(m)
  139. for k, v in mod.__dict__.items():
  140. if k == "_m_dump_modulestate":
  141. continue
  142. if is_non_serializable_module(v):
  143. return False
  144. elif not isinstance(v, Module):
  145. leafs, _ = tree_flatten(v, is_leaf=lambda _: True)
  146. for leaf in leafs:
  147. if not _check_leaf_type(leaf) or is_non_serializable_module(leaf):
  148. logger.warn(
  149. "Type {} is not supported by traced module".format(
  150. leaf if isinstance(leaf, type) else type(leaf)
  151. )
  152. )
  153. return False
  154. return True
  155. class _ModuleList(Module, MutableSequence):
  156. r"""A List-like container.
  157. Using a ``ModuleList``, one can visit, add, delete and modify submodules
  158. just like an ordinary python list.
  159. """
  160. def __init__(self, modules: Optional[Iterable[Module]] = None):
  161. super().__init__()
  162. self._size = 0
  163. if modules is None:
  164. return
  165. for mod in modules:
  166. self.append(mod)
  167. @classmethod
  168. def _ikey(cls, idx):
  169. return "{}".format(idx)
  170. def _check_idx(self, idx):
  171. L = len(self)
  172. if idx < 0:
  173. idx = L + idx
  174. if idx < 0 or idx >= L:
  175. raise IndexError("list index out of range")
  176. return idx
  177. def __getitem__(self, idx: int):
  178. if isinstance(idx, slice):
  179. idx = range(self._size)[idx]
  180. if not isinstance(idx, Sequence):
  181. idx = [
  182. idx,
  183. ]
  184. rst = []
  185. for i in idx:
  186. i = self._check_idx(i)
  187. key = self._ikey(i)
  188. try:
  189. rst.append(getattr(self, key))
  190. except AttributeError:
  191. raise IndexError("list index out of range")
  192. return rst if len(rst) > 1 else rst[0]
  193. def __setattr__(self, key, value):
  194. # clear mod name to avoid warning in Module's setattr
  195. if isinstance(value, Module):
  196. value._short_name = None
  197. super().__setattr__(key, value)
  198. def __setitem__(self, idx: int, mod: Module):
  199. if not isinstance(mod, Module):
  200. raise ValueError("invalid sub-module")
  201. idx = self._check_idx(idx)
  202. setattr(self, self._ikey(idx), mod)
  203. def __delitem__(self, idx):
  204. idx = self._check_idx(idx)
  205. L = len(self)
  206. for orig_idx in range(idx + 1, L):
  207. new_idx = orig_idx - 1
  208. self[new_idx] = self[orig_idx]
  209. delattr(self, self._ikey(L - 1))
  210. self._size -= 1
  211. def __len__(self):
  212. return self._size
  213. def insert(self, idx, mod: Module):
  214. assert isinstance(mod, Module)
  215. L = len(self)
  216. if idx < 0:
  217. idx = L - idx
  218. # clip idx to (0, L)
  219. if idx > L:
  220. idx = L
  221. elif idx < 0:
  222. idx = 0
  223. for new_idx in range(L, idx, -1):
  224. orig_idx = new_idx - 1
  225. key = self._ikey(new_idx)
  226. setattr(self, key, self[orig_idx])
  227. key = self._ikey(idx)
  228. setattr(self, key, mod)
  229. self._size += 1
  230. def forward(self):
  231. raise RuntimeError("ModuleList is not callable")
  232. class _ModuleDict(Module, MutableMapping):
  233. r"""A Dict-like container.
  234. Using a ``ModuleDict``, one can visit, add, delete and modify submodules
  235. just like an ordinary python dict.
  236. """
  237. def __init__(self, modules: Optional[Dict[str, Module]] = None):
  238. super().__init__()
  239. self._module_keys = []
  240. if modules is not None:
  241. self.update(modules)
  242. def __delitem__(self, key):
  243. delattr(self, key)
  244. assert key in self._module_keys
  245. self._module_keys.remove(key)
  246. def __getitem__(self, key):
  247. return getattr(self, key)
  248. def __setattr__(self, key, value):
  249. # clear mod name to avoid warning in Module's setattr
  250. if isinstance(value, Module):
  251. value._short_name = None
  252. super().__setattr__(key, value)
  253. def __setitem__(self, key, value):
  254. if not isinstance(value, Module):
  255. raise ValueError("invalid sub-module")
  256. setattr(self, key, value)
  257. if key not in self._module_keys:
  258. self._module_keys.append(key)
  259. def __iter__(self):
  260. return iter(self.keys())
  261. def __len__(self):
  262. return len(self._module_keys)
  263. def items(self):
  264. return [(key, getattr(self, key)) for key in self._module_keys]
  265. def values(self):
  266. return [getattr(self, key) for key in self._module_keys]
  267. def keys(self):
  268. return self._module_keys
  269. def forward(self):
  270. raise RuntimeError("ModuleList is not callable")
  271. def assign_attr(obj: Union[Module, Tensor], module: Module, target: str):
  272. *prefix, name = target.split(".")
  273. for item in prefix:
  274. module = getattr(module, item)
  275. if not isinstance(module, Module):
  276. raise AttributeError("`{}` is not an Module".format(item))
  277. setattr(module, name, obj)
  278. def get_subattr(module: Module, target: str):
  279. # todo : remove this import
  280. from .node import ModuleNode
  281. if target == "":
  282. return module
  283. *prefix, name = target.split(".")
  284. for item in prefix:
  285. module = getattr(module, item)
  286. if not isinstance(module, (Module, ModuleNode)):
  287. raise AttributeError("`{}` is not an Module".format(item))
  288. return getattr(module, name)