You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tester.py 7.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import argparse
  2. # import logging
  3. import random
  4. import time
  5. from itertools import cycle
  6. import json
  7. import sys
  8. sys.path.insert(0,"/home/hanjiayi/Document/nni")
  9. sys.path.insert(0,"/home/hanjiayi/Document/nni/examples")
  10. sys.path.insert(0,"/home/hanjiayi/Document/nni/nni/algorithms")
  11. import os
  12. # 指定写入文件的环境变量,原实现于nnictl_utils.py的search_space_auto_gen
  13. os.environ["NNI_GEN_SEARCH_SPACE"] = "auto_gen_search_space.json"
  14. os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"
  15. import numpy as np
  16. import torch
  17. import torch.nn as nn
  18. from nni.algorithms.nas.pytorch.classic_nas import get_and_apply_next_architecture
  19. from nni.nas.pytorch.fixed import apply_fixed_architecture
  20. from nni.nas.pytorch.utils import AverageMeterGroup
  21. from dataloader import get_imagenet_iter_dali
  22. from network import ShuffleNetV2OneShot, load_and_parse_state_dict
  23. from utils import CrossEntropyLabelSmooth, accuracy
  24. # logger = logging.getLogger("nni.spos.tester") # "nni.spos.tester"
  25. print("Evolution Beginning...")
  26. def retrain_bn(model, criterion, max_iters, log_freq, loader):
  27. with torch.no_grad():
  28. # logger.info("Clear BN statistics...")
  29. print("clear BN statistics")
  30. for m in model.modules():
  31. if isinstance(m, nn.BatchNorm2d):
  32. m.running_mean = torch.zeros_like(m.running_mean)
  33. m.running_var = torch.ones_like(m.running_var)
  34. # logger.info("Train BN with training set (BN sanitize)...")
  35. print("Train BN with training set (BN sanitize)...")
  36. model.train()
  37. meters = AverageMeterGroup()
  38. start_time = time.time()
  39. for step in range(max_iters):
  40. inputs, targets = next(loader)
  41. logits = model(inputs)
  42. loss = criterion(logits, targets)
  43. metrics = accuracy(logits, targets)
  44. metrics["loss"] = loss.item()
  45. meters.update(metrics)
  46. if step % log_freq == 0 or step + 1 == max_iters:
  47. # logger.info("Train Step [%d/%d] %s time %.3fs ", step + 1, max_iters, meters, time.time() - start_time)
  48. print("Train Step [%d/%d] %s time %.3fs "% (step + 1, max_iters, meters, time.time() - start_time))
  49. def test_acc(model, criterion, log_freq, loader):
  50. # logger.info("Start testing...")
  51. print("start testing...")
  52. model.eval()
  53. meters = AverageMeterGroup()
  54. start_time = time.time()
  55. with torch.no_grad():
  56. for step, (inputs, targets) in enumerate(loader):
  57. logits = model(inputs)
  58. loss = criterion(logits, targets)
  59. metrics = accuracy(logits, targets)
  60. metrics["loss"] = loss.item()
  61. meters.update(metrics)
  62. if step % log_freq == 0 or step + 1 == len(loader):
  63. # logger.info("Valid Step [%d/%d] time %.3fs acc1 %.4f acc5 %.4f loss %.4f",
  64. # step + 1, len(loader), time.time() - start_time,
  65. # meters.acc1.avg, meters.acc5.avg, meters.loss.avg)
  66. print("Valid Step [%d/%d] time %.3fs acc1 %.4f acc5 %.4f loss %.4f"%
  67. (step + 1, len(loader), time.time() - start_time,
  68. meters.acc1.avg, meters.acc5.avg, meters.loss.avg))
  69. if step>len(loader): # 遍历一遍就停止
  70. break
  71. return meters.acc1.avg
  72. def evaluate_acc(model, criterion, args, loader_train, loader_test):
  73. retrain_bn(model, criterion, args.train_iters, args.log_frequency, loader_train) # todo
  74. acc = test_acc(model, criterion, args.log_frequency, loader_test)
  75. assert isinstance(acc, float)
  76. torch.cuda.empty_cache()
  77. return acc
  78. if __name__ == "__main__":
  79. parser = argparse.ArgumentParser("SPOS Candidate Tester")
  80. parser.add_argument("--imagenet-dir", type=str, default="/mnt/local/hanjiayi/imagenet") # ./data/imagenet
  81. parser.add_argument("--checkpoint", type=str, default="./data/checkpoint-150000.pth.tar") # ./data/checkpoint-150000.pth.tar
  82. parser.add_argument("--spos-preprocessing", default=True,
  83. help="When true, image values will range from 0 to 255 and use BGR "
  84. "(as in original repo).") # , action="store_true"
  85. parser.add_argument("--seed", type=int, default=42)
  86. parser.add_argument("--workers", type=int, default=6) # 线程数
  87. parser.add_argument("--train-batch-size", type=int, default=128)
  88. parser.add_argument("--train-iters", type=int, default=200)
  89. parser.add_argument("--test-batch-size", type=int, default=512) # nni中为512,官方repo为200
  90. parser.add_argument("--log-frequency", type=int, default=10)
  91. parser.add_argument("--architecture", type=str, default="./architecture_final.json", help="load the file to retrain or eval")
  92. parser.add_argument("--mode", type=str, default="gen", help="there are two modes here: gen mode for generating architecture, and evl mode for evaluation model")
  93. args = parser.parse_args()
  94. # use a fixed set of image will improve the performance
  95. torch.manual_seed(args.seed)
  96. torch.cuda.manual_seed_all(args.seed)
  97. np.random.seed(args.seed)
  98. random.seed(args.seed)
  99. torch.backends.cudnn.deterministic = True
  100. assert torch.cuda.is_available()
  101. model = ShuffleNetV2OneShot()
  102. criterion = CrossEntropyLabelSmooth(1000, 0.1)
  103. if args.mode == "gen":
  104. get_and_apply_next_architecture(model)
  105. model.load_state_dict(load_and_parse_state_dict(filepath=args.checkpoint))
  106. else: # evaluate the model
  107. print("## test&retrain -- load model ## begin to load model")
  108. model.load_state_dict(load_and_parse_state_dict(filepath=args.checkpoint))
  109. print("## test&retrain -- load model ## model loaded")
  110. print("## test&retrain -- apply architecture ## begin to apply architecture to model")
  111. apply_fixed_architecture(model, args.architecture)
  112. print("## test&retrain -- apply architecture ## architecture applied")
  113. model.cuda(0)
  114. print("## test&retrain -- load train data ## begin to load train data")
  115. train_loader = get_imagenet_iter_dali("train", args.imagenet_dir, args.train_batch_size, args.workers,
  116. spos_preprocessing=args.spos_preprocessing,
  117. seed=args.seed, device_id=0)
  118. print("## test&retrain -- load train data ## train data loaded")
  119. print("## test&retrain -- load test data ## begin to load test data")
  120. val_loader = get_imagenet_iter_dali("val", args.imagenet_dir, args.test_batch_size, args.workers,
  121. spos_preprocessing=args.spos_preprocessing, shuffle=True,
  122. seed=args.seed, device_id=0)
  123. print("## test&retrain -- load test date ## test data loaded")
  124. train_loader = cycle(train_loader)
  125. acc = evaluate_acc(model, criterion, args, train_loader, val_loader)
  126. # 把模型最终的准确率写入一个文件中
  127. os.makedirs("./acc", exist_ok=True)
  128. with open("./acc/{}".format(args.architecture[-12:]), "w") as f: # [-12:] 代表没有路径的文件名
  129. # {filename1: acc,
  130. # filename2: acc,
  131. # 000_000.json: acc,
  132. # 000_001.json: acc,
  133. # ......
  134. # }
  135. json.dump({args.architecture: acc}, f)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能