You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resnet.py 11 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. """
  2. # ---------------------------------------------------------------------------------
  3. # -*- coding: utf-8 -*-
  4. -----------------------------------------------------------------------------------
  5. # Copyright (c) Microsoft
  6. # Licensed under the MIT License.
  7. # Written by Bin Xiao (Bin.Xiao@microsoft.com)
  8. # Modified by Xingyi Zhou
  9. # Refer from: https://github.com/xingyizhou/CenterNet
  10. # Modifier: Nguyen Mau Dung (2020.08.09)
  11. # ------------------------------------------------------------------------------
  12. """
  13. from __future__ import absolute_import
  14. from __future__ import division
  15. from __future__ import print_function
  16. import os
  17. import torch
  18. import torch.nn as nn
  19. import torch.utils.model_zoo as model_zoo
  20. BN_MOMENTUM = 0.1
  21. model_urls = {
  22. 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
  23. 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
  24. 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
  25. 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
  26. 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
  27. }
  28. def conv3x3(in_planes, out_planes, stride=1):
  29. """3x3 convolution with padding"""
  30. return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
  31. padding=1, bias=False)
  32. class BasicBlock(nn.Module):
  33. expansion = 1
  34. def __init__(self, inplanes, planes, stride=1, downsample=None):
  35. super(BasicBlock, self).__init__()
  36. self.conv1 = conv3x3(inplanes, planes, stride)
  37. self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
  38. self.relu = nn.ReLU(inplace=True)
  39. self.conv2 = conv3x3(planes, planes)
  40. self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
  41. self.downsample = downsample
  42. self.stride = stride
  43. def forward(self, x):
  44. residual = x
  45. out = self.conv1(x)
  46. out = self.bn1(out)
  47. out = self.relu(out)
  48. out = self.conv2(out)
  49. out = self.bn2(out)
  50. if self.downsample is not None:
  51. residual = self.downsample(x)
  52. out += residual
  53. out = self.relu(out)
  54. return out
  55. class Bottleneck(nn.Module):
  56. expansion = 4
  57. def __init__(self, inplanes, planes, stride=1, downsample=None):
  58. super(Bottleneck, self).__init__()
  59. self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
  60. self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
  61. self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
  62. padding=1, bias=False)
  63. self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
  64. self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
  65. bias=False)
  66. self.bn3 = nn.BatchNorm2d(planes * self.expansion,
  67. momentum=BN_MOMENTUM)
  68. self.relu = nn.ReLU(inplace=True)
  69. self.downsample = downsample
  70. self.stride = stride
  71. def forward(self, x):
  72. residual = x
  73. out = self.conv1(x)
  74. out = self.bn1(out)
  75. out = self.relu(out)
  76. out = self.conv2(out)
  77. out = self.bn2(out)
  78. out = self.relu(out)
  79. out = self.conv3(out)
  80. out = self.bn3(out)
  81. if self.downsample is not None:
  82. residual = self.downsample(x)
  83. out += residual
  84. out = self.relu(out)
  85. return out
  86. class PoseResNet(nn.Module):
  87. def __init__(self, block, layers, heads, head_conv, **kwargs):
  88. self.inplanes = 64
  89. self.deconv_with_bias = False
  90. self.heads = heads
  91. super(PoseResNet, self).__init__()
  92. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
  93. bias=False)
  94. self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
  95. self.relu = nn.ReLU(inplace=True)
  96. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  97. self.layer1 = self._make_layer(block, 64, layers[0])
  98. self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
  99. self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
  100. self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
  101. # used for deconv layers
  102. self.deconv_layers = self._make_deconv_layer(
  103. 3,
  104. [256, 256, 256],
  105. [4, 4, 4],
  106. )
  107. # self.final_layer = []
  108. for head in sorted(self.heads):
  109. num_output = self.heads[head]
  110. if head_conv > 0:
  111. fc = nn.Sequential(
  112. nn.Conv2d(256, head_conv,
  113. kernel_size=3, padding=1, bias=True),
  114. nn.ReLU(inplace=True),
  115. nn.Conv2d(head_conv, num_output,
  116. kernel_size=1, stride=1, padding=0))
  117. else:
  118. fc = nn.Conv2d(
  119. in_channels=256,
  120. out_channels=num_output,
  121. kernel_size=1,
  122. stride=1,
  123. padding=0
  124. )
  125. self.__setattr__(head, fc)
  126. # self.final_layer = nn.ModuleList(self.final_layer)
  127. def _make_layer(self, block, planes, blocks, stride=1):
  128. downsample = None
  129. if stride != 1 or self.inplanes != planes * block.expansion:
  130. downsample = nn.Sequential(
  131. nn.Conv2d(self.inplanes, planes * block.expansion,
  132. kernel_size=1, stride=stride, bias=False),
  133. nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
  134. )
  135. layers = []
  136. layers.append(block(self.inplanes, planes, stride, downsample))
  137. self.inplanes = planes * block.expansion
  138. for i in range(1, blocks):
  139. layers.append(block(self.inplanes, planes))
  140. return nn.Sequential(*layers)
  141. def _get_deconv_cfg(self, deconv_kernel, index):
  142. if deconv_kernel == 4:
  143. padding = 1
  144. output_padding = 0
  145. elif deconv_kernel == 3:
  146. padding = 1
  147. output_padding = 1
  148. elif deconv_kernel == 2:
  149. padding = 0
  150. output_padding = 0
  151. return deconv_kernel, padding, output_padding
  152. def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
  153. assert num_layers == len(num_filters), \
  154. 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
  155. assert num_layers == len(num_kernels), \
  156. 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
  157. layers = []
  158. for i in range(num_layers):
  159. kernel, padding, output_padding = \
  160. self._get_deconv_cfg(num_kernels[i], i)
  161. planes = num_filters[i]
  162. layers.append(
  163. nn.ConvTranspose2d(
  164. in_channels=self.inplanes,
  165. out_channels=planes,
  166. kernel_size=kernel,
  167. stride=2,
  168. padding=padding,
  169. output_padding=output_padding,
  170. bias=self.deconv_with_bias))
  171. layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
  172. layers.append(nn.ReLU(inplace=True))
  173. self.inplanes = planes
  174. return nn.Sequential(*layers)
  175. def forward(self, x):
  176. x = self.conv1(x)
  177. x = self.bn1(x)
  178. x = self.relu(x)
  179. x = self.maxpool(x)
  180. x = self.layer1(x)
  181. x = self.layer2(x)
  182. x = self.layer3(x)
  183. x = self.layer4(x)
  184. x = self.deconv_layers(x)
  185. ret = {}
  186. for head in self.heads:
  187. ret[head] = self.__getattr__(head)(x)
  188. return ret
  189. def init_weights(self, num_layers, pretrained=True):
  190. if pretrained:
  191. # print('=> init resnet deconv weights from normal distribution')
  192. for _, m in self.deconv_layers.named_modules():
  193. if isinstance(m, nn.ConvTranspose2d):
  194. # print('=> init {}.weight as normal(0, 0.001)'.format(name))
  195. # print('=> init {}.bias as 0'.format(name))
  196. nn.init.normal_(m.weight, std=0.001)
  197. if self.deconv_with_bias:
  198. nn.init.constant_(m.bias, 0)
  199. elif isinstance(m, nn.BatchNorm2d):
  200. # print('=> init {}.weight as 1'.format(name))
  201. # print('=> init {}.bias as 0'.format(name))
  202. nn.init.constant_(m.weight, 1)
  203. nn.init.constant_(m.bias, 0)
  204. # print('=> init final conv weights from normal distribution')
  205. for head in self.heads:
  206. final_layer = self.__getattr__(head)
  207. for i, m in enumerate(final_layer.modules()):
  208. if isinstance(m, nn.Conv2d):
  209. # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  210. # print('=> init {}.weight as normal(0, 0.001)'.format(name))
  211. # print('=> init {}.bias as 0'.format(name))
  212. if m.weight.shape[0] == self.heads[head]:
  213. if 'hm' in head:
  214. nn.init.constant_(m.bias, -2.19)
  215. else:
  216. nn.init.normal_(m.weight, std=0.001)
  217. nn.init.constant_(m.bias, 0)
  218. # pretrained_state_dict = torch.load(pretrained)
  219. url = model_urls['resnet{}'.format(num_layers)]
  220. pretrained_state_dict = model_zoo.load_url(url)
  221. print('=> loading pretrained model {}'.format(url))
  222. self.load_state_dict(pretrained_state_dict, strict=False)
  223. resnet_spec = {18: (BasicBlock, [2, 2, 2, 2]),
  224. 34: (BasicBlock, [3, 4, 6, 3]),
  225. 50: (Bottleneck, [3, 4, 6, 3]),
  226. 101: (Bottleneck, [3, 4, 23, 3]),
  227. 152: (Bottleneck, [3, 8, 36, 3])}
  228. def get_pose_net(num_layers, heads, head_conv, imagenet_pretrained):
  229. block_class, layers = resnet_spec[num_layers]
  230. model = PoseResNet(block_class, layers, heads, head_conv=head_conv)
  231. model.init_weights(num_layers, pretrained=imagenet_pretrained)
  232. return model

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能