You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_module.py 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766
  1. # -*- coding: utf-8 -*-
  2. from collections import OrderedDict
  3. from io import BytesIO
  4. import numpy as np
  5. import pytest
  6. import megengine as mge
  7. import megengine.functional as F
  8. from megengine import Parameter, Tensor, tensor
  9. from megengine.device import get_device_count
  10. from megengine.module import (
  11. BatchNorm1d,
  12. BatchNorm2d,
  13. Conv1d,
  14. Conv2d,
  15. Dropout,
  16. GroupNorm,
  17. Linear,
  18. MaxPool2d,
  19. Module,
  20. Sequential,
  21. Softmax,
  22. )
  23. from megengine.module.module import _access_structure
  24. from megengine.quantization.quantize import quantize, quantize_qat
  25. from megengine.traced_module import TracedModule, trace_module
  26. from megengine.utils.module_utils import get_expand_structure, set_expand_structure
  27. class MLP(Module):
  28. def __init__(self):
  29. super().__init__()
  30. self.dense0 = Linear(28, 50)
  31. self.dense1 = Linear(50, 20)
  32. def forward(self, x):
  33. x = self.dense0(x)
  34. x = F.relu(x)
  35. x = self.dense1(x)
  36. return x
  37. class MyModule(Module):
  38. class InnerModule(Module):
  39. def __init__(self):
  40. super().__init__()
  41. self.bn = BatchNorm2d(4)
  42. def forward(self, x):
  43. return self.bn(x)
  44. def __init__(self):
  45. super().__init__()
  46. self.i = self.InnerModule()
  47. self.bn = BatchNorm2d(4)
  48. self.param = Parameter(np.ones(1, dtype=np.float32))
  49. self.buff = Tensor(np.ones(1, dtype=np.float32))
  50. def forward(self, x):
  51. x = self.i(x)
  52. x = self.bn(x)
  53. return x
  54. @pytest.mark.parametrize("test_traced_module", [True, False])
  55. def test_module_api(test_traced_module):
  56. m = MyModule()
  57. if test_traced_module:
  58. buff = m.buff
  59. param = m.param
  60. m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16))))
  61. assert "buff" not in m.__dict__
  62. assert "param" not in m.__dict__
  63. m.buff = buff
  64. m.param = param
  65. assert list(m.children()) == [m.bn, m.i]
  66. assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)]
  67. assert list(m.modules()) == [m, m.bn, m.i, m.i.bn]
  68. assert list(m.named_modules()) == [
  69. ("", m),
  70. ("bn", m.bn),
  71. ("i", m.i),
  72. ("i.bn", m.i.bn),
  73. ]
  74. assert list(m.named_modules(prefix="x")) == [
  75. ("x", m),
  76. ("x.bn", m.bn),
  77. ("x.i", m.i),
  78. ("x.i.bn", m.i.bn),
  79. ]
  80. assert list(m.buffers()) == [
  81. m.bn.running_mean,
  82. m.bn.running_var,
  83. m.buff,
  84. m.i.bn.running_mean,
  85. m.i.bn.running_var,
  86. ]
  87. assert list(m.buffers(recursive=False)) == [m.buff]
  88. assert list(m.named_buffers()) == [
  89. ("bn.running_mean", m.bn.running_mean),
  90. ("bn.running_var", m.bn.running_var),
  91. ("buff", m.buff),
  92. ("i.bn.running_mean", m.i.bn.running_mean),
  93. ("i.bn.running_var", m.i.bn.running_var),
  94. ]
  95. assert list(m.parameters()) == [
  96. m.bn.bias,
  97. m.bn.weight,
  98. m.i.bn.bias,
  99. m.i.bn.weight,
  100. m.param,
  101. ]
  102. assert list(m.named_parameters()) == [
  103. ("bn.bias", m.bn.bias),
  104. ("bn.weight", m.bn.weight),
  105. ("i.bn.bias", m.i.bn.bias),
  106. ("i.bn.weight", m.i.bn.weight),
  107. ("param", m.param),
  108. ]
  109. assert list(m.tensors()) == [
  110. m.bn.bias,
  111. m.bn.running_mean,
  112. m.bn.running_var,
  113. m.bn.weight,
  114. m.buff,
  115. m.i.bn.bias,
  116. m.i.bn.running_mean,
  117. m.i.bn.running_var,
  118. m.i.bn.weight,
  119. m.param,
  120. ]
  121. assert list(m.named_tensors()) == [
  122. ("bn.bias", m.bn.bias),
  123. ("bn.running_mean", m.bn.running_mean),
  124. ("bn.running_var", m.bn.running_var),
  125. ("bn.weight", m.bn.weight),
  126. ("buff", m.buff),
  127. ("i.bn.bias", m.i.bn.bias),
  128. ("i.bn.running_mean", m.i.bn.running_mean),
  129. ("i.bn.running_var", m.i.bn.running_var),
  130. ("i.bn.weight", m.i.bn.weight),
  131. ("param", m.param),
  132. ]
  133. m.eval()
  134. assert (
  135. m.training == False
  136. and m.bn.training == False
  137. and m.i.training == False
  138. and m.i.bn.training == False
  139. )
  140. m.bn.train()
  141. assert m.training == False and m.bn.training == True and m.i.bn.training == False
  142. m.eval()
  143. m.i.train()
  144. assert (
  145. m.training == False
  146. and m.bn.training == False
  147. and m.i.training == True
  148. and m.i.bn.training == True
  149. )
  150. m.eval()
  151. m.train()
  152. assert m.training == True and m.bn.training == True and m.i.bn.training == True
  153. def fn(m):
  154. m.training = False
  155. m.apply(fn)
  156. assert m.bn.training == False and m.i.bn.training == False
  157. @pytest.mark.parametrize("test_traced_module", [True, False])
  158. def test_module_api_reuse_submodule(test_traced_module):
  159. m = MyModule()
  160. if test_traced_module:
  161. m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16))))
  162. m.h = m.i # pylint: disable=attribute-defined-outside-init
  163. assert list(m.modules()) == [m, m.bn, m.i, m.i.bn]
  164. assert list(m.named_modules()) == [
  165. ("", m),
  166. ("bn", m.bn),
  167. ("h", m.i),
  168. ("h.bn", m.i.bn),
  169. ]
  170. @pytest.mark.parametrize("test_traced_module", [True, False])
  171. def test_module_api_iterable_stability(test_traced_module):
  172. m = MyModule()
  173. if test_traced_module:
  174. m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16))))
  175. l = list(m.modules())
  176. for _ in range(100):
  177. assert list(m.modules()) == l
  178. @pytest.mark.parametrize("test_traced_module", [True, False])
  179. def test_module_api_hooks(test_traced_module):
  180. net = MyModule()
  181. if test_traced_module:
  182. net = trace_module(net, Tensor(np.zeros((1, 4, 1, 1))))
  183. pre_hook_num = 0
  184. post_hook_num = 0
  185. hooks = []
  186. def pre_hook(_, inputs):
  187. nonlocal pre_hook_num
  188. pre_hook_num += 1
  189. modified_inputs = tuple(inp + 1 for inp in inputs)
  190. return modified_inputs
  191. def post_hook(_, __, outputs):
  192. nonlocal post_hook_num
  193. post_hook_num += 1
  194. outputs += 1
  195. return outputs
  196. net.apply(lambda module: hooks.append(module.register_forward_pre_hook(pre_hook)))
  197. net.apply(lambda module: hooks.append(module.register_forward_hook(post_hook)))
  198. shape = (1, 4, 1, 1)
  199. x = tensor(np.zeros(shape, dtype=np.float32))
  200. y = net(x)
  201. assert pre_hook_num == 4
  202. assert post_hook_num == 4
  203. mean1 = Parameter(np.zeros(shape), dtype=np.float32)
  204. bn1 = F.batch_norm(
  205. x + 3, mean1, Parameter(np.ones(shape), dtype=np.float32), training=True
  206. )
  207. np.testing.assert_allclose(
  208. net.i.bn.running_mean.numpy(), mean1.numpy(),
  209. )
  210. mean2 = Parameter(np.zeros(shape), dtype=np.float32)
  211. bn2 = F.batch_norm(
  212. bn1 + 3, mean2, Parameter(np.ones(shape), dtype=np.float32), training=True
  213. )
  214. np.testing.assert_allclose(
  215. net.bn.running_mean.numpy(), mean2.numpy(),
  216. )
  217. np.testing.assert_allclose((bn2 + 2).numpy(), y.numpy())
  218. assert len(hooks) == 8
  219. for handler in hooks:
  220. handler.remove()
  221. y = net(x)
  222. assert pre_hook_num == 4
  223. assert post_hook_num == 4
  224. class MyModule2(Module):
  225. class InnerModule(Module):
  226. def __init__(self):
  227. super().__init__()
  228. self.bn = BatchNorm2d(4)
  229. self.test_bool_key = {True: 1, False: 0}
  230. def forward(self, x):
  231. x = self.bn(x)
  232. def __init__(self):
  233. super().__init__()
  234. self.bn = BatchNorm2d(4)
  235. self.a = [
  236. BatchNorm2d(4),
  237. {"x": BatchNorm2d(4), "y": [BatchNorm2d(4), self.InnerModule()], "z": 0},
  238. (self.InnerModule(),),
  239. ]
  240. def forward(self, x):
  241. return x
  242. def test_expand_structure():
  243. m = MyModule2()
  244. rst = [
  245. ("", m),
  246. ("a.0", m.a[0]),
  247. ("a.1.x", m.a[1]["x"]),
  248. ("a.1.y.0", m.a[1]["y"][0]),
  249. ("a.1.y.1", m.a[1]["y"][1]),
  250. ("a.1.y.1.bn", m.a[1]["y"][1].bn),
  251. ("a.2.0", m.a[2][0]),
  252. ("a.2.0.bn", m.a[2][0].bn),
  253. ("bn", m.bn),
  254. ]
  255. assert list(m.named_modules()) == rst
  256. for item in rst[1:]:
  257. assert get_expand_structure(m, item[0]) == item[1]
  258. for item in reversed(rst[1:]):
  259. if _access_structure(m, item[0], lambda p, k, o: isinstance(p, tuple)):
  260. continue
  261. set_expand_structure(m, item[0], "TEST_VALUE")
  262. assert get_expand_structure(m, item[0]) == "TEST_VALUE"
  263. def test_flatten_others():
  264. def be_others(obj):
  265. return not isinstance(obj, (Tensor, Module))
  266. m = MyModule2()
  267. assert len(list(m._flatten(with_key=True, predicate=be_others))) == 0
  268. def test_flatten_with_parent():
  269. m = MyModule2()
  270. assert list(m.named_modules(with_parent=True)) == [
  271. ("", m, None),
  272. ("a.0", m.a[0], m),
  273. ("a.1.x", m.a[1]["x"], m),
  274. ("a.1.y.0", m.a[1]["y"][0], m),
  275. ("a.1.y.1", m.a[1]["y"][1], m),
  276. ("a.1.y.1.bn", m.a[1]["y"][1].bn, m.a[1]["y"][1]),
  277. ("a.2.0", m.a[2][0], m),
  278. ("a.2.0.bn", m.a[2][0].bn, m.a[2][0]),
  279. ("bn", m.bn, m),
  280. ]
  281. assert list(m.modules(with_parent=True)) == [
  282. (m, None),
  283. (m.a[0], m),
  284. (m.a[1]["x"], m),
  285. (m.a[1]["y"][0], m),
  286. (m.a[1]["y"][1], m),
  287. (m.a[1]["y"][1].bn, m.a[1]["y"][1]),
  288. (m.a[2][0], m),
  289. (m.a[2][0].bn, m.a[2][0]),
  290. (m.bn, m),
  291. ]
  292. class MyModule3(Module):
  293. class InnerModule(Module):
  294. def __init__(self):
  295. super().__init__()
  296. self.bn = BatchNorm2d(4)
  297. def forward(self, x):
  298. x = self.bn(x)
  299. def __init__(self):
  300. super().__init__()
  301. self.bn = BatchNorm2d(4)
  302. self.seq = Sequential(BatchNorm2d(4), self.InnerModule(),)
  303. def forward(self, x):
  304. return x
  305. def test_module_api_with_sequential():
  306. m = MyModule3()
  307. assert list(m.named_modules()) == [
  308. ("", m),
  309. ("bn", m.bn),
  310. ("seq", m.seq),
  311. ("seq.0", m.seq[0]),
  312. ("seq.1", m.seq[1]),
  313. ("seq.1.bn", m.seq[1].bn),
  314. ]
  315. def test_sequential_named_children():
  316. modules = OrderedDict()
  317. modules["name0"] = Linear(20, 10)
  318. modules["name1"] = Linear(10, 5)
  319. modules["name2"] = Linear(5, 1)
  320. m = Sequential(modules)
  321. l = list(m.named_children())
  322. assert l[0][0] == "name0"
  323. assert l[1][0] == "name1"
  324. assert l[2][0] == "name2"
  325. def test_state_dict():
  326. data_shape = (2, 28)
  327. data = tensor(np.random.random(data_shape))
  328. mlp = MLP()
  329. pred0 = mlp(data)
  330. with BytesIO() as fout:
  331. mge.save(mlp.state_dict(), fout)
  332. fout.seek(0)
  333. state_dict = mge.load(fout)
  334. state_dict["extra"] = None
  335. mlp1 = MLP()
  336. mlp1.load_state_dict(state_dict, strict=False)
  337. pred1 = mlp1(data)
  338. np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6)
  339. with pytest.raises(KeyError):
  340. mlp1.load_state_dict(state_dict)
  341. del state_dict["extra"]
  342. del state_dict["dense0.bias"]
  343. with pytest.raises(KeyError):
  344. mlp1.load_state_dict(state_dict)
  345. class AssertModule(Module):
  346. def __init__(self):
  347. super().__init__()
  348. self.error_tensor_key = {True: tensor([]), False: 0}
  349. def forward(self, x):
  350. return x
  351. def test_assert_message():
  352. with pytest.raises(
  353. AssertionError, match="keys for Tensor and Module must be str, error key: True"
  354. ):
  355. m = AssertModule()
  356. list(m._flatten())
  357. class Simple(Module):
  358. def __init__(self):
  359. super().__init__()
  360. self.conv0 = Conv2d(1, 1, kernel_size=3, bias=False)
  361. self.conv1 = Conv2d(1, 1, kernel_size=3, bias=False)
  362. self.conv1.weight = self.conv0.weight
  363. def forward(self, inputs):
  364. x = self.conv0(inputs)
  365. y = self.conv1(inputs)
  366. return x + y
  367. @pytest.mark.parametrize("test_traced_module", [True, False])
  368. def test_shared_param(test_traced_module):
  369. net = Simple()
  370. if test_traced_module:
  371. net = trace_module(net, tensor(np.random.random((1, 1, 8, 8))))
  372. assert net.conv0.weight is net.conv1.weight
  373. data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32))
  374. np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy())
  375. with BytesIO() as f:
  376. mge.save(net, f)
  377. f.seek(0)
  378. net1 = mge.load(f)
  379. assert net1.conv0.weight is net1.conv1.weight
  380. np.testing.assert_allclose(net1.conv0(data).numpy(), net1.conv1(data).numpy())
  381. with BytesIO() as f:
  382. mge.save(net.conv0, f)
  383. f.seek(0)
  384. conv0 = mge.load(f)
  385. with BytesIO() as f:
  386. mge.save(net.conv1, f)
  387. f.seek(0)
  388. conv1 = mge.load(f)
  389. assert conv0.weight is not conv1.weight
  390. np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy())
  391. class Simple2(Module):
  392. def __init__(self):
  393. super().__init__()
  394. self.conv1 = Conv1d(1, 1, kernel_size=3, bias=False)
  395. self.conv0 = Conv1d(1, 1, kernel_size=3, bias=False)
  396. self.conv1.weight = self.conv0.weight
  397. def forward(self, inputs):
  398. pass
  399. def test_shared_param_1d():
  400. net = Simple2()
  401. assert net.conv0.weight is net.conv1.weight
  402. data = tensor(np.random.random((1, 1, 8)).astype(np.float32))
  403. np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy())
  404. with BytesIO() as f:
  405. mge.save(net, f)
  406. f.seek(0)
  407. net1 = mge.load(f)
  408. assert net1.conv0.weight is net1.conv1.weight
  409. np.testing.assert_allclose(net1.conv0(data).numpy(), net1.conv1(data).numpy())
  410. with BytesIO() as f:
  411. mge.save(net.conv0, f)
  412. f.seek(0)
  413. conv0 = mge.load(f)
  414. with BytesIO() as f:
  415. mge.save(net.conv1, f)
  416. f.seek(0)
  417. conv1 = mge.load(f)
  418. assert conv0.weight is not conv1.weight
  419. np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy())
  420. @pytest.mark.parametrize("test_traced_module", [True, False])
  421. def test_pickle_module(test_traced_module):
  422. data_shape = (2, 28)
  423. data = tensor(np.random.random(data_shape))
  424. mlp = MLP()
  425. pred_gt = mlp(data)
  426. if test_traced_module:
  427. mlp = trace_module(mlp, data)
  428. # pickle before forward
  429. with BytesIO() as fout:
  430. mge.save(mlp, fout)
  431. fout.seek(0)
  432. mlp1 = mge.load(fout)
  433. if test_traced_module:
  434. assert type(mlp1) == TracedModule
  435. pred0 = mlp1(data)
  436. pred1 = mlp(data)
  437. # pickle after forward
  438. with BytesIO() as fout:
  439. mge.save(mlp, fout)
  440. fout.seek(0)
  441. mlp1 = mge.load(fout)
  442. if test_traced_module:
  443. assert type(mlp1) == TracedModule
  444. pred2 = mlp1(data)
  445. np.testing.assert_allclose(pred_gt.numpy(), pred1.numpy(), atol=5e-6)
  446. np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6)
  447. np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6)
  448. def test_repr_basic():
  449. # test whether __repr__ can output correct information
  450. class ConvModel(Module):
  451. def __init__(self):
  452. super().__init__()
  453. self.conv1 = Conv2d(3, 128, 3, padding=1, bias=False)
  454. self.conv2 = Conv2d(3, 128, 3, dilation=2, bias=False)
  455. self.bn1 = BatchNorm1d(128)
  456. self.bn2 = BatchNorm2d(128)
  457. self.pooling = MaxPool2d(kernel_size=2, padding=0)
  458. modules = OrderedDict()
  459. modules["depthwise"] = Conv2d(256, 256, 3, 1, 0, groups=256, bias=False,)
  460. modules["pointwise"] = Conv2d(
  461. 256, 256, kernel_size=1, stride=1, padding=0, bias=True,
  462. )
  463. self.submodule1 = Sequential(modules)
  464. self.list1 = [Dropout(drop_prob=0.1), [Softmax(axis=100)]]
  465. self.tuple1 = (
  466. Dropout(drop_prob=0.1),
  467. (Softmax(axis=100), Dropout(drop_prob=0.2)),
  468. )
  469. self.dict1 = {"Dropout": Dropout(drop_prob=0.1)}
  470. self.fc1 = Linear(512, 1024)
  471. def forward(self, inputs):
  472. pass
  473. ground_truth = (
  474. "ConvModel(\n"
  475. " (conv1): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n"
  476. " (conv2): Conv2d(3, 128, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n"
  477. " (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
  478. " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)\n"
  479. " (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0)\n"
  480. " (submodule1): Sequential(\n"
  481. " (depthwise): Conv2d(256, 256, kernel_size=(3, 3), groups=256, bias=False)\n"
  482. " (pointwise): Conv2d(256, 256, kernel_size=(1, 1))\n"
  483. " )\n"
  484. " (list1.0): Dropout(drop_prob=0.1)\n"
  485. " (list1.1.0): Softmax(axis=100)\n"
  486. " (tuple1.0): Dropout(drop_prob=0.1)\n"
  487. " (tuple1.1.0): Softmax(axis=100)\n"
  488. " (tuple1.1.1): Dropout(drop_prob=0.2)\n"
  489. " (dict1.Dropout): Dropout(drop_prob=0.1)\n"
  490. " (fc1): Linear(in_features=512, out_features=1024, bias=True)\n"
  491. ")"
  492. )
  493. net = ConvModel()
  494. output = net.__repr__()
  495. assert output == ground_truth
  496. def test_repr_module_reassign():
  497. # test whether __repr__ can deal with module reassign
  498. class ConvModel1(Module):
  499. def __init__(self):
  500. super().__init__()
  501. self.conv1 = Conv2d(3, 128, 3, bias=False)
  502. self.conv2 = Conv2d(3, 128, 3, padding=1, bias=False)
  503. self.conv1 = Conv2d(3, 256, 3, dilation=2, bias=False)
  504. def forward(self, inputs):
  505. pass
  506. ground_truth = (
  507. "ConvModel1(\n"
  508. " (conv1): Conv2d(3, 256, kernel_size=(3, 3), dilation=(2, 2), bias=False)\n"
  509. " (conv2): Conv2d(3, 128, kernel_size=(3, 3), padding=(1, 1), bias=False)\n"
  510. ")"
  511. )
  512. net = ConvModel1()
  513. output = net.__repr__()
  514. assert output == ground_truth
  515. def test_repr_module_rereference():
  516. # test whether __repr__ can deal with module re-reference
  517. class ConvModel2(Module):
  518. def __init__(self):
  519. super().__init__()
  520. self.conv1 = Conv2d(3, 128, 3, bias=False)
  521. self.conv2 = self.conv1
  522. self.conv3 = self.conv1
  523. def forward(self, inputs):
  524. pass
  525. ground_truth = (
  526. "ConvModel2(\n"
  527. " (conv1): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n"
  528. " (conv2): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n"
  529. " (conv3): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n"
  530. ")"
  531. )
  532. net = ConvModel2()
  533. output = net.__repr__()
  534. assert output == ground_truth
  535. def test_repr_module_delete():
  536. # test whether __repr__ can deal with module delete
  537. class ConvModel3(Module):
  538. def __init__(self):
  539. super().__init__()
  540. self.conv1 = Conv2d(3, 128, 3, bias=False)
  541. self.softmax = Softmax(100)
  542. def forward(self, inputs):
  543. pass
  544. ground_truth = (
  545. "ConvModel3(\n"
  546. " (conv1): Conv2d(3, 128, kernel_size=(3, 3), bias=False)\n"
  547. ")"
  548. )
  549. net = ConvModel3()
  550. del net.softmax
  551. output = net.__repr__()
  552. assert output == ground_truth
  553. def test_repr_module_reset_attr():
  554. class ResetAttrModule(Module):
  555. def __init__(self, flag):
  556. super().__init__()
  557. if flag:
  558. self.a = None
  559. self.a = Linear(3, 5)
  560. else:
  561. self.a = Linear(3, 5)
  562. self.a = None
  563. def forward(self, x):
  564. if self.a:
  565. x = self.a(x)
  566. return x
  567. ground_truth = [
  568. (
  569. "ResetAttrModule(\n"
  570. " (a): Linear(in_features=3, out_features=5, bias=True)\n"
  571. ")"
  572. ),
  573. ("ResetAttrModule()"),
  574. ]
  575. m0 = ResetAttrModule(True)
  576. m1 = ResetAttrModule(False)
  577. output = [m0.__repr__(), m1.__repr__()]
  578. assert output == ground_truth
  579. def test_module_compatible():
  580. class Empty(Module):
  581. def forward(self):
  582. pass
  583. empty_module = Empty()
  584. old_attributes = set(
  585. [
  586. "_modules",
  587. "name",
  588. "training",
  589. "quantize_disabled",
  590. "_forward_pre_hooks",
  591. "_forward_hooks",
  592. "_name",
  593. "_short_name",
  594. ]
  595. )
  596. current_attributes = set(empty_module.__dict__.keys())
  597. assert (
  598. old_attributes == current_attributes
  599. ), "Add or delete attributes in Module class may break compatibility of pickle serialization"
  600. def test_grou_norm():
  601. class OriginGroupNormFunc(Module):
  602. def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs):
  603. super().__init__(**kwargs)
  604. assert num_channels % num_groups == 0
  605. self.num_groups = num_groups
  606. self.num_channels = num_channels
  607. self.eps = eps
  608. self.affine = affine
  609. if self.affine:
  610. self.weight = Parameter(np.ones(num_channels, dtype=np.float32))
  611. self.bias = Parameter(np.zeros(num_channels, dtype=np.float32))
  612. else:
  613. self.weight = None
  614. self.bias = None
  615. def forward(self, x):
  616. N, C, H, W = x.shape
  617. x = x.reshape(N, self.num_groups, -1)
  618. mean = x.mean(axis=2, keepdims=True)
  619. var = (x * x).mean(axis=2, keepdims=True) - mean * mean
  620. x = (x - mean) / F.sqrt(var + self.eps)
  621. x = x.reshape(N, C, H, W)
  622. if self.affine:
  623. x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(
  624. 1, -1, 1, 1
  625. )
  626. return x
  627. inp = np.random.randn(2, 256, 10, 16).astype("float32")
  628. mge_inp = Tensor(inp)
  629. mge_m = GroupNorm(32, 256)
  630. ori_inp = Tensor(inp)
  631. ori_m = OriginGroupNormFunc(32, 256)
  632. targets = np.array(2)
  633. mge_gm = mge.autodiff.GradManager().attach(mge_m.parameters())
  634. ori_gm = mge.autodiff.GradManager().attach(ori_m.parameters())
  635. for i in range(2):
  636. with mge_gm:
  637. mge_output = mge_m(mge_inp)
  638. loss = F.loss.square_loss(
  639. mge_output.sum(), mge.tensor(targets, dtype=np.float32)
  640. )
  641. mge_gm.backward(loss)
  642. with ori_gm:
  643. ori_output = ori_m(ori_inp)
  644. loss = F.loss.square_loss(
  645. ori_output.sum(), mge.tensor(targets, dtype=np.float32)
  646. )
  647. ori_gm.backward(loss)
  648. np.testing.assert_allclose(mge_output.numpy(), ori_output.numpy(), atol=1e-05)
  649. np.testing.assert_allclose(
  650. mge_m.weight.grad.numpy(), ori_m.weight.grad.numpy(), rtol=1e-03
  651. )
  652. np.testing.assert_allclose(
  653. mge_m.bias.grad.numpy(), ori_m.bias.grad.numpy(), rtol=1e-03
  654. )