|
- from copy import deepcopy
-
- from ...utils.builder_util import modify_block_args
- from ..blocks import get_Bottleneck, InvertedResidual
-
- from timm.models.efficientnet_blocks import *
-
- from pytorch.mutables import LayerChoice
-
- class SuperNetBuilder:
- """ Build Trunk Blocks
- """
-
- def __init__(
- self,
- choices,
- channel_multiplier=1.0,
- channel_divisor=8,
- channel_min=None,
- output_stride=32,
- pad_type='',
- act_layer=None,
- se_kwargs=None,
- norm_layer=nn.BatchNorm2d,
- norm_kwargs=None,
- drop_path_rate=0.,
- feature_location='',
- verbose=False,
- resunit=False,
- dil_conv=False,
- logger=None):
-
- # dict
- # choices = {'kernel_size': [3, 5, 7], 'exp_ratio': [4, 6]}
- self.choices = [[x, y] for x in choices['kernel_size']
- for y in choices['exp_ratio']]
- self.choices_num = len(self.choices) - 1
- self.channel_multiplier = channel_multiplier
- self.channel_divisor = channel_divisor
- self.channel_min = channel_min
- self.output_stride = output_stride
- self.pad_type = pad_type
- self.act_layer = act_layer
- self.se_kwargs = se_kwargs
- self.norm_layer = norm_layer
- self.norm_kwargs = norm_kwargs
- self.drop_path_rate = drop_path_rate
- self.feature_location = feature_location
- assert feature_location in ('pre_pwl', 'post_exp', '')
- self.verbose = verbose
- self.resunit = resunit
- self.dil_conv = dil_conv
- self.logger = logger
-
- # state updated during build, consumed by model
- self.in_chs = None
-
- def _round_channels(self, chs):
- return round_channels(
- chs,
- self.channel_multiplier,
- self.channel_divisor,
- self.channel_min)
-
- def _make_block(
- self,
- ba,
- choice_idx,
- block_idx,
- block_count,
- resunit=False,
- dil_conv=False):
- drop_path_rate = self.drop_path_rate * block_idx / block_count
- bt = ba.pop('block_type')
- ba['in_chs'] = self.in_chs
- ba['out_chs'] = self._round_channels(ba['out_chs'])
- if 'fake_in_chs' in ba and ba['fake_in_chs']:
- # FIXME this is a hack to work around mismatch in origin impl input
- # filters
- ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
- ba['norm_layer'] = self.norm_layer
- ba['norm_kwargs'] = self.norm_kwargs
- ba['pad_type'] = self.pad_type
- # block act fn overrides the model default
- ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
- assert ba['act_layer'] is not None
- if bt == 'ir':
- ba['drop_path_rate'] = drop_path_rate
- ba['se_kwargs'] = self.se_kwargs
- if self.verbose:
- self.logger.info(
- ' InvertedResidual {}, Args: {}'.format(
- block_idx, str(ba)))
- block = InvertedResidual(**ba)
- elif bt == 'ds' or bt == 'dsa':
- ba['drop_path_rate'] = drop_path_rate
- ba['se_kwargs'] = self.se_kwargs
- if self.verbose:
- self.logger.info(
- ' DepthwiseSeparable {}, Args: {}'.format(
- block_idx, str(ba)))
- block = DepthwiseSeparableConv(**ba)
- elif bt == 'cn':
- if self.verbose:
- self.logger.info(
- ' ConvBnAct {}, Args: {}'.format(
- block_idx, str(ba)))
- block = ConvBnAct(**ba)
- else:
- assert False, 'Uknkown block type (%s) while building model.' % bt
- if choice_idx == self.choice_num - 1:
- self.in_chs = ba['out_chs'] # update in_chs for arg of next block
-
- return block
-
- def __call__(self, in_chs, model_block_args):
- """ Build the blocks
- Args:
- in_chs: Number of input-channels passed to first block
- model_block_args: A list of lists, outer list defines stages, inner
- list contains strings defining block configuration(s)
- Return:
- List of block stacks (each stack wrapped in nn.Sequential)
- """
- if self.verbose:
- self.logger.info('Building model trunk with %d stages...' % len(model_block_args))
- self.in_chs = in_chs
- total_block_count = sum([len(x) for x in model_block_args])
- total_block_idx = 0
- current_stride = 2
- current_dilation = 1
- feature_idx = 0
- stages = []
- # outer list of block_args defines the stacks ('stages' by some conventions)
- for stage_idx, stage_block_args in enumerate(model_block_args):
- last_stack = stage_idx == (len(model_block_args) - 1)
- if self.verbose:
- self.logger.info('Stack: {}'.format(stage_idx))
- assert isinstance(stage_block_args, list)
-
- # blocks = []
- # each stack (stage) contains a list of block arguments
- for block_idx, block_args in enumerate(stage_block_args):
- last_block = block_idx == (len(stage_block_args) - 1)
- if self.verbose:
- self.logger.info(' Block: {}'.format(block_idx))
-
- # Sort out stride, dilation, and feature extraction details
- assert block_args['stride'] in (1, 2)
- if block_idx >= 1:
- # only the first block in any stack can have a stride > 1
- block_args['stride'] = 1
-
- next_dilation = current_dilation
- if block_args['stride'] > 1:
- next_output_stride = current_stride * block_args['stride']
- if next_output_stride > self.output_stride:
- next_dilation = current_dilation * block_args['stride']
- block_args['stride'] = 1
- else:
- current_stride = next_output_stride
- block_args['dilation'] = current_dilation
- if next_dilation != current_dilation:
- current_dilation = next_dilation
-
-
- if stage_idx==0 or stage_idx==6:
- self.choice_num = 1
- else:
- self.choice_num = len(self.choices)
-
- if self.dil_conv:
- self.choice_num += 2
-
- choice_blocks = []
- block_args_copy = deepcopy(block_args)
- if self.choice_num == 1:
- # create the block
- block = self._make_block(block_args, 0, total_block_idx, total_block_count)
- choice_blocks.append(block)
- else:
- for choice_idx, choice in enumerate(self.choices):
- # create the block
- block_args = deepcopy(block_args_copy)
- block_args = modify_block_args(block_args, choice[0], choice[1])
- block = self._make_block(block_args, choice_idx, total_block_idx, total_block_count)
- choice_blocks.append(block)
- if self.dil_conv:
- block_args = deepcopy(block_args_copy)
- block_args = modify_block_args(block_args, 3, 0)
- block = self._make_block(block_args, self.choice_num - 2, total_block_idx, total_block_count,
- resunit=self.resunit, dil_conv=self.dil_conv)
- choice_blocks.append(block)
-
- block_args = deepcopy(block_args_copy)
- block_args = modify_block_args(block_args, 5, 0)
- block = self._make_block(block_args, self.choice_num - 1, total_block_idx, total_block_count,
- resunit=self.resunit, dil_conv=self.dil_conv)
- choice_blocks.append(block)
-
- if self.resunit:
- block = get_Bottleneck(block.conv_pw.in_channels,
- block.conv_pwl.out_channels,
- block.conv_dw.stride[0])
- choice_blocks.append(block)
-
- choice_block = LayerChoice(choice_blocks)
- stages.append(choice_block)
- # create the block
- # block = self._make_block(block_args, total_block_idx, total_block_count)
- total_block_idx += 1 # incr global block idx (across all stacks)
-
- # stages.append(blocks)
- return stages
|