You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_module_stats.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. import math
  2. from copy import deepcopy
  3. import numpy as np
  4. import pytest
  5. import megengine as mge
  6. import megengine.functional as F
  7. import megengine.hub as hub
  8. import megengine.module as M
  9. from megengine.core._trace_option import use_symbolic_shape
  10. from megengine.utils.module_stats import module_stats
  11. @pytest.mark.skipif(
  12. use_symbolic_shape(), reason="This test do not support symbolic shape.",
  13. )
  14. def test_module_stats():
  15. net = ResNet(BasicBlock, [2, 2, 2, 2])
  16. input_shape = (1, 3, 224, 224)
  17. total_stats, stats_details = module_stats(net, input_shapes=input_shape)
  18. x1 = np.random.random((1, 3, 224, 224)).astype("float32")
  19. gt_flops, gt_acts = net.get_stats(mge.tensor(x1))
  20. assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == (
  21. gt_flops,
  22. gt_acts,
  23. )
  24. total_stats, stats_details = module_stats(net, inputs=x1)
  25. assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == (
  26. gt_flops,
  27. gt_acts,
  28. )
  29. class BasicBlock(M.Module):
  30. expansion = 1
  31. def __init__(
  32. self,
  33. in_channels,
  34. channels,
  35. stride=1,
  36. groups=1,
  37. base_width=64,
  38. dilation=1,
  39. norm=M.BatchNorm2d,
  40. ):
  41. super().__init__()
  42. self.tmp_in_channels = in_channels
  43. self.tmp_channels = channels
  44. self.stride = stride
  45. if groups != 1 or base_width != 64:
  46. raise ValueError("BasicBlock only supports groups=1 and base_width=64")
  47. if dilation > 1:
  48. raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
  49. self.conv1 = M.Conv2d(
  50. in_channels, channels, 3, stride, padding=dilation, bias=False
  51. )
  52. self.bn1 = norm(channels)
  53. self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False)
  54. self.bn2 = norm(channels)
  55. self.downsample_id = M.Identity()
  56. self.downsample_conv = M.Conv2d(in_channels, channels, 1, stride, bias=False)
  57. self.downsample_norm = norm(channels)
  58. def forward(self, x):
  59. identity = x
  60. x = self.conv1(x)
  61. x = self.bn1(x)
  62. x = F.relu(x)
  63. x = self.conv2(x)
  64. x = self.bn2(x)
  65. if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
  66. identity = self.downsample_id(identity)
  67. else:
  68. identity = self.downsample_conv(identity)
  69. identity = self.downsample_norm(identity)
  70. x += identity
  71. x = F.relu(x)
  72. return x
  73. def get_stats(self, x):
  74. activations, flops = 0, 0
  75. identity = x
  76. in_x = deepcopy(x)
  77. x = self.conv1(x)
  78. tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
  79. activations += tmp_acts
  80. flops += tmp_flops
  81. in_x = deepcopy(x)
  82. x = self.bn1(x)
  83. tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
  84. activations += tmp_acts
  85. flops += tmp_flops
  86. x = F.relu(x)
  87. in_x = deepcopy(x)
  88. x = self.conv2(x)
  89. tmp_flops, tmp_acts = cal_conv_stats(self.conv2, in_x, x)
  90. activations += tmp_acts
  91. flops += tmp_flops
  92. in_x = deepcopy(x)
  93. x = self.bn2(x)
  94. tmp_flops, tmp_acts = cal_norm_stats(self.bn2, in_x, x)
  95. activations += tmp_acts
  96. flops += tmp_flops
  97. if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
  98. identity = self.downsample_id(identity)
  99. else:
  100. in_x = deepcopy(identity)
  101. identity = self.downsample_conv(identity)
  102. tmp_flops, tmp_acts = cal_conv_stats(self.downsample_conv, in_x, identity)
  103. activations += tmp_acts
  104. flops += tmp_flops
  105. in_x = deepcopy(identity)
  106. identity = self.downsample_norm(identity)
  107. tmp_flops, tmp_acts = cal_norm_stats(self.downsample_norm, in_x, identity)
  108. activations += tmp_acts
  109. flops += tmp_flops
  110. x += identity
  111. x = F.relu(x)
  112. return x, flops, activations
  113. class ResNet(M.Module):
  114. def __init__(
  115. self,
  116. block,
  117. layers=[2, 2, 2, 2],
  118. num_classes=1000,
  119. zero_init_residual=False,
  120. groups=1,
  121. width_per_group=64,
  122. replace_stride_with_dilation=None,
  123. norm=M.BatchNorm2d,
  124. ):
  125. super().__init__()
  126. self.in_channels = 64
  127. self.dilation = 1
  128. if replace_stride_with_dilation is None:
  129. # each element in the tuple indicates if we should replace
  130. # the 2x2 stride with a dilated convolution instead
  131. replace_stride_with_dilation = [False, False, False]
  132. if len(replace_stride_with_dilation) != 3:
  133. raise ValueError(
  134. "replace_stride_with_dilation should be None "
  135. "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
  136. )
  137. self.groups = groups
  138. self.base_width = width_per_group
  139. self.conv1 = M.Conv2d(
  140. 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False
  141. )
  142. self.bn1 = norm(self.in_channels)
  143. self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1)
  144. self.layer1_0 = BasicBlock(
  145. self.in_channels,
  146. 64,
  147. stride=1,
  148. groups=self.groups,
  149. base_width=self.base_width,
  150. dilation=self.dilation,
  151. norm=M.BatchNorm2d,
  152. )
  153. self.layer1_1 = BasicBlock(
  154. self.in_channels,
  155. 64,
  156. stride=1,
  157. groups=self.groups,
  158. base_width=self.base_width,
  159. dilation=self.dilation,
  160. norm=M.BatchNorm2d,
  161. )
  162. self.layer2_0 = BasicBlock(64, 128, stride=2)
  163. self.layer2_1 = BasicBlock(128, 128)
  164. self.layer3_0 = BasicBlock(128, 256, stride=2)
  165. self.layer3_1 = BasicBlock(256, 256)
  166. self.layer4_0 = BasicBlock(256, 512, stride=2)
  167. self.layer4_1 = BasicBlock(512, 512)
  168. self.layer1 = self._make_layer(block, 64, layers[0], norm=norm)
  169. self.layer2 = self._make_layer(
  170. block, 128, 2, stride=2, dilate=replace_stride_with_dilation[0], norm=norm
  171. )
  172. self.layer3 = self._make_layer(
  173. block, 256, 2, stride=2, dilate=replace_stride_with_dilation[1], norm=norm
  174. )
  175. self.layer4 = self._make_layer(
  176. block, 512, 2, stride=2, dilate=replace_stride_with_dilation[2], norm=norm
  177. )
  178. self.fc = M.Linear(512, num_classes)
  179. for m in self.modules():
  180. if isinstance(m, M.Conv2d):
  181. M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  182. if m.bias is not None:
  183. fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
  184. bound = 1 / math.sqrt(fan_in)
  185. M.init.uniform_(m.bias, -bound, bound)
  186. elif isinstance(m, M.BatchNorm2d):
  187. M.init.ones_(m.weight)
  188. M.init.zeros_(m.bias)
  189. elif isinstance(m, M.Linear):
  190. M.init.msra_uniform_(m.weight, a=math.sqrt(5))
  191. if m.bias is not None:
  192. fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
  193. bound = 1 / math.sqrt(fan_in)
  194. M.init.uniform_(m.bias, -bound, bound)
  195. if zero_init_residual:
  196. for m in self.modules():
  197. M.init.zeros_(m.bn2.weight)
  198. def _make_layer(
  199. self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d
  200. ):
  201. previous_dilation = self.dilation
  202. if dilate:
  203. self.dilation *= stride
  204. stride = 1
  205. layers = []
  206. layers.append(
  207. block(
  208. self.in_channels,
  209. channels,
  210. stride,
  211. groups=self.groups,
  212. base_width=self.base_width,
  213. dilation=previous_dilation,
  214. norm=norm,
  215. )
  216. )
  217. self.in_channels = channels * block.expansion
  218. for _ in range(1, blocks):
  219. layers.append(
  220. block(
  221. self.in_channels,
  222. channels,
  223. groups=self.groups,
  224. base_width=self.base_width,
  225. dilation=self.dilation,
  226. norm=norm,
  227. )
  228. )
  229. return M.Sequential(*layers)
  230. def extract_features(self, x):
  231. outputs = {}
  232. x = self.conv1(x)
  233. x = self.bn1(x)
  234. x = F.relu(x)
  235. x = self.maxpool(x)
  236. outputs["stem"] = x
  237. x = self.layer1(x)
  238. outputs["res2"] = x
  239. x = self.layer2(x)
  240. outputs["res3"] = x
  241. x = self.layer3(x)
  242. outputs["res4"] = x
  243. x = self.layer4(x)
  244. outputs["res5"] = x
  245. return outputs
  246. def forward(self, x):
  247. x = self.extract_features(x)["res5"]
  248. x = F.avg_pool2d(x, 7)
  249. x = F.flatten(x, 1)
  250. x = self.fc(x)
  251. return x
  252. def get_stats(self, x):
  253. flops, activations = 0, 0
  254. in_x = deepcopy(x)
  255. x = self.conv1(x)
  256. tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
  257. activations += tmp_acts
  258. flops += tmp_flops
  259. in_x = deepcopy(x)
  260. x = self.bn1(x)
  261. tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
  262. activations += tmp_acts
  263. flops += tmp_flops
  264. x = F.relu(x)
  265. in_x = deepcopy(x)
  266. x = self.maxpool(x)
  267. tmp_flops, tmp_acts = cal_pool_stats(self.maxpool, in_x, x)
  268. activations += tmp_acts
  269. flops += tmp_flops
  270. x, tmp_flops, tmp_acts = self.layer1_0.get_stats(x)
  271. activations += tmp_acts
  272. flops += tmp_flops
  273. x, tmp_flops, tmp_acts = self.layer1_1.get_stats(x)
  274. activations += tmp_acts
  275. flops += tmp_flops
  276. x, tmp_flops, tmp_acts = self.layer2_0.get_stats(x)
  277. activations += tmp_acts
  278. flops += tmp_flops
  279. x, tmp_flops, tmp_acts = self.layer2_1.get_stats(x)
  280. activations += tmp_acts
  281. flops += tmp_flops
  282. x, tmp_flops, tmp_acts = self.layer3_0.get_stats(x)
  283. activations += tmp_acts
  284. flops += tmp_flops
  285. x, tmp_flops, tmp_acts = self.layer3_1.get_stats(x)
  286. activations += tmp_acts
  287. flops += tmp_flops
  288. x, tmp_flops, tmp_acts = self.layer4_0.get_stats(x)
  289. activations += tmp_acts
  290. flops += tmp_flops
  291. x, tmp_flops, tmp_acts = self.layer4_1.get_stats(x)
  292. activations += tmp_acts
  293. flops += tmp_flops
  294. x = F.avg_pool2d(x, 7)
  295. x = F.flatten(x, 1)
  296. in_x = deepcopy(x)
  297. x = self.fc(x)
  298. tmp_flops, tmp_acts = cal_linear_stats(self.fc, in_x, x)
  299. activations += tmp_acts
  300. flops += tmp_flops
  301. return flops, activations
  302. def cal_conv_stats(module, input, output):
  303. bias = 1 if module.bias is not None else 0
  304. flops = np.prod(output[0].shape) * (
  305. module.in_channels // module.groups * np.prod(module.kernel_size) + bias
  306. )
  307. acts = np.prod(output[0].shape)
  308. return flops, acts
  309. def cal_norm_stats(module, input, output):
  310. return np.prod(input[0].shape) * 7, np.prod(output[0].shape)
  311. def cal_linear_stats(module, inputs, outputs):
  312. bias = module.out_features if module.bias is not None else 0
  313. return (
  314. np.prod(outputs[0].shape) * module.in_features + bias,
  315. np.prod(outputs[0].shape),
  316. )
  317. def cal_pool_stats(module, inputs, outputs):
  318. return (
  319. np.prod(outputs[0].shape) * (module.kernel_size ** 2),
  320. np.prod(outputs[0].shape),
  321. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台