You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

selector.py 2.1 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import sys
  2. sys.path.append('../..')
  3. from pytorch.selector import Selector
  4. from pytorch.utils import mkdirs
  5. import shutil
  6. import argparse
  7. import os
  8. import json
  9. class ClassicnasSelector(Selector):
  10. def __init__(self, args, single_candidate=True):
  11. super().__init__(single_candidate)
  12. self.args = args
  13. def fit(self):
  14. """
  15. only one candatite, function passed
  16. """
  17. train_dir = os.path.join(self.args['experiment_dir'],'train')
  18. max_accuracy = 0
  19. best_selected_space = ''
  20. for trialId in os.listdir(train_dir):
  21. path= os.path.join(train_dir,trialId,'result','result.json')
  22. max_accuracy_trial = 0
  23. with open(path,'r') as f:
  24. for line in f:
  25. result_dict = json.loads(line)
  26. accuracy = result_dict["result"]["value"]
  27. if accuracy>max_accuracy_trial:
  28. max_accuracy_trial=accuracy
  29. print(max_accuracy_trial)
  30. if max_accuracy_trial > max_accuracy:
  31. max_accuracy = max_accuracy_trial
  32. best_selected_space = os.path.join(train_dir,trialId,'model_selected_space','model_selected_space.json')
  33. print('best trial id:',trialId)
  34. shutil.copyfile(best_selected_space,self.args['best_selected_space_path'])
  35. def get_params():
  36. # Training settings
  37. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  38. parser.add_argument("--experiment_dir", type=str,
  39. default='./experiment_dir', help="data directory")
  40. parser.add_argument("--best_selected_space_path", type=str,
  41. default='./best_selected_space.json', help="selected_space_path")
  42. args, _ = parser.parse_known_args()
  43. return args
  44. if __name__ == "__main__":
  45. params = vars(get_params())
  46. args =params
  47. mkdirs(args['best_selected_space_path'])
  48. hpo_selector = ClassicnasSelector(args,single_candidate=False)
  49. hpo_selector.fit()

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能