You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import collections
  9. import copy
  10. import inspect
  11. from collections.abc import MutableMapping, MutableSequence
  12. from inspect import FullArgSpec
  13. from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Union
  14. from .. import get_logger
  15. from ..module import Module
  16. from ..tensor import Parameter, Tensor
  17. logger = get_logger(__name__)
  18. def replace_container_with_module_container(container):
  19. has_module = False
  20. module_container = None
  21. if isinstance(container, Dict):
  22. m_dic = copy.copy(container)
  23. for key, value in container.items():
  24. if isinstance(value, Module):
  25. has_module = True
  26. elif isinstance(value, (List, Dict)):
  27. (
  28. _has_module,
  29. _module_container,
  30. ) = replace_container_with_module_container(value)
  31. m_dic[key] = _module_container
  32. if _has_module:
  33. has_module = True
  34. if not all(isinstance(v, Module) for v in m_dic.values()):
  35. return has_module, None
  36. else:
  37. return has_module, _ModuleDict(m_dic)
  38. elif isinstance(container, List):
  39. m_list = copy.copy(container)
  40. for ind, value in enumerate(container):
  41. if isinstance(value, Module):
  42. has_module = True
  43. elif isinstance(value, (List, Dict)):
  44. (
  45. _has_module,
  46. _module_container,
  47. ) = replace_container_with_module_container(value)
  48. m_list[ind] = _module_container
  49. if _has_module:
  50. has_module = True
  51. if not all(isinstance(v, Module) for v in m_list):
  52. return has_module, None
  53. else:
  54. return has_module, _ModuleList(m_list)
  55. return has_module, module_container
  56. def _convert_kwargs_to_args(
  57. argspecs: Union[Callable, FullArgSpec], args, kwargs, is_bounded=False
  58. ):
  59. # is_bounded = True when func is a method and provided args don't include 'self'
  60. arg_specs = (
  61. inspect.getfullargspec(argspecs) if isinstance(argspecs, Callable) else argspecs
  62. )
  63. assert isinstance(arg_specs, FullArgSpec)
  64. arg_specs_args = arg_specs.args
  65. if is_bounded:
  66. arg_specs_args = arg_specs.args[1:]
  67. new_args = []
  68. new_kwargs = {}
  69. new_args.extend(args)
  70. if set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys()):
  71. repeated_arg_name = set(arg_specs_args[0 : len(new_args)]) & set(kwargs.keys())
  72. raise TypeError(
  73. "{} got multiple values for argument {}".format(
  74. func.__qualname__, ", ".join(repeated_arg_name)
  75. )
  76. )
  77. if len(new_args) < len(arg_specs.args):
  78. for ind in range(len(new_args), len(arg_specs_args)):
  79. arg_name = arg_specs_args[ind]
  80. if arg_name in kwargs:
  81. new_args.append(kwargs[arg_name])
  82. else:
  83. index = ind - len(arg_specs_args) + len(arg_specs.defaults)
  84. assert index < len(arg_specs.defaults) and index >= 0
  85. new_args.append(arg_specs.defaults[index])
  86. for kwarg_name in arg_specs.kwonlyargs:
  87. if kwarg_name in kwargs:
  88. new_kwargs[kwarg_name] = kwargs[kwarg_name]
  89. else:
  90. assert kwarg_name in arg_specs.kwonlydefaults
  91. new_kwargs[kwarg_name] = arg_specs.kwonlydefaults[kwarg_name]
  92. for k, v in kwargs.items():
  93. if k not in arg_specs.args and k not in arg_specs.kwonlyargs:
  94. if arg_specs.varkw is None:
  95. raise TypeError(
  96. "{} got an unexpected keyword argument {}".format(
  97. func.__qualname__, k
  98. )
  99. )
  100. new_kwargs[k] = v
  101. return tuple(new_args), new_kwargs
  102. def _check_obj_attr(obj):
  103. # check if all the attributes of a obj is serializable
  104. from .pytree import tree_flatten
  105. from .pytree import SUPPORTED_LEAF_CLS, SUPPORTED_LEAF_TYPE, TreeDef
  106. from .expr import Expr
  107. from .traced_module import TracedModule, InternalGraph, NameSpace
  108. def _check_leaf_type(leaf):
  109. leaf_type = leaf if isinstance(leaf, type) else type(leaf)
  110. traced_module_types = [Expr, TreeDef, TracedModule, InternalGraph, NameSpace]
  111. return (
  112. issubclass(leaf_type, tuple(SUPPORTED_LEAF_CLS + traced_module_types))
  113. or leaf_type in SUPPORTED_LEAF_TYPE
  114. )
  115. for _, v in obj.items():
  116. leafs, _ = tree_flatten(v, is_leaf=lambda _: True)
  117. for leaf in leafs:
  118. assert _check_leaf_type(
  119. leaf
  120. ), "Type {} is not supported by traced module".format(
  121. leaf if isinstance(leaf, type) else type(leaf)
  122. )
  123. def _check_builtin_module_attr(mod):
  124. from .pytree import _is_leaf as _check_leaf_type
  125. from .pytree import tree_flatten
  126. # check if all the attributes of a builtin module is serializable
  127. is_non_serializable_module = lambda m: isinstance(
  128. m, Module
  129. ) and not _check_builtin_module_attr(m)
  130. for k, v in mod.__dict__.items():
  131. if k == "_m_dump_modulestate":
  132. continue
  133. if is_non_serializable_module(v):
  134. return False
  135. elif not isinstance(v, Module):
  136. leafs, _ = tree_flatten(v, is_leaf=lambda _: True)
  137. for leaf in leafs:
  138. if not _check_leaf_type(leaf) or is_non_serializable_module(leaf):
  139. logger.warn(
  140. "Type {} is not supported by traced module".format(
  141. leaf if isinstance(leaf, type) else type(leaf)
  142. )
  143. )
  144. return False
  145. return True
  146. class _ModuleList(Module, MutableSequence):
  147. r"""A List-like container.
  148. Using a ``ModuleList``, one can visit, add, delete and modify submodules
  149. just like an ordinary python list.
  150. """
  151. def __init__(self, modules: Optional[Iterable[Module]] = None):
  152. super().__init__()
  153. self._size = 0
  154. if modules is None:
  155. return
  156. for mod in modules:
  157. self.append(mod)
  158. @classmethod
  159. def _ikey(cls, idx):
  160. return "{}".format(idx)
  161. def _check_idx(self, idx):
  162. L = len(self)
  163. if idx < 0:
  164. idx = L + idx
  165. if idx < 0 or idx >= L:
  166. raise IndexError("list index out of range")
  167. return idx
  168. def __getitem__(self, idx: int):
  169. if isinstance(idx, slice):
  170. idx = range(self._size)[idx]
  171. if not isinstance(idx, Sequence):
  172. idx = [
  173. idx,
  174. ]
  175. rst = []
  176. for i in idx:
  177. i = self._check_idx(i)
  178. key = self._ikey(i)
  179. try:
  180. rst.append(getattr(self, key))
  181. except AttributeError:
  182. raise IndexError("list index out of range")
  183. return rst if len(rst) > 1 else rst[0]
  184. def __setattr__(self, key, value):
  185. # clear mod name to avoid warning in Module's setattr
  186. if isinstance(value, Module):
  187. value._name = None
  188. super().__setattr__(key, value)
  189. def __setitem__(self, idx: int, mod: Module):
  190. if not isinstance(mod, Module):
  191. raise ValueError("invalid sub-module")
  192. idx = self._check_idx(idx)
  193. setattr(self, self._ikey(idx), mod)
  194. def __delitem__(self, idx):
  195. idx = self._check_idx(idx)
  196. L = len(self)
  197. for orig_idx in range(idx + 1, L):
  198. new_idx = orig_idx - 1
  199. self[new_idx] = self[orig_idx]
  200. delattr(self, self._ikey(L - 1))
  201. self._size -= 1
  202. def __len__(self):
  203. return self._size
  204. def insert(self, idx, mod: Module):
  205. assert isinstance(mod, Module)
  206. L = len(self)
  207. if idx < 0:
  208. idx = L - idx
  209. # clip idx to (0, L)
  210. if idx > L:
  211. idx = L
  212. elif idx < 0:
  213. idx = 0
  214. for new_idx in range(L, idx, -1):
  215. orig_idx = new_idx - 1
  216. key = self._ikey(new_idx)
  217. setattr(self, key, self[orig_idx])
  218. key = self._ikey(idx)
  219. setattr(self, key, mod)
  220. self._size += 1
  221. def forward(self):
  222. raise RuntimeError("ModuleList is not callable")
  223. class _ModuleDict(Module, MutableMapping):
  224. r"""A Dict-like container.
  225. Using a ``ModuleDict``, one can visit, add, delete and modify submodules
  226. just like an ordinary python dict.
  227. """
  228. def __init__(self, modules: Optional[Dict[str, Module]] = None):
  229. super().__init__()
  230. self._module_keys = []
  231. if modules is not None:
  232. self.update(modules)
  233. def __delitem__(self, key):
  234. delattr(self, key)
  235. assert key in self._module_keys
  236. self._module_keys.remove(key)
  237. def __getitem__(self, key):
  238. return getattr(self, key)
  239. def __setattr__(self, key, value):
  240. # clear mod name to avoid warning in Module's setattr
  241. if isinstance(value, Module):
  242. value._name = None
  243. super().__setattr__(key, value)
  244. def __setitem__(self, key, value):
  245. if not isinstance(value, Module):
  246. raise ValueError("invalid sub-module")
  247. setattr(self, key, value)
  248. if key not in self._module_keys:
  249. self._module_keys.append(key)
  250. def __iter__(self):
  251. return iter(self.keys())
  252. def __len__(self):
  253. return len(self._module_keys)
  254. def items(self):
  255. return [(key, getattr(self, key)) for key in self._module_keys]
  256. def values(self):
  257. return [getattr(self, key) for key in self._module_keys]
  258. def keys(self):
  259. return self._module_keys
  260. def forward(self):
  261. raise RuntimeError("ModuleList is not callable")
  262. def assign_attr(obj: Union[Module, Tensor], module: Module, target: str):
  263. *prefix, name = target.split(".")
  264. for item in prefix:
  265. module = getattr(module, item)
  266. if not isinstance(module, Module):
  267. raise AttributeError("`{}` is not an Module".format(item))
  268. setattr(module, name, obj)
  269. def get_subattr(module: Module, target: str):
  270. # todo : remove this import
  271. from .node import ModuleNode
  272. if target == "":
  273. return module
  274. *prefix, name = target.split(".")
  275. for item in prefix:
  276. module = getattr(module, item)
  277. if not isinstance(module, (Module, ModuleNode)):
  278. raise AttributeError("`{}` is not an Module".format(item))
  279. return getattr(module, name)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台