Browse Source

fix(mge/module): fix non-str key error of dict in module

GitOrigin-RevId: f82cd48230
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
d2f5874a52
2 changed files with 44 additions and 10 deletions
  1. +16
    -8
      python_module/megengine/module/module.py
  2. +28
    -2
      python_module/test/unit/module/test_module.py

+ 16
- 8
python_module/megengine/module/module.py View File

@@ -18,17 +18,25 @@ logger = get_logger(__name__)




def _expand_structure(key, obj): def _expand_structure(key, obj):
if isinstance(obj, (list, tuple, dict)):
if isinstance(obj, (Tensor, Module)):
return [(key, obj)]
elif isinstance(obj, (list, tuple, dict)):
ret = [] ret = []
if isinstance(obj, dict): if isinstance(obj, dict):
targets = ((k, obj[k]) for k in sorted(obj)) targets = ((k, obj[k]) for k in sorted(obj))
else: else:
targets = ((str(k), v) for k, v in enumerate(obj)) targets = ((str(k), v) for k, v in enumerate(obj))
for k, o in targets: for k, o in targets:
ret.extend(_expand_structure(key + "." + k, o))
sub_ret = _expand_structure(k, o)
if sub_ret and not isinstance(k, str):
raise AssertionError(
"keys for Tensor and Module must be str, error key: {}".format(k)
)
for kt, vt in sub_ret:
ret.extend([(key + "." + kt, vt)])
return ret return ret
else: else:
return [(key, obj)]
return []




def _is_parameter(obj): def _is_parameter(obj):
@@ -72,11 +80,11 @@ class Module(metaclass=ABCMeta):
predicate: Callable[[Any], bool] = lambda _: True, predicate: Callable[[Any], bool] = lambda _: True,
seen: Optional[Set[int]] = None seen: Optional[Set[int]] = None
) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]: ) -> Union[Iterable[Any], Iterable[Tuple[str, Any]]]:
"""Scans the module object and returns an iterable for the attributes that
agree with the ``predicate``. For multiple calls of this function with same
arguments, the order of objects within the returned iterable is guaranteed to be
identical, as long as all the involved module objects' ``__dict__`` does not
change thoughout those calls.
"""Scans the module object and returns an iterable for the :class:`~.Tensor`
and :class:`~.Module` attributes that agree with the ``predicate``. For multiple
calls of this function with same arguments, the order of objects within the
returned iterable is guaranteed to be identical, as long as all the involved
module objects' ``__dict__`` does not change thoughout those calls.


:param recursive: Whether to recursively scan all the submodules. :param recursive: Whether to recursively scan all the submodules.
:param with_key: Whether to yield keys along with yielded objects. :param with_key: Whether to yield keys along with yielded objects.


+ 28
- 2
python_module/test/unit/module/test_module.py View File

@@ -14,7 +14,7 @@ import pytest
from helpers import MLP from helpers import MLP


import megengine as mge import megengine as mge
from megengine.core import Buffer, Parameter, tensor
from megengine.core import Buffer, Parameter, Tensor, tensor
from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential from megengine.module import BatchNorm1d, BatchNorm2d, Conv2d, Module, Sequential
from megengine.test import assertTensorClose from megengine.test import assertTensorClose


@@ -139,6 +139,7 @@ class MyModule2(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.bn = BatchNorm2d(4) self.bn = BatchNorm2d(4)
self.test_bool_key = {True: 1, False: 0}


def forward(self, x): def forward(self, x):
x = self.bn(x) x = self.bn(x)
@@ -148,7 +149,7 @@ class MyModule2(Module):
self.bn = BatchNorm2d(4) self.bn = BatchNorm2d(4)
self.a = [ self.a = [
BatchNorm2d(4), BatchNorm2d(4),
{"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()]},
{"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()], "z": 0},
(self.InnerModule(),), (self.InnerModule(),),
] ]


@@ -171,6 +172,14 @@ def test_expand_structure():
] ]




def test_flatten_others():
def be_others(obj):
return not isinstance(obj, (Tensor, Module))

m = MyModule2()
assert len(list(m._flatten(with_key=True, predicate=be_others))) == 0


def test_flatten_with_parent(): def test_flatten_with_parent():
m = MyModule2() m = MyModule2()
assert list(m.named_modules(with_parent=True)) == [ assert list(m.named_modules(with_parent=True)) == [
@@ -251,6 +260,23 @@ def test_state_dict():
mlp1.load_state_dict(state_dict) mlp1.load_state_dict(state_dict)




class AssertModule(Module):
def __init__(self):
super().__init__()
self.error_tensor_key = {True: tensor(), False: 0}

def forward(self, x):
return x


def test_assert_message():
m = AssertModule()
with pytest.raises(
AssertionError, match="keys for Tensor and Module must be str, error key: True"
):
list(m._flatten())


class Simple(Module): class Simple(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()


Loading…
Cancel
Save