Browse Source

test(mge/module): add module reset attribute test

GitOrigin-RevId: 6c9adc4a70
release-1.4
Megvii Engine Team 4 years ago
parent
commit
9be7cb8924
1 changed files with 31 additions and 0 deletions
  1. +31
    -0
      imperative/python/test/unit/module/test_module.py

+ 31
- 0
imperative/python/test/unit/module/test_module.py View File

@@ -625,3 +625,34 @@ def test_repr_module_delete():
del net.softmax
output = net.__repr__()
assert output == ground_truth


def test_repr_module_reset_attr():
class ResetAttrModule(Module):
def __init__(self, flag):
super().__init__()
if flag:
self.a = None
self.a = Linear(3, 5)
else:
self.a = Linear(3, 5)
self.a = None

def forward(self, x):
if self.a:
x = self.a(x)
return x

ground_truth = [
(
"ResetAttrModule(\n"
" (a): Linear(in_features=3, out_features=5, bias=True)\n"
")"
),
("ResetAttrModule()"),
]

m0 = ResetAttrModule(True)
m1 = ResetAttrModule(False)
output = [m0.__repr__(), m1.__repr__()]
assert output == ground_truth

Loading…
Cancel
Save