You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

README.md 1.2 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # Network Morphism
  2. The implementation of the Network Morphism algorithm is based on
  3. [Auto-Keras: An Efficient Neural Architecture Search System](https://arxiv.org/pdf/1806.10282.pdf)
  4. Train stage
  5. ```
  6. python network_morphism_train.py
  7. --trial_id 0
  8. --experiment_dir 'tadl'
  9. --log_path 'tadl/train/0/log'
  10. --data_dir '../data/'
  11. --result_path 'trial_id/result.json'
  12. --log_path 'trial_id/log'
  13. --search_space_path 'experiment_id/search_space.json'
  14. --best_selected_space_path 'experiment_id/best_selected_space.json'
  15. --lr 0.001 --epochs 100 --batch_size 32 --opt 'SGD'
  16. ```
  17. select stage
  18. ```
  19. python network_morphism_select.py
  20. ```
  21. retrain stage
  22. ```
  23. python network_morphism_retrain.py
  24. --data_dir '../data/'
  25. --experiment_dir 'tadl'
  26. --result_path 'trial_id/result.json'
  27. --log_path 'trial_id/log'
  28. --best_selected_space_path 'experiment_id/best_selected_space.json'
  29. --best_checkpoint_dir 'experiment_id/'
  30. --trial_id 0 --batch_size 32 --opt 'SGD' --epochs 100 --lr 0.001
  31. ```
  32. The best model searched achieved 88.1% on CIFAR-10 dataset after 100 trials.
  33. Dependencies:
  34. ```
  35. Python = 3.6.13
  36. pytorch = 1.8.0
  37. torchvision = 0.9.0
  38. scipy = 1.5.2
  39. scikit-learn = 0.24.1
  40. ```

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能