Browse Source

fix(mge/module): fix named_tensors

GitOrigin-RevId: bb5aa1f41d
release-1.7
Megvii Engine Team 3 years ago
parent
commit
f7d05db7d6
2 changed files with 36 additions and 3 deletions
  1. +12
    -3
      imperative/python/megengine/module/module.py
  2. +24
    -0
      imperative/python/test/unit/module/test_module.py

+ 12
- 3
imperative/python/megengine/module/module.py View File

@@ -67,6 +67,10 @@ def _is_parameter(obj):
return isinstance(obj, Parameter)


def _is_tensor(obj):
return isinstance(obj, Tensor)


def _is_buffer(obj):
return isinstance(obj, Tensor) and not isinstance(obj, Parameter)

@@ -309,8 +313,9 @@ class Module(metaclass=ABCMeta):
module, else only returns :class:`~.Tensor` that are direct attributes
of this module.
"""

yield from self._flatten(with_key=False, recursive=recursive, **kwargs)
yield from self._flatten(
with_key=False, predicate=_is_tensor, recursive=recursive, **kwargs
)

def named_tensors(
self, prefix: Optional[str] = None, recursive: bool = True, **kwargs
@@ -325,7 +330,11 @@ class Module(metaclass=ABCMeta):
of this module.
"""
yield from self._flatten(
with_key=True, prefix=prefix, recursive=recursive, **kwargs,
with_key=True,
prefix=prefix,
predicate=_is_tensor,
recursive=recursive,
**kwargs,
)

def children(self, **kwargs) -> "Iterable[Module]":


+ 24
- 0
imperative/python/test/unit/module/test_module.py View File

@@ -124,6 +124,30 @@ def test_module_api(test_traced_module):
("i.bn.weight", m.i.bn.weight),
("param", m.param),
]
assert list(m.tensors()) == [
m.bn.bias,
m.bn.running_mean,
m.bn.running_var,
m.bn.weight,
m.buff,
m.i.bn.bias,
m.i.bn.running_mean,
m.i.bn.running_var,
m.i.bn.weight,
m.param,
]
assert list(m.named_tensors()) == [
("bn.bias", m.bn.bias),
("bn.running_mean", m.bn.running_mean),
("bn.running_var", m.bn.running_var),
("bn.weight", m.bn.weight),
("buff", m.buff),
("i.bn.bias", m.i.bn.bias),
("i.bn.running_mean", m.i.bn.running_mean),
("i.bn.running_var", m.i.bn.running_var),
("i.bn.weight", m.i.bn.weight),
("param", m.param),
]
m.eval()
assert (
m.training == False


Loading…
Cancel
Save