You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_module_stats.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. import collections
  2. import math
  3. from copy import deepcopy
  4. import numpy as np
  5. import pytest
  6. import megengine as mge
  7. import megengine.functional as F
  8. import megengine.hub as hub
  9. import megengine.module as M
  10. from megengine.core._trace_option import use_symbolic_shape
  11. from megengine.utils.module_stats import module_stats
  12. @pytest.mark.skipif(
  13. use_symbolic_shape(), reason="This test do not support symbolic shape.",
  14. )
  15. def test_module_stats():
  16. net = ResNet(BasicBlock, [2, 2, 2, 2])
  17. input_shape = (1, 3, 224, 224)
  18. total_stats, stats_details = module_stats(net, input_shapes=input_shape)
  19. x1 = np.random.random((1, 3, 224, 224)).astype("float32")
  20. gt_flops, gt_acts = net.get_stats(mge.tensor(x1))
  21. assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,)
  22. total_stats, stats_details = module_stats(net, inputs=x1)
  23. assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,)
  24. @pytest.mark.skipif(
  25. use_symbolic_shape(), reason="This test do not support symbolic shape.",
  26. )
  27. def test_other_input_module_state():
  28. a = [1, 2]
  29. b = {"1": 1, "2": 2}
  30. nt = collections.namedtuple("nt", ["n", "t"])
  31. _nt = nt(n=1, t=2)
  32. net = FakeNet()
  33. net(a)
  34. net(b)
  35. net(_nt)
  36. @pytest.mark.skipif(
  37. use_symbolic_shape(), reason="This test do not support symbolic shape.",
  38. )
  39. def test_duplicated_module():
  40. input_shape = (1, 3, 224, 224)
  41. net0 = TestNet0()
  42. net0_stats, _ = module_stats(net0, input_shapes=input_shape)
  43. net1 = TestNet1()
  44. net1_stats, _ = module_stats(net1, input_shapes=input_shape)
  45. net2 = TestNet2()
  46. net2_stats, _ = module_stats(net2, input_shapes=input_shape)
  47. assert net0_stats.param_dims == net1_stats.param_dims
  48. assert net0_stats.param_size == net1_stats.param_size
  49. assert net0_stats.param_dims == net2_stats.param_dims
  50. assert net0_stats.param_size == net2_stats.param_size
  51. class TestNet0(M.Module):
  52. def __init__(self):
  53. super().__init__()
  54. self.conv = M.Conv2d(3, 3, 3, padding=(1, 1))
  55. self.conv.bias = mge.Parameter(
  56. np.random.random(self.conv.bias.shape).astype(np.float32)
  57. )
  58. def forward(self, x):
  59. x = self.conv(x)
  60. return x
  61. class TestNet1(TestNet0):
  62. def __init__(self):
  63. super().__init__()
  64. self.conv1 = self.conv
  65. def forward(self, x):
  66. x = self.conv(x)
  67. x = self.conv1(x)
  68. return x
  69. class TestNet2(TestNet0):
  70. def __init__(self):
  71. super().__init__()
  72. self.conv1 = M.Conv2d(3, 3, 3, padding=(1, 1))
  73. self.conv1.weight = self.conv.weight
  74. self.conv1.bias = self.conv.bias
  75. def forward(self, x):
  76. x = self.conv(x)
  77. x = self.conv1(x)
  78. return x
  79. class FakeNet(M.Module):
  80. def __init__(self):
  81. super().__init__()
  82. def forward(self, x):
  83. assert isinstance(
  84. x,
  85. (np.ndarray, collections.abc.Mapping, collections.abc.Sequence, mge.Tensor),
  86. ) or (isinstance(x, tuple) and hasattr(x, "_fields"))
  87. class BasicBlock(M.Module):
  88. expansion = 1
  89. def __init__(
  90. self,
  91. in_channels,
  92. channels,
  93. stride=1,
  94. groups=1,
  95. base_width=64,
  96. dilation=1,
  97. norm=M.BatchNorm2d,
  98. ):
  99. super().__init__()
  100. self.tmp_in_channels = in_channels
  101. self.tmp_channels = channels
  102. self.stride = stride
  103. if groups != 1 or base_width != 64:
  104. raise ValueError("BasicBlock only supports groups=1 and base_width=64")
  105. if dilation > 1:
  106. raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
  107. self.conv1 = M.Conv2d(
  108. in_channels, channels, 3, stride, padding=dilation, bias=False
  109. )
  110. self.bn1 = norm(channels)
  111. self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False)
  112. self.bn2 = norm(channels)
  113. self.downsample_id = M.Identity()
  114. self.downsample_conv = M.Conv2d(in_channels, channels, 1, stride, bias=False)
  115. self.downsample_norm = norm(channels)
  116. def forward(self, x):
  117. identity = x
  118. x = self.conv1(x)
  119. x = self.bn1(x)
  120. x = F.relu(x)
  121. x = self.conv2(x)
  122. x = self.bn2(x)
  123. if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
  124. identity = self.downsample_id(identity)
  125. else:
  126. identity = self.downsample_conv(identity)
  127. identity = self.downsample_norm(identity)
  128. x += identity
  129. x = F.relu(x)
  130. return x
  131. def get_stats(self, x):
  132. activations, flops = 0, 0
  133. identity = x
  134. in_x = deepcopy(x)
  135. x = self.conv1(x)
  136. tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
  137. activations += tmp_acts
  138. flops += tmp_flops
  139. in_x = deepcopy(x)
  140. x = self.bn1(x)
  141. tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
  142. activations += tmp_acts
  143. flops += tmp_flops
  144. x = F.relu(x)
  145. in_x = deepcopy(x)
  146. x = self.conv2(x)
  147. tmp_flops, tmp_acts = cal_conv_stats(self.conv2, in_x, x)
  148. activations += tmp_acts
  149. flops += tmp_flops
  150. in_x = deepcopy(x)
  151. x = self.bn2(x)
  152. tmp_flops, tmp_acts = cal_norm_stats(self.bn2, in_x, x)
  153. activations += tmp_acts
  154. flops += tmp_flops
  155. if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
  156. identity = self.downsample_id(identity)
  157. else:
  158. in_x = deepcopy(identity)
  159. identity = self.downsample_conv(identity)
  160. tmp_flops, tmp_acts = cal_conv_stats(self.downsample_conv, in_x, identity)
  161. activations += tmp_acts
  162. flops += tmp_flops
  163. in_x = deepcopy(identity)
  164. identity = self.downsample_norm(identity)
  165. tmp_flops, tmp_acts = cal_norm_stats(self.downsample_norm, in_x, identity)
  166. activations += tmp_acts
  167. flops += tmp_flops
  168. x += identity
  169. x = F.relu(x)
  170. return x, flops, activations
  171. class ResNet(M.Module):
  172. def __init__(
  173. self,
  174. block,
  175. layers=[2, 2, 2, 2],
  176. num_classes=1000,
  177. zero_init_residual=False,
  178. groups=1,
  179. width_per_group=64,
  180. replace_stride_with_dilation=None,
  181. norm=M.BatchNorm2d,
  182. ):
  183. super().__init__()
  184. self.in_channels = 64
  185. self.dilation = 1
  186. if replace_stride_with_dilation is None:
  187. # each element in the tuple indicates if we should replace
  188. # the 2x2 stride with a dilated convolution instead
  189. replace_stride_with_dilation = [False, False, False]
  190. if len(replace_stride_with_dilation) != 3:
  191. raise ValueError(
  192. "replace_stride_with_dilation should be None "
  193. "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
  194. )
  195. self.groups = groups
  196. self.base_width = width_per_group
  197. self.conv1 = M.Conv2d(
  198. 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False
  199. )
  200. self.bn1 = norm(self.in_channels)
  201. self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1)
  202. self.layer1_0 = BasicBlock(
  203. self.in_channels,
  204. 64,
  205. stride=1,
  206. groups=self.groups,
  207. base_width=self.base_width,
  208. dilation=self.dilation,
  209. norm=M.BatchNorm2d,
  210. )
  211. self.layer1_1 = BasicBlock(
  212. self.in_channels,
  213. 64,
  214. stride=1,
  215. groups=self.groups,
  216. base_width=self.base_width,
  217. dilation=self.dilation,
  218. norm=M.BatchNorm2d,
  219. )
  220. self.layer2_0 = BasicBlock(64, 128, stride=2)
  221. self.layer2_1 = BasicBlock(128, 128)
  222. self.layer3_0 = BasicBlock(128, 256, stride=2)
  223. self.layer3_1 = BasicBlock(256, 256)
  224. self.layer4_0 = BasicBlock(256, 512, stride=2)
  225. self.layer4_1 = BasicBlock(512, 512)
  226. self.layer1 = self._make_layer(block, 64, layers[0], norm=norm)
  227. self.layer2 = self._make_layer(
  228. block, 128, 2, stride=2, dilate=replace_stride_with_dilation[0], norm=norm
  229. )
  230. self.layer3 = self._make_layer(
  231. block, 256, 2, stride=2, dilate=replace_stride_with_dilation[1], norm=norm
  232. )
  233. self.layer4 = self._make_layer(
  234. block, 512, 2, stride=2, dilate=replace_stride_with_dilation[2], norm=norm
  235. )
  236. self.fc = M.Linear(512, num_classes)
  237. for m in self.modules():
  238. if isinstance(m, M.Conv2d):
  239. M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  240. if m.bias is not None:
  241. fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
  242. bound = 1 / math.sqrt(fan_in)
  243. M.init.uniform_(m.bias, -bound, bound)
  244. elif isinstance(m, M.BatchNorm2d):
  245. M.init.ones_(m.weight)
  246. M.init.zeros_(m.bias)
  247. elif isinstance(m, M.Linear):
  248. M.init.msra_uniform_(m.weight, a=math.sqrt(5))
  249. if m.bias is not None:
  250. fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
  251. bound = 1 / math.sqrt(fan_in)
  252. M.init.uniform_(m.bias, -bound, bound)
  253. if zero_init_residual:
  254. for m in self.modules():
  255. M.init.zeros_(m.bn2.weight)
  256. def _make_layer(
  257. self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d
  258. ):
  259. previous_dilation = self.dilation
  260. if dilate:
  261. self.dilation *= stride
  262. stride = 1
  263. layers = []
  264. layers.append(
  265. block(
  266. self.in_channels,
  267. channels,
  268. stride,
  269. groups=self.groups,
  270. base_width=self.base_width,
  271. dilation=previous_dilation,
  272. norm=norm,
  273. )
  274. )
  275. self.in_channels = channels * block.expansion
  276. for _ in range(1, blocks):
  277. layers.append(
  278. block(
  279. self.in_channels,
  280. channels,
  281. groups=self.groups,
  282. base_width=self.base_width,
  283. dilation=self.dilation,
  284. norm=norm,
  285. )
  286. )
  287. return M.Sequential(*layers)
  288. def extract_features(self, x):
  289. outputs = {}
  290. x = self.conv1(x)
  291. x = self.bn1(x)
  292. x = F.relu(x)
  293. x = self.maxpool(x)
  294. outputs["stem"] = x
  295. x = self.layer1(x)
  296. outputs["res2"] = x
  297. x = self.layer2(x)
  298. outputs["res3"] = x
  299. x = self.layer3(x)
  300. outputs["res4"] = x
  301. x = self.layer4(x)
  302. outputs["res5"] = x
  303. return outputs
  304. def forward(self, x):
  305. x = self.extract_features(x)["res5"]
  306. x = F.avg_pool2d(x, 7)
  307. x = F.flatten(x, 1)
  308. x = self.fc(x)
  309. return x
  310. def get_stats(self, x):
  311. flops, activations = 0, 0
  312. in_x = deepcopy(x)
  313. x = self.conv1(x)
  314. tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
  315. activations += tmp_acts
  316. flops += tmp_flops
  317. in_x = deepcopy(x)
  318. x = self.bn1(x)
  319. tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
  320. activations += tmp_acts
  321. flops += tmp_flops
  322. x = F.relu(x)
  323. in_x = deepcopy(x)
  324. x = self.maxpool(x)
  325. tmp_flops, tmp_acts = cal_pool_stats(self.maxpool, in_x, x)
  326. activations += tmp_acts
  327. flops += tmp_flops
  328. x, tmp_flops, tmp_acts = self.layer1_0.get_stats(x)
  329. activations += tmp_acts
  330. flops += tmp_flops
  331. x, tmp_flops, tmp_acts = self.layer1_1.get_stats(x)
  332. activations += tmp_acts
  333. flops += tmp_flops
  334. x, tmp_flops, tmp_acts = self.layer2_0.get_stats(x)
  335. activations += tmp_acts
  336. flops += tmp_flops
  337. x, tmp_flops, tmp_acts = self.layer2_1.get_stats(x)
  338. activations += tmp_acts
  339. flops += tmp_flops
  340. x, tmp_flops, tmp_acts = self.layer3_0.get_stats(x)
  341. activations += tmp_acts
  342. flops += tmp_flops
  343. x, tmp_flops, tmp_acts = self.layer3_1.get_stats(x)
  344. activations += tmp_acts
  345. flops += tmp_flops
  346. x, tmp_flops, tmp_acts = self.layer4_0.get_stats(x)
  347. activations += tmp_acts
  348. flops += tmp_flops
  349. x, tmp_flops, tmp_acts = self.layer4_1.get_stats(x)
  350. activations += tmp_acts
  351. flops += tmp_flops
  352. x = F.avg_pool2d(x, 7)
  353. x = F.flatten(x, 1)
  354. in_x = deepcopy(x)
  355. x = self.fc(x)
  356. tmp_flops, tmp_acts = cal_linear_stats(self.fc, in_x, x)
  357. activations += tmp_acts
  358. flops += tmp_flops
  359. return flops, activations
  360. def cal_conv_stats(module, input, output):
  361. bias = 1 if module.bias is not None else 0
  362. flops = np.prod(output[0].shape) * (
  363. module.in_channels // module.groups * np.prod(module.kernel_size) + bias
  364. )
  365. acts = np.prod(output[0].shape)
  366. return flops, acts
  367. def cal_norm_stats(module, input, output):
  368. return np.prod(input[0].shape) * 7, np.prod(output[0].shape)
  369. def cal_linear_stats(module, inputs, outputs):
  370. bias = module.out_features if module.bias is not None else 0
  371. return (
  372. np.prod(outputs[0].shape) * module.in_features + bias,
  373. np.prod(outputs[0].shape),
  374. )
  375. def cal_pool_stats(module, inputs, outputs):
  376. return (
  377. np.prod(outputs[0].shape) * (module.kernel_size ** 2),
  378. np.prod(outputs[0].shape),
  379. )