You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

darts_retrain.py 9.3 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import sys
  2. sys.path.append('..'+ '/' + '..')
  3. import os
  4. import logging
  5. import time
  6. from argparse import ArgumentParser
  7. import torch
  8. import torch.nn as nn
  9. # from torch.utils.tensorboard import SummaryWriter
  10. import datasets
  11. import utils
  12. from model import CNN
  13. from pytorch.utils import set_seed, mkdirs, init_logger, save_best_checkpoint, AverageMeter
  14. from pytorch.fixed import apply_fixed_architecture
  15. from pytorch.retrainer import Retrainer
  16. logger = logging.getLogger(__name__)
  17. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  18. # writer = SummaryWriter()
  19. class DartsRetrainer(Retrainer):
  20. def __init__(self, aux_weight, grad_clip, epochs, log_frequency):
  21. self.aux_weight = aux_weight
  22. self.grad_clip = grad_clip
  23. self.epochs = epochs
  24. self.log_frequency = log_frequency
  25. def train(self, train_loader, model, optimizer, criterion, epoch):
  26. top1 = AverageMeter("top1")
  27. top5 = AverageMeter("top5")
  28. losses = AverageMeter("losses")
  29. cur_step = epoch * len(train_loader)
  30. cur_lr = optimizer.param_groups[0]["lr"]
  31. logger.info("Epoch %d LR %.6f", epoch, cur_lr)
  32. # writer.add_scalar("lr", cur_lr, global_step=cur_step)
  33. model.train()
  34. for step, (x, y) in enumerate(train_loader):
  35. x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
  36. bs = x.size(0)
  37. optimizer.zero_grad()
  38. logits, aux_logits = model(x)
  39. loss = criterion(logits, y)
  40. if self.aux_weight > 0.:
  41. loss += self.aux_weight * criterion(aux_logits, y)
  42. loss.backward()
  43. # gradient clipping
  44. nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip)
  45. optimizer.step()
  46. accuracy = utils.accuracy(logits, y, topk=(1, 5))
  47. losses.update(loss.item(), bs)
  48. top1.update(accuracy["acc1"], bs)
  49. top5.update(accuracy["acc5"], bs)
  50. # writer.add_scalar("loss/train", loss.item(), global_step=cur_step)
  51. # writer.add_scalar("acc1/train", accuracy["acc1"], global_step=cur_step)
  52. # writer.add_scalar("acc5/train", accuracy["acc5"], global_step=cur_step)
  53. if step % self.log_frequency == 0 or step == len(train_loader) - 1:
  54. logger.info(
  55. "Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
  56. "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
  57. epoch + 1, self.epochs, step, len(train_loader) - 1, losses=losses,
  58. top1=top1, top5=top5))
  59. cur_step += 1
  60. logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg))
  61. def validate(self, valid_loader, model, criterion, epoch, cur_step):
  62. top1 = AverageMeter("top1")
  63. top5 = AverageMeter("top5")
  64. losses = AverageMeter("losses")
  65. model.eval()
  66. with torch.no_grad():
  67. for step, (X, y) in enumerate(valid_loader):
  68. X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
  69. bs = X.size(0)
  70. logits = model(X)
  71. loss = criterion(logits, y)
  72. accuracy = utils.accuracy(logits, y, topk=(1, 5))
  73. losses.update(loss.item(), bs)
  74. top1.update(accuracy["acc1"], bs)
  75. top5.update(accuracy["acc5"], bs)
  76. if step % self.log_frequency == 0 or step == len(valid_loader) - 1:
  77. logger.info(
  78. "Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
  79. "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
  80. epoch + 1, self.epochs, step, len(valid_loader) - 1, losses=losses,
  81. top1=top1, top5=top5))
  82. # writer.add_scalar("loss/test", losses.avg, global_step=cur_step)
  83. # writer.add_scalar("acc1/test", top1.avg, global_step=cur_step)
  84. # writer.add_scalar("acc5/test", top5.avg, global_step=cur_step)
  85. logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg))
  86. return top1.avg
  87. if __name__ == "__main__":
  88. parser = ArgumentParser("DARTS retrain")
  89. parser.add_argument("--data_dir", type=str,
  90. default='./data/', help="search_space json file")
  91. parser.add_argument("--result_path", type=str,
  92. default='.0/result.json', help="training result")
  93. parser.add_argument("--log_path", type=str,
  94. default='.0/log', help="log for info")
  95. parser.add_argument("--best_selected_space_path", type=str,
  96. default='./best_selected_space.json', help="final best selected space")
  97. parser.add_argument("--best_checkpoint_dir", type=str,
  98. default='./', help="default name is best_checkpoint_epoch{}.pth")
  99. parser.add_argument('--trial_id', type=int, default=0, metavar='N',
  100. help='trial_id,start from 0')
  101. parser.add_argument("--layers", default=20, type=int)
  102. parser.add_argument("--lr", default=0.025, type=float)
  103. parser.add_argument("--batch_size", default=128, type=int)
  104. parser.add_argument("--log_frequency", default=10, type=int)
  105. parser.add_argument("--epochs", default=5, type=int)
  106. parser.add_argument("--aux_weight", default=0.4, type=float)
  107. parser.add_argument("--drop_path_prob", default=0.2, type=float)
  108. parser.add_argument("--workers", default=4, type=int)
  109. parser.add_argument("--channels", default=36, type=int)
  110. parser.add_argument("--grad_clip", default=5., type=float)
  111. parser.add_argument("--class_num", default=10, type=int, help="cifar10")
  112. args = parser.parse_args()
  113. mkdirs(args.result_path, args.log_path, args.best_checkpoint_dir)
  114. init_logger(args.log_path)
  115. logger.info(args)
  116. set_seed(args.trial_id)
  117. dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16, root=args.data_dir)
  118. model = CNN(32, 3, args.channels, args.class_num, args.layers, auxiliary=True)
  119. apply_fixed_architecture(model, args.best_selected_space_path)
  120. criterion = nn.CrossEntropyLoss()
  121. model.to(device)
  122. criterion.to(device)
  123. optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=0.9, weight_decay=3.0E-4)
  124. lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=1E-6)
  125. train_loader = torch.utils.data.DataLoader(dataset_train,
  126. batch_size=args.batch_size,
  127. shuffle=True,
  128. num_workers=args.workers,
  129. pin_memory=True)
  130. valid_loader = torch.utils.data.DataLoader(dataset_valid,
  131. batch_size=args.batch_size,
  132. shuffle=False,
  133. num_workers=args.workers,
  134. pin_memory=True)
  135. retrainer = DartsRetrainer(aux_weight=args.aux_weight,
  136. grad_clip=args.grad_clip,
  137. epochs=args.epochs,
  138. log_frequency = args.log_frequency)
  139. # result = {"Accuracy": [], "Cost_time": ''}
  140. best_top1 = 0.
  141. start_time = time.time()
  142. with open(args.result_path, "w") as file:
  143. file.write('')
  144. for epoch in range(args.epochs):
  145. drop_prob = args.drop_path_prob * epoch / args.epochs
  146. model.drop_path_prob(drop_prob)
  147. # training
  148. retrainer.train(train_loader, model, optimizer, criterion, epoch)
  149. # validation
  150. cur_step = (epoch + 1) * len(train_loader)
  151. top1 = retrainer.validate(valid_loader, model, criterion, epoch, cur_step)
  152. # 后端在终端过滤,{"type": "Accuracy", "result": {"sequence": 1, "category": "epoch", "value":96.7}}
  153. logger.info({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}})
  154. with open(args.result_path, "a") as file:
  155. file.write(str({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}}) + '\n')
  156. # result["Accuracy"].append(top1)
  157. best_top1 = max(best_top1, top1)
  158. lr_scheduler.step()
  159. logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
  160. cost_time = time.time() - start_time
  161. # 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}}
  162. logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})
  163. with open(args.result_path, "a") as file:
  164. file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}))
  165. # result["Cost_time"] = str(cost_time) + ' s'
  166. # dump_global_result(args.result_path, result)
  167. save_best_checkpoint(args.best_checkpoint_dir, model, optimizer, epoch)
  168. logger.info("Save best checkpoint in {}".format(os.path.join(args.best_checkpoint_dir, "best_checkpoint_epoch{}.pth".format(epoch))))

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能