diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 0838553f..8fb4c2dd 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -609,14 +609,6 @@ class Module(metaclass=ABCMeta): return set(loaded), set(skipped) - def __getattribute__(self, name: str): - value = super().__getattribute__(name) - if name == "__dict__": - return value - for prefix, variable in _expand_structure(name, value): - variable._name = prefix - return value - def __setattr__(self, name: str, value): is_module_like = _is_module(value) or isinstance(value, (list, tuple, dict)) if name != "_modules": @@ -631,6 +623,15 @@ class Module(metaclass=ABCMeta): else: if modules is not None and name in modules: modules.remove(name) + for k, v in _expand_structure(name, value): + if not v._name: + v._name = k + else: + logger.warning( + "try setting the submodule `{}` to a new attribute `{}`, its name `{}` will remain unchanged".format( + v._name, k, v._name + ) + ) super().__setattr__(name, value) def __delattr__(self, name: str): diff --git a/imperative/python/test/unit/module/test_module.py b/imperative/python/test/unit/module/test_module.py index 0ff93a36..235a54d1 100644 --- a/imperative/python/test/unit/module/test_module.py +++ b/imperative/python/test/unit/module/test_module.py @@ -368,10 +368,10 @@ class AssertModule(Module): def test_assert_message(): - m = AssertModule() with pytest.raises( AssertionError, match="keys for Tensor and Module must be str, error key: True" ): + m = AssertModule() list(m._flatten()) diff --git a/imperative/python/test/unit/utils/test_dump_naming.py b/imperative/python/test/unit/utils/test_dump_naming.py index f154fc82..333a4ddd 100644 --- a/imperative/python/test/unit/utils/test_dump_naming.py +++ b/imperative/python/test/unit/utils/test_dump_naming.py @@ -155,13 +155,13 @@ def test_with_submodule_in_container(symbolic): m = Simple("simple") ops = _dump_and_load(m, symbolic) - assert ops[-1].outputs[0].name == "simple.l2.l2-1.ADD" - assert ops[-1].name == "simple.l2.l2-1.ADD" - assert ops[-2].name == "simple.l2.l2-1.MatrixMul" - assert ops[-3].name == "simple.l1.1.ADD" - assert ops[-4].name == "simple.l1.1.MatrixMul" - assert ops[-5].name == "simple.l0.1.ADD" - assert ops[-6].name == "simple.l0.1.MatrixMul" + assert ops[-1].outputs[0].name == "simple.l0.1.ADD[2]" + assert ops[-1].name == "simple.l0.1.ADD[2]" + assert ops[-2].name == "simple.l0.1.MatrixMul[2]" + assert ops[-3].name == "simple.l0.1.ADD[1]" + assert ops[-4].name == "simple.l0.1.MatrixMul[1]" + assert ops[-5].name == "simple.l0.1.ADD[0]" + assert ops[-6].name == "simple.l0.1.MatrixMul[0]" @pytest.mark.parametrize("symbolic", [False, True])