You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_module_stats.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. import math
  2. from copy import deepcopy
  3. import numpy as np
  4. import pytest
  5. import megengine as mge
  6. import megengine.functional as F
  7. import megengine.hub as hub
  8. import megengine.module as M
  9. from megengine.core._trace_option import use_symbolic_shape
  10. from megengine.utils.module_stats import module_stats
  11. @pytest.mark.skipif(
  12. use_symbolic_shape(), reason="This test do not support symbolic shape.",
  13. )
  14. def test_module_stats():
  15. net = ResNet(BasicBlock, [2, 2, 2, 2])
  16. input_shape = (1, 3, 224, 224)
  17. total_stats, stats_details = module_stats(net, input_shapes=input_shape)
  18. x1 = np.random.random((1, 3, 224, 224)).astype("float32")
  19. gt_flops, gt_acts = net.get_stats(mge.tensor(x1))
  20. assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,)
  21. total_stats, stats_details = module_stats(net, inputs=x1)
  22. assert (total_stats.flops, total_stats.act_dims) == (gt_flops, gt_acts,)
  23. class BasicBlock(M.Module):
  24. expansion = 1
  25. def __init__(
  26. self,
  27. in_channels,
  28. channels,
  29. stride=1,
  30. groups=1,
  31. base_width=64,
  32. dilation=1,
  33. norm=M.BatchNorm2d,
  34. ):
  35. super().__init__()
  36. self.tmp_in_channels = in_channels
  37. self.tmp_channels = channels
  38. self.stride = stride
  39. if groups != 1 or base_width != 64:
  40. raise ValueError("BasicBlock only supports groups=1 and base_width=64")
  41. if dilation > 1:
  42. raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
  43. self.conv1 = M.Conv2d(
  44. in_channels, channels, 3, stride, padding=dilation, bias=False
  45. )
  46. self.bn1 = norm(channels)
  47. self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False)
  48. self.bn2 = norm(channels)
  49. self.downsample_id = M.Identity()
  50. self.downsample_conv = M.Conv2d(in_channels, channels, 1, stride, bias=False)
  51. self.downsample_norm = norm(channels)
  52. def forward(self, x):
  53. identity = x
  54. x = self.conv1(x)
  55. x = self.bn1(x)
  56. x = F.relu(x)
  57. x = self.conv2(x)
  58. x = self.bn2(x)
  59. if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
  60. identity = self.downsample_id(identity)
  61. else:
  62. identity = self.downsample_conv(identity)
  63. identity = self.downsample_norm(identity)
  64. x += identity
  65. x = F.relu(x)
  66. return x
  67. def get_stats(self, x):
  68. activations, flops = 0, 0
  69. identity = x
  70. in_x = deepcopy(x)
  71. x = self.conv1(x)
  72. tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
  73. activations += tmp_acts
  74. flops += tmp_flops
  75. in_x = deepcopy(x)
  76. x = self.bn1(x)
  77. tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
  78. activations += tmp_acts
  79. flops += tmp_flops
  80. x = F.relu(x)
  81. in_x = deepcopy(x)
  82. x = self.conv2(x)
  83. tmp_flops, tmp_acts = cal_conv_stats(self.conv2, in_x, x)
  84. activations += tmp_acts
  85. flops += tmp_flops
  86. in_x = deepcopy(x)
  87. x = self.bn2(x)
  88. tmp_flops, tmp_acts = cal_norm_stats(self.bn2, in_x, x)
  89. activations += tmp_acts
  90. flops += tmp_flops
  91. if self.tmp_in_channels == self.tmp_channels and self.stride == 1:
  92. identity = self.downsample_id(identity)
  93. else:
  94. in_x = deepcopy(identity)
  95. identity = self.downsample_conv(identity)
  96. tmp_flops, tmp_acts = cal_conv_stats(self.downsample_conv, in_x, identity)
  97. activations += tmp_acts
  98. flops += tmp_flops
  99. in_x = deepcopy(identity)
  100. identity = self.downsample_norm(identity)
  101. tmp_flops, tmp_acts = cal_norm_stats(self.downsample_norm, in_x, identity)
  102. activations += tmp_acts
  103. flops += tmp_flops
  104. x += identity
  105. x = F.relu(x)
  106. return x, flops, activations
  107. class ResNet(M.Module):
  108. def __init__(
  109. self,
  110. block,
  111. layers=[2, 2, 2, 2],
  112. num_classes=1000,
  113. zero_init_residual=False,
  114. groups=1,
  115. width_per_group=64,
  116. replace_stride_with_dilation=None,
  117. norm=M.BatchNorm2d,
  118. ):
  119. super().__init__()
  120. self.in_channels = 64
  121. self.dilation = 1
  122. if replace_stride_with_dilation is None:
  123. # each element in the tuple indicates if we should replace
  124. # the 2x2 stride with a dilated convolution instead
  125. replace_stride_with_dilation = [False, False, False]
  126. if len(replace_stride_with_dilation) != 3:
  127. raise ValueError(
  128. "replace_stride_with_dilation should be None "
  129. "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
  130. )
  131. self.groups = groups
  132. self.base_width = width_per_group
  133. self.conv1 = M.Conv2d(
  134. 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False
  135. )
  136. self.bn1 = norm(self.in_channels)
  137. self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=1)
  138. self.layer1_0 = BasicBlock(
  139. self.in_channels,
  140. 64,
  141. stride=1,
  142. groups=self.groups,
  143. base_width=self.base_width,
  144. dilation=self.dilation,
  145. norm=M.BatchNorm2d,
  146. )
  147. self.layer1_1 = BasicBlock(
  148. self.in_channels,
  149. 64,
  150. stride=1,
  151. groups=self.groups,
  152. base_width=self.base_width,
  153. dilation=self.dilation,
  154. norm=M.BatchNorm2d,
  155. )
  156. self.layer2_0 = BasicBlock(64, 128, stride=2)
  157. self.layer2_1 = BasicBlock(128, 128)
  158. self.layer3_0 = BasicBlock(128, 256, stride=2)
  159. self.layer3_1 = BasicBlock(256, 256)
  160. self.layer4_0 = BasicBlock(256, 512, stride=2)
  161. self.layer4_1 = BasicBlock(512, 512)
  162. self.layer1 = self._make_layer(block, 64, layers[0], norm=norm)
  163. self.layer2 = self._make_layer(
  164. block, 128, 2, stride=2, dilate=replace_stride_with_dilation[0], norm=norm
  165. )
  166. self.layer3 = self._make_layer(
  167. block, 256, 2, stride=2, dilate=replace_stride_with_dilation[1], norm=norm
  168. )
  169. self.layer4 = self._make_layer(
  170. block, 512, 2, stride=2, dilate=replace_stride_with_dilation[2], norm=norm
  171. )
  172. self.fc = M.Linear(512, num_classes)
  173. for m in self.modules():
  174. if isinstance(m, M.Conv2d):
  175. M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  176. if m.bias is not None:
  177. fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
  178. bound = 1 / math.sqrt(fan_in)
  179. M.init.uniform_(m.bias, -bound, bound)
  180. elif isinstance(m, M.BatchNorm2d):
  181. M.init.ones_(m.weight)
  182. M.init.zeros_(m.bias)
  183. elif isinstance(m, M.Linear):
  184. M.init.msra_uniform_(m.weight, a=math.sqrt(5))
  185. if m.bias is not None:
  186. fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
  187. bound = 1 / math.sqrt(fan_in)
  188. M.init.uniform_(m.bias, -bound, bound)
  189. if zero_init_residual:
  190. for m in self.modules():
  191. M.init.zeros_(m.bn2.weight)
  192. def _make_layer(
  193. self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d
  194. ):
  195. previous_dilation = self.dilation
  196. if dilate:
  197. self.dilation *= stride
  198. stride = 1
  199. layers = []
  200. layers.append(
  201. block(
  202. self.in_channels,
  203. channels,
  204. stride,
  205. groups=self.groups,
  206. base_width=self.base_width,
  207. dilation=previous_dilation,
  208. norm=norm,
  209. )
  210. )
  211. self.in_channels = channels * block.expansion
  212. for _ in range(1, blocks):
  213. layers.append(
  214. block(
  215. self.in_channels,
  216. channels,
  217. groups=self.groups,
  218. base_width=self.base_width,
  219. dilation=self.dilation,
  220. norm=norm,
  221. )
  222. )
  223. return M.Sequential(*layers)
  224. def extract_features(self, x):
  225. outputs = {}
  226. x = self.conv1(x)
  227. x = self.bn1(x)
  228. x = F.relu(x)
  229. x = self.maxpool(x)
  230. outputs["stem"] = x
  231. x = self.layer1(x)
  232. outputs["res2"] = x
  233. x = self.layer2(x)
  234. outputs["res3"] = x
  235. x = self.layer3(x)
  236. outputs["res4"] = x
  237. x = self.layer4(x)
  238. outputs["res5"] = x
  239. return outputs
  240. def forward(self, x):
  241. x = self.extract_features(x)["res5"]
  242. x = F.avg_pool2d(x, 7)
  243. x = F.flatten(x, 1)
  244. x = self.fc(x)
  245. return x
  246. def get_stats(self, x):
  247. flops, activations = 0, 0
  248. in_x = deepcopy(x)
  249. x = self.conv1(x)
  250. tmp_flops, tmp_acts = cal_conv_stats(self.conv1, in_x, x)
  251. activations += tmp_acts
  252. flops += tmp_flops
  253. in_x = deepcopy(x)
  254. x = self.bn1(x)
  255. tmp_flops, tmp_acts = cal_norm_stats(self.bn1, in_x, x)
  256. activations += tmp_acts
  257. flops += tmp_flops
  258. x = F.relu(x)
  259. in_x = deepcopy(x)
  260. x = self.maxpool(x)
  261. tmp_flops, tmp_acts = cal_pool_stats(self.maxpool, in_x, x)
  262. activations += tmp_acts
  263. flops += tmp_flops
  264. x, tmp_flops, tmp_acts = self.layer1_0.get_stats(x)
  265. activations += tmp_acts
  266. flops += tmp_flops
  267. x, tmp_flops, tmp_acts = self.layer1_1.get_stats(x)
  268. activations += tmp_acts
  269. flops += tmp_flops
  270. x, tmp_flops, tmp_acts = self.layer2_0.get_stats(x)
  271. activations += tmp_acts
  272. flops += tmp_flops
  273. x, tmp_flops, tmp_acts = self.layer2_1.get_stats(x)
  274. activations += tmp_acts
  275. flops += tmp_flops
  276. x, tmp_flops, tmp_acts = self.layer3_0.get_stats(x)
  277. activations += tmp_acts
  278. flops += tmp_flops
  279. x, tmp_flops, tmp_acts = self.layer3_1.get_stats(x)
  280. activations += tmp_acts
  281. flops += tmp_flops
  282. x, tmp_flops, tmp_acts = self.layer4_0.get_stats(x)
  283. activations += tmp_acts
  284. flops += tmp_flops
  285. x, tmp_flops, tmp_acts = self.layer4_1.get_stats(x)
  286. activations += tmp_acts
  287. flops += tmp_flops
  288. x = F.avg_pool2d(x, 7)
  289. x = F.flatten(x, 1)
  290. in_x = deepcopy(x)
  291. x = self.fc(x)
  292. tmp_flops, tmp_acts = cal_linear_stats(self.fc, in_x, x)
  293. activations += tmp_acts
  294. flops += tmp_flops
  295. return flops, activations
  296. def cal_conv_stats(module, input, output):
  297. bias = 1 if module.bias is not None else 0
  298. flops = np.prod(output[0].shape) * (
  299. module.in_channels // module.groups * np.prod(module.kernel_size) + bias
  300. )
  301. acts = np.prod(output[0].shape)
  302. return flops, acts
  303. def cal_norm_stats(module, input, output):
  304. return np.prod(input[0].shape) * 7, np.prod(output[0].shape)
  305. def cal_linear_stats(module, inputs, outputs):
  306. bias = module.out_features if module.bias is not None else 0
  307. return (
  308. np.prod(outputs[0].shape) * module.in_features + bias,
  309. np.prod(outputs[0].shape),
  310. )
  311. def cal_pool_stats(module, inputs, outputs):
  312. return (
  313. np.prod(outputs[0].shape) * (module.kernel_size ** 2),
  314. np.prod(outputs[0].shape),
  315. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台