diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index f770fa16..89a8edff 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -68,6 +68,7 @@ class Module(metaclass=ABCMeta): *, recursive: bool = True, with_key: bool = False, + with_parent: bool = False, prefix: Optional[str] = None, predicate: Callable[[Any], bool] = lambda _: True, seen: Optional[Set[int]] = None @@ -80,6 +81,7 @@ class Module(metaclass=ABCMeta): :param recursive: Whether to recursively scan all the submodules. :param with_key: Whether to yield keys along with yielded objects. + :param with_parent: Whether to yield ``self`` along with yielded objects. :param prefix: The prefix appended to the yielded keys. :param predicate: The predicate function applied to scanned objects. :param seen: A dict that records whether a module has been traversed yet. @@ -88,7 +90,7 @@ class Module(metaclass=ABCMeta): seen = set([id(self)]) module_dict = vars(self) - _prefix = "" if not prefix else prefix + "." + _prefix = "" if prefix is None else prefix + "." for key in sorted(module_dict): for expanded_key, leaf in _expand_structure(key, module_dict[key]): @@ -98,8 +100,12 @@ class Module(metaclass=ABCMeta): seen.add(leaf_id) if predicate(leaf): - if with_key: + if with_key and with_parent: + yield _prefix + expanded_key, leaf, self + elif with_key: yield _prefix + expanded_key, leaf + elif with_parent: + yield leaf, self else: yield leaf @@ -107,22 +113,22 @@ class Module(metaclass=ABCMeta): yield from leaf._flatten( recursive=recursive, with_key=with_key, - prefix=None if prefix is None else _prefix + expanded_key, + with_parent=with_parent, + prefix=_prefix + expanded_key if with_key else None, predicate=predicate, seen=seen, ) def parameters( - self, requires_grad: Optional[bool] = None, recursive: bool = True + self, requires_grad: Optional[bool] = None, recursive: bool = True, **kwargs ) -> Iterable[Parameter]: r"""Returns an iterable for the :class:`~.Parameter` of the module. :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` - attribute of returned :class:`.Parameter`. ``None`` for - no limitation. + attribute of returned :class:`.Parameter`. ``None`` for no limitation. :param recursive: If ``True``, returns all :class:`~.Parameter` within this - module, else only returns :class:`~.Parameter` that are direct - attributes of this module. + module, else only returns :class:`~.Parameter` that are direct attributes + of this module. """ def predicate(obj) -> bool: @@ -130,24 +136,26 @@ class Module(metaclass=ABCMeta): requires_grad is None or obj.requires_grad == requires_grad ) - yield from self._flatten(predicate=predicate, recursive=recursive) + yield from self._flatten( + with_key=False, predicate=predicate, recursive=recursive, **kwargs + ) def named_parameters( self, requires_grad: Optional[bool] = None, - prefix: str = "", + prefix: Optional[str] = None, recursive: bool = True, + **kwargs ) -> Iterable[Tuple[str, Parameter]]: """Returns an iterable for key :class:`~.Parameter` pairs of the module, where ``key`` is the dotted path from this module to the :class:`~.Parameter` . :param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad` - attribute of returned :class:`~.Parameter` . ``None`` for - no limitation. + attribute of returned :class:`~.Parameter` . ``None`` for no limitation. :param prefix: The prefix prepended to the keys. :param recursive: If ``True``, returns all :class:`~.Parameter` within this - module, else only returns :class:`~.Parameter` that are direct - attributes of this module. + module, else only returns :class:`~.Parameter` that are direct attributes + of this module. """ def predicate(obj) -> bool: @@ -156,17 +164,23 @@ class Module(metaclass=ABCMeta): ) yield from self._flatten( - with_key=True, prefix=prefix, predicate=predicate, recursive=recursive + with_key=True, + prefix=prefix, + predicate=predicate, + recursive=recursive, + **kwargs, ) - def buffers(self, recursive: bool = True) -> Iterable[Buffer]: + def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Buffer]: """Returns an iterable for the :class:`~.Buffer` of the module. :param recursive: If ``True``, returns all :class:`~.Buffer` within this - module, else only returns :class:`~.Buffer` that are direct - attributes of this module. + module, else only returns :class:`~.Buffer` that are direct attributes + of this module. """ - yield from self._flatten(predicate=_is_buffer, recursive=recursive) + yield from self._flatten( + with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs + ) def replace_param( self, params: dict, start_pos: int, seen: Optional[Set[int]] = None @@ -192,48 +206,66 @@ class Module(metaclass=ABCMeta): return offset def named_buffers( - self, prefix: str = "", recursive: bool = True + self, prefix: Optional[str] = None, recursive: bool = True, **kwargs ) -> Iterable[Tuple[str, Buffer]]: """Returns an iterable for key :class:`~.Buffer` pairs of the module, where ``key`` is the dotted path from this module to the :class:`~.Buffer` . :param prefix: The prefix prepended to the keys. :param recursive: If ``True``, returns all :class:`~.Buffer` within this - module, else only returns :class:`~.Buffer` that are direct - attributes of this module. + module, else only returns :class:`~.Buffer` that are direct attributes + of this module. """ yield from self._flatten( - with_key=True, prefix=prefix, predicate=_is_buffer, recursive=recursive + with_key=True, + prefix=prefix, + predicate=_is_buffer, + recursive=recursive, + **kwargs, ) - def children(self) -> "Iterable[Module]": + def children(self, **kwargs) -> "Iterable[Module]": """Returns an iterable for all the submodules that are direct attributes of this module. """ - yield from self._flatten(predicate=_is_module, recursive=False) + yield from self._flatten( + with_key=False, predicate=_is_module, recursive=False, **kwargs + ) - def named_children(self) -> "Iterable[Tuple[str, Module]]": + def named_children(self, **kwargs) -> "Iterable[Tuple[str, Module]]": """Returns an iterable of key-submodule pairs for all the submodules that are direct attributes of this module, where 'key' is the attribute name of submodules. """ - yield from self._flatten(with_key=True, predicate=_is_module, recursive=False) + yield from self._flatten( + with_key=True, predicate=_is_module, recursive=False, **kwargs + ) - def modules(self) -> "Iterable[Module]": + def modules(self, **kwargs) -> "Iterable[Module]": """Returns an iterable for all the modules within this module, including itself. """ - yield self - yield from self._flatten(predicate=_is_module) + if "with_parent" in kwargs and kwargs["with_parent"]: + yield self, None + else: + yield self + yield from self._flatten(with_key=False, predicate=_is_module, **kwargs) - def named_modules(self, prefix: str = "") -> "Iterable[Tuple[str, Module]]": + def named_modules( + self, prefix: Optional[str] = None, **kwargs + ) -> "Iterable[Tuple[str, Module]]": """Returns an iterable of key-module pairs for all the modules within this module, including itself, where 'key' is the dotted path from this module to the submodules. :param prefix: The prefix prepended to the path. """ - yield prefix, self - yield from self._flatten(with_key=True, prefix=prefix, predicate=_is_module) + if "with_parent" in kwargs and kwargs["with_parent"]: + yield ("" if prefix is None else prefix), self, None + else: + yield ("" if prefix is None else prefix), self + yield from self._flatten( + with_key=True, prefix=prefix, predicate=_is_module, **kwargs + ) def apply(self, fn: "Callable[[Module], Any]") -> None: """Apply function ``fn`` to all the modules within this module, including diff --git a/python_module/megengine/optimizer/optimizer.py b/python_module/megengine/optimizer/optimizer.py index d559783a..cfbbad97 100644 --- a/python_module/megengine/optimizer/optimizer.py +++ b/python_module/megengine/optimizer/optimizer.py @@ -53,9 +53,6 @@ class Optimizer(metaclass=ABCMeta): if isinstance(params, (Parameter, dict)): params = [params] else: - assert isinstance( - params, Iterable - ), "params argument given to the optimizer should be Parameter or dict" if not isinstance(params, Iterable): raise TypeError( "params argument given to the optimizer should be " @@ -65,13 +62,15 @@ class Optimizer(metaclass=ABCMeta): self.param_groups = [] # type: list param_groups = list(params) - assert len(param_groups) != 0, "optimizer got an empty parameter list" + if len(param_groups) == 0: + raise ValueError("optimizer got an empty parameter list") param_type = type(param_groups[0]) for param in param_groups: - assert isinstance( - param, param_type - ), "types of params argument given to the optimizer shoud be same" + if not isinstance(param, param_type): + raise TypeError( + "types of params argument given to the optimizer shoud be same" + ) if not isinstance(param_groups[0], dict): param_groups = [{"params": param_groups}] @@ -150,7 +149,7 @@ class Optimizer(metaclass=ABCMeta): def backward(self, loss: Tensor): """Computes the back-propagation of the network given loss. - :param loss: The obtained loss tensor + :param loss: The obtained loss tensor """ rst = [] key = 0 diff --git a/python_module/test/unit/module/test_module.py b/python_module/test/unit/module/test_module.py index 3dc2a567..1c72e2dd 100644 --- a/python_module/test/unit/module/test_module.py +++ b/python_module/test/unit/module/test_module.py @@ -15,7 +15,7 @@ from helpers import MLP import megengine as mge from megengine.core import Buffer, Parameter, tensor -from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module +from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential from megengine.test import assertTensorClose @@ -156,7 +156,7 @@ class MyModule2(Module): return x -def test_mode_api_expand_structure(): +def test_expand_structure(): m = MyModule2() assert list(m.named_modules()) == [ ("", m), @@ -171,6 +171,62 @@ def test_mode_api_expand_structure(): ] +def test_flatten_with_parent(): + m = MyModule2() + assert list(m.named_modules(with_parent=True)) == [ + ("", m, None), + ("a.0", m.a[0], m), + ("a.1.x", m.a[1]["x"], m), + ("a.1.y.0", m.a[1]["y"][0], m), + ("a.1.y.1", m.a[1]["y"][1], m), + ("a.1.y.1.bn", m.a[1]["y"][1].bn, m.a[1]["y"][1]), + ("a.2.0", m.a[2][0], m), + ("a.2.0.bn", m.a[2][0].bn, m.a[2][0]), + ("bn", m.bn, m), + ] + assert list(m.modules(with_parent=True)) == [ + (m, None), + (m.a[0], m), + (m.a[1]["x"], m), + (m.a[1]["y"][0], m), + (m.a[1]["y"][1], m), + (m.a[1]["y"][1].bn, m.a[1]["y"][1]), + (m.a[2][0], m), + (m.a[2][0].bn, m.a[2][0]), + (m.bn, m), + ] + + +class MyModule3(Module): + class InnerModule(Module): + def __init__(self): + super().__init__() + self.bn = BatchNorm2d(4) + + def forward(self, x): + x = self.bn(x) + + def __init__(self): + super().__init__() + self.bn = BatchNorm2d(4) + self.seq = Sequential(BatchNorm2d(4), self.InnerModule(),) + + def forward(self, x): + return x + + +def test_module_api_with_sequential(): + m = MyModule3() + assert list(m.named_modules()) == [ + ("", m), + ("bn", m.bn), + ("seq", m.seq), + ("seq.0", m.seq[0]), + ("seq.1", m.seq[1]), + ("seq.1.bn", m.seq[1].bn), + ] + + def test_state_dict(): data_shape = (2, 28) data = tensor()