Browse Source

feat(mge/module): add `with_parent` argument in `_flatten`

GitOrigin-RevId: 1c88559ece
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
814a427810
3 changed files with 130 additions and 43 deletions
  1. +65
    -33
      python_module/megengine/module/module.py
  2. +7
    -8
      python_module/megengine/optimizer/optimizer.py
  3. +58
    -2
      python_module/test/unit/module/test_module.py

+ 65
- 33
python_module/megengine/module/module.py View File

@@ -68,6 +68,7 @@ class Module(metaclass=ABCMeta):
*,
recursive: bool = True,
with_key: bool = False,
with_parent: bool = False,
prefix: Optional[str] = None,
predicate: Callable[[Any], bool] = lambda _: True,
seen: Optional[Set[int]] = None
@@ -80,6 +81,7 @@ class Module(metaclass=ABCMeta):

:param recursive: Whether to recursively scan all the submodules.
:param with_key: Whether to yield keys along with yielded objects.
:param with_parent: Whether to yield ``self`` along with yielded objects.
:param prefix: The prefix appended to the yielded keys.
:param predicate: The predicate function applied to scanned objects.
:param seen: A dict that records whether a module has been traversed yet.
@@ -88,7 +90,7 @@ class Module(metaclass=ABCMeta):
seen = set([id(self)])

module_dict = vars(self)
_prefix = "" if not prefix else prefix + "."
_prefix = "" if prefix is None else prefix + "."

for key in sorted(module_dict):
for expanded_key, leaf in _expand_structure(key, module_dict[key]):
@@ -98,8 +100,12 @@ class Module(metaclass=ABCMeta):
seen.add(leaf_id)

if predicate(leaf):
if with_key:
if with_key and with_parent:
yield _prefix + expanded_key, leaf, self
elif with_key:
yield _prefix + expanded_key, leaf
elif with_parent:
yield leaf, self
else:
yield leaf

@@ -107,22 +113,22 @@ class Module(metaclass=ABCMeta):
yield from leaf._flatten(
recursive=recursive,
with_key=with_key,
prefix=None if prefix is None else _prefix + expanded_key,
with_parent=with_parent,
prefix=_prefix + expanded_key if with_key else None,
predicate=predicate,
seen=seen,
)

def parameters(
self, requires_grad: Optional[bool] = None, recursive: bool = True
self, requires_grad: Optional[bool] = None, recursive: bool = True, **kwargs
) -> Iterable[Parameter]:
r"""Returns an iterable for the :class:`~.Parameter` of the module.

:param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad`
attribute of returned :class:`.Parameter`. ``None`` for
no limitation.
attribute of returned :class:`.Parameter`. ``None`` for no limitation.
:param recursive: If ``True``, returns all :class:`~.Parameter` within this
module, else only returns :class:`~.Parameter` that are direct
attributes of this module.
module, else only returns :class:`~.Parameter` that are direct attributes
of this module.
"""

def predicate(obj) -> bool:
@@ -130,24 +136,26 @@ class Module(metaclass=ABCMeta):
requires_grad is None or obj.requires_grad == requires_grad
)

yield from self._flatten(predicate=predicate, recursive=recursive)
yield from self._flatten(
with_key=False, predicate=predicate, recursive=recursive, **kwargs
)

def named_parameters(
self,
requires_grad: Optional[bool] = None,
prefix: str = "",
prefix: Optional[str] = None,
recursive: bool = True,
**kwargs
) -> Iterable[Tuple[str, Parameter]]:
"""Returns an iterable for key :class:`~.Parameter` pairs of the module, where
``key`` is the dotted path from this module to the :class:`~.Parameter` .

:param requires_grad: Limitation over the :attr:`~.Parameter.requires_grad`
attribute of returned :class:`~.Parameter` . ``None`` for
no limitation.
attribute of returned :class:`~.Parameter` . ``None`` for no limitation.
:param prefix: The prefix prepended to the keys.
:param recursive: If ``True``, returns all :class:`~.Parameter` within this
module, else only returns :class:`~.Parameter` that are direct
attributes of this module.
module, else only returns :class:`~.Parameter` that are direct attributes
of this module.
"""

