You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.h 12 kB


  1. #pragma once
  2. #include "src/arm_common/conv_bias/opr_impl.h"
  3. namespace megdnn {
  4. namespace arm_common {
  5. class ConvBiasImpl::AlgoS8DirectStride1 final : public AlgoBase {
  6. public:
  7. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  8. const char* name() const override { return "S8STRD1"; }
  9. bool usable(
  10. const NCBKernSizeParam& param,
  11. AlgoSelectionStrategy algo_selection_strategy) const override;
  12. size_t get_workspace(const NCBKernSizeParam& param) const override;
  13. virtual SmallVector<NCBKern> dispatch_kerns(
  14. const NCBKernSizeParam& param) const override;
  15. bool is_preferred(const NCBKernSizeParam& param) const override;
  16. ConvAlgoTypePack get_algo_type() const override {
  17. return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
  18. }
  19. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_S8)
  20. };
  21. class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase {
  22. public:
  23. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  24. const char* name() const override { return "S8STRD2"; }
  25. bool usable(
  26. const NCBKernSizeParam& param,
  27. AlgoSelectionStrategy algo_selection_strategy) const override;
  28. size_t get_workspace(const NCBKernSizeParam& param) const override;
  29. virtual SmallVector<NCBKern> dispatch_kerns(
  30. const NCBKernSizeParam& param) const override;
  31. ConvAlgoTypePack get_algo_type() const override {
  32. return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
  33. }
  34. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_S8)
  35. };
  36. class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase {
  37. public:
  38. AlgoS8DirectNCHW44() {}
  39. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  40. const char* name() const override { return "S8_NCHW44_DIRECT"; }
  41. bool usable(
  42. const NCBKernSizeParam& param,
  43. AlgoSelectionStrategy algo_selection_strategy) const override;
  44. size_t get_workspace(const NCBKernSizeParam& param) const override;
  45. virtual SmallVector<NCBKern> dispatch_kerns(
  46. const NCBKernSizeParam& param) const override;
  47. bool is_preferred(const NCBKernSizeParam& param) const override;
  48. ConvAlgoTypePack get_algo_type() const override {
  49. return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
  50. }
  51. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44)
  52. };
  53. class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase {
  54. public:
  55. AlgoS8DirectNCHWNCHW44() {}
  56. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  57. const char* name() const override { return "S8_CONV_NCHW_NCHW44"; }
  58. bool usable(
  59. const NCBKernSizeParam& param,
  60. AlgoSelectionStrategy algo_selection_strategy) const override;
  61. size_t get_workspace(const NCBKernSizeParam& param) const override;
  62. virtual SmallVector<NCBKern> dispatch_kerns(
  63. const NCBKernSizeParam& param) const override;
  64. bool is_preferred(const NCBKernSizeParam& param) const override;
  65. ConvAlgoTypePack get_algo_type() const override {
  66. return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
  67. }
  68. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_S8)
  69. };
  70. class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase {
  71. public:
  72. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  73. const char* name() const override { return "S8_CHAN_WISE_STRD1_NCHW44"; }
  74. bool usable(
  75. const NCBKernSizeParam& param,
  76. AlgoSelectionStrategy algo_selection_strategy) const override;
  77. size_t get_workspace(const NCBKernSizeParam& param) const override;
  78. virtual SmallVector<NCBKern> dispatch_kerns(
  79. const NCBKernSizeParam& param) const override;
  80. ConvAlgoTypePack get_algo_type() const override {
  81. return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
  82. }
  83. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD1_NCHW44_S8)
  84. };
  85. class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase {
  86. public:
  87. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  88. const char* name() const override { return "S8_CHAN_WISE_STRD2_NCHW44"; }
  89. bool usable(
  90. const NCBKernSizeParam& param,
  91. AlgoSelectionStrategy algo_selection_strategy) const override;
  92. size_t get_workspace(const NCBKernSizeParam& param) const override;
  93. virtual SmallVector<NCBKern> dispatch_kerns(
  94. const NCBKernSizeParam& param) const override;
  95. ConvAlgoTypePack get_algo_type() const override {
  96. return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
  97. }
  98. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD2_NCHW44_S8)
  99. };
  100. #if MGB_ENABLE_DOT
  101. class ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44 final : public AlgoBase {
  102. public:
  103. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  104. const char* name() const override { return "ARMDOTS8_NCHW_NCHW44"; }
  105. bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy)
  106. const override;
  107. size_t get_workspace(const NCBKernSizeParam&) const override;
  108. virtual SmallVector<NCBKern> dispatch_kerns(
  109. const NCBKernSizeParam& param) const override;
  110. ConvAlgoTypePack get_algo_type() const override {
  111. return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
  112. }
  113. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW_NCHW44_DOT_S8)
  114. };
  115. class ConvBiasImpl::AlgoDotS8DirectChanWiseLarge final : public AlgoBase {
  116. public:
  117. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  118. const char* name() const override { return "ARMDOTS8_DIRECT_CHANWISE_LARGE"; }
  119. bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy)
  120. const override;
  121. size_t get_workspace(const NCBKernSizeParam&) const override;
  122. virtual SmallVector<NCBKern> dispatch_kerns(
  123. const NCBKernSizeParam& param) const override;
  124. ConvAlgoTypePack get_algo_type() const override {
  125. return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
  126. }
  127. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DOT_DIRECT_CHANWISE_LARGE_S8)
  128. };
  129. class ConvBiasImpl::AlgoDotS8Im2colChanWiseLarge final : public AlgoBase {
  130. public:
  131. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  132. const char* name() const override { return "ARMDOTS8_IM2COL_CHANWISE_LARGE"; }
  133. bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy)
  134. const override;
  135. size_t get_workspace(const NCBKernSizeParam&) const override;
  136. virtual SmallVector<NCBKern> dispatch_kerns(
  137. const NCBKernSizeParam& param) const override;
  138. ConvAlgoTypePack get_algo_type() const override {
  139. return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL};
  140. }
  141. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DOT_IM2COL_CHANWISE_LARGE_S8)
  142. };
  143. class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase {
  144. public:
  145. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  146. const char* name() const override { return "ARMDOTS8STRD1"; }
  147. bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy)
  148. const override;
  149. size_t get_workspace(const NCBKernSizeParam&) const override;
  150. virtual SmallVector<NCBKern> dispatch_kerns(
  151. const NCBKernSizeParam& param) const override;
  152. ConvAlgoTypePack get_algo_type() const override {
  153. return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
  154. }
  155. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_S8)
  156. };
  157. class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase {
  158. public:
  159. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  160. const char* name() const override { return "ARMDOTS8STRD2"; }
  161. bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy)
  162. const override;
  163. size_t get_workspace(const NCBKernSizeParam&) const override;
  164. virtual SmallVector<NCBKern> dispatch_kerns(
  165. const NCBKernSizeParam& param) const override;
  166. ConvAlgoTypePack get_algo_type() const override {
  167. return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
  168. }
  169. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_S8)
  170. };
  171. class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase {
  172. public:
  173. AlgoDotS8Direct_NCHW44() {}
  174. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  175. const char* name() const override { return "ARMDOTS8DIRECT_NCHW44"; }
  176. bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy)
  177. const override;
  178. size_t get_workspace(const NCBKernSizeParam&) const override;
  179. SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam& param) const override;
  180. bool is_preferred(const NCBKernSizeParam& param) const override;
  181. ConvAlgoTypePack get_algo_type() const override {
  182. return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
  183. }
  184. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_NCHW44_DOT_S8)
  185. };
  186. #endif
  187. class ConvBiasImpl::AlgoS8WinogradF23_8x8 final : public AlgoBase {
  188. public:
  189. AlgoS8WinogradF23_8x8(
  190. fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size)
  191. : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
  192. const char* name() const override {
  193. if (m_name.empty()) {
  194. m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
  195. m_matmul_algo->name(), {8, 2, m_tile_size, 3});
  196. }
  197. return m_name.c_str();
  198. }
  199. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  200. MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
  201. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_S8)
  202. };
  203. //=======================input int8 compute fp32 output int8============
  204. class ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44 final : public AlgoBase {
  205. public:
  206. AlgoS8CF32WinogradF23_4x4_NCHW44(
  207. fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size)
  208. : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
  209. const char* name() const override {
  210. if (m_name.empty()) {
  211. m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
  212. m_matmul_algo->name(), {4, 2, m_tile_size, 3},
  213. param::ConvBias::Format::NCHW44);
  214. }
  215. return m_name.c_str();
  216. }
  217. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  218. MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
  219. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32)
  220. };
  221. //=======================input int8 compute int16 output int8============
  222. class ConvBiasImpl::AlgoS8WinogradF23_8x8_NCHW44 final : public AlgoBase {
  223. public:
  224. AlgoS8WinogradF23_8x8_NCHW44(
  225. fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size)
  226. : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
  227. const char* name() const override {
  228. if (m_name.empty()) {
  229. m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
  230. m_matmul_algo->name(), {8, 2, m_tile_size, 3},
  231. param::ConvBias::Format::NCHW44);
  232. }
  233. return m_name.c_str();
  234. }
  235. AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
  236. MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
  237. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8)
  238. };
  239. } // namespace arm_common
  240. } // namespace megdnn
  241. // vim: syntax=cpp.doxygen