You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 12 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # https://github.com/microsoft/nni/blob/v2.0/examples/nas/cream/train.py
  2. import sys
  3. sys.path.append('../..')
  4. import os
  5. import sys
  6. import time
  7. import json
  8. import torch
  9. import numpy as np
  10. import torch.nn as nn
  11. from argparse import ArgumentParser
  12. # import timm packages
  13. from timm.loss import LabelSmoothingCrossEntropy
  14. from timm.data import Dataset, create_loader
  15. from timm.models import resume_checkpoint
  16. # import apex as distributed package
  17. # try:
  18. # from apex.parallel import DistributedDataParallel as DDP
  19. # from apex.parallel import convert_syncbn_model
  20. #
  21. # USE_APEX = True
  22. # except ImportError as e:
  23. # print(e)
  24. # from torch.nn.parallel import DistributedDataParallel as DDP
  25. #
  26. # USE_APEX = False
  27. # import models and training functions
  28. from lib.utils.flops_table import FlopsEst
  29. from lib.models.structures.supernet import gen_supernet
  30. from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
  31. from lib.utils.util import get_logger, \
  32. create_optimizer_supernet, create_supernet_scheduler
  33. from pytorch.utils import mkdirs, str2bool
  34. from pytorch.callbacks import LRSchedulerCallback
  35. from pytorch.callbacks import ModelCheckpoint
  36. from algorithms import CreamSupernetTrainer
  37. from algorithms import RandomMutator
  38. def parse_args():
  39. """See lib.utils.config"""
  40. parser = ArgumentParser()
  41. # path
  42. parser.add_argument("--checkpoint_dir", type=str, default='')
  43. parser.add_argument("--data_dir", type=str, default='./data')
  44. parser.add_argument("--experiment_dir", type=str, default='./')
  45. parser.add_argument("--model_name", type=str, default='trainer')
  46. parser.add_argument("--log_path", type=str, default='output/log')
  47. parser.add_argument("--result_path", type=str, default='output/result.json')
  48. parser.add_argument("--search_space_path", type=str, default='output/search_space.json')
  49. parser.add_argument("--best_selected_space_path", type=str,
  50. default='output/selected_space.json')
  51. # int
  52. parser.add_argument("--acc_gap", type=int, default=5)
  53. parser.add_argument("--batch_size", type=int, default=1)
  54. parser.add_argument("--epochs", type=int, default=200)
  55. parser.add_argument("--flops_minimum", type=int, default=0)
  56. parser.add_argument("--flops_maximum", type=int, default=200)
  57. parser.add_argument("--image_size", type=int, default=224)
  58. parser.add_argument("--local_rank", type=int, default=0)
  59. parser.add_argument("--log_interval", type=int, default=50)
  60. parser.add_argument("--meta_sta_epoch", type=int, default=20)
  61. parser.add_argument("--num_classes", type=int, default=1000)
  62. parser.add_argument("--num_gpu", type=int, default=1)
  63. parser.add_argument("--pool_size", type=int, default=10)
  64. parser.add_argument("--trial_id", type=int, default=42)
  65. parser.add_argument("--slice_num", type=int, default=4)
  66. parser.add_argument("--tta", type=int, default=0)
  67. parser.add_argument("--update_iter", type=int, default=1300)
  68. parser.add_argument("--workers", type=int, default=4)
  69. # float
  70. parser.add_argument("--color_jitter", type=float, default=0.4)
  71. parser.add_argument("--dropout_rate", type=float, default=0.0)
  72. parser.add_argument("--lr", type=float, default=1e-2)
  73. parser.add_argument("--meta_lr", type=float, default=1e-4)
  74. parser.add_argument("--opt_eps", type=float, default=1e-2)
  75. parser.add_argument("--re_prob", type=float, default=0.2)
  76. parser.add_argument("--momentum", type=float, default=0.9)
  77. parser.add_argument("--smoothing", type=float, default=0.1)
  78. parser.add_argument("--weight_decay", type=float, default=1e-4)
  79. # bool
  80. parser.add_argument("--auto_resume", type=str2bool, default='False')
  81. parser.add_argument("--dil_conv", type=str2bool, default='False')
  82. parser.add_argument("--resunit", type=str2bool, default='False')
  83. parser.add_argument("--sync_bn", type=str2bool, default='False')
  84. parser.add_argument("--verbose", type=str2bool, default='False')
  85. # str
  86. # gp: type of global pool ["avg", "max", "avgmax", "avgmaxc"]
  87. parser.add_argument("--gp", type=str, default='avg')
  88. parser.add_argument("--interpolation", type=str, default='bilinear')
  89. parser.add_argument("--opt", type=str, default='sgd')
  90. parser.add_argument("--pick_method", type=str, default='meta')
  91. parser.add_argument("--re_mode", type=str, default='pixel')
  92. args = parser.parse_args()
  93. args.sync_bn = False
  94. args.verbose = False
  95. args.data_dir = args.data_dir + "/imagenet"
  96. return args
  97. def main():
  98. args = parse_args()
  99. mkdirs(args.experiment_dir,
  100. args.best_selected_space_path,
  101. args.search_space_path,
  102. args.result_path,
  103. args.log_path)
  104. with open(args.result_path, "w") as ss_file:
  105. ss_file.write('')
  106. # resolve logging
  107. if len(args.checkpoint_dir > 1):
  108. mkdirs(args.checkpoint_dir + "/")
  109. args.checkpoint_dir = os.path.join(
  110. args.checkpoint_dir,
  111. "{}_{}".format(args.model_name, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
  112. )
  113. if not os.path.exists(args.checkpoint_dir):
  114. os.mkdir(args.checkpoint_dir)
  115. if args.local_rank == 0:
  116. logger = get_logger(args.log_path)
  117. else:
  118. logger = None
  119. # initialize distributed parameters
  120. torch.cuda.set_device(args.local_rank)
  121. # torch.distributed.init_process_group(backend='nccl', init_method='env://')
  122. if args.local_rank == 0:
  123. logger.info(
  124. 'Training on Process %d with %d GPUs.',
  125. args.local_rank, args.num_gpu)
  126. # fix random seeds
  127. torch.manual_seed(args.trial_id)
  128. torch.cuda.manual_seed_all(args.trial_id)
  129. np.random.seed(args.trial_id)
  130. torch.backends.cudnn.deterministic = True
  131. torch.backends.cudnn.benchmark = False
  132. # generate supernet and optimizer
  133. model, sta_num, resolution, search_space = gen_supernet(
  134. flops_minimum=args.flops_minimum,
  135. flops_maximum=args.flops_maximum,
  136. num_classes=args.num_classes,
  137. drop_rate=args.dropout_rate,
  138. global_pool=args.gp,
  139. resunit=args.resunit,
  140. dil_conv=args.dil_conv,
  141. slice=args.slice_num,
  142. verbose=args.verbose,
  143. logger=logger)
  144. optimizer = create_optimizer_supernet(args, model)
  145. # number of choice blocks in supernet
  146. choice_num = len(model.blocks[7])
  147. if args.local_rank == 0:
  148. logger.info('Supernet created, param count: %d', (
  149. sum([m.numel() for m in model.parameters()])))
  150. logger.info('resolution: %d', resolution)
  151. logger.info('choice number: %d', choice_num)
  152. with open(args.search_space_path, "w") as f:
  153. print("dump search space.")
  154. json.dump({'search_space': search_space}, f)
  155. # initialize flops look-up table
  156. model_est = FlopsEst(model)
  157. flops_dict, flops_fixed = model_est.flops_dict, model_est.flops_fixed
  158. model = model.cuda()
  159. # convert model to distributed mode
  160. if args.sync_bn:
  161. try:
  162. # if USE_APEX:
  163. # model = convert_syncbn_model(model)
  164. # else:
  165. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  166. if args.local_rank == 0:
  167. logger.info('Converted model to use Synchronized BatchNorm.')
  168. except Exception as exception:
  169. logger.info(
  170. 'Failed to enable Synchronized BatchNorm. '
  171. 'Install Apex or Torch >= 1.1 with Exception %s', exception)
  172. # if USE_APEX:
  173. # model = DDP(model, delay_allreduce=True)
  174. # else:
  175. # if args.local_rank == 0:
  176. # logger.info(
  177. # "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
  178. # # can use device str in Torch >= 1.1
  179. # model = DDP(model, device_ids=[args.local_rank], find_unused_parameters=True)
  180. # optionally resume from a checkpoint
  181. resume_epoch = None
  182. if False: # args.auto_resume:
  183. checkpoint = torch.load(args.experiment_dir)
  184. model.load_state_dict(checkpoint['child_model_state_dict'])
  185. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  186. resume_epoch = checkpoint['epoch']
  187. # create learning rate scheduler
  188. lr_scheduler, num_epochs = create_supernet_scheduler(optimizer, args.epochs, args.num_gpu,
  189. args.batch_size, args.lr)
  190. start_epoch = resume_epoch if resume_epoch is not None else 0
  191. if start_epoch > 0:
  192. lr_scheduler.step(start_epoch)
  193. if args.local_rank == 0:
  194. logger.info('Scheduled epochs: %d', num_epochs)
  195. # imagenet train dataset
  196. train_dir = os.path.join(args.data_dir, 'train')
  197. if not os.path.exists(train_dir):
  198. logger.info('Training folder does not exist at: %s', train_dir)
  199. sys.exit()
  200. dataset_train = Dataset(train_dir)
  201. loader_train = create_loader(
  202. dataset_train,
  203. input_size=(3, args.image_size, args.image_size),
  204. batch_size=args.batch_size,
  205. is_training=True,
  206. use_prefetcher=True,
  207. re_prob=args.re_prob,
  208. re_mode=args.re_mode,
  209. color_jitter=args.color_jitter,
  210. interpolation='random',
  211. num_workers=args.workers,
  212. distributed=False,
  213. collate_fn=None,
  214. crop_pct=DEFAULT_CROP_PCT,
  215. mean=IMAGENET_DEFAULT_MEAN,
  216. std=IMAGENET_DEFAULT_STD
  217. )
  218. # imagenet validation dataset
  219. eval_dir = os.path.join(args.data_dir, 'val')
  220. if not os.path.isdir(eval_dir):
  221. logger.info('Validation folder does not exist at: %s', eval_dir)
  222. sys.exit()
  223. dataset_eval = Dataset(eval_dir)
  224. loader_eval = create_loader(
  225. dataset_eval,
  226. input_size=(3, args.image_size, args.image_size),
  227. batch_size=4 * args.batch_size,
  228. is_training=False,
  229. use_prefetcher=True,
  230. num_workers=args.workers,
  231. distributed=False,
  232. crop_pct=DEFAULT_CROP_PCT,
  233. mean=IMAGENET_DEFAULT_MEAN,
  234. std=IMAGENET_DEFAULT_STD,
  235. interpolation=args.interpolation
  236. )
  237. # whether to use label smoothing
  238. if args.smoothing > 0.:
  239. train_loss_fn = LabelSmoothingCrossEntropy(
  240. smoothing=args.smoothing).cuda()
  241. validate_loss_fn = nn.CrossEntropyLoss().cuda()
  242. else:
  243. train_loss_fn = nn.CrossEntropyLoss().cuda()
  244. validate_loss_fn = train_loss_fn
  245. mutator = RandomMutator(model)
  246. _callbacks = [LRSchedulerCallback(lr_scheduler)]
  247. if len(args.checkpoint_dir) > 1:
  248. _callbacks.append(ModelCheckpoint(checkpoint_dir))
  249. trainer = CreamSupernetTrainer(args.best_selected_space_path, model, train_loss_fn,
  250. validate_loss_fn,
  251. optimizer, num_epochs, loader_train, loader_eval,
  252. result_path=args.result_path,
  253. mutator=mutator,
  254. batch_size=args.batch_size,
  255. log_frequency=args.log_interval,
  256. meta_sta_epoch=args.meta_sta_epoch,
  257. update_iter=args.update_iter,
  258. slices=args.slice_num,
  259. pool_size=args.pool_size,
  260. pick_method=args.pick_method,
  261. choice_num=choice_num,
  262. sta_num=sta_num,
  263. acc_gap=args.acc_gap,
  264. flops_dict=flops_dict,
  265. flops_fixed=flops_fixed,
  266. local_rank=args.local_rank,
  267. callbacks=_callbacks)
  268. trainer.train()
  269. if __name__ == '__main__':
  270. main()

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能