You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

build_childnet.py 7.6 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. from ...utils.util import *
  2. from collections import OrderedDict
  3. from timm.models.efficientnet_blocks import *
  4. class ChildNetBuilder:
  5. def __init__(
  6. self,
  7. channel_multiplier=1.0,
  8. channel_divisor=8,
  9. channel_min=None,
  10. output_stride=32,
  11. pad_type='',
  12. act_layer=None,
  13. se_kwargs=None,
  14. norm_layer=nn.BatchNorm2d,
  15. norm_kwargs=None,
  16. drop_path_rate=0.,
  17. feature_location='',
  18. verbose=False,
  19. logger=None):
  20. self.channel_multiplier = channel_multiplier
  21. self.channel_divisor = channel_divisor
  22. self.channel_min = channel_min
  23. self.output_stride = output_stride
  24. self.pad_type = pad_type
  25. self.act_layer = act_layer
  26. self.se_kwargs = se_kwargs
  27. self.norm_layer = norm_layer
  28. self.norm_kwargs = norm_kwargs
  29. self.drop_path_rate = drop_path_rate
  30. self.feature_location = feature_location
  31. assert feature_location in ('pre_pwl', 'post_exp', '')
  32. self.verbose = verbose
  33. self.in_chs = None
  34. self.features = OrderedDict()
  35. self.logger = logger
  36. def _round_channels(self, chs):
  37. return round_channels(
  38. chs,
  39. self.channel_multiplier,
  40. self.channel_divisor,
  41. self.channel_min)
  42. def _make_block(self, ba, block_idx, block_count):
  43. drop_path_rate = self.drop_path_rate * block_idx / block_count
  44. bt = ba.pop('block_type')
  45. ba['in_chs'] = self.in_chs
  46. ba['out_chs'] = self._round_channels(ba['out_chs'])
  47. if 'fake_in_chs' in ba and ba['fake_in_chs']:
  48. ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
  49. ba['norm_layer'] = self.norm_layer
  50. ba['norm_kwargs'] = self.norm_kwargs
  51. ba['pad_type'] = self.pad_type
  52. # block act fn overrides the model default
  53. ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
  54. assert ba['act_layer'] is not None
  55. if bt == 'ir':
  56. ba['drop_path_rate'] = drop_path_rate
  57. ba['se_kwargs'] = self.se_kwargs
  58. if self.verbose:
  59. self.logger.info(
  60. ' InvertedResidual {}, Args: {}'.format(
  61. block_idx, str(ba)))
  62. block = InvertedResidual(**ba)
  63. elif bt == 'ds' or bt == 'dsa':
  64. ba['drop_path_rate'] = drop_path_rate
  65. ba['se_kwargs'] = self.se_kwargs
  66. if self.verbose:
  67. self.logger.info(
  68. ' DepthwiseSeparable {}, Args: {}'.format(
  69. block_idx, str(ba)))
  70. block = DepthwiseSeparableConv(**ba)
  71. elif bt == 'cn':
  72. if self.verbose:
  73. self.logger.info(
  74. ' ConvBnAct {}, Args: {}'.format(
  75. block_idx, str(ba)))
  76. block = ConvBnAct(**ba)
  77. else:
  78. assert False, 'Uknkown block type (%s) while building model.' % bt
  79. self.in_chs = ba['out_chs'] # update in_chs for arg of next block
  80. return block
  81. def __call__(self, in_chs, model_block_args):
  82. """ Build the blocks
  83. Args:
  84. in_chs: Number of input-channels passed to first block
  85. model_block_args: A list of lists, outer list defines stages, inner
  86. list contains strings defining block configuration(s)
  87. Return:
  88. List of block stacks (each stack wrapped in nn.Sequential)
  89. """
  90. if self.verbose:
  91. self.logger.info(
  92. 'Building model trunk with %d stages...' %
  93. len(model_block_args))
  94. self.in_chs = in_chs
  95. total_block_count = sum([len(x) for x in model_block_args])
  96. total_block_idx = 0
  97. current_stride = 2
  98. current_dilation = 1
  99. feature_idx = 0
  100. stages = []
  101. # outer list of block_args defines the stacks ('stages' by some
  102. # conventions)
  103. for stage_idx, stage_block_args in enumerate(model_block_args):
  104. last_stack = stage_idx == (len(model_block_args) - 1)
  105. if self.verbose:
  106. self.logger.info('Stack: {}'.format(stage_idx))
  107. assert isinstance(stage_block_args, list)
  108. blocks = []
  109. # each stack (stage) contains a list of block arguments
  110. for block_idx, block_args in enumerate(stage_block_args):
  111. last_block = block_idx == (len(stage_block_args) - 1)
  112. extract_features = '' # No features extracted
  113. if self.verbose:
  114. self.logger.info(' Block: {}'.format(block_idx))
  115. # Sort out stride, dilation, and feature extraction details
  116. assert block_args['stride'] in (1, 2)
  117. if block_idx >= 1:
  118. # only the first block in any stack can have a stride > 1
  119. block_args['stride'] = 1
  120. do_extract = False
  121. if self.feature_location == 'pre_pwl':
  122. if last_block:
  123. next_stage_idx = stage_idx + 1
  124. if next_stage_idx >= len(model_block_args):
  125. do_extract = True
  126. else:
  127. do_extract = model_block_args[next_stage_idx][0]['stride'] > 1
  128. elif self.feature_location == 'post_exp':
  129. if block_args['stride'] > 1 or (last_stack and last_block):
  130. do_extract = True
  131. if do_extract:
  132. extract_features = self.feature_location
  133. next_dilation = current_dilation
  134. if block_args['stride'] > 1:
  135. next_output_stride = current_stride * block_args['stride']
  136. if next_output_stride > self.output_stride:
  137. next_dilation = current_dilation * block_args['stride']
  138. block_args['stride'] = 1
  139. if self.verbose:
  140. self.logger.info(
  141. ' Converting stride to dilation to maintain output_stride=={}'.format(
  142. self.output_stride))
  143. else:
  144. current_stride = next_output_stride
  145. block_args['dilation'] = current_dilation
  146. if next_dilation != current_dilation:
  147. current_dilation = next_dilation
  148. # create the block
  149. block = self._make_block(
  150. block_args, total_block_idx, total_block_count)
  151. blocks.append(block)
  152. # stash feature module name and channel info for model feature
  153. # extraction
  154. if extract_features:
  155. feature_module = block.feature_module(extract_features)
  156. if feature_module:
  157. feature_module = 'blocks.{}.{}.'.format(
  158. stage_idx, block_idx) + feature_module
  159. feature_channels = block.feature_channels(extract_features)
  160. self.features[feature_idx] = dict(
  161. name=feature_module,
  162. num_chs=feature_channels
  163. )
  164. feature_idx += 1
  165. # incr global block idx (across all stacks)
  166. total_block_idx += 1
  167. stages.append(nn.Sequential(*blocks))
  168. return stages

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能