You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

supernet.py 4.2 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import argparse
  2. import logging
  3. import random
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from pytorch.callbacks import LRSchedulerCallback
  8. from pytorch.callbacks import ModelCheckpoint
  9. from algorithms.spos import SPOSSupernetTrainingMutator, SPOSSupernetTrainer
  10. from dataloader import get_imagenet_iter_dali
  11. from network import ShuffleNetV2OneShot, load_and_parse_state_dict
  12. from utils import CrossEntropyLabelSmooth, accuracy
  13. import os
  14. os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
  15. logger = logging.getLogger("nni.spos.supernet")
  16. if __name__ == "__main__":
  17. parser = argparse.ArgumentParser("SPOS Supernet Training")
  18. # 数据的路径需要修改,由于home的容量较小,数据存储在local下面
  19. # default="./data/imagenet"
  20. parser.add_argument("--imagenet-dir", type=str, default="/mnt/local/imagenet")
  21. parser.add_argument("--load-checkpoint", action="store_true", default=False)
  22. parser.add_argument("--spos-preprocessing", action="store_true", default=False,
  23. help="When true, image values will range from 0 to 255 and use BGR "
  24. "(as in original repo).")
  25. parser.add_argument("--workers", type=int, default=4)
  26. parser.add_argument("--batch-size", type=int, default=512) # 原始大小为768
  27. parser.add_argument("--epochs", type=int, default=120) # 原始大小是120
  28. parser.add_argument("--learning-rate", type=float, default=0.5)
  29. parser.add_argument("--momentum", type=float, default=0.9)
  30. parser.add_argument("--weight-decay", type=float, default=4E-5)
  31. parser.add_argument("--label-smooth", type=float, default=0.1)
  32. parser.add_argument("--log-frequency", type=int, default=10)
  33. parser.add_argument("--seed", type=int, default=42)
  34. parser.add_argument("--label-smoothing", type=float, default=0.1)
  35. args = parser.parse_args()
  36. torch.manual_seed(args.seed)
  37. torch.cuda.manual_seed_all(args.seed)
  38. np.random.seed(args.seed)
  39. random.seed(args.seed)
  40. torch.backends.cudnn.deterministic = True
  41. model = ShuffleNetV2OneShot()
  42. flops_func = model.get_candidate_flops
  43. if args.load_checkpoint:
  44. if not args.spos_preprocessing:
  45. logger.warning("You might want to use SPOS preprocessing if you are loading their checkpoints.")
  46. model.load_state_dict(load_and_parse_state_dict())
  47. model.cuda()
  48. if torch.cuda.device_count() > 1: # exclude last gpu, saving for data preprocessing on gpu
  49. model = nn.DataParallel(model, device_ids=list(range(0, torch.cuda.device_count() - 1)))
  50. mutator = SPOSSupernetTrainingMutator(model, flops_func=flops_func,
  51. flops_lb=290E6, flops_ub=360E6)
  52. criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing)
  53. optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate,
  54. momentum=args.momentum, weight_decay=args.weight_decay)
  55. scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
  56. lambda step: (1.0 - step / args.epochs)
  57. if step <= args.epochs else 0,
  58. last_epoch=-1)
  59. device_id = 3
  60. train_loader = get_imagenet_iter_dali("train", args.imagenet_dir, args.batch_size, args.workers,
  61. spos_preprocessing=args.spos_preprocessing,device_id=device_id)
  62. valid_loader = get_imagenet_iter_dali("val", args.imagenet_dir, args.batch_size, args.workers,
  63. spos_preprocessing=args.spos_preprocessing,device_id=device_id)
  64. trainer = SPOSSupernetTrainer(model, criterion, accuracy, optimizer,
  65. args.epochs, train_loader, valid_loader,
  66. mutator=mutator, batch_size=args.batch_size,
  67. log_frequency=args.log_frequency, workers=args.workers,
  68. callbacks=[LRSchedulerCallback(scheduler),
  69. ModelCheckpoint("./checkpoints")])
  70. trainer.train()

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能