# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import torch import torch.nn as nn class StdConv(nn.Module): def __init__(self, C_in, C_out): super(StdConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), nn.BatchNorm2d(C_out, affine=False), nn.ReLU() ) def forward(self, x): return self.conv(x) def __str__(self): return 'StdConv' class PoolBranch(nn.Module): def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False): super().__init__() self.kernel_size = kernel_size self.pool_type = pool_type self.preproc = StdConv(C_in, C_out) self.pool = Pool(pool_type, kernel_size, stride, padding) self.bn = nn.BatchNorm2d(C_out, affine=affine) def forward(self, x): out = self.preproc(x) out = self.pool(out) out = self.bn(out) return out def __str__(self): return '{}PoolBranch_{}'.format(self.pool_type, self.kernel_size) class SeparableConv(nn.Module): def __init__(self, C_in, C_out, kernel_size, stride, padding): self.kernel_size = kernel_size super(SeparableConv, self).__init__() self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride, groups=C_in, bias=False) self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False) def forward(self, x): out = self.depthwise(x) out = self.pointwise(out) return out def __str__(self): return 'SeparableConv_{}'.format(self.kernel_size) class ConvBranch(nn.Module): def __init__(self, C_in, C_out, kernel_size, stride, padding, separable): super(ConvBranch, self).__init__() self.kernel_size = kernel_size self.preproc = StdConv(C_in, C_out) if separable: self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding) else: self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding) self.postproc = nn.Sequential( nn.BatchNorm2d(C_out, affine=False), nn.ReLU() ) def forward(self, x): out = self.preproc(x) out = self.conv(out) out = self.postproc(out) return out def __str__(self): return 'ConvBranch_{}'.format(self.kernel_size) class FactorizedReduce(nn.Module): def __init__(self, C_in, C_out, affine=False): super().__init__() self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) self.bn = nn.BatchNorm2d(C_out, affine=affine) def forward(self, x): out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) out = self.bn(out) return out def __str__(self): return 'FactorizedReduce' class Pool(nn.Module): def __init__(self, pool_type, kernel_size, stride, padding): super().__init__() self.kernel_size = kernel_size self.pool_type = pool_type if pool_type.lower() == 'max': self.pool = nn.MaxPool2d(kernel_size, stride, padding) elif pool_type.lower() == 'avg': self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) else: raise ValueError() def forward(self, x): return self.pool(x) def __str__(self): return '{}Pool_{}'.format(self.pool_type, self.kernel_size) class SepConvBN(nn.Module): def __init__(self, C_in, C_out, kernel_size, padding): super().__init__() self.kernel_size = kernel_size self.relu = nn.ReLU() self.conv = SeparableConv(C_in, C_out, kernel_size, 1, padding) self.bn = nn.BatchNorm2d(C_out, affine=True) def forward(self, x): x = self.relu(x) x = self.conv(x) x = self.bn(x) return x def __str__(self): return 'SepConvBN_{}'.format(self.kernel_size)