You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluator.py 9.0 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import os
  2. import json
  3. import time
  4. import random
  5. import argparse
  6. import numpy as np
  7. # import logging
  8. from itertools import cycle
  9. import sys
  10. sys.path.append("../")
  11. sys.path.append("../../")
  12. import torch
  13. import torch.nn as nn
  14. os.environ["NNI_GEN_SEARCH_SPACE"] = "auto_gen_search_space.json"
  15. os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"
  16. from pytorch.fixed import apply_fixed_architecture
  17. from pytorch.utils import AverageMeterGroup
  18. from dataloader import get_imagenet_iter_dali
  19. from network import ShuffleNetV2OneShot, load_and_parse_state_dict
  20. from utils import CrossEntropyLabelSmooth, accuracy
  21. # logger = logging.getLogger("nni.spos.tester") # "nni.spos.tester"
  22. print("Evolution Beginning...")
  23. class Evaluator:
  24. """
  25. retrain the BN layer in specified model and evaluate it
  26. """
  27. def __init__(self, imagenet_dir="/mnt/local/hanjiayi/imagenet", # imagenet dataset
  28. checkpoint="./data/checkpoint-150000.pth.tar", # fine model from supernet
  29. spos_preprocessing=True, # RGB or BGR
  30. seed=42, # torch.manual_seed
  31. workers=1, # the number of subprocess
  32. train_batch_size=128,
  33. train_iters=200,
  34. test_batch_size=512,
  35. log_frequency=10,
  36. ):
  37. self.imagenet_dir = imagenet_dir
  38. self.checkpoint = checkpoint
  39. self.spos_preprocessing = spos_preprocessing
  40. self.seed = seed
  41. self.workers = workers
  42. self.train_batch_size = train_batch_size
  43. self.train_iters = train_iters
  44. self.test_batch_size = test_batch_size
  45. self.log_frequency = log_frequency
  46. print("### program interval 1 ###")
  47. self.model = ShuffleNetV2OneShot()
  48. print("### program interval 2 ###")
  49. print("## test&retrain -- load model ## begin to load model")
  50. self.model.load_state_dict(load_and_parse_state_dict(filepath=self.checkpoint))
  51. print("## test&retrain -- load model ## model loaded")
  52. torch.manual_seed(self.seed)
  53. torch.cuda.manual_seed_all(self.seed)
  54. np.random.seed(self.seed)
  55. random.seed(self.seed)
  56. torch.backends.cudnn.deterministic = True
  57. assert torch.cuda.is_available()
  58. self.criterion = CrossEntropyLabelSmooth(1000, 0.1)
  59. print("##### load training data #####")
  60. self.train_loader = get_imagenet_iter_dali("train", self.imagenet_dir, self.train_batch_size, self.workers,
  61. spos_preprocessing=self.spos_preprocessing,
  62. seed=self.seed, device_id=0)
  63. print("##### training data loaded finished #####")
  64. print("##### load validating data #####")
  65. self.val_loader = get_imagenet_iter_dali("val", self.imagenet_dir, self.test_batch_size, self.workers,
  66. spos_preprocessing=self.spos_preprocessing, shuffle=True,
  67. seed=self.seed, device_id=0)
  68. print("##### validating data loaded finished #####")
  69. def retrain_bn(self, model, criterion, max_iters, log_freq, loader):
  70. with torch.no_grad():
  71. # logger.info("Clear BN statistics...")
  72. print("clear BN statistics")
  73. for m in model.modules():
  74. if isinstance(m, nn.BatchNorm2d):
  75. m.running_mean = torch.zeros_like(m.running_mean)
  76. m.running_var = torch.ones_like(m.running_var)
  77. # logger.info("Train BN with training set (BN sanitize)...")
  78. print("Train BN with training set (BN sanitize)...")
  79. model.train()
  80. meters = AverageMeterGroup()
  81. start_time = time.time()
  82. for step in range(max_iters):
  83. inputs, targets = next(loader)
  84. logits = model(inputs)
  85. loss = criterion(logits, targets)
  86. metrics = accuracy(logits, targets)
  87. metrics["loss"] = loss.item()
  88. meters.update(metrics)
  89. if step % log_freq == 0 or step + 1 == max_iters:
  90. # logger.info("Train Step [%d/%d] %s time %.3fs ", step + 1, max_iters, meters, time.time() - start_time)
  91. print("Train Step [%d/%d] %s time %.3fs "% (step + 1, max_iters, meters, time.time() - start_time))
  92. def test_acc(self, model, criterion, log_freq, loader):
  93. # logger.info("Start testing...")
  94. print("start testing...")
  95. model.eval()
  96. meters = AverageMeterGroup()
  97. start_time = time.time()
  98. with torch.no_grad():
  99. for step, (inputs, targets) in enumerate(loader):
  100. logits = model(inputs)
  101. loss = criterion(logits, targets)
  102. metrics = accuracy(logits, targets)
  103. metrics["loss"] = loss.item()
  104. meters.update(metrics)
  105. if step % log_freq == 0 or step + 1 == len(loader):
  106. # logger.info("Valid Step [%d/%d] time %.3fs acc1 %.4f acc5 %.4f loss %.4f",
  107. # step + 1, len(loader), time.time() - start_time,
  108. # meters.acc1.avg, meters.acc5.avg, meters.loss.avg)
  109. print("Valid Step [%d/%d] time %.3fs acc1 %.4f acc5 %.4f loss %.4f"%
  110. (step + 1, len(loader), time.time() - start_time,
  111. meters.acc1.avg, meters.acc5.avg, meters.loss.avg))
  112. if step>len(loader): # 遍历一遍就停止
  113. break
  114. return meters.acc1.avg
  115. def evaluate_acc(self, model, criterion, loader_train, loader_test):
  116. self.retrain_bn(model, criterion, self.train_iters, self.log_frequency, loader_train) # todo
  117. acc = self.test_acc(model, criterion, self.log_frequency, loader_test)
  118. assert isinstance(acc, float)
  119. torch.cuda.empty_cache()
  120. return acc
  121. def eval_model(self, epoch, architecture):
  122. # evaluate the model
  123. print("## test&retrain -- apply architecture ## begin to apply architecture to model")
  124. apply_fixed_architecture(self.model, architecture)
  125. print("## test&retrain -- apply architecture ## architecture applied")
  126. self.model.cuda(0)
  127. self.train_loader = cycle(self.train_loader)
  128. acc = self.evaluate_acc(self.model, self.criterion, self.train_loader, self.val_loader)
  129. # 把模型最终的准确率写入一个文件中
  130. os.makedirs("./acc", exist_ok=True)
  131. with open("./acc/{}".format(architecture[-12:]), "w") as f: # [-12:] 代表没有路径的文件名
  132. # {filename1: acc,
  133. # filename2: acc,
  134. # 000_000.json: acc,
  135. # 000_001.json: acc,
  136. # ......
  137. # }
  138. json.dump({architecture: acc}, f)
  139. if __name__ == "__main__":
  140. parser = argparse.ArgumentParser("SPOS Candidate Evaluator")
  141. parser.add_argument("--imagenet-dir", type=str, default="/mnt/local/hanjiayi/imagenet") # ./data/imagenet
  142. parser.add_argument("--checkpoint", type=str, default="./data/checkpoint-150000.pth.tar") # ./data/checkpoint-150000.pth.tar
  143. parser.add_argument("--spos-preprocessing", default=True,
  144. help="When true, image values will range from 0 to 255 and use BGR "
  145. "(as in original repo).") # , action="store_true"
  146. parser.add_argument("--seed", type=int, default=42)
  147. parser.add_argument("--workers", type=int, default=1) # 线程数
  148. parser.add_argument("--train-batch-size", type=int, default=128)
  149. parser.add_argument("--train-iters", type=int, default=200)
  150. parser.add_argument("--test-batch-size", type=int, default=512) # nni中为512,官方repo为200
  151. parser.add_argument("--log-frequency", type=int, default=10)
  152. parser.add_argument("--architecture", type=str, default="./architecture_final.json", help="load the file to retrain or eval")
  153. parser.add_argument("--epoch", type=int, default=0, help="when epoch=0, this file should generate an architecture file")
  154. args = parser.parse_args()
  155. evl = Evaluator(imagenet_dir=args.imagenet_dir, # imagenet dataset
  156. checkpoint=args.checkpoint, # fine model from supernet
  157. spos_preprocessing=args.spos_preprocessing, # RGB or BGR
  158. seed=args.seed, # torch.manual_seed
  159. workers=args.workers, # the number of subprocess
  160. train_batch_size=args.train_batch_size,
  161. train_iters=args.train_iters,
  162. test_batch_size=args.test_batch_size,
  163. log_frequency=args.log_frequency,
  164. )
  165. evl.eval_model(args.epoch, args.architecture)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能