You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pdarts_retrain.py 9.3 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import sys
  2. sys.path.append('..'+ '/' + '..')
  3. import os
  4. import logging
  5. import time
  6. import json
  7. from argparse import ArgumentParser
  8. import torch
  9. import torch.nn as nn
  10. # from torch.utils.tensorboard import SummaryWriter
  11. from model import CNN
  12. from pytorch.fixed import apply_fixed_architecture
  13. from pytorch.utils import set_seed, mkdirs, init_logger, save_best_checkpoint, AverageMeter
  14. from pytorch.darts import utils
  15. from pytorch.darts import datasets
  16. from pytorch.retrainer import Retrainer
  17. logger = logging.getLogger(__name__)
  18. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  19. # writer = SummaryWriter()
  20. class PdartsRetrainer(Retrainer):
  21. def __init__(self, aux_weight, grad_clip, epochs, log_frequency):
  22. self.aux_weight = aux_weight
  23. self.grad_clip = grad_clip
  24. self.epochs = epochs
  25. self.log_frequency = log_frequency
  26. def train(self, train_loader, model, optimizer, criterion, epoch):
  27. top1 = AverageMeter("top1")
  28. top5 = AverageMeter("top5")
  29. losses = AverageMeter("losses")
  30. cur_step = epoch * len(train_loader)
  31. cur_lr = optimizer.param_groups[0]["lr"]
  32. logger.info("Epoch %d LR %.6f", epoch, cur_lr)
  33. # writer.add_scalar("lr", cur_lr, global_step=cur_step)
  34. model.train()
  35. for step, (x, y) in enumerate(train_loader):
  36. x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
  37. bs = x.size(0)
  38. optimizer.zero_grad()
  39. logits, aux_logits = model(x)
  40. loss = criterion(logits, y)
  41. if self.aux_weight > 0.:
  42. loss += self.aux_weight * criterion(aux_logits, y)
  43. loss.backward()
  44. # gradient clipping
  45. nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip)
  46. optimizer.step()
  47. accuracy = utils.accuracy(logits, y, topk=(1, 5))
  48. losses.update(loss.item(), bs)
  49. top1.update(accuracy["acc1"], bs)
  50. top5.update(accuracy["acc5"], bs)
  51. # writer.add_scalar("loss/train", loss.item(), global_step=cur_step)
  52. # writer.add_scalar("acc1/train", accuracy["acc1"], global_step=cur_step)
  53. # writer.add_scalar("acc5/train", accuracy["acc5"], global_step=cur_step)
  54. if step % self.log_frequency == 0 or step == len(train_loader) - 1:
  55. logger.info(
  56. "Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
  57. "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
  58. epoch + 1, self.epochs, step, len(train_loader) - 1, losses=losses,
  59. top1=top1, top5=top5))
  60. cur_step += 1
  61. logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg))
  62. def validate(self, valid_loader, model, criterion, epoch, cur_step):
  63. top1 = AverageMeter("top1")
  64. top5 = AverageMeter("top5")
  65. losses = AverageMeter("losses")
  66. model.eval()
  67. with torch.no_grad():
  68. for step, (X, y) in enumerate(valid_loader):
  69. X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
  70. bs = X.size(0)
  71. logits = model(X)
  72. loss = criterion(logits, y)
  73. accuracy = utils.accuracy(logits, y, topk=(1, 5))
  74. losses.update(loss.item(), bs)
  75. top1.update(accuracy["acc1"], bs)
  76. top5.update(accuracy["acc5"], bs)
  77. if step % self.log_frequency == 0 or step == len(valid_loader) - 1:
  78. logger.info(
  79. "Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
  80. "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
  81. epoch + 1, self.epochs, step, len(valid_loader) - 1, losses=losses,
  82. top1=top1, top5=top5))
  83. # writer.add_scalar("loss/test", losses.avg, global_step=cur_step)
  84. # writer.add_scalar("acc1/test", top1.avg, global_step=cur_step)
  85. # writer.add_scalar("acc5/test", top5.avg, global_step=cur_step)
  86. logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg))
  87. return top1.avg
  88. if __name__ == "__main__":
  89. parser = ArgumentParser("Pdarts retrain")
  90. parser.add_argument("--data_dir", type=str,
  91. default='./', help="search_space json file")
  92. parser.add_argument("--result_path", type=str,
  93. default='./result.json', help="training result")
  94. parser.add_argument("--log_path", type=str,
  95. default='.0/log', help="log for info")
  96. parser.add_argument("--best_selected_space_path", type=str,
  97. default='./best_selected_space.json', help="final best selected space")
  98. parser.add_argument("--best_checkpoint_dir", type=str,
  99. default='', help="default name is best_checkpoint_epoch{}.pth")
  100. parser.add_argument('--trial_id', type=int, default=0, metavar='N',
  101. help='trial_id,start from 0')
  102. parser.add_argument("--layers", default=20, type=int)
  103. parser.add_argument("--batch_size", default=96, type=int)
  104. parser.add_argument("--log_frequency", default=10, type=int)
  105. parser.add_argument("--epochs", default=600, type=int)
  106. parser.add_argument("--lr", default=0.025, type=float)
  107. parser.add_argument("--channels", default=36, type=int)
  108. parser.add_argument("--aux_weight", default=0.4, type=float)
  109. parser.add_argument("--drop_path_prob", default=0.3, type=float)
  110. parser.add_argument("--workers", default=4)
  111. parser.add_argument("--grad_clip", default=5., type=float)
  112. args = parser.parse_args()
  113. mkdirs(args.result_path, args.log_path, args.best_checkpoint_dir)
  114. init_logger(args.log_path)
  115. logger.info(args)
  116. set_seed(args.trial_id)
  117. logger.info("loading data")
  118. dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16, root=args.data_dir)
  119. model = CNN(32, 3, 36, 10, args.layers, auxiliary=True, search=False, dropout_rate=0.0)
  120. if isinstance(args.best_selected_space_path, str):
  121. with open(args.best_selected_space_path) as f:
  122. fixed_arc = json.load(f)
  123. apply_fixed_architecture(model, fixed_arc=fixed_arc["best_selected_space"])
  124. criterion = nn.CrossEntropyLoss()
  125. model.to(device)
  126. criterion.to(device)
  127. optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=0.9, weight_decay=3.0E-4)
  128. lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=1E-6)
  129. train_loader = torch.utils.data.DataLoader(dataset_train,
  130. batch_size=args.batch_size,
  131. shuffle=True,
  132. num_workers=args.workers,
  133. pin_memory=True)
  134. valid_loader = torch.utils.data.DataLoader(dataset_valid,
  135. batch_size=args.batch_size,
  136. shuffle=False,
  137. num_workers=args.workers,
  138. pin_memory=True)
  139. retrainer = PdartsRetrainer(aux_weight=args.aux_weight,
  140. grad_clip=args.grad_clip,
  141. epochs=args.epochs,
  142. log_frequency = args.log_frequency)
  143. best_top1 = 0.
  144. start_time = time.time()
  145. for epoch in range(args.epochs):
  146. drop_prob = args.drop_path_prob * epoch / args.epochs
  147. model.drop_path_prob(drop_prob)
  148. # training
  149. retrainer.train(train_loader, model, optimizer, criterion, epoch)
  150. # validation
  151. cur_step = (epoch + 1) * len(train_loader)
  152. top1 = retrainer.validate(valid_loader, model, criterion, epoch, cur_step)
  153. # 后端在终端过滤,{"type": "Accuracy", "result": {"sequence": 1, "category": "epoch", "value":96.7}}
  154. logger.info({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}})
  155. with open(args.result_path, "a") as file:
  156. file.write(str({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}}) + '\n')
  157. best_top1 = max(best_top1, top1)
  158. lr_scheduler.step()
  159. logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
  160. cost_time = time.time() - start_time
  161. # 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}}
  162. logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})
  163. with open(args.result_path, "a") as file:
  164. file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}))
  165. # result["Cost_time"] = str(cost_time) + ' s'
  166. # dump_global_result(args.result_path, result)
  167. save_best_checkpoint(args.best_checkpoint_dir, model, optimizer, epoch)
  168. logger.info("Save best checkpoint in {}".format(os.path.join(args.best_checkpoint_dir, "best_checkpoint_epoch{}.pth".format(epoch))))

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能