You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.h 9.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. #pragma once
  2. #include <unordered_map>
  3. #include "megdnn/opr_param_defs.h"
  4. #include "megdnn/oprs/base.h"
  5. #include "src/common/algo_base.h"
  6. #include "src/common/utils.h"
  7. #include "src/naive/matrix_mul/opr_impl.h"
  8. namespace megdnn {
  9. struct AlgoTypePack {
  10. detail::AlgoDataType data_type : 32;
  11. param::MatrixMul::Format format : 32;
  12. };
  13. namespace fallback {
  14. class MatrixMulImpl : public naive::MatrixMulForwardImpl {
  15. public:
  16. using naive::MatrixMulForwardImpl::MatrixMulForwardImpl;
  17. using AlgoDataType = detail::AlgoDataType;
  18. bool is_thread_safe() const override { return true; }
  19. size_t get_workspace_in_bytes(
  20. const TensorLayout&, const TensorLayout&, const TensorLayout&) override;
  21. void exec(
  22. _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
  23. _megdnn_workspace workspace) override;
  24. struct KernSizeParam {
  25. DType A_type, B_type, C_type;
  26. size_t M, N, K;
  27. size_t LDA, LDB, LDC;
  28. bool trA, trB;
  29. Param::ComputeMode compute_mode;
  30. Param::Format format;
  31. //! get the data type category of the param for select the algo
  32. AlgoDataType deduce_algo_data_type() const;
  33. };
  34. struct KernParam : public KernSizeParam {
  35. RefPtr A_ptr;
  36. RefPtr B_ptr;
  37. RefPtr C_ptr;
  38. void* workspace_ptr = nullptr;
  39. size_t workspace_size = 0;
  40. template <typename T>
  41. inline const T* A() const {
  42. // A_type.assert_is_compatible_ctype<T>();
  43. return static_cast<const T*>(A_ptr.get_ptr());
  44. }
  45. template <typename T>
  46. inline const T* B() const {
  47. // B_type.assert_is_compatible_ctype<T>();
  48. return static_cast<const T*>(B_ptr.get_ptr());
  49. }
  50. template <typename T>
  51. inline T* C() const {
  52. // C_type.assert_is_compatible_ctype<T>();
  53. return static_cast<T*>(C_ptr.get_ptr());
  54. }
  55. template <typename T>
  56. inline T* workspace() const {
  57. return static_cast<T*>(workspace_ptr);
  58. }
  59. };
  60. typedef void (*kern_t)(const KernParam&);
  61. typedef void (*kern_naked_t)(
  62. const KernParam&, const void* a_panel, const void* b_panel);
  63. class AlgoBase : public Algorithm {
  64. protected:
  65. virtual ~AlgoBase() = default;
  66. bool can_be_treated_as_int8x8x32(const KernSizeParam& param) const {
  67. return param.A_type.enumv() == param.B_type.enumv() &&
  68. (param.A_type.enumv() == DTypeEnum::Int8 ||
  69. param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
  70. (param.C_type.enumv() == DTypeEnum::Int32 ||
  71. param.C_type.enumv() == DTypeEnum::QuantizedS32) &&
  72. param.compute_mode == Param::ComputeMode::DEFAULT &&
  73. param.format == param::MatrixMul::Format::DEFAULT;
  74. }
  75. bool can_be_treated_as_int8x8x16(const KernSizeParam& param) const {
  76. return param.A_type.enumv() == param.B_type.enumv() &&
  77. (param.A_type.enumv() == DTypeEnum::Int8 ||
  78. param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
  79. (param.C_type.enumv() == DTypeEnum::Int16 ||
  80. param.C_type.enumv() == DTypeEnum::QuantizedS16);
  81. }
  82. public:
  83. AlgoBase() { m_handle_type = Handle::HandleType::FALLBACK; }
  84. enum class AlgoType : uint32_t {
  85. //! fallback
  86. FB_F32K8x12x1 = 1 << 0,
  87. FB_GEMV,
  88. FB_NAIVE,
  89. FB_GI_F32_GEMV_MK4,
  90. FB_GI_F32_MK4_4x8,
  91. #if MEGDNN_X86
  92. //! x86
  93. X86_F32_BLAS = 1 << 8,
  94. X86_F32_MKL_PACKA,
  95. X86_INT8X8X32_AVX2_2X4X16,
  96. X86_INT8X8X32_AVX2_4X16X2,
  97. X86_INT8X8X16_AVX2,
  98. X86_INT8X8X16_SSE,
  99. X86_INT8X8X32_SSE_4X8X2,
  100. X86_F32_MK8_8X8,
  101. X86_F32_6x16,
  102. X86_INT8X8X32_VNNI,
  103. X86_INT8X8X32_MKLDNN,
  104. #elif MEGDNN_AARCH64 || MEGDNN_ARMV7
  105. ARM_COMMON_INT8X8X16 = 1 << 8,
  106. ARM_COMMON_INT8X8X32_GEMV,
  107. ARM_COMMON_INT8X8X32_GEMV_MK4,
  108. ARM_COMMON_INT8X8X32_GEMV_MK4_DOT,
  109. ARM_COMMON_INT8X8X32_GEVM_DOT,
  110. ARM_COMMON_INT8X8X32_GEVM_N32K4_DOT,
  111. ARM_COMMON_F16_GEMV,
  112. ARM_COMMON_GEVM,
  113. #if MEGDNN_AARCH64
  114. AARCH64_F32_K8X12X1 = 1 << 16,
  115. AARCH64_F32_MK4_K8X12X1,
  116. AARCH64_F32_K4X16X1,
  117. AARCH64_F32_MK4_4x16,
  118. AARCH64_F32_GEMV,
  119. AARCH64_F16_K8X24X1,
  120. AARCH64_F16_MK8_8X8,
  121. AARCH64_INT8X8X32_K8X12X4_DOTPROD,
  122. AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD,
  123. AARCH64_INT8X8X32_MK4_4X4X16,
  124. AARCH64_INT8X8X32_K4X4X16,
  125. AARCH64_INT8X8X32_K8X8X8,
  126. AARCH64_INT8X8X16_K8X8X8,
  127. AARCH64_INT8X8X16_K4X4X16,
  128. AARCH64_INT8X8X16_MK4_16X12X4,
  129. AARCH64_INT8X8X16_MK4_K8X8X8,
  130. AARCH64_INT8X8X16_MK4_4X4X8,
  131. AARCH64_INT16X16X32_K12X8X1,
  132. AARCH64_INT16X16X32_MK8_8X8,
  133. AARCH64_QUINT8_K8X8X4_DOTPROD,
  134. AARCH64_QUINT8_GEMV_DOTPROD,
  135. AARCH64_QUINT8_K8X8X8,
  136. AARCH64_INT4X4X16_K8X8X8,
  137. #else
  138. ARMV7_F32 = 1 << 16,
  139. ARMV7_F32_MK4_PACK_4X12,
  140. ARMV7_F32_MK4_4x8,
  141. ARMV7_F16_K4X16X1,
  142. ARMV7_F16_MK8_4X8,
  143. ARMV7_INT8_K6X8X4,
  144. ARMV7_QUINT8_K4X8X4,
  145. ARMV7_INT8_MK4_8X4X4_DOTPROD,
  146. ARMV7_F32_GEMV,
  147. ARMV7_INT8X8X32_K4X2X16,
  148. ARMV7_INT8X8X32_K4X8X8,
  149. ARMV7_QUINT8_K4X8X8,
  150. ARMV7_INT8X8X16_K4X2X16,
  151. ARMV7_INT8X8X16_K4X8X8,
  152. ARMV7_INT8X8X16_MK4_K8X8X4,
  153. ARMV7_INT16X16X32_K12X4X1,
  154. ARMV7_INT16X16X32_MK8_4X8,
  155. ARMV7_INT8X8X32_MK4_4X2X16,
  156. ARMV7_INT8X8X16_K8X8X4
  157. #endif
  158. #endif
  159. };
  160. enum class AlgoSet : uint32_t {
  161. ALGO_TYPE_GEMM = 0,
  162. ALGO_TYPE_GEMV = 1,
  163. ALGO_TYPE_GEVM = 2,
  164. };
  165. enum class PackMode : uint32_t {
  166. DEFAULT = 0,
  167. NO_PACK = 1,
  168. ONLY_PACKA = 2,
  169. };
  170. struct InnerBlockSize {
  171. size_t m, n, k;
  172. };
  173. struct MatmulDescription {
  174. PackMode packmode;
  175. InnerBlockSize innerblocksize;
  176. AlgoTypePack algo_type;
  177. size_t packa_type_size;
  178. };
  179. virtual bool usable(const KernSizeParam&) const = 0;
  180. virtual bool preferred(const KernSizeParam&) const { return true; }
  181. virtual size_t get_workspace(const KernSizeParam&) const = 0;
  182. virtual kern_t get_kern(const KernSizeParam&) const = 0;
  183. virtual kern_naked_t get_kern_naked(const KernSizeParam&) const {
  184. megdnn_assert(0);
  185. };
  186. virtual AlgoSet algoset() const { return AlgoSet::ALGO_TYPE_GEMM; }
  187. virtual PackMode packmode() const { return PackMode::DEFAULT; }
  188. virtual void pack_A(const KernParam&, void*, size_t, size_t) const {
  189. megdnn_assert(0);
  190. };
  191. virtual void pack_B(const KernParam&, void*, size_t, size_t) const {
  192. megdnn_assert(0);
  193. };
  194. virtual WorkspaceBundle get_bundle(const KernSizeParam&) const {
  195. megdnn_assert(0);
  196. };
  197. virtual InnerBlockSize get_inner_block_size() const { megdnn_assert(0); };
  198. bool preferred_attribute(
  199. const KernSizeParam& param,
  200. const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
  201. const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
  202. return contain_attribute_all(positive_attr) &&
  203. !contain_attribute_any(negative_attr) && preferred(param);
  204. };
  205. virtual MatmulDescription matmul_description() const = 0;
  206. using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
  207. };
  208. private:
  209. class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1
  210. class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44
  211. class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44
  212. class AlgoGemv;
  213. class AlgoNaive;
  214. class AlgoPack;
  215. //! maintain all the algos of in the opr of fallback
  216. static const AlgoPack& algo_pack();
  217. Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
  218. public:
  219. /**
  220. * \brief get all the algorithm for the opr.
  221. */
  222. virtual SmallVector<AlgoBase*> get_all_packed_algo();
  223. /**
  224. * \brief select algo according to input algo type
  225. */
  226. SmallVector<AlgoBase*> select_algo_type(AlgoTypePack algo_type);
  227. protected:
  228. KernSizeParam make_kern_size_param(
  229. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C);
  230. KernParam make_kern_param(
  231. _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
  232. _megdnn_workspace workspace);
  233. std::vector<Algorithm*> get_all_algorithms(
  234. const TensorLayout& A, const TensorLayout& B,
  235. const TensorLayout& C) override;
  236. std::vector<Algorithm*> get_all_algorithms_safe(
  237. const TensorLayout& A, const TensorLayout& B,
  238. const TensorLayout& C) override;
  239. Algorithm* get_algorithm_heuristic(
  240. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
  241. size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
  242. const AlgoAttribute& negative_attr) override;
  243. };
  244. } // namespace fallback
  245. } // namespace megdnn
  246. // vim: syntax=cpp.doxygen