You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network.py 9.1 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import sys
  2. sys.path.append("../../")
  3. import os
  4. import re
  5. import pickle
  6. import torch
  7. import torch.nn as nn
  8. from pytorch.mutables import LayerChoice
  9. from blocks import ShuffleNetBlock, ShuffleXceptionBlock
  10. PARSED_FLOPS = {'LayerChoice1': [13396992, 15805440, 19418112, 13146112],
  11. 'LayerChoice2': [7325696, 8931328, 11339776, 12343296],
  12. 'LayerChoice3': [7325696, 8931328, 11339776, 12343296],
  13. 'LayerChoice4': [7325696, 8931328, 11339776, 12343296],
  14. 'LayerChoice5': [26304768, 28111104, 30820608, 20296192],
  15. 'LayerChoice6': [10599680, 11603200, 13108480, 16746240],
  16. 'LayerChoice7': [10599680, 11603200, 13108480, 16746240],
  17. 'LayerChoice8': [10599680, 11603200, 13108480, 16746240],
  18. 'LayerChoice9': [30670080, 31673600, 33178880, 21199360],
  19. 'LayerChoice10': [10317440, 10819200, 11571840, 15899520],
  20. 'LayerChoice11': [10317440, 10819200, 11571840, 15899520],
  21. 'LayerChoice12': [10317440, 10819200, 11571840, 15899520],
  22. 'LayerChoice13': [10317440, 10819200, 11571840, 15899520],
  23. 'LayerChoice14': [10317440, 10819200, 11571840, 15899520],
  24. 'LayerChoice15': [10317440, 10819200, 11571840, 15899520],
  25. 'LayerChoice16': [10317440, 10819200, 11571840, 15899520],
  26. 'LayerChoice17': [30387840, 30889600, 31642240, 20634880],
  27. 'LayerChoice18': [10176320, 10427200, 10803520, 15476160],
  28. 'LayerChoice19': [10176320, 10427200, 10803520, 15476160],
  29. 'LayerChoice20': [10176320, 10427200, 10803520, 15476160]}
  30. class ShuffleNetV2OneShot(nn.Module):
  31. block_keys = [
  32. 'shufflenet_3x3',
  33. 'shufflenet_5x5',
  34. 'shufflenet_7x7',
  35. 'xception_3x3',
  36. ]
  37. def __init__(self, input_size=224, first_conv_channels=16, last_conv_channels=1024, n_classes=1000,
  38. op_flops_path="./data/op_flops_dict.pkl", affine=False):
  39. super().__init__()
  40. assert input_size % 32 == 0
  41. with open(os.path.join(os.path.dirname(__file__), op_flops_path), "rb") as fp:
  42. self._op_flops_dict = pickle.load(fp)
  43. self.stage_blocks = [4, 4, 8, 4]
  44. self.stage_channels = [64, 160, 320, 640]
  45. self._parsed_flops = dict()
  46. self._input_size = input_size
  47. self._feature_map_size = input_size
  48. self._first_conv_channels = first_conv_channels
  49. self._last_conv_channels = last_conv_channels
  50. self._n_classes = n_classes
  51. self._affine = affine
  52. # building first layer
  53. self.first_conv = nn.Sequential(
  54. nn.Conv2d(3, first_conv_channels, 3, 2, 1, bias=False),
  55. nn.BatchNorm2d(first_conv_channels, affine=affine),
  56. nn.ReLU(inplace=True),
  57. )
  58. self._feature_map_size //= 2
  59. p_channels = first_conv_channels
  60. features = []
  61. for num_blocks, channels in zip(self.stage_blocks, self.stage_channels):
  62. features.extend(self._make_blocks(num_blocks, p_channels, channels))
  63. p_channels = channels
  64. self.features = nn.Sequential(*features)
  65. self.conv_last = nn.Sequential(
  66. nn.Conv2d(p_channels, last_conv_channels, 1, 1, 0, bias=False),
  67. nn.BatchNorm2d(last_conv_channels, affine=affine),
  68. nn.ReLU(inplace=True),
  69. )
  70. self.globalpool = nn.AvgPool2d(self._feature_map_size)
  71. self.dropout = nn.Dropout(0.1)
  72. self.classifier = nn.Sequential(
  73. nn.Linear(last_conv_channels, n_classes, bias=False),
  74. )
  75. self._initialize_weights()
  76. def _make_blocks(self, blocks, in_channels, channels):
  77. result = []
  78. for i in range(blocks):
  79. stride = 2 if i == 0 else 1
  80. inp = in_channels if i == 0 else channels
  81. oup = channels
  82. base_mid_channels = channels // 2
  83. mid_channels = int(base_mid_channels) # prepare for scale
  84. choice_block = LayerChoice([
  85. ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=3, stride=stride, affine=self._affine),
  86. ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=5, stride=stride, affine=self._affine),
  87. ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=7, stride=stride, affine=self._affine),
  88. ShuffleXceptionBlock(inp, oup, mid_channels=mid_channels, stride=stride, affine=self._affine)
  89. ])
  90. result.append(choice_block)
  91. # find the corresponding flops
  92. flop_key = (inp, oup, mid_channels, self._feature_map_size, self._feature_map_size, stride)
  93. self._parsed_flops[choice_block.key] = [
  94. self._op_flops_dict["{}_stride_{}".format(k, stride)][flop_key] for k in self.block_keys
  95. ]
  96. if stride == 2:
  97. self._feature_map_size //= 2
  98. # ##### mended by han ###################
  99. # 通过mutables.LayerChoice生成的choice_block会不断的更新choice_block.key编号,每次自增1,
  100. # 这样会使self._parsed_flops的键编号超过20,这样的键是不存在的
  101. # 出于所有算法共用一个mutable的原因,不在其中对
  102. # global_mutable_counting()
  103. # _reset_global_mutable_counting()
  104. # 两个函数进行调用或修改,因此在此需要对self.parsed_flops的键重命名
  105. _d = dict()
  106. for key, value in self._parsed_flops.items():
  107. _head = key[:11] # LayerChoice
  108. _index = int(key[11:]) % 20 # 模20,因为choiceblock共有20个,需要保证编号出于0-20
  109. if _index == 0:
  110. _index = 20 # 模20为0的索引,事实上应该是20
  111. _d.update({_head + str(_index): value})
  112. self._parsed_flops = _d
  113. # #######################################
  114. return result
  115. def forward(self, x):
  116. bs = x.size(0)
  117. x = self.first_conv(x)
  118. x = self.features(x)
  119. x = self.conv_last(x)
  120. x = self.globalpool(x)
  121. x = self.dropout(x)
  122. x = x.contiguous().view(bs, -1)
  123. x = self.classifier(x)
  124. return x
  125. def get_candidate_flops(self, candidate):
  126. conv1_flops = self._op_flops_dict["conv1"][(3, self._first_conv_channels,
  127. self._input_size, self._input_size, 2)]
  128. # Should use `last_conv_channels` here, but megvii insists that it's `n_classes`. Keeping it.
  129. # https://github.com/megvii-model/SinglePathOneShot/blob/36eed6cf083497ffa9cfe7b8da25bb0b6ba5a452/src/Supernet/flops.py#L313
  130. rest_flops = self._op_flops_dict["rest_operation"][(self.stage_channels[-1], self._n_classes,
  131. self._feature_map_size, self._feature_map_size, 1)]
  132. total_flops = conv1_flops + rest_flops
  133. for k, m in candidate.items():
  134. parsed_flops_dict = self._parsed_flops[k]
  135. if isinstance(m, dict): # to be compatible with classical nas format
  136. total_flops += parsed_flops_dict[m["_idx"]]
  137. else:
  138. total_flops += parsed_flops_dict[torch.max(m, 0)[1]]
  139. return total_flops
  140. def _initialize_weights(self):
  141. for name, m in self.named_modules():
  142. if isinstance(m, nn.Conv2d):
  143. if 'first' in name:
  144. nn.init.normal_(m.weight, 0, 0.01)
  145. else:
  146. nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
  147. if m.bias is not None:
  148. nn.init.constant_(m.bias, 0)
  149. elif isinstance(m, nn.BatchNorm2d):
  150. if m.weight is not None:
  151. nn.init.constant_(m.weight, 1)
  152. if m.bias is not None:
  153. nn.init.constant_(m.bias, 0.0001)
  154. nn.init.constant_(m.running_mean, 0)
  155. elif isinstance(m, nn.BatchNorm1d):
  156. nn.init.constant_(m.weight, 1)
  157. if m.bias is not None:
  158. nn.init.constant_(m.bias, 0.0001)
  159. nn.init.constant_(m.running_mean, 0)
  160. elif isinstance(m, nn.Linear):
  161. nn.init.normal_(m.weight, 0, 0.01)
  162. if m.bias is not None:
  163. nn.init.constant_(m.bias, 0)
  164. def load_and_parse_state_dict(filepath="./data/checkpoint-150000.pth.tar"):
  165. checkpoint = torch.load(filepath, map_location=torch.device("cpu"))
  166. if "state_dict" in checkpoint:
  167. checkpoint = checkpoint["state_dict"]
  168. result = dict()
  169. for k, v in checkpoint.items():
  170. if k.startswith("module."):
  171. k = k[len("module."):]
  172. result[k] = v
  173. return result
  174. if __name__ == "__main__":
  175. model = ShuffleNetV2OneShot()

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能