You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

supernet.py 6.6 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT License.
  3. # Written by Hao Du and Houwen Peng
  4. # email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
  5. from ...utils.builder_util import *
  6. from ...utils.search_structure_supernet import *
  7. from ...utils.op_by_layer_dict import flops_op_dict
  8. from ..builders.build_supernet import *
  9. from timm.models.layers import SelectAdaptivePool2d
  10. from timm.models.layers.activations import hard_sigmoid
  11. class SuperNet(nn.Module):
  12. def __init__(
  13. self,
  14. block_args,
  15. choices,
  16. num_classes=1000,
  17. in_chans=3,
  18. stem_size=16,
  19. num_features=1280,
  20. head_bias=True,
  21. channel_multiplier=1.0,
  22. pad_type='',
  23. act_layer=nn.ReLU,
  24. drop_rate=0.,
  25. drop_path_rate=0.,
  26. slice=4,
  27. se_kwargs=None,
  28. norm_layer=nn.BatchNorm2d,
  29. logger=None,
  30. norm_kwargs=None,
  31. global_pool='avg',
  32. resunit=False,
  33. dil_conv=False,
  34. verbose=False):
  35. super(SuperNet, self).__init__()
  36. self.num_classes = num_classes
  37. self.num_features = num_features
  38. self.drop_rate = drop_rate
  39. self._in_chs = in_chans
  40. self.logger = logger
  41. # Stem
  42. stem_size = round_channels(stem_size, channel_multiplier)
  43. self.conv_stem = create_conv2d(
  44. self._in_chs, stem_size, 3, stride=2, padding=pad_type)
  45. self.bn1 = norm_layer(stem_size, **norm_kwargs)
  46. self.act1 = act_layer(inplace=True)
  47. self._in_chs = stem_size
  48. # Middle stages (IR/ER/DS Blocks)
  49. builder = SuperNetBuilder(
  50. choices,
  51. channel_multiplier,
  52. 8,
  53. None,
  54. 32,
  55. pad_type,
  56. act_layer,
  57. se_kwargs,
  58. norm_layer,
  59. norm_kwargs,
  60. drop_path_rate,
  61. verbose=verbose,
  62. resunit=resunit,
  63. dil_conv=dil_conv,
  64. logger=self.logger)
  65. blocks = builder(self._in_chs, block_args)
  66. self.blocks = nn.Sequential(*blocks)
  67. self._in_chs = builder.in_chs
  68. # Head + Pooling
  69. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  70. self.conv_head = create_conv2d(
  71. self._in_chs,
  72. self.num_features,
  73. 1,
  74. padding=pad_type,
  75. bias=head_bias)
  76. self.act2 = act_layer(inplace=True)
  77. # Classifier
  78. self.classifier = nn.Linear(
  79. self.num_features *
  80. self.global_pool.feat_mult(),
  81. self.num_classes)
  82. self.meta_layer = nn.Linear(self.num_classes * slice, 1)
  83. efficientnet_init_weights(self)
  84. def get_classifier(self):
  85. return self.classifier
  86. def reset_classifier(self, num_classes, global_pool='avg'):
  87. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
  88. self.num_classes = num_classes
  89. self.classifier = nn.Linear(
  90. self.num_features * self.global_pool.feat_mult(),
  91. num_classes) if self.num_classes else None
  92. def forward_features(self, x):
  93. x = self.conv_stem(x)
  94. x = self.bn1(x)
  95. x = self.act1(x)
  96. x = self.blocks(x)
  97. x = self.global_pool(x)
  98. x = self.conv_head(x)
  99. x = self.act2(x)
  100. return x
  101. def forward(self, x):
  102. x = self.forward_features(x)
  103. x = x.flatten(1)
  104. if self.drop_rate > 0.:
  105. x = F.dropout(x, p=self.drop_rate, training=self.training)
  106. return self.classifier(x)
  107. def forward_meta(self, features):
  108. return self.meta_layer(features.view(1, -1))
  109. def rand_parameters(self, architecture, meta=False):
  110. for name, param in self.named_parameters(recurse=True):
  111. if 'meta' in name and meta:
  112. yield param
  113. elif 'blocks' not in name and 'meta' not in name and (not meta):
  114. yield param
  115. if not meta:
  116. for layer, layer_arch in zip(self.blocks, architecture):
  117. for blocks, arch in zip(layer, layer_arch):
  118. if arch == -1:
  119. continue
  120. for name, param in blocks[arch].named_parameters(
  121. recurse=True):
  122. yield param
  123. class Classifier(nn.Module):
  124. def __init__(self, num_classes=1000):
  125. super(Classifier, self).__init__()
  126. self.classifier = nn.Linear(num_classes, num_classes)
  127. def forward(self, x):
  128. return self.classifier(x)
  129. def gen_supernet(flops_minimum=0, flops_maximum=600, **kwargs):
  130. choices = {'kernel_size': [3, 5, 7], 'exp_ratio': [4, 6]}
  131. num_features = 1280
  132. # act_layer = HardSwish
  133. act_layer = Swish
  134. arch_def = [
  135. # stage 0, 112x112 in
  136. ['ds_r1_k3_s1_e1_c16_se0.25'],
  137. # stage 1, 112x112 in
  138. ['ir_r1_k3_s2_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25', 'ir_r1_k3_s1_e4_c24_se0.25',
  139. 'ir_r1_k3_s1_e4_c24_se0.25'],
  140. # stage 2, 56x56 in
  141. ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r1_k5_s1_e4_c40_se0.25', 'ir_r1_k5_s2_e4_c40_se0.25',
  142. 'ir_r1_k5_s2_e4_c40_se0.25'],
  143. # stage 3, 28x28 in
  144. ['ir_r1_k3_s2_e6_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25', 'ir_r1_k3_s1_e4_c80_se0.25',
  145. 'ir_r2_k3_s1_e4_c80_se0.25'],
  146. # stage 4, 14x14in
  147. ['ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25', 'ir_r1_k3_s1_e6_c96_se0.25',
  148. 'ir_r1_k3_s1_e6_c96_se0.25'],
  149. # stage 5, 14x14in
  150. ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s2_e6_c192_se0.25',
  151. 'ir_r1_k5_s2_e6_c192_se0.25'],
  152. # stage 6, 7x7 in
  153. ['cn_r1_k1_s1_c320_se0.25'],
  154. ]
  155. sta_num, arch_def, resolution = search_for_layer(
  156. flops_op_dict, arch_def, flops_minimum, flops_maximum)
  157. if sta_num is None or arch_def is None or resolution is None:
  158. raise ValueError('Invalid FLOPs Settings')
  159. model_kwargs = dict(
  160. block_args=decode_arch_def(arch_def),
  161. choices=choices,
  162. num_features=num_features,
  163. stem_size=16,
  164. norm_kwargs=resolve_bn_args(kwargs),
  165. act_layer=act_layer,
  166. se_kwargs=dict(
  167. act_layer=nn.ReLU,
  168. gate_fn=hard_sigmoid,
  169. reduce_mid=True,
  170. divisor=8),
  171. **kwargs,
  172. )
  173. model = SuperNet(**model_kwargs)
  174. return model, sta_num, resolution, arch_def

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能