import sys from utils import accuracy import torch from torch import nn from torch.utils.data import DataLoader import datasets import time import logging import os import argparse import distutils.util import numpy as np import json import random sys.path.append('..'+ '/' + '..') # import custom packages from macro import GeneralNetwork from micro import MicroNetwork from pytorch.fixed import apply_fixed_architecture from pytorch.retrainer import Retrainer from pytorch.utils import AverageMeterGroup, to_device, save_best_checkpoint, mkdirs class EnasRetrainer(Retrainer): """ ENAS retrainer. Parameters ---------- model : nn.Module PyTorch model to be trained. data_dir : dataset path The path of the dataset. best_checkpoint_dir: 'best_checkpoint.pth' The directory for saving model. batch_size : int Batch size. eval_batch_size : int Batch size. num_epochs : int Number of epochs planned for training. lr : float Learning rate. is_cuda: Boolean Whether to use GPU for training. log_every : int Step count per logging. child_grad_bound : float Gradient bound. child_l2_reg: float L2 regression. eval_every_epochs: int Evaluate every epochs. logger: logging. workers : int Workers for data loading. device : torch.device ``torch.device("cpu")`` or ``torch.device("cuda")``. aux_weight : float Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss. """ def __init__(self,model,data_dir = './data',best_checkpoint_dir = './best_checkpoint', batch_size = 1024, eval_batch_size = 1024,num_epochs = 2,lr = 0.02,is_cuda = 'True', log_every = 40,child_grad_bound = 0.5, child_l2_reg=3e-6, eval_every_epochs=2, logger = logging.getLogger("enas-retrain"), result_path='./'): self.aux_weight = 0.4 self.device = torch.device("cuda:0" ) self.workers = 4 self.child_model = model self.data_dir = data_dir self.best_checkpoint_dir = best_checkpoint_dir self.batch_size = batch_size self.eval_batch_size = eval_batch_size self.num_epochs = num_epochs self.lr = lr self.is_cuda = is_cuda self.log_every = log_every self.child_grad_bound = child_grad_bound self.child_l2_reg = child_l2_reg self.eval_every_epochs = eval_every_epochs self.logger = logger self.optimizer = torch.optim.SGD(self.child_model.parameters(), self.lr, momentum=0.9, weight_decay=1.0E-4, nesterov=True) self.criterion = nn.CrossEntropyLoss() self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.num_epochs, eta_min=0.001) # load dataset self.init_dataloader() self.child_model.to(self.device) self.result_path = result_path with open(self.result_path, "w") as file: file.write('') def train(self): """ Train ``num_epochs``. Trigger callbacks at the start and the end of each epoch. Parameters ---------- validate : bool If ``true``, will do validation every epoch. """ self.logger.info('** Start training **') self.start_time = time.time() for epoch in range(self.num_epochs): self.train_one_epoch(epoch) self.child_model.eval() # if epoch / self.eval_every_epochs == 0: self.logger.info("Epoch {}: Eval".format(epoch)) self.validate_one_epoch(epoch) self.lr_scheduler.step() # print('** saving model **') self.logger.info("** Save best model **") # save_state = { # 'epoch': epoch, # 'child_model_state_dict': self.child_model.state_dict(), # 'optimizer_state_dict': self.optimizer.state_dict()} # torch.save(save_state, self.best_checkpoint_dir) save_best_checkpoint(self.best_checkpoint_dir, self.child_model, self.optimizer, epoch) def validate(self): """ Do one validation. Validate one epoch. """ pass def export(self, file): """ dump the architecture to ``file``. Parameters ---------- file : str File path to export to. Expected to be a JSON. """ pass def checkpoint(self): """ Override to dump a checkpoint. """ pass def init_dataloader(self): self.logger.info("Build dataloader") self.dataset_train, self.dataset_valid = datasets.get_dataset("cifar10", self.data_dir) n_train = len(self.dataset_train) split = n_train // 10 indices = list(range(n_train)) train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split]) valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:]) self.train_loader = torch.utils.data.DataLoader(self.dataset_train, batch_size=self.batch_size, sampler=train_sampler, num_workers=self.workers) self.valid_loader = torch.utils.data.DataLoader(self.dataset_train, batch_size=self.eval_batch_size, sampler=valid_sampler, num_workers=self.workers) self.test_loader = torch.utils.data.DataLoader(self.dataset_valid, batch_size=self.batch_size, num_workers=self.workers) # self.train_loader = cycle(self.train_loader) # self.valid_loader = cycle(self.valid_loader) def train_one_epoch(self,epoch): """ Train one epoch. Parameters ---------- epoch : int Epoch number starting from 0. """ tot_acc = 0 tot = 0 losses = [] step = 0 self.child_model.train() meters = AverageMeterGroup() for batch in self.train_loader: step += 1 x, y = batch x, y = to_device(x, self.device), to_device(y, self.device) logits = self.child_model(x) if isinstance(logits, tuple): logits, aux_logits = logits aux_loss = self.criterion(aux_logits, y) else: aux_loss = 0. acc = accuracy(logits, y) loss = self.criterion(logits, y) loss = loss + self.aux_weight * aux_loss self.optimizer.zero_grad() loss.backward() grad_norm = 0 trainable_params = self.child_model.parameters() # assert FLAGS.child_grad_bound is not None, "Need grad_bound to clip gradients." # # compute the gradient norm value # grad_norm = nn.utils.clip_grad_norm_(trainable_params, 99999999) # for param in trainable_params: # nn.utils.clip_grad_norm_(param, self.child_grad_bound) # clip grad # print(param_ == param) if self.child_grad_bound is not None: grad_norm = nn.utils.clip_grad_norm_(trainable_params, self.child_grad_bound) trainable_params = grad_norm self.optimizer.step() tot_acc += acc['acc1'] tot += 1 losses.append(loss) acc["loss"] = loss.item() meters.update(acc) if step % self.log_every == 0: curr_time = time.time() log_string = "" log_string += "epoch={:<6d}".format(epoch) log_string += "ch_step={:<6d}".format(step) log_string += " loss={:<8.6f}".format(loss) log_string += " lr={:<8.4f}".format(self.optimizer.param_groups[0]['lr']) log_string += " |g|={:<8.4f}".format(grad_norm) log_string += " tr_acc={:<8.4f}/{:>3d}".format(acc['acc1'], logits.size()[0]) log_string += " mins={:<10.2f}".format(float(curr_time - self.start_time) / 60) self.logger.info(log_string) print("Model Epoch [%d/%d] %.3f mins %s \n " % (epoch + 1, self.num_epochs, float(time.time() - self.start_time) / 60, meters )) final_acc = float(tot_acc) / tot losses = torch.tensor(losses) loss = losses.mean() def validate_one_epoch(self,epoch): tot_acc = 0 tot = 0 losses = [] meters = AverageMeterGroup() with torch.no_grad(): # save memory meters = AverageMeterGroup() for batch in self.valid_loader: x, y = batch x, y = to_device(x, self.device), to_device(y, self.device) logits = self.child_model(x) if isinstance(logits, tuple): logits, aux_logits = logits aux_loss = self.criterion(aux_logits, y) else: aux_loss = 0. loss = self.criterion(logits, y) loss = loss + self.aux_weight * aux_loss # loss = loss.mean() preds = logits.argmax(dim=1).long() acc = torch.eq(preds, y.long()).long().sum().item() acc_v = accuracy(logits, y) losses.append(loss) tot_acc += acc tot += len(y) acc_v["loss"] = loss.item() meters.update(acc_v) losses = torch.tensor(losses) loss = losses.mean() if tot > 0: final_acc = float(tot_acc) / tot else: final_acc = 0 self.logger.info("Error in calculating final_acc") with open(self.result_path, "a") as file: file.write( str({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": final_acc}}) + '\n') # print("Model eval %.3fmins %s \n " % ( # float(time.time() - self.start_time) / 60, meters )) print({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": final_acc}}) self.logger.info( "ch_step= {}_accuracy={:<6.4f} {}_loss={:<6.4f}".format( "test", final_acc, "test", loss)) logging.basicConfig(format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s', level=logging.INFO, filename='./retrain.log', filemode='a') logger = logging.getLogger("enas-retrain") def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--data_dir", type=str, default="./data", help="Directory containing the dataset and embedding file. (default: %(default)s)") parser.add_argument( "--model_selected_space_path", type=str, default="./model_selected_space.json", # required=True, help="Architecture json file. (default: %(default)s)") parser.add_argument("--result_path", type=str, default='./result.json', help="res directory") parser.add_argument("--search_space_path", type=str, default='./search_space.json', help="search_space directory") parser.add_argument("--log_path", type=str, default='output/log') parser.add_argument( "--best_selected_space_path", type=str, default="./best_selected_space.json", # required=True, help="Best architecture selected json file by experiment. (default: %(default)s)") parser.add_argument( "--best_checkpoint_dir", type=str, default="best_checkpoint", help="Path for saved checkpoints. (default: %(default)s)") parser.add_argument('--trial_id', type=int, default=0, metavar='N', help='trial_id,start from 0') parser.add_argument("--search_for", choices=["macro", "micro"], default="macro") parser.add_argument( "--batch_size", type=int, default=128, help="Number of samples each batch for training. (default: %(default)s)") parser.add_argument( "--eval_batch_size", type=int, default=128, help="Number of samples each batch for evaluation. (default: %(default)s)") parser.add_argument( "--epochs", type=int, default=10, help="The number of training epochs. (default: %(default)s)") parser.add_argument( "--lr", type=float, default=0.02, help="The initial learning rate. (default: %(default)s)") parser.add_argument( "--is_cuda", type=distutils.util.strtobool, default=True, help="Specify the device type. (default: %(default)s)") parser.add_argument( "--load_checkpoint", type=distutils.util.strtobool, default=False, help="Whether to load checkpoint. (default: %(default)s)") parser.add_argument( "--log_every", type=int, default=50, help="How many steps to log. (default: %(default)s)") parser.add_argument( "--eval_every_epochs", type=int, default=1, help="How many epochs to eval. (default: %(default)s)") parser.add_argument( "--child_grad_bound", type=float, default=5.0, help="The threshold for gradient clipping. (default: %(default)s)") # parser.add_argument( "--child_l2_reg", type=float, default=3e-6, help="Weight decay factor. (default: %(default)s)") parser.add_argument( "--child_lr_decay_scheme", type=str, default="cosine", help="Learning rate annealing strategy, only 'cosine' supported. (default: %(default)s)") #todo: remove global FLAGS FLAGS = parser.parse_args() # decode human readable search space to model def convert_selected_space_format(child_fixed_arc): # with open('./macro_selected_space.json') as js: with open(child_fixed_arc) as js: selected_space = json.load(js) ops = selected_space['op_list'] selected_space.pop('op_list') new_selected_space = {} for key, value in selected_space.items(): # for macro if FLAGS.search_for == 'macro': new_key = key.split('_')[-1] # for micro elif FLAGS.search_for == 'micro': new_key = key if len(value) > 1 or len(value)==0: new_value = value elif len(value) > 0 and value[0] in ops: new_value = ops.index(value[0]) else: new_value = value[0] new_selected_space[new_key] = new_value return new_selected_space def set_random_seed(seed): logger.info("set random seed for data reading: {}".format(seed)) random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) if FLAGS.is_cuda: torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True def main(): parse_args() child_fixed_arc = FLAGS.best_selected_space_path # './macro_seletced_space' search_for = FLAGS.search_for # set seed to result todo: trial ID set_random_seed(FLAGS.trial_id) mkdirs(FLAGS.result_path, FLAGS.log_path, FLAGS.best_checkpoint_dir) # define and load model logger.info('** ' + FLAGS.search_for + 'search **') fixed_arc = convert_selected_space_format(child_fixed_arc) # Model, macro search or micro search if FLAGS.search_for == 'macro': child_model = GeneralNetwork() elif FLAGS.search_for == 'micro': child_model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) apply_fixed_architecture(child_model, fixed_arc) # load model if FLAGS.load_checkpoint: print('** Load model **') logger.info('** Load model **') child_model.load_state_dict(torch.load(FLAGS.best_checkpoint_dir)['child_model_state_dict']) retrainer = EnasRetrainer(model=child_model, data_dir = FLAGS.data_dir, best_checkpoint_dir=FLAGS.best_checkpoint_dir, batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.eval_batch_size, num_epochs=FLAGS.epochs, lr=FLAGS.lr, is_cuda=FLAGS.is_cuda, log_every=FLAGS.log_every, child_grad_bound=FLAGS.child_grad_bound, child_l2_reg=FLAGS.child_l2_reg, eval_every_epochs=FLAGS.eval_every_epochs, logger=logger, result_path=FLAGS.result_path, ) t1 = time.time() retrainer.train() print('cost time for retrain: ' , time.time() - t1) if __name__ == "__main__": main()