import torch import torch.nn as nn class ShuffleNetBlock(nn.Module): """ When stride = 1, the block receives input with 2 * inp channels. Otherwise inp channels. """ def __init__(self, inp, oup, mid_channels, ksize, stride, sequence="pdp", affine=True): super().__init__() assert stride in [1, 2] assert ksize in [3, 5, 7] self.channels = inp // 2 if stride == 1 else inp self.inp = inp self.oup = oup self.mid_channels = mid_channels self.ksize = ksize self.stride = stride self.pad = ksize // 2 self.oup_main = oup - self.channels self._affine = affine assert self.oup_main > 0 self.branch_main = nn.Sequential(*self._decode_point_depth_conv(sequence)) if stride == 2: self.branch_proj = nn.Sequential( # dw nn.Conv2d(self.channels, self.channels, ksize, stride, self.pad, groups=self.channels, bias=False), nn.BatchNorm2d(self.channels, affine=affine), # pw-linear nn.Conv2d(self.channels, self.channels, 1, 1, 0, bias=False), nn.BatchNorm2d(self.channels, affine=affine), nn.ReLU(inplace=True) ) def forward(self, x): if self.stride == 2: x_proj, x = self.branch_proj(x), x else: x_proj, x = self._channel_shuffle(x) return torch.cat((x_proj, self.branch_main(x)), 1) def _decode_point_depth_conv(self, sequence): result = [] first_depth = first_point = True pc = c = self.channels for i, token in enumerate(sequence): # compute output channels of this conv if i + 1 == len(sequence): assert token == "p", "Last conv must be point-wise conv." c = self.oup_main elif token == "p" and first_point: c = self.mid_channels if token == "d": # depth-wise conv assert pc == c, "Depth-wise conv must not change channels." result.append(nn.Conv2d(pc, c, self.ksize, self.stride if first_depth else 1, self.pad, groups=c, bias=False)) result.append(nn.BatchNorm2d(c, affine=self._affine)) first_depth = False elif token == "p": # point-wise conv result.append(nn.Conv2d(pc, c, 1, 1, 0, bias=False)) result.append(nn.BatchNorm2d(c, affine=self._affine)) result.append(nn.ReLU(inplace=True)) first_point = False else: raise ValueError("Conv sequence must be d and p.") pc = c return result def _channel_shuffle(self, x): bs, num_channels, height, width = x.data.size() assert (num_channels % 4 == 0) x = x.reshape(bs * num_channels // 2, 2, height * width) x = x.permute(1, 0, 2) x = x.reshape(2, -1, num_channels // 2, height, width) return x[0], x[1] class ShuffleXceptionBlock(ShuffleNetBlock): def __init__(self, inp, oup, mid_channels, stride, affine=True): super().__init__(inp, oup, mid_channels, 3, stride, "dpdpdp", affine)