from itertools import cycle import os import sys sys.path.append('..'+ '/' + '..') import numpy as np import random import logging import time from argparse import ArgumentParser from collections import OrderedDict import json import torch import torch.nn as nn import torch.optim as optim # import custom libraries import datasets from pytorch.trainer import Trainer from pytorch.utils import AverageMeterGroup, to_device, mkdirs from pytorch.mutables import LayerChoice, InputChoice, MutableScope from macro import GeneralNetwork from micro import MicroNetwork # from trainer import EnasTrainer from mutator import EnasMutator from pytorch.callbacks import (ArchitectureCheckpoint, LRSchedulerCallback) from utils import accuracy, reward_accuracy torch.cuda.set_device(0) logging.basicConfig(format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s', level=logging.INFO, filename='./train.log', filemode='a') logger = logging.getLogger('enas_train') class EnasTrainer(Trainer): """ ENAS trainer. Parameters ---------- model : nn.Module PyTorch model to be trained. loss : callable Receives logits and ground truth label, return a loss tensor. metrics : callable Receives logits and ground truth label, return a dict of metrics. reward_function : callable Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward. optimizer : Optimizer The optimizer used for optimizing the model. num_epochs : int Number of epochs planned for training. dataset_train : Dataset Dataset for training. Will be split for training weights and architecture weights. dataset_valid : Dataset Dataset for testing. mutator : EnasMutator Use when customizing your own mutator or a mutator with customized parameters. batch_size : int Batch size. workers : int Workers for data loading. device : torch.device ``torch.device("cpu")`` or ``torch.device("cuda")``. log_frequency : int Step count per logging. callbacks : list of Callback list of callbacks to trigger at events. entropy_weight : float Weight of sample entropy loss. skip_weight : float Weight of skip penalty loss. baseline_decay : float Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``. child_steps : int How many mini-batches for model training per epoch. mutator_lr : float Learning rate for RL controller. mutator_steps_aggregate : int Number of steps that will be aggregated into one mini-batch for RL controller. mutator_steps : int Number of mini-batches for each epoch of RL controller learning. aux_weight : float Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss. test_arc_per_epoch : int How many architectures are chosen for direct test after each epoch. """ def __init__(self, model, loss, metrics, reward_function, optimizer, num_epochs, dataset_train, dataset_valid, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, child_steps=500, mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4, test_arc_per_epoch=1,child_model_path = './', result_path='./'): super().__init__(model, mutator if mutator is not None else EnasMutator(model), loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks) self.reward_function = reward_function self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr) self.batch_size = batch_size self.workers = workers self.entropy_weight = entropy_weight self.skip_weight = skip_weight self.baseline_decay = baseline_decay self.baseline = 0. self.mutator_steps_aggregate = mutator_steps_aggregate self.mutator_steps = mutator_steps self.child_steps = child_steps self.aux_weight = aux_weight self.test_arc_per_epoch = test_arc_per_epoch self.child_model_path = child_model_path # saving the child model self.init_dataloader() # self.result = {'accuracy':[], # 'cost_time':0} self.result_path = result_path with open(self.result_path, "w") as file: file.write('') def init_dataloader(self): n_train = len(self.dataset_train) split = n_train // 10 indices = list(range(n_train)) train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split]) valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:]) self.train_loader = torch.utils.data.DataLoader(self.dataset_train, batch_size=self.batch_size, sampler=train_sampler, num_workers=self.workers) self.valid_loader = torch.utils.data.DataLoader(self.dataset_train, batch_size=self.batch_size, sampler=valid_sampler, num_workers=self.workers) self.test_loader = torch.utils.data.DataLoader(self.dataset_valid, batch_size=self.batch_size, num_workers=self.workers) self.train_loader = cycle(self.train_loader) self.valid_loader = cycle(self.valid_loader) def train_one_epoch(self, epoch): # Sample model and train self.model.train() self.mutator.eval() meters = AverageMeterGroup() for step in range(1, self.child_steps + 1): x, y = next(self.train_loader) x, y = to_device(x, self.device), to_device(y, self.device) self.optimizer.zero_grad() with torch.no_grad(): self.mutator.reset() # self._write_graph_status() logits = self.model(x) if isinstance(logits, tuple): logits, aux_logits = logits aux_loss = self.loss(aux_logits, y) else: aux_loss = 0. metrics = self.metrics(logits, y) loss = self.loss(logits, y) loss = loss + self.aux_weight * aux_loss loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 5.) self.optimizer.step() metrics["loss"] = loss.item() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: logger.info("Model Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs, step, self.child_steps, meters) # Train sampler (mutator) self.model.eval() self.mutator.train() meters = AverageMeterGroup() for mutator_step in range(1, self.mutator_steps + 1): self.mutator_optim.zero_grad() for step in range(1, self.mutator_steps_aggregate + 1): x, y = next(self.valid_loader) x, y = to_device(x, self.device), to_device(y, self.device) self.mutator.reset() with torch.no_grad(): logits = self.model(x) # self._write_graph_status() metrics = self.metrics(logits, y) reward = self.reward_function(logits, y) if self.entropy_weight: reward += self.entropy_weight * self.mutator.sample_entropy.item() self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) loss = self.mutator.sample_log_prob * (reward - self.baseline) if self.skip_weight: loss += self.skip_weight * self.mutator.sample_skip_penalty metrics["reward"] = reward metrics["loss"] = loss.item() metrics["ent"] = self.mutator.sample_entropy.item() metrics["log_prob"] = self.mutator.sample_log_prob.item() metrics["baseline"] = self.baseline metrics["skip"] = self.mutator.sample_skip_penalty loss /= self.mutator_steps_aggregate loss.backward() meters.update(metrics) cur_step = step + (mutator_step - 1) * self.mutator_steps_aggregate if self.log_frequency is not None and cur_step % self.log_frequency == 0: logger.info("RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s", epoch + 1, self.num_epochs, mutator_step, self.mutator_steps, step, self.mutator_steps_aggregate, meters) nn.utils.clip_grad_norm_(self.mutator.parameters(), 5.) self.mutator_optim.step() def validate_one_epoch(self, epoch): with torch.no_grad(): accuracy = 0 for arc_id in range(self.test_arc_per_epoch): meters = AverageMeterGroup() count, acc_this_round = 0,0 for x, y in self.test_loader: x, y = to_device(x, self.device), to_device(y, self.device) self.mutator.reset() child_model = self.export_child_model() # self._generate_child_model(epoch, # count, # arc_id, # child_model, # self.child_model_path) logits = self.model(x) if isinstance(logits, tuple): logits, _ = logits metrics = self.metrics(logits, y) loss = self.loss(logits, y) metrics["loss"] = loss.item() meters.update(metrics) count += 1 acc_this_round += metrics['acc1'] logger.info("Test Epoch [%d/%d] Arc [%d/%d] Summary %s", epoch + 1, self.num_epochs, arc_id + 1, self.test_arc_per_epoch, meters.summary()) acc_this_round /= count accuracy += acc_this_round # logger.info({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": meters.get_last_acc()}}) print({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": meters.get_last_acc()}}) with open(self.result_path, "a") as file: file.write(str({"type": "Accuracy", "result": {"sequence": epoch, "category": "epoch", "value": meters.get_last_acc()}}) + '\n') # self.result['accuracy'].append(accuracy / self.test_arc_per_epoch) # export child_model def export_child_model(self, selected_space=False): if selected_space: sampled = self.mutator.sample_final() else: sampled = self.mutator._cache result = OrderedDict() cur_layer_id = None for mutable in self.mutator.mutables: if not isinstance(mutable, (LayerChoice, InputChoice)): cur_layer_id = mutable.key # not supported as built-in continue choosed_ops_idx = self.mutator._convert_mutable_decision_to_human_readable(mutable, sampled[mutable.key]) if not isinstance(choosed_ops_idx, list): choosed_ops_idx = [choosed_ops_idx] if isinstance(mutable, LayerChoice): if 'op_list' not in result: result['op_list'] = [str(i) for i in mutable] choosed_ops = [str(mutable[idx]) for idx in choosed_ops_idx] else: choosed_ops = choosed_ops_idx if 'node' in cur_layer_id: result[mutable.key] = choosed_ops else: result[cur_layer_id + '_' + mutable.key] = choosed_ops return result def _generate_child_model(self, validation_epoch, model_idx, validation_step, child_model, file_path): # create child_models folder # parent_path = os.path.join(file_path, 'child_model') parent_path = file_path if not os.path.exists(parent_path): os.mkdir(parent_path) # create secondary directory secondary_path = os.path.join(parent_path, 'validation_epoch_{}'.format(validation_epoch)) if not os.path.exists(secondary_path): os.mkdir(secondary_path) # create third directory folder_path = os.path.join(secondary_path, 'validation_step_{}'.format(validation_step)) if not os.path.exists(folder_path): os.mkdir(folder_path) # save sampled child_model for validation saved_path = os.path.join(folder_path, "child_model_%02d.json" % model_idx) with open(saved_path, "w") as ss_file: json.dump(child_model, ss_file, indent=2) # save search space as search_space.json def save_nas_search_space(mutator,file_path): result = OrderedDict() cur_layer_idx = None for mutable in mutator.mutables.traverse(): if not isinstance(mutable,(LayerChoice, InputChoice)): cur_layer_idx = mutable.key + '_' continue # macro if 'layer' in cur_layer_idx: if isinstance(mutable, LayerChoice): if 'op_list' not in result: result['op_list'] = [str(i) for i in mutable] result[cur_layer_idx + mutable.key] = 'op_list' else: result[cur_layer_idx + mutable.key] = {'skip_connection': False if mutable.n_chosen else True, 'n_chosen': mutable.n_chosen if mutable.n_chosen else '', 'choose_from': mutable.choose_from if mutable.choose_from else ''} # micro elif 'node' in cur_layer_idx: if isinstance(mutable,LayerChoice): if 'op_list' not in result: result['op_list'] = [str(i) for i in mutable] result[mutable.key] = 'op_list' else: result[mutable.key] = {'skip_connection':False if mutable.n_chosen else True, 'n_chosen': mutable.n_chosen if mutable.n_chosen else '', 'choose_from': mutable.choose_from if mutable.choose_from else ''} dump_global_result(file_path,result) # def dump_global_result(args,global_result): # with open(args['result_path'], "w") as ss_file: # json.dump(global_result, ss_file, sort_keys=True, indent=2) def dump_global_result(res_path,global_result, sort_keys = False): with open(res_path, "w") as ss_file: json.dump(global_result, ss_file, sort_keys=sort_keys, indent=2) if __name__ == "__main__": parser = ArgumentParser("enas") parser.add_argument( "--data_dir", type=str, default="./data", help="Directory containing the dataset and embedding file. (default: %(default)s)") parser.add_argument("--model_selected_space_path", type=str, default='./model_selected_space.json', help="sapce_path_out directory") parser.add_argument("--result_path", type=str, default='./model_result.json', help="res directory") parser.add_argument("--search_space_path", type=str, default='./search_space.json', help="search_space directory") parser.add_argument("--best_selected_space_path", type=str, default='./model_selected_space.json', help="Best sapce_path_out directory of experiment") parser.add_argument('--trial_id', type=int, default=0, metavar='N', help='trial_id,start from 0') parser.add_argument('--lr', type=float, default=0.005, metavar='N', help='learning rate') parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)") parser.add_argument("--batch_size", default=128, type=int) parser.add_argument("--log_frequency", default=10, type=int) parser.add_argument("--search_for", choices=["macro", "micro"], default="macro") args = parser.parse_args() mkdirs(args.result_path, args.search_space_path, args.best_selected_space_path) # 设置随机种子 torch.manual_seed(args.trial_id) torch.cuda.manual_seed_all(args.trial_id) np.random.seed(args.trial_id) random.seed(args.trial_id) # use deterministic instead of nondeterministic algorithm # make sure exact results can be reproduced everytime. torch.backends.cudnn.deterministic = True dataset_train, dataset_valid = datasets.get_dataset("cifar10", args.data_dir) if args.search_for == "macro": model = GeneralNetwork() num_epochs = args.epochs or 310 mutator = None mutator = EnasMutator(model) elif args.search_for == "micro": model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) num_epochs = args.epochs or 150 mutator = EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True) else: raise AssertionError # 储存整个网络结构 # args.search_spach_path = None#str(args.search_for) + str(args.search_space_path) # print( args.search_space_path, args.search_for ) save_nas_search_space(mutator, args.search_space_path) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001) trainer = EnasTrainer(model, loss=criterion, metrics=accuracy, reward_function=reward_accuracy, optimizer=optimizer, callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./"+args.search_for+"_checkpoints")], batch_size=args.batch_size, num_epochs=num_epochs, dataset_train=dataset_train, dataset_valid=dataset_valid, log_frequency=args.log_frequency, mutator=mutator, child_steps=2, mutator_steps=2, child_model_path='./'+args.search_for+'_child_model', result_path=args.result_path) logger.info(trainer.metrics) t1 = time.time() trainer.train() # trainer.result["cost_time"] = time.time() - t1 # dump_global_result(args.result_path,trainer.result) selected_model = trainer.export_child_model(selected_space = True) dump_global_result(args.best_selected_space_path,selected_model)