You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util.py 5.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT License.
  3. # Written by Hao Du and Houwen Peng
  4. # email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
  5. import sys
  6. import torch
  7. import logging
  8. import argparse
  9. import torch
  10. import torch.nn as nn
  11. from copy import deepcopy
  12. from torch import optim as optim
  13. from thop import profile, clever_format
  14. from timm.utils import *
  15. from ..config import cfg
  16. def get_path_acc(model, path, val_loader, args, val_iters=50):
  17. prec1_m = AverageMeter()
  18. prec5_m = AverageMeter()
  19. with torch.no_grad():
  20. for batch_idx, (input, target) in enumerate(val_loader):
  21. if batch_idx >= val_iters:
  22. break
  23. if not args.prefetcher:
  24. input = input.cuda()
  25. target = target.cuda()
  26. output = model(input, path)
  27. if isinstance(output, (tuple, list)):
  28. output = output[0]
  29. # augmentation reduction
  30. reduce_factor = args.tta
  31. if reduce_factor > 1:
  32. output = output.unfold(
  33. 0,
  34. reduce_factor,
  35. reduce_factor).mean(
  36. dim=2)
  37. target = target[0:target.size(0):reduce_factor]
  38. prec1, prec5 = accuracy(output, target, topk=(1, 5))
  39. torch.cuda.synchronize()
  40. prec1_m.update(prec1.item(), output.size(0))
  41. prec5_m.update(prec5.item(), output.size(0))
  42. return (prec1_m.avg, prec5_m.avg)
  43. def get_logger(file_path):
  44. """ Make python logger """
  45. log_format = '%(asctime)s | %(message)s'
  46. logging.basicConfig(stream=sys.stdout, level=logging.INFO,
  47. format=log_format, datefmt='%m/%d %I:%M:%S %p')
  48. logger = logging.getLogger()
  49. logger.setLevel(logging.INFO)
  50. formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p')
  51. file_handler = logging.FileHandler(file_path)
  52. file_handler.setFormatter(formatter)
  53. logger.addHandler(file_handler)
  54. return logger
  55. def add_weight_decay_supernet(model, args, weight_decay=1e-5, skip_list=()):
  56. decay = []
  57. no_decay = []
  58. meta_layer_no_decay = []
  59. meta_layer_decay = []
  60. for name, param in model.named_parameters():
  61. if not param.requires_grad:
  62. continue # frozen weights
  63. if len(param.shape) == 1 or name.endswith(
  64. ".bias") or name in skip_list:
  65. if 'meta_layer' in name:
  66. meta_layer_no_decay.append(param)
  67. else:
  68. no_decay.append(param)
  69. else:
  70. if 'meta_layer' in name:
  71. meta_layer_decay.append(param)
  72. else:
  73. decay.append(param)
  74. return [
  75. {'params': no_decay, 'weight_decay': 0., 'lr': args.lr},
  76. {'params': decay, 'weight_decay': weight_decay, 'lr': args.lr},
  77. {'params': meta_layer_no_decay, 'weight_decay': 0., 'lr': args.meta_lr},
  78. {'params': meta_layer_decay, 'weight_decay': 0, 'lr': args.meta_lr},
  79. ]
  80. def create_optimizer_supernet(args, model, has_apex=False, filter_bias_and_bn=True):
  81. weight_decay = args.weight_decay
  82. if 'adamw' == args.opt or 'radam' == args.opt :
  83. weight_decay /= args.lr
  84. if weight_decay and filter_bias_and_bn:
  85. parameters = add_weight_decay_supernet(model, args, weight_decay)
  86. weight_decay = 0.
  87. else:
  88. parameters = model.parameters()
  89. if 'fused' == args.opt:
  90. assert has_apex and torch.cuda.is_available(
  91. ), 'APEX and CUDA required for fused optimizers'
  92. if args.opt == 'sgd' or args.opt == 'nesterov':
  93. optimizer = optim.SGD(
  94. parameters,
  95. momentum=args.momentum,
  96. weight_decay=weight_decay,
  97. nesterov=True)
  98. elif args.opt == 'momentum':
  99. optimizer = optim.SGD(
  100. parameters,
  101. momentum=args.momentum,
  102. weight_decay=weight_decay,
  103. nesterov=False)
  104. elif args.opt == 'adam':
  105. optimizer = optim.Adam(
  106. parameters, weight_decay=weight_decay, eps=args.opt_eps)
  107. else:
  108. assert False and "Invalid optimizer"
  109. raise ValueError
  110. return optimizer
  111. def convert_lowercase(cfg):
  112. keys = cfg.keys()
  113. lowercase_keys = [key.lower() for key in keys]
  114. values = [cfg.get(key) for key in keys]
  115. for lowercase_key, value in zip(lowercase_keys, values):
  116. cfg.setdefault(lowercase_key, value)
  117. return cfg
  118. #
  119. # def parse_config_args(exp_name):
  120. # parser = argparse.ArgumentParser(description=exp_name)
  121. # parser.add_argument(
  122. # '--cfg',
  123. # type=str,
  124. # default='../experiments/workspace/retrain/retrain.yaml',
  125. # help='configuration of cream')
  126. # parser.add_argument('--local_rank', type=int, default=0,
  127. # help='local_rank')
  128. # args = parser.parse_args()
  129. #
  130. # cfg.merge_from_file(args.cfg)
  131. # converted_cfg = convert_lowercase(cfg)
  132. #
  133. # return args, converted_cfg
  134. def get_model_flops_params(model, input_size=(1, 3, 224, 224)):
  135. input = torch.randn(input_size)
  136. macs, params = profile(deepcopy(model), inputs=(input,), verbose=False)
  137. macs, params = clever_format([macs, params], "%.3f")
  138. return macs, params
  139. def cross_entropy_loss_with_soft_target(pred, soft_target):
  140. logsoftmax = nn.LogSoftmax()
  141. return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
  142. def create_supernet_scheduler(optimizer, epochs, num_gpu, batch_size, lr):
  143. ITERS = epochs * \
  144. (1280000 / (num_gpu * batch_size))
  145. lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: (
  146. lr - step / ITERS) if step <= ITERS else 0, last_epoch=-1)
  147. return lr_scheduler, epochs

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能