import argparse # import logging import random import time from itertools import cycle import json import sys sys.path.insert(0,"/home/hanjiayi/Document/nni") sys.path.insert(0,"/home/hanjiayi/Document/nni/examples") sys.path.insert(0,"/home/hanjiayi/Document/nni/nni/algorithms") import os # 指定写入文件的环境变量,原实现于nnictl_utils.py的search_space_auto_gen os.environ["NNI_GEN_SEARCH_SPACE"] = "auto_gen_search_space.json" os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3" import numpy as np import torch import torch.nn as nn from nni.algorithms.nas.pytorch.classic_nas import get_and_apply_next_architecture from nni.nas.pytorch.fixed import apply_fixed_architecture from nni.nas.pytorch.utils import AverageMeterGroup from dataloader import get_imagenet_iter_dali from network import ShuffleNetV2OneShot, load_and_parse_state_dict from utils import CrossEntropyLabelSmooth, accuracy # logger = logging.getLogger("nni.spos.tester") # "nni.spos.tester" print("Evolution Beginning...") def retrain_bn(model, criterion, max_iters, log_freq, loader): with torch.no_grad(): # logger.info("Clear BN statistics...") print("clear BN statistics") for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.running_mean = torch.zeros_like(m.running_mean) m.running_var = torch.ones_like(m.running_var) # logger.info("Train BN with training set (BN sanitize)...") print("Train BN with training set (BN sanitize)...") model.train() meters = AverageMeterGroup() start_time = time.time() for step in range(max_iters): inputs, targets = next(loader) logits = model(inputs) loss = criterion(logits, targets) metrics = accuracy(logits, targets) metrics["loss"] = loss.item() meters.update(metrics) if step % log_freq == 0 or step + 1 == max_iters: # logger.info("Train Step [%d/%d] %s time %.3fs ", step + 1, max_iters, meters, time.time() - start_time) print("Train Step [%d/%d] %s time %.3fs "% (step + 1, max_iters, meters, time.time() - start_time)) def test_acc(model, criterion, log_freq, loader): # logger.info("Start testing...") print("start testing...") model.eval() meters = AverageMeterGroup() start_time = time.time() with torch.no_grad(): for step, (inputs, targets) in enumerate(loader): logits = model(inputs) loss = criterion(logits, targets) metrics = accuracy(logits, targets) metrics["loss"] = loss.item() meters.update(metrics) if step % log_freq == 0 or step + 1 == len(loader): # logger.info("Valid Step [%d/%d] time %.3fs acc1 %.4f acc5 %.4f loss %.4f", # step + 1, len(loader), time.time() - start_time, # meters.acc1.avg, meters.acc5.avg, meters.loss.avg) print("Valid Step [%d/%d] time %.3fs acc1 %.4f acc5 %.4f loss %.4f"% (step + 1, len(loader), time.time() - start_time, meters.acc1.avg, meters.acc5.avg, meters.loss.avg)) if step>len(loader): # 遍历一遍就停止 break return meters.acc1.avg def evaluate_acc(model, criterion, args, loader_train, loader_test): retrain_bn(model, criterion, args.train_iters, args.log_frequency, loader_train) # todo acc = test_acc(model, criterion, args.log_frequency, loader_test) assert isinstance(acc, float) torch.cuda.empty_cache() return acc if __name__ == "__main__": parser = argparse.ArgumentParser("SPOS Candidate Tester") parser.add_argument("--imagenet-dir", type=str, default="/mnt/local/hanjiayi/imagenet") # ./data/imagenet parser.add_argument("--checkpoint", type=str, default="./data/checkpoint-150000.pth.tar") # ./data/checkpoint-150000.pth.tar parser.add_argument("--spos-preprocessing", default=True, help="When true, image values will range from 0 to 255 and use BGR " "(as in original repo).") # , action="store_true" parser.add_argument("--seed", type=int, default=42) parser.add_argument("--workers", type=int, default=6) # 线程数 parser.add_argument("--train-batch-size", type=int, default=128) parser.add_argument("--train-iters", type=int, default=200) parser.add_argument("--test-batch-size", type=int, default=512) # nni中为512,官方repo为200 parser.add_argument("--log-frequency", type=int, default=10) parser.add_argument("--architecture", type=str, default="./architecture_final.json", help="load the file to retrain or eval") parser.add_argument("--mode", type=str, default="gen", help="there are two modes here: gen mode for generating architecture, and evl mode for evaluation model") args = parser.parse_args() # use a fixed set of image will improve the performance torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) np.random.seed(args.seed) random.seed(args.seed) torch.backends.cudnn.deterministic = True assert torch.cuda.is_available() model = ShuffleNetV2OneShot() criterion = CrossEntropyLabelSmooth(1000, 0.1) if args.mode == "gen": get_and_apply_next_architecture(model) model.load_state_dict(load_and_parse_state_dict(filepath=args.checkpoint)) else: # evaluate the model print("## test&retrain -- load model ## begin to load model") model.load_state_dict(load_and_parse_state_dict(filepath=args.checkpoint)) print("## test&retrain -- load model ## model loaded") print("## test&retrain -- apply architecture ## begin to apply architecture to model") apply_fixed_architecture(model, args.architecture) print("## test&retrain -- apply architecture ## architecture applied") model.cuda(0) print("## test&retrain -- load train data ## begin to load train data") train_loader = get_imagenet_iter_dali("train", args.imagenet_dir, args.train_batch_size, args.workers, spos_preprocessing=args.spos_preprocessing, seed=args.seed, device_id=0) print("## test&retrain -- load train data ## train data loaded") print("## test&retrain -- load test data ## begin to load test data") val_loader = get_imagenet_iter_dali("val", args.imagenet_dir, args.test_batch_size, args.workers, spos_preprocessing=args.spos_preprocessing, shuffle=True, seed=args.seed, device_id=0) print("## test&retrain -- load test date ## test data loaded") train_loader = cycle(train_loader) acc = evaluate_acc(model, criterion, args, train_loader, val_loader) # 把模型最终的准确率写入一个文件中 os.makedirs("./acc", exist_ok=True) with open("./acc/{}".format(args.architecture[-12:]), "w") as f: # [-12:] 代表没有路径的文件名 # {filename1: acc, # filename2: acc, # 000_000.json: acc, # 000_001.json: acc, # ...... # } json.dump({args.architecture: acc}, f)