import sys sys.path.append('..'+ '/' + '..') import os import logging import pickle import shutil import random import math import time import datetime import argparse import distutils.util import numpy as np import json import torch from torch import nn from torch import optim from torch.utils.data import DataLoader import torch.nn.functional as Func from macro import GeneralNetwork from micro import MicroNetwork import datasets from utils import accuracy, reward_accuracy from pytorch.fixed import apply_fixed_architecture from pytorch.utils import AverageMeterGroup, to_device, save_best_checkpoint logger = logging.getLogger("enas-retrain") # TODO: def set_random_seed(seed): logger.info("set random seed for data reading: {}".format(seed)) random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) random.seed(seed) torch.manual_seed_all(seed) if FLAGS.is_cuda: torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True # TODO: parser args def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--data_dir", type=str, default="./data", help="Directory containing the dataset and embedding file. (default: %(default)s)") parser.add_argument("--search_space_path", type=str, default='./search_space.json', help="search_space directory") parser.add_argument( "--selected_space_path", type=str, default="./selected_space.json", # required=True, help="Architecture json file. (default: %(default)s)") parser.add_argument("--result_path", type=str, default='./result.json', help="res directory") parser.add_argument('--trial_id', type=int, default=0, metavar='N', help='trial_id,start from 0') parser.add_argument( "--output_dir", type=str, default="./output", help="The output directory. (default: %(default)s)") parser.add_argument( "--best_checkpoint_dir", type=str, default="best_checkpoint", help="Path for saved checkpoints. (default: %(default)s)") parser.add_argument("--search_for", choices=["macro", "micro"], default="micro") parser.add_argument( "--batch_size", type=int, default=128, help="Number of samples each batch for training. (default: %(default)s)") parser.add_argument( "--eval_batch_size", type=int, default=128, help="Number of samples each batch for evaluation. (default: %(default)s)") parser.add_argument( "--class_num", type=int, default=10, help="The number of categories. (default: %(default)s)") parser.add_argument( "--epochs", type=int, default=10, help="The number of training epochs. (default: %(default)s)") parser.add_argument( "--child_lr", type=float, default=0.02, help="The initial learning rate. (default: %(default)s)") parser.add_argument( "--is_cuda", type=distutils.util.strtobool, default=True, help="Specify the device type. (default: %(default)s)") parser.add_argument( "--load_checkpoint", type=distutils.util.strtobool, default=False, help="Whether to load checkpoint. (default: %(default)s)") parser.add_argument( "--log_every", type=int, default=50, help="How many steps to log. (default: %(default)s)") parser.add_argument( "--eval_every_epochs", type=int, default=1, help="How many epochs to eval. (default: %(default)s)") parser.add_argument( "--child_grad_bound", type=float, default=5.0, help="The threshold for gradient clipping. (default: %(default)s)") # parser.add_argument( "--child_lr_decay_scheme", type=str, default="cosine", help="Learning rate annealing strategy, only 'cosine' supported. (default: %(default)s)") #todo: remove parser.add_argument( "--child_lr_T_0", type=int, default=10, help="The length of one cycle. (default: %(default)s)") # todo: use for parser.add_argument( "--child_lr_T_mul", type=int, default=2, help="The multiplication factor per cycle. (default: %(default)s)") # todo: use for parser.add_argument( "--child_l2_reg", type=float, default=3e-6, help="Weight decay factor. (default: %(default)s)") parser.add_argument( "--child_lr_max", type=float, default=0.002, help="The max learning rate. (default: %(default)s)") parser.add_argument( "--child_lr_min", type=float, default=0.001, help="The min learning rate. (default: %(default)s)") parser.add_argument( "--multi_path", type=distutils.util.strtobool, default=False, help="Search for multiple path in the architecture. (default: %(default)s)") # todo: use for parser.add_argument( "--is_mask", type=distutils.util.strtobool, default=True, help="Apply mask. (default: %(default)s)") global FLAGS FLAGS = parser.parse_args() def print_user_flags(FLAGS, line_limit=80): log_strings = "\n" + "-" * line_limit + "\n" for flag_name in sorted(vars(FLAGS)): value = "{}".format(getattr(FLAGS, flag_name)) log_string = flag_name log_string += "." * (line_limit - len(flag_name) - len(value)) log_string += value log_strings = log_strings + log_string log_strings = log_strings + "\n" log_strings += "-" * line_limit logger.info(log_strings) def eval_once(child_model, device, eval_set, criterion, valid_dataloader=None, test_dataloader=None): if eval_set == "test": assert test_dataloader is not None dataloader = test_dataloader elif eval_set == "valid": assert valid_dataloader is not None dataloader = valid_dataloader else: raise NotImplementedError("Unknown eval_set '{}'".format(eval_set)) tot_acc = 0 tot = 0 losses = [] with torch.no_grad(): # save memory for batch in dataloader: x, y = batch x, y = to_device(x, device), to_device(y, device) logits = child_model(x) if isinstance(logits, tuple): logits, aux_logits = logits aux_loss = criterion(aux_logits, y) else: aux_loss = 0. loss = criterion(logits, y) loss = loss + aux_weight * aux_loss # loss = loss.mean() preds = logits.argmax(dim=1).long() acc = torch.eq(preds, y.long()).long().sum().item() losses.append(loss) tot_acc += acc tot += len(y) losses = torch.tensor(losses) loss = losses.mean() if tot > 0: final_acc = float(tot_acc) / tot else: final_acc = 0 logger.info("Error in calculating final_acc") return final_acc, loss # TODO: learning rate scheduler def update_lr( optimizer, epoch, l2_reg=1e-4, lr_warmup_val=None, lr_init=0.1, lr_decay_scheme="cosine", lr_max=0.002, lr_min=0.000000001, lr_T_0=4, lr_T_mul=1, sync_replicas=False, num_aggregate=None, num_replicas=None): if lr_decay_scheme == "cosine": assert lr_max is not None, "Need lr_max to use lr_cosine" assert lr_min is not None, "Need lr_min to use lr_cosine" assert lr_T_0 is not None, "Need lr_T_0 to use lr_cosine" assert lr_T_mul is not None, "Need lr_T_mul to use lr_cosine" T_i = lr_T_0 t_epoch = epoch last_reset = 0 while True: t_epoch -= T_i if t_epoch < 0: break last_reset += T_i T_i *= lr_T_mul T_curr = epoch - last_reset def _update(): rate = T_curr / T_i * 3.1415926 lr = lr_min + 0.5 * (lr_max - lr_min) * (1.0 + math.cos(rate)) return lr learning_rate = _update() else: raise ValueError("Unknown learning rate decay scheme {}".format(lr_decay_scheme)) #update lr in optimizer for params_group in optimizer.param_groups: params_group['lr'] = learning_rate return learning_rate def train(device, output_dir='./output'): workers = 4 data = 'cifar10' data_dir = FLAGS.data_dir output_dir = FLAGS.output_dir checkpoint_dir = FLAGS.best_checkpoint_dir batch_size = FLAGS.batch_size eval_batch_size = FLAGS.eval_batch_size class_num = FLAGS.class_num epochs = FLAGS.epochs child_lr = FLAGS.child_lr is_cuda = FLAGS.is_cuda load_checkpoint = FLAGS.load_checkpoint log_every = FLAGS.log_every eval_every_epochs = FLAGS.eval_every_epochs child_grad_bound = FLAGS.child_grad_bound child_l2_reg = FLAGS.child_l2_reg logger.info("Build dataloader") dataset_train, dataset_valid = datasets.get_dataset("cifar10") n_train = len(dataset_train) split = n_train // 10 indices = list(range(n_train)) train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split]) valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:]) train_dataloader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, sampler=train_sampler, num_workers=workers) valid_dataloader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, sampler=valid_sampler, num_workers=workers) test_dataloader = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, num_workers=workers) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(child_model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4, nesterov=True) # optimizer = optim.Adam(child_model.parameters(), eps=1e-3, weight_decay=FLAGS.child_l2_reg) # TODO lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.001) # move model to CPU/GPU device child_model.to(device) criterion.to(device) logger.info('Start training') start_time = time.time() step = 0 # save path if not os.path.exists(output_dir): os.mkdir(output_dir) # model_save_path = os.path.join(output_dir, "model.pth") # best_model_save_path = os.path.join(output_dir, "best_model.pth") best_acc = 0 start_epoch = 0 # TODO: load checkpoints # train for epoch in range(start_epoch, epochs): lr = update_lr(optimizer, epoch, l2_reg= 1e-4, lr_warmup_val=None, lr_init=FLAGS.child_lr, lr_decay_scheme=FLAGS.child_lr_decay_scheme, lr_max=0.05, lr_min=0.001, lr_T_0=10, lr_T_mul=2) child_model.train() for batch in train_dataloader: step += 1 x, y = batch x, y = to_device(x, device), to_device(y, device) logits = child_model(x) if isinstance(logits, tuple): logits, aux_logits = logits aux_loss = criterion(aux_logits, y) else: aux_loss = 0. acc = accuracy(logits, y) loss = criterion(logits, y) loss = loss + aux_weight * aux_loss optimizer.zero_grad() loss.backward() grad_norm = 0 trainable_params = child_model.parameters() for param in trainable_params: nn.utils.clip_grad_norm_(param, child_grad_bound) # clip grad optimizer.step() if step % log_every == 0: curr_time = time.time() log_string = "" log_string += "epoch={:<6d}".format(epoch) log_string += "ch_step={:<6d}".format(step) log_string += " loss={:<8.6f}".format(loss) log_string += " lr={:<8.4f}".format(lr) log_string += " |g|={:<8.4f}".format(grad_norm) log_string += " tr_acc={:<8.4f}/{:>3d}".format(acc['acc1'], logits.size()[0]) log_string += " mins={:<10.2f}".format(float(curr_time - start_time) / 60) logger.info(log_string) epoch += 1 save_state = { 'step': step, 'epoch': epoch, 'child_model_state_dict': child_model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()} # print(' Epoch {:<3d} loss: {:<.2f} '.format(epoch, loss)) # torch.save(save_state, model_save_path) child_model.eval() logger.info("Epoch {}: Eval".format(epoch)) eval_acc, eval_loss = eval_once(child_model, device, "test", criterion, test_dataloader=test_dataloader) logger.info( "ch_step={} {}_accuracy={:<6.4f} {}_loss={:<6.4f}".format(step, "test", eval_acc, "test", eval_loss)) if eval_acc > best_acc: best_acc = eval_acc logger.info("Save best model") # save_state = { # 'step': step, # 'epoch': epoch, # 'child_model_state_dict': child_model.state_dict(), # 'optimizer_state_dict': optimizer.state_dict()} # torch.save(save_state, best_model_save_path) save_best_checkpoint(checkpoint_dir, child_model, optimizer, epoch) result['accuracy'].append('Epoch {} acc: {:<6.4f}'.format(epoch, eval_acc,)) acc_l.append(eval_acc) print(result['accuracy'][-1]) print('max acc %.4f at epoch: %i'%(max(acc_l), np.argmax(np.array(acc_l)))) print('Time cost: %.4f hours'%( float(time.time() - start_time) /3600. )) return result # macro = True parse_args() child_fixed_arc = FLAGS.selected_space_path # './macro_seletced_space' search_for = FLAGS.search_for # 设置随机种子 torch.manual_seed(FLAGS.trial_id) torch.cuda.manual_seed_all(FLAGS.trial_id) np.random.seed(FLAGS.trial_id) random.seed(FLAGS.trial_id) aux_weight = 0.4 result = {'accuracy':[]} acc_l = [] # decode human readable search space to model def convert_selected_space_format(): # with open('./macro_selected_space.json') as js: with open(child_fixed_arc) as js: selected_space = json.load(js) ops = selected_space['op_list'] selected_space.pop('op_list') new_selected_space = {} for key, value in selected_space.items(): # for macro if FLAGS.search_for == 'macro': new_key = key.split('_')[-1] # for micro elif FLAGS.search_for == 'micro': new_key = key if len(value) > 1 or len(value)==0: new_value = value elif len(value) > 0 and value[0] in ops: new_value = ops.index(value[0]) else: new_value = value[0] new_selected_space[new_key] = new_value return new_selected_space fixed_arc = convert_selected_space_format() # TODO : macro search or micro search if FLAGS.search_for == 'macro': child_model = GeneralNetwork() elif FLAGS.search_for == 'micro': child_model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) apply_fixed_architecture(child_model,fixed_arc) def dump_global_result(res_path,global_result, sort_keys = False): with open(res_path, "w") as ss_file: json.dump(global_result, ss_file, sort_keys=sort_keys, indent=2) def main(): os.environ['CUDA_VISIBLE_DEVICES'] = '4' # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda" if FLAGS.is_cuda else "cpu") train(device) dump_global_result('result_retrain.json', result['accuracy']) if __name__ == "__main__": main()