You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test.py 6.0 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT License.
  3. # Written by Hao Du and Houwen Peng
  4. # email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
  5. import os
  6. import warnings
  7. import datetime
  8. import torch
  9. import torch.nn as nn
  10. # from torch.utils.tensorboard import SummaryWriter
  11. # import timm packages
  12. from timm.utils import ModelEma
  13. from timm.models import resume_checkpoint
  14. from timm.data import Dataset, create_loader
  15. # import apex as distributed package
  16. try:
  17. from apex.parallel import convert_syncbn_model
  18. from apex.parallel import DistributedDataParallel as DDP
  19. HAS_APEX = True
  20. except ImportError as e:
  21. print(e)
  22. from torch.nn.parallel import DistributedDataParallel as DDP
  23. HAS_APEX = False
  24. # import models and training functions
  25. from lib.core.test import validate
  26. from lib.models.structures.childnet import gen_childnet
  27. from lib.utils.util import parse_config_args, get_logger, get_model_flops_params
  28. from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  29. def main():
  30. args, cfg = parse_config_args('child net testing')
  31. # resolve logging
  32. output_dir = os.path.join(cfg.SAVE_PATH,
  33. "{}-{}".format(datetime.date.today().strftime('%m%d'),
  34. cfg.MODEL))
  35. if not os.path.exists(output_dir):
  36. os.mkdir(output_dir)
  37. if args.local_rank == 0:
  38. logger = get_logger(os.path.join(output_dir, 'test.log'))
  39. writer = None # SummaryWriter(os.path.join(output_dir, 'runs'))
  40. else:
  41. writer, logger = None, None
  42. # retrain model selection
  43. if cfg.NET.SELECTION == 481:
  44. arch_list = [
  45. [0], [
  46. 3, 4, 3, 1], [
  47. 3, 2, 3, 0], [
  48. 3, 3, 3, 1], [
  49. 3, 3, 3, 3], [
  50. 3, 3, 3, 3], [0]]
  51. cfg.DATASET.IMAGE_SIZE = 224
  52. elif cfg.NET.SELECTION == 43:
  53. arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]]
  54. cfg.DATASET.IMAGE_SIZE = 96
  55. elif cfg.NET.SELECTION == 14:
  56. arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]]
  57. cfg.DATASET.IMAGE_SIZE = 64
  58. elif cfg.NET.SELECTION == 112:
  59. arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]]
  60. cfg.DATASET.IMAGE_SIZE = 160
  61. elif cfg.NET.SELECTION == 287:
  62. arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]]
  63. cfg.DATASET.IMAGE_SIZE = 224
  64. elif cfg.NET.SELECTION == 604:
  65. arch_list = [[0], [3, 3, 2, 3, 3], [3, 2, 3, 2, 3], [3, 2, 3, 2, 3],
  66. [3, 3, 2, 2, 3, 3], [3, 3, 2, 3, 3, 3], [0]]
  67. cfg.DATASET.IMAGE_SIZE = 224
  68. else:
  69. raise ValueError("Model Test Selection is not Supported!")
  70. # define childnet architecture from arch_list
  71. stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25']
  72. # TODO: this param from NNI is different from microsoft/Cream.
  73. choice_block_pool = ['ir_r1_k3_s2_e4_c24_se0.25',
  74. 'ir_r1_k5_s2_e4_c40_se0.25',
  75. 'ir_r1_k3_s2_e6_c80_se0.25',
  76. 'ir_r1_k3_s1_e6_c96_se0.25',
  77. 'ir_r1_k5_s2_e6_c192_se0.25']
  78. arch_def = [[stem[0]]] + [[choice_block_pool[idx]
  79. for repeat_times in range(len(arch_list[idx + 1]))]
  80. for idx in range(len(choice_block_pool))] + [[stem[1]]]
  81. # generate childnet
  82. model = gen_childnet(
  83. arch_list,
  84. arch_def,
  85. num_classes=cfg.DATASET.NUM_CLASSES,
  86. drop_rate=cfg.NET.DROPOUT_RATE,
  87. global_pool=cfg.NET.GP)
  88. if args.local_rank == 0:
  89. macs, params = get_model_flops_params(model, input_size=(
  90. 1, 3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE))
  91. logger.info(
  92. '[Model-{}] Flops: {} Params: {}'.format(cfg.NET.SELECTION, macs, params))
  93. # initialize distributed parameters
  94. torch.cuda.set_device(args.local_rank)
  95. torch.distributed.init_process_group(backend='nccl', init_method='env://')
  96. if args.local_rank == 0:
  97. logger.info(
  98. "Training on Process {} with {} GPUs.".format(
  99. args.local_rank, cfg.NUM_GPU))
  100. # resume model from checkpoint
  101. assert cfg.AUTO_RESUME is True and os.path.exists(cfg.RESUME_PATH)
  102. resume_checkpoint(model, cfg.RESUME_PATH)
  103. model = model.cuda()
  104. model_ema = None
  105. if cfg.NET.EMA.USE:
  106. # Important to create EMA model after cuda(), DP wrapper, and AMP but
  107. # before SyncBN and DDP wrapper
  108. model_ema = ModelEma(
  109. model,
  110. decay=cfg.NET.EMA.DECAY,
  111. device='cpu' if cfg.NET.EMA.FORCE_CPU else '',
  112. resume=cfg.RESUME_PATH)
  113. # imagenet validation dataset
  114. eval_dir = os.path.join(cfg.DATA_DIR, 'val')
  115. if not os.path.exists(eval_dir) and args.local_rank == 0:
  116. logger.error(
  117. 'Validation folder does not exist at: {}'.format(eval_dir))
  118. exit(1)
  119. dataset_eval = Dataset(eval_dir)
  120. loader_eval = create_loader(
  121. dataset_eval,
  122. input_size=(3, cfg.DATASET.IMAGE_SIZE, cfg.DATASET.IMAGE_SIZE),
  123. batch_size=cfg.DATASET.VAL_BATCH_MUL * cfg.DATASET.BATCH_SIZE,
  124. is_training=False,
  125. num_workers=cfg.WORKERS,
  126. distributed=True,
  127. pin_memory=cfg.DATASET.PIN_MEM,
  128. crop_pct=DEFAULT_CROP_PCT,
  129. mean=IMAGENET_DEFAULT_MEAN,
  130. std=IMAGENET_DEFAULT_STD
  131. )
  132. # only test accuracy of model-EMA
  133. validate_loss_fn = nn.CrossEntropyLoss().cuda()
  134. validate(0, model, loader_eval, validate_loss_fn, cfg,
  135. log_suffix='_EMA', logger=logger,
  136. writer=writer, local_rank=args.local_rank)
  137. if cfg.NET.EMA.USE:
  138. validate(0, model_ema.ema, loader_eval, validate_loss_fn, cfg,
  139. log_suffix='_EMA', logger=logger,
  140. writer=writer, local_rank=args.local_rank)
  141. if __name__ == '__main__':
  142. main()

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能