def predicate(obj) -> bool:
@@ -156,17 +164,23 @@ class Module(metaclass=ABCMeta):
)

yield from self._flatten(
with_key=True, prefix=prefix, predicate=predicate, recursive=recursive
with_key=True,
prefix=prefix,
predicate=predicate,
recursive=recursive,
**kwargs,
)

def buffers(self, recursive: bool = True) -> Iterable[Buffer]:
def buffers(self, recursive: bool = True, **kwargs) -> Iterable[Buffer]:
"""Returns an iterable for the :class:`~.Buffer` of the module.

:param recursive: If ``True``, returns all :class:`~.Buffer` within this
module, else only returns :class:`~.Buffer` that are direct
attributes of this module.
module, else only returns :class:`~.Buffer` that are direct attributes
of this module.
"""
yield from self._flatten(predicate=_is_buffer, recursive=recursive)
yield from self._flatten(
with_key=False, predicate=_is_buffer, recursive=recursive, **kwargs
)

def replace_param(
self, params: dict, start_pos: int, seen: Optional[Set[int]] = None
@@ -192,48 +206,66 @@ class Module(metaclass=ABCMeta):
return offset

def named_buffers(
self, prefix: str = "", recursive: bool = True
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
) -> Iterable[Tuple[str, Buffer]]:
"""Returns an iterable for key :class:`~.Buffer` pairs of the module, where
``key`` is the dotted path from this module to the :class:`~.Buffer` .

:param prefix: The prefix prepended to the keys.
:param recursive: If ``True``, returns all :class:`~.Buffer` within this
module, else only returns :class:`~.Buffer` that are direct
attributes of this module.
module, else only returns :class:`~.Buffer` that are direct attributes
of this module.
"""
yield from self._flatten(
with_key=True, prefix=prefix, predicate=_is_buffer, recursive=recursive
with_key=True,
prefix=prefix,
predicate=_is_buffer,
recursive=recursive,
**kwargs,
)

def children(self) -> "Iterable[Module]":
def children(self, **kwargs) -> "Iterable[Module]":
"""Returns an iterable for all the submodules that are direct attributes of this
module.
"""
yield from self._flatten(predicate=_is_module, recursive=False)
yield from self._flatten(
with_key=False, predicate=_is_module, recursive=False, **kwargs
)

def named_children(self) -> "Iterable[Tuple[str, Module]]":
def named_children(self, **kwargs) -> "Iterable[Tuple[str, Module]]":
"""Returns an iterable of key-submodule pairs for all the submodules that are
direct attributes of this module, where 'key' is the attribute name of
submodules.
"""
yield from self._flatten(with_key=True, predicate=_is_module, recursive=False)
yield from self._flatten(
with_key=True, predicate=_is_module, recursive=False, **kwargs
)

def modules(self) -> "Iterable[Module]":
def modules(self, **kwargs) -> "Iterable[Module]":
"""Returns an iterable for all the modules within this module, including itself.
"""
yield self
yield from self._flatten(predicate=_is_module)
if "with_parent" in kwargs and kwargs["with_parent"]:
yield self, None
else:
yield self
yield from self._flatten(with_key=False, predicate=_is_module, **kwargs)

def named_modules(self, prefix: str = "") -> "Iterable[Tuple[str, Module]]":
def named_modules(
self, prefix: Optional[str] = None, **kwargs
) -> "Iterable[Tuple[str, Module]]":
"""Returns an iterable of key-module pairs for all the modules within this
module, including itself, where 'key' is the dotted path from this module to the
submodules.

:param prefix: The prefix prepended to the path.
"""
yield prefix, self
yield from self._flatten(with_key=True, prefix=prefix, predicate=_is_module)
if "with_parent" in kwargs and kwargs["with_parent"]:
yield ("" if prefix is None else prefix), self, None
else:
yield ("" if prefix is None else prefix), self
yield from self._flatten(
with_key=True, prefix=prefix, predicate=_is_module, **kwargs
)

def apply(self, fn: "Callable[[Module], Any]") -> None:
"""Apply function ``fn`` to all the modules within this module, including


+ 7
- 8
python_module/megengine/optimizer/optimizer.py View File

@@ -53,9 +53,6 @@ class Optimizer(metaclass=ABCMeta):
if isinstance(params, (Parameter, dict)):
params = [params]
else:
assert isinstance(
params, Iterable
), "params argument given to the optimizer should be Parameter or dict"
if not isinstance(params, Iterable):
raise TypeError(
"params argument given to the optimizer should be "
@@ -65,13 +62,15 @@ class Optimizer(metaclass=ABCMeta):
self.param_groups = [] # type: list

param_groups = list(params)
assert len(param_groups) != 0, "optimizer got an empty parameter list"
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")

param_type = type(param_groups[0])
for param in param_groups:
assert isinstance(
param, param_type
), "types of params argument given to the optimizer shoud be same"
if not isinstance(param, param_type):
raise TypeError(
"types of params argument given to the optimizer shoud be same"
)

if not isinstance(param_groups[0], dict):
param_groups = [{"params": param_groups}]
@@ -150,7 +149,7 @@ class Optimizer(metaclass=ABCMeta):
def backward(self, loss: Tensor):
"""Computes the back-propagation of the network given loss.

:param loss: The obtained loss tensor
:param loss: The obtained loss tensor
"""
rst = []
key = 0


+ 58
- 2
python_module/test/unit/module/test_module.py View File

@@ -15,7 +15,7 @@ from helpers import MLP

import megengine as mge
from megengine.core import Buffer, Parameter, tensor
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential
from megengine.test import assertTensorClose


@@ -156,7 +156,7 @@ class MyModule2(Module):
return x


def test_mode_api_expand_structure():
def test_expand_structure():
m = MyModule2()
assert list(m.named_modules()) == [
("", m),
@@ -171,6 +171,62 @@ def test_mode_api_expand_structure():
]


def test_flatten_with_parent():
m = MyModule2()
assert list(m.named_modules(with_parent=True)) == [
("", m, None),
("a.0", m.a[0], m),
("a.1.x", m.a[1]["x"], m),
("a.1.y.0", m.a[1]["y"][0], m),
("a.1.y.1", m.a[1]["y"][1], m),
("a.1.y.1.bn", m.a[1]["y"][1].bn, m.a[1]["y"][1]),
("a.2.0", m.a[2][0], m),
("a.2.0.bn", m.a[2][0].bn, m.a[2][0]),
("bn", m.bn, m),
]
assert list(m.modules(with_parent=True)) == [
(m, None),
(m.a[0], m),
(m.a[1]["x"], m),
(m.a[1]["y"][0], m),
(m.a[1]["y"][1], m),
(m.a[1]["y"][1].bn, m.a[1]["y"][1]),
(m.a[2][0], m),
(m.a[2][0].bn, m.a[2][0]),
(m.bn, m),
]


class MyModule3(Module):
class InnerModule(Module):
def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)

def forward(self, x):
x = self.bn(x)

def __init__(self):
super().__init__()
self.bn = BatchNorm2d(4)
self.seq = Sequential(BatchNorm2d(4), self.InnerModule(),)

def forward(self, x):
return x


def test_module_api_with_sequential():
m = MyModule3()
assert list(m.named_modules()) == [
("", m),
("bn", m.bn),
("seq", m.seq),
("seq.0", m.seq[0]),
("seq.1", m.seq[1]),
("seq.1.bn", m.seq[1].bn),
]


def test_state_dict():
data_shape = (2, 28)
data = tensor()


Loading…
Cancel
Save