You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pdarts_train.py 4.2 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import sys
  2. sys.path.append('..'+ '/' + '..')
  3. import time
  4. import logging
  5. from argparse import ArgumentParser
  6. from pdartstrainer import PdartsTrainer
  7. from pytorch.utils import mkdirs, set_seed, init_logger, list_str2int
  8. logger = logging.getLogger(__name__)
  9. if __name__ == "__main__":
  10. parser = ArgumentParser("pdarts")
  11. parser.add_argument("--data_dir", type=str,
  12. default='../data/', help="search_space json file")
  13. parser.add_argument("--result_path", type=str,
  14. default='0/result.json', help="training result")
  15. parser.add_argument("--log_path", type=str,
  16. default='0/log', help="log for info")
  17. parser.add_argument("--search_space_path", type=str,
  18. default='./search_space.json', help="search space of PDARTS")
  19. parser.add_argument("--best_selected_space_path", type=str,
  20. default='./best_selected_space.json', help="final best selected space")
  21. parser.add_argument('--trial_id', type=int, default=0, help='for ensuring reproducibility ')
  22. parser.add_argument('--model_lr', type=float, default=0.025, help='learning rate for training model weights')
  23. parser.add_argument('--arch_lr', type=float, default=3e-4, help='learning rate for training architecture')
  24. parser.add_argument("--epochs", default=2, type=int)
  25. parser.add_argument("--pre_epochs", default=15, type=int)
  26. parser.add_argument("--batch_size", default=96, type=int)
  27. parser.add_argument("--init_layers", default=5, type=int)
  28. parser.add_argument('--add_layers', default=[0, 6, 12], nargs='+', type=int, help='add layers in each stage')
  29. parser.add_argument('--dropped_ops', default=[3, 2, 1], nargs='+', type=int, help='drop ops in each stage')
  30. parser.add_argument('--dropout_rates', default=[0.1, 0.4, 0.7], nargs='+', type=float, help='drop ops probability in each stage')
  31. # parser.add_argument('--add_layers', action='append', help='add layers in each stage')
  32. # parser.add_argument('--dropped_ops', action='append', help='drop ops in each stage')
  33. # parser.add_argument('--dropout_rates', action='append', help='drop ops probability in each stage')
  34. parser.add_argument("--channels", default=16, type=int)
  35. parser.add_argument("--log_frequency", default=50, type=int)
  36. parser.add_argument("--class_num", default=10, type=int)
  37. parser.add_argument("--unrolled", default=False, action="store_true")
  38. args = parser.parse_args()
  39. mkdirs(args.result_path, args.log_path, args.search_space_path, args.best_selected_space_path)
  40. init_logger(args.log_path, "info")
  41. set_seed(args.trial_id)
  42. # args.add_layers = list_str2int(args.add_layers)
  43. # args.dropped_ops = list_str2int(args.dropped_ops)
  44. # args.dropout_rates = list_str2int(args.dropout_rates)
  45. logger.info(args)
  46. logger.info("initializing pdarts trainer")
  47. trainer = PdartsTrainer(
  48. init_layers=args.init_layers,
  49. pdarts_num_layers=args.add_layers,
  50. pdarts_num_to_drop=args.dropped_ops,
  51. pdarts_dropout_rates=args.dropout_rates,
  52. num_epochs=args.epochs,
  53. num_pre_epochs=args.pre_epochs,
  54. model_lr=args.model_lr,
  55. arch_lr=args.arch_lr,
  56. batch_size=args.batch_size,
  57. class_num=args.class_num,
  58. channels=args.channels,
  59. result_path=args.result_path,
  60. log_frequency=args.log_frequency,
  61. unrolled=args.unrolled,
  62. data_dir = args.data_dir,
  63. search_space_path=args.search_space_path,
  64. best_selected_space_path=args.best_selected_space_path
  65. )
  66. logger.info("training")
  67. start_time = time.time()
  68. trainer.train(validate=True)
  69. # result = trainer.result
  70. cost_time = time.time() - start_time
  71. # 后端在终端过滤,{"type": "Cost_time", "result": {"value": "* s"}}
  72. logger.info({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}})
  73. with open(args.result_path, "a") as file:
  74. file.write(str({"type": "Cost_time", "result": {"value": str(cost_time) + ' s'}}))

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能