You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evolve.py 3.5 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # import os
  2. # debugger_path = os.path.abspath("./")
  3. # os.chdir(debugger_path)
  4. import json
  5. import argparse
  6. from evolution_tuner import EvolutionWithFlops
  7. from evaluator import Evaluator
  8. import sys
  9. sys.path.append("../")
  10. sys.path.append("../../")
  11. def load_search_space(path="./auto_gen_search_space.json"):
  12. with open(path) as f:
  13. search_space = json.load(f)
  14. return search_space
  15. def trial(args, trial_id, search_space):
  16. """
  17. search the best model by evoluationary algo
  18. """
  19. evolution_spos = EvolutionWithFlops(max_epochs=args.max_epoches,
  20. num_select=args.num_select,
  21. num_population=args.num_population,
  22. m_prob=args.m_prob,
  23. num_crossover=args.num_crossover,
  24. num_mutation=args.num_mutation,
  25. epoch=trial_id,
  26. )
  27. evolution_spos.update_search_space(search_space)
  28. if __name__ == "__main__":
  29. parser = argparse.ArgumentParser("search the net by evolution")
  30. parser.add_argument("--search_space_path", type=str, default="./auto_gen_search_space.json")
  31. parser.add_argument("--checkpoint", type=str, default="./data/checkpoint-150000.pth.tar") # ./data/checkpoint-150000.pth.tar
  32. parser.add_argument("--num_select", type=int, default=2) # 10
  33. parser.add_argument("--num_population", type=int, default=4) # 50
  34. parser.add_argument("--workers", type=int, default=1) # 线程数
  35. parser.add_argument("--num_crossover", type=int, default=2) # 25
  36. parser.add_argument("--num_mutation", type=int, default=2) # 25
  37. parser.add_argument("--max_epoches", type=int, default=3) # 20
  38. parser.add_argument("--trial_id", type=int, default=1)
  39. parser.add_argument("--m_prob", type=float, default=0.1)
  40. parser.add_argument("--imagenet-dir", type=str, default="/mnt/local/hanjiayi/imagenet") # ./data/imagenet
  41. parser.add_argument("--spos-preprocessing", default=True,
  42. help="When true, image values will range from 0 to 255 and use BGR "
  43. "(as in original repo).")
  44. parser.add_argument("--seed", type=int, default=42)
  45. parser.add_argument("--train-batch-size", type=int, default=128)
  46. parser.add_argument("--train-iters", type=int, default=200)
  47. parser.add_argument("--test-batch-size", type=int, default=512) # nni中为512,官方repo为200
  48. parser.add_argument("--log-frequency", type=int, default=10)
  49. parser.add_argument("--architecture", type=str, default="./architecture_final.json", help="load the file to retrain or eval")
  50. args = parser.parse_args()
  51. search_space = load_search_space(path=args.search_space_path)
  52. # evl = Evaluator()
  53. # if args.single_trial:
  54. # epoch = 0
  55. # print("*" * 50, "\n")
  56. # print("epoch {}{}".format(epoch, "\n"))
  57. # print("*" * 50, "\n")
  58. # trial(args, epoch, search_space=search_space)
  59. # else:
  60. # for epoch in range(2, args.max_epoches+2):
  61. # print("*"*50, "\n")
  62. # print("epoch {}{}".format(epoch, "\n"))
  63. # print("*"*50, "\n")
  64. # trial(args, epoch, search_space=search_space)
  65. trial(args, args.trial_id, search_space)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能