You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_module_stats.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. import collections
  2. import math
  3. from copy import deepcopy
  4. import numpy as np
  5. import pytest
  6. import megengine as mge
  7. import megengine.functional as F
  8. import megengine.hub as hub
  9. import megengine.module as M
  10. from megengine.core._trace_option import use_symbolic_shape
  11. from megengine.utils.module_stats import module_stats
  12. @pytest.mark.skipif(
  13. use_symbolic_shape(), reason="This test do not support symbolic shape.",
  14. )
  15. def test_module_stats():
  16. net = ResNet(BasicBlock, [2, 2, 2, 2])
  17. input_shape = (1, 3, 224, 224)
  18. total_stats, stats_details = module_stats(net, input_shapes=input_shape)
  19. x1 = np.random.random((1, 3, 224, 224)).astype("float32")
  20. gt_flops, gt_acts = net.get_stats(mge.tensor(x1))
  21. assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,)
  22. total_stats, stats_details = module_stats(net, inputs=x1)
  23. assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,)
  24. @pytest.mark.skipif(
  25. use_symbolic_shape(), reason="This test do not support symbolic shape.",
  26. )
  27. def test_other_input_module_state():
  28. a = [1, 2]
  29. b = {"1": 1, "2": 2}
  30. nt = collections.namedtuple("nt", ["n", "t"])
  31. _nt = nt(n=1, t=2)
  32. net = FakeNet()
  33. net(a)
  34. net(b)
  35. net(_nt)
  36. class FakeNet(M.Module):
  37. def __init__(self):
  38. super().__init__()
  39. def forward(self, x):
  40. assert isinstance(
  41. x,
  42. (np.ndarray, collections.abc.Mapping, collections.abc.Sequence, mge.Tensor),
  43. ) or (isinstance(x, tuple) and hasattr(x, "_fields"))
  44. class BasicBlock(M.Module):
  45. expansion = 1
  46. def __init__(
  47. self,
  48. in_channels,
  49. channels,
  50. stride=1,
  51. groups=1,
  52. base_width=64,
  53. dilation=1,
  54. norm=M.BatchNorm2d,
  55. ):
  56. super().__init__()
  57. self.tmp_in_channels = in_channels
  58. self.tmp_channels = channels
  59. self.stride = stride
  60. if groups != 1 or base_width != 64:
  61. raise ValueError("BasicBlock only supports groups=1 and base_width=64")
  62. if dilation > 1:
  63. raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
  64. self.conv1 = M.Conv2d(
  65. in_channels, channels, 3, stride, padding=dilation, bias=False
  66. )
  67. self.bn1 = norm(channels)
  68. self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False)
  69. self.bn2 = norm(channels)
  70. self.downsample_id = M.Identity()
  71. self.downsample_conv = M.Conv2d(in_channels, channels, 1, stride, bias=False)
  72. self.downsample_norm = norm(channels)
  73. def forward(self, x):
  74. identity = x
  75. x = self.conv1(x)
  76. x = self.bn1(x)
  77. x = F.relu(x)
  78. x = self.conv2(x)
  79. x = self.bn2(x)
  80. if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
  81. identity = self.downsample_id(identity)
  82. else:
  83. identity = self.downsample_conv(identity)
  84. identity = self.downsample_norm(identity)
  85. x += identity
  86. x = F.relu(x)
  87. return x
  88. def get_stats(self, x):
  89. activations, flops = 0, 0
  90. identity = x
  91. in_x = deepcopy(x)
  92. x = self.conv1(x)
  93. tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
  94. activations += tmp_acts
  95. flops += tmp_flops
  96. in_x = deepcopy(x)
  97. x = self.bn1(x)
  98. tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
  99. activations += tmp_acts
  100. flops += tmp_flops
  101. x = F.relu(x)
  102. in_x = deepcopy(x)
  103. x = self.conv2(x)
  104. tmp_flops, tmp_acts = cal_conv_stats(self.conv2, in_x, x)
  105. activations += tmp_acts
  106. flops += tmp_flops
  107. in_x = deepcopy(x)
  108. x = self.bn2(x)
  109. tmp_flops, tmp_acts = cal_norm_stats(self.bn2, in_x, x)
  110. activations += tmp_acts
  111. flops += tmp_flops
  112. if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
  113. identity = self.downsample_id(identity)
  114. else:
  115. in_x = deepcopy(identity)
  116. identity = self.downsample_conv(identity)
  117. tmp_flops, tmp_acts = cal_conv_stats(self.downsample_conv, in_x, identity)
  118. activations += tmp_acts
  119. flops += tmp_flops
  120. in_x = deepcopy(identity)
  121. identity = self.downsample_norm(identity)
  122. tmp_flops, tmp_acts = cal_norm_stats(self.downsample_norm, in_x, identity)
  123. activations += tmp_acts
  124. flops += tmp_flops
  125. x += identity
  126. x = F.relu(x)
  127. return x, flops, activations
  128. class ResNet(M.Module):
  129. def __init__(
  130. self,
  131. block,
  132. layers=[2, 2, 2, 2],
  133. num_classes=1000,
  134. zero_init_residual=False,
  135. groups=1,
  136. width_per_group=64,
  137. replace_stride_with_dilation=None,
  138. norm=M.BatchNorm2d,
  139. ):
  140. super().__init__()
  141. self.in_channels = 64
  142. self.dilation = 1
  143. if replace_stride_with_dilation is None:
  144. # each element in the tuple indicates if we should replace
  145. # the 2x2 stride with a dilated convolution instead
  146. replace_stride_with_dilation = [False, False, False]
  147. if len(replace_stride_with_dilation) != 3:
  148. raise ValueError(
  149. "replace_stride_with_dilation should be None "
  150. "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
  151. )
  152. self.groups = groups
  153. self.base_width = width_per_group
  154. self.conv1 = M.Conv2d(
  155. 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False
  156. )
  157. self.bn1 = norm(self.in_channels)
  158. self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1)
  159. self.layer1_0 = BasicBlock(
  160. self.in_channels,
  161. 64,
  162. stride=1,
  163. groups=self.groups,
  164. base_width=self.base_width,
  165. dilation=self.dilation,
  166. norm=M.BatchNorm2d,
  167. )
  168. self.layer1_1 = BasicBlock(
  169. self.in_channels,
  170. 64,
  171. stride=1,
  172. groups=self.groups,
  173. base_width=self.base_width,
  174. dilation=self.dilation,
  175. norm=M.BatchNorm2d,
  176. )
  177. self.layer2_0 = BasicBlock(64, 128, stride=2)
  178. self.layer2_1 = BasicBlock(128, 128)
  179. self.layer3_0 = BasicBlock(128, 256, stride=2)
  180. self.layer3_1 = BasicBlock(256, 256)
  181. self.layer4_0 = BasicBlock(256, 512, stride=2)
  182. self.layer4_1 = BasicBlock(512, 512)
  183. self.layer1 = self._make_layer(block, 64, layers[0], norm=norm)
  184. self.layer2 = self._make_layer(
  185. block, 128, 2, stride=2, dilate=replace_stride_with_dilation[0], norm=norm
  186. )
  187. self.layer3 = self._make_layer(
  188. block, 256, 2, stride=2, dilate=replace_stride_with_dilation[1], norm=norm
  189. )
  190. self.layer4 = self._make_layer(
  191. block, 512, 2, stride=2, dilate=replace_stride_with_dilation[2], norm=norm
  192. )
  193. self.fc = M.Linear(512, num_classes)
  194. for m in self.modules():
  195. if isinstance(m, M.Conv2d):
  196. M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  197. if m.bias is not None:
  198. fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
  199. bound = 1 / math.sqrt(fan_in)
  200. M.init.uniform_(m.bias, -bound, bound)
  201. elif isinstance(m, M.BatchNorm2d):
  202. M.init.ones_(m.weight)
  203. M.init.zeros_(m.bias)
  204. elif isinstance(m, M.Linear):
  205. M.init.msra_uniform_(m.weight, a=math.sqrt(5))
  206. if m.bias is not None:
  207. fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
  208. bound = 1 / math.sqrt(fan_in)
  209. M.init.uniform_(m.bias, -bound, bound)
  210. if zero_init_residual:
  211. for m in self.modules():
  212. M.init.zeros_(m.bn2.weight)
  213. def _make_layer(
  214. self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d
  215. ):
  216. previous_dilation = self.dilation
  217. if dilate:
  218. self.dilation *= stride
  219. stride = 1
  220. layers = []
  221. layers.append(
  222. block(
  223. self.in_channels,
  224. channels,
  225. stride,
  226. groups=self.groups,
  227. base_width=self.base_width,
  228. dilation=previous_dilation,
  229. norm=norm,
  230. )
  231. )
  232. self.in_channels = channels * block.expansion
  233. for _ in range(1, blocks):
  234. layers.append(
  235. block(
  236. self.in_channels,
  237. channels,
  238. groups=self.groups,
  239. base_width=self.base_width,
  240. dilation=self.dilation,
  241. norm=norm,
  242. )
  243. )
  244. return M.Sequential(*layers)
  245. def extract_features(self, x):
  246. outputs = {}
  247. x = self.conv1(x)
  248. x = self.bn1(x)
  249. x = F.relu(x)
  250. x = self.maxpool(x)
  251. outputs["stem"] = x
  252. x = self.layer1(x)
  253. outputs["res2"] = x
  254. x = self.layer2(x)
  255. outputs["res3"] = x
  256. x = self.layer3(x)
  257. outputs["res4"] = x
  258. x = self.layer4(x)
  259. outputs["res5"] = x
  260. return outputs
  261. def forward(self, x):
  262. x = self.extract_features(x)["res5"]
  263. x = F.avg_pool2d(x, 7)
  264. x = F.flatten(x, 1)
  265. x = self.fc(x)
  266. return x
  267. def get_stats(self, x):
  268. flops, activations = 0, 0
  269. in_x = deepcopy(x)
  270. x = self.conv1(x)
  271. tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
  272. activations += tmp_acts
  273. flops += tmp_flops
  274. in_x = deepcopy(x)
  275. x = self.bn1(x)
  276. tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
  277. activations += tmp_acts
  278. flops += tmp_flops
  279. x = F.relu(x)
  280. in_x = deepcopy(x)
  281. x = self.maxpool(x)
  282. tmp_flops, tmp_acts = cal_pool_stats(self.maxpool, in_x, x)
  283. activations += tmp_acts
  284. flops += tmp_flops
  285. x, tmp_flops, tmp_acts = self.layer1_0.get_stats(x)
  286. activations += tmp_acts
  287. flops += tmp_flops
  288. x, tmp_flops, tmp_acts = self.layer1_1.get_stats(x)
  289. activations += tmp_acts
  290. flops += tmp_flops
  291. x, tmp_flops, tmp_acts = self.layer2_0.get_stats(x)
  292. activations += tmp_acts
  293. flops += tmp_flops
  294. x, tmp_flops, tmp_acts = self.layer2_1.get_stats(x)
  295. activations += tmp_acts
  296. flops += tmp_flops
  297. x, tmp_flops, tmp_acts = self.layer3_0.get_stats(x)
  298. activations += tmp_acts
  299. flops += tmp_flops
  300. x, tmp_flops, tmp_acts = self.layer3_1.get_stats(x)
  301. activations += tmp_acts
  302. flops += tmp_flops
  303. x, tmp_flops, tmp_acts = self.layer4_0.get_stats(x)
  304. activations += tmp_acts
  305. flops += tmp_flops
  306. x, tmp_flops, tmp_acts = self.layer4_1.get_stats(x)
  307. activations += tmp_acts
  308. flops += tmp_flops
  309. x = F.avg_pool2d(x, 7)
  310. x = F.flatten(x, 1)
  311. in_x = deepcopy(x)
  312. x = self.fc(x)
  313. tmp_flops, tmp_acts = cal_linear_stats(self.fc, in_x, x)
  314. activations += tmp_acts
  315. flops += tmp_flops
  316. return flops, activations
  317. def cal_conv_stats(module, input, output):
  318. bias = 1 if module.bias is not None else 0
  319. flops = np.prod(output[0].shape) * (
  320. module.in_channels // module.groups * np.prod(module.kernel_size) + bias
  321. )
  322. acts = np.prod(output[0].shape)
  323. return flops, acts
  324. def cal_norm_stats(module, input, output):
  325. return np.prod(input[0].shape) * 7, np.prod(output[0].shape)
  326. def cal_linear_stats(module, inputs, outputs):
  327. bias = module.out_features if module.bias is not None else 0
  328. return (
  329. np.prod(outputs[0].shape) * module.in_features + bias,
  330. np.prod(outputs[0].shape),
  331. )
  332. def cal_pool_stats(module, inputs, outputs):
  333. return (
  334. np.prod(outputs[0].shape) * (module.kernel_size ** 2),
  335. np.prod(outputs[0].shape),
  336. )