You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from abc import ABCMeta, abstractmethod
  10. from collections import OrderedDict
  11. from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union
  12. import numpy as np
  13. from ..core import Buffer, Parameter, Tensor
  14. from ..logger import get_logger
  15. logger = get_logger(__name__)
  16. def _expand_structure(key, obj):
  17. if isinstance(obj, (list, tuple, dict)):
  18. ret = []
  19. if isinstance(obj, dict):
  20. targets = ((k, obj[k]) for k in sorted(obj))
  21. else:
  22. targets = ((str(k), v) for k, v in enumerate(obj))
  23. for k, o in targets:
  24. ret.extend(_expand_structure(key + "." + k, o))
  25. return ret
  26. else:
  27. return [(key, obj)]
  28. def _is_parameter(obj):
  29. return isinstance(obj, Parameter)
  30. def _is_buffer(obj):
  31. return isinstance(obj, Buffer)
  32. def _is_module(obj):
  33. return isinstance(obj, Module)
  34. class Module(metaclass=ABCMeta):
  35. """Base Module class.
  36. """
  37. def __init__(self):
  38. self.training = True
  39. @abstractmethod
  40. def forward(self, inputs):
  41. pass
  42. def __call__(self, *inputs, **kwargs):
  43. # ToDo: Convert numpy or scalar
  44. # Maybe ToDo: set training phase
  45. # Maybe ToDo: set computing graph
  46. outputs = self.forward(*inputs, **kwargs)
  47. # Maybe ToDo: set connectivity metadata
  48. return outputs
  49. def _flatten(
  50. self,
  51. *,
  52. recursive: bool = True,
  53. with_key: bool = False,
  54. prefix: Optional[str] = None,
  55. predicate: Callable[[Any], bool] = lambda _: True,
  56. seen: Optional[Set[int]] = None
  57. ) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]:
  58. """Scans the module object and returns an iterable for the attributes that
  59. agree with the ``predicate``. For multiple calls of this function with same
  60. arguments, the order of objects within the returned iterable is guaranteed to be
  61. identical, as long as all the involved module objects' ``__dict__`` does not
  62. change thoughout those calls.
  63. :param recursive: Whether to recursively scan all the submodules.
  64. :param with_key: Whether to yield keys along with yielded objects.
  65. :param prefix: The prefix appended to the yielded keys.
  66. :param predicate: The predicate function applied to scanned objects.
  67. :param seen: A dict that records whether a module has been traversed yet.
  68. """
  69. if seen is None:
  70. seen = set([id(self)])
  71. module_dict = vars(self)
  72. _prefix = "" if not prefix else prefix + "."
  73. for key in sorted(module_dict):
  74. for expanded_key, leaf in _expand_structure(key, module_dict[key]):
  75. leaf_id = id(leaf)
  76. if leaf_id in seen:
  77. continue
  78. seen.add(leaf_id)
  79. if predicate(leaf):
  80. if with_key:
  81. yield _prefix + expanded_key, leaf
  82. else:
  83. yield leaf
  84. if recursive and isinstance(leaf, Module):
  85. yield from leaf._flatten(
  86. recursive=recursive,
  87. with_key=with_key,
  88. prefix=None if prefix is None else _prefix + expanded_key,
  89. predicate=predicate,
  90. seen=seen,
  91. )
  92. def parameters(
  93. self, requires_grad: Optional[bool] = None, recursive: bool = True
  94. ) -> Iterable[Parameter]:
  95. r"""Returns an iterable for the :class:`~.Parameter` of the module.
  96. :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad`
  97. attribute of returned :class:`.Parameter`. ``None`` for
  98. no limitation.
  99. :param recursive: If ``True``, returns all :class:`~.Parameter` within this
  100. module, else only returns :class:`~.Parameter` that are direct
  101. attributes of this module.
  102. """
  103. def predicate(obj) -> bool:
  104. return _is_parameter(obj) and (
  105. requires_grad is None or obj.requires_grad == requires_grad
  106. )
  107. yield from self._flatten(predicate=predicate, recursive=recursive)
  108. def named_parameters(
  109. self,
  110. requires_grad: Optional[bool] = None,
  111. prefix: str = "",
  112. recursive: bool = True,
  113. ) -> Iterable[Tuple[str, Parameter]]:
  114. """Returns an iterable for key :class:`~.Parameter` pairs of the module, where
  115. ``key`` is the dotted path from this module to the :class:`~.Parameter` .
  116. :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad`
  117. attribute of returned :class:`~.Parameter` . ``None`` for
  118. no limitation.
  119. :param prefix: The prefix prepended to the keys.
  120. :param recursive: If ``True``, returns all :class:`~.Parameter` within this
  121. module, else only returns :class:`~.Parameter` that are direct
  122. attributes of this module.
  123. """
  124. def predicate(obj) -> bool:
  125. return _is_parameter(obj) and (
  126. requires_grad is None or obj.requires_grad == requires_grad
  127. )
  128. yield from self._flatten(
  129. with_key=True, prefix=prefix, predicate=predicate, recursive=recursive
  130. )
  131. def buffers(self, recursive: bool = True) -> Iterable[Buffer]:
  132. """Returns an iterable for the :class:`~.Buffer` of the module.
  133. :param recursive: If ``True``, returns all :class:`~.Buffer` within this
  134. module, else only returns :class:`~.Buffer` that are direct
  135. attributes of this module.
  136. """
  137. yield from self._flatten(predicate=_is_buffer, recursive=recursive)
  138. def replace_param(self,
  139. params: dict,
  140. start_pos: int,
  141. seen: Optional[Set[int]] = None):
  142. offset = 0
  143. if seen is None:
  144. seen = set([id(self)])
  145. module_dict = vars(self)
  146. for key in sorted(module_dict):
  147. hash_id = id(module_dict[key])
  148. if hash_id in seen:
  149. continue
  150. seen.add(hash_id)
  151. if isinstance(module_dict[key], Parameter):
  152. if start_pos + offset in params:
  153. assert module_dict[key].shape == params[start_pos +
  154. offset].shape
  155. module_dict[key] = params[start_pos + offset]
  156. offset += 1
  157. if isinstance(module_dict[key], Module):
  158. offset += module_dict[key].replace_param(params, start_pos + offset, seen)
  159. return offset
  160. def named_buffers(
  161. self, prefix: str = "", recursive: bool = True
  162. ) -> Iterable[Tuple[str, Buffer]]:
  163. """Returns an iterable for key :class:`~.Buffer` pairs of the module, where
  164. ``key`` is the dotted path from this module to the :class:`~.Buffer` .
  165. :param prefix: The prefix prepended to the keys.
  166. :param recursive: If ``True``, returns all :class:`~.Buffer` within this
  167. module, else only returns :class:`~.Buffer` that are direct
  168. attributes of this module.
  169. """
  170. yield from self._flatten(
  171. with_key=True, prefix=prefix, predicate=_is_buffer, recursive=recursive
  172. )
  173. def children(self) -> "Iterable[Module]":
  174. """Returns an iterable for all the submodules that are direct attributes of this
  175. module.
  176. """
  177. yield from self._flatten(predicate=_is_module, recursive=False)
  178. def named_children(self) -> "Iterable[Tuple[str, Module]]":
  179. """Returns an iterable of key-submodule pairs for all the submodules that are
  180. direct attributes of this module, where 'key' is the attribute name of
  181. submodules.
  182. """
  183. yield from self._flatten(with_key=True, predicate=_is_module, recursive=False)
  184. def modules(self) -> "Iterable[Module]":
  185. """Returns an iterable for all the modules within this module, including itself.
  186. """
  187. yield self
  188. yield from self._flatten(predicate=_is_module)
  189. def named_modules(self, prefix: str = "") -> "Iterable[Tuple[str, Module]]":
  190. """Returns an iterable of key-module pairs for all the modules within this
  191. module, including itself, where 'key' is the dotted path from this module to the
  192. submodules.
  193. :param prefix: The prefix prepended to the path.
  194. """
  195. yield prefix, self
  196. yield from self._flatten(with_key=True, prefix=prefix, predicate=_is_module)
  197. def apply(self, fn: "Callable[[Module], Any]") -> None:
  198. """Apply function ``fn`` to all the modules within this module, including
  199. itself.
  200. :param fn: The function to be applied on modules.
  201. """
  202. for it in self.modules():
  203. fn(it)
  204. def zero_grad(self) -> None:
  205. """Set all parameters' grads to zero
  206. """
  207. for param in self.parameters():
  208. if param.grad is not None:
  209. param.grad.reset_zero()
  210. def train(self, mode: bool = True) -> None:
  211. """Set training mode of all the modules within this module (including itself) to
  212. ``mode``. This effectively sets the ``training`` attributes of those modules
  213. to ``mode``, but only has effect on certain modules (e.g.
  214. :class:`~.BatchNorm2d`, :class:`~.Dropout`)
  215. :param mode: The training mode to be set on modules.
  216. """
  217. self.training = mode
  218. def fn(x) -> None:
  219. x.training = mode
  220. self.apply(fn)
  221. def eval(self) -> None:
  222. """Set training mode of all the modules within this module (including itself) to
  223. ``False``. See :meth:`~.Module.train` for details.
  224. """
  225. self.train(False)
  226. def state_dict(self, rst=None, prefix="", keep_var=False):
  227. r"""Returns a dictionary containing whole states of the module.
  228. """
  229. def is_state(obj):
  230. return _is_parameter(obj) or _is_buffer(obj)
  231. if rst is None:
  232. rst = OrderedDict()
  233. for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state):
  234. assert prefix + k not in rst, "duplicated state: {}".format(k)
  235. if keep_var:
  236. rst[prefix + k] = v
  237. else:
  238. rst[prefix + k] = v.numpy()
  239. for k, submodule in self._flatten(
  240. recursive=False,
  241. with_key=True,
  242. predicate=lambda obj: isinstance(obj, Module),
  243. ):
  244. submodule.state_dict(rst, prefix + k + ".", keep_var)
  245. return rst
  246. def load_state_dict(
  247. self,
  248. state_dict: Union[dict, Callable[[str, Tensor], Optional[np.ndarray]]],
  249. strict=True,
  250. ):
  251. r"""Load a given dictionary created by :func:`state_dict` into this module.
  252. If ``strict`` is ``True``, the keys of :func:`state_dict` must exactly match the keys
  253. returned by :func:`state_dict`.
  254. Users can also pass a closure: `Function[key: str, var: Tensor] -> Optional[np.ndarray]`
  255. as a `state_dict`, in order to handle complex situations. For example, load everything
  256. except for the final linear classifier:
  257. .. code-block::
  258. state_dict = {...} # Dict[str, np.ndarray]
  259. model.load_state_dict({
  260. k: None if k.startswith('fc') else v
  261. for k, v in state_dict.items()
  262. }, strict=False)
  263. Here returning `None` means skipping parameter `k`.
  264. To prevent shape mismatch (e.g. load PyTorch weights), we can reshape before loading:
  265. .. code-block::
  266. state_dict = {...}
  267. def reshape_accordingly(k, v):
  268. return state_dict[k].reshape(v.shape)
  269. model.load_state_dict(reshape_accordingly)
  270. We can also perform inplace re-initialization or pruning:
  271. .. code-block::
  272. def reinit_and_pruning(k, v):
  273. if 'bias' in k:
  274. M.init.zero_(v)
  275. if 'conv' in k:
  276. return v.numpy() * (np.abs(v.numpy()) > 1e-3).astype("float32)
  277. model.load_state_dict(reinit_and_pruning, strict=False)
  278. """
  279. unused = []
  280. if isinstance(state_dict, dict):
  281. unused = state_dict.keys()
  282. def closure(k, _): # var unused
  283. return state_dict[k] if k in state_dict else None
  284. elif callable(state_dict):
  285. closure = state_dict
  286. else:
  287. raise ValueError(
  288. "`state_dict` must load a dict or callable, got {}".format(
  289. type(state_dict)
  290. )
  291. )
  292. loaded, skipped = self._load_state_dict_with_closure(closure)
  293. unused = set(unused) - loaded
  294. if strict and len(unused) != 0:
  295. raise KeyError(
  296. "Unused params violate `strict=True`, unused={}".format(unused)
  297. )
  298. if strict and len(skipped) != 0:
  299. raise KeyError(
  300. "Missing params violate `strict=True`, missing={}".format(skipped)
  301. )
  302. def _load_state_dict_with_closure(self, closure):
  303. """Advance state_dict load through callable `closure` whose signature is
  304. `closure(key: str, var: Tensor) -> Union[np.ndarry, None]`
  305. """
  306. assert callable(closure), "closure must be a function"
  307. loaded = []
  308. skipped = []
  309. local_state_dict = self.state_dict(keep_var=True)
  310. for k, var in local_state_dict.items():
  311. to_be_load = closure(k, var)
  312. if to_be_load is None:
  313. logger.warning("skip loading param `%s`", k)
  314. skipped.append(k)
  315. continue
  316. assert isinstance(
  317. to_be_load, np.ndarray
  318. ), "closure should return a `np.ndarray`, now `{}` get {}".format(
  319. k, to_be_load
  320. )
  321. assert (
  322. var.shape == to_be_load.shape
  323. ), "param `{}` shape mismatch, should be {}, get {}".format(
  324. k, var.shape, to_be_load.shape
  325. )
  326. var.set_value(to_be_load)
  327. loaded.append(k)
  328. return set(loaded), set(skipped)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)