You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_module.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700
  1. # -*- coding: utf-8 -*-
  2. from collections import OrderedDict
  3. from io import BytesIO
  4. import numpy as np
  5. import pytest
  6. import megengine as mge
  7. import megengine.functional as F
  8. from megengine import Parameter, Tensor, tensor
  9. from megengine.module import (
  10. BatchNorm1d,
  11. BatchNorm2d,
  12. Conv1d,
  13. Conv2d,
  14. Dropout,
  15. Linear,
  16. MaxPool2d,
  17. Module,
  18. Sequential,
  19. Softmax,
  20. )
  21. from megengine.module.module import _access_structure
  22. from megengine.quantization.quantize import quantize, quantize_qat
  23. from megengine.traced_module import TracedModule, trace_module
  24. from megengine.utils.module_utils import get_expand_structure, set_expand_structure
  25. class MLP(Module):
  26. def __init__(self):
  27. super().__init__()
  28. self.dense0 = Linear(28, 50)
  29. self.dense1 = Linear(50, 20)
  30. def forward(self, x):
  31. x = self.dense0(x)
  32. x = F.relu(x)
  33. x = self.dense1(x)
  34. return x
  35. class MyModule(Module):
  36. class InnerModule(Module):
  37. def __init__(self):
  38. super().__init__()
  39. self.bn = BatchNorm2d(4)
  40. def forward(self, x):
  41. return self.bn(x)
  42. def __init__(self):
  43. super().__init__()
  44. self.i = self.InnerModule()
  45. self.bn = BatchNorm2d(4)
  46. self.param = Parameter(np.ones(1, dtype=np.float32))
  47. self.buff = Tensor(np.ones(1, dtype=np.float32))
  48. def forward(self, x):
  49. x = self.i(x)
  50. x = self.bn(x)
  51. return x
  52. @pytest.mark.parametrize("test_traced_module", [True, False])
  53. def test_module_api(test_traced_module):
  54. m = MyModule()
  55. if test_traced_module:
  56. buff = m.buff
  57. param = m.param
  58. m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16))))
  59. assert "buff" not in m.__dict__
  60. assert "param" not in m.__dict__
  61. m.buff = buff
  62. m.param = param
  63. assert list(m.children()) == [m.bn, m.i]
  64. assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)]
  65. assert list(m.modules()) == [m, m.bn, m.i, m.i.bn]
  66. assert list(m.named_modules()) == [
  67. ("", m),
  68. ("bn", m.bn),
  69. ("i", m.i),
  70. ("i.bn", m.i.bn),
  71. ]
  72. assert list(m.named_modules(prefix="x")) == [
  73. ("x", m),
  74. ("x.bn", m.bn),
  75. ("x.i", m.i),
  76. ("x.i.bn", m.i.bn),
  77. ]
  78. assert list(m.buffers()) == [
  79. m.bn.running_mean,
  80. m.bn.running_var,
  81. m.buff,
  82. m.i.bn.running_mean,
  83. m.i.bn.running_var,
  84. ]
  85. assert list(m.buffers(recursive=False)) == [m.buff]
  86. assert list(m.named_buffers()) == [
  87. ("bn.running_mean", m.bn.running_mean),
  88. ("bn.running_var", m.bn.running_var),
  89. ("buff", m.buff),
  90. ("i.bn.running_mean", m.i.bn.running_mean),
  91. ("i.bn.running_var", m.i.bn.running_var),
  92. ]
  93. assert list(m.parameters()) == [
  94. m.bn.bias,
  95. m.bn.weight,
  96. m.i.bn.bias,
  97. m.i.bn.weight,
  98. m.param,
  99. ]
  100. assert list(m.named_parameters()) == [
  101. ("bn.bias", m.bn.bias),
  102. ("bn.weight", m.bn.weight),
  103. ("i.bn.bias", m.i.bn.bias),
  104. ("i.bn.weight", m.i.bn.weight),
  105. ("param", m.param),
  106. ]
  107. assert list(m.tensors()) == [
  108. m.bn.bias,
  109. m.bn.running_mean,
  110. m.bn.running_var,
  111. m.bn.weight,
  112. m.buff,
  113. m.i.bn.bias,
  114. m.i.bn.running_mean,
  115. m.i.bn.running_var,
  116. m.i.bn.weight,
  117. m.param,
  118. ]
  119. assert list(m.named_tensors()) == [
  120. ("bn.bias", m.bn.bias),
  121. ("bn.running_mean", m.bn.running_mean),
  122. ("bn.running_var", m.bn.running_var),
  123. ("bn.weight", m.bn.weight),
  124. ("buff", m.buff),
  125. ("i.bn.bias", m.i.bn.bias),
  126. ("i.bn.running_mean", m.i.bn.running_mean),
  127. ("i.bn.running_var", m.i.bn.running_var),
  128. ("i.bn.weight", m.i.bn.weight),
  129. ("param", m.param),
  130. ]
  131. m.eval()
  132. assert (
  133. m.training == False
  134. and m.bn.training == False
  135. and m.i.training == False
  136. and m.i.bn.training == False
  137. )
  138. m.bn.train()
  139. assert m.training == False and m.bn.training == True and m.i.bn.training == False
  140. m.eval()
  141. m.i.train()
  142. assert (
  143. m.training == False
  144. and m.bn.training == False
  145. and m.i.training == True
  146. and m.i.bn.training == True
  147. )
  148. m.eval()
  149. m.train()
  150. assert m.training == True and m.bn.training == True and m.i.bn.training == True
  151. def fn(m):
  152. m.training = False
  153. m.apply(fn)
  154. assert m.bn.training == False and m.i.bn.training == False
  155. @pytest.mark.parametrize("test_traced_module", [True, False])
  156. def test_module_api_reuse_submodule(test_traced_module):
  157. m = MyModule()
  158. if test_traced_module:
  159. m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16))))
  160. m.h = m.i # pylint: disable=attribute-defined-outside-init
  161. assert list(m.modules()) == [m, m.bn, m.i, m.i.bn]
  162. assert list(m.named_modules()) == [
  163. ("", m),
  164. ("bn", m.bn),
  165. ("h", m.i),
  166. ("h.bn", m.i.bn),
  167. ]
  168. @pytest.mark.parametrize("test_traced_module", [True, False])
  169. def test_module_api_iterable_stability(test_traced_module):
  170. m = MyModule()
  171. if test_traced_module:
  172. m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16))))
  173. l = list(m.modules())
  174. for _ in range(100):
  175. assert list(m.modules()) == l
  176. @pytest.mark.parametrize("test_traced_module", [True, False])
  177. def test_module_api_hooks(test_traced_module):
  178. net = MyModule()
  179. if test_traced_module:
  180. net = trace_module(net, Tensor(np.zeros((1, 4, 1, 1))))
  181. pre_hook_num = 0
  182. post_hook_num = 0
  183. hooks = []
  184. def pre_hook(_, inputs):
  185. nonlocal pre_hook_num
  186. pre_hook_num += 1
  187. modified_inputs = tuple(inp + 1 for inp in inputs)
  188. return modified_inputs
  189. def post_hook(_, __, outputs):
  190. nonlocal post_hook_num
  191. post_hook_num += 1
  192. outputs += 1
  193. return outputs
  194. net.apply(lambda module: hooks.append(module.register_forward_pre_hook(pre_hook)))
  195. net.apply(lambda module: hooks.append(module.register_forward_hook(post_hook)))
  196. shape = (1, 4, 1, 1)
  197. x = tensor(np.zeros(shape, dtype=np.float32))
  198. y = net(x)
  199. assert pre_hook_num == 4
  200. assert post_hook_num == 4
  201. mean1 = Parameter(np.zeros(shape), dtype=np.float32)
  202. bn1 = F.batch_norm(
  203. x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True
  204. )
  205. np.testing.assert_allclose(
  206. net.i.bn.running_mean.numpy(), mean1.numpy(),
  207. )
  208. mean2 = Parameter(np.zeros(shape), dtype=np.float32)
  209. bn2 = F.batch_norm(
  210. bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True
  211. )
  212. np.testing.assert_allclose(
  213. net.bn.running_mean.numpy(), mean2.numpy(),
  214. )
  215. np.testing.assert_allclose((bn2 + 2).numpy(), y.numpy())
  216. assert len(hooks) == 8
  217. for handler in hooks:
  218. handler.remove()
  219. y = net(x)
  220. assert pre_hook_num == 4
  221. assert post_hook_num == 4
  222. class MyModule2(Module):
  223. class InnerModule(Module):
  224. def __init__(self):
  225. super().__init__()
  226. self.bn = BatchNorm2d(4)
  227. self.test_bool_key = {True: 1, False: 0}
  228. def forward(self, x):
  229. x = self.bn(x)
  230. def __init__(self):
  231. super().__init__()
  232. self.bn = BatchNorm2d(4)
  233. self.a = [
  234. BatchNorm2d(4),
  235. {"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()], "z": 0},
  236. (self.InnerModule(),),
  237. ]
  238. def forward(self, x):
  239. return x
  240. def test_expand_structure():
  241. m = MyModule2()
  242. rst = [
  243. ("", m),
  244. ("a.0", m.a[0]),
  245. ("a.1.x", m.a[1]["x"]),
  246. ("a.1.y.0", m.a[1]["y"][0]),
  247. ("a.1.y.1", m.a[1]["y"][1]),
  248. ("a.1.y.1.bn", m.a[1]["y"][1].bn),
  249. ("a.2.0", m.a[2][0]),
  250. ("a.2.0.bn", m.a[2][0].bn),
  251. ("bn", m.bn),
  252. ]
  253. assert list(m.named_modules()) == rst
  254. for item in rst[1:]:
  255. assert get_expand_structure(m, item[0]) == item[1]
  256. for item in reversed(rst[1:]):
  257. if _access_structure(m, item[0], lambda p, k, o: isinstance(p, tuple)):
  258. continue
  259. set_expand_structure(m, item[0], "TEST_VALUE")
  260. assert get_expand_structure(m, item[0]) == "TEST_VALUE"
  261. def test_flatten_others():
  262. def be_others(obj):
  263. return not isinstance(obj, (Tensor, Module))
  264. m = MyModule2()
  265. assert len(list(m._flatten(with_key=True, predicate=be_others))) == 0
  266. def test_flatten_with_parent():
  267. m = MyModule2()
  268. assert list(m.named_modules(with_parent=True)) == [
  269. ("", m, None),
  270. ("a.0", m.a[0], m),
  271. ("a.1.x", m.a[1]["x"], m),
  272. ("a.1.y.0", m.a[1]["y"][0], m),
  273. ("a.1.y.1", m.a[1]["y"][1], m),
  274. ("a.1.y.1.bn", m.a[1]["y"][1].bn, m.a[1]["y"][1]),
  275. ("a.2.0", m.a[2][0], m),
  276. ("a.2.0.bn", m.a[2][0].bn, m.a[2][0]),
  277. ("bn", m.bn, m),
  278. ]
  279. assert list(m.modules(with_parent=True)) == [
  280. (m, None),
  281. (m.a[0], m),
  282. (m.a[1]["x"], m),
  283. (m.a[1]["y"][0], m),
  284. (m.a[1]["y"][1], m),
  285. (m.a[1]["y"][1].bn, m.a[1]["y"][1]),
  286. (m.a[2][0], m),
  287. (m.a[2][0].bn, m.a[2][0]),
  288. (m.bn, m),
  289. ]
  290. class MyModule3(Module):
  291. class InnerModule(Module):
  292. def __init__(self):
  293. super().__init__()
  294. self.bn = BatchNorm2d(4)
  295. def forward(self, x):
  296. x = self.bn(x)
  297. def __init__(self):
  298. super().__init__()
  299. self.bn = BatchNorm2d(4)
  300. self.seq = Sequential(BatchNorm2d(4), self.InnerModule(),)
  301. def forward(self, x):
  302. return x
  303. def test_module_api_with_sequential():
  304. m = MyModule3()
  305. assert list(m.named_modules()) == [
  306. ("", m),
  307. ("bn", m.bn),
  308. ("seq", m.seq),
  309. ("seq.0", m.seq[0]),
  310. ("seq.1", m.seq[1]),
  311. ("seq.1.bn", m.seq[1].bn),
  312. ]
  313. def test_sequential_named_children():
  314. modules = OrderedDict()
  315. modules["name0"] = Linear(20, 10)
  316. modules["name1"] = Linear(10, 5)
  317. modules["name2"] = Linear(5, 1)
  318. m = Sequential(modules)
  319. l = list(m.named_children())
  320. assert l[0][0] == "name0"
  321. assert l[1][0] == "name1"
  322. assert l[2][0] == "name2"
  323. def test_state_dict():
  324. data_shape = (2, 28)
  325. data = tensor(np.random.random(data_shape))
  326. mlp = MLP()
  327. pred0 = mlp(data)
  328. with BytesIO() as fout:
  329. mge.save(mlp.state_dict(), fout)
  330. fout.seek(0)
  331. state_dict = mge.load(fout)
  332. state_dict["extra"] = None
  333. mlp1 = MLP()
  334. mlp1.load_state_dict(state_dict, strict=False)
  335. pred1 = mlp1(data)
  336. np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6)
  337. with pytest.raises(KeyError):
  338. mlp1.load_state_dict(state_dict)
  339. del state_dict["extra"]
  340. del state_dict["dense0.bias"]
  341. with pytest.raises(KeyError):
  342. mlp1.load_state_dict(state_dict)
  343. class AssertModule(Module):
  344. def __init__(self):
  345. super().__init__()
  346. self.error_tensor_key = {True: tensor([]), False: 0}
  347. def forward(self, x):
  348. return x
  349. def test_assert_message():
  350. with pytest.raises(
  351. AssertionError, match="keys for Tensor and Module must be str, error key: True"
  352. ):
  353. m = AssertModule()
  354. list(m._flatten())
  355. class Simple(Module):
  356. def __init__(self):
  357. super().__init__()
  358. self.conv0 = Conv2d(1, 1, kernel_size=3, bias=False)
  359. self.conv1 = Conv2d(1, 1, kernel_size=3, bias=False)
  360. self.conv1.weight = self.conv0.weight
  361. def forward(self, inputs):
  362. x = self.conv0(inputs)
  363. y = self.conv1(inputs)
  364. return x + y
  365. @pytest.mark.parametrize("test_traced_module", [True, False])
  366. def test_shared_param(test_traced_module):
  367. net = Simple()
  368. if test_traced_module:
  369. net = trace_module(net, tensor(np.random.random((1, 1, 8, 8))))
  370. assert net.conv0.weight is net.conv1.weight
  371. data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32))
  372. np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy())
  373. with BytesIO() as f:
  374. mge.save(net, f)
  375. f.seek(0)
  376. net1 = mge.load(f)
  377. assert net1.conv0.weight is net1.conv1.weight
  378. np.testing.assert_allclose(net1.conv0(data).numpy(), net1.conv1(data).numpy())
  379. with BytesIO() as f:
  380. mge.save(net.conv0, f)
  381. f.seek(0)
  382. conv0 = mge.load(f)
  383. with BytesIO() as f:
  384. mge.save(net.conv1, f)
  385. f.seek(0)
  386. conv1 = mge.load(f)
  387. assert conv0.weight is not conv1.weight
  388. np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy())
  389. class Simple2(Module):
  390. def __init__(self):
  391. super().__init__()
  392. self.conv1 = Conv1d(1, 1, kernel_size=3, bias=False)
  393. self.conv0 = Conv1d(1, 1, kernel_size=3, bias=False)
  394. self.conv1.weight = self.conv0.weight
  395. def forward(self, inputs):
  396. pass
  397. def test_shared_param_1d():
  398. net = Simple2()
  399. assert net.conv0.weight is net.conv1.weight
  400. data = tensor(np.random.random((1, 1, 8)).astype(np.float32))
  401. np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy())
  402. with BytesIO() as f:
  403. mge.save(net, f)
  404. f.seek(0)
  405. net1 = mge.load(f)
  406. assert net1.conv0.weight is net1.conv1.weight
  407. np.testing.assert_allclose(net1.conv0(data).numpy(), net1.conv1(data).numpy())
  408. with BytesIO() as f:
  409. mge.save(net.conv0, f)
  410. f.seek(0)
  411. conv0 = mge.load(f)
  412. with BytesIO() as f:
  413. mge.save(net.conv1, f)
  414. f.seek(0)
  415. conv1 = mge.load(f)
  416. assert conv0.weight is not conv1.weight
  417. np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy())
  418. @pytest.mark.parametrize("test_traced_module", [True, False])
  419. def test_pickle_module(test_traced_module):
  420. data_shape = (2, 28)
  421. data = tensor(np.random.random(data_shape))
  422. mlp = MLP()
  423. pred_gt = mlp(data)
  424. if test_traced_module:
  425. mlp = trace_module(mlp, data)
  426. # pickle before forward
  427. with BytesIO() as fout:
  428. mge.save(mlp, fout)
  429. fout.seek(0)
  430. mlp1 = mge.load(fout)
  431. if test_traced_module:
  432. assert type(mlp1) == TracedModule
  433. pred0 = mlp1(data)
  434. pred1 = mlp(data)
  435. # pickle after forward
  436. with BytesIO() as fout:
  437. mge.save(mlp, fout)
  438. fout.seek(0)
  439. mlp1 = mge.load(fout)
  440. if test_traced_module:
  441. assert type(mlp1) == TracedModule
  442. pred2 = mlp1(data)
  443. np.testing.assert_allclose(pred_gt.numpy(), pred1.numpy(), atol=5e-6)
  444. np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6)
  445. np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6)
  446. def test_repr_basic():
  447. # test whether __repr__ can output correct information
  448. class ConvModel(Module):
  449. def __init__(self):
  450. super().__init__()
  451. self.conv1 = Conv2d(3, 128, 3, padding=1, bias=False)
  452. self.conv2 = Conv2d(3, 128, 3, dilation=2, bias=False)
  453. self.bn1 = BatchNorm1d(128)
  454. self.bn2 = BatchNorm2d(128)
  455. self.pooling = MaxPool2d(kernel_size=2, padding=0)
  456. modules = OrderedDict()
  457. modules["depthwise"] = Conv2d(256, 256, 3, 1, 0, groups=256, bias=False,)
  458. modules["pointwise"] = Conv2d(
  459. 256, 256, kernel_size=1, stride=1, padding=0, bias=True,
  460. )
  461. self.submodule1 = Sequential(modules)
  462. self.list1 = [Dropout(drop_prob=0.1), [Softmax(axis=100)]]
  463. self.tuple1 = (
  464. Dropout(drop_prob=0.1),
  465. (Softmax(axis=100), Dropout(drop_prob=0.2)),
  466. )
  467. self.dict1 = {"Dropout": Dropout(drop_prob=0.1)}
  468. self.fc1 = Linear(512, 1024)
  469. def forward(self, inputs):
  470. pass
  471. ground_truth = (
  472. "ConvModel(\n"
  473. " (conv1): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n"
  474. " (conv2): Conv2d(3, 128, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n"
  475. " (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
  476. " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
  477. " (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0)\n"
  478. " (submodule1): Sequential(\n"
  479. " (depthwise): Conv2d(256, 256, kernel_size=(3, 3), groups=256, bias=False)\n"
  480. " (pointwise): Conv2d(256, 256, kernel_size=(1, 1))\n"
  481. " )\n"
  482. " (list1.0): Dropout(drop_prob=0.1)\n"
  483. " (list1.1.0): Softmax(axis=100)\n"
  484. " (tuple1.0): Dropout(drop_prob=0.1)\n"
  485. " (tuple1.1.0): Softmax(axis=100)\n"
  486. " (tuple1.1.1): Dropout(drop_prob=0.2)\n"
  487. " (dict1.Dropout): Dropout(drop_prob=0.1)\n"
  488. " (fc1): Linear(in_features=512, out_features=1024, bias=True)\n"
  489. ")"
  490. )
  491. net = ConvModel()
  492. output = net.__repr__()
  493. assert output == ground_truth
  494. def test_repr_module_reassign():
  495. # test whether __repr__ can deal with module reassign
  496. class ConvModel1(Module):
  497. def __init__(self):
  498. super().__init__()
  499. self.conv1 = Conv2d(3, 128, 3, bias=False)
  500. self.conv2 = Conv2d(3, 128, 3, padding=1, bias=False)
  501. self.conv1 = Conv2d(3, 256, 3, dilation=2, bias=False)
  502. def forward(self, inputs):
  503. pass
  504. ground_truth = (
  505. "ConvModel1(\n"
  506. " (conv1): Conv2d(3, 256, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n"
  507. " (conv2): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n"
  508. ")"
  509. )
  510. net = ConvModel1()
  511. output = net.__repr__()
  512. assert output == ground_truth
  513. def test_repr_module_rereference():
  514. # test whether __repr__ can deal with module re-reference
  515. class ConvModel2(Module):
  516. def __init__(self):
  517. super().__init__()
  518. self.conv1 = Conv2d(3, 128, 3, bias=False)
  519. self.conv2 = self.conv1
  520. self.conv3 = self.conv1
  521. def forward(self, inputs):
  522. pass
  523. ground_truth = (
  524. "ConvModel2(\n"
  525. " (conv1): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n"
  526. " (conv2): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n"
  527. " (conv3): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n"
  528. ")"
  529. )
  530. net = ConvModel2()
  531. output = net.__repr__()
  532. assert output == ground_truth
  533. def test_repr_module_delete():
  534. # test whether __repr__ can deal with module delete
  535. class ConvModel3(Module):
  536. def __init__(self):
  537. super().__init__()
  538. self.conv1 = Conv2d(3, 128, 3, bias=False)
  539. self.softmax = Softmax(100)
  540. def forward(self, inputs):
  541. pass
  542. ground_truth = (
  543. "ConvModel3(\n"
  544. " (conv1): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n"
  545. ")"
  546. )
  547. net = ConvModel3()
  548. del net.softmax
  549. output = net.__repr__()
  550. assert output == ground_truth
  551. def test_repr_module_reset_attr():
  552. class ResetAttrModule(Module):
  553. def __init__(self, flag):
  554. super().__init__()
  555. if flag:
  556. self.a = None
  557. self.a = Linear(3, 5)
  558. else:
  559. self.a = Linear(3, 5)
  560. self.a = None
  561. def forward(self, x):
  562. if self.a:
  563. x = self.a(x)
  564. return x
  565. ground_truth = [
  566. (
  567. "ResetAttrModule(\n"
  568. " (a): Linear(in_features=3, out_features=5, bias=True)\n"
  569. ")"
  570. ),
  571. ("ResetAttrModule()"),
  572. ]
  573. m0 = ResetAttrModule(True)
  574. m1 = ResetAttrModule(False)
  575. output = [m0.__repr__(), m1.__repr__()]
  576. assert output == ground_truth
  577. def test_module_compatible():
  578. class Empty(Module):
  579. def forward(self):
  580. pass
  581. empty_module = Empty()
  582. old_attributes = set(
  583. [
  584. "_modules",
  585. "name",
  586. "training",
  587. "quantize_disabled",
  588. "_forward_pre_hooks",
  589. "_forward_hooks",
  590. "_name",
  591. "_short_name",
  592. ]
  593. )
  594. current_attributes = set(empty_module.__dict__.keys())
  595. assert (
  596. old_attributes == current_attributes
  597. ), "Add or delete attributes in Module class may break compatibility of pickle serialization"