You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

retrain.py 16 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. import sys
  2. sys.path.append('..'+ '/' + '..')
  3. import os
  4. import logging
  5. import pickle
  6. import shutil
  7. import random
  8. import math
  9. import time
  10. import datetime
  11. import argparse
  12. import distutils.util
  13. import numpy as np
  14. import json
  15. import torch
  16. from torch import nn
  17. from torch import optim
  18. from torch.utils.data import DataLoader
  19. import torch.nn.functional as Func
  20. from macro import GeneralNetwork
  21. from micro import MicroNetwork
  22. import datasets
  23. from utils import accuracy, reward_accuracy
  24. from pytorch.fixed import apply_fixed_architecture
  25. from pytorch.utils import AverageMeterGroup, to_device, save_best_checkpoint
  26. logger = logging.getLogger("enas-retrain")
  27. # TODO:
  28. def set_random_seed(seed):
  29. logger.info("set random seed for data reading: {}".format(seed))
  30. random.seed(seed)
  31. os.environ['PYTHONHASHSEED'] = str(seed)
  32. np.random.seed(seed)
  33. random.seed(seed)
  34. torch.manual_seed_all(seed)
  35. if FLAGS.is_cuda:
  36. torch.cuda.manual_seed_all(seed)
  37. torch.backends.cudnn.deterministic = True
  38. # TODO: parser args
  39. def parse_args():
  40. parser = argparse.ArgumentParser()
  41. parser.add_argument(
  42. "--data_dir",
  43. type=str,
  44. default="./data",
  45. help="Directory containing the dataset and embedding file. (default: %(default)s)")
  46. parser.add_argument("--search_space_path", type=str,
  47. default='./search_space.json', help="search_space directory")
  48. parser.add_argument(
  49. "--selected_space_path",
  50. type=str,
  51. default="./selected_space.json",
  52. # required=True,
  53. help="Architecture json file. (default: %(default)s)")
  54. parser.add_argument("--result_path", type=str,
  55. default='./result.json', help="res directory")
  56. parser.add_argument('--trial_id', type=int, default=0, metavar='N',
  57. help='trial_id,start from 0')
  58. parser.add_argument(
  59. "--output_dir",
  60. type=str,
  61. default="./output",
  62. help="The output directory. (default: %(default)s)")
  63. parser.add_argument(
  64. "--best_checkpoint_dir",
  65. type=str,
  66. default="best_checkpoint",
  67. help="Path for saved checkpoints. (default: %(default)s)")
  68. parser.add_argument("--search_for",
  69. choices=["macro", "micro"],
  70. default="micro")
  71. parser.add_argument(
  72. "--batch_size",
  73. type=int,
  74. default=128,
  75. help="Number of samples each batch for training. (default: %(default)s)")
  76. parser.add_argument(
  77. "--eval_batch_size",
  78. type=int,
  79. default=128,
  80. help="Number of samples each batch for evaluation. (default: %(default)s)")
  81. parser.add_argument(
  82. "--class_num",
  83. type=int,
  84. default=10,
  85. help="The number of categories. (default: %(default)s)")
  86. parser.add_argument(
  87. "--epochs",
  88. type=int,
  89. default=10,
  90. help="The number of training epochs. (default: %(default)s)")
  91. parser.add_argument(
  92. "--child_lr",
  93. type=float,
  94. default=0.02,
  95. help="The initial learning rate. (default: %(default)s)")
  96. parser.add_argument(
  97. "--is_cuda",
  98. type=distutils.util.strtobool,
  99. default=True,
  100. help="Specify the device type. (default: %(default)s)")
  101. parser.add_argument(
  102. "--load_checkpoint",
  103. type=distutils.util.strtobool,
  104. default=False,
  105. help="Whether to load checkpoint. (default: %(default)s)")
  106. parser.add_argument(
  107. "--log_every",
  108. type=int,
  109. default=50,
  110. help="How many steps to log. (default: %(default)s)")
  111. parser.add_argument(
  112. "--eval_every_epochs",
  113. type=int,
  114. default=1,
  115. help="How many epochs to eval. (default: %(default)s)")
  116. parser.add_argument(
  117. "--child_grad_bound",
  118. type=float,
  119. default=5.0,
  120. help="The threshold for gradient clipping. (default: %(default)s)") #
  121. parser.add_argument(
  122. "--child_lr_decay_scheme",
  123. type=str,
  124. default="cosine",
  125. help="Learning rate annealing strategy, only 'cosine' supported. (default: %(default)s)") #todo: remove
  126. parser.add_argument(
  127. "--child_lr_T_0",
  128. type=int,
  129. default=10,
  130. help="The length of one cycle. (default: %(default)s)") # todo: use for
  131. parser.add_argument(
  132. "--child_lr_T_mul",
  133. type=int,
  134. default=2,
  135. help="The multiplication factor per cycle. (default: %(default)s)") # todo: use for
  136. parser.add_argument(
  137. "--child_l2_reg",
  138. type=float,
  139. default=3e-6,
  140. help="Weight decay factor. (default: %(default)s)")
  141. parser.add_argument(
  142. "--child_lr_max",
  143. type=float,
  144. default=0.002,
  145. help="The max learning rate. (default: %(default)s)")
  146. parser.add_argument(
  147. "--child_lr_min",
  148. type=float,
  149. default=0.001,
  150. help="The min learning rate. (default: %(default)s)")
  151. parser.add_argument(
  152. "--multi_path",
  153. type=distutils.util.strtobool,
  154. default=False,
  155. help="Search for multiple path in the architecture. (default: %(default)s)") # todo: use for
  156. parser.add_argument(
  157. "--is_mask",
  158. type=distutils.util.strtobool,
  159. default=True,
  160. help="Apply mask. (default: %(default)s)")
  161. global FLAGS
  162. FLAGS = parser.parse_args()
  163. def print_user_flags(FLAGS, line_limit=80):
  164. log_strings = "\n" + "-" * line_limit + "\n"
  165. for flag_name in sorted(vars(FLAGS)):
  166. value = "{}".format(getattr(FLAGS, flag_name))
  167. log_string = flag_name
  168. log_string += "." * (line_limit - len(flag_name) - len(value))
  169. log_string += value
  170. log_strings = log_strings + log_string
  171. log_strings = log_strings + "\n"
  172. log_strings += "-" * line_limit
  173. logger.info(log_strings)
  174. def eval_once(child_model, device, eval_set, criterion, valid_dataloader=None, test_dataloader=None):
  175. if eval_set == "test":
  176. assert test_dataloader is not None
  177. dataloader = test_dataloader
  178. elif eval_set == "valid":
  179. assert valid_dataloader is not None
  180. dataloader = valid_dataloader
  181. else:
  182. raise NotImplementedError("Unknown eval_set '{}'".format(eval_set))
  183. tot_acc = 0
  184. tot = 0
  185. losses = []
  186. with torch.no_grad(): # save memory
  187. for batch in dataloader:
  188. x, y = batch
  189. x, y = to_device(x, device), to_device(y, device)
  190. logits = child_model(x)
  191. if isinstance(logits, tuple):
  192. logits, aux_logits = logits
  193. aux_loss = criterion(aux_logits, y)
  194. else:
  195. aux_loss = 0.
  196. loss = criterion(logits, y)
  197. loss = loss + aux_weight * aux_loss
  198. # loss = loss.mean()
  199. preds = logits.argmax(dim=1).long()
  200. acc = torch.eq(preds, y.long()).long().sum().item()
  201. losses.append(loss)
  202. tot_acc += acc
  203. tot += len(y)
  204. losses = torch.tensor(losses)
  205. loss = losses.mean()
  206. if tot > 0:
  207. final_acc = float(tot_acc) / tot
  208. else:
  209. final_acc = 0
  210. logger.info("Error in calculating final_acc")
  211. return final_acc, loss
  212. # TODO: learning rate scheduler
  213. def update_lr(
  214. optimizer,
  215. epoch,
  216. l2_reg=1e-4,
  217. lr_warmup_val=None,
  218. lr_init=0.1,
  219. lr_decay_scheme="cosine",
  220. lr_max=0.002,
  221. lr_min=0.000000001,
  222. lr_T_0=4,
  223. lr_T_mul=1,
  224. sync_replicas=False,
  225. num_aggregate=None,
  226. num_replicas=None):
  227. if lr_decay_scheme == "cosine":
  228. assert lr_max is not None, "Need lr_max to use lr_cosine"
  229. assert lr_min is not None, "Need lr_min to use lr_cosine"
  230. assert lr_T_0 is not None, "Need lr_T_0 to use lr_cosine"
  231. assert lr_T_mul is not None, "Need lr_T_mul to use lr_cosine"
  232. T_i = lr_T_0
  233. t_epoch = epoch
  234. last_reset = 0
  235. while True:
  236. t_epoch -= T_i
  237. if t_epoch < 0:
  238. break
  239. last_reset += T_i
  240. T_i *= lr_T_mul
  241. T_curr = epoch - last_reset
  242. def _update():
  243. rate = T_curr / T_i * 3.1415926
  244. lr = lr_min + 0.5 * (lr_max - lr_min) * (1.0 + math.cos(rate))
  245. return lr
  246. learning_rate = _update()
  247. else:
  248. raise ValueError("Unknown learning rate decay scheme {}".format(lr_decay_scheme))
  249. #update lr in optimizer
  250. for params_group in optimizer.param_groups:
  251. params_group['lr'] = learning_rate
  252. return learning_rate
  253. def train(device, output_dir='./output'):
  254. workers = 4
  255. data = 'cifar10'
  256. data_dir = FLAGS.data_dir
  257. output_dir = FLAGS.output_dir
  258. checkpoint_dir = FLAGS.best_checkpoint_dir
  259. batch_size = FLAGS.batch_size
  260. eval_batch_size = FLAGS.eval_batch_size
  261. class_num = FLAGS.class_num
  262. epochs = FLAGS.epochs
  263. child_lr = FLAGS.child_lr
  264. is_cuda = FLAGS.is_cuda
  265. load_checkpoint = FLAGS.load_checkpoint
  266. log_every = FLAGS.log_every
  267. eval_every_epochs = FLAGS.eval_every_epochs
  268. child_grad_bound = FLAGS.child_grad_bound
  269. child_l2_reg = FLAGS.child_l2_reg
  270. logger.info("Build dataloader")
  271. dataset_train, dataset_valid = datasets.get_dataset("cifar10")
  272. n_train = len(dataset_train)
  273. split = n_train // 10
  274. indices = list(range(n_train))
  275. train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
  276. valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
  277. train_dataloader = torch.utils.data.DataLoader(dataset_train,
  278. batch_size=batch_size,
  279. sampler=train_sampler,
  280. num_workers=workers)
  281. valid_dataloader = torch.utils.data.DataLoader(dataset_train,
  282. batch_size=batch_size,
  283. sampler=valid_sampler,
  284. num_workers=workers)
  285. test_dataloader = torch.utils.data.DataLoader(dataset_valid,
  286. batch_size=batch_size,
  287. num_workers=workers)
  288. criterion = nn.CrossEntropyLoss()
  289. optimizer = torch.optim.SGD(child_model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4, nesterov=True)
  290. # optimizer = optim.Adam(child_model.parameters(), eps=1e-3, weight_decay=FLAGS.child_l2_reg)
  291. # TODO
  292. lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.001)
  293. # move model to CPU/GPU device
  294. child_model.to(device)
  295. criterion.to(device)
  296. logger.info('Start training')
  297. start_time = time.time()
  298. step = 0
  299. # save path
  300. if not os.path.exists(output_dir):
  301. os.mkdir(output_dir)
  302. # model_save_path = os.path.join(output_dir, "model.pth")
  303. # best_model_save_path = os.path.join(output_dir, "best_model.pth")
  304. best_acc = 0
  305. start_epoch = 0
  306. # TODO: load checkpoints
  307. # train
  308. for epoch in range(start_epoch, epochs):
  309. lr = update_lr(optimizer,
  310. epoch,
  311. l2_reg= 1e-4,
  312. lr_warmup_val=None,
  313. lr_init=FLAGS.child_lr,
  314. lr_decay_scheme=FLAGS.child_lr_decay_scheme,
  315. lr_max=0.05,
  316. lr_min=0.001,
  317. lr_T_0=10,
  318. lr_T_mul=2)
  319. child_model.train()
  320. for batch in train_dataloader:
  321. step += 1
  322. x, y = batch
  323. x, y = to_device(x, device), to_device(y, device)
  324. logits = child_model(x)
  325. if isinstance(logits, tuple):
  326. logits, aux_logits = logits
  327. aux_loss = criterion(aux_logits, y)
  328. else:
  329. aux_loss = 0.
  330. acc = accuracy(logits, y)
  331. loss = criterion(logits, y)
  332. loss = loss + aux_weight * aux_loss
  333. optimizer.zero_grad()
  334. loss.backward()
  335. grad_norm = 0
  336. trainable_params = child_model.parameters()
  337. for param in trainable_params:
  338. nn.utils.clip_grad_norm_(param, child_grad_bound) # clip grad
  339. optimizer.step()
  340. if step % log_every == 0:
  341. curr_time = time.time()
  342. log_string = ""
  343. log_string += "epoch={:<6d}".format(epoch)
  344. log_string += "ch_step={:<6d}".format(step)
  345. log_string += " loss={:<8.6f}".format(loss)
  346. log_string += " lr={:<8.4f}".format(lr)
  347. log_string += " |g|={:<8.4f}".format(grad_norm)
  348. log_string += " tr_acc={:<8.4f}/{:>3d}".format(acc['acc1'], logits.size()[0])
  349. log_string += " mins={:<10.2f}".format(float(curr_time - start_time) / 60)
  350. logger.info(log_string)
  351. epoch += 1
  352. save_state = {
  353. 'step': step,
  354. 'epoch': epoch,
  355. 'child_model_state_dict': child_model.state_dict(),
  356. 'optimizer_state_dict': optimizer.state_dict()}
  357. # print(' Epoch {:<3d} loss: {:<.2f} '.format(epoch, loss))
  358. # torch.save(save_state, model_save_path)
  359. child_model.eval()
  360. logger.info("Epoch {}: Eval".format(epoch))
  361. eval_acc, eval_loss = eval_once(child_model, device, "test", criterion, test_dataloader=test_dataloader)
  362. logger.info(
  363. "ch_step={} {}_accuracy={:<6.4f} {}_loss={:<6.4f}".format(step, "test", eval_acc, "test", eval_loss))
  364. if eval_acc > best_acc:
  365. best_acc = eval_acc
  366. logger.info("Save best model")
  367. # save_state = {
  368. # 'step': step,
  369. # 'epoch': epoch,
  370. # 'child_model_state_dict': child_model.state_dict(),
  371. # 'optimizer_state_dict': optimizer.state_dict()}
  372. # torch.save(save_state, best_model_save_path)
  373. save_best_checkpoint(checkpoint_dir, child_model, optimizer, epoch)
  374. result['accuracy'].append('Epoch {} acc: {:<6.4f}'.format(epoch, eval_acc,))
  375. acc_l.append(eval_acc)
  376. print(result['accuracy'][-1])
  377. print('max acc %.4f at epoch: %i'%(max(acc_l), np.argmax(np.array(acc_l))))
  378. print('Time cost: %.4f hours'%( float(time.time() - start_time) /3600. ))
  379. return result
  380. # macro = True
  381. parse_args()
  382. child_fixed_arc = FLAGS.selected_space_path # './macro_seletced_space'
  383. search_for = FLAGS.search_for
  384. # 设置随机种子
  385. torch.manual_seed(FLAGS.trial_id)
  386. torch.cuda.manual_seed_all(FLAGS.trial_id)
  387. np.random.seed(FLAGS.trial_id)
  388. random.seed(FLAGS.trial_id)
  389. aux_weight = 0.4
  390. result = {'accuracy':[]}
  391. acc_l = []
  392. # decode human readable search space to model
  393. def convert_selected_space_format():
  394. # with open('./macro_selected_space.json') as js:
  395. with open(child_fixed_arc) as js:
  396. selected_space = json.load(js)
  397. ops = selected_space['op_list']
  398. selected_space.pop('op_list')
  399. new_selected_space = {}
  400. for key, value in selected_space.items():
  401. # for macro
  402. if FLAGS.search_for == 'macro':
  403. new_key = key.split('_')[-1]
  404. # for micro
  405. elif FLAGS.search_for == 'micro':
  406. new_key = key
  407. if len(value) > 1 or len(value)==0:
  408. new_value = value
  409. elif len(value) > 0 and value[0] in ops:
  410. new_value = ops.index(value[0])
  411. else:
  412. new_value = value[0]
  413. new_selected_space[new_key] = new_value
  414. return new_selected_space
  415. fixed_arc = convert_selected_space_format()
  416. # TODO : macro search or micro search
  417. if FLAGS.search_for == 'macro':
  418. child_model = GeneralNetwork()
  419. elif FLAGS.search_for == 'micro':
  420. child_model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True)
  421. apply_fixed_architecture(child_model,fixed_arc)
  422. def dump_global_result(res_path,global_result, sort_keys = False):
  423. with open(res_path, "w") as ss_file:
  424. json.dump(global_result, ss_file, sort_keys=sort_keys, indent=2)
  425. def main():
  426. os.environ['CUDA_VISIBLE_DEVICES'] = '4'
  427. # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
  428. device = torch.device("cuda" if FLAGS.is_cuda else "cpu")
  429. train(device)
  430. dump_global_result('result_retrain.json', result['accuracy'])
  431. if __name__ == "__main__":
  432. main()

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能