You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dartstrainer.py 10 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import copy
  4. import logging
  5. import torch
  6. import torch.nn as nn
  7. from pytorch.trainer import Trainer
  8. from pytorch.utils import AverageMeterGroup, dump_global_result
  9. from pytorch.darts.dartsmutator import DartsMutator
  10. import json
  11. logger = logging.getLogger(__name__)
  12. class DartsTrainer(Trainer):
  13. """
  14. DARTS trainer.
  15. Parameters
  16. ----------
  17. model : nn.Module
  18. PyTorch model to be trained.
  19. loss : callable
  20. Receives logits and ground truth label, return a loss tensor.
  21. metrics : callable
  22. Receives logits and ground truth label, return a dict of metrics.
  23. optimizer : Optimizer
  24. The optimizer used for optimizing the model.
  25. num_epochs : int
  26. Number of epochs planned for training.
  27. dataset_train : Dataset
  28. Dataset for training. Will be split for training weights and architecture weights.
  29. dataset_valid : Dataset
  30. Dataset for testing.
  31. mutator : DartsMutator
  32. Use in case of customizing your own DartsMutator. By default will instantiate a DartsMutator.
  33. batch_size : int
  34. Batch size.
  35. workers : int
  36. Workers for data loading.
  37. device : torch.device
  38. ``torch.device("cpu")`` or ``torch.device("cuda")``.
  39. log_frequency : int
  40. Step count per logging.
  41. callbacks : list of Callback
  42. list of callbacks to trigger at events.
  43. arch_lr : float
  44. Learning rate of architecture parameters.
  45. unrolled : float
  46. ``True`` if using second order optimization, else first order optimization.
  47. """
  48. def __init__(self, model, loss, metrics,
  49. optimizer, num_epochs, dataset_train, dataset_valid, search_space_path, result_path, num_pre_epochs=0,
  50. mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
  51. callbacks=None, arch_lr=3.0E-4, unrolled=False):
  52. super().__init__(model, mutator if mutator is not None else DartsMutator(model),
  53. loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
  54. batch_size, workers, device, log_frequency, callbacks)
  55. self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arch_lr, betas=(0.5, 0.999), weight_decay=1.0E-3)
  56. self.unrolled = unrolled
  57. self.num_pre_epoches = num_pre_epochs
  58. self.result_path = result_path
  59. with open(self.result_path, "w") as file:
  60. file.write('')
  61. n_train = len(self.dataset_train)
  62. split = n_train // 2
  63. indices = list(range(n_train))
  64. train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
  65. valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
  66. self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
  67. batch_size=batch_size,
  68. sampler=train_sampler,
  69. num_workers=workers)
  70. self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
  71. batch_size=batch_size,
  72. sampler=valid_sampler,
  73. num_workers=workers)
  74. self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
  75. batch_size=batch_size,
  76. num_workers=workers)
  77. if search_space_path is not None:
  78. dump_global_result(search_space_path, self.mutator._generate_search_space())
  79. # self.result = {"Accuracy": []}
  80. def train_one_epoch(self, epoch):
  81. self.model.train()
  82. self.mutator.train()
  83. meters = AverageMeterGroup()
  84. # t1 = time()
  85. for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
  86. trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
  87. val_X, val_y = val_X.to(self.device), val_y.to(self.device)
  88. if epoch >= self.num_pre_epoches:
  89. # phase 1. architecture step
  90. self.ctrl_optim.zero_grad()
  91. if self.unrolled:
  92. self._unrolled_backward(trn_X, trn_y, val_X, val_y)
  93. else:
  94. self._backward(val_X, val_y)
  95. self.ctrl_optim.step()
  96. # phase 2: child network step
  97. self.optimizer.zero_grad()
  98. logits, loss = self._logits_and_loss(trn_X, trn_y)
  99. loss.backward()
  100. nn.utils.clip_grad_norm_(self.model.parameters(), 5.) # gradient clipping
  101. self.optimizer.step()
  102. metrics = self.metrics(logits, trn_y)
  103. metrics["loss"] = loss.item()
  104. meters.update(metrics)
  105. if self.log_frequency is not None and step % self.log_frequency == 0:
  106. logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
  107. self.num_epochs, step + 1, len(self.train_loader), meters)
  108. def validate_one_epoch(self, epoch, log_print=True):
  109. self.model.eval()
  110. self.mutator.eval()
  111. meters = AverageMeterGroup()
  112. with torch.no_grad():
  113. self.mutator.reset()
  114. for step, (X, y) in enumerate(self.test_loader):
  115. X, y = X.to(self.device), y.to(self.device)
  116. logits = self.model(X)
  117. metrics = self.metrics(logits, y)
  118. meters.update(metrics)
  119. if self.log_frequency is not None and step % self.log_frequency == 0:
  120. logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
  121. self.num_epochs, step + 1, len(self.test_loader), meters)
  122. if log_print:
  123. # 后端在终端过滤,{"type": "Accuracy", "result": {"sequence": 1, "category": "epoch", "value":96.7}}
  124. logger.info({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": meters.get_last_acc()}})
  125. with open(self.result_path, "a") as file:
  126. file.write(str({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": meters.get_last_acc()}}) + '\n')
  127. # self.result["Accuracy"].append(meters.get_last_acc())
  128. def _logits_and_loss(self, X, y):
  129. self.mutator.reset()
  130. logits = self.model(X)
  131. loss = self.loss(logits, y)
  132. # self._write_graph_status()
  133. return logits, loss
  134. def _backward(self, val_X, val_y):
  135. """
  136. Simple backward with gradient descent
  137. """
  138. _, loss = self._logits_and_loss(val_X, val_y)
  139. loss.backward()
  140. def _unrolled_backward(self, trn_X, trn_y, val_X, val_y):
  141. """
  142. Compute unrolled loss and backward its gradients
  143. """
  144. backup_params = copy.deepcopy(tuple(self.model.parameters()))
  145. # do virtual step on training data
  146. lr = self.optimizer.param_groups[0]["lr"]
  147. momentum = self.optimizer.param_groups[0]["momentum"]
  148. weight_decay = self.optimizer.param_groups[0]["weight_decay"]
  149. self._compute_virtual_model(trn_X, trn_y, lr, momentum, weight_decay)
  150. # calculate unrolled loss on validation data
  151. # keep gradients for model here for compute hessian
  152. _, loss = self._logits_and_loss(val_X, val_y)
  153. w_model, w_ctrl = tuple(self.model.parameters()), tuple(self.mutator.parameters())
  154. w_grads = torch.autograd.grad(loss, w_model + w_ctrl)
  155. d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):]
  156. # compute hessian and final gradients
  157. hessian = self._compute_hessian(backup_params, d_model, trn_X, trn_y)
  158. with torch.no_grad():
  159. for param, d, h in zip(w_ctrl, d_ctrl, hessian):
  160. # gradient = dalpha - lr * hessian
  161. param.grad = d - lr * h
  162. # restore weights
  163. self._restore_weights(backup_params)
  164. def _compute_virtual_model(self, X, y, lr, momentum, weight_decay):
  165. """
  166. Compute unrolled weights w`
  167. """
  168. # don't need zero_grad, using autograd to calculate gradients
  169. _, loss = self._logits_and_loss(X, y)
  170. gradients = torch.autograd.grad(loss, self.model.parameters())
  171. with torch.no_grad():
  172. for w, g in zip(self.model.parameters(), gradients):
  173. m = self.optimizer.state[w].get("momentum_buffer", 0.)
  174. w = w - lr * (momentum * m + g + weight_decay * w)
  175. def _restore_weights(self, backup_params):
  176. with torch.no_grad():
  177. for param, backup in zip(self.model.parameters(), backup_params):
  178. param.copy_(backup)
  179. def _compute_hessian(self, backup_params, dw, trn_X, trn_y):
  180. """
  181. dw = dw` { L_val(w`, alpha) }
  182. w+ = w + eps * dw
  183. w- = w - eps * dw
  184. hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
  185. eps = 0.01 / ||dw||
  186. """
  187. self._restore_weights(backup_params)
  188. norm = torch.cat([w.view(-1) for w in dw]).norm()
  189. eps = 0.01 / norm
  190. if norm < 1E-8:
  191. logger.warning("In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.", norm.item())
  192. dalphas = []
  193. for e in [eps, -2. * eps]:
  194. # w+ = w + eps*dw`, w- = w - eps*dw`
  195. with torch.no_grad():
  196. for p, d in zip(self.model.parameters(), dw):
  197. p += e * d
  198. _, loss = self._logits_and_loss(trn_X, trn_y)
  199. dalphas.append(torch.autograd.grad(loss, self.mutator.parameters()))
  200. dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
  201. hessian = [(p - n) / (2. * eps) for p, n in zip(dalpha_pos, dalpha_neg)]
  202. return hessian

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能