You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

readme_spos.md 3.2 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Single Path One-Shot(SPOS)
  2. ## **简介**
  3. 该方法由[Single Path One-Shot Neural Architecture Search with Uniform Sampling](https://arxiv.org/abs/1904.00420)
  4. 中提出,主体思想可以分为两个部分,分别是Single Path和One-shot。其中One-Shot指,前期训练一个超网络,
  5. 后期对超网络不断进行采样或剪枝等等的方法来获得最终的子网络。而Single Path指,在对于训练好的超网络,每一个模型都是超网络的一条路径。
  6. 该算法整体来看即:将网络的层级结构视为一条路径,路径的节点即每个神经层,每个节点有多种选择(多种神经层),对每个节点进行采样得到一个确定的神经层,
  7. 并连接每个节点成为一个路径,该路径即最终采样得到的子网络。
  8. 本实例参照microsoft nni中的spos repo实现了spos的超网训练、子网络的进化搜索、最终选取网络的重训练。
  9. ## 使用介绍
  10. - 模型的训练用到了NVIDIA dali工具,需要提前[安装](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html)
  11. - 模型的训练使用imagenet数据集,需要提前准备
  12. - 模型的flops计算需要用到一个flops查找表,可以在[megvii](https://onedrive.live.com/?authkey=%21ADesvSdfsq%5FcN48&id=E7CA2ABE6D98E66F%21106&cid=E7CA2ABE6D98E66F)
  13. 下载。同时这里还可以下载到官方提供的supernet模型,以及最终重训练的模型等等。
  14. ### **目录结构**
  15. 可以将imagenet数据放在```./data```目录下,标准的数据处理方式可以参考[这里](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4)
  16. imagenet文件准备好之后,训练集和测试集应分别包含1000个子文件夹。
  17. 将文件准备齐全之后,目录结构应类似如下:
  18. ```
  19. spos
  20. ├── architecture_final.json
  21. ├── blocks.py
  22. ├── config_search.yml
  23. ├── data
  24. │ ├── imagenet
  25. │ │ ├── train
  26. │ │ └── val
  27. │ └── op_flops_dict.pkl
  28. ├── dataloader.py
  29. ├── network.py
  30. ├── readme.md
  31. ├── scratch.py
  32. ├── supernet.py
  33. ├── tester.py
  34. ├── evolution_tuner.py
  35. └── utils.py
  36. ```
  37. ### **超网络的训练**
  38. ```python supernet.py```
  39. - 如果不需要训练整个超网络,可以试用上述地址中下载的supernet网络,并将其放在```./data```目录下
  40. - 训练完成之后,checkpoint会到处在```./checkpoints```路径下
  41. — 为了和[官方repo](https://github.com/megvii-model/SinglePathOneShot) 保持一致,数据的通道使用BGR模式,同时数据的输入范围保持在[0,255].
  42. ### **子网络的进化搜索**
  43. 首先准备搜索空间
  44. ```python tester.py --mode gen```
  45. 然后进行基于进化算法的搜索
  46. ```python search.py```
  47. - 每次进化都会选出若干最优,其数目定义在dali_loader.py中,最终的准确率保存在```./acc```,路径下
  48. - 进化的模型结构(仅包含结构的json文件)保存在```./checkpoints```路径下
  49. - 模型结构的映射关系保存在```./id2cand```路径下
  50. ### **最终模型的重训练**
  51. ```python scartch.py```
  52. today

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能