import sys sys.path.append("../../") import os import re import pickle import torch import torch.nn as nn from pytorch.mutables import LayerChoice from blocks import ShuffleNetBlock, ShuffleXceptionBlock PARSED_FLOPS = {'LayerChoice1': [13396992, 15805440, 19418112, 13146112], 'LayerChoice2': [7325696, 8931328, 11339776, 12343296], 'LayerChoice3': [7325696, 8931328, 11339776, 12343296], 'LayerChoice4': [7325696, 8931328, 11339776, 12343296], 'LayerChoice5': [26304768, 28111104, 30820608, 20296192], 'LayerChoice6': [10599680, 11603200, 13108480, 16746240], 'LayerChoice7': [10599680, 11603200, 13108480, 16746240], 'LayerChoice8': [10599680, 11603200, 13108480, 16746240], 'LayerChoice9': [30670080, 31673600, 33178880, 21199360], 'LayerChoice10': [10317440, 10819200, 11571840, 15899520], 'LayerChoice11': [10317440, 10819200, 11571840, 15899520], 'LayerChoice12': [10317440, 10819200, 11571840, 15899520], 'LayerChoice13': [10317440, 10819200, 11571840, 15899520], 'LayerChoice14': [10317440, 10819200, 11571840, 15899520], 'LayerChoice15': [10317440, 10819200, 11571840, 15899520], 'LayerChoice16': [10317440, 10819200, 11571840, 15899520], 'LayerChoice17': [30387840, 30889600, 31642240, 20634880], 'LayerChoice18': [10176320, 10427200, 10803520, 15476160], 'LayerChoice19': [10176320, 10427200, 10803520, 15476160], 'LayerChoice20': [10176320, 10427200, 10803520, 15476160]} class ShuffleNetV2OneShot(nn.Module): block_keys = [ 'shufflenet_3x3', 'shufflenet_5x5', 'shufflenet_7x7', 'xception_3x3', ] def __init__(self, input_size=224, first_conv_channels=16, last_conv_channels=1024, n_classes=1000, op_flops_path="./data/op_flops_dict.pkl", affine=False): super().__init__() assert input_size % 32 == 0 with open(os.path.join(os.path.dirname(__file__), op_flops_path), "rb") as fp: self._op_flops_dict = pickle.load(fp) self.stage_blocks = [4, 4, 8, 4] self.stage_channels = [64, 160, 320, 640] self._parsed_flops = dict() self._input_size = input_size self._feature_map_size = input_size self._first_conv_channels = first_conv_channels self._last_conv_channels = last_conv_channels self._n_classes = n_classes self._affine = affine # building first layer self.first_conv = nn.Sequential( nn.Conv2d(3, first_conv_channels, 3, 2, 1, bias=False), nn.BatchNorm2d(first_conv_channels, affine=affine), nn.ReLU(inplace=True), ) self._feature_map_size //= 2 p_channels = first_conv_channels features = [] for num_blocks, channels in zip(self.stage_blocks, self.stage_channels): features.extend(self._make_blocks(num_blocks, p_channels, channels)) p_channels = channels self.features = nn.Sequential(*features) self.conv_last = nn.Sequential( nn.Conv2d(p_channels, last_conv_channels, 1, 1, 0, bias=False), nn.BatchNorm2d(last_conv_channels, affine=affine), nn.ReLU(inplace=True), ) self.globalpool = nn.AvgPool2d(self._feature_map_size) self.dropout = nn.Dropout(0.1) self.classifier = nn.Sequential( nn.Linear(last_conv_channels, n_classes, bias=False), ) self._initialize_weights() def _make_blocks(self, blocks, in_channels, channels): result = [] for i in range(blocks): stride = 2 if i == 0 else 1 inp = in_channels if i == 0 else channels oup = channels base_mid_channels = channels // 2 mid_channels = int(base_mid_channels) # prepare for scale choice_block = LayerChoice([ ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=3, stride=stride, affine=self._affine), ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=5, stride=stride, affine=self._affine), ShuffleNetBlock(inp, oup, mid_channels=mid_channels, ksize=7, stride=stride, affine=self._affine), ShuffleXceptionBlock(inp, oup, mid_channels=mid_channels, stride=stride, affine=self._affine) ]) result.append(choice_block) # find the corresponding flops flop_key = (inp, oup, mid_channels, self._feature_map_size, self._feature_map_size, stride) self._parsed_flops[choice_block.key] = [ self._op_flops_dict["{}_stride_{}".format(k, stride)][flop_key] for k in self.block_keys ] if stride == 2: self._feature_map_size //= 2 # ##### mended by han ################### # 通过mutables.LayerChoice生成的choice_block会不断的更新choice_block.key编号,每次自增1, # 这样会使self._parsed_flops的键编号超过20,这样的键是不存在的 # 出于所有算法共用一个mutable的原因,不在其中对 # global_mutable_counting() # _reset_global_mutable_counting() # 两个函数进行调用或修改,因此在此需要对self.parsed_flops的键重命名 _d = dict() for key, value in self._parsed_flops.items(): _head = key[:11] # LayerChoice _index = int(key[11:]) % 20 # 模20,因为choiceblock共有20个,需要保证编号出于0-20 if _index == 0: _index = 20 # 模20为0的索引,事实上应该是20 _d.update({_head + str(_index): value}) self._parsed_flops = _d # ####################################### return result def forward(self, x): bs = x.size(0) x = self.first_conv(x) x = self.features(x) x = self.conv_last(x) x = self.globalpool(x) x = self.dropout(x) x = x.contiguous().view(bs, -1) x = self.classifier(x) return x def get_candidate_flops(self, candidate): conv1_flops = self._op_flops_dict["conv1"][(3, self._first_conv_channels, self._input_size, self._input_size, 2)] # Should use `last_conv_channels` here, but megvii insists that it's `n_classes`. Keeping it. # https://github.com/megvii-model/SinglePathOneShot/blob/36eed6cf083497ffa9cfe7b8da25bb0b6ba5a452/src/Supernet/flops.py#L313 rest_flops = self._op_flops_dict["rest_operation"][(self.stage_channels[-1], self._n_classes, self._feature_map_size, self._feature_map_size, 1)] total_flops = conv1_flops + rest_flops for k, m in candidate.items(): parsed_flops_dict = self._parsed_flops[k] if isinstance(m, dict): # to be compatible with classical nas format total_flops += parsed_flops_dict[m["_idx"]] else: total_flops += parsed_flops_dict[torch.max(m, 0)[1]] return total_flops def _initialize_weights(self): for name, m in self.named_modules(): if isinstance(m, nn.Conv2d): if 'first' in name: nn.init.normal_(m.weight, 0, 0.01) else: nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1]) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0.0001) nn.init.constant_(m.running_mean, 0) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0.0001) nn.init.constant_(m.running_mean, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) def load_and_parse_state_dict(filepath="./data/checkpoint-150000.pth.tar"): checkpoint = torch.load(filepath, map_location=torch.device("cpu")) if "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] result = dict() for k, v in checkpoint.items(): if k.startswith("module."): k = k[len("module."):] result[k] = v return result if __name__ == "__main__": model = ShuffleNetV2OneShot()