You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 17 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import os
  4. import json
  5. import numpy as np
  6. import torch
  7. import logging
  8. from copy import deepcopy
  9. from pytorch.trainer import Trainer
  10. from pytorch.utils import AverageMeterGroup
  11. from .utils import accuracy, reduce_metrics
  12. logger = logging.getLogger(__name__)
  13. class CreamSupernetTrainer(Trainer):
  14. """
  15. This trainer trains a supernet and output prioritized architectures that can be used for other tasks.
  16. Parameters
  17. ----------
  18. model : nn.Module
  19. Model with mutables.
  20. loss : callable
  21. Called with logits and targets. Returns a loss tensor.
  22. val_loss : callable
  23. Called with logits and targets for validation only. Returns a loss tensor.
  24. optimizer : Optimizer
  25. Optimizer that optimizes the model.
  26. num_epochs : int
  27. Number of epochs of training.
  28. train_loader : iterablez
  29. Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
  30. valid_loader : iterablez
  31. Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
  32. mutator : Mutator
  33. A mutator object that has been initialized with the model.
  34. batch_size : int
  35. Batch size.
  36. log_frequency : int
  37. Number of mini-batches to log metrics.
  38. meta_sta_epoch : int
  39. start epoch of using meta matching network to pick teacher architecture
  40. update_iter : int
  41. interval of updating meta matching networks
  42. slices : int
  43. batch size of mini training data in the process of training meta matching network
  44. pool_size : int
  45. board size
  46. pick_method : basestring
  47. how to pick teacher network
  48. choice_num : int
  49. number of operations in supernet
  50. sta_num : int
  51. layer number of each stage in supernet (5 stage in supernet)
  52. acc_gap : int
  53. maximum accuracy improvement to omit the limitation of flops
  54. flops_dict : Dict
  55. dictionary of each layer's operations in supernet
  56. flops_fixed : int
  57. flops of fixed part in supernet
  58. local_rank : int
  59. index of current rank
  60. callbacks : list of Callback
  61. Callbacks to plug into the trainer. See Callbacks.
  62. """
  63. def __init__(self, selected_space, model, loss, val_loss,
  64. optimizer, num_epochs, train_loader, valid_loader,
  65. mutator=None, batch_size=64, log_frequency=None,
  66. meta_sta_epoch=20, update_iter=200, slices=2,
  67. pool_size=10, pick_method='meta', choice_num=6,
  68. sta_num=(4, 4, 4, 4, 4), acc_gap=5,
  69. flops_dict=None, flops_fixed=0, local_rank=0, callbacks=None, result_path=None):
  70. assert torch.cuda.is_available()
  71. super(CreamSupernetTrainer, self).__init__(model, mutator, loss, None,
  72. optimizer, num_epochs, None, None,
  73. batch_size, None, None, log_frequency, callbacks)
  74. self.selected_space = selected_space
  75. self.model = model
  76. self.loss = loss
  77. self.val_loss = val_loss
  78. self.train_loader = train_loader
  79. self.valid_loader = valid_loader
  80. self.log_frequency = log_frequency
  81. self.batch_size = batch_size
  82. self.optimizer = optimizer
  83. self.model = model
  84. self.loss = loss
  85. self.num_epochs = num_epochs
  86. self.meta_sta_epoch = meta_sta_epoch
  87. self.update_iter = update_iter
  88. self.slices = slices
  89. self.pick_method = pick_method
  90. self.pool_size = pool_size
  91. self.local_rank = local_rank
  92. self.choice_num = choice_num
  93. self.sta_num = sta_num
  94. self.acc_gap = acc_gap
  95. self.flops_dict = flops_dict
  96. self.flops_fixed = flops_fixed
  97. self.current_student_arch = None
  98. self.current_teacher_arch = None
  99. self.main_proc = (local_rank == 0)
  100. self.current_epoch = 0
  101. self.prioritized_board = []
  102. self.result_path = result_path
  103. # size of prioritized board
  104. def _board_size(self):
  105. return len(self.prioritized_board)
  106. # select teacher architecture according to the logit difference
  107. def _select_teacher(self):
  108. self._replace_mutator_cand(self.current_student_arch)
  109. if self.pick_method == 'top1':
  110. meta_value, teacher_cand = 0.5, sorted(
  111. self.prioritized_board, reverse=True)[0][3]
  112. elif self.pick_method == 'meta':
  113. meta_value, cand_idx, teacher_cand = -1000000000, -1, None
  114. for now_idx, item in enumerate(self.prioritized_board):
  115. inputx = item[4]
  116. output = torch.nn.functional.softmax(self.model(inputx), dim=1)
  117. weight = self.model.forward_meta(output - item[5])
  118. if weight > meta_value:
  119. meta_value = weight
  120. cand_idx = now_idx
  121. teacher_cand = self.prioritized_board[cand_idx][3]
  122. assert teacher_cand is not None
  123. meta_value = torch.nn.functional.sigmoid(-weight)
  124. else:
  125. raise ValueError('Method Not supported')
  126. return meta_value, teacher_cand
  127. # check whether to update prioritized board
  128. def _isUpdateBoard(self, prec1, flops):
  129. if self.current_epoch <= self.meta_sta_epoch:
  130. return False
  131. if len(self.prioritized_board) < self.pool_size:
  132. return True
  133. if prec1 > self.prioritized_board[-1][1] + self.acc_gap:
  134. return True
  135. if prec1 > self.prioritized_board[-1][1] and flops < self.prioritized_board[-1][2]:
  136. return True
  137. return False
  138. # update prioritized board
  139. def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flops):
  140. if self._isUpdateBoard(prec1, flops):
  141. val_prec1 = prec1
  142. training_data = deepcopy(inputs[:self.slices].detach())
  143. if len(self.prioritized_board) == 0:
  144. features = deepcopy(outputs[:self.slices].detach())
  145. else:
  146. features = deepcopy(teacher_output[:self.slices].detach())
  147. self.prioritized_board.append(
  148. (val_prec1,
  149. prec1,
  150. flops,
  151. self.current_student_arch,
  152. training_data,
  153. torch.nn.functional.softmax(
  154. features,
  155. dim=1)))
  156. self.prioritized_board = sorted(
  157. self.prioritized_board, reverse=True)
  158. if len(self.prioritized_board) > self.pool_size:
  159. self.prioritized_board = sorted(
  160. self.prioritized_board, reverse=True)
  161. del self.prioritized_board[-1]
  162. # only update student network weights
  163. def _update_student_weights_only(self, grad_1):
  164. for weight, grad_item in zip(
  165. self.model.module.rand_parameters(self.current_student_arch), grad_1):
  166. weight.grad = grad_item
  167. torch.nn.utils.clip_grad_norm_(
  168. self.model.module.rand_parameters(self.current_student_arch), 1)
  169. self.optimizer.step()
  170. for weight, grad_item in zip(
  171. self.model.module.rand_parameters(self.current_student_arch), grad_1):
  172. del weight.grad
  173. # only update meta networks weights
  174. def _update_meta_weights_only(self, teacher_cand, grad_teacher):
  175. for weight, grad_item in zip(self.model.module.rand_parameters(
  176. teacher_cand, self.pick_method == 'meta'), grad_teacher):
  177. weight.grad = grad_item
  178. # clip gradients
  179. torch.nn.utils.clip_grad_norm_(
  180. self.model.module.rand_parameters(
  181. self.current_student_arch, self.pick_method == 'meta'), 1)
  182. self.optimizer.step()
  183. for weight, grad_item in zip(self.model.module.rand_parameters(
  184. teacher_cand, self.pick_method == 'meta'), grad_teacher):
  185. del weight.grad
  186. # simulate sgd updating
  187. def _simulate_sgd_update(self, w, g, optimizer):
  188. return g * optimizer.param_groups[-1]['lr'] + w
  189. # split training images into several slices
  190. def _get_minibatch_input(self, input):
  191. slice = self.slices
  192. x = deepcopy(input[:slice].clone().detach())
  193. return x
  194. # calculate 1st gradient of student architectures
  195. def _calculate_1st_gradient(self, kd_loss):
  196. self.optimizer.zero_grad()
  197. grad = torch.autograd.grad(
  198. kd_loss,
  199. self.model.module.rand_parameters(self.current_student_arch),
  200. create_graph=True)
  201. return grad
  202. # calculate 2nd gradient of meta networks
  203. def _calculate_2nd_gradient(self, validation_loss, teacher_cand, students_weight):
  204. self.optimizer.zero_grad()
  205. grad_student_val = torch.autograd.grad(
  206. validation_loss,
  207. self.model.module.rand_parameters(self.current_student_arch),
  208. retain_graph=True)
  209. grad_teacher = torch.autograd.grad(
  210. students_weight[0],
  211. self.model.module.rand_parameters(
  212. teacher_cand,
  213. self.pick_method == 'meta'),
  214. grad_outputs=grad_student_val)
  215. return grad_teacher
  216. # forward training data
  217. def _forward_training(self, x, meta_value):
  218. self._replace_mutator_cand(self.current_student_arch)
  219. output = self.model(x)
  220. with torch.no_grad():
  221. self._replace_mutator_cand(self.current_teacher_arch)
  222. teacher_output = self.model(x)
  223. soft_label = torch.nn.functional.softmax(teacher_output, dim=1)
  224. kd_loss = meta_value * \
  225. self._cross_entropy_loss_with_soft_target(output, soft_label)
  226. return kd_loss
  227. # calculate soft target loss
  228. def _cross_entropy_loss_with_soft_target(self, pred, soft_target):
  229. logsoftmax = torch.nn.LogSoftmax()
  230. return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
  231. # forward validation data
  232. def _forward_validation(self, input, target):
  233. slice = self.slices
  234. x = input[slice:slice * 2].clone()
  235. self._replace_mutator_cand(self.current_student_arch)
  236. output_2 = self.model(x)
  237. validation_loss = self.loss(output_2, target[slice:slice * 2])
  238. return validation_loss
  239. def _isUpdateMeta(self, batch_idx):
  240. isUpdate = True
  241. isUpdate &= (self.current_epoch > self.meta_sta_epoch)
  242. isUpdate &= (batch_idx > 0)
  243. isUpdate &= (batch_idx % self.update_iter == 0)
  244. isUpdate &= (self._board_size() > 0)
  245. return isUpdate
  246. def _replace_mutator_cand(self, cand):
  247. self.mutator._cache = cand
  248. # update meta matching networks
  249. def _run_update(self, input, target, batch_idx):
  250. if self._isUpdateMeta(batch_idx):
  251. x = self._get_minibatch_input(input)
  252. meta_value, teacher_cand = self._select_teacher()
  253. kd_loss = self._forward_training(x, meta_value)
  254. # calculate 1st gradient
  255. grad_1st = self._calculate_1st_gradient(kd_loss)
  256. # simulate updated student weights
  257. students_weight = [
  258. self._simulate_sgd_update(
  259. p, grad_item, self.optimizer) for p, grad_item in zip(
  260. self.model.module.rand_parameters(self.current_student_arch), grad_1st)]
  261. # update student weights
  262. self._update_student_weights_only(grad_1st)
  263. validation_loss = self._forward_validation(input, target)
  264. # calculate 2nd gradient
  265. grad_teacher = self._calculate_2nd_gradient(validation_loss,
  266. teacher_cand,
  267. students_weight)
  268. # update meta matching networks
  269. self._update_meta_weights_only(teacher_cand, grad_teacher)
  270. # delete internal variants
  271. del grad_teacher, grad_1st, x, validation_loss, kd_loss, students_weight
  272. def _get_cand_flops(self, cand):
  273. flops = 0
  274. for block_id, block in enumerate(cand):
  275. if block == 'LayerChoice1' or block_id == 'LayerChoice23':
  276. continue
  277. for idx, choice in enumerate(cand[block]):
  278. flops += self.flops_dict[block_id][idx] * (1 if choice else 0)
  279. return flops + self.flops_fixed
  280. def train_one_epoch(self, epoch):
  281. self.current_epoch = epoch
  282. meters = AverageMeterGroup()
  283. self.steps_per_epoch = len(self.train_loader)
  284. for step, (input_data, target) in enumerate(self.train_loader):
  285. self.mutator.reset()
  286. self.current_student_arch = self.mutator._cache
  287. input_data, target = input_data.cuda(), target.cuda()
  288. # calculate flops of current architecture
  289. cand_flops = self._get_cand_flops(self.mutator._cache)
  290. # update meta matching network
  291. self._run_update(input_data, target, step)
  292. if self._board_size() > 0:
  293. # select teacher architecture
  294. meta_value, teacher_cand = self._select_teacher()
  295. self.current_teacher_arch = teacher_cand
  296. # forward supernet
  297. if self._board_size() == 0 or epoch <= self.meta_sta_epoch:
  298. self._replace_mutator_cand(self.current_student_arch)
  299. output = self.model(input_data)
  300. loss = self.loss(output, target)
  301. kd_loss, teacher_output, teacher_cand = None, None, None
  302. else:
  303. self._replace_mutator_cand(self.current_student_arch)
  304. output = self.model(input_data)
  305. gt_loss = self.loss(output, target)
  306. with torch.no_grad():
  307. self._replace_mutator_cand(self.current_teacher_arch)
  308. teacher_output = self.model(input_data).detach()
  309. soft_label = torch.nn.functional.softmax(teacher_output, dim=1)
  310. kd_loss = self._cross_entropy_loss_with_soft_target(output, soft_label)
  311. loss = (meta_value * kd_loss + (2 - meta_value) * gt_loss) / 2
  312. # update network
  313. self.optimizer.zero_grad()
  314. loss.backward()
  315. self.optimizer.step()
  316. # update metrics
  317. prec1, prec5 = accuracy(output, target, topk=(1, 5))
  318. metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
  319. metrics = reduce_metrics(metrics)
  320. meters.update(metrics)
  321. # update prioritized board
  322. self._update_prioritized_board(input_data,
  323. teacher_output,
  324. output,
  325. metrics['prec1'],
  326. cand_flops)
  327. if self.main_proc and (
  328. step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
  329. logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs,
  330. step + 1, len(self.train_loader), meters)
  331. arch_list = []
  332. # if self.main_proc and self.num_epochs == epoch + 1:
  333. for idx, i in enumerate(self.prioritized_board):
  334. # logger.info("prioritized_board: No.%s %s", idx, i[:4])
  335. if idx == 0:
  336. for arch in list(i[3].values()):
  337. _ = arch.numpy()
  338. _ = np.where(_)[0].tolist()
  339. arch_list.append(_)
  340. if len(arch_list) > 0:
  341. with open(self.selected_space, "w") as f:
  342. print("dump selected space.")
  343. json.dump({'selected_space': arch_list}, f)
  344. def validate_one_epoch(self, epoch):
  345. self.model.eval()
  346. meters = AverageMeterGroup()
  347. with torch.no_grad():
  348. for step, (x, y) in enumerate(self.valid_loader):
  349. self.mutator.reset()
  350. logits = self.model(x)
  351. loss = self.val_loss(logits, y)
  352. prec1, prec5 = accuracy(logits, y, topk=(1, 5))
  353. metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
  354. metrics = reduce_metrics(metrics)
  355. meters.update(metrics)
  356. if self.log_frequency is not None and step % self.log_frequency == 0:
  357. logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
  358. self.num_epochs, step + 1, len(self.valid_loader), meters)
  359. # print({'type': 'Accuracy', 'result': {'sequence': epoch, 'category': 'epoch',
  360. # 'value': metrics["prec1"]}})
  361. if self.result_path is not None:
  362. with open(self.result_path, "a") as ss_file:
  363. ss_file.write(json.dumps(
  364. {'type': 'Accuracy',
  365. 'result': {'sequence': epoch,
  366. 'category': 'epoch',
  367. 'value': metrics["prec1"]}}) + '\n')

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能