You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_module_stats.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. import math
  2. from copy import deepcopy
  3. import numpy as np
  4. import pytest
  5. import megengine as mge
  6. import megengine.functional as F
  7. import megengine.hub as hub
  8. import megengine.module as M
  9. from megengine.core._trace_option import use_symbolic_shape
  10. from megengine.utils.module_stats import module_stats
  11. @pytest.mark.skipif(
  12. use_symbolic_shape(), reason="This test do not support symbolic shape.",
  13. )
  14. def test_module_stats():
  15. net = ResNet(BasicBlock, [2, 2, 2, 2])
  16. input_shape = (1, 3, 224, 224)
  17. total_stats, stats_details = module_stats(net, input_shape)
  18. x1 = mge.tensor(np.zeros((1, 3, 224, 224)))
  19. gt_flops, gt_acts = net.get_stats(x1)
  20. assert (total_stats.flops, stats_details.activations[-1]["act_dim"]) == (
  21. gt_flops,
  22. gt_acts,
  23. )
  24. class BasicBlock(M.Module):
  25. expansion = 1
  26. def __init__(
  27. self,
  28. in_channels,
  29. channels,
  30. stride=1,
  31. groups=1,
  32. base_width=64,
  33. dilation=1,
  34. norm=M.BatchNorm2d,
  35. ):
  36. super().__init__()
  37. self.tmp_in_channels = in_channels
  38. self.tmp_channels = channels
  39. self.stride = stride
  40. if groups != 1 or base_width != 64:
  41. raise ValueError("BasicBlock only supports groups=1 and base_width=64")
  42. if dilation > 1:
  43. raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
  44. self.conv1 = M.Conv2d(
  45. in_channels, channels, 3, stride, padding=dilation, bias=False
  46. )
  47. self.bn1 = norm(channels)
  48. self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False)
  49. self.bn2 = norm(channels)
  50. self.downsample_id = M.Identity()
  51. self.downsample_conv = M.Conv2d(in_channels, channels, 1, stride, bias=False)
  52. self.downsample_norm = norm(channels)
  53. def forward(self, x):
  54. identity = x
  55. x = self.conv1(x)
  56. x = self.bn1(x)
  57. x = F.relu(x)
  58. x = self.conv2(x)
  59. x = self.bn2(x)
  60. if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
  61. identity = self.downsample_id(identity)
  62. else:
  63. identity = self.downsample_conv(identity)
  64. identity = self.downsample_norm(identity)
  65. x += identity
  66. x = F.relu(x)
  67. return x
  68. def get_stats(self, x):
  69. activations, flops = 0, 0
  70. identity = x
  71. in_x = deepcopy(x)
  72. x = self.conv1(x)
  73. tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
  74. activations += tmp_acts
  75. flops += tmp_flops
  76. in_x = deepcopy(x)
  77. x = self.bn1(x)
  78. tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
  79. activations += tmp_acts
  80. flops += tmp_flops
  81. x = F.relu(x)
  82. in_x = deepcopy(x)
  83. x = self.conv2(x)
  84. tmp_flops, tmp_acts = cal_conv_stats(self.conv2, in_x, x)
  85. activations += tmp_acts
  86. flops += tmp_flops
  87. in_x = deepcopy(x)
  88. x = self.bn2(x)
  89. tmp_flops, tmp_acts = cal_norm_stats(self.bn2, in_x, x)
  90. activations += tmp_acts
  91. flops += tmp_flops
  92. if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
  93. identity = self.downsample_id(identity)
  94. else:
  95. in_x = deepcopy(identity)
  96. identity = self.downsample_conv(identity)
  97. tmp_flops, tmp_acts = cal_conv_stats(self.downsample_conv, in_x, identity)
  98. activations += tmp_acts
  99. flops += tmp_flops
  100. in_x = deepcopy(identity)
  101. identity = self.downsample_norm(identity)
  102. tmp_flops, tmp_acts = cal_norm_stats(self.downsample_norm, in_x, identity)
  103. activations += tmp_acts
  104. flops += tmp_flops
  105. x += identity
  106. x = F.relu(x)
  107. return x, flops, activations
  108. class ResNet(M.Module):
  109. def __init__(
  110. self,
  111. block,
  112. layers=[2, 2, 2, 2],
  113. num_classes=1000,
  114. zero_init_residual=False,
  115. groups=1,
  116. width_per_group=64,
  117. replace_stride_with_dilation=None,
  118. norm=M.BatchNorm2d,
  119. ):
  120. super().__init__()
  121. self.in_channels = 64
  122. self.dilation = 1
  123. if replace_stride_with_dilation is None:
  124. # each element in the tuple indicates if we should replace
  125. # the 2x2 stride with a dilated convolution instead
  126. replace_stride_with_dilation = [False, False, False]
  127. if len(replace_stride_with_dilation) != 3:
  128. raise ValueError(
  129. "replace_stride_with_dilation should be None "
  130. "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
  131. )
  132. self.groups = groups
  133. self.base_width = width_per_group
  134. self.conv1 = M.Conv2d(
  135. 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False
  136. )
  137. self.bn1 = norm(self.in_channels)
  138. self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1)
  139. self.layer1_0 = BasicBlock(
  140. self.in_channels,
  141. 64,
  142. stride=1,
  143. groups=self.groups,
  144. base_width=self.base_width,
  145. dilation=self.dilation,
  146. norm=M.BatchNorm2d,
  147. )
  148. self.layer1_1 = BasicBlock(
  149. self.in_channels,
  150. 64,
  151. stride=1,
  152. groups=self.groups,
  153. base_width=self.base_width,
  154. dilation=self.dilation,
  155. norm=M.BatchNorm2d,
  156. )
  157. self.layer2_0 = BasicBlock(64, 128, stride=2)
  158. self.layer2_1 = BasicBlock(128, 128)
  159. self.layer3_0 = BasicBlock(128, 256, stride=2)
  160. self.layer3_1 = BasicBlock(256, 256)
  161. self.layer4_0 = BasicBlock(256, 512, stride=2)
  162. self.layer4_1 = BasicBlock(512, 512)
  163. self.layer1 = self._make_layer(block, 64, layers[0], norm=norm)
  164. self.layer2 = self._make_layer(
  165. block, 128, 2, stride=2, dilate=replace_stride_with_dilation[0], norm=norm
  166. )
  167. self.layer3 = self._make_layer(
  168. block, 256, 2, stride=2, dilate=replace_stride_with_dilation[1], norm=norm
  169. )
  170. self.layer4 = self._make_layer(
  171. block, 512, 2, stride=2, dilate=replace_stride_with_dilation[2], norm=norm
  172. )
  173. self.fc = M.Linear(512, num_classes)
  174. for m in self.modules():
  175. if isinstance(m, M.Conv2d):
  176. M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  177. if m.bias is not None:
  178. fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
  179. bound = 1 / math.sqrt(fan_in)
  180. M.init.uniform_(m.bias, -bound, bound)
  181. elif isinstance(m, M.BatchNorm2d):
  182. M.init.ones_(m.weight)
  183. M.init.zeros_(m.bias)
  184. elif isinstance(m, M.Linear):
  185. M.init.msra_uniform_(m.weight, a=math.sqrt(5))
  186. if m.bias is not None:
  187. fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
  188. bound = 1 / math.sqrt(fan_in)
  189. M.init.uniform_(m.bias, -bound, bound)
  190. if zero_init_residual:
  191. for m in self.modules():
  192. M.init.zeros_(m.bn2.weight)
  193. def _make_layer(
  194. self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d
  195. ):
  196. previous_dilation = self.dilation
  197. if dilate:
  198. self.dilation *= stride
  199. stride = 1
  200. layers = []
  201. layers.append(
  202. block(
  203. self.in_channels,
  204. channels,
  205. stride,
  206. groups=self.groups,
  207. base_width=self.base_width,
  208. dilation=previous_dilation,
  209. norm=norm,
  210. )
  211. )
  212. self.in_channels = channels * block.expansion
  213. for _ in range(1, blocks):
  214. layers.append(
  215. block(
  216. self.in_channels,
  217. channels,
  218. groups=self.groups,
  219. base_width=self.base_width,
  220. dilation=self.dilation,
  221. norm=norm,
  222. )
  223. )
  224. return M.Sequential(*layers)
  225. def extract_features(self, x):
  226. outputs = {}
  227. x = self.conv1(x)
  228. x = self.bn1(x)
  229. x = F.relu(x)
  230. x = self.maxpool(x)
  231. outputs["stem"] = x
  232. x = self.layer1(x)
  233. outputs["res2"] = x
  234. x = self.layer2(x)
  235. outputs["res3"] = x
  236. x = self.layer3(x)
  237. outputs["res4"] = x
  238. x = self.layer4(x)
  239. outputs["res5"] = x
  240. return outputs
  241. def forward(self, x):
  242. x = self.extract_features(x)["res5"]
  243. x = F.avg_pool2d(x, 7)
  244. x = F.flatten(x, 1)
  245. x = self.fc(x)
  246. return x
  247. def get_stats(self, x):
  248. flops, activations = 0, 0
  249. in_x = deepcopy(x)
  250. x = self.conv1(x)
  251. tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
  252. activations += tmp_acts
  253. flops += tmp_flops
  254. in_x = deepcopy(x)
  255. x = self.bn1(x)
  256. tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
  257. activations += tmp_acts
  258. flops += tmp_flops
  259. x = F.relu(x)
  260. in_x = deepcopy(x)
  261. x = self.maxpool(x)
  262. tmp_flops, tmp_acts = cal_pool_stats(self.maxpool, in_x, x)
  263. activations += tmp_acts
  264. flops += tmp_flops
  265. x, tmp_flops, tmp_acts = self.layer1_0.get_stats(x)
  266. activations += tmp_acts
  267. flops += tmp_flops
  268. x, tmp_flops, tmp_acts = self.layer1_1.get_stats(x)
  269. activations += tmp_acts
  270. flops += tmp_flops
  271. x, tmp_flops, tmp_acts = self.layer2_0.get_stats(x)
  272. activations += tmp_acts
  273. flops += tmp_flops
  274. x, tmp_flops, tmp_acts = self.layer2_1.get_stats(x)
  275. activations += tmp_acts
  276. flops += tmp_flops
  277. x, tmp_flops, tmp_acts = self.layer3_0.get_stats(x)
  278. activations += tmp_acts
  279. flops += tmp_flops
  280. x, tmp_flops, tmp_acts = self.layer3_1.get_stats(x)
  281. activations += tmp_acts
  282. flops += tmp_flops
  283. x, tmp_flops, tmp_acts = self.layer4_0.get_stats(x)
  284. activations += tmp_acts
  285. flops += tmp_flops
  286. x, tmp_flops, tmp_acts = self.layer4_1.get_stats(x)
  287. activations += tmp_acts
  288. flops += tmp_flops
  289. x = F.avg_pool2d(x, 7)
  290. x = F.flatten(x, 1)
  291. in_x = deepcopy(x)
  292. x = self.fc(x)
  293. tmp_flops, tmp_acts = cal_linear_stats(self.fc, in_x, x)
  294. activations += tmp_acts
  295. flops += tmp_flops
  296. return flops, activations
  297. def cal_conv_stats(module, input, output):
  298. bias = 1 if module.bias is not None else 0
  299. flops = np.prod(output[0].shape) * (
  300. module.in_channels // module.groups * np.prod(module.kernel_size) + bias
  301. )
  302. acts = np.prod(output[0].shape)
  303. return flops, acts
  304. def cal_norm_stats(module, input, output):
  305. return np.prod(input[0].shape) * 7, np.prod(output[0].shape)
  306. def cal_linear_stats(module, inputs, outputs):
  307. bias = module.out_features if module.bias is not None else 0
  308. return (
  309. np.prod(outputs[0].shape) * module.in_features + bias,
  310. np.prod(outputs[0].shape),
  311. )
  312. def cal_pool_stats(module, inputs, outputs):
  313. return (
  314. np.prod(outputs[0].shape) * (module.kernel_size ** 2),
  315. np.prod(outputs[0].shape),
  316. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台