You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

build_supernet.py 8.9 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. from copy import deepcopy
  2. from ...utils.builder_util import modify_block_args
  3. from ..blocks import get_Bottleneck, InvertedResidual
  4. from timm.models.efficientnet_blocks import *
  5. from pytorch.mutables import LayerChoice
  6. class SuperNetBuilder:
  7. """ Build Trunk Blocks
  8. """
  9. def __init__(
  10. self,
  11. choices,
  12. channel_multiplier=1.0,
  13. channel_divisor=8,
  14. channel_min=None,
  15. output_stride=32,
  16. pad_type='',
  17. act_layer=None,
  18. se_kwargs=None,
  19. norm_layer=nn.BatchNorm2d,
  20. norm_kwargs=None,
  21. drop_path_rate=0.,
  22. feature_location='',
  23. verbose=False,
  24. resunit=False,
  25. dil_conv=False,
  26. logger=None):
  27. # dict
  28. # choices = {'kernel_size': [3, 5, 7], 'exp_ratio': [4, 6]}
  29. self.choices = [[x, y] for x in choices['kernel_size']
  30. for y in choices['exp_ratio']]
  31. self.choices_num = len(self.choices) - 1
  32. self.channel_multiplier = channel_multiplier
  33. self.channel_divisor = channel_divisor
  34. self.channel_min = channel_min
  35. self.output_stride = output_stride
  36. self.pad_type = pad_type
  37. self.act_layer = act_layer
  38. self.se_kwargs = se_kwargs
  39. self.norm_layer = norm_layer
  40. self.norm_kwargs = norm_kwargs
  41. self.drop_path_rate = drop_path_rate
  42. self.feature_location = feature_location
  43. assert feature_location in ('pre_pwl', 'post_exp', '')
  44. self.verbose = verbose
  45. self.resunit = resunit
  46. self.dil_conv = dil_conv
  47. self.logger = logger
  48. # state updated during build, consumed by model
  49. self.in_chs = None
  50. def _round_channels(self, chs):
  51. return round_channels(
  52. chs,
  53. self.channel_multiplier,
  54. self.channel_divisor,
  55. self.channel_min)
  56. def _make_block(
  57. self,
  58. ba,
  59. choice_idx,
  60. block_idx,
  61. block_count,
  62. resunit=False,
  63. dil_conv=False):
  64. drop_path_rate = self.drop_path_rate * block_idx / block_count
  65. bt = ba.pop('block_type')
  66. ba['in_chs'] = self.in_chs
  67. ba['out_chs'] = self._round_channels(ba['out_chs'])
  68. if 'fake_in_chs' in ba and ba['fake_in_chs']:
  69. # FIXME this is a hack to work around mismatch in origin impl input
  70. # filters
  71. ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
  72. ba['norm_layer'] = self.norm_layer
  73. ba['norm_kwargs'] = self.norm_kwargs
  74. ba['pad_type'] = self.pad_type
  75. # block act fn overrides the model default
  76. ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
  77. assert ba['act_layer'] is not None
  78. if bt == 'ir':
  79. ba['drop_path_rate'] = drop_path_rate
  80. ba['se_kwargs'] = self.se_kwargs
  81. if self.verbose:
  82. self.logger.info(
  83. ' InvertedResidual {}, Args: {}'.format(
  84. block_idx, str(ba)))
  85. block = InvertedResidual(**ba)
  86. elif bt == 'ds' or bt == 'dsa':
  87. ba['drop_path_rate'] = drop_path_rate
  88. ba['se_kwargs'] = self.se_kwargs
  89. if self.verbose:
  90. self.logger.info(
  91. ' DepthwiseSeparable {}, Args: {}'.format(
  92. block_idx, str(ba)))
  93. block = DepthwiseSeparableConv(**ba)
  94. elif bt == 'cn':
  95. if self.verbose:
  96. self.logger.info(
  97. ' ConvBnAct {}, Args: {}'.format(
  98. block_idx, str(ba)))
  99. block = ConvBnAct(**ba)
  100. else:
  101. assert False, 'Uknkown block type (%s) while building model.' % bt
  102. if choice_idx == self.choice_num - 1:
  103. self.in_chs = ba['out_chs'] # update in_chs for arg of next block
  104. return block
  105. def __call__(self, in_chs, model_block_args):
  106. """ Build the blocks
  107. Args:
  108. in_chs: Number of input-channels passed to first block
  109. model_block_args: A list of lists, outer list defines stages, inner
  110. list contains strings defining block configuration(s)
  111. Return:
  112. List of block stacks (each stack wrapped in nn.Sequential)
  113. """
  114. if self.verbose:
  115. self.logger.info('Building model trunk with %d stages...' % len(model_block_args))
  116. self.in_chs = in_chs
  117. total_block_count = sum([len(x) for x in model_block_args])
  118. total_block_idx = 0
  119. current_stride = 2
  120. current_dilation = 1
  121. feature_idx = 0
  122. stages = []
  123. # outer list of block_args defines the stacks ('stages' by some conventions)
  124. for stage_idx, stage_block_args in enumerate(model_block_args):
  125. last_stack = stage_idx == (len(model_block_args) - 1)
  126. if self.verbose:
  127. self.logger.info('Stack: {}'.format(stage_idx))
  128. assert isinstance(stage_block_args, list)
  129. # blocks = []
  130. # each stack (stage) contains a list of block arguments
  131. for block_idx, block_args in enumerate(stage_block_args):
  132. last_block = block_idx == (len(stage_block_args) - 1)
  133. if self.verbose:
  134. self.logger.info(' Block: {}'.format(block_idx))
  135. # Sort out stride, dilation, and feature extraction details
  136. assert block_args['stride'] in (1, 2)
  137. if block_idx >= 1:
  138. # only the first block in any stack can have a stride > 1
  139. block_args['stride'] = 1
  140. next_dilation = current_dilation
  141. if block_args['stride'] > 1:
  142. next_output_stride = current_stride * block_args['stride']
  143. if next_output_stride > self.output_stride:
  144. next_dilation = current_dilation * block_args['stride']
  145. block_args['stride'] = 1
  146. else:
  147. current_stride = next_output_stride
  148. block_args['dilation'] = current_dilation
  149. if next_dilation != current_dilation:
  150. current_dilation = next_dilation
  151. if stage_idx==0 or stage_idx==6:
  152. self.choice_num = 1
  153. else:
  154. self.choice_num = len(self.choices)
  155. if self.dil_conv:
  156. self.choice_num += 2
  157. choice_blocks = []
  158. block_args_copy = deepcopy(block_args)
  159. if self.choice_num == 1:
  160. # create the block
  161. block = self._make_block(block_args, 0, total_block_idx, total_block_count)
  162. choice_blocks.append(block)
  163. else:
  164. for choice_idx, choice in enumerate(self.choices):
  165. # create the block
  166. block_args = deepcopy(block_args_copy)
  167. block_args = modify_block_args(block_args, choice[0], choice[1])
  168. block = self._make_block(block_args, choice_idx, total_block_idx, total_block_count)
  169. choice_blocks.append(block)
  170. if self.dil_conv:
  171. block_args = deepcopy(block_args_copy)
  172. block_args = modify_block_args(block_args, 3, 0)
  173. block = self._make_block(block_args, self.choice_num - 2, total_block_idx, total_block_count,
  174. resunit=self.resunit, dil_conv=self.dil_conv)
  175. choice_blocks.append(block)
  176. block_args = deepcopy(block_args_copy)
  177. block_args = modify_block_args(block_args, 5, 0)
  178. block = self._make_block(block_args, self.choice_num - 1, total_block_idx, total_block_count,
  179. resunit=self.resunit, dil_conv=self.dil_conv)
  180. choice_blocks.append(block)
  181. if self.resunit:
  182. block = get_Bottleneck(block.conv_pw.in_channels,
  183. block.conv_pwl.out_channels,
  184. block.conv_dw.stride[0])
  185. choice_blocks.append(block)
  186. choice_block = LayerChoice(choice_blocks)
  187. stages.append(choice_block)
  188. # create the block
  189. # block = self._make_block(block_args, total_block_idx, total_block_count)
  190. total_block_idx += 1 # incr global block idx (across all stacks)
  191. # stages.append(blocks)
  192. return stages

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能