# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import sys sys.path.append('..'+ '/' + '..') import logging import time from argparse import ArgumentParser import torch import torch.nn as nn import datasets from macro import GeneralNetwork from micro import MicroNetwork from trainer import EnasTrainer from mutator import EnasMutator from pytorch.callbacks import (ArchitectureCheckpoint, LRSchedulerCallback) from utils import accuracy, reward_accuracy from collections import OrderedDict from pytorch.mutables import LayerChoice, InputChoice import json torch.cuda.set_device(4) logger = logging.getLogger('tadl-enas') # save search space as search_space.json def save_nas_search_space(mutator,file_path): result = OrderedDict() cur_layer_idx = None for mutable in mutator.mutables.traverse(): if not isinstance(mutable,(LayerChoice, InputChoice)): cur_layer_idx = mutable.key + '_' continue # macro if 'layer' in cur_layer_idx: if isinstance(mutable, LayerChoice): if 'op_list' not in result: result['op_list'] = [str(i) for i in mutable] result[cur_layer_idx + mutable.key] = 'op_list' else: result[cur_layer_idx + mutable.key] = {'skip_connection': False if mutable.n_chosen else True, 'n_chosen': mutable.n_chosen if mutable.n_chosen else '', 'choose_from': mutable.choose_from if mutable.choose_from else ''} # micro elif 'node' in cur_layer_idx: if isinstance(mutable,LayerChoice): if 'op_list' not in result: result['op_list'] = [str(i) for i in mutable] result[mutable.key] = 'op_list' else: result[mutable.key] = {'skip_connection':False if mutable.n_chosen else True, 'n_chosen': mutable.n_chosen if mutable.n_chosen else '', 'choose_from': mutable.choose_from if mutable.choose_from else ''} dump_global_result(file_path,result) # def dump_global_result(args,global_result): # with open(args['result_path'], "w") as ss_file: # json.dump(global_result, ss_file, sort_keys=True, indent=2) def dump_global_result(res_path,global_result, sort_keys = False): with open(res_path, "w") as ss_file: json.dump(global_result, ss_file, sort_keys=sort_keys, indent=2) if __name__ == "__main__": parser = ArgumentParser("enas") parser.add_argument("--search_space_path", type=str, default='./search_space.json', help="search_space directory") parser.add_argument("--selected_space_path", type=str, default='./selected_space.json', help="sapce_path_out directory") parser.add_argument("--result_path", type=str, default='./result.json', help="res directory") parser.add_argument('--trial_id', type=int, default=0, metavar='N', help='trial_id,start from 0') parser.add_argument("--batch-size", default=128, type=int) parser.add_argument("--log-frequency", default=10, type=int) parser.add_argument("--search_for", choices=["macro", "micro"], default="macro") parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)") args = parser.parse_args() # 设置随机种子 torch.manual_seed(args.trial_id) torch.cuda.manual_seed_all(args.trial_id) np.random.seed(args.trial_id) random.seed(args.trial_id) dataset_train, dataset_valid = datasets.get_dataset("cifar10") if args.search_for == "macro": model = GeneralNetwork() num_epochs = args.epochs or 310 mutator = None mutator = EnasMutator(model) elif args.search_for == "micro": model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) num_epochs = args.epochs or 150 mutator = EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True) else: raise AssertionError # 储存整个网络结构 # args.search_spach_path = None#str(args.search_for) + str(args.search_space_path) # print( args.search_space_path, args.search_for ) save_nas_search_space(mutator, args.search_space_path) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001) trainer = EnasTrainer(model, loss=criterion, metrics=accuracy, reward_function=reward_accuracy, optimizer=optimizer, callbacks=[LRSchedulerCallback(lr_scheduler)], batch_size=args.batch_size, num_epochs=num_epochs, dataset_train=dataset_train, dataset_valid=dataset_valid, log_frequency=args.log_frequency, mutator=mutator, child_model_path='./'+args.search_for+'_child_model') logger.info(trainer.metrics) t1 = time.time() trainer.train() trainer.result["cost_time"] = time.time() - t1 dump_global_result(args.result_path,trainer.result) selected_model = trainer.export_child_model(selected_space = True) dump_global_result(args.selected_space_path,selected_model)