# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import torch import torch.nn as nn import torch.nn.functional as F from pytorch import mutables from ops import FactorizedReduce, StdConv, SepConvBN, Pool class AuxiliaryHead(nn.Module): def __init__(self, in_channels, num_classes): super().__init__() self.in_channels = in_channels self.num_classes = num_classes self.pooling = nn.Sequential( nn.ReLU(), nn.AvgPool2d(5, 3, 2) ) self.proj = nn.Sequential( StdConv(in_channels, 128), StdConv(128, 768) ) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Linear(768, 10, bias=False) def forward(self, x): bs = x.size(0) x = self.pooling(x) x = self.proj(x) x = self.avg_pool(x).view(bs, -1) x = self.fc(x) return x class Cell(nn.Module): def __init__(self, cell_name, prev_labels, channels): super().__init__() self.input_choice = mutables.InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True, key=cell_name + "_input") self.op_choice = mutables.LayerChoice([ SepConvBN(channels, channels, 3, 1), SepConvBN(channels, channels, 5, 2), Pool("avg", 3, 1, 1), Pool("max", 3, 1, 1), nn.Identity() ], key=cell_name + "_op") def forward(self, prev_layers): chosen_input, chosen_mask = self.input_choice(prev_layers) cell_out = self.op_choice(chosen_input) return cell_out, chosen_mask class Node(mutables.MutableScope): def __init__(self, node_name, prev_node_names, channels): super().__init__(node_name) self.cell_x = Cell(node_name + "_x", prev_node_names, channels) self.cell_y = Cell(node_name + "_y", prev_node_names, channels) def forward(self, prev_layers): out_x, mask_x = self.cell_x(prev_layers) out_y, mask_y = self.cell_y(prev_layers) return out_x + out_y, mask_x | mask_y class Calibration(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.process = None if in_channels != out_channels: self.process = StdConv(in_channels, out_channels) def forward(self, x): if self.process is None: return x return self.process(x) class ReductionLayer(nn.Module): def __init__(self, in_channels_pp, in_channels_p, out_channels): super().__init__() self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False) self.reduce1 = FactorizedReduce(in_channels_p, out_channels, affine=False) def forward(self, pprev, prev): return self.reduce0(pprev), self.reduce1(prev) class ENASLayer(nn.Module): def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduction): super().__init__() self.preproc0 = Calibration(in_channels_pp, out_channels) self.preproc1 = Calibration(in_channels_p, out_channels) self.num_nodes = num_nodes name_prefix = "reduce" if reduction else "normal" self.nodes = nn.ModuleList() node_labels = [mutables.InputChoice.NO_KEY, mutables.InputChoice.NO_KEY] for i in range(num_nodes): node_labels.append("{}_node_{}".format(name_prefix, i)) self.nodes.append(Node(node_labels[-1], node_labels[:-1], out_channels)) self.final_conv_w = nn.Parameter(torch.zeros(out_channels, self.num_nodes + 2, out_channels, 1, 1), requires_grad=True) self.bn = nn.BatchNorm2d(out_channels, affine=False) self.reset_parameters() def reset_parameters(self): nn.init.kaiming_normal_(self.final_conv_w) def forward(self, pprev, prev): pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev) prev_nodes_out = [pprev_, prev_] nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device) for i in range(self.num_nodes): node_out, mask = self.nodes[i](prev_nodes_out) nodes_used_mask[:mask.size(0)] |= mask.to(node_out.device) prev_nodes_out.append(node_out) unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1) unused_nodes = F.relu(unused_nodes) conv_weight = self.final_conv_w[:, ~nodes_used_mask, :, :, :] conv_weight = conv_weight.view(conv_weight.size(0), -1, 1, 1) out = F.conv2d(unused_nodes, conv_weight) return prev, self.bn(out) class MicroNetwork(nn.Module): def __init__(self, num_layers=2, num_nodes=5, out_channels=24, in_channels=3, num_classes=10, dropout_rate=0.0, use_aux_heads=False): super().__init__() self.num_layers = num_layers self.use_aux_heads = use_aux_heads self.stem = nn.Sequential( nn.Conv2d(in_channels, out_channels * 3, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels * 3) ) pool_distance = self.num_layers // 3 pool_layers = [pool_distance, 2 * pool_distance + 1] self.dropout = nn.Dropout(dropout_rate) self.layers = nn.ModuleList() c_pp = c_p = out_channels * 3 c_cur = out_channels for layer_id in range(self.num_layers + 2): reduction = False if layer_id in pool_layers: c_cur, reduction = c_p * 2, True self.layers.append(ReductionLayer(c_pp, c_p, c_cur)) c_pp = c_p = c_cur self.layers.append(ENASLayer(num_nodes, c_pp, c_p, c_cur, reduction)) if self.use_aux_heads and layer_id == pool_layers[-1] + 1: self.layers.append(AuxiliaryHead(c_cur, num_classes)) c_pp, c_p = c_p, c_cur self.gap = nn.AdaptiveAvgPool2d(1) self.dense = nn.Linear(c_cur, num_classes) self.reset_parameters() def reset_parameters(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) def forward(self, x): bs = x.size(0) prev = cur = self.stem(x) aux_logits = None for layer in self.layers: if isinstance(layer, AuxiliaryHead): if self.training: aux_logits = layer(cur) else: prev, cur = layer(prev, cur) cur = self.gap(F.relu(cur)).view(bs, -1) cur = self.dropout(cur) logits = self.dense(cur) if aux_logits is not None: return logits, aux_logits return logits