# https://github.com/microsoft/nni/blob/v2.0/examples/nas/cream/train.py import sys sys.path.append('../..') import os import sys import time import json import torch import numpy as np import torch.nn as nn from argparse import ArgumentParser # import timm packages from timm.loss import LabelSmoothingCrossEntropy from timm.data import Dataset, create_loader from timm.models import resume_checkpoint # import apex as distributed package # try: # from apex.parallel import DistributedDataParallel as DDP # from apex.parallel import convert_syncbn_model # # USE_APEX = True # except ImportError as e: # print(e) # from torch.nn.parallel import DistributedDataParallel as DDP # # USE_APEX = False # import models and training functions from lib.utils.flops_table import FlopsEst from lib.models.structures.supernet import gen_supernet from lib.config import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN from lib.utils.util import get_logger, \ create_optimizer_supernet, create_supernet_scheduler from pytorch.utils import mkdirs, str2bool from pytorch.callbacks import LRSchedulerCallback from pytorch.callbacks import ModelCheckpoint from algorithms import CreamSupernetTrainer from algorithms import RandomMutator def parse_args(): """See lib.utils.config""" parser = ArgumentParser() # path parser.add_argument("--checkpoint_dir", type=str, default='') parser.add_argument("--data_dir", type=str, default='./data') parser.add_argument("--experiment_dir", type=str, default='./') parser.add_argument("--model_name", type=str, default='trainer') parser.add_argument("--log_path", type=str, default='output/log') parser.add_argument("--result_path", type=str, default='output/result.json') parser.add_argument("--search_space_path", type=str, default='output/search_space.json') parser.add_argument("--best_selected_space_path", type=str, default='output/selected_space.json') # int parser.add_argument("--acc_gap", type=int, default=5) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--epochs", type=int, default=200) parser.add_argument("--flops_minimum", type=int, default=0) parser.add_argument("--flops_maximum", type=int, default=200) parser.add_argument("--image_size", type=int, default=224) parser.add_argument("--local_rank", type=int, default=0) parser.add_argument("--log_interval", type=int, default=50) parser.add_argument("--meta_sta_epoch", type=int, default=20) parser.add_argument("--num_classes", type=int, default=1000) parser.add_argument("--num_gpu", type=int, default=1) parser.add_argument("--pool_size", type=int, default=10) parser.add_argument("--trial_id", type=int, default=42) parser.add_argument("--slice_num", type=int, default=4) parser.add_argument("--tta", type=int, default=0) parser.add_argument("--update_iter", type=int, default=1300) parser.add_argument("--workers", type=int, default=4) # float parser.add_argument("--color_jitter", type=float, default=0.4) parser.add_argument("--dropout_rate", type=float, default=0.0) parser.add_argument("--lr", type=float, default=1e-2) parser.add_argument("--meta_lr", type=float, default=1e-4) parser.add_argument("--opt_eps", type=float, default=1e-2) parser.add_argument("--re_prob", type=float, default=0.2) parser.add_argument("--momentum", type=float, default=0.9) parser.add_argument("--smoothing", type=float, default=0.1) parser.add_argument("--weight_decay", type=float, default=1e-4) # bool parser.add_argument("--auto_resume", type=str2bool, default='False') parser.add_argument("--dil_conv", type=str2bool, default='False') parser.add_argument("--resunit", type=str2bool, default='False') parser.add_argument("--sync_bn", type=str2bool, default='False') parser.add_argument("--verbose", type=str2bool, default='False') # str # gp: type of global pool ["avg", "max", "avgmax", "avgmaxc"] parser.add_argument("--gp", type=str, default='avg') parser.add_argument("--interpolation", type=str, default='bilinear') parser.add_argument("--opt", type=str, default='sgd') parser.add_argument("--pick_method", type=str, default='meta') parser.add_argument("--re_mode", type=str, default='pixel') args = parser.parse_args() args.sync_bn = False args.verbose = False args.data_dir = args.data_dir + "/imagenet" return args def main(): args = parse_args() mkdirs(args.experiment_dir, args.best_selected_space_path, args.search_space_path, args.result_path, args.log_path) with open(args.result_path, "w") as ss_file: ss_file.write('') # resolve logging if len(args.checkpoint_dir > 1): mkdirs(args.checkpoint_dir + "/") args.checkpoint_dir = os.path.join( args.checkpoint_dir, "{}_{}".format(args.model_name, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) ) if not os.path.exists(args.checkpoint_dir): os.mkdir(args.checkpoint_dir) if args.local_rank == 0: logger = get_logger(args.log_path) else: logger = None # initialize distributed parameters torch.cuda.set_device(args.local_rank) # torch.distributed.init_process_group(backend='nccl', init_method='env://') if args.local_rank == 0: logger.info( 'Training on Process %d with %d GPUs.', args.local_rank, args.num_gpu) # fix random seeds torch.manual_seed(args.trial_id) torch.cuda.manual_seed_all(args.trial_id) np.random.seed(args.trial_id) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # generate supernet and optimizer model, sta_num, resolution, search_space = gen_supernet( flops_minimum=args.flops_minimum, flops_maximum=args.flops_maximum, num_classes=args.num_classes, drop_rate=args.dropout_rate, global_pool=args.gp, resunit=args.resunit, dil_conv=args.dil_conv, slice=args.slice_num, verbose=args.verbose, logger=logger) optimizer = create_optimizer_supernet(args, model) # number of choice blocks in supernet choice_num = len(model.blocks[7]) if args.local_rank == 0: logger.info('Supernet created, param count: %d', ( sum([m.numel() for m in model.parameters()]))) logger.info('resolution: %d', resolution) logger.info('choice number: %d', choice_num) with open(args.search_space_path, "w") as f: print("dump search space.") json.dump({'search_space': search_space}, f) # initialize flops look-up table model_est = FlopsEst(model) flops_dict, flops_fixed = model_est.flops_dict, model_est.flops_fixed model = model.cuda() # convert model to distributed mode if args.sync_bn: try: # if USE_APEX: # model = convert_syncbn_model(model) # else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.local_rank == 0: logger.info('Converted model to use Synchronized BatchNorm.') except Exception as exception: logger.info( 'Failed to enable Synchronized BatchNorm. ' 'Install Apex or Torch >= 1.1 with Exception %s', exception) # if USE_APEX: # model = DDP(model, delay_allreduce=True) # else: # if args.local_rank == 0: # logger.info( # "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.") # # can use device str in Torch >= 1.1 # model = DDP(model, device_ids=[args.local_rank], find_unused_parameters=True) # optionally resume from a checkpoint resume_epoch = None if False: # args.auto_resume: checkpoint = torch.load(args.experiment_dir) model.load_state_dict(checkpoint['child_model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) resume_epoch = checkpoint['epoch'] # create learning rate scheduler lr_scheduler, num_epochs = create_supernet_scheduler(optimizer, args.epochs, args.num_gpu, args.batch_size, args.lr) start_epoch = resume_epoch if resume_epoch is not None else 0 if start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logger.info('Scheduled epochs: %d', num_epochs) # imagenet train dataset train_dir = os.path.join(args.data_dir, 'train') if not os.path.exists(train_dir): logger.info('Training folder does not exist at: %s', train_dir) sys.exit() dataset_train = Dataset(train_dir) loader_train = create_loader( dataset_train, input_size=(3, args.image_size, args.image_size), batch_size=args.batch_size, is_training=True, use_prefetcher=True, re_prob=args.re_prob, re_mode=args.re_mode, color_jitter=args.color_jitter, interpolation='random', num_workers=args.workers, distributed=False, collate_fn=None, crop_pct=DEFAULT_CROP_PCT, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ) # imagenet validation dataset eval_dir = os.path.join(args.data_dir, 'val') if not os.path.isdir(eval_dir): logger.info('Validation folder does not exist at: %s', eval_dir) sys.exit() dataset_eval = Dataset(eval_dir) loader_eval = create_loader( dataset_eval, input_size=(3, args.image_size, args.image_size), batch_size=4 * args.batch_size, is_training=False, use_prefetcher=True, num_workers=args.workers, distributed=False, crop_pct=DEFAULT_CROP_PCT, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, interpolation=args.interpolation ) # whether to use label smoothing if args.smoothing > 0.: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn mutator = RandomMutator(model) _callbacks = [LRSchedulerCallback(lr_scheduler)] if len(args.checkpoint_dir) > 1: _callbacks.append(ModelCheckpoint(checkpoint_dir)) trainer = CreamSupernetTrainer(args.best_selected_space_path, model, train_loss_fn, validate_loss_fn, optimizer, num_epochs, loader_train, loader_eval, result_path=args.result_path, mutator=mutator, batch_size=args.batch_size, log_frequency=args.log_interval, meta_sta_epoch=args.meta_sta_epoch, update_iter=args.update_iter, slices=args.slice_num, pool_size=args.pool_size, pick_method=args.pick_method, choice_num=choice_num, sta_num=sta_num, acc_gap=args.acc_gap, flops_dict=flops_dict, flops_fixed=flops_fixed, local_rank=args.local_rank, callbacks=_callbacks) trainer.train() if __name__ == '__main__': main()