# This file is downloaded from # https://github.com/rwightman/pytorch-image-models import torch.nn as nn from timm.models.layers import create_conv2d from timm.models.efficientnet_blocks import make_divisible, resolve_se_args, \ SqueezeExcite, drop_path class InvertedResidual(nn.Module): """ Inverted residual block w/ optional SE and CondConv routing""" def __init__( self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, conv_kwargs=None, drop_path_rate=0.): super(InvertedResidual, self).__init__() norm_kwargs = norm_kwargs or {} conv_kwargs = conv_kwargs or {} mid_chs = make_divisible(in_chs * exp_ratio) has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_path_rate = drop_path_rate # Point-wise expansion self.conv_pw = create_conv2d( in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) self.bn1 = norm_layer(mid_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Depth-wise convolution self.conv_dw = create_conv2d( mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True, **conv_kwargs) self.bn2 = norm_layer(mid_chs, **norm_kwargs) self.act2 = act_layer(inplace=True) # Squeeze-and-excitation if has_se: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) else: self.se = None # Point-wise linear projection self.conv_pwl = create_conv2d( mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) self.bn3 = norm_layer(out_chs, **norm_kwargs) def feature_info(self, location): if location == 'expansion': # after SE, input to PWL info = dict( module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) else: # location == 'bottleneck', block output info = dict( module='', hook_type='', num_chs=self.conv_pwl.out_channels) return info def forward(self, x): residual = x # Point-wise expansion x = self.conv_pw(x) x = self.bn1(x) x = self.act1(x) # Depth-wise convolution x = self.conv_dw(x) x = self.bn2(x) x = self.act2(x) # Squeeze-and-excitation if self.se is not None: x = self.se(x) # Point-wise linear projection x = self.conv_pwl(x) x = self.bn3(x) if self.has_residual: if self.drop_path_rate > 0.: x = drop_path(x, self.drop_path_rate, self.training) x += residual return x