# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # Written by Hao Du and Houwen Peng # email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com import sys import torch import logging import argparse import torch import torch.nn as nn from copy import deepcopy from torch import optim as optim from thop import profile, clever_format from timm.utils import * from ..config import cfg def get_path_acc(model, path, val_loader, args, val_iters=50): prec1_m = AverageMeter() prec5_m = AverageMeter() with torch.no_grad(): for batch_idx, (input, target) in enumerate(val_loader): if batch_idx >= val_iters: break if not args.prefetcher: input = input.cuda() target = target.cuda() output = model(input, path) if isinstance(output, (tuple, list)): output = output[0] # augmentation reduction reduce_factor = args.tta if reduce_factor > 1: output = output.unfold( 0, reduce_factor, reduce_factor).mean( dim=2) target = target[0:target.size(0):reduce_factor] prec1, prec5 = accuracy(output, target, topk=(1, 5)) torch.cuda.synchronize() prec1_m.update(prec1.item(), output.size(0)) prec5_m.update(prec5.item(), output.size(0)) return (prec1_m.avg, prec5_m.avg) def get_logger(file_path): """ Make python logger """ log_format = '%(asctime)s | %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') logger = logging.getLogger() logger.setLevel(logging.INFO) formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p') file_handler = logging.FileHandler(file_path) file_handler.setFormatter(formatter) logger.addHandler(file_handler) return logger def add_weight_decay_supernet(model, args, weight_decay=1e-5, skip_list=()): decay = [] no_decay = [] meta_layer_no_decay = [] meta_layer_decay = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # frozen weights if len(param.shape) == 1 or name.endswith( ".bias") or name in skip_list: if 'meta_layer' in name: meta_layer_no_decay.append(param) else: no_decay.append(param) else: if 'meta_layer' in name: meta_layer_decay.append(param) else: decay.append(param) return [ {'params': no_decay, 'weight_decay': 0., 'lr': args.lr}, {'params': decay, 'weight_decay': weight_decay, 'lr': args.lr}, {'params': meta_layer_no_decay, 'weight_decay': 0., 'lr': args.meta_lr}, {'params': meta_layer_decay, 'weight_decay': 0, 'lr': args.meta_lr}, ] def create_optimizer_supernet(args, model, has_apex=False, filter_bias_and_bn=True): weight_decay = args.weight_decay if 'adamw' == args.opt or 'radam' == args.opt : weight_decay /= args.lr if weight_decay and filter_bias_and_bn: parameters = add_weight_decay_supernet(model, args, weight_decay) weight_decay = 0. else: parameters = model.parameters() if 'fused' == args.opt: assert has_apex and torch.cuda.is_available( ), 'APEX and CUDA required for fused optimizers' if args.opt == 'sgd' or args.opt == 'nesterov': optimizer = optim.SGD( parameters, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) elif args.opt == 'momentum': optimizer = optim.SGD( parameters, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) elif args.opt == 'adam': optimizer = optim.Adam( parameters, weight_decay=weight_decay, eps=args.opt_eps) else: assert False and "Invalid optimizer" raise ValueError return optimizer def convert_lowercase(cfg): keys = cfg.keys() lowercase_keys = [key.lower() for key in keys] values = [cfg.get(key) for key in keys] for lowercase_key, value in zip(lowercase_keys, values): cfg.setdefault(lowercase_key, value) return cfg # # def parse_config_args(exp_name): # parser = argparse.ArgumentParser(description=exp_name) # parser.add_argument( # '--cfg', # type=str, # default='../experiments/workspace/retrain/retrain.yaml', # help='configuration of cream') # parser.add_argument('--local_rank', type=int, default=0, # help='local_rank') # args = parser.parse_args() # # cfg.merge_from_file(args.cfg) # converted_cfg = convert_lowercase(cfg) # # return args, converted_cfg def get_model_flops_params(model, input_size=(1, 3, 224, 224)): input = torch.randn(input_size) macs, params = profile(deepcopy(model), inputs=(input,), verbose=False) macs, params = clever_format([macs, params], "%.3f") return macs, params def cross_entropy_loss_with_soft_target(pred, soft_target): logsoftmax = nn.LogSoftmax() return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) def create_supernet_scheduler(optimizer, epochs, num_gpu, batch_size, lr): ITERS = epochs * \ (1280000 / (num_gpu * batch_size)) lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: ( lr - step / ITERS) if step <= ITERS else 0, last_epoch=-1) return lr_scheduler, epochs