GitOrigin-RevId: 5ac525f010
release-1.5
@@ -609,14 +609,6 @@ class Module(metaclass=ABCMeta): | |||||
return set(loaded), set(skipped) | return set(loaded), set(skipped) | ||||
def __getattribute__(self, name: str): | |||||
value = super().__getattribute__(name) | |||||
if name == "__dict__": | |||||
return value | |||||
for prefix, variable in _expand_structure(name, value): | |||||
variable._name = prefix | |||||
return value | |||||
def __setattr__(self, name: str, value): | def __setattr__(self, name: str, value): | ||||
is_module_like = _is_module(value) or isinstance(value, (list, tuple, dict)) | is_module_like = _is_module(value) or isinstance(value, (list, tuple, dict)) | ||||
if name != "_modules": | if name != "_modules": | ||||
@@ -631,6 +623,15 @@ class Module(metaclass=ABCMeta): | |||||
else: | else: | ||||
if modules is not None and name in modules: | if modules is not None and name in modules: | ||||
modules.remove(name) | modules.remove(name) | ||||
for k, v in _expand_structure(name, value): | |||||
if not v._name: | |||||
v._name = k | |||||
else: | |||||
logger.warning( | |||||
"try setting the submodule `{}` to a new attribute `{}`, its name `{}` will remain unchanged".format( | |||||
v._name, k, v._name | |||||
) | |||||
) | |||||
super().__setattr__(name, value) | super().__setattr__(name, value) | ||||
def __delattr__(self, name: str): | def __delattr__(self, name: str): | ||||
@@ -368,10 +368,10 @@ class AssertModule(Module): | |||||
def test_assert_message(): | def test_assert_message(): | ||||
m = AssertModule() | |||||
with pytest.raises( | with pytest.raises( | ||||
AssertionError, match="keys for Tensor and Module must be str, error key: True" | AssertionError, match="keys for Tensor and Module must be str, error key: True" | ||||
): | ): | ||||
m = AssertModule() | |||||
list(m._flatten()) | list(m._flatten()) | ||||
@@ -155,13 +155,13 @@ def test_with_submodule_in_container(symbolic): | |||||
m = Simple("simple") | m = Simple("simple") | ||||
ops = _dump_and_load(m, symbolic) | ops = _dump_and_load(m, symbolic) | ||||
assert ops[-1].outputs[0].name == "simple.l2.l2-1.ADD" | |||||
assert ops[-1].name == "simple.l2.l2-1.ADD" | |||||
assert ops[-2].name == "simple.l2.l2-1.MatrixMul" | |||||
assert ops[-3].name == "simple.l1.1.ADD" | |||||
assert ops[-4].name == "simple.l1.1.MatrixMul" | |||||
assert ops[-5].name == "simple.l0.1.ADD" | |||||
assert ops[-6].name == "simple.l0.1.MatrixMul" | |||||
assert ops[-1].outputs[0].name == "simple.l0.1.ADD[2]" | |||||
assert ops[-1].name == "simple.l0.1.ADD[2]" | |||||
assert ops[-2].name == "simple.l0.1.MatrixMul[2]" | |||||
assert ops[-3].name == "simple.l0.1.ADD[1]" | |||||
assert ops[-4].name == "simple.l0.1.MatrixMul[1]" | |||||
assert ops[-5].name == "simple.l0.1.ADD[0]" | |||||
assert ops[-6].name == "simple.l0.1.MatrixMul[0]" | |||||
@pytest.mark.parametrize("symbolic", [False, True]) | @pytest.mark.parametrize("symbolic", [False, True]) | ||||