You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from abc import ABCMeta, abstractmethod
  9. from collections import OrderedDict
  10. from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union
  11. import numpy as np
  12. from .._internal.dtype import is_quantize
  13. from ..core import Buffer, Parameter, Tensor
  14. from ..logger import get_logger
  15. from ..utils.hook import HookHandler
  16. logger = get_logger(__name__)
  17. def _expand_structure(key, obj):
  18. if isinstance(obj, (Tensor, Module)):
  19. return [(key, obj)]
  20. elif isinstance(obj, (list, tuple, dict)):
  21. ret = []
  22. if isinstance(obj, dict):
  23. targets = ((k, obj[k]) for k in sorted(obj))
  24. else:
  25. targets = ((str(k), v) for k, v in enumerate(obj))
  26. for k, o in targets:
  27. sub_ret = _expand_structure(k, o)
  28. if sub_ret and not isinstance(k, str):
  29. raise AssertionError(
  30. "keys for Tensor and Module must be str, error key: {}".format(k)
  31. )
  32. for kt, vt in sub_ret:
  33. ret.extend([(key + "." + kt, vt)])
  34. return ret
  35. else:
  36. return []
  37. def _is_parameter(obj):
  38. return isinstance(obj, Parameter)
  39. def _is_buffer(obj):
  40. return isinstance(obj, Buffer)
  41. def _is_module(obj):
  42. return isinstance(obj, Module)
  43. class Module(metaclass=ABCMeta):
  44. """Base Module class.
  45. """
  46. def __init__(self):
  47. # runtime attributes
  48. self.training = True
  49. self.quantize_disabled = False
  50. # hooks
  51. self._forward_pre_hooks = OrderedDict()
  52. self._forward_hooks = OrderedDict()
  53. @abstractmethod
  54. def forward(self, inputs):
  55. pass
  56. def register_forward_pre_hook(self, hook: Callable) -> HookHandler:
  57. """Register a hook to handle forward inputs. `hook` should be a function
  58. Note that `inputs` keyword inputs
  59. :param hook: a function that receive `module` and `inputs`, then return
  60. a modified `inputs` or `None`.
  61. :return: a handler with :meth:`~.HookHandler.remove` interface to delete the hook.
  62. """
  63. return HookHandler(self._forward_pre_hooks, hook)
  64. def register_forward_hook(self, hook: Callable) -> HookHandler:
  65. """Register a hook to handle forward results. `hook` should be a function that
  66. receive `module`, `inputs` and `outputs`, then return a modified `outputs` or `None`.
  67. This method return a handler with :meth:`~.HookHandler.remove` interface to delete the hook.
  68. """
  69. return HookHandler(self._forward_hooks, hook)
  70. def __call__(self, *inputs, **kwargs):
  71. for hook in self._forward_pre_hooks.values():
  72. modified_inputs = hook(self, inputs)
  73. if modified_inputs is not None:
  74. if not isinstance(modified_inputs, tuple):
  75. modified_inputs = (modified_inputs,)
  76. inputs = modified_inputs
  77. outputs = self.forward(*inputs, **kwargs)
  78. for hook in self._forward_hooks.values():
  79. modified_outputs = hook(self, inputs, outputs)
  80. if modified_outputs is not None:
  81. outputs = modified_outputs
  82. return outputs
  83. def _flatten(
  84. self,
  85. *,
  86. recursive: bool = True,
  87. with_key: bool = False,
  88. with_parent: bool = False,
  89. prefix: Optional[str] = None,
  90. predicate: Callable[[Any], bool] = lambda _: True,
  91. seen: Optional[Set[int]] = None
  92. ) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]:
  93. """Scans the module object and returns an iterable for the :class:`~.Tensor`
  94. and :class:`~.Module` attributes that agree with the ``predicate``. For multiple
  95. calls of this function with same arguments, the order of objects within the
  96. returned iterable is guaranteed to be identical, as long as all the involved
  97. module objects' ``__dict__`` does not change thoughout those calls.
  98. :param recursive: Whether to recursively scan all the submodules.
  99. :param with_key: Whether to yield keys along with yielded objects.
  100. :param with_parent: Whether to yield ``self`` along with yielded objects.
  101. :param prefix: The prefix appended to the yielded keys.
  102. :param predicate: The predicate function applied to scanned objects.
  103. :param seen: A dict that records whether a module has been traversed yet.
  104. """
  105. if seen is None:
  106. seen = set([id(self)])
  107. module_dict = vars(self)
  108. _prefix = "" if prefix is None else prefix + "."
  109. for key in sorted(module_dict):
  110. for expanded_key, leaf in _expand_structure(key, module_dict[key]):
  111. leaf_id = id(leaf)
  112. if leaf_id in seen:
  113. continue
  114. seen.add(leaf_id)
  115. if predicate(leaf):
  116. if with_key and with_parent:
  117. yield _prefix + expanded_key, leaf, self
  118. elif with_key:
  119. yield _prefix + expanded_key, leaf
  120. elif with_parent:
  121. yield leaf, self
  122. else:
  123. yield leaf
  124. if recursive and isinstance(leaf, Module):
  125. yield from leaf._flatten(
  126. recursive=recursive,
  127. with_key=with_key,
  128. with_parent=with_parent,
  129. prefix=_prefix + expanded_key if with_key else None,
  130. predicate=predicate,
  131. seen=seen,
  132. )
  133. def parameters(
  134. self, requires_grad: Optional[bool] = None, recursive: bool = True, **kwargs
  135. ) -> Iterable[Parameter]:
  136. r"""Returns an iterable for the :class:`~.Parameter` of the module.
  137. :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad`
  138. attribute of returned :class:`.Parameter`. ``None`` for no limitation.
  139. :param recursive: If ``True``, returns all :class:`~.Parameter` within this
  140. module, else only returns :class:`~.Parameter` that are direct attributes
  141. of this module.
  142. """
  143. def predicate(obj) -> bool:
  144. return _is_parameter(obj) and (
  145. requires_grad is None or obj.requires_grad == requires_grad
  146. )
  147. yield from self._flatten(
  148. with_key=False, predicate=predicate, recursive=recursive, **kwargs
  149. )
  150. def named_parameters(
  151. self,
  152. requires_grad: Optional[bool] = None,
  153. prefix: Optional[str] = None,
  154. recursive: bool = True,
  155. **kwargs
  156. ) -> Iterable[Tuple[str, Parameter]]:
  157. """Returns an iterable for key :class:`~.Parameter` pairs of the module, where
  158. ``key`` is the dotted path from this module to the :class:`~.Parameter` .
  159. :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad`
  160. attribute of returned :class:`~.Parameter` . ``None`` for no limitation.
  161. :param prefix: The prefix prepended to the keys.
  162. :param recursive: If ``True``, returns all :class:`~.Parameter` within this
  163. module, else only returns :class:`~.Parameter` that are direct attributes
  164. of this module.
  165. """
  166. def predicate(obj) -> bool:
  167. return _is_parameter(obj) and (
  168. requires_grad is None or obj.requires_grad == requires_grad
  169. )
  170. yield from self._flatten(
  171. with_key=True,
  172. prefix=prefix,
  173. predicate=predicate,
  174. recursive=recursive,
  175. **kwargs,
  176. )
  177. def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Buffer]:
  178. """Returns an iterable for the :class:`~.Buffer` of the module.
  179. :param recursive: If ``True``, returns all :class:`~.Buffer` within this
  180. module, else only returns :class:`~.Buffer` that are direct attributes
  181. of this module.
  182. """
  183. yield from self._flatten(
  184. with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs
  185. )
  186. def named_buffers(
  187. self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
  188. ) -> Iterable[Tuple[str, Buffer]]:
  189. """Returns an iterable for key :class:`~.Buffer` pairs of the module, where
  190. ``key`` is the dotted path from this module to the :class:`~.Buffer` .
  191. :param prefix: The prefix prepended to the keys.
  192. :param recursive: If ``True``, returns all :class:`~.Buffer` within this
  193. module, else only returns :class:`~.Buffer` that are direct attributes
  194. of this module.
  195. """
  196. yield from self._flatten(
  197. with_key=True,
  198. prefix=prefix,
  199. predicate=_is_buffer,
  200. recursive=recursive,
  201. **kwargs,
  202. )
  203. def children(self, **kwargs) -> "Iterable[Module]":
  204. """Returns an iterable for all the submodules that are direct attributes of this
  205. module.
  206. """
  207. yield from self._flatten(
  208. with_key=False, predicate=_is_module, recursive=False, **kwargs
  209. )
  210. def named_children(self, **kwargs) -> "Iterable[Tuple[str, Module]]":
  211. """Returns an iterable of key-submodule pairs for all the submodules that are
  212. direct attributes of this module, where 'key' is the attribute name of
  213. submodules.
  214. """
  215. yield from self._flatten(
  216. with_key=True, predicate=_is_module, recursive=False, **kwargs
  217. )
  218. def modules(self, **kwargs) -> "Iterable[Module]":
  219. """Returns an iterable for all the modules within this module, including itself.
  220. """
  221. if "with_parent" in kwargs and kwargs["with_parent"]:
  222. yield self, None
  223. else:
  224. yield self
  225. yield from self._flatten(with_key=False, predicate=_is_module, **kwargs)
  226. def named_modules(
  227. self, prefix: Optional[str] = None, **kwargs
  228. ) -> "Iterable[Tuple[str, Module]]":
  229. """Returns an iterable of key-module pairs for all the modules within this
  230. module, including itself, where 'key' is the dotted path from this module to the
  231. submodules.
  232. :param prefix: The prefix prepended to the path.
  233. """
  234. if "with_parent" in kwargs and kwargs["with_parent"]:
  235. yield ("" if prefix is None else prefix), self, None
  236. else:
  237. yield ("" if prefix is None else prefix), self
  238. yield from self._flatten(
  239. with_key=True, prefix=prefix, predicate=_is_module, **kwargs
  240. )
  241. def apply(self, fn: "Callable[[Module], Any]") -> None:
  242. """Apply function ``fn`` to all the modules within this module, including
  243. itself.
  244. :param fn: The function to be applied on modules.
  245. """
  246. for it in self.modules():
  247. fn(it)
  248. def zero_grad(self) -> None:
  249. """Set all parameters' grads to zero
  250. """
  251. for param in self.parameters():
  252. if param.grad is not None:
  253. param.grad.reset_zero()
  254. def train(self, mode: bool = True, recursive: bool = True) -> None:
  255. """Set training mode of all the modules within this module (including itself) to
  256. ``mode``. This effectively sets the ``training`` attributes of those modules
  257. to ``mode``, but only has effect on certain modules (e.g.
  258. :class:`~.BatchNorm2d`, :class:`~.Dropout`, :class:`~.Observer`)
  259. :param mode: the training mode to be set on modules.
  260. :param recursive: whether to recursively call submodules' ``train()``.
  261. """
  262. if not recursive:
  263. self.training = mode
  264. return
  265. def fn(module: Module) -> None:
  266. module.train(mode, recursive=False)
  267. self.apply(fn)
  268. def eval(self) -> None:
  269. """Set training mode of all the modules within this module (including itself) to
  270. ``False``. See :meth:`~.Module.train` for details.
  271. """
  272. self.train(False)
  273. def disable_quantize(self, value=True):
  274. r"""
  275. Set ``module``'s ``quantize_disabled`` attribute and return ``module``.
  276. Could be used as a decorator.
  277. """
  278. def fn(module: Module) -> None:
  279. module.quantize_disabled = value
  280. self.apply(fn)
  281. def replace_param(
  282. self, params: dict, start_pos: int, seen: Optional[Set[int]] = None
  283. ):
  284. """Replace module's parameters with `params`, used by :class:`~.ParamPack` to
  285. speedup multimachine training.
  286. """
  287. offset = 0
  288. if seen is None:
  289. seen = set([id(self)])
  290. module_dict = vars(self)
  291. for key in sorted(module_dict):
  292. hash_id = id(module_dict[key])
  293. if hash_id in seen:
  294. continue
  295. seen.add(hash_id)
  296. if isinstance(module_dict[key], Parameter):
  297. if start_pos + offset in params:
  298. assert module_dict[key].shape == params[start_pos + offset].shape
  299. module_dict[key] = params[start_pos + offset]
  300. offset += 1
  301. if isinstance(module_dict[key], Module):
  302. offset += module_dict[key].replace_param(
  303. params, start_pos + offset, seen
  304. )
  305. return offset
  306. def state_dict(self, rst=None, prefix="", keep_var=False):
  307. r"""Returns a dictionary containing whole states of the module.
  308. """
  309. def is_state(obj):
  310. return _is_parameter(obj) or _is_buffer(obj)
  311. if rst is None:
  312. rst = OrderedDict()
  313. for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state):
  314. assert prefix + k not in rst, "duplicated state: {}".format(k)
  315. if keep_var:
  316. rst[prefix + k] = v
  317. else:
  318. rst[prefix + k] = v.numpy()
  319. for k, submodule in self._flatten(
  320. recursive=False,
  321. with_key=True,
  322. predicate=lambda obj: isinstance(obj, Module),
  323. ):
  324. submodule.state_dict(rst, prefix + k + ".", keep_var)
  325. return rst
  326. def load_state_dict(
  327. self,
  328. state_dict: Union[dict, Callable[[str, Tensor], Optional[np.ndarray]]],
  329. strict=True,
  330. ):
  331. r"""Load a given dictionary created by :func:`state_dict` into this module.
  332. If ``strict`` is ``True``, the keys of :func:`state_dict` must exactly match the keys
  333. returned by :func:`state_dict`.
  334. Users can also pass a closure: `Function[key: str, var: Tensor] -> Optional[np.ndarray]`
  335. as a `state_dict`, in order to handle complex situations. For example, load everything
  336. except for the final linear classifier:
  337. .. code-block::
  338. state_dict = {...} # Dict[str, np.ndarray]
  339. model.load_state_dict({
  340. k: None if k.startswith('fc') else v
  341. for k, v in state_dict.items()
  342. }, strict=False)
  343. Here returning `None` means skipping parameter `k`.
  344. To prevent shape mismatch (e.g. load PyTorch weights), we can reshape before loading:
  345. .. code-block::
  346. state_dict = {...}
  347. def reshape_accordingly(k, v):
  348. return state_dict[k].reshape(v.shape)
  349. model.load_state_dict(reshape_accordingly)
  350. We can also perform inplace re-initialization or pruning:
  351. .. code-block::
  352. def reinit_and_pruning(k, v):
  353. if 'bias' in k:
  354. M.init.zero_(v)
  355. if 'conv' in k:
  356. return v.numpy() * (np.abs(v.numpy()) > 1e-3).astype("float32)
  357. model.load_state_dict(reinit_and_pruning, strict=False)
  358. """
  359. unused = []
  360. if isinstance(state_dict, dict):
  361. unused = state_dict.keys()
  362. def closure(k, _): # var unused
  363. return state_dict[k] if k in state_dict else None
  364. elif callable(state_dict):
  365. closure = state_dict
  366. else:
  367. raise ValueError(
  368. "`state_dict` must load a dict or callable, got {}".format(
  369. type(state_dict)
  370. )
  371. )
  372. loaded, skipped = self._load_state_dict_with_closure(closure)
  373. unused = set(unused) - loaded
  374. if len(unused) != 0:
  375. if strict:
  376. raise KeyError(
  377. "Unused params violate `strict=True`, unused={}".format(unused)
  378. )
  379. else:
  380. logger.warning(
  381. "Unused params in `strict=False` mode, unused={}".format(unused)
  382. )
  383. if len(skipped) != 0:
  384. if strict:
  385. raise KeyError(
  386. "Missing params violate `strict=True`, missing={}".format(skipped)
  387. )
  388. else:
  389. logger.warning(
  390. "Missing params in `strict=False` mode, missing={}".format(skipped)
  391. )
  392. def _load_state_dict_with_closure(self, closure):
  393. """Advance state_dict load through callable `closure` whose signature is
  394. `closure(key: str, var: Tensor) -> Union[np.ndarry, None]`
  395. """
  396. assert callable(closure), "closure must be a function"
  397. loaded = []
  398. skipped = []
  399. local_state_dict = self.state_dict(keep_var=True)
  400. for k, var in local_state_dict.items():
  401. to_be_load = closure(k, var)
  402. if to_be_load is None:
  403. skipped.append(k)
  404. continue
  405. assert isinstance(
  406. to_be_load, np.ndarray
  407. ), "closure should return a `np.ndarray`, now `{}` get {}".format(
  408. k, to_be_load
  409. )
  410. assert (
  411. var.shape == to_be_load.shape
  412. ), "param `{}` shape mismatch, should be {}, get {}".format(
  413. k, var.shape, to_be_load.shape
  414. )
  415. # For quantized dtype, the initialized dtype
  416. # scale/zero_points maybe invalid, use pretrained dtype instead.
  417. if is_quantize(to_be_load.dtype) and is_quantize(var.dtype):
  418. var.dtype = to_be_load.dtype
  419. var.set_value(to_be_load)
  420. loaded.append(k)
  421. return set(loaded), set(skipped)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台