You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.h 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. /**
  2. * \file dnn/src/arm_common/conv_bias/int8/algos.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "src/arm_common/conv_bias/opr_impl.h"
  14. namespace megdnn {
  15. namespace arm_common {
  16. class ConvBiasImpl::AlgoS8DirectStride1 final : public AlgoBase {
  17. bool m_large_group;
  18. public:
  19. AlgoS8DirectStride1(bool large_group) : m_large_group(large_group) {}
  20. bool is_reproducible() const override { return true; }
  21. const char* name() const override {
  22. return m_large_group ? "S8STRD1_LARGE_GROUP" : "S8STRD1_SMALL_GROUP";
  23. }
  24. bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
  25. AlgoSelectionStrategy algo_selection_strategy) const override;
  26. size_t get_workspace(fallback::ConvBiasImpl*,
  27. const NCBKernSizeParam& param) const override;
  28. virtual SmallVector<NCBKern> dispatch_kerns(
  29. fallback::ConvBiasImpl* opr,
  30. const NCBKernSizeParam& param) const override;
  31. bool is_preferred(megdnn::fallback::ConvBiasImpl*,
  32. const NCBKernSizeParam& param) const override;
  33. };
  34. class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase {
  35. bool m_large_group;
  36. public:
  37. AlgoS8DirectStride2(bool large_group) : m_large_group(large_group) {}
  38. bool is_reproducible() const override { return true; }
  39. const char* name() const override {
  40. return m_large_group ? "S8STRD2_LARGE_GROUP" : "S8STRD2_SMALL_GROUP";
  41. }
  42. bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
  43. AlgoSelectionStrategy algo_selection_strategy) const override;
  44. size_t get_workspace(fallback::ConvBiasImpl*,
  45. const NCBKernSizeParam& param) const override;
  46. virtual SmallVector<NCBKern> dispatch_kerns(
  47. fallback::ConvBiasImpl* opr,
  48. const NCBKernSizeParam& param) const override;
  49. };
  50. class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase {
  51. public:
  52. AlgoS8DirectNCHW44() {}
  53. bool is_reproducible() const override { return true; }
  54. const char* name() const override { return "S8_NCHW44_DIRECT"; }
  55. bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
  56. AlgoSelectionStrategy algo_selection_strategy) const override;
  57. size_t get_workspace(fallback::ConvBiasImpl*,
  58. const NCBKernSizeParam& param) const override;
  59. virtual SmallVector<NCBKern> dispatch_kerns(
  60. fallback::ConvBiasImpl* opr,
  61. const NCBKernSizeParam& param) const override;
  62. bool is_preferred(megdnn::fallback::ConvBiasImpl*,
  63. const NCBKernSizeParam& param) const override;
  64. };
  65. class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase {
  66. public:
  67. AlgoS8DirectNCHWNCHW44() {}
  68. bool is_reproducible() const override { return true; }
  69. const char* name() const override { return "S8_CONV_NCHW_NCHW44"; }
  70. bool usable(fallback::ConvBiasImpl*, const NCBKernSizeParam& param,
  71. AlgoSelectionStrategy algo_selection_strategy) const override;
  72. size_t get_workspace(fallback::ConvBiasImpl*,
  73. const NCBKernSizeParam& param) const override;
  74. virtual SmallVector<NCBKern> dispatch_kerns(
  75. fallback::ConvBiasImpl* opr,
  76. const NCBKernSizeParam& param) const override;
  77. bool is_preferred(megdnn::fallback::ConvBiasImpl*,
  78. const NCBKernSizeParam& param) const override;
  79. };
  80. class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase {
  81. public:
  82. bool is_reproducible() const override { return true; }
  83. const char* name() const override { return "S8_CHAN_WISE_STRD1_NCHW44"; }
  84. bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
  85. AlgoSelectionStrategy algo_selection_strategy) const override;
  86. size_t get_workspace(fallback::ConvBiasImpl*,
  87. const NCBKernSizeParam& param) const override;
  88. virtual SmallVector<NCBKern> dispatch_kerns(
  89. fallback::ConvBiasImpl* opr,
  90. const NCBKernSizeParam& param) const override;
  91. };
  92. class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase {
  93. public:
  94. bool is_reproducible() const override { return true; }
  95. const char* name() const override { return "S8_CHAN_WISE_STRD2_NCHW44"; }
  96. bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
  97. AlgoSelectionStrategy algo_selection_strategy) const override;
  98. size_t get_workspace(fallback::ConvBiasImpl*,
  99. const NCBKernSizeParam& param) const override;
  100. virtual SmallVector<NCBKern> dispatch_kerns(
  101. fallback::ConvBiasImpl* opr,
  102. const NCBKernSizeParam& param) const override;
  103. };
  104. #if __ARM_FEATURE_DOTPROD
  105. class ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44 final : public AlgoBase {
  106. public:
  107. bool is_reproducible() const override { return true; }
  108. const char* name() const override { return "ARMDOTS8_NCHW_NCHW44"; }
  109. bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam&,
  110. AlgoSelectionStrategy algo_selection_strategy) const override;
  111. size_t get_workspace(FallbackConvBiasImpl*,
  112. const NCBKernSizeParam&) const override;
  113. virtual SmallVector<NCBKern> dispatch_kerns(
  114. fallback::ConvBiasImpl* opr,
  115. const NCBKernSizeParam& param) const override;
  116. };
  117. class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase {
  118. bool m_large_group;
  119. public:
  120. AlgoDotS8DirectStride1(bool large_group) : m_large_group(large_group) {}
  121. bool is_reproducible() const override { return true; }
  122. const char* name() const override {
  123. return m_large_group ? "ARMDOTS8STRD1_LARGE_GROUP"
  124. : "ARMDOTS8STRD1_SMALL_GROUP";
  125. }
  126. bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam&,
  127. AlgoSelectionStrategy algo_selection_strategy) const override;
  128. size_t get_workspace(FallbackConvBiasImpl*,
  129. const NCBKernSizeParam&) const override;
  130. virtual SmallVector<NCBKern> dispatch_kerns(
  131. fallback::ConvBiasImpl* opr,
  132. const NCBKernSizeParam& param) const override;
  133. };
  134. class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase {
  135. bool m_large_group;
  136. public:
  137. AlgoDotS8DirectStride2(bool large_group) : m_large_group(large_group) {}
  138. bool is_reproducible() const override { return true; }
  139. const char* name() const override {
  140. return m_large_group ? "ARMDOTS8STRD2_LARGE_GROUP"
  141. : "ARMDOTS8STRD2_SMALL_GROUP";
  142. }
  143. bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam&,
  144. AlgoSelectionStrategy algo_selection_strategy) const override;
  145. size_t get_workspace(FallbackConvBiasImpl*,
  146. const NCBKernSizeParam&) const override;
  147. virtual SmallVector<NCBKern> dispatch_kerns(
  148. fallback::ConvBiasImpl* opr,
  149. const NCBKernSizeParam& param) const override;
  150. };
  151. class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase {
  152. public:
  153. AlgoDotS8Direct_NCHW44() {}
  154. bool is_reproducible() const override { return true; }
  155. const char* name() const override {
  156. return "ARMDOTS8DIRECT_NCHW44";
  157. }
  158. bool usable(FallbackConvBiasImpl*, const NCBKernSizeParam&,
  159. AlgoSelectionStrategy algo_selection_strategy) const override;
  160. size_t get_workspace(FallbackConvBiasImpl*,
  161. const NCBKernSizeParam&) const override;
  162. SmallVector<NCBKern> dispatch_kerns(
  163. fallback::ConvBiasImpl* opr,
  164. const NCBKernSizeParam& param) const override;
  165. bool is_preferred(megdnn::fallback::ConvBiasImpl*,
  166. const NCBKernSizeParam& param) const override;
  167. };
  168. #endif
  169. class ConvBiasImpl::AlgoS8WinogradF23_8x8 final : public AlgoBase {
  170. public:
  171. AlgoS8WinogradF23_8x8(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  172. uint32_t tile_size)
  173. : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
  174. const char* name() const override {
  175. if (m_name.empty()) {
  176. m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
  177. m_matmul_algo->name(), {8, 2, m_tile_size});
  178. }
  179. return m_name.c_str();
  180. }
  181. MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
  182. };
  183. //=======================input int8 compute fp32 output int8============
  184. class ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44 final : public AlgoBase {
  185. public:
  186. AlgoS8CF32WinogradF23_4x4_NCHW44(
  187. fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size)
  188. : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
  189. const char* name() const override {
  190. if (m_name.empty()) {
  191. m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
  192. m_matmul_algo->name(), {4, 2, m_tile_size},
  193. param::ConvBias::Format::NCHW44);
  194. }
  195. return m_name.c_str();
  196. }
  197. MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
  198. };
  199. //=======================input int8 compute int16 output int8============
  200. class ConvBiasImpl::AlgoS8WinogradF23_8x8_NCHW44 final : public AlgoBase {
  201. public:
  202. AlgoS8WinogradF23_8x8_NCHW44(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
  203. uint32_t tile_size)
  204. : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
  205. const char* name() const override {
  206. if (m_name.empty()) {
  207. m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
  208. m_matmul_algo->name(), {8, 2, m_tile_size},
  209. param::ConvBias::Format::NCHW44);
  210. }
  211. return m_name.c_str();
  212. }
  213. MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
  214. };
  215. } // namespace arm_common
  216. } // namespace megdnn
  217. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台