# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. from collections import OrderedDict import torch import torch.nn as nn from pytorch import mutables from pytorch.darts import ops def random_channel_shuffle(x): num_channels = x.data.size()[1] indices = torch.randperm(num_channels) x = x[:, indices] return x def channel_shuffle(x, groups): batchsize, num_channels, height, width = x.data.size() channels_per_group = num_channels // groups # reshape x = x.view(batchsize, groups, channels_per_group, height, width) x = torch.transpose(x, 1, 2).contiguous() # flatten x = x.view(batchsize, -1, height, width) return x class AuxiliaryHead(nn.Module): """ Auxiliary head in 2/3 place of network to let the gradient flow well """ def __init__(self, input_size, C, n_classes): """ assuming input size 7x7 or 8x8 """ assert input_size in [7, 8] super().__init__() self.net = nn.Sequential( nn.ReLU(inplace=True), nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out nn.Conv2d(C, 128, kernel_size=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out nn.BatchNorm2d(768), nn.ReLU(inplace=True) ) self.linear = nn.Linear(768, n_classes) def forward(self, x): out = self.net(x) out = out.view(out.size(0), -1) # flatten logits = self.linear(out) return logits class Node(nn.Module): def __init__(self, node_id, num_prev_nodes, channels, k, num_downsample_connect, search): super().__init__() if search: self.k = k partial_channles = channels // k else: partial_channles = channels self.search = search self.ops = nn.ModuleList() choice_keys = [] for i in range(num_prev_nodes): stride = 2 if i < num_downsample_connect else 1 choice_keys.append("{}_p{}".format(node_id, i)) self.ops.append( mutables.LayerChoice(OrderedDict([ ("maxpool", ops.PoolBN('max', partial_channles, 3, stride, 1, affine=False)), ("avgpool", ops.PoolBN('avg', partial_channles, 3, stride, 1, affine=False)), ("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(partial_channles, partial_channles, affine=False)), ("sepconv3x3", ops.SepConv(partial_channles, partial_channles, 3, stride, 1, affine=False)), ("sepconv5x5", ops.SepConv(partial_channles, partial_channles, 5, stride, 2, affine=False)), ("dilconv3x3", ops.DilConv(partial_channles, partial_channles, 3, stride, 2, 2, affine=False)), ("dilconv5x5", ops.DilConv(partial_channles, partial_channles, 5, stride, 4, 2, affine=False)) ]), key=choice_keys[-1])) self.drop_path = ops.DropPath() self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id)) self.pool = nn.MaxPool2d(2,2) def forward(self, prev_nodes): assert len(self.ops) == len(prev_nodes), "len(self.ops) != len(prev_nodes) in Node" # for each candicate predecessor of each intermediate node if self.search: # in search results = [] for op, x in zip(self.ops, prev_nodes): # channel shuffle channels = x.shape[1] # channel proportion k=4 temp0 = x[ : , : channels//self.k, : , :] temp1 = x[ : ,channels//self.k : , : , :] out = op(temp0) # normal if out.shape[2] == x.shape[2]: result = torch.cat([out, temp1], dim=1) # reduction else: result = torch.cat([out, self.pool(temp1)], dim=1) results.append(channel_shuffle(result, self.k)) # # channel random shuffule # channels = random_channel_shuffle(x).shape[1] # # channel proportion k=4 # temp0 = x[ : , : channels//self.k, : , :] # temp1 = x[ : ,channels//self.k : , : , :] # out = op(temp0) # # normal # if out.shape[2] == x.shape[2]: # result = torch.cat([out, temp1], dim=1) # # reduction # else: # result = torch.cat([out, self.pool(temp1)], dim=1) # results.append(result) else: # in retrain, no channel shuffle results = [op(node) for op, node in zip(self.ops, prev_nodes)] output = [self.drop_path(re) if re is not None else None for re in results] return self.input_switch(output) class Cell(nn.Module): def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction, k, search): super().__init__() self.reduction = reduction self.n_nodes = n_nodes # If previous cell is reduction cell, current input size does not match with # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. if reduction_p: self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False) else: self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False) self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False) # generate dag self.mutable_ops = nn.ModuleList() for depth in range(2, self.n_nodes + 2): self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth), depth, channels, k, 2 if reduction else 0, search)) def forward(self, s0, s1): # s0, s1 are the outputs of previous previous cell and previous cell, respectively. tensors = [self.preproc0(s0), self.preproc1(s1)] for node in self.mutable_ops: cur_tensor = node(tensors) tensors.append(cur_tensor) output = torch.cat(tensors[2:], dim=1) return output class CNN(nn.Module): def __init__(self, input_size, in_channels, channels, n_classes, n_layers, k=4, n_nodes=4, stem_multiplier=3, auxiliary=False, search=True): super().__init__() self.in_channels = in_channels self.channels = channels self.n_classes = n_classes self.n_layers = n_layers self.n_nodes = n_nodes self.aux_pos = 2 * n_layers // 3 if auxiliary else -1 c_cur = stem_multiplier * self.channels self.stem = nn.Sequential( nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False), nn.BatchNorm2d(c_cur) ) # for the first cell, stem is used for both s0 and s1 # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size. channels_pp, channels_p, c_cur = c_cur, c_cur, channels self.cells = nn.ModuleList() reduction_p, reduction = False, False for i in range(n_layers): reduction_p, reduction = reduction, False # Reduce featuremap size and double channels in 1/3 and 2/3 layer. if i in [n_layers // 3, 2 * n_layers // 3]: c_cur *= 2 reduction = True cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction, k, search) self.cells.append(cell) c_cur_out = c_cur * n_nodes channels_pp, channels_p = channels_p, c_cur_out if i == self.aux_pos: self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes) self.gap = nn.AdaptiveAvgPool2d(1) self.linear = nn.Linear(channels_p, n_classes) def forward(self, x): s0 = s1 = self.stem(x) aux_logits = None for i, cell in enumerate(self.cells): s0, s1 = s1, cell(s0, s1) if i == self.aux_pos and self.training: aux_logits = self.aux_head(s1) out = self.gap(s1) out = out.view(out.size(0), -1) # flatten logits = self.linear(out) if aux_logits is not None: return logits, aux_logits return logits def drop_path_prob(self, p): for module in self.modules(): if isinstance(module, ops.DropPath): module.p = p def _loss(self, input, target): logits = self(input) return self._criterion(logits, target)