# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import logging import os import random from argparse import ArgumentParser from itertools import cycle import numpy as np import torch import torch.nn as nn import sys sys.path.append("../..") from pytorch.enas import EnasMutator, EnasTrainer from pytorch.callbacks import LRSchedulerCallback from pytorch.mutables import LayerChoice, InputChoice, MutableScope from dataloader import read_data_sst from model import Model from utils import accuracy, dump_global_result from collections import OrderedDict import os import json import time logger = logging.getLogger("nni.textnas") logger.setLevel(logging.INFO) # For debugging mode # os.chdir('/home/yangyi/pytorch/textnas') os.environ["CUDA_VISIBLE_DEVICES"]='4' def save_textnas_search_space(mutator,file_path): result = OrderedDict() cur_layer_idx = None for mutable in mutator.mutables.traverse(): if not isinstance(mutable,(LayerChoice, InputChoice)): cur_layer_idx = mutable.key continue if isinstance(mutable,LayerChoice): if 'op_list' not in result: result['op_list'] = [str(i) for i in mutable] result[cur_layer_idx+ '_'+ mutable.key] = 'op_list' else: result[cur_layer_idx+ '_'+ mutable.key] = {'skip_connection':False if mutable.n_chosen else True, 'n_chosen': mutable.n_chosen if mutable.n_chosen else '', 'choose_from': mutable.choose_from if mutable.choose_from else ''} dump_global_result(file_path,result) class TextNASTrainer(EnasTrainer): def __init__(self, *args, train_loader=None, valid_loader=None, test_loader=None, **kwargs): super().__init__(*args, **kwargs) self.train_loader = train_loader self.valid_loader = valid_loader self.test_loader = test_loader self.result = {'accuracy':[], 'cost_time':0} def init_dataloader(self): pass if __name__ == "__main__": parser = ArgumentParser("textnas") parser.add_argument("--search_space_path", type=str, default='./search_space.json', help="search_space directory") parser.add_argument("--selected_space_path", type=str, default='./selected_space.json', help="sapce_path_out directory") parser.add_argument("--result_path", type=str, default='./result.json', help="res directory") parser.add_argument('--trial_id', type=int, default=0, metavar='N', help='trial_id,start from 0') parser.add_argument("--batch-size", default=128, type=int) parser.add_argument("--log-frequency", default=50, type=int) parser.add_argument("--epochs", default=2, type=int) parser.add_argument("--lr", default=5e-3, type=float) args = parser.parse_args() # 设置随机种子 torch.manual_seed(args.trial_id) torch.cuda.manual_seed_all(args.trial_id) np.random.seed(args.trial_id) random.seed(args.trial_id) # use deterministic instead of nondeterministic algorithm # make sure exact results can be reproduced everytime. torch.backends.cudnn.deterministic = True # 配置计算资源及load数据 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") train_dataset, valid_dataset, test_dataset, embedding = read_data_sst("data") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4) train_loader, valid_loader = cycle(train_loader), cycle(valid_loader) # 导入模型以及预训练的词向量 model = Model(embedding) # 实例化一个mutator, mutator主要是用于选择搜索空间的 mutator = EnasMutator(model, temperature=None, tanh_constant=None, entropy_reduction="mean") # 储存整个网络结构 save_textnas_search_space(mutator, args.search_space_path) criterion = nn.CrossEntropyLoss() # 实例化优化器 optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, eps=1e-3, weight_decay=2e-6) # 实例化学习率变化器 lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-5) # 实例话一个训练器 trainer = TextNASTrainer(model, loss=criterion, metrics=lambda output, target: {"acc": accuracy(output, target)}, reward_function=accuracy, optimizer=optimizer, callbacks=[LRSchedulerCallback(lr_scheduler)], batch_size=args.batch_size, num_epochs=args.epochs, dataset_train=None, dataset_valid=None, train_loader=train_loader, valid_loader=valid_loader, test_loader=test_loader, log_frequency=args.log_frequency, mutator=mutator, mutator_lr=2e-3, mutator_steps=5, mutator_steps_aggregate=1, child_steps=50, baseline_decay=0.99, test_arc_per_epoch=10) logger.info(trainer.metrics) t1 = time.time() trainer.train() trainer.result["cost_time"] = time.time() - t1 dump_global_result(args.result_path,trainer.result) # os.makedirs("checkpoints", exist_ok=True) # for i in range(2): # trainer.export(os.path.join("checkpoints", "architecture_%02d.json" % i)) selected_model = trainer.export_child_model(selected_space = True) dump_global_result(args.selected_space_path,selected_model)