You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

search.py 6.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import logging
  4. import os
  5. import random
  6. from argparse import ArgumentParser
  7. from itertools import cycle
  8. import numpy as np
  9. import torch
  10. import torch.nn as nn
  11. import sys
  12. sys.path.append("../..")
  13. from pytorch.enas import EnasMutator, EnasTrainer
  14. from pytorch.callbacks import LRSchedulerCallback
  15. from pytorch.mutables import LayerChoice, InputChoice, MutableScope
  16. from dataloader import read_data_sst
  17. from model import Model
  18. from utils import accuracy, dump_global_result
  19. from collections import OrderedDict
  20. import os
  21. import json
  22. import time
  23. logger = logging.getLogger("nni.textnas")
  24. logger.setLevel(logging.INFO)
  25. # For debugging mode
  26. # os.chdir('/home/yangyi/pytorch/textnas')
  27. os.environ["CUDA_VISIBLE_DEVICES"]='4'
  28. def save_textnas_search_space(mutator,file_path):
  29. result = OrderedDict()
  30. cur_layer_idx = None
  31. for mutable in mutator.mutables.traverse():
  32. if not isinstance(mutable,(LayerChoice, InputChoice)):
  33. cur_layer_idx = mutable.key
  34. continue
  35. if isinstance(mutable,LayerChoice):
  36. if 'op_list' not in result:
  37. result['op_list'] = [str(i) for i in mutable]
  38. result[cur_layer_idx+ '_'+ mutable.key] = 'op_list'
  39. else:
  40. result[cur_layer_idx+ '_'+ mutable.key] = {'skip_connection':False if mutable.n_chosen else True,
  41. 'n_chosen': mutable.n_chosen if mutable.n_chosen else '',
  42. 'choose_from': mutable.choose_from if mutable.choose_from else ''}
  43. dump_global_result(file_path,result)
  44. class TextNASTrainer(EnasTrainer):
  45. def __init__(self, *args, train_loader=None, valid_loader=None, test_loader=None, **kwargs):
  46. super().__init__(*args, **kwargs)
  47. self.train_loader = train_loader
  48. self.valid_loader = valid_loader
  49. self.test_loader = test_loader
  50. self.result = {'accuracy':[],
  51. 'cost_time':0}
  52. def init_dataloader(self):
  53. pass
  54. if __name__ == "__main__":
  55. parser = ArgumentParser("textnas")
  56. parser.add_argument("--search_space_path", type=str,
  57. default='./search_space.json', help="search_space directory")
  58. parser.add_argument("--selected_space_path", type=str,
  59. default='./selected_space.json', help="sapce_path_out directory")
  60. parser.add_argument("--result_path", type=str,
  61. default='./result.json', help="res directory")
  62. parser.add_argument('--trial_id', type=int, default=0, metavar='N',
  63. help='trial_id,start from 0')
  64. parser.add_argument("--batch-size", default=128, type=int)
  65. parser.add_argument("--log-frequency", default=50, type=int)
  66. parser.add_argument("--epochs", default=2, type=int)
  67. parser.add_argument("--lr", default=5e-3, type=float)
  68. args = parser.parse_args()
  69. # 设置随机种子
  70. torch.manual_seed(args.trial_id)
  71. torch.cuda.manual_seed_all(args.trial_id)
  72. np.random.seed(args.trial_id)
  73. random.seed(args.trial_id)
  74. # use deterministic instead of nondeterministic algorithm
  75. # make sure exact results can be reproduced everytime.
  76. torch.backends.cudnn.deterministic = True
  77. # 配置计算资源及load数据
  78. device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
  79. train_dataset, valid_dataset, test_dataset, embedding = read_data_sst("data")
  80. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True)
  81. valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True)
  82. test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4)
  83. train_loader, valid_loader = cycle(train_loader), cycle(valid_loader)
  84. # 导入模型以及预训练的词向量
  85. model = Model(embedding)
  86. # 实例化一个mutator, mutator主要是用于选择搜索空间的
  87. mutator = EnasMutator(model, temperature=None, tanh_constant=None, entropy_reduction="mean")
  88. # 储存整个网络结构
  89. save_textnas_search_space(mutator, args.search_space_path)
  90. criterion = nn.CrossEntropyLoss()
  91. # 实例化优化器
  92. optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, eps=1e-3, weight_decay=2e-6)
  93. # 实例化学习率变化器
  94. lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-5)
  95. # 实例话一个训练器
  96. trainer = TextNASTrainer(model,
  97. loss=criterion,
  98. metrics=lambda output, target: {"acc": accuracy(output, target)},
  99. reward_function=accuracy,
  100. optimizer=optimizer,
  101. callbacks=[LRSchedulerCallback(lr_scheduler)],
  102. batch_size=args.batch_size,
  103. num_epochs=args.epochs,
  104. dataset_train=None,
  105. dataset_valid=None,
  106. train_loader=train_loader,
  107. valid_loader=valid_loader,
  108. test_loader=test_loader,
  109. log_frequency=args.log_frequency,
  110. mutator=mutator,
  111. mutator_lr=2e-3,
  112. mutator_steps=5,
  113. mutator_steps_aggregate=1,
  114. child_steps=50,
  115. baseline_decay=0.99,
  116. test_arc_per_epoch=10)
  117. logger.info(trainer.metrics)
  118. t1 = time.time()
  119. trainer.train()
  120. trainer.result["cost_time"] = time.time() - t1
  121. dump_global_result(args.result_path,trainer.result)
  122. # os.makedirs("checkpoints", exist_ok=True)
  123. # for i in range(2):
  124. # trainer.export(os.path.join("checkpoints", "architecture_%02d.json" % i))
  125. selected_model = trainer.export_child_model(selected_space = True)
  126. dump_global_result(args.selected_space_path,selected_model)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能