GitOrigin-RevId: f82cd48230
tags/v0.5.0
@@ -18,17 +18,25 @@ logger = get_logger(__name__) | |||||
def _expand_structure(key, obj): | def _expand_structure(key, obj): | ||||
if isinstance(obj, (list, tuple, dict)): | |||||
if isinstance(obj, (Tensor, Module)): | |||||
return [(key, obj)] | |||||
elif isinstance(obj, (list, tuple, dict)): | |||||
ret = [] | ret = [] | ||||
if isinstance(obj, dict): | if isinstance(obj, dict): | ||||
targets = ((k, obj[k]) for k in sorted(obj)) | targets = ((k, obj[k]) for k in sorted(obj)) | ||||
else: | else: | ||||
targets = ((str(k), v) for k, v in enumerate(obj)) | targets = ((str(k), v) for k, v in enumerate(obj)) | ||||
for k, o in targets: | for k, o in targets: | ||||
ret.extend(_expand_structure(key + "." + k, o)) | |||||
sub_ret = _expand_structure(k, o) | |||||
if sub_ret and not isinstance(k, str): | |||||
raise AssertionError( | |||||
"keys for Tensor and Module must be str, error key: {}".format(k) | |||||
) | |||||
for kt, vt in sub_ret: | |||||
ret.extend([(key + "." + kt, vt)]) | |||||
return ret | return ret | ||||
else: | else: | ||||
return [(key, obj)] | |||||
return [] | |||||
def _is_parameter(obj): | def _is_parameter(obj): | ||||
@@ -72,11 +80,11 @@ class Module(metaclass=ABCMeta): | |||||
predicate: Callable[[Any], bool] = lambda _: True, | predicate: Callable[[Any], bool] = lambda _: True, | ||||
seen: Optional[Set[int]] = None | seen: Optional[Set[int]] = None | ||||
) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]: | ) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]: | ||||
"""Scans the module object and returns an iterable for the attributes that | |||||
agree with the ``predicate``. For multiple calls of this function with same | |||||
arguments, the order of objects within the returned iterable is guaranteed to be | |||||
identical, as long as all the involved module objects' ``__dict__`` does not | |||||
change thoughout those calls. | |||||
"""Scans the module object and returns an iterable for the :class:`~.Tensor` | |||||
and :class:`~.Module` attributes that agree with the ``predicate``. For multiple | |||||
calls of this function with same arguments, the order of objects within the | |||||
returned iterable is guaranteed to be identical, as long as all the involved | |||||
module objects' ``__dict__`` does not change thoughout those calls. | |||||
:param recursive: Whether to recursively scan all the submodules. | :param recursive: Whether to recursively scan all the submodules. | ||||
:param with_key: Whether to yield keys along with yielded objects. | :param with_key: Whether to yield keys along with yielded objects. | ||||
@@ -14,7 +14,7 @@ import pytest | |||||
from helpers import MLP | from helpers import MLP | ||||
import megengine as mge | import megengine as mge | ||||
from megengine.core import Buffer, Parameter, tensor | |||||
from megengine.core import Buffer, Parameter, Tensor, tensor | |||||
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential | from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential | ||||
from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
@@ -139,6 +139,7 @@ class MyModule2(Module): | |||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
self.bn = BatchNorm2d(4) | self.bn = BatchNorm2d(4) | ||||
self.test_bool_key = {True: 1, False: 0} | |||||
def forward(self, x): | def forward(self, x): | ||||
x = self.bn(x) | x = self.bn(x) | ||||
@@ -148,7 +149,7 @@ class MyModule2(Module): | |||||
self.bn = BatchNorm2d(4) | self.bn = BatchNorm2d(4) | ||||
self.a = [ | self.a = [ | ||||
BatchNorm2d(4), | BatchNorm2d(4), | ||||
{"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()]}, | |||||
{"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()], "z": 0}, | |||||
(self.InnerModule(),), | (self.InnerModule(),), | ||||
] | ] | ||||
@@ -171,6 +172,14 @@ def test_expand_structure(): | |||||
] | ] | ||||
def test_flatten_others(): | |||||
def be_others(obj): | |||||
return not isinstance(obj, (Tensor, Module)) | |||||
m = MyModule2() | |||||
assert len(list(m._flatten(with_key=True, predicate=be_others))) == 0 | |||||
def test_flatten_with_parent(): | def test_flatten_with_parent(): | ||||
m = MyModule2() | m = MyModule2() | ||||
assert list(m.named_modules(with_parent=True)) == [ | assert list(m.named_modules(with_parent=True)) == [ | ||||
@@ -251,6 +260,23 @@ def test_state_dict(): | |||||
mlp1.load_state_dict(state_dict) | mlp1.load_state_dict(state_dict) | ||||
class AssertModule(Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.error_tensor_key = {True: tensor(), False: 0} | |||||
def forward(self, x): | |||||
return x | |||||
def test_assert_message(): | |||||
m = AssertModule() | |||||
with pytest.raises( | |||||
AssertionError, match="keys for Tensor and Module must be str, error key: True" | |||||
): | |||||
list(m._flatten()) | |||||
class Simple(Module): | class Simple(Module): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||