You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

retrainer.py 18 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT License.
  3. # Written by Hao Du and Houwen Peng
  4. # email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
  5. import sys
  6. sys.path.append('../..')
  7. import os
  8. import json
  9. import time
  10. import timm
  11. import torch
  12. import numpy as np
  13. import torch.nn as nn
  14. from argparse import ArgumentParser
  15. # from torch.utils.tensorboard import SummaryWriter
  16. # import timm packages
  17. from timm.optim import create_optimizer
  18. from timm.models import resume_checkpoint
  19. from timm.scheduler import create_scheduler
  20. from timm.data import Dataset, create_loader
  21. from timm.utils import CheckpointSaver, ModelEma, update_summary
  22. from timm.loss import LabelSmoothingCrossEntropy
  23. # import apex as distributed package
  24. try:
  25. from apex import amp
  26. from apex.parallel import DistributedDataParallel as DDP
  27. from apex.parallel import convert_syncbn_model
  28. HAS_APEX = True
  29. except ImportError as e:
  30. print(e)
  31. from torch.nn.parallel import DistributedDataParallel as DDP
  32. HAS_APEX = False
  33. # import models and training functions
  34. from pytorch.utils import mkdirs, save_best_checkpoint, str2bool
  35. from lib.core.test import validate
  36. from lib.core.retrain import train_epoch
  37. from lib.models.structures.childnet import gen_childnet
  38. from lib.utils.util import get_logger, get_model_flops_params
  39. from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  40. def parse_args():
  41. """See lib.utils.config"""
  42. parser = ArgumentParser()
  43. # path
  44. parser.add_argument("--best_checkpoint_dir", type=str, default='./output/best_checkpoint/')
  45. parser.add_argument("--checkpoint_dir", type=str, default='./output/checkpoints/')
  46. parser.add_argument("--data_dir", type=str, default='./data')
  47. parser.add_argument("--experiment_dir", type=str, default='./')
  48. parser.add_argument("--model_name", type=str, default='retrainer')
  49. parser.add_argument("--log_path", type=str, default='output/log')
  50. parser.add_argument("--result_path", type=str, default='output/result.json')
  51. parser.add_argument("--best_selected_space_path", type=str,
  52. default='output/selected_space.json')
  53. # int
  54. parser.add_argument("--acc_gap", type=int, default=5)
  55. parser.add_argument("--batch_size", type=int, default=32)
  56. parser.add_argument("--cooldown_epochs", type=int, default=10)
  57. parser.add_argument("--decay_epochs", type=int, default=10)
  58. parser.add_argument("--epochs", type=int, default=200)
  59. parser.add_argument("--flops_minimum", type=int, default=0)
  60. parser.add_argument("--flops_maximum", type=int, default=200)
  61. parser.add_argument("--image_size", type=int, default=224)
  62. parser.add_argument("--local_rank", type=int, default=0)
  63. parser.add_argument("--log_interval", type=int, default=50)
  64. parser.add_argument("--meta_sta_epoch", type=int, default=20)
  65. parser.add_argument("--num_classes", type=int, default=1000)
  66. parser.add_argument("--num_gpu", type=int, default=1)
  67. parser.add_argument("--parience_epochs", type=int, default=10)
  68. parser.add_argument("--pool_size", type=int, default=10)
  69. parser.add_argument("--recovery_interval", type=int, default=10)
  70. parser.add_argument("--trial_id", type=int, default=42)
  71. parser.add_argument("--selection", type=int, default=-1)
  72. parser.add_argument("--slice_num", type=int, default=4)
  73. parser.add_argument("--tta", type=int, default=0)
  74. parser.add_argument("--update_iter", type=int, default=1300)
  75. parser.add_argument("--val_batch_mul", type=int, default=4)
  76. parser.add_argument("--warmup_epochs", type=int, default=3)
  77. parser.add_argument("--workers", type=int, default=4)
  78. # float
  79. parser.add_argument("--color_jitter", type=float, default=0.4)
  80. parser.add_argument("--decay_rate", type=float, default=0.1)
  81. parser.add_argument("--dropout_rate", type=float, default=0.0)
  82. parser.add_argument("--ema_decay", type=float, default=0.998)
  83. parser.add_argument("--lr", type=float, default=1e-2)
  84. parser.add_argument("--meta_lr", type=float, default=1e-4)
  85. parser.add_argument("--re_prob", type=float, default=0.2)
  86. parser.add_argument("--opt_eps", type=float, default=1e-2)
  87. parser.add_argument("--momentum", type=float, default=0.9)
  88. parser.add_argument("--min_lr", type=float, default=1e-5)
  89. parser.add_argument("--smoothing", type=float, default=0.1)
  90. parser.add_argument("--weight_decay", type=float, default=1e-4)
  91. parser.add_argument("--warmup_lr", type=float, default=1e-4)
  92. # bool
  93. parser.add_argument("--auto_resume", type=str2bool, default='False')
  94. parser.add_argument("--dil_conv", type=str2bool, default='False')
  95. parser.add_argument("--ema_cpu", type=str2bool, default='False')
  96. parser.add_argument("--pin_mem", type=str2bool, default='True')
  97. parser.add_argument("--resunit", type=str2bool, default='False')
  98. parser.add_argument("--save_images", type=str2bool, default='False')
  99. parser.add_argument("--sync_bn", type=str2bool, default='False')
  100. parser.add_argument("--use_ema", type=str2bool, default='False')
  101. parser.add_argument("--verbose", type=str2bool, default='False')
  102. # str
  103. parser.add_argument("--aa", type=str, default='rand-m9-mstd0.5')
  104. parser.add_argument("--eval_metrics", type=str, default='prec1')
  105. # gp: type of global pool ["avg", "max", "avgmax", "avgmaxc"]
  106. parser.add_argument("--gp", type=str, default='avg')
  107. parser.add_argument("--interpolation", type=str, default='bilinear')
  108. parser.add_argument("--opt", type=str, default='sgd')
  109. parser.add_argument("--pick_method", type=str, default='meta')
  110. parser.add_argument("--re_mode", type=str, default='pixel')
  111. parser.add_argument("--sched", type=str, default='sgd')
  112. args = parser.parse_args()
  113. args.sync_bn = False
  114. args.verbose = False
  115. args.data_dir = args.data_dir + "/imagenet"
  116. return args
  117. def main():
  118. args = parse_args()
  119. mkdirs(args.checkpoint_dir + "/",
  120. args.experiment_dir,
  121. args.best_selected_space_path,
  122. args.result_path)
  123. with open(args.result_path, "w") as ss_file:
  124. ss_file.write('')
  125. if len(args.checkpoint_dir > 1):
  126. mkdirs(args.best_checkpoint_dir + "/")
  127. args.checkpoint_dir = os.path.join(
  128. args.checkpoint_dir,
  129. "{}_{}".format(args.model_name, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
  130. )
  131. if not os.path.exists(args.checkpoint_dir):
  132. os.mkdir(args.checkpoint_dir)
  133. # resolve logging
  134. if args.local_rank == 0:
  135. logger = get_logger(args.log_path)
  136. writer = None # SummaryWriter(os.path.join(output_dir, 'runs'))
  137. else:
  138. writer, logger = None, None
  139. # retrain model selection
  140. if args.selection == -1:
  141. if os.path.exists(args.best_selected_space_path):
  142. with open(args.best_selected_space_path, "r") as f:
  143. arch_list = json.load(f)['selected_space']
  144. else:
  145. args.selection = 14
  146. logger.warning("args.best_selected_space_path is not exist. Set selection to 14.")
  147. if args.selection == 481:
  148. arch_list = [
  149. [0], [
  150. 3, 4, 3, 1], [
  151. 3, 2, 3, 0], [
  152. 3, 3, 3, 1], [
  153. 3, 3, 3, 3], [
  154. 3, 3, 3, 3], [0]]
  155. args.image_size = 224
  156. elif args.selection == 43:
  157. arch_list = [[0], [3], [3, 1], [3, 1], [3, 3, 3], [3, 3], [0]]
  158. args.image_size = 96
  159. elif args.selection == 14:
  160. arch_list = [[0], [3], [3, 3], [3, 3], [3], [3], [0]]
  161. args.image_size = 64
  162. elif args.selection == 112:
  163. arch_list = [[0], [3], [3, 3], [3, 3], [3, 3, 3], [3, 3], [0]]
  164. args.image_size = 160
  165. elif args.selection == 287:
  166. arch_list = [[0], [3], [3, 3], [3, 1, 3], [3, 3, 3, 3], [3, 3, 3], [0]]
  167. args.image_size = 224
  168. elif args.selection == 604:
  169. arch_list = [
  170. [0], [
  171. 3, 3, 2, 3, 3], [
  172. 3, 2, 3, 2, 3], [
  173. 3, 2, 3, 2, 3], [
  174. 3, 3, 2, 2, 3, 3], [
  175. 3, 3, 2, 3, 3, 3], [0]]
  176. args.image_size = 224
  177. elif args.selection == -1:
  178. args.image_size = 224
  179. else:
  180. raise ValueError("Model Retrain Selection is not Supported!")
  181. print(arch_list)
  182. # define childnet architecture from arch_list
  183. stem = ['ds_r1_k3_s1_e1_c16_se0.25', 'cn_r1_k1_s1_c320_se0.25']
  184. # TODO: this param from NNI is different from microsoft/Cream.
  185. choice_block_pool = ['ir_r1_k3_s2_e4_c24_se0.25',
  186. 'ir_r1_k5_s2_e4_c40_se0.25',
  187. 'ir_r1_k3_s2_e6_c80_se0.25',
  188. 'ir_r1_k3_s1_e6_c96_se0.25',
  189. 'ir_r1_k5_s2_e6_c192_se0.25']
  190. arch_def = [[stem[0]]] + [[choice_block_pool[idx]
  191. for repeat_times in range(len(arch_list[idx + 1]))]
  192. for idx in range(len(choice_block_pool))] + [[stem[1]]]
  193. # generate childnet
  194. model = gen_childnet(
  195. arch_list,
  196. arch_def,
  197. num_classes=args.num_classes,
  198. drop_rate=args.dropout_rate,
  199. global_pool=args.gp)
  200. # initialize distributed parameters
  201. distributed = args.num_gpu > 1
  202. torch.cuda.set_device(args.local_rank)
  203. if args.local_rank == 0:
  204. logger.info(
  205. 'Training on Process {} with {} GPUs.'.format(
  206. args.local_rank, args.num_gpu))
  207. # fix random seeds
  208. torch.manual_seed(args.trial_id)
  209. torch.cuda.manual_seed_all(args.trial_id)
  210. np.random.seed(args.trial_id)
  211. torch.backends.cudnn.deterministic = True
  212. torch.backends.cudnn.benchmark = False
  213. # get parameters and FLOPs of model
  214. if args.local_rank == 0:
  215. macs, params = get_model_flops_params(model, input_size=(
  216. 1, 3, args.image_size, args.image_size))
  217. logger.info(
  218. '[Model-{}] Flops: {} Params: {}'.format(args.selection, macs, params))
  219. # create optimizer
  220. model = model.cuda()
  221. optimizer = create_optimizer(args, model)
  222. # optionally resume from a checkpoint
  223. resume_epoch = None
  224. if args.auto_resume:
  225. if int(timm.__version__[2]) >= 3:
  226. resume_epoch = resume_checkpoint(model, args.experiment_dir, optimizer)
  227. else:
  228. resume_state, resume_epoch = resume_checkpoint(model, args.experiment_dir)
  229. optimizer.load_state_dict(resume_state['optimizer'])
  230. del resume_state
  231. model_ema = None
  232. if args.use_ema:
  233. model_ema = ModelEma(
  234. model,
  235. decay=args.ema_decay,
  236. device='cpu' if args.ema_cpu else '',
  237. resume=args.experiment_dir if args.auto_resume else None)
  238. # initialize training parameters
  239. eval_metric = args.eval_metrics
  240. best_metric, best_epoch, saver = None, None, None
  241. if args.local_rank == 0:
  242. decreasing = True if eval_metric == 'loss' else False
  243. if int(timm.__version__[2]) >= 3:
  244. saver = CheckpointSaver(model, optimizer,
  245. checkpoint_dir=args.checkpoint_dir,
  246. recovery_dir=args.checkpoint_dir,
  247. model_ema=model_ema,
  248. decreasing=decreasing,
  249. max_history=2)
  250. else:
  251. saver = CheckpointSaver(
  252. checkpoint_dir=args.checkpoint_dir,
  253. recovery_dir=args.checkpoint_dir,
  254. decreasing=decreasing,
  255. max_history=2)
  256. if distributed:
  257. torch.distributed.init_process_group(backend='nccl', init_method='env://')
  258. if args.sync_bn:
  259. try:
  260. if HAS_APEX:
  261. model = convert_syncbn_model(model)
  262. else:
  263. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  264. if args.local_rank == 0:
  265. logger.info('Converted model to use Synchronized BatchNorm.')
  266. except Exception as e:
  267. if args.local_rank == 0:
  268. logger.error(
  269. 'Failed to enable Synchronized BatchNorm. '
  270. 'Install Apex or Torch >= 1.1 with exception {}'.format(e))
  271. if HAS_APEX:
  272. model = DDP(model, delay_allreduce=True)
  273. else:
  274. if args.local_rank == 0:
  275. logger.info(
  276. "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
  277. # can use device str in Torch >= 1.1
  278. model = DDP(model, device_ids=[args.local_rank], find_unused_parameters=True)
  279. # imagenet train dataset
  280. train_dir = os.path.join(args.data_dir, 'train')
  281. if not os.path.exists(train_dir) and args.local_rank == 0:
  282. logger.error('Training folder does not exist at: {}'.format(train_dir))
  283. exit(1)
  284. dataset_train = Dataset(train_dir)
  285. loader_train = create_loader(
  286. dataset_train,
  287. input_size=(3, args.image_size, args.image_size),
  288. batch_size=args.batch_size,
  289. is_training=True,
  290. color_jitter=args.color_jitter,
  291. auto_augment=args.aa,
  292. num_aug_splits=0,
  293. crop_pct=DEFAULT_CROP_PCT,
  294. mean=IMAGENET_DEFAULT_MEAN,
  295. std=IMAGENET_DEFAULT_STD,
  296. num_workers=args.workers,
  297. distributed=distributed,
  298. collate_fn=None,
  299. pin_memory=args.pin_mem,
  300. interpolation='random',
  301. re_mode=args.re_mode,
  302. re_prob=args.re_prob
  303. )
  304. # imagenet validation dataset
  305. eval_dir = os.path.join(args.data_dir, 'val')
  306. if not os.path.exists(eval_dir) and args.local_rank == 0:
  307. logger.error(
  308. 'Validation folder does not exist at: {}'.format(eval_dir))
  309. exit(1)
  310. dataset_eval = Dataset(eval_dir)
  311. loader_eval = create_loader(
  312. dataset_eval,
  313. input_size=(3, args.image_size, args.image_size),
  314. batch_size=args.val_batch_mul * args.batch_size,
  315. is_training=False,
  316. interpolation=args.interpolation,
  317. crop_pct=DEFAULT_CROP_PCT,
  318. mean=IMAGENET_DEFAULT_MEAN,
  319. std=IMAGENET_DEFAULT_STD,
  320. num_workers=args.workers,
  321. distributed=distributed,
  322. pin_memory=args.pin_mem
  323. )
  324. # whether to use label smoothing
  325. if args.smoothing > 0.:
  326. train_loss_fn = LabelSmoothingCrossEntropy(
  327. smoothing=args.smoothing).cuda()
  328. validate_loss_fn = nn.CrossEntropyLoss().cuda()
  329. else:
  330. train_loss_fn = nn.CrossEntropyLoss().cuda()
  331. validate_loss_fn = train_loss_fn
  332. # create learning rate scheduler
  333. lr_scheduler, num_epochs = create_scheduler(args, optimizer)
  334. start_epoch = resume_epoch if resume_epoch is not None else 0
  335. if start_epoch > 0:
  336. lr_scheduler.step(start_epoch)
  337. if args.local_rank == 0:
  338. logger.info('Scheduled epochs: {}'.format(num_epochs))
  339. try:
  340. best_record, best_ep = 0, 0
  341. for epoch in range(start_epoch, num_epochs):
  342. if distributed:
  343. loader_train.sampler.set_epoch(epoch)
  344. train_metrics = train_epoch(
  345. epoch,
  346. model,
  347. loader_train,
  348. optimizer,
  349. train_loss_fn,
  350. args,
  351. lr_scheduler=lr_scheduler,
  352. saver=saver,
  353. output_dir=args.checkpoint_dir,
  354. model_ema=model_ema,
  355. logger=logger,
  356. writer=writer,
  357. local_rank=args.local_rank)
  358. eval_metrics = validate(
  359. epoch,
  360. model,
  361. loader_eval,
  362. validate_loss_fn,
  363. args,
  364. logger=logger,
  365. writer=writer,
  366. local_rank=args.local_rank,
  367. result_path=args.result_path
  368. )
  369. if model_ema is not None and not args.ema_cpu:
  370. ema_eval_metrics = validate(
  371. epoch,
  372. model_ema.ema,
  373. loader_eval,
  374. validate_loss_fn,
  375. args,
  376. log_suffix='_EMA',
  377. logger=logger,
  378. writer=writer,
  379. local_rank=args.local_rank
  380. )
  381. eval_metrics = ema_eval_metrics
  382. if lr_scheduler is not None:
  383. lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
  384. update_summary(epoch, train_metrics, eval_metrics, os.path.join(
  385. args.checkpoint_dir, 'summary.csv'), write_header=best_metric is None)
  386. if saver is not None:
  387. # save proper checkpoint with eval metric
  388. save_metric = eval_metrics[eval_metric]
  389. if int(timm.__version__[2]) >= 3:
  390. best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
  391. else:
  392. best_metric, best_epoch = saver.save_checkpoint(
  393. model, optimizer, args,
  394. epoch=epoch, metric=save_metric)
  395. if best_record < eval_metrics[eval_metric]:
  396. best_record = eval_metrics[eval_metric]
  397. best_ep = epoch
  398. if args.local_rank == 0:
  399. logger.info(
  400. '*** Best metric: {0} (epoch {1})'.format(best_record, best_ep))
  401. except KeyboardInterrupt:
  402. pass
  403. if best_metric is not None:
  404. logger.info(
  405. '*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
  406. save_best_checkpoint(args.best_checkpoint_dir, model, optimizer, epoch)
  407. if __name__ == '__main__':
  408. main()

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能