You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_module_stats.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. import collections
  2. import math
  3. from copy import deepcopy
  4. import numpy as np
  5. import pytest
  6. import megengine as mge
  7. import megengine.functional as F
  8. import megengine.hub as hub
  9. import megengine.module as M
  10. from megengine.core._trace_option import use_symbolic_shape
  11. from megengine.utils.module_stats import (
  12. hook_modules,
  13. module_stats,
  14. register_hook_module,
  15. )
  16. @pytest.mark.skipif(
  17. use_symbolic_shape(), reason="This test do not support symbolic shape.",
  18. )
  19. def test_module_stats():
  20. net = ResNet(BasicBlock, [2, 2, 2, 2])
  21. input_shape = (1, 3, 224, 224)
  22. total_stats, stats_details = module_stats(net, input_shapes=input_shape)
  23. x1 = np.random.random((1, 3, 224, 224)).astype("float32")
  24. gt_flops, gt_acts = net.get_stats(mge.tensor(x1))
  25. assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,)
  26. total_stats, stats_details = module_stats(net, inputs=x1)
  27. assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,)
  28. @pytest.mark.skipif(
  29. use_symbolic_shape(), reason="This test do not support symbolic shape.",
  30. )
  31. def test_other_input_module_state():
  32. a = [1, 2]
  33. b = {"1": 1, "2": 2}
  34. nt = collections.namedtuple("nt", ["n", "t"])
  35. _nt = nt(n=1, t=2)
  36. net = FakeNet()
  37. net(a)
  38. net(b)
  39. net(_nt)
  40. @pytest.mark.skipif(
  41. use_symbolic_shape(), reason="This test do not support symbolic shape.",
  42. )
  43. def test_duplicated_module():
  44. input_shape = (1, 3, 224, 224)
  45. net0 = TestNet0()
  46. net0_stats, _ = module_stats(net0, input_shapes=input_shape)
  47. net1 = TestNet1()
  48. net1_stats, _ = module_stats(net1, input_shapes=input_shape)
  49. net2 = TestNet2()
  50. net2_stats, _ = module_stats(net2, input_shapes=input_shape)
  51. assert net0_stats.param_dims == net1_stats.param_dims
  52. assert net0_stats.param_size == net1_stats.param_size
  53. assert net0_stats.param_dims == net2_stats.param_dims
  54. assert net0_stats.param_size == net2_stats.param_size
  55. @pytest.mark.skipif(
  56. use_symbolic_shape(), reason="This test do not support symbolic shape.",
  57. )
  58. def test_getattribute_param():
  59. class MyConvBn(M.Module):
  60. def __init__(self):
  61. super().__init__()
  62. self.in_channels = 64
  63. self.conv1 = M.Conv2d(
  64. 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=True
  65. )
  66. self.conv1.reset_parameters()
  67. self.bn1 = M.BatchNorm2d(self.in_channels)
  68. def forward(self, input):
  69. input = self.conv1.calc_conv(input, self.conv1.weight, self.conv1.bias)
  70. input = self.bn1(input)
  71. return input
  72. model = MyConvBn()
  73. input_shape = (1, 3, 224, 224)
  74. total_stats, stats_detail = module_stats(model, input_shapes=input_shape)
  75. params = stats_detail.params
  76. def get_name(obj):
  77. return obj["name"]
  78. param_names = list(map(get_name, params))
  79. assert "conv1-w" in param_names and "conv1-b" in param_names
  80. conv1_b_param = params[param_names.index("conv1-b")]
  81. assert int(conv1_b_param["mean"]) == 0 and int(conv1_b_param["std"]) == 0
  82. class TestNet0(M.Module):
  83. def __init__(self):
  84. super().__init__()
  85. self.conv = M.Conv2d(3, 3, 3, padding=(1, 1))
  86. self.conv.bias = mge.Parameter(
  87. np.random.random(self.conv.bias.shape).astype(np.float32)
  88. )
  89. def forward(self, x):
  90. x = self.conv(x)
  91. return x
  92. class TestNet1(TestNet0):
  93. def __init__(self):
  94. super().__init__()
  95. self.conv1 = self.conv
  96. def forward(self, x):
  97. x = self.conv(x)
  98. x = self.conv1(x)
  99. return x
  100. class TestNet2(TestNet0):
  101. def __init__(self):
  102. super().__init__()
  103. self.conv1 = M.Conv2d(3, 3, 3, padding=(1, 1))
  104. self.conv1.weight = self.conv.weight
  105. self.conv1.bias = self.conv.bias
  106. def forward(self, x):
  107. x = self.conv(x)
  108. x = self.conv1(x)
  109. return x
  110. class FakeNet(M.Module):
  111. def __init__(self):
  112. super().__init__()
  113. def forward(self, x):
  114. assert isinstance(
  115. x,
  116. (
  117. np.ndarray,
  118. collections.abc.Mapping,
  119. collections.abc.Sequence,
  120. mge.Tensor,
  121. ),
  122. ) or (isinstance(x, tuple) and hasattr(x, "_fields"))
  123. class BasicBlock(M.Module):
  124. expansion = 1
  125. def __init__(
  126. self,
  127. in_channels,
  128. channels,
  129. stride=1,
  130. groups=1,
  131. base_width=64,
  132. dilation=1,
  133. norm=M.BatchNorm2d,
  134. ):
  135. super().__init__()
  136. self.tmp_in_channels = in_channels
  137. self.tmp_channels = channels
  138. self.stride = stride
  139. if groups != 1 or base_width != 64:
  140. raise ValueError("BasicBlock only supports groups=1 and base_width=64")
  141. if dilation > 1:
  142. raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
  143. self.conv1 = M.Conv2d(
  144. in_channels, channels, 3, stride, padding=dilation, bias=False
  145. )
  146. self.bn1 = norm(channels)
  147. self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False)
  148. self.bn2 = norm(channels)
  149. self.downsample_id = M.Identity()
  150. self.downsample_conv = M.Conv2d(in_channels, channels, 1, stride, bias=False)
  151. self.downsample_norm = norm(channels)
  152. def forward(self, x):
  153. identity = x
  154. x = self.conv1(x)
  155. x = self.bn1(x)
  156. x = F.relu(x)
  157. x = self.conv2(x)
  158. x = self.bn2(x)
  159. if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
  160. identity = self.downsample_id(identity)
  161. else:
  162. identity = self.downsample_conv(identity)
  163. identity = self.downsample_norm(identity)
  164. x += identity
  165. x = F.relu(x)
  166. return x
  167. def get_stats(self, x):
  168. activations, flops = 0, 0
  169. identity = x
  170. in_x = deepcopy(x)
  171. x = self.conv1(x)
  172. tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
  173. activations += tmp_acts
  174. flops += tmp_flops
  175. in_x = deepcopy(x)
  176. x = self.bn1(x)
  177. tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
  178. activations += tmp_acts
  179. flops += tmp_flops
  180. x = F.relu(x)
  181. in_x = deepcopy(x)
  182. x = self.conv2(x)
  183. tmp_flops, tmp_acts = cal_conv_stats(self.conv2, in_x, x)
  184. activations += tmp_acts
  185. flops += tmp_flops
  186. in_x = deepcopy(x)
  187. x = self.bn2(x)
  188. tmp_flops, tmp_acts = cal_norm_stats(self.bn2, in_x, x)
  189. activations += tmp_acts
  190. flops += tmp_flops
  191. if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
  192. identity = self.downsample_id(identity)
  193. else:
  194. in_x = deepcopy(identity)
  195. identity = self.downsample_conv(identity)
  196. tmp_flops, tmp_acts = cal_conv_stats(self.downsample_conv, in_x, identity)
  197. activations += tmp_acts
  198. flops += tmp_flops
  199. in_x = deepcopy(identity)
  200. identity = self.downsample_norm(identity)
  201. tmp_flops, tmp_acts = cal_norm_stats(self.downsample_norm, in_x, identity)
  202. activations += tmp_acts
  203. flops += tmp_flops
  204. x += identity
  205. x = F.relu(x)
  206. return x, flops, activations
  207. class ResNet(M.Module):
  208. def __init__(
  209. self,
  210. block,
  211. layers=[2, 2, 2, 2],
  212. num_classes=1000,
  213. zero_init_residual=False,
  214. groups=1,
  215. width_per_group=64,
  216. replace_stride_with_dilation=None,
  217. norm=M.BatchNorm2d,
  218. ):
  219. super().__init__()
  220. self.in_channels = 64
  221. self.dilation = 1
  222. if replace_stride_with_dilation is None:
  223. # each element in the tuple indicates if we should replace
  224. # the 2x2 stride with a dilated convolution instead
  225. replace_stride_with_dilation = [False, False, False]
  226. if len(replace_stride_with_dilation) != 3:
  227. raise ValueError(
  228. "replace_stride_with_dilation should be None "
  229. "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
  230. )
  231. self.groups = groups
  232. self.base_width = width_per_group
  233. self.conv1 = M.Conv2d(
  234. 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False
  235. )
  236. self.bn1 = norm(self.in_channels)
  237. self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1)
  238. self.layer1_0 = BasicBlock(
  239. self.in_channels,
  240. 64,
  241. stride=1,
  242. groups=self.groups,
  243. base_width=self.base_width,
  244. dilation=self.dilation,
  245. norm=M.BatchNorm2d,
  246. )
  247. self.layer1_1 = BasicBlock(
  248. self.in_channels,
  249. 64,
  250. stride=1,
  251. groups=self.groups,
  252. base_width=self.base_width,
  253. dilation=self.dilation,
  254. norm=M.BatchNorm2d,
  255. )
  256. self.layer2_0 = BasicBlock(64, 128, stride=2)
  257. self.layer2_1 = BasicBlock(128, 128)
  258. self.layer3_0 = BasicBlock(128, 256, stride=2)
  259. self.layer3_1 = BasicBlock(256, 256)
  260. self.layer4_0 = BasicBlock(256, 512, stride=2)
  261. self.layer4_1 = BasicBlock(512, 512)
  262. self.layer1 = self._make_layer(block, 64, layers[0], norm=norm)
  263. self.layer2 = self._make_layer(
  264. block, 128, 2, stride=2, dilate=replace_stride_with_dilation[0], norm=norm
  265. )
  266. self.layer3 = self._make_layer(
  267. block, 256, 2, stride=2, dilate=replace_stride_with_dilation[1], norm=norm
  268. )
  269. self.layer4 = self._make_layer(
  270. block, 512, 2, stride=2, dilate=replace_stride_with_dilation[2], norm=norm
  271. )
  272. self.fc = M.Linear(512, num_classes)
  273. for m in self.modules():
  274. if isinstance(m, M.Conv2d):
  275. M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  276. if m.bias is not None:
  277. fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
  278. bound = 1 / math.sqrt(fan_in)
  279. M.init.uniform_(m.bias, -bound, bound)
  280. elif isinstance(m, M.BatchNorm2d):
  281. M.init.ones_(m.weight)
  282. M.init.zeros_(m.bias)
  283. elif isinstance(m, M.Linear):
  284. M.init.msra_uniform_(m.weight, a=math.sqrt(5))
  285. if m.bias is not None:
  286. fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
  287. bound = 1 / math.sqrt(fan_in)
  288. M.init.uniform_(m.bias, -bound, bound)
  289. if zero_init_residual:
  290. for m in self.modules():
  291. M.init.zeros_(m.bn2.weight)
  292. def _make_layer(
  293. self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d
  294. ):
  295. previous_dilation = self.dilation
  296. if dilate:
  297. self.dilation *= stride
  298. stride = 1
  299. layers = []
  300. layers.append(
  301. block(
  302. self.in_channels,
  303. channels,
  304. stride,
  305. groups=self.groups,
  306. base_width=self.base_width,
  307. dilation=previous_dilation,
  308. norm=norm,
  309. )
  310. )
  311. self.in_channels = channels * block.expansion
  312. for _ in range(1, blocks):
  313. layers.append(
  314. block(
  315. self.in_channels,
  316. channels,
  317. groups=self.groups,
  318. base_width=self.base_width,
  319. dilation=self.dilation,
  320. norm=norm,
  321. )
  322. )
  323. return M.Sequential(*layers)
  324. def extract_features(self, x):
  325. outputs = {}
  326. x = self.conv1(x)
  327. x = self.bn1(x)
  328. x = F.relu(x)
  329. x = self.maxpool(x)
  330. outputs["stem"] = x
  331. x = self.layer1(x)
  332. outputs["res2"] = x
  333. x = self.layer2(x)
  334. outputs["res3"] = x
  335. x = self.layer3(x)
  336. outputs["res4"] = x
  337. x = self.layer4(x)
  338. outputs["res5"] = x
  339. return outputs
  340. def forward(self, x):
  341. x = self.extract_features(x)["res5"]
  342. x = F.avg_pool2d(x, 7)
  343. x = F.flatten(x, 1)
  344. x = self.fc(x)
  345. return x
  346. def get_stats(self, x):
  347. flops, activations = 0, 0
  348. in_x = deepcopy(x)
  349. x = self.conv1(x)
  350. tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
  351. activations += tmp_acts
  352. flops += tmp_flops
  353. in_x = deepcopy(x)
  354. x = self.bn1(x)
  355. tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
  356. activations += tmp_acts
  357. flops += tmp_flops
  358. x = F.relu(x)
  359. in_x = deepcopy(x)
  360. x = self.maxpool(x)
  361. tmp_flops, tmp_acts = cal_pool_stats(self.maxpool, in_x, x)
  362. activations += tmp_acts
  363. flops += tmp_flops
  364. x, tmp_flops, tmp_acts = self.layer1_0.get_stats(x)
  365. activations += tmp_acts
  366. flops += tmp_flops
  367. x, tmp_flops, tmp_acts = self.layer1_1.get_stats(x)
  368. activations += tmp_acts
  369. flops += tmp_flops
  370. x, tmp_flops, tmp_acts = self.layer2_0.get_stats(x)
  371. activations += tmp_acts
  372. flops += tmp_flops
  373. x, tmp_flops, tmp_acts = self.layer2_1.get_stats(x)
  374. activations += tmp_acts
  375. flops += tmp_flops
  376. x, tmp_flops, tmp_acts = self.layer3_0.get_stats(x)
  377. activations += tmp_acts
  378. flops += tmp_flops
  379. x, tmp_flops, tmp_acts = self.layer3_1.get_stats(x)
  380. activations += tmp_acts
  381. flops += tmp_flops
  382. x, tmp_flops, tmp_acts = self.layer4_0.get_stats(x)
  383. activations += tmp_acts
  384. flops += tmp_flops
  385. x, tmp_flops, tmp_acts = self.layer4_1.get_stats(x)
  386. activations += tmp_acts
  387. flops += tmp_flops
  388. x = F.avg_pool2d(x, 7)
  389. x = F.flatten(x, 1)
  390. in_x = deepcopy(x)
  391. x = self.fc(x)
  392. tmp_flops, tmp_acts = cal_linear_stats(self.fc, in_x, x)
  393. activations += tmp_acts
  394. flops += tmp_flops
  395. return flops, activations
  396. def cal_conv_stats(module, input, output):
  397. bias = 1 if module.bias is not None else 0
  398. flops = np.prod(output[0].shape) * (
  399. module.in_channels // module.groups * np.prod(module.kernel_size) + bias
  400. )
  401. acts = np.prod(output[0].shape)
  402. return flops, acts
  403. def cal_norm_stats(module, input, output):
  404. return np.prod(input[0].shape) * 7, np.prod(output[0].shape)
  405. def cal_linear_stats(module, inputs, outputs):
  406. bias = module.out_features if module.bias is not None else 0
  407. return (
  408. np.prod(outputs[0].shape) * module.in_features + bias,
  409. np.prod(outputs[0].shape),
  410. )
  411. def cal_pool_stats(module, inputs, outputs):
  412. return (
  413. np.prod(outputs[0].shape) * (module.kernel_size ** 2),
  414. np.prod(outputs[0].shape),
  415. )
  416. def test_register_hook_module():
  417. modules = [TestNet0, TestNet1, TestNet2, FakeNet, BasicBlock, ResNet]
  418. register_hook_module(modules)
  419. for module in modules:
  420. assert module in hook_modules