Browse Source

fix(imperative/module): remove ``__getattribute__`` method in module

GitOrigin-RevId: 5ac525f010
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
b2944559a8
3 changed files with 17 additions and 16 deletions
  1. +9
    -8
      imperative/python/megengine/module/module.py
  2. +1
    -1
      imperative/python/test/unit/module/test_module.py
  3. +7
    -7
      imperative/python/test/unit/utils/test_dump_naming.py

+ 9
- 8
imperative/python/megengine/module/module.py View File

@@ -609,14 +609,6 @@ class Module(metaclass=ABCMeta):

return set(loaded), set(skipped)

def __getattribute__(self, name: str):
value = super().__getattribute__(name)
if name == "__dict__":
return value
for prefix, variable in _expand_structure(name, value):
variable._name = prefix
return value

def __setattr__(self, name: str, value):
is_module_like = _is_module(value) or isinstance(value, (list, tuple, dict))
if name != "_modules":
@@ -631,6 +623,15 @@ class Module(metaclass=ABCMeta):
else:
if modules is not None and name in modules:
modules.remove(name)
for k, v in _expand_structure(name, value):
if not v._name:
v._name = k
else:
logger.warning(
"try setting the submodule `{}` to a new attribute `{}`, its name `{}` will remain unchanged".format(
v._name, k, v._name
)
)
super().__setattr__(name, value)

def __delattr__(self, name: str):


+ 1
- 1
imperative/python/test/unit/module/test_module.py View File

@@ -368,10 +368,10 @@ class AssertModule(Module):


def test_assert_message():
m = AssertModule()
with pytest.raises(
AssertionError, match="keys for Tensor and Module must be str, error key: True"
):
m = AssertModule()
list(m._flatten())




+ 7
- 7
imperative/python/test/unit/utils/test_dump_naming.py View File

@@ -155,13 +155,13 @@ def test_with_submodule_in_container(symbolic):
m = Simple("simple")

ops = _dump_and_load(m, symbolic)
assert ops[-1].outputs[0].name == "simple.l2.l2-1.ADD"
assert ops[-1].name == "simple.l2.l2-1.ADD"
assert ops[-2].name == "simple.l2.l2-1.MatrixMul"
assert ops[-3].name == "simple.l1.1.ADD"
assert ops[-4].name == "simple.l1.1.MatrixMul"
assert ops[-5].name == "simple.l0.1.ADD"
assert ops[-6].name == "simple.l0.1.MatrixMul"
assert ops[-1].outputs[0].name == "simple.l0.1.ADD[2]"
assert ops[-1].name == "simple.l0.1.ADD[2]"
assert ops[-2].name == "simple.l0.1.MatrixMul[2]"
assert ops[-3].name == "simple.l0.1.ADD[1]"
assert ops[-4].name == "simple.l0.1.MatrixMul[1]"
assert ops[-5].name == "simple.l0.1.ADD[0]"
assert ops[-6].name == "simple.l0.1.MatrixMul[0]"


@pytest.mark.parametrize("symbolic", [False, True])


Loading…
Cancel
Save