You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

scratch.py 6.7 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import os
  2. import argparse
  3. import logging
  4. import random
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. from dataloader import get_imagenet_iter_dali
  9. from pytorch.fixed import apply_fixed_architecture
  10. from pytorch.utils import AverageMeterGroup
  11. from torch.utils.tensorboard import SummaryWriter
  12. # import torch.distributed as dist
  13. from network import ShuffleNetV2OneShot
  14. from utils import CrossEntropyLabelSmooth, accuracy
  15. # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
  16. logger = logging.getLogger("nni.spos.scratch")
  17. def train(epoch, model, criterion, optimizer, loader, writer, args):
  18. model.train()
  19. meters = AverageMeterGroup()
  20. cur_lr = optimizer.param_groups[0]["lr"]
  21. for step, (x, y) in enumerate(loader):
  22. cur_step = len(loader) * epoch + step
  23. optimizer.zero_grad()
  24. logits = model(x)
  25. loss = criterion(logits, y)
  26. loss.backward()
  27. optimizer.step()
  28. metrics = accuracy(logits, y)
  29. metrics["loss"] = loss.item()
  30. meters.update(metrics)
  31. writer.add_scalar("lr", cur_lr, global_step=cur_step)
  32. writer.add_scalar("loss/train", loss.item(), global_step=cur_step)
  33. writer.add_scalar("acc1/train", metrics["acc1"], global_step=cur_step)
  34. writer.add_scalar("acc5/train", metrics["acc5"], global_step=cur_step)
  35. if step % args.log_frequency == 0 or step + 1 == len(loader):
  36. logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1,
  37. args.epochs, step + 1, len(loader), meters)
  38. if step > len(loader):
  39. break
  40. logger.info("Epoch %d training summary: %s", epoch + 1, meters)
  41. def validate(epoch, model, criterion, loader, writer, args):
  42. model.eval()
  43. meters = AverageMeterGroup()
  44. with torch.no_grad():
  45. for step, (x, y) in enumerate(loader):
  46. logits = model(x)
  47. loss = criterion(logits, y)
  48. metrics = accuracy(logits, y)
  49. metrics["loss"] = loss.item()
  50. meters.update(metrics)
  51. if step % args.log_frequency == 0 or step + 1 == len(loader):
  52. logger.info("Epoch [%d/%d] Validation Step [%d/%d] %s", epoch + 1,
  53. args.epochs, step + 1, len(loader), meters)
  54. if step > len(loader):
  55. break
  56. writer.add_scalar("loss/test", meters.loss.avg, global_step=epoch)
  57. writer.add_scalar("acc1/test", meters.acc1.avg, global_step=epoch)
  58. writer.add_scalar("acc5/test", meters.acc5.avg, global_step=epoch)
  59. logger.info("Epoch %d validation: top1 = %f, top5 = %f", epoch + 1, meters.acc1.avg, meters.acc5.avg)
  60. def dump_checkpoint(model, epoch, checkpoint_dir):
  61. if isinstance(model, nn.DataParallel):
  62. state_dict = model.module.state_dict()
  63. else:
  64. state_dict = model.state_dict()
  65. if not os.path.exists(checkpoint_dir):
  66. os.makedirs(checkpoint_dir)
  67. dest_path = os.path.join(checkpoint_dir, "epoch_{}.pth.tar".format(epoch))
  68. logger.info("Saving model to %s", dest_path)
  69. torch.save(state_dict, dest_path)
  70. if __name__ == "__main__":
  71. parser = argparse.ArgumentParser("SPOS Training From Scratch")
  72. parser.add_argument("--imagenet-dir", type=str, default="/mnt/local/hanjiayi/imagenet") # ./data/imagenet
  73. parser.add_argument("--tb-dir", type=str, default="runs")
  74. parser.add_argument("--architecture", type=str, default="./checkpoints/037_034.json") # "architecture_final.json"
  75. parser.add_argument("--workers", type=int, default=4)
  76. parser.add_argument("--batch-size", type=int, default=1024)
  77. parser.add_argument("--epochs", type=int, default=240)
  78. parser.add_argument("--learning-rate", type=float, default=0.5)
  79. parser.add_argument("--momentum", type=float, default=0.9)
  80. parser.add_argument("--weight-decay", type=float, default=4E-5)
  81. parser.add_argument("--label-smooth", type=float, default=0.1)
  82. parser.add_argument("--log-frequency", type=int, default=10)
  83. parser.add_argument("--lr-decay", type=str, default="linear")
  84. parser.add_argument("--seed", type=int, default=42)
  85. parser.add_argument("--spos-preprocessing", default=False, action="store_true")
  86. parser.add_argument("--label-smoothing", type=float, default=0.1)
  87. parser.add_argument("--local_rank", default=[0,1,2,3])
  88. args = parser.parse_args()
  89. torch.manual_seed(args.seed)
  90. torch.cuda.manual_seed_all(args.seed)
  91. np.random.seed(args.seed)
  92. random.seed(args.seed)
  93. torch.backends.cudnn.deterministic = True
  94. model = ShuffleNetV2OneShot(affine=True)
  95. model.cuda("cuda:0")
  96. apply_fixed_architecture(model, args.architecture)
  97. # state_dict是否发生变化
  98. state_dict = model.state_dict()
  99. # todo DDP并行的一些设置
  100. # dist.init_process_group(backend = "nccl")
  101. if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu
  102. model = nn.DataParallel(model,
  103. device_ids=list(range(0, torch.cuda.device_count() - 1))) # todo # device_ids=list(range(0, torch.cuda.device_count() - 1))
  104. # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=args.local_rank)
  105. criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing)
  106. optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate,
  107. momentum=args.momentum, weight_decay=args.weight_decay)
  108. if args.lr_decay == "linear":
  109. scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
  110. lambda step: (1.0 - step / args.epochs)
  111. if step <= args.epochs else 0,
  112. last_epoch=-1)
  113. elif args.lr_decay == "cosine":
  114. scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, 1E-3)
  115. else:
  116. raise ValueError("'%s' not supported." % args.lr_decay)
  117. writer = SummaryWriter(log_dir=args.tb_dir)
  118. train_loader = get_imagenet_iter_dali("train", args.imagenet_dir, args.batch_size, args.workers,
  119. spos_preprocessing=args.spos_preprocessing)
  120. val_loader = get_imagenet_iter_dali("val", args.imagenet_dir, args.batch_size, args.workers,
  121. spos_preprocessing=args.spos_preprocessing)
  122. for epoch in range(args.epochs):
  123. train(epoch, model, criterion, optimizer, train_loader, writer, args)
  124. validate(epoch, model, criterion, val_loader, writer, args)
  125. scheduler.step()
  126. dump_checkpoint(model, epoch, "scratch_checkpoints")
  127. writer.close()

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能