# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from abc import ABCMeta, abstractmethod from collections import OrderedDict from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union import numpy as np from ..core import Buffer, Parameter, Tensor from ..logger import get_logger logger = get_logger(__name__) def _expand_structure(key, obj): if isinstance(obj, (list, tuple, dict)): ret = [] if isinstance(obj, dict): targets = ((k, obj[k]) for k in sorted(obj)) else: targets = ((str(k), v) for k, v in enumerate(obj)) for k, o in targets: ret.extend(_expand_structure(key + "." + k, o)) return ret else: return [(key, obj)] def _is_parameter(obj): return isinstance(obj, Parameter) def _is_buffer(obj): return isinstance(obj, Buffer) def _is_module(obj): return isinstance(obj, Module) class Module(metaclass=ABCMeta): """Base Module class. """ def __init__(self): self.training = True @abstractmethod def forward(self, inputs): pass def __call__(self, *inputs, **kwargs): # ToDo: Convert numpy or scalar # Maybe ToDo: set training phase # Maybe ToDo: set computing graph outputs = self.forward(*inputs, **kwargs) # Maybe ToDo: set connectivity metadata return outputs def _flatten( self, *, recursive: bool = True, with_key: bool = False, prefix: Optional[str] = None, predicate: Callable[[Any], bool] = lambda _: True, seen: Optional[Set[int]] = None ) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]: """Scans the module object and returns an iterable for the attributes that agree with the ``predicate``. For multiple calls of this function with same arguments, the order of objects within the returned iterable is guaranteed to be identical, as long as all the involved module objects' ``__dict__`` does not change thoughout those calls. :param recursive: Whether to recursively scan all the submodules. :param with_key: Whether to yield keys along with yielded objects. :param prefix: The prefix appended to the yielded keys. :param predicate: The predicate function applied to scanned objects. :param seen: A dict that records whether a module has been traversed yet. """ if seen is None: seen = set([id(self)]) module_dict = vars(self) _prefix = "" if not prefix else prefix + "." for key in sorted(module_dict): for expanded_key, leaf in _expand_structure(key, module_dict[key]): leaf_id = id(leaf) if leaf_id in seen: continue seen.add(leaf_id) if predicate(leaf): if with_key: yield _prefix + expanded_key, leaf else: yield leaf if recursive and isinstance(leaf, Module): yield from leaf._flatten( recursive=recursive, with_key=with_key, prefix=None if prefix is None else _prefix + expanded_key, predicate=predicate, seen=seen, ) def parameters( self, requires_grad: Optional[bool] = None, recursive: bool = True ) -> Iterable[Parameter]: r"""Returns an iterable for the :class:`~.Parameter` of the module. :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` attribute of returned :class:`.Parameter`. ``None`` for no limitation. :param recursive: If ``True``, returns all :class:`~.Parameter` within this module, else only returns :class:`~.Parameter` that are direct attributes of this module. """ def predicate(obj) -> bool: return _is_parameter(obj) and ( requires_grad is None or obj.requires_grad == requires_grad ) yield from self._flatten(predicate=predicate, recursive=recursive) def named_parameters( self, requires_grad: Optional[bool] = None, prefix: str = "", recursive: bool = True, ) -> Iterable[Tuple[str, Parameter]]: """Returns an iterable for key :class:`~.Parameter` pairs of the module, where ``key`` is the dotted path from this module to the :class:`~.Parameter` . :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` attribute of returned :class:`~.Parameter` . ``None`` for no limitation. :param prefix: The prefix prepended to the keys. :param recursive: If ``True``, returns all :class:`~.Parameter` within this module, else only returns :class:`~.Parameter` that are direct attributes of this module. """ def predicate(obj) -> bool: return _is_parameter(obj) and ( requires_grad is None or obj.requires_grad == requires_grad ) yield from self._flatten( with_key=True, prefix=prefix, predicate=predicate, recursive=recursive ) def buffers(self, recursive: bool = True) -> Iterable[Buffer]: """Returns an iterable for the :class:`~.Buffer` of the module. :param recursive: If ``True``, returns all :class:`~.Buffer` within this module, else only returns :class:`~.Buffer` that are direct attributes of this module. """ yield from self._flatten(predicate=_is_buffer, recursive=recursive) def replace_param(self, params: dict, start_pos: int, seen: Optional[Set[int]] = None): offset = 0 if seen is None: seen = set([id(self)]) module_dict = vars(self) for key in sorted(module_dict): hash_id = id(module_dict[key]) if hash_id in seen: continue seen.add(hash_id) if isinstance(module_dict[key], Parameter): if start_pos + offset in params: assert module_dict[key].shape == params[start_pos + offset].shape module_dict[key] = params[start_pos + offset] offset += 1 if isinstance(module_dict[key], Module): offset += module_dict[key].replace_param(params, start_pos + offset, seen) return offset def named_buffers( self, prefix: str = "", recursive: bool = True ) -> Iterable[Tuple[str, Buffer]]: """Returns an iterable for key :class:`~.Buffer` pairs of the module, where ``key`` is the dotted path from this module to the :class:`~.Buffer` . :param prefix: The prefix prepended to the keys. :param recursive: If ``True``, returns all :class:`~.Buffer` within this module, else only returns :class:`~.Buffer` that are direct attributes of this module. """ yield from self._flatten( with_key=True, prefix=prefix, predicate=_is_buffer, recursive=recursive ) def children(self) -> "Iterable[Module]": """Returns an iterable for all the submodules that are direct attributes of this module. """ yield from self._flatten(predicate=_is_module, recursive=False) def named_children(self) -> "Iterable[Tuple[str, Module]]": """Returns an iterable of key-submodule pairs for all the submodules that are direct attributes of this module, where 'key' is the attribute name of submodules. """ yield from self._flatten(with_key=True, predicate=_is_module, recursive=False) def modules(self) -> "Iterable[Module]": """Returns an iterable for all the modules within this module, including itself. """ yield self yield from self._flatten(predicate=_is_module) def named_modules(self, prefix: str = "") -> "Iterable[Tuple[str, Module]]": """Returns an iterable of key-module pairs for all the modules within this module, including itself, where 'key' is the dotted path from this module to the submodules. :param prefix: The prefix prepended to the path. """ yield prefix, self yield from self._flatten(with_key=True, prefix=prefix, predicate=_is_module) def apply(self, fn: "Callable[[Module], Any]") -> None: """Apply function ``fn`` to all the modules within this module, including itself. :param fn: The function to be applied on modules. """ for it in self.modules(): fn(it) def zero_grad(self) -> None: """Set all parameters' grads to zero """ for param in self.parameters(): if param.grad is not None: param.grad.reset_zero() def train(self, mode: bool = True) -> None: """Set training mode of all the modules within this module (including itself) to ``mode``. This effectively sets the ``training`` attributes of those modules to ``mode``, but only has effect on certain modules (e.g. :class:`~.BatchNorm2d`, :class:`~.Dropout`) :param mode: The training mode to be set on modules. """ self.training = mode def fn(x) -> None: x.training = mode self.apply(fn) def eval(self) -> None: """Set training mode of all the modules within this module (including itself) to ``False``. See :meth:`~.Module.train` for details. """ self.train(False) def state_dict(self, rst=None, prefix="", keep_var=False): r"""Returns a dictionary containing whole states of the module. """ def is_state(obj): return _is_parameter(obj) or _is_buffer(obj) if rst is None: rst = OrderedDict() for k, v in self._flatten(recursive=False, with_key=True, predicate=is_state): assert prefix + k not in rst, "duplicated state: {}".format(k) if keep_var: rst[prefix + k] = v else: rst[prefix + k] = v.numpy() for k, submodule in self._flatten( recursive=False, with_key=True, predicate=lambda obj: isinstance(obj, Module), ): submodule.state_dict(rst, prefix + k + ".", keep_var) return rst def load_state_dict( self, state_dict: Union[dict, Callable[[str, Tensor], Optional[np.ndarray]]], strict=True, ): r"""Load a given dictionary created by :func:`state_dict` into this module. If ``strict`` is ``True``, the keys of :func:`state_dict` must exactly match the keys returned by :func:`state_dict`. Users can also pass a closure: `Function[key: str, var: Tensor] -> Optional[np.ndarray]` as a `state_dict`, in order to handle complex situations. For example, load everything except for the final linear classifier: .. code-block:: state_dict = {...} # Dict[str, np.ndarray] model.load_state_dict({ k: None if k.startswith('fc') else v for k, v in state_dict.items() }, strict=False) Here returning `None` means skipping parameter `k`. To prevent shape mismatch (e.g. load PyTorch weights), we can reshape before loading: .. code-block:: state_dict = {...} def reshape_accordingly(k, v): return state_dict[k].reshape(v.shape) model.load_state_dict(reshape_accordingly) We can also perform inplace re-initialization or pruning: .. code-block:: def reinit_and_pruning(k, v): if 'bias' in k: M.init.zero_(v) if 'conv' in k: return v.numpy() * (np.abs(v.numpy()) > 1e-3).astype("float32) model.load_state_dict(reinit_and_pruning, strict=False) """ unused = [] if isinstance(state_dict, dict): unused = state_dict.keys() def closure(k, _): # var unused return state_dict[k] if k in state_dict else None elif callable(state_dict): closure = state_dict else: raise ValueError( "`state_dict` must load a dict or callable, got {}".format( type(state_dict) ) ) loaded, skipped = self._load_state_dict_with_closure(closure) unused = set(unused) - loaded if strict and len(unused) != 0: raise KeyError( "Unused params violate `strict=True`, unused={}".format(unused) ) if strict and len(skipped) != 0: raise KeyError( "Missing params violate `strict=True`, missing={}".format(skipped) ) def _load_state_dict_with_closure(self, closure): """Advance state_dict load through callable `closure` whose signature is `closure(key: str, var: Tensor) -> Union[np.ndarry, None]` """ assert callable(closure), "closure must be a function" loaded = [] skipped = [] local_state_dict = self.state_dict(keep_var=True) for k, var in local_state_dict.items(): to_be_load = closure(k, var) if to_be_load is None: logger.warning("skip loading param `%s`", k) skipped.append(k) continue assert isinstance( to_be_load, np.ndarray ), "closure should return a `np.ndarray`, now `{}` get {}".format( k, to_be_load ) assert ( var.shape == to_be_load.shape ), "param `{}` shape mismatch, should be {}, get {}".format( k, var.shape, to_be_load.shape ) var.set_value(to_be_load) loaded.append(k) return set(loaded), set(skipped)