|
|
@@ -625,3 +625,34 @@ def test_repr_module_delete(): |
|
|
|
del net.softmax |
|
|
|
output = net.__repr__() |
|
|
|
assert output == ground_truth |
|
|
|
|
|
|
|
|
|
|
|
def test_repr_module_reset_attr(): |
|
|
|
class ResetAttrModule(Module): |
|
|
|
def __init__(self, flag): |
|
|
|
super().__init__() |
|
|
|
if flag: |
|
|
|
self.a = None |
|
|
|
self.a = Linear(3, 5) |
|
|
|
else: |
|
|
|
self.a = Linear(3, 5) |
|
|
|
self.a = None |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
if self.a: |
|
|
|
x = self.a(x) |
|
|
|
return x |
|
|
|
|
|
|
|
ground_truth = [ |
|
|
|
( |
|
|
|
"ResetAttrModule(\n" |
|
|
|
" (a): Linear(in_features=3, out_features=5, bias=True)\n" |
|
|
|
")" |
|
|
|
), |
|
|
|
("ResetAttrModule()"), |
|
|
|
] |
|
|
|
|
|
|
|
m0 = ResetAttrModule(True) |
|
|
|
m1 = ResetAttrModule(False) |
|
|
|
output = [m0.__repr__(), m1.__repr__()] |
|
|
|
assert output == ground_truth |