You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resnet50_mge.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. #!/usr/bin/env python3
  2. import argparse
  3. import math
  4. import megengine.functional as F
  5. import megengine.hub as hub
  6. import megengine.module as M
  7. import numpy as np
  8. from megengine import jit, tensor
  9. class BasicBlock(M.Module):
  10. expansion = 1
  11. def __init__(
  12. self,
  13. in_channels,
  14. channels,
  15. stride=1,
  16. groups=1,
  17. base_width=64,
  18. dilation=1,
  19. norm=M.BatchNorm2d,
  20. ):
  21. super().__init__()
  22. if groups != 1 or base_width != 64:
  23. raise ValueError("BasicBlock only supports groups=1 and base_width=64")
  24. if dilation > 1:
  25. raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
  26. self.conv1 = M.Conv2d(
  27. in_channels, channels, 3, stride, padding=dilation, bias=False
  28. )
  29. self.bn1 = norm(channels)
  30. self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False)
  31. self.bn2 = norm(channels)
  32. self.downsample = (
  33. M.Identity()
  34. if in_channels == channels and stride == 1
  35. else M.Sequential(
  36. M.Conv2d(in_channels, channels, 1, stride, bias=False), norm(channels),
  37. )
  38. )
  39. def forward(self, x):
  40. identity = x
  41. x = self.conv1(x)
  42. x = self.bn1(x)
  43. x = F.relu(x)
  44. x = self.conv2(x)
  45. x = self.bn2(x)
  46. identity = self.downsample(identity)
  47. x += identity
  48. x = F.relu(x)
  49. return x
  50. class Bottleneck(M.Module):
  51. expansion = 4
  52. def __init__(
  53. self,
  54. in_channels,
  55. channels,
  56. stride=1,
  57. groups=1,
  58. base_width=64,
  59. dilation=1,
  60. norm=M.BatchNorm2d,
  61. ):
  62. super().__init__()
  63. width = int(channels * (base_width / 64.0)) * groups
  64. self.conv1 = M.Conv2d(in_channels, width, 1, 1, bias=False)
  65. self.bn1 = norm(width)
  66. self.conv2 = M.Conv2d(
  67. width,
  68. width,
  69. 3,
  70. stride,
  71. padding=dilation,
  72. groups=groups,
  73. dilation=dilation,
  74. bias=False,
  75. )
  76. self.bn2 = norm(width)
  77. self.conv3 = M.Conv2d(width, channels * self.expansion, 1, 1, bias=False)
  78. self.bn3 = norm(channels * self.expansion)
  79. self.downsample = (
  80. M.Identity()
  81. if in_channels == channels * self.expansion and stride == 1
  82. else M.Sequential(
  83. M.Conv2d(in_channels, channels * self.expansion, 1, stride, bias=False),
  84. norm(channels * self.expansion),
  85. )
  86. )
  87. def forward(self, x):
  88. identity = x
  89. x = self.conv1(x)
  90. x = self.bn1(x)
  91. x = F.relu(x)
  92. x = self.conv2(x)
  93. x = self.bn2(x)
  94. x = F.relu(x)
  95. x = self.conv3(x)
  96. x = self.bn3(x)
  97. identity = self.downsample(identity)
  98. x += identity
  99. x = F.relu(x)
  100. return x
  101. class ResNet(M.Module):
  102. def __init__(
  103. self,
  104. block,
  105. layers,
  106. num_classes=1000,
  107. zero_init_residual=False,
  108. groups=1,
  109. width_per_group=64,
  110. replace_stride_with_dilation=None,
  111. norm=M.BatchNorm2d,
  112. ):
  113. super().__init__()
  114. self.in_channels = 64
  115. self.dilation = 1
  116. if replace_stride_with_dilation is None:
  117. # each element in the tuple indicates if we should replace
  118. # the 2x2 stride with a dilated convolution instead
  119. replace_stride_with_dilation = [False, False, False]
  120. if len(replace_stride_with_dilation) != 3:
  121. raise ValueError(
  122. "replace_stride_with_dilation should be None "
  123. "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
  124. )
  125. self.groups = groups
  126. self.base_width = width_per_group
  127. self.conv1 = M.Conv2d(
  128. 3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False
  129. )
  130. self.bn1 = norm(self.in_channels)
  131. self.maxpool = M.MaxPool2d(kernel_size=3, stride=2, padding=0)
  132. self.layer1 = self._make_layer(block, 64, layers[0], norm=norm)
  133. self.layer2 = self._make_layer(
  134. block,
  135. 128,
  136. layers[1],
  137. stride=2,
  138. dilate=replace_stride_with_dilation[0],
  139. norm=norm,
  140. )
  141. self.layer3 = self._make_layer(
  142. block,
  143. 256,
  144. layers[2],
  145. stride=2,
  146. dilate=replace_stride_with_dilation[1],
  147. norm=norm,
  148. )
  149. self.layer4 = self._make_layer(
  150. block,
  151. 512,
  152. layers[3],
  153. stride=2,
  154. dilate=replace_stride_with_dilation[2],
  155. norm=norm,
  156. )
  157. self.fc = M.Linear(512 * block.expansion, num_classes)
  158. for m in self.modules():
  159. if isinstance(m, M.Conv2d):
  160. M.init.msra_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  161. if m.bias is not None:
  162. fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
  163. bound = 1 / math.sqrt(fan_in)
  164. M.init.uniform_(m.bias, -bound, bound)
  165. elif isinstance(m, M.BatchNorm2d):
  166. M.init.ones_(m.weight)
  167. M.init.zeros_(m.bias)
  168. elif isinstance(m, M.Linear):
  169. M.init.msra_uniform_(m.weight, a=math.sqrt(5))
  170. if m.bias is not None:
  171. fan_in, _ = M.init.calculate_fan_in_and_fan_out(m.weight)
  172. bound = 1 / math.sqrt(fan_in)
  173. M.init.uniform_(m.bias, -bound, bound)
  174. # Zero-initialize the last BN in each residual branch,
  175. # so that the residual branch starts with zeros, and each residual block
  176. # behaves like an identity. According to https://arxiv.org/abs/1706.02677
  177. # This improves the model by 0.2~0.3%.
  178. if zero_init_residual:
  179. for m in self.modules():
  180. if isinstance(m, Bottleneck):
  181. M.init.zeros_(m.bn3.weight)
  182. elif isinstance(m, BasicBlock):
  183. M.init.zeros_(m.bn2.weight)
  184. def _make_layer(
  185. self, block, channels, blocks, stride=1, dilate=False, norm=M.BatchNorm2d
  186. ):
  187. previous_dilation = self.dilation
  188. if dilate:
  189. self.dilation *= stride
  190. stride = 1
  191. layers = []
  192. layers.append(
  193. block(
  194. self.in_channels,
  195. channels,
  196. stride,
  197. groups=self.groups,
  198. base_width=self.base_width,
  199. dilation=previous_dilation,
  200. norm=norm,
  201. )
  202. )
  203. self.in_channels = channels * block.expansion
  204. for _ in range(1, blocks):
  205. layers.append(
  206. block(
  207. self.in_channels,
  208. channels,
  209. groups=self.groups,
  210. base_width=self.base_width,
  211. dilation=self.dilation,
  212. norm=norm,
  213. )
  214. )
  215. return M.Sequential(*layers)
  216. def extract_features(self, x):
  217. outputs = {}
  218. x = self.conv1(x)
  219. x = self.bn1(x)
  220. x = F.relu(x)
  221. x = self.maxpool(x)
  222. outputs["stem"] = x
  223. x = self.layer1(x)
  224. outputs["res2"] = x
  225. x = self.layer2(x)
  226. outputs["res3"] = x
  227. x = self.layer3(x)
  228. outputs["res4"] = x
  229. x = self.layer4(x)
  230. outputs["res5"] = x
  231. return outputs
  232. def forward(self, x):
  233. x = F.reshape(x, (1,3,224,224))
  234. x = self.extract_features(x)["res5"]
  235. x = F.avg_pool2d(x, 7)
  236. x = F.flatten(x, 1)
  237. x = self.fc(x)
  238. return x
  239. @hub.pretrained(
  240. "https://data.megengine.org.cn/models/weights/resnet50_fbaug_76254_4e14b7d1.pkl"
  241. )
  242. def resnet50(**kwargs):
  243. r"""ResNet-50 model from
  244. `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
  245. """
  246. return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
  247. if __name__ == "__main__":
  248. parser = argparse.ArgumentParser(
  249. description="dump mge model for resnet50",
  250. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  251. )
  252. parser.add_argument(
  253. "-b", "--batch-size", help="batch size of the model", default=1, type=int
  254. )
  255. parser.add_argument(
  256. "-d",
  257. "--dtype",
  258. help="the dtype of the model,which includes float32 and uint8",
  259. default="float32",
  260. type=str,
  261. )
  262. parser.add_argument(
  263. "--inputs",
  264. help="set the inputs data to get a model with testcase",
  265. default="",
  266. type=str,
  267. )
  268. parser.add_argument(
  269. "--dir",
  270. help="set the dir where the model to dump",
  271. default=".",
  272. type=str,
  273. )
  274. parser.add_argument(
  275. "--enable-nchw4",
  276. help="enable-nchw4 for NVIDIA CUDNN",
  277. action='store_true'
  278. )
  279. parser.add_argument(
  280. "--enable-chwn4",
  281. help="enable-chwn4 for NVIDIA CUDNN",
  282. action='store_true'
  283. )
  284. args = parser.parse_args()
  285. net = resnet50()
  286. net.eval()
  287. @jit.trace(symbolic=True, capture_as_const=True)
  288. def fun(data):
  289. return net(data)
  290. if args.dtype == "float32" or args.dtype == "uint8":
  291. # dump float32
  292. data_type=np.float32
  293. if args.dtype == "uint8":
  294. data_type =np.uint8
  295. data = tensor(
  296. (np.random.random([args.batch_size, 3, 224, 224])*255).astype(data_type)
  297. )
  298. fun(data)
  299. if args.inputs == "":
  300. fun.dump(
  301. args.dir + "/resnet50_b" + str(args.batch_size) + "_"+ args.dtype +"_without_data.mge", arg_names=["data"],
  302. no_assert=True, enable_nchw4=args.enable_nchw4, enable_chwn4=args.enable_chwn4,
  303. )
  304. else:
  305. fun.dump(
  306. args.dir + "/resnet50_b" + str(args.batch_size) + "_"+ args.dtype +"_with_data.mge", arg_names=["data"],
  307. input_data=[args.inputs], no_assert=True, enable_nchw4=args.enable_nchw4,
  308. )
  309. else:
  310. raise TypeError("dtype should be float32 or uint8")