You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

retrainer.py 18 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. import sys
  2. from utils import accuracy
  3. import torch
  4. from torch import nn
  5. from torch.utils.data import DataLoader
  6. import datasets
  7. import time
  8. import logging
  9. import os
  10. import argparse
  11. import distutils.util
  12. import numpy as np
  13. import json
  14. import random
  15. sys.path.append('..'+ '/' + '..')
  16. # import custom packages
  17. from macro import GeneralNetwork
  18. from micro import MicroNetwork
  19. from pytorch.fixed import apply_fixed_architecture
  20. from pytorch.retrainer import Retrainer
  21. from pytorch.utils import AverageMeterGroup, to_device, save_best_checkpoint, mkdirs
  22. class EnasRetrainer(Retrainer):
  23. """
  24. ENAS retrainer.
  25. Parameters
  26. ----------
  27. model : nn.Module
  28. PyTorch model to be trained.
  29. data_dir : dataset path
  30. The path of the dataset.
  31. best_checkpoint_dir: 'best_checkpoint.pth'
  32. The directory for saving model.
  33. batch_size : int
  34. Batch size.
  35. eval_batch_size : int
  36. Batch size.
  37. num_epochs : int
  38. Number of epochs planned for training.
  39. lr : float
  40. Learning rate.
  41. is_cuda: Boolean
  42. Whether to use GPU for training.
  43. log_every : int
  44. Step count per logging.
  45. child_grad_bound : float
  46. Gradient bound.
  47. child_l2_reg: float
  48. L2 regression.
  49. eval_every_epochs: int
  50. Evaluate every epochs.
  51. logger:
  52. logging.
  53. workers : int
  54. Workers for data loading.
  55. device : torch.device
  56. ``torch.device("cpu")`` or ``torch.device("cuda")``.
  57. aux_weight : float
  58. Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss.
  59. """
  60. def __init__(self,model,data_dir = './data',best_checkpoint_dir = './best_checkpoint',
  61. batch_size = 1024, eval_batch_size = 1024,num_epochs = 2,lr = 0.02,is_cuda = 'True',
  62. log_every = 40,child_grad_bound = 0.5, child_l2_reg=3e-6, eval_every_epochs=2,
  63. logger = logging.getLogger("enas-retrain"), result_path='./'):
  64. self.aux_weight = 0.4
  65. self.device = torch.device("cuda:0" )
  66. self.workers = 4
  67. self.child_model = model
  68. self.data_dir = data_dir
  69. self.best_checkpoint_dir = best_checkpoint_dir
  70. self.batch_size = batch_size
  71. self.eval_batch_size = eval_batch_size
  72. self.num_epochs = num_epochs
  73. self.lr = lr
  74. self.is_cuda = is_cuda
  75. self.log_every = log_every
  76. self.child_grad_bound = child_grad_bound
  77. self.child_l2_reg = child_l2_reg
  78. self.eval_every_epochs = eval_every_epochs
  79. self.logger = logger
  80. self.optimizer = torch.optim.SGD(self.child_model.parameters(), self.lr, momentum=0.9, weight_decay=1.0E-4, nesterov=True)
  81. self.criterion = nn.CrossEntropyLoss()
  82. self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.num_epochs, eta_min=0.001)
  83. # load dataset
  84. self.init_dataloader()
  85. self.child_model.to(self.device)
  86. self.result_path = result_path
  87. with open(self.result_path, "w") as file:
  88. file.write('')
  89. def train(self):
  90. """
  91. Train ``num_epochs``.
  92. Trigger callbacks at the start and the end of each epoch.
  93. Parameters
  94. ----------
  95. validate : bool
  96. If ``true``, will do validation every epoch.
  97. """
  98. self.logger.info('** Start training **')
  99. self.start_time = time.time()
  100. for epoch in range(self.num_epochs):
  101. self.train_one_epoch(epoch)
  102. self.child_model.eval()
  103. # if epoch / self.eval_every_epochs == 0:
  104. self.logger.info("Epoch {}: Eval".format(epoch))
  105. self.validate_one_epoch(epoch)
  106. self.lr_scheduler.step()
  107. # print('** saving model **')
  108. self.logger.info("** Save best model **")
  109. # save_state = {
  110. # 'epoch': epoch,
  111. # 'child_model_state_dict': self.child_model.state_dict(),
  112. # 'optimizer_state_dict': self.optimizer.state_dict()}
  113. # torch.save(save_state, self.best_checkpoint_dir)
  114. save_best_checkpoint(self.best_checkpoint_dir, self.child_model, self.optimizer, epoch)
  115. def validate(self):
  116. """
  117. Do one validation. Validate one epoch.
  118. """
  119. pass
  120. def export(self, file):
  121. """
  122. dump the architecture to ``file``.
  123. Parameters
  124. ----------
  125. file : str
  126. File path to export to. Expected to be a JSON.
  127. """
  128. pass
  129. def checkpoint(self):
  130. """
  131. Override to dump a checkpoint.
  132. """
  133. pass
  134. def init_dataloader(self):
  135. self.logger.info("Build dataloader")
  136. self.dataset_train, self.dataset_valid = datasets.get_dataset("cifar10", self.data_dir)
  137. n_train = len(self.dataset_train)
  138. split = n_train // 10
  139. indices = list(range(n_train))
  140. train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
  141. valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
  142. self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
  143. batch_size=self.batch_size,
  144. sampler=train_sampler,
  145. num_workers=self.workers)
  146. self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
  147. batch_size=self.eval_batch_size,
  148. sampler=valid_sampler,
  149. num_workers=self.workers)
  150. self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
  151. batch_size=self.batch_size,
  152. num_workers=self.workers)
  153. # self.train_loader = cycle(self.train_loader)
  154. # self.valid_loader = cycle(self.valid_loader)
  155. def train_one_epoch(self,epoch):
  156. """
  157. Train one epoch.
  158. Parameters
  159. ----------
  160. epoch : int
  161. Epoch number starting from 0.
  162. """
  163. tot_acc = 0
  164. tot = 0
  165. losses = []
  166. step = 0
  167. self.child_model.train()
  168. meters = AverageMeterGroup()
  169. for batch in self.train_loader:
  170. step += 1
  171. x, y = batch
  172. x, y = to_device(x, self.device), to_device(y, self.device)
  173. logits = self.child_model(x)
  174. if isinstance(logits, tuple):
  175. logits, aux_logits = logits
  176. aux_loss = self.criterion(aux_logits, y)
  177. else:
  178. aux_loss = 0.
  179. acc = accuracy(logits, y)
  180. loss = self.criterion(logits, y)
  181. loss = loss + self.aux_weight * aux_loss
  182. self.optimizer.zero_grad()
  183. loss.backward()
  184. grad_norm = 0
  185. trainable_params = self.child_model.parameters()
  186. # assert FLAGS.child_grad_bound is not None, "Need grad_bound to clip gradients."
  187. # # compute the gradient norm value
  188. # grad_norm = nn.utils.clip_grad_norm_(trainable_params, 99999999)
  189. # for param in trainable_params:
  190. # nn.utils.clip_grad_norm_(param, self.child_grad_bound) # clip grad
  191. # print(param_ == param)
  192. if self.child_grad_bound is not None:
  193. grad_norm = nn.utils.clip_grad_norm_(trainable_params, self.child_grad_bound)
  194. trainable_params = grad_norm
  195. self.optimizer.step()
  196. tot_acc += acc['acc1']
  197. tot += 1
  198. losses.append(loss)
  199. acc["loss"] = loss.item()
  200. meters.update(acc)
  201. if step % self.log_every == 0:
  202. curr_time = time.time()
  203. log_string = ""
  204. log_string += "epoch={:<6d}".format(epoch)
  205. log_string += "ch_step={:<6d}".format(step)
  206. log_string += " loss={:<8.6f}".format(loss)
  207. log_string += " lr={:<8.4f}".format(self.optimizer.param_groups[0]['lr'])
  208. log_string += " |g|={:<8.4f}".format(grad_norm)
  209. log_string += " tr_acc={:<8.4f}/{:>3d}".format(acc['acc1'], logits.size()[0])
  210. log_string += " mins={:<10.2f}".format(float(curr_time - self.start_time) / 60)
  211. self.logger.info(log_string)
  212. print("Model Epoch [%d/%d] %.3f mins %s \n " % (epoch + 1,
  213. self.num_epochs, float(time.time() - self.start_time) / 60, meters ))
  214. final_acc = float(tot_acc) / tot
  215. losses = torch.tensor(losses)
  216. loss = losses.mean()
  217. def validate_one_epoch(self,epoch):
  218. tot_acc = 0
  219. tot = 0
  220. losses = []
  221. meters = AverageMeterGroup()
  222. with torch.no_grad(): # save memory
  223. meters = AverageMeterGroup()
  224. for batch in self.valid_loader:
  225. x, y = batch
  226. x, y = to_device(x, self.device), to_device(y, self.device)
  227. logits = self.child_model(x)
  228. if isinstance(logits, tuple):
  229. logits, aux_logits = logits
  230. aux_loss = self.criterion(aux_logits, y)
  231. else:
  232. aux_loss = 0.
  233. loss = self.criterion(logits, y)
  234. loss = loss + self.aux_weight * aux_loss
  235. # loss = loss.mean()
  236. preds = logits.argmax(dim=1).long()
  237. acc = torch.eq(preds, y.long()).long().sum().item()
  238. acc_v = accuracy(logits, y)
  239. losses.append(loss)
  240. tot_acc += acc
  241. tot += len(y)
  242. acc_v["loss"] = loss.item()
  243. meters.update(acc_v)
  244. losses = torch.tensor(losses)
  245. loss = losses.mean()
  246. if tot > 0:
  247. final_acc = float(tot_acc) / tot
  248. else:
  249. final_acc = 0
  250. self.logger.info("Error in calculating final_acc")
  251. with open(self.result_path, "a") as file:
  252. file.write(
  253. str({"type": "Accuracy",
  254. "result": {"sequence": epoch, "category": "epoch", "value": final_acc}}) + '\n')
  255. # print("Model eval %.3fmins %s \n " % (
  256. # float(time.time() - self.start_time) / 60, meters ))
  257. print({"type": "Accuracy",
  258. "result": {"sequence": epoch, "category": "epoch", "value": final_acc}})
  259. self.logger.info(
  260. "ch_step= {}_accuracy={:<6.4f} {}_loss={:<6.4f}".format( "test", final_acc, "test", loss))
  261. logging.basicConfig(format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s',
  262. level=logging.INFO,
  263. filename='./retrain.log',
  264. filemode='a')
  265. logger = logging.getLogger("enas-retrain")
  266. def parse_args():
  267. parser = argparse.ArgumentParser()
  268. parser.add_argument(
  269. "--data_dir",
  270. type=str,
  271. default="./data",
  272. help="Directory containing the dataset and embedding file. (default: %(default)s)")
  273. parser.add_argument(
  274. "--model_selected_space_path",
  275. type=str,
  276. default="./model_selected_space.json",
  277. # required=True,
  278. help="Architecture json file. (default: %(default)s)")
  279. parser.add_argument("--result_path", type=str,
  280. default='./result.json', help="res directory")
  281. parser.add_argument("--search_space_path", type=str,
  282. default='./search_space.json', help="search_space directory")
  283. parser.add_argument("--log_path", type=str, default='output/log')
  284. parser.add_argument(
  285. "--best_selected_space_path",
  286. type=str,
  287. default="./best_selected_space.json",
  288. # required=True,
  289. help="Best architecture selected json file by experiment. (default: %(default)s)")
  290. parser.add_argument(
  291. "--best_checkpoint_dir",
  292. type=str,
  293. default="best_checkpoint",
  294. help="Path for saved checkpoints. (default: %(default)s)")
  295. parser.add_argument('--trial_id', type=int, default=0, metavar='N',
  296. help='trial_id,start from 0')
  297. parser.add_argument("--search_for",
  298. choices=["macro", "micro"],
  299. default="macro")
  300. parser.add_argument(
  301. "--batch_size",
  302. type=int,
  303. default=128,
  304. help="Number of samples each batch for training. (default: %(default)s)")
  305. parser.add_argument(
  306. "--eval_batch_size",
  307. type=int,
  308. default=128,
  309. help="Number of samples each batch for evaluation. (default: %(default)s)")
  310. parser.add_argument(
  311. "--epochs",
  312. type=int,
  313. default=10,
  314. help="The number of training epochs. (default: %(default)s)")
  315. parser.add_argument(
  316. "--lr",
  317. type=float,
  318. default=0.02,
  319. help="The initial learning rate. (default: %(default)s)")
  320. parser.add_argument(
  321. "--is_cuda",
  322. type=distutils.util.strtobool,
  323. default=True,
  324. help="Specify the device type. (default: %(default)s)")
  325. parser.add_argument(
  326. "--load_checkpoint",
  327. type=distutils.util.strtobool,
  328. default=False,
  329. help="Whether to load checkpoint. (default: %(default)s)")
  330. parser.add_argument(
  331. "--log_every",
  332. type=int,
  333. default=50,
  334. help="How many steps to log. (default: %(default)s)")
  335. parser.add_argument(
  336. "--eval_every_epochs",
  337. type=int,
  338. default=1,
  339. help="How many epochs to eval. (default: %(default)s)")
  340. parser.add_argument(
  341. "--child_grad_bound",
  342. type=float,
  343. default=5.0,
  344. help="The threshold for gradient clipping. (default: %(default)s)") #
  345. parser.add_argument(
  346. "--child_l2_reg",
  347. type=float,
  348. default=3e-6,
  349. help="Weight decay factor. (default: %(default)s)")
  350. parser.add_argument(
  351. "--child_lr_decay_scheme",
  352. type=str,
  353. default="cosine",
  354. help="Learning rate annealing strategy, only 'cosine' supported. (default: %(default)s)") #todo: remove
  355. global FLAGS
  356. FLAGS = parser.parse_args()
  357. # decode human readable search space to model
  358. def convert_selected_space_format(child_fixed_arc):
  359. # with open('./macro_selected_space.json') as js:
  360. with open(child_fixed_arc) as js:
  361. selected_space = json.load(js)
  362. ops = selected_space['op_list']
  363. selected_space.pop('op_list')
  364. new_selected_space = {}
  365. for key, value in selected_space.items():
  366. # for macro
  367. if FLAGS.search_for == 'macro':
  368. new_key = key.split('_')[-1]
  369. # for micro
  370. elif FLAGS.search_for == 'micro':
  371. new_key = key
  372. if len(value) > 1 or len(value)==0:
  373. new_value = value
  374. elif len(value) > 0 and value[0] in ops:
  375. new_value = ops.index(value[0])
  376. else:
  377. new_value = value[0]
  378. new_selected_space[new_key] = new_value
  379. return new_selected_space
  380. def set_random_seed(seed):
  381. logger.info("set random seed for data reading: {}".format(seed))
  382. random.seed(seed)
  383. os.environ['PYTHONHASHSEED'] = str(seed)
  384. np.random.seed(seed)
  385. random.seed(seed)
  386. torch.manual_seed(seed)
  387. if FLAGS.is_cuda:
  388. torch.cuda.manual_seed_all(seed)
  389. torch.backends.cudnn.deterministic = True
  390. def main():
  391. parse_args()
  392. child_fixed_arc = FLAGS.best_selected_space_path # './macro_seletced_space'
  393. search_for = FLAGS.search_for
  394. # set seed to result todo: trial ID
  395. set_random_seed(FLAGS.trial_id)
  396. mkdirs(FLAGS.result_path, FLAGS.log_path, FLAGS.best_checkpoint_dir)
  397. # define and load model
  398. logger.info('** ' + FLAGS.search_for + 'search **')
  399. fixed_arc = convert_selected_space_format(child_fixed_arc)
  400. # Model, macro search or micro search
  401. if FLAGS.search_for == 'macro':
  402. child_model = GeneralNetwork()
  403. elif FLAGS.search_for == 'micro':
  404. child_model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True)
  405. apply_fixed_architecture(child_model, fixed_arc)
  406. # load model
  407. if FLAGS.load_checkpoint:
  408. print('** Load model **')
  409. logger.info('** Load model **')
  410. child_model.load_state_dict(torch.load(FLAGS.best_checkpoint_dir)['child_model_state_dict'])
  411. retrainer = EnasRetrainer(model=child_model,
  412. data_dir = FLAGS.data_dir,
  413. best_checkpoint_dir=FLAGS.best_checkpoint_dir,
  414. batch_size=FLAGS.batch_size,
  415. eval_batch_size=FLAGS.eval_batch_size,
  416. num_epochs=FLAGS.epochs,
  417. lr=FLAGS.lr,
  418. is_cuda=FLAGS.is_cuda,
  419. log_every=FLAGS.log_every,
  420. child_grad_bound=FLAGS.child_grad_bound,
  421. child_l2_reg=FLAGS.child_l2_reg,
  422. eval_every_epochs=FLAGS.eval_every_epochs,
  423. logger=logger,
  424. result_path=FLAGS.result_path,
  425. )
  426. t1 = time.time()
  427. retrainer.train()
  428. print('cost time for retrain: ' , time.time() - t1)
  429. if __name__ == "__main__":
  430. main()

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能