You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_module_stats.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. import collections
  2. import math
  3. from copy import deepcopy
  4. import numpy as np
  5. import pytest
  6. import megengine as mge
  7. import megengine.functional as F
  8. import megengine.hub as hub
  9. import megengine.module as M
  10. from megengine.core._trace_option import use_symbolic_shape
  11. from megengine.utils.module_stats import module_stats
  12. @pytest.mark.skipif(
  13. use_symbolic_shape(), reason="This test do not support symbolic shape.",
  14. )
  15. def test_module_stats():
  16. net = ResNet(BasicBlock, [2, 2, 2, 2])
  17. input_shape = (1, 3, 224, 224)
  18. total_stats, stats_details = module_stats(net, input_shapes=input_shape)
  19. x1 = np.random.random((1, 3, 224, 224)).astype("float32")
  20. gt_flops, gt_acts = net.get_stats(mge.tensor(x1))
  21. assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,)
  22. total_stats, stats_details = module_stats(net, inputs=x1)
  23. assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,)
  24. @pytest.mark.skipif(
  25. use_symbolic_shape(), reason="This test do not support symbolic shape.",
  26. )
  27. def test_other_input_module_state():
  28. a = [1, 2]
  29. b = {"1": 1, "2": 2}
  30. nt = collections.namedtuple("nt", ["n", "t"])
  31. _nt = nt(n=1, t=2)
  32. net = FakeNet()
  33. net(a)
  34. net(b)
  35. net(_nt)
  36. @pytest.mark.skipif(
  37. use_symbolic_shape(), reason="This test do not support symbolic shape.",
  38. )
  39. def test_duplicated_module():
  40. input_shape = (1, 3, 224, 224)
  41. net0 = TestNet0()
  42. net0_stats, _ = module_stats(net0, input_shapes=input_shape)
  43. net1 = TestNet1()
  44. net1_stats, _ = module_stats(net1, input_shapes=input_shape)
  45. net2 = TestNet2()
  46. net2_stats, _ = module_stats(net2, input_shapes=input_shape)
  47. assert net0_stats.param_dims == net1_stats.param_dims
  48. assert net0_stats.param_size == net1_stats.param_size
  49. assert net0_stats.param_dims == net2_stats.param_dims
  50. assert net0_stats.param_size == net2_stats.param_size
  51. @pytest.mark.skipif(
  52. use_symbolic_shape(), reason="This test do not support symbolic shape.",
  53. )
  54. def test_getattribute_param():
  55. class MyConvBn(M.Module):
  56. def __init__(self):
  57. super().__init__()
  58. self.in_channels = 64
  59. self.conv1 = M.Conv2d(
  60. 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=True
  61. )
  62. self.bn1 = M.BatchNorm2d(self.in_channels)
  63. def forward(self, input):
  64. input = self.conv1.calc_conv(input, self.conv1.weight, self.conv1.bias)
  65. input = self.bn1(input)
  66. return input
  67. model = MyConvBn()
  68. input_shape = (1, 3, 224, 224)
  69. total_stats, stats_detail = module_stats(model, input_shapes=input_shape)
  70. params = stats_detail.params
  71. def get_name(obj):
  72. return obj["name"]
  73. param_name = list(map(get_name, params))
  74. assert "conv1-w" in param_name and "conv1-b" in param_name
  75. class TestNet0(M.Module):
  76. def __init__(self):
  77. super().__init__()
  78. self.conv = M.Conv2d(3, 3, 3, padding=(1, 1))
  79. self.conv.bias = mge.Parameter(
  80. np.random.random(self.conv.bias.shape).astype(np.float32)
  81. )
  82. def forward(self, x):
  83. x = self.conv(x)
  84. return x
  85. class TestNet1(TestNet0):
  86. def __init__(self):
  87. super().__init__()
  88. self.conv1 = self.conv
  89. def forward(self, x):
  90. x = self.conv(x)
  91. x = self.conv1(x)
  92. return x
  93. class TestNet2(TestNet0):
  94. def __init__(self):
  95. super().__init__()
  96. self.conv1 = M.Conv2d(3, 3, 3, padding=(1, 1))
  97. self.conv1.weight = self.conv.weight
  98. self.conv1.bias = self.conv.bias
  99. def forward(self, x):
  100. x = self.conv(x)
  101. x = self.conv1(x)
  102. return x
  103. class FakeNet(M.Module):
  104. def __init__(self):
  105. super().__init__()
  106. def forward(self, x):
  107. assert isinstance(
  108. x,
  109. (
  110. np.ndarray,
  111. collections.abc.Mapping,
  112. collections.abc.Sequence,
  113. mge.Tensor,
  114. ),
  115. ) or (isinstance(x, tuple) and hasattr(x, "_fields"))
  116. class BasicBlock(M.Module):
  117. expansion = 1
  118. def __init__(
  119. self,
  120. in_channels,
  121. channels,
  122. stride=1,
  123. groups=1,
  124. base_width=64,
  125. dilation=1,
  126. norm=M.BatchNorm2d,
  127. ):
  128. super().__init__()
  129. self.tmp_in_channels = in_channels
  130. self.tmp_channels = channels
  131. self.stride = stride
  132. if groups != 1 or base_width != 64:
  133. raise ValueError("BasicBlock only supports groups=1 and base_width=64")
  134. if dilation > 1:
  135. raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
  136. self.conv1 = M.Conv2d(
  137. in_channels, channels, 3, stride, padding=dilation, bias=False
  138. )
  139. self.bn1 = norm(channels)
  140. self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False)
  141. self.bn2 = norm(channels)
  142. self.downsample_id = M.Identity()
  143. self.downsample_conv = M.Conv2d(in_channels, channels, 1, stride, bias=False)
  144. self.downsample_norm = norm(channels)
  145. def forward(self, x):
  146. identity = x
  147. x = self.conv1(x)
  148. x = self.bn1(x)
  149. x = F.relu(x)
  150. x = self.conv2(x)
  151. x = self.bn2(x)
  152. if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
  153. identity = self.downsample_id(identity)
  154. else:
  155. identity = self.downsample_conv(identity)
  156. identity = self.downsample_norm(identity)
  157. x += identity
  158. x = F.relu(x)
  159. return x
  160. def get_stats(self, x):
  161. activations, flops = 0, 0
  162. identity = x
  163. in_x = deepcopy(x)
  164. x = self.conv1(x)
  165. tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
  166. activations += tmp_acts
  167. flops += tmp_flops
  168. in_x = deepcopy(x)
  169. x = self.bn1(x)
  170. tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
  171. activations += tmp_acts
  172. flops += tmp_flops
  173. x = F.relu(x)
  174. in_x = deepcopy(x)
  175. x = self.conv2(x)
  176. tmp_flops, tmp_acts = cal_conv_stats(self.conv2, in_x, x)
  177. activations += tmp_acts
  178. flops += tmp_flops
  179. in_x = deepcopy(x)
  180. x = self.bn2(x)
  181. tmp_flops, tmp_acts = cal_norm_stats(self.bn2, in_x, x)
  182. activations += tmp_acts
  183. flops += tmp_flops
  184. if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
  185. identity = self.downsample_id(identity)
  186. else:
  187. in_x = deepcopy(identity)
  188. identity = self.downsample_conv(identity)
  189. tmp_flops, tmp_acts = cal_conv_stats(self.downsample_conv, in_x, identity)
  190. activations += tmp_acts
  191. flops += tmp_flops
  192. in_x = deepcopy(identity)
  193. identity = self.downsample_norm(identity)
  194. tmp_flops, tmp_acts = cal_norm_stats(self.downsample_norm, in_x, identity)
  195. activations += tmp_acts
  196. flops += tmp_flops
  197. x += identity
  198. x = F.relu(x)
  199. return x, flops, activations
  200. class ResNet(M.Module):
  201. def __init__(
  202. self,
  203. block,
  204. layers=[2, 2, 2, 2],
  205. num_classes=1000,
  206. zero_init_residual=False,
  207. groups=1,
  208. width_per_group=64,
  209. replace_stride_with_dilation=None,
  210. norm=M.BatchNorm2d,
  211. ):
  212. super().__init__()
  213. self.in_channels = 64
  214. self.dilation = 1
  215. if replace_stride_with_dilation is None:
  216. # each element in the tuple indicates if we should replace
  217. # the 2x2 stride with a dilated convolution instead
  218. replace_stride_with_dilation = [False, False, False]
  219. if len(replace_stride_with_dilation) != 3:
  220. raise ValueError(
  221. "replace_stride_with_dilation should be None "
  222. "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
  223. )
  224. self.groups = groups
  225. self.base_width = width_per_group
  226. self.conv1 = M.Conv2d(
  227. 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False
  228. )
  229. self.bn1 = norm(self.in_channels)
  230. self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1)
  231. self.layer1_0 = BasicBlock(
  232. self.in_channels,
  233. 64,
  234. stride=1,
  235. groups=self.groups,
  236. base_width=self.base_width,
  237. dilation=self.dilation,
  238. norm=M.BatchNorm2d,
  239. )
  240. self.layer1_1 = BasicBlock(
  241. self.in_channels,
  242. 64,
  243. stride=1,
  244. groups=self.groups,
  245. base_width=self.base_width,
  246. dilation=self.dilation,
  247. norm=M.BatchNorm2d,
  248. )
  249. self.layer2_0 = BasicBlock(64, 128, stride=2)
  250. self.layer2_1 = BasicBlock(128, 128)
  251. self.layer3_0 = BasicBlock(128, 256, stride=2)
  252. self.layer3_1 = BasicBlock(256, 256)
  253. self.layer4_0 = BasicBlock(256, 512, stride=2)
  254. self.layer4_1 = BasicBlock(512, 512)
  255. self.layer1 = self._make_layer(block, 64, layers[0], norm=norm)
  256. self.layer2 = self._make_layer(
  257. block, 128, 2, stride=2, dilate=replace_stride_with_dilation[0], norm=norm
  258. )
  259. self.layer3 = self._make_layer(
  260. block, 256, 2, stride=2, dilate=replace_stride_with_dilation[1], norm=norm
  261. )
  262. self.layer4 = self._make_layer(
  263. block, 512, 2, stride=2, dilate=replace_stride_with_dilation[2], norm=norm
  264. )
  265. self.fc = M.Linear(512, num_classes)
  266. for m in self.modules():
  267. if isinstance(m, M.Conv2d):
  268. M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  269. if m.bias is not None:
  270. fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
  271. bound = 1 / math.sqrt(fan_in)
  272. M.init.uniform_(m.bias, -bound, bound)
  273. elif isinstance(m, M.BatchNorm2d):
  274. M.init.ones_(m.weight)
  275. M.init.zeros_(m.bias)
  276. elif isinstance(m, M.Linear):
  277. M.init.msra_uniform_(m.weight, a=math.sqrt(5))
  278. if m.bias is not None:
  279. fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
  280. bound = 1 / math.sqrt(fan_in)
  281. M.init.uniform_(m.bias, -bound, bound)
  282. if zero_init_residual:
  283. for m in self.modules():
  284. M.init.zeros_(m.bn2.weight)
  285. def _make_layer(
  286. self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d
  287. ):
  288. previous_dilation = self.dilation
  289. if dilate:
  290. self.dilation *= stride
  291. stride = 1
  292. layers = []
  293. layers.append(
  294. block(
  295. self.in_channels,
  296. channels,
  297. stride,
  298. groups=self.groups,
  299. base_width=self.base_width,
  300. dilation=previous_dilation,
  301. norm=norm,
  302. )
  303. )
  304. self.in_channels = channels * block.expansion
  305. for _ in range(1, blocks):
  306. layers.append(
  307. block(
  308. self.in_channels,
  309. channels,
  310. groups=self.groups,
  311. base_width=self.base_width,
  312. dilation=self.dilation,
  313. norm=norm,
  314. )
  315. )
  316. return M.Sequential(*layers)
  317. def extract_features(self, x):
  318. outputs = {}
  319. x = self.conv1(x)
  320. x = self.bn1(x)
  321. x = F.relu(x)
  322. x = self.maxpool(x)
  323. outputs["stem"] = x
  324. x = self.layer1(x)
  325. outputs["res2"] = x
  326. x = self.layer2(x)
  327. outputs["res3"] = x
  328. x = self.layer3(x)
  329. outputs["res4"] = x
  330. x = self.layer4(x)
  331. outputs["res5"] = x
  332. return outputs
  333. def forward(self, x):
  334. x = self.extract_features(x)["res5"]
  335. x = F.avg_pool2d(x, 7)
  336. x = F.flatten(x, 1)
  337. x = self.fc(x)
  338. return x
  339. def get_stats(self, x):
  340. flops, activations = 0, 0
  341. in_x = deepcopy(x)
  342. x = self.conv1(x)
  343. tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
  344. activations += tmp_acts
  345. flops += tmp_flops
  346. in_x = deepcopy(x)
  347. x = self.bn1(x)
  348. tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
  349. activations += tmp_acts
  350. flops += tmp_flops
  351. x = F.relu(x)
  352. in_x = deepcopy(x)
  353. x = self.maxpool(x)
  354. tmp_flops, tmp_acts = cal_pool_stats(self.maxpool, in_x, x)
  355. activations += tmp_acts
  356. flops += tmp_flops
  357. x, tmp_flops, tmp_acts = self.layer1_0.get_stats(x)
  358. activations += tmp_acts
  359. flops += tmp_flops
  360. x, tmp_flops, tmp_acts = self.layer1_1.get_stats(x)
  361. activations += tmp_acts
  362. flops += tmp_flops
  363. x, tmp_flops, tmp_acts = self.layer2_0.get_stats(x)
  364. activations += tmp_acts
  365. flops += tmp_flops
  366. x, tmp_flops, tmp_acts = self.layer2_1.get_stats(x)
  367. activations += tmp_acts
  368. flops += tmp_flops
  369. x, tmp_flops, tmp_acts = self.layer3_0.get_stats(x)
  370. activations += tmp_acts
  371. flops += tmp_flops
  372. x, tmp_flops, tmp_acts = self.layer3_1.get_stats(x)
  373. activations += tmp_acts
  374. flops += tmp_flops
  375. x, tmp_flops, tmp_acts = self.layer4_0.get_stats(x)
  376. activations += tmp_acts
  377. flops += tmp_flops
  378. x, tmp_flops, tmp_acts = self.layer4_1.get_stats(x)
  379. activations += tmp_acts
  380. flops += tmp_flops
  381. x = F.avg_pool2d(x, 7)
  382. x = F.flatten(x, 1)
  383. in_x = deepcopy(x)
  384. x = self.fc(x)
  385. tmp_flops, tmp_acts = cal_linear_stats(self.fc, in_x, x)
  386. activations += tmp_acts
  387. flops += tmp_flops
  388. return flops, activations
  389. def cal_conv_stats(module, input, output):
  390. bias = 1 if module.bias is not None else 0
  391. flops = np.prod(output[0].shape) * (
  392. module.in_channels // module.groups * np.prod(module.kernel_size) + bias
  393. )
  394. acts = np.prod(output[0].shape)
  395. return flops, acts
  396. def cal_norm_stats(module, input, output):
  397. return np.prod(input[0].shape) * 7, np.prod(output[0].shape)
  398. def cal_linear_stats(module, inputs, outputs):
  399. bias = module.out_features if module.bias is not None else 0
  400. return (
  401. np.prod(outputs[0].shape) * module.in_features + bias,
  402. np.prod(outputs[0].shape),
  403. )
  404. def cal_pool_stats(module, inputs, outputs):
  405. return (
  406. np.prod(outputs[0].shape) * (module.kernel_size ** 2),
  407. np.prod(outputs[0].shape),
  408. )