You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pcdarts_retrain.py 9.3 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import sys
  2. sys.path.append('..'+ '/' + '..')
  3. import os
  4. import logging
  5. import time
  6. from argparse import ArgumentParser
  7. import torch
  8. import torch.nn as nn
  9. import numpy as np
  10. # from torch.utils.tensorboard import SummaryWriter
  11. import torch.backends.cudnn as cudnn
  12. from model import CNN
  13. from pytorch.fixed import apply_fixed_architecture
  14. from pytorch.utils import set_seed, mkdirs, init_logger, save_best_checkpoint, AverageMeter
  15. from pytorch.darts import utils
  16. from pytorch.darts import datasets
  17. from pytorch.retrainer import Retrainer
  18. logger = logging.getLogger(__name__)
  19. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  20. # writer = SummaryWriter()
  21. class PCdartsRetrainer(Retrainer):
  22. def __init__(self, aux_weight, grad_clip, epochs, log_frequency):
  23. self.aux_weight = aux_weight
  24. self.grad_clip = grad_clip
  25. self.epochs = epochs
  26. self.log_frequency = log_frequency
  27. def train(self, train_loader, model, optimizer, criterion, epoch):
  28. top1 = AverageMeter("top1")
  29. top5 = AverageMeter("top5")
  30. losses = AverageMeter("losses")
  31. cur_step = epoch * len(train_loader)
  32. cur_lr = optimizer.param_groups[0]["lr"]
  33. logger.info("Epoch %d LR %.6f", epoch, cur_lr)
  34. # writer.add_scalar("lr", cur_lr, global_step=cur_step)
  35. model.train()
  36. for step, (x, y) in enumerate(train_loader):
  37. x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
  38. bs = x.size(0)
  39. optimizer.zero_grad()
  40. logits, aux_logits = model(x)
  41. loss = criterion(logits, y)
  42. if self.aux_weight > 0.:
  43. loss += self.aux_weight * criterion(aux_logits, y)
  44. loss.backward()
  45. # gradient clipping
  46. nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip)
  47. optimizer.step()
  48. accuracy = utils.accuracy(logits, y, topk=(1, 5))
  49. losses.update(loss.item(), bs)
  50. top1.update(accuracy["acc1"], bs)
  51. top5.update(accuracy["acc5"], bs)
  52. # writer.add_scalar("loss/train", loss.item(), global_step=cur_step)
  53. # writer.add_scalar("acc1/train", accuracy["acc1"], global_step=cur_step)
  54. # writer.add_scalar("acc5/train", accuracy["acc5"], global_step=cur_step)
  55. if step % self.log_frequency == 0 or step == len(train_loader) - 1:
  56. logger.info(
  57. "Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
  58. "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
  59. epoch + 1, self.epochs, step, len(train_loader) - 1, losses=losses,
  60. top1=top1, top5=top5))
  61. cur_step += 1
  62. logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg))
  63. def validate(self, valid_loader, model, criterion, epoch, cur_step):
  64. top1 = AverageMeter("top1")
  65. top5 = AverageMeter("top5")
  66. losses = AverageMeter("losses")
  67. model.eval()
  68. with torch.no_grad():
  69. for step, (X, y) in enumerate(valid_loader):
  70. X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
  71. bs = X.size(0)
  72. logits = model(X)
  73. loss = criterion(logits, y)
  74. accuracy = utils.accuracy(logits, y, topk=(1, 5))
  75. losses.update(loss.item(), bs)
  76. top1.update(accuracy["acc1"], bs)
  77. top5.update(accuracy["acc5"], bs)
  78. if step % self.log_frequency == 0 or step == len(valid_loader) - 1:
  79. logger.info(
  80. "Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
  81. "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
  82. epoch + 1, self.epochs, step, len(valid_loader) - 1, losses=losses,
  83. top1=top1, top5=top5))
  84. # writer.add_scalar("loss/test", losses.avg, global_step=cur_step)
  85. # writer.add_scalar("acc1/test", top1.avg, global_step=cur_step)
  86. # writer.add_scalar("acc5/test", top5.avg, global_step=cur_step)
  87. logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, self.epochs, top1.avg))
  88. return top1.avg
  89. if __name__ == "__main__":
  90. parser = ArgumentParser("PCDARTS retrain")
  91. parser.add_argument("--data_dir", type=str,
  92. default='./', help="search_space json file")
  93. parser.add_argument("--result_path", type=str,
  94. default='./result.json', help="training result")
  95. parser.add_argument("--log_path", type=str,
  96. default='.0/log', help="log for info")
  97. parser.add_argument("--best_selected_space_path", type=str,
  98. default='./best_selected_space.json', help="final best selected space")
  99. parser.add_argument("--best_checkpoint_dir", type=str,
  100. default='', help="default name is best_checkpoint_epoch{}.pth")
  101. parser.add_argument('--trial_id', type=int, default=0, metavar='N',
  102. help='trial_id,start from 0')
  103. parser.add_argument("--layers", default=20, type=int)
  104. parser.add_argument("--lr", default=0.01, type=float)
  105. parser.add_argument("--batch_size", default=96, type=int)
  106. parser.add_argument("--log_frequency", default=10, type=int)
  107. parser.add_argument("--epochs", default=600, type=int)
  108. parser.add_argument("--aux_weight", default=0.4, type=float)
  109. parser.add_argument("--drop_path_prob", default=0.2, type=float)
  110. parser.add_argument("--workers", default=4, type=int)
  111. parser.add_argument("--class_num", default=10, type=int, help="cifar10")
  112. parser.add_argument("--channels", default=36, type=int)
  113. parser.add_argument("--grad_clip", default=6., type=float)
  114. args = parser.parse_args()
  115. mkdirs(args.result_path, args.log_path, args.best_checkpoint_dir)
  116. init_logger(args.log_path)
  117. logger.info(args)
  118. set_seed(args.trial_id)
  119. logger.info("loading data")
  120. dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16, root=args.data_dir)
  121. model = CNN(32, 3, args.channels, args.class_num, args.layers, auxiliary=True, search=False)
  122. apply_fixed_architecture(model, args.best_selected_space_path)
  123. criterion = nn.CrossEntropyLoss()
  124. model.to(device)
  125. criterion.to(device)
  126. optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=0.9, weight_decay=3.0E-4)
  127. lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=1E-6)
  128. train_loader = torch.utils.data.DataLoader(dataset_train,
  129. batch_size=args.batch_size,
  130. shuffle=True,
  131. num_workers=args.workers,
  132. pin_memory=True)
  133. valid_loader = torch.utils.data.DataLoader(dataset_valid,
  134. batch_size=args.batch_size,
  135. shuffle=False,
  136. num_workers=args.workers,
  137. pin_memory=True)
  138. retrainer = PCdartsRetrainer(aux_weight=args.aux_weight,
  139. grad_clip=args.grad_clip,
  140. epochs=args.epochs,
  141. log_frequency = args.log_frequency)
  142. # result = {"Accuracy": [], "Cost_time": ''}
  143. best_top1 = 0.
  144. start_time = time.time()
  145. for epoch in range(args.epochs):
  146. drop_prob = args.drop_path_prob * epoch / args.epochs
  147. model.drop_path_prob(drop_prob)
  148. # training
  149. retrainer.train(train_loader, model, optimizer, criterion, epoch)
  150. # validation
  151. cur_step = (epoch + 1) * len(train_loader)
  152. top1 = retrainer.validate(valid_loader, model, criterion, epoch, cur_step)
  153. # 后端在终端过滤,{"type": "Accuracy", "result": {"sequence": 1, "category": "epoch", "value":96.7}}
  154. logger.info({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}})
  155. with open(args.result_path, "a") as file:
  156. file.write(str({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": top1}}) + '\n')
  157. # result["Accuracy"].append(top1)
  158. best_top1 = max(best_top1, top1)
  159. lr_scheduler.step()
  160. logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
  161. cost_time = time.time() - start_time
  162. # 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}}
  163. logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})
  164. with open(args.result_path, "a") as file:
  165. file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}))
  166. # result["Cost_time"] = str(cost_time) + ' s'
  167. # dump_global_result(args.result_path, result)
  168. save_best_checkpoint(args.best_checkpoint_dir, model, optimizer, epoch)
  169. logger.info("Save best checkpoint in {}".format(os.path.join(args.best_checkpoint_dir, "best_checkpoint_epoch{}.pth".format(epoch))))

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能