You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

API_prune.md 3.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. 1. 通道剪枝算子
  2. =========
  3. ## 1.1 "bn"剪枝算子
  4. - `get_pruneThre_bn():`卷积层对应的BN层的gamma参数作为缩放因子,获得剪枝对应阈值
  5. - [源代码](../model_compress/ChannelSlimming/prune/util/prune_algorithm.py#L120)
  6. - **返回**:剪枝对应的阈值
  7. - `get_removeIndex_bn(a, thre):`根据阈值获得当前卷积层需要剪枝的通道index
  8. - [源代码](../model_compress/ChannelSlimming/prune/util/prune_algorithm.py#L182)
  9. - **参数**:
  10. - **a**:当前卷积层的参数
  11. - **thre**:`get_pruneThre_bn()`返回的阈值
  12. 1.2 "conv_avg"剪枝算子
  13. ---------
  14. - `get_pruneThre_conv_avg():`卷积层参数的平均值作为缩放因子,获得剪枝对应阈值
  15. - [源代码](../model_compress/ChannelSlimming/prune/util/prune_algorithm.py#L54)
  16. - **返回**:剪枝对应的阈值
  17. - `get_removeIndex_conv_avg(a, shape, thre):`根据阈值获得当前卷积层需要剪枝的通道index
  18. - [源代码](../model_compress/ChannelSlimming/prune/util/prune_algorithm.py#L187)
  19. - **参数**:
  20. - **a**:当前卷积层的参数
  21. - **shape**:当前卷积层的shape信息
  22. - **thre**:`get_pruneThre_conv_avg()`返回的阈值
  23. ## 1.3 "conv_max"剪枝算子
  24. - 同"conv_avg"剪枝算子
  25. ## 1.4 "conv_all"剪枝算子
  26. - 同"conv_avg"剪枝算子
  27. 1.5 "random"剪枝算子
  28. ---------
  29. - `get_removeIndex_conv_avg(shape):`随机选择需要剪枝的通道index
  30. - [源代码](../model_compress/ChannelSlimming/prune/util/prune_algorithm.py#L220)
  31. - **参数**:
  32. - **shape**:当前卷积层的shape信息
  33. 1.6 "dnn"剪枝算子
  34. ---------
  35. - `get_pruneThre_fc():`全连接层的神经元的参数的平均值作为缩放因子,获得剪枝对应阈值
  36. - [源代码](../model_compress/ChannelSlimming/prune/util/prune_algorithm.py#137)
  37. - **返回**:剪枝对应的阈值
  38. - `get_removeIndex_fc(a, shape, thre):`根据阈值获得当前全连接层需要剪枝的神经元index
  39. - [源代码](../model_compress/ChannelSlimming/prune/util/prune_algorithm.py#L171)
  40. - **参数**:
  41. - **a**:当前全连接层的参数
  42. - **shape**:当前全连接层的shape信息
  43. - **thre**:`get_pruneThre_fc()`返回的阈值
  44. 2. 模型调用算子
  45. =========
  46. ## 2.1 pruneDnn.py
  47. - DNN模型剪枝,可调用1.6剪枝算子
  48. - [文件](../model_compress/ChannelSlimming/prune/pruneDnn.py)
  49. ## 2.2 pruneLenet.py
  50. - CNN模型的lenet模型剪枝,可调用1.1-1.5剪枝算子
  51. - [文件](../model_compress/ChannelSlimming/prune/pruneLenet.py)
  52. ## 2.3 pruneAlexnet.py
  53. - CNN模型的lenet模型剪枝,可调用1.1-1.5剪枝算子
  54. - [文件](../model_compress/ChannelSlimming/prune/pruneAlexnet.py)
  55. ## 2.4 pruneVggnet.py
  56. - CNN模型的lenet模型剪枝,可调用1.1-1.5剪枝算子
  57. - [文件](../model_compress/ChannelSlimming/prune/pruneVggnet.py)
  58. ## 2.5 pruneResnet.py
  59. - CNN模型的lenet模型剪枝,可调用1.1-1.5剪枝算子
  60. - [文件](../model_compress/ChannelSlimming/prune/pruneResnet.py)

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能