You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.cpp 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735
  1. /**
  2. * \file dnn/src/x86/matrix_mul/algos.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/common/utils.h"
  13. #include "src/fallback/matrix_mul/gemm_impl.h"
  14. #include "src/x86/matrix_mul/algos.h"
  15. #include "src/x86/matrix_mul/f32/strategy.h"
  16. #include "src/x86/matrix_mul/int8/strategy.h"
  17. #include "midout.h"
  18. MIDOUT_DECL(megdnn_x86_matmul_kern)
  19. MIDOUT_DECL(megdnn_x86_matmul_kern_mk8_8x8)
  20. MIDOUT_DECL(megdnn_x86_matmul_kern_mkldnn)
  21. using namespace megdnn;
  22. using namespace x86;
  23. /* ===================== F32 Blas algo ===================== */
  24. namespace {
  25. void f32_blas_kern(const MatrixMulImpl::KernParam& kern_param) {
  26. #if MEGDNN_X86_WITH_MKL || MEGDNN_X86_WITH_OPENBLAS
  27. auto m = kern_param.M, n = kern_param.N, k = kern_param.K;
  28. bool trA = kern_param.trA, trB = kern_param.trB;
  29. const auto Aptr = kern_param.A<dt_float32>(),
  30. Bptr = kern_param.B<dt_float32>();
  31. auto Cptr = kern_param.C<dt_float32>();
  32. auto Atrd = kern_param.LDA, Btrd = kern_param.LDB, Ctrd = kern_param.LDC;
  33. disable_denorm();
  34. cblas_sgemm(CblasRowMajor, trA ? CblasTrans : CblasNoTrans,
  35. trB ? CblasTrans : CblasNoTrans, m, n, k, 1.0f, Aptr, Atrd,
  36. Bptr, Btrd, 0.0f, Cptr, Ctrd);
  37. #else
  38. megdnn_throw("a blas library is required");
  39. #endif
  40. }
  41. #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
  42. void f32_blas_kern_only_packA(const MatrixMulImpl::KernParam& kern_param,
  43. const void* a_panel, const void* b_panel) {
  44. MEGDNN_MARK_USED_VAR(b_panel);
  45. auto m = kern_param.M, n = kern_param.N, k = kern_param.K;
  46. const auto Bptr = kern_param.B<dt_float32>();
  47. auto Cptr = kern_param.C<dt_float32>();
  48. auto Atrd = kern_param.LDA, Btrd = kern_param.LDB, Ctrd = kern_param.LDC;
  49. disable_denorm();
  50. cblas_sgemm_compute(CblasRowMajor, CblasPacked, CblasNoTrans, m, n, k,
  51. static_cast<const float*>(a_panel), Atrd, Bptr, Btrd,
  52. 0.0f, Cptr, Ctrd);
  53. }
  54. #endif
  55. } // anonymous namespace
  56. bool MatrixMulImpl::AlgoF32Blas::usable(
  57. const KernSizeParam& kern_size_param) const {
  58. #if MEGDNN_X86_WITH_MKL || MEGDNN_X86_WITH_OPENBLAS
  59. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  60. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  61. kern_size_param.B_type == kern_size_param.A_type &&
  62. kern_size_param.C_type == kern_size_param.A_type &&
  63. kern_size_param.A_type == dtype::Float32() &&
  64. preferred(kern_size_param);
  65. #else
  66. return false;
  67. #endif
  68. }
  69. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Blas::get_kern(
  70. const KernSizeParam&) const {
  71. return f32_blas_kern;
  72. }
  73. /* ===================== AlgoF32BlasPackA====================== */
  74. #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
  75. bool MatrixMulImpl::AlgoF32MKLPackA::usable(
  76. const KernSizeParam& kern_size_param) const {
  77. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  78. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  79. kern_size_param.B_type == kern_size_param.A_type &&
  80. kern_size_param.C_type == kern_size_param.A_type &&
  81. kern_size_param.A_type == dtype::Float32() &&
  82. preferred(kern_size_param);
  83. }
  84. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MKLPackA::get_kern(
  85. const KernSizeParam&) const {
  86. return f32_blas_kern;
  87. }
  88. MatrixMulImpl::kern_naked_t MatrixMulImpl::AlgoF32MKLPackA::get_kern_naked(
  89. const KernSizeParam&) const {
  90. return f32_blas_kern_only_packA;
  91. }
  92. WorkspaceBundle MatrixMulImpl::AlgoF32MKLPackA::get_bundle(
  93. const KernSizeParam& param) const {
  94. auto M = param.M;
  95. auto N = param.N;
  96. auto K = param.K;
  97. size_t a_size = cblas_sgemm_pack_get_size(CblasAMatrix, M, N, K);
  98. return {nullptr, {a_size, 0, 0}};
  99. }
  100. void MatrixMulImpl::AlgoF32MKLPackA::pack_A(const KernParam& kern_param,
  101. void* out, size_t index,
  102. size_t stride) const {
  103. MEGDNN_MARK_USED_VAR(stride);
  104. MEGDNN_MARK_USED_VAR(index);
  105. auto m = kern_param.M, n = kern_param.N, k = kern_param.K;
  106. const auto Aptr = kern_param.A<dt_float32>();
  107. auto Atrd = kern_param.LDA;
  108. disable_denorm();
  109. cblas_sgemm_pack(CblasRowMajor, CblasAMatrix, CblasNoTrans, m, n, k, 1.0f,
  110. Aptr, Atrd, static_cast<float*>(out));
  111. }
  112. #endif
  113. /* ===================== Int8 Vnni algo ===================== */
  114. #if MEGDNN_X86_WITH_VNNI
  115. #define ALIGN_SIZE 64
  116. namespace {
  117. void int8x8x32_kern_vnni(const MatrixMulImpl::KernParam& kern_param) {
  118. MEGDNN_MARK_USED_VAR(kern_param);
  119. MIDOUT_BEGIN(megdnn_x86_matmul_kern_vnni, midout_iv(0)) {
  120. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  121. auto trA = kern_param.trA, trB = kern_param.trB;
  122. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  123. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  124. C_type = kern_param.C_type;
  125. const auto Aptr = kern_param.A<dt_int8>(),
  126. Bptr = kern_param.B<dt_int8>();
  127. auto Cptr = kern_param.C<dt_int32>();
  128. x86::matmul::gemm_int8_vnni_12x32x4 strategy(M, N, K, A_type, B_type,
  129. C_type);
  130. megdnn::matmul::GemmInterleaved<x86::matmul::gemm_int8_vnni_12x32x4>(
  131. M, N, K, trA, trB, strategy, ALIGN_SIZE)
  132. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  133. kern_param.workspace_ptr);
  134. }
  135. MIDOUT_END();
  136. }
  137. size_t get_kern_workspace(MatrixMulImpl::KernSizeParam kern_size_param) {
  138. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
  139. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  140. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  141. C_type = kern_size_param.C_type;
  142. x86::matmul::gemm_int8_vnni_12x32x4 strategy(M, N, K, A_type, B_type,
  143. C_type);
  144. return megdnn::matmul::GemmInterleaved<x86::matmul::gemm_int8_vnni_12x32x4>(
  145. M, N, K, trA, trB, strategy, ALIGN_SIZE)
  146. .get_workspace_size();
  147. }
  148. } // namespace
  149. bool MatrixMulImpl::AlgoInt8x8x32Vnni::usable(
  150. const KernSizeParam& kern_size_param) const {
  151. return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
  152. ((kern_size_param.A_type.enumv() == DTypeEnum::Int8 &&
  153. kern_size_param.C_type.enumv() == DTypeEnum::Int32) ||
  154. (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
  155. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
  156. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  157. kern_size_param.format == Param::Format::DEFAULT &&
  158. preferred(kern_size_param) && is_supported(SIMDType::VNNI);
  159. }
  160. size_t MatrixMulImpl::AlgoInt8x8x32Vnni::get_workspace(
  161. const KernSizeParam& kern_size_param) const {
  162. return get_kern_workspace(kern_size_param);
  163. }
  164. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Vnni::get_kern(
  165. const KernSizeParam&) const {
  166. return int8x8x32_kern_vnni;
  167. }
  168. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
  169. AlgoInt8x8x32Vnni, megdnn_x86_matmul_kern, "AlgoInt8x8x32Vnni"_hash,
  170. x86::matmul::gemm_int8_vnni_12x32x4, dt_int8, dt_int32,
  171. dt_uint8AlgoDataType::QINT8X8X32, DEFAULT);
  172. #endif
  173. /* ===================== Int8 mkldnn algo ===================== */
  174. #if MEGDNN_X86_WITH_MKL_DNN
  175. namespace {
  176. void int8x8x32_kern_mkldnn(const MatrixMulImpl::KernParam& kern_param) {
  177. MEGDNN_MARK_USED_VAR(kern_param);
  178. MIDOUT_BEGIN(megdnn_x86_matmul_kern_mkldnn, midout_iv(0)) {
  179. const char transA = kern_param.trA ? 'T' : 'N';
  180. const char transB = kern_param.trB ? 'T' : 'N';
  181. const char offsetC = 'F';
  182. const int64_t M = static_cast<int64_t>(kern_param.M);
  183. const int64_t N = static_cast<int64_t>(kern_param.N);
  184. const int64_t K = static_cast<int64_t>(kern_param.K);
  185. const int64_t LDA = static_cast<int64_t>(kern_param.LDA);
  186. const int64_t LDB = static_cast<int64_t>(kern_param.LDB);
  187. const int64_t LDC = static_cast<int64_t>(kern_param.LDC);
  188. const float alpha = 1.0f, beta = 0.0f;
  189. const int8_t ao = 0, bo = 0;
  190. const int32_t co = 0;
  191. const int8_t* A_ptr = static_cast<const int8_t*>(kern_param.A_ptr);
  192. const int8_t* B_ptr = static_cast<const int8_t*>(kern_param.B_ptr);
  193. int32_t* C_ptr = static_cast<int32_t*>(kern_param.C_ptr);
  194. auto status = mkldnn_gemm_s8s8s32(transA, transB, offsetC, M, N, K,
  195. alpha, A_ptr, LDA, ao, B_ptr, LDB, bo,
  196. beta, C_ptr, LDC, &co);
  197. megdnn_assert(status == mkldnn_success,
  198. "mkldnn_gemm_s8s8s32 compute error!!!");
  199. }
  200. MIDOUT_END();
  201. }
  202. } // namespace
  203. bool MatrixMulImpl::AlgoInt8x8x32Mkldnn::usable(
  204. const KernSizeParam& kern_size_param) const {
  205. return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
  206. ((kern_size_param.A_type.enumv() == DTypeEnum::Int8 &&
  207. kern_size_param.C_type.enumv() == DTypeEnum::Int32) ||
  208. (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
  209. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
  210. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  211. kern_size_param.format == Param::Format::DEFAULT &&
  212. is_supported(SIMDType::VNNI) && preferred(kern_size_param);
  213. }
  214. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Mkldnn::get_kern(
  215. const KernSizeParam&) const {
  216. return int8x8x32_kern_mkldnn;
  217. }
  218. #endif
  219. namespace {
  220. void gemm_s8s8s32_avx2_2x4x16(const MatrixMulImpl::KernParam& kern_param) {
  221. MEGDNN_MARK_USED_VAR(kern_param);
  222. MIDOUT_BEGIN(megdnn_x86_matmul_kern_avx2_2x4x16, midout_iv(0)) {
  223. constexpr int cacheline = 64;
  224. const size_t m = kern_param.M;
  225. const size_t n = kern_param.N;
  226. const size_t k = kern_param.K;
  227. const bool trans_a = kern_param.trA;
  228. const bool trans_b = kern_param.trB;
  229. const size_t lda = kern_param.LDA;
  230. const size_t ldb = kern_param.LDB;
  231. const size_t ldc = kern_param.LDC;
  232. auto a_type = kern_param.A_type;
  233. auto b_type = kern_param.B_type;
  234. auto c_type = kern_param.C_type;
  235. const auto a_ptr = kern_param.A<dt_int8>();
  236. const auto b_ptr = kern_param.B<dt_int8>();
  237. auto c_ptr = kern_param.C<dt_int32>();
  238. x86::matmul::gemm_avx2_s8s8s32_2x4x16 strategy(m, n, k, a_type, b_type,
  239. c_type);
  240. megdnn::matmul::GemmInterleaved<x86::matmul::gemm_avx2_s8s8s32_2x4x16>(
  241. m, n, k, trans_a, trans_b, strategy, cacheline)
  242. .execute(a_ptr, lda, b_ptr, ldb, c_ptr, ldc,
  243. kern_param.workspace_ptr);
  244. }
  245. MIDOUT_END();
  246. }
  247. void gemm_s8s8s32_avx2_4x16x2(const MatrixMulImpl::KernParam& kern_param) {
  248. MEGDNN_MARK_USED_VAR(kern_param);
  249. MIDOUT_BEGIN(megdnn_x86_matmul_kern_avx2_4x16x2, midout_iv(0)) {
  250. constexpr int cacheline = 64;
  251. const size_t m = kern_param.M;
  252. const size_t n = kern_param.N;
  253. const size_t k = kern_param.K;
  254. const bool trans_a = kern_param.trA;
  255. const bool trans_b = kern_param.trB;
  256. const size_t lda = kern_param.LDA;
  257. const size_t ldb = kern_param.LDB;
  258. const size_t ldc = kern_param.LDC;
  259. auto a_type = kern_param.A_type;
  260. auto b_type = kern_param.B_type;
  261. auto c_type = kern_param.C_type;
  262. const auto a_ptr = kern_param.A<dt_int8>();
  263. const auto b_ptr = kern_param.B<dt_int8>();
  264. auto c_ptr = kern_param.C<dt_int32>();
  265. x86::matmul::gemm_avx2_s8s8s32_4x16x2 strategy(m, n, k, a_type, b_type,
  266. c_type);
  267. megdnn::matmul::GemmInterleaved<x86::matmul::gemm_avx2_s8s8s32_4x16x2>(
  268. m, n, k, trans_a, trans_b, strategy, cacheline)
  269. .execute(a_ptr, lda, b_ptr, ldb, c_ptr, ldc,
  270. kern_param.workspace_ptr);
  271. }
  272. MIDOUT_END();
  273. }
  274. void gemm_s8s8s32_sse_4x8x2(const MatrixMulImpl::KernParam& kern_param) {
  275. MEGDNN_MARK_USED_VAR(kern_param);
  276. MIDOUT_BEGIN(megdnn_x86_matmul_kern_sse_4x8x2, midout_iv(0)) {
  277. constexpr int cacheline = 64;
  278. x86::matmul::gemm_sse_s8s8s32_4x8x2 strategy(
  279. kern_param.M, kern_param.N, kern_param.K, kern_param.A_type,
  280. kern_param.B_type, kern_param.C_type);
  281. megdnn::matmul::GemmInterleaved<x86::matmul::gemm_sse_s8s8s32_4x8x2>(
  282. kern_param.M, kern_param.N, kern_param.K, kern_param.trA,
  283. kern_param.trB, strategy, cacheline)
  284. .execute(kern_param.A<dt_int8>(), kern_param.LDA,
  285. kern_param.B<dt_int8>(), kern_param.LDB,
  286. kern_param.C<dt_int32>(), kern_param.LDC,
  287. kern_param.workspace_ptr);
  288. }
  289. MIDOUT_END();
  290. }
  291. void gemm_f32_avx2_6x16(const MatrixMulImpl::KernParam& kern_param) {
  292. MEGDNN_MARK_USED_VAR(kern_param);
  293. MIDOUT_BEGIN(megdnn_x86_matmul_kern_avx2_6x16x2, midout_iv(0)) {
  294. constexpr int cacheline = 64;
  295. const size_t m = kern_param.M;
  296. const size_t n = kern_param.N;
  297. const size_t k = kern_param.K;
  298. const bool trans_a = kern_param.trA;
  299. const bool trans_b = kern_param.trB;
  300. const size_t lda = kern_param.LDA;
  301. const size_t ldb = kern_param.LDB;
  302. const size_t ldc = kern_param.LDC;
  303. auto a_type = kern_param.A_type;
  304. auto b_type = kern_param.B_type;
  305. auto c_type = kern_param.C_type;
  306. const auto a_ptr = kern_param.A<float>();
  307. const auto b_ptr = kern_param.B<float>();
  308. auto c_ptr = kern_param.C<float>();
  309. x86::matmul::sgemm_pack_6x16_avx2 strategy(m, n, k, a_type, b_type,
  310. c_type);
  311. megdnn::matmul::GemmInterleaved<x86::matmul::sgemm_pack_6x16_avx2>(
  312. m, n, k, trans_a, trans_b, strategy, cacheline)
  313. .execute(a_ptr, lda, b_ptr, ldb, c_ptr, ldc,
  314. kern_param.workspace_ptr);
  315. }
  316. MIDOUT_END();
  317. }
  318. } // namespace
  319. /*************************AlgoInt8x8x16AVX2********************/
  320. void MatrixMulImpl::AlgoInt8x8x16AVX2::gemm_s8s8s16_avx2_4x16x2(
  321. const MatrixMulImpl::KernParam& kern_param) {
  322. MEGDNN_MARK_USED_VAR(kern_param);
  323. MIDOUT_BEGIN(megdnn_x86_matmul_kern_avx2_4x16x2, midout_iv(1)) {
  324. constexpr int cacheline = 64;
  325. const size_t m = kern_param.M;
  326. const size_t n = kern_param.N;
  327. const size_t k = kern_param.K;
  328. const bool trans_a = kern_param.trA;
  329. const bool trans_b = kern_param.trB;
  330. const size_t lda = kern_param.LDA;
  331. const size_t ldb = kern_param.LDB;
  332. const size_t ldc = kern_param.LDC;
  333. auto a_type = kern_param.A_type;
  334. auto b_type = kern_param.B_type;
  335. auto c_type = kern_param.C_type;
  336. const auto a_ptr = kern_param.A<dt_int8>();
  337. const auto b_ptr = kern_param.B<dt_int8>();
  338. auto c_ptr = kern_param.C<dt_int16>();
  339. x86::matmul::gemm_avx2_s8s8s16_4x16x2 strategy(m, n, k, a_type, b_type,
  340. c_type);
  341. megdnn::matmul::GemmInterleaved<x86::matmul::gemm_avx2_s8s8s16_4x16x2>(
  342. m, n, k, trans_a, trans_b, strategy, cacheline)
  343. .execute(a_ptr, lda, b_ptr, ldb, c_ptr, ldc,
  344. kern_param.workspace_ptr);
  345. }
  346. MIDOUT_END();
  347. }
  348. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_kern(
  349. const KernSizeParam&) const {
  350. return gemm_s8s8s16_avx2_4x16x2;
  351. }
  352. bool MatrixMulImpl::AlgoInt8x8x16AVX2::usable(
  353. const KernSizeParam& kern_size_param) const {
  354. bool is_ab_same =
  355. kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv();
  356. bool is_type_ok =
  357. ((kern_size_param.A_type.enumv() == DTypeEnum::Int8 &&
  358. kern_size_param.C_type.enumv() == DTypeEnum::Int16) ||
  359. (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
  360. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16));
  361. bool is_mode_ok =
  362. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  363. kern_size_param.format == Param::Format::DEFAULT &&
  364. is_supported(SIMDType::AVX2);
  365. bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok;
  366. return is_param_ok;
  367. }
  368. bool MatrixMulImpl::AlgoInt8x8x16AVX2::preferred(const KernSizeParam&) const {
  369. return true;
  370. }
  371. size_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_workspace(
  372. const KernSizeParam& kern_param) const {
  373. constexpr int cacheline = 64;
  374. const size_t m = kern_param.M;
  375. const size_t n = kern_param.N;
  376. const size_t k = kern_param.K;
  377. const bool trans_a = kern_param.trA;
  378. const bool trans_b = kern_param.trB;
  379. auto a_type = kern_param.A_type;
  380. auto b_type = kern_param.B_type;
  381. auto c_type = kern_param.C_type;
  382. x86::matmul::gemm_avx2_s8s8s16_4x16x2 strategy(m, n, k, a_type, b_type,
  383. c_type);
  384. return megdnn::matmul::GemmInterleaved<
  385. x86::matmul::gemm_avx2_s8s8s16_4x16x2>(
  386. m, n, k, trans_a, trans_b, strategy, cacheline)
  387. .get_workspace_size();
  388. }
  389. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
  390. AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, "AlgoInt8x8x16AVX2"_hash,
  391. x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16,
  392. AlgoDataType::INT8X8X16, DEFAULT);
  393. /*************************AlgoInt8x8x16SSE********************/
  394. void MatrixMulImpl::AlgoInt8x8x16SSE::gemm_s8s8s16_sse_4x8x2(
  395. const MatrixMulImpl::KernParam& kern_param) {
  396. MEGDNN_MARK_USED_VAR(kern_param);
  397. MIDOUT_BEGIN(megdnn_x86_matmul_kern_sse_4x8x2, midout_iv(2)) {
  398. constexpr int cacheline = 64;
  399. const size_t m = kern_param.M;
  400. const size_t n = kern_param.N;
  401. const size_t k = kern_param.K;
  402. const bool trans_a = kern_param.trA;
  403. const bool trans_b = kern_param.trB;
  404. const size_t lda = kern_param.LDA;
  405. const size_t ldb = kern_param.LDB;
  406. const size_t ldc = kern_param.LDC;
  407. auto a_type = kern_param.A_type;
  408. auto b_type = kern_param.B_type;
  409. auto c_type = kern_param.C_type;
  410. const auto a_ptr = kern_param.A<dt_int8>();
  411. const auto b_ptr = kern_param.B<dt_int8>();
  412. auto c_ptr = kern_param.C<dt_int16>();
  413. x86::matmul::gemm_sse_s8s8s16_4x8x2 strategy(m, n, k, a_type, b_type,
  414. c_type);
  415. megdnn::matmul::GemmInterleaved<x86::matmul::gemm_sse_s8s8s16_4x8x2>(
  416. m, n, k, trans_a, trans_b, strategy, cacheline)
  417. .execute(a_ptr, lda, b_ptr, ldb, c_ptr, ldc,
  418. kern_param.workspace_ptr);
  419. }
  420. MIDOUT_END();
  421. }
  422. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16SSE::get_kern(
  423. const KernSizeParam&) const {
  424. return gemm_s8s8s16_sse_4x8x2;
  425. }
  426. bool MatrixMulImpl::AlgoInt8x8x16SSE::usable(
  427. const KernSizeParam& kern_size_param) const {
  428. bool is_ab_same =
  429. kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv();
  430. bool is_type_ok =
  431. ((kern_size_param.A_type.enumv() == DTypeEnum::Int8 &&
  432. kern_size_param.C_type.enumv() == DTypeEnum::Int16) ||
  433. (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
  434. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16));
  435. bool is_mode_ok =
  436. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  437. kern_size_param.format == Param::Format::DEFAULT &&
  438. is_supported(SIMDType::SSE4_1);
  439. bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok;
  440. return is_param_ok;
  441. }
  442. bool MatrixMulImpl::AlgoInt8x8x16SSE::preferred(const KernSizeParam&) const {
  443. return true;
  444. }
  445. size_t MatrixMulImpl::AlgoInt8x8x16SSE::get_workspace(
  446. const KernSizeParam& kern_param) const {
  447. constexpr int cacheline = 64;
  448. const size_t m = kern_param.M;
  449. const size_t n = kern_param.N;
  450. const size_t k = kern_param.K;
  451. const bool trans_a = kern_param.trA;
  452. const bool trans_b = kern_param.trB;
  453. auto a_type = kern_param.A_type;
  454. auto b_type = kern_param.B_type;
  455. auto c_type = kern_param.C_type;
  456. x86::matmul::gemm_sse_s8s8s16_4x8x2 strategy(m, n, k, a_type, b_type,
  457. c_type);
  458. return megdnn::matmul::GemmInterleaved<x86::matmul::gemm_sse_s8s8s16_4x8x2>(
  459. m, n, k, trans_a, trans_b, strategy, cacheline)
  460. .get_workspace_size();
  461. }
  462. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16SSE,
  463. megdnn_x86_matmul_kern,
  464. "AlgoInt8x8x16SSE"_hash,
  465. x86::matmul::gemm_sse_s8s8s16_4x8x2,
  466. dt_int8, dt_int16, dt_int16,
  467. AlgoDataType::INT8X8X16, DEFAULT);
  468. /*************************AlgoInt8x8x32AVX2M4N16K2********************/
  469. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern(
  470. const KernSizeParam&) const {
  471. return gemm_s8s8s32_avx2_4x16x2;
  472. }
  473. bool MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::usable(
  474. const KernSizeParam& kern_size_param) const {
  475. bool is_param_ok =
  476. kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
  477. ((kern_size_param.A_type.enumv() == DTypeEnum::Int8 &&
  478. kern_size_param.C_type.enumv() == DTypeEnum::Int32) ||
  479. (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
  480. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
  481. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  482. kern_size_param.format == Param::Format::DEFAULT &&
  483. is_supported(SIMDType::AVX2);
  484. return is_param_ok;
  485. }
  486. size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace(
  487. const KernSizeParam& kern_param) const {
  488. constexpr int cacheline = 64;
  489. const size_t m = kern_param.M;
  490. const size_t n = kern_param.N;
  491. const size_t k = kern_param.K;
  492. const bool trans_a = kern_param.trA;
  493. const bool trans_b = kern_param.trB;
  494. auto a_type = kern_param.A_type;
  495. auto b_type = kern_param.B_type;
  496. auto c_type = kern_param.C_type;
  497. x86::matmul::gemm_avx2_s8s8s32_4x16x2 strategy(m, n, k, a_type, b_type,
  498. c_type);
  499. return megdnn::matmul::GemmInterleaved<
  500. x86::matmul::gemm_avx2_s8s8s32_4x16x2>(
  501. m, n, k, trans_a, trans_b, strategy, cacheline)
  502. .get_workspace_size();
  503. }
  504. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
  505. AlgoInt8x8x32AVX2M4N16K2, megdnn_x86_matmul_kern,
  506. "AlgoInt8x8x32AVX2M4N16K2"_hash, x86::matmul::gemm_avx2_s8s8s32_4x16x2,
  507. dt_int8, dt_int32, dt_int16, AlgoDataType::QINT8X8X32, DEFAULT);
  508. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_kern(
  509. const KernSizeParam&) const {
  510. return gemm_s8s8s32_avx2_2x4x16;
  511. }
  512. bool MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::usable(
  513. const KernSizeParam& kern_size_param) const {
  514. return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
  515. ((kern_size_param.A_type.enumv() == DTypeEnum::Int8 &&
  516. kern_size_param.C_type.enumv() == DTypeEnum::Int32) ||
  517. (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
  518. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
  519. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  520. kern_size_param.format == Param::Format::DEFAULT &&
  521. is_supported(SIMDType::AVX2);
  522. }
  523. size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace(
  524. const KernSizeParam& kern_param) const {
  525. constexpr int cacheline = 64;
  526. const size_t m = kern_param.M;
  527. const size_t n = kern_param.N;
  528. const size_t k = kern_param.K;
  529. const bool trans_a = kern_param.trA;
  530. const bool trans_b = kern_param.trB;
  531. auto a_type = kern_param.A_type;
  532. auto b_type = kern_param.B_type;
  533. auto c_type = kern_param.C_type;
  534. x86::matmul::gemm_avx2_s8s8s32_2x4x16 strategy(m, n, k, a_type, b_type,
  535. c_type);
  536. return megdnn::matmul::GemmInterleaved<
  537. x86::matmul::gemm_avx2_s8s8s32_2x4x16>(
  538. m, n, k, trans_a, trans_b, strategy, cacheline)
  539. .get_workspace_size();
  540. }
  541. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16,
  542. megdnn_x86_matmul_kern,
  543. "AlgoInt8x8x32AVX2M2N4K16"_hash,
  544. x86::matmul::gemm_avx2_s8s8s32_2x4x16,
  545. dt_int8, dt_int32,
  546. AlgoDataType::QINT8X8X32, DEFAULT);
  547. /*************************AlgoInt8x8x32SSEM4N8K2********************/
  548. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_kern(
  549. const KernSizeParam&) const {
  550. return gemm_s8s8s32_sse_4x8x2;
  551. }
  552. bool MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::usable(
  553. const KernSizeParam& kern_size_param) const {
  554. return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
  555. ((kern_size_param.A_type.enumv() == DTypeEnum::Int8 &&
  556. kern_size_param.C_type.enumv() == DTypeEnum::Int32) ||
  557. (kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
  558. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
  559. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  560. kern_size_param.format == Param::Format::DEFAULT &&
  561. is_supported(SIMDType::SSE4_1);
  562. }
  563. size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace(
  564. const KernSizeParam& kern_param) const {
  565. constexpr int cacheline = 64;
  566. const size_t m = kern_param.M;
  567. const size_t n = kern_param.N;
  568. const size_t k = kern_param.K;
  569. const bool trans_a = kern_param.trA;
  570. const bool trans_b = kern_param.trB;
  571. auto a_type = kern_param.A_type;
  572. auto b_type = kern_param.B_type;
  573. auto c_type = kern_param.C_type;
  574. x86::matmul::gemm_sse_s8s8s32_4x8x2 strategy(m, n, k, a_type, b_type,
  575. c_type);
  576. return megdnn::matmul::GemmInterleaved<x86::matmul::gemm_sse_s8s8s32_4x8x2>(
  577. m, n, k, trans_a, trans_b, strategy, cacheline)
  578. .get_workspace_size();
  579. }
  580. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32SSEM4N8K2,
  581. megdnn_x86_matmul_kern,
  582. "AlgoInt8x8x32SSEM4N8K2"_hash,
  583. x86::matmul::gemm_sse_s8s8s32_4x8x2,
  584. dt_int8, dt_int32, dt_int16,
  585. AlgoDataType::QINT8X8X32, DEFAULT);
  586. /*************************AlgoF32MK8_8x8********************/
  587. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK8_8x8::get_kern(
  588. const KernSizeParam&) const {
  589. auto f32_kern_mk8_8x8 = [](const MatrixMulImpl::KernParam& kern_param) {
  590. MIDOUT_BEGIN(megdnn_x86_matmul_kern_mk8_8x8, midout_iv(0)) {
  591. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  592. auto trA = kern_param.trA, trB = kern_param.trB;
  593. auto LDA = kern_param.LDA, LDB = kern_param.LDB,
  594. LDC = kern_param.LDC;
  595. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  596. C_type = kern_param.C_type;
  597. const auto Aptr = kern_param.A<float>(),
  598. Bptr = kern_param.B<float>();
  599. auto Cptr = kern_param.C<float>();
  600. x86::matmul::sgemm_nopack_8x8_avx2 strategy(A_type, B_type, C_type);
  601. megdnn::matmul::GemmInterleaved<x86::matmul::sgemm_nopack_8x8_avx2,
  602. false>(M, N, K, trA, trB, strategy)
  603. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  604. kern_param.workspace_ptr);
  605. }
  606. MIDOUT_END();
  607. };
  608. return f32_kern_mk8_8x8;
  609. }
  610. bool MatrixMulImpl::AlgoF32MK8_8x8::usable(
  611. const KernSizeParam& kern_size_param) const {
  612. constexpr static size_t MB = 8;
  613. constexpr static size_t KB = 8;
  614. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  615. kern_size_param.B_type.enumv() == kern_size_param.A_type.enumv() &&
  616. kern_size_param.C_type.enumv() == kern_size_param.A_type.enumv() &&
  617. kern_size_param.A_type.enumv() == DTypeEnum::Float32 &&
  618. kern_size_param.format == param::MatrixMul::Format::MK8 &&
  619. !kern_size_param.trA && !kern_size_param.trB &&
  620. kern_size_param.M % MB == 0 && kern_size_param.K % KB == 0 &&
  621. is_supported(SIMDType::FMA);
  622. }
  623. size_t MatrixMulImpl::AlgoF32MK8_8x8::get_workspace(
  624. const KernSizeParam& kern_param) const {
  625. MIDOUT_BEGIN(megdnn_x86_matmul_kern_mk8_8x8, midout_iv(0)) {
  626. const size_t m = kern_param.M;
  627. const size_t n = kern_param.N;
  628. const size_t k = kern_param.K;
  629. const bool trans_a = kern_param.trA;
  630. const bool trans_b = kern_param.trB;
  631. auto a_type = kern_param.A_type;
  632. auto b_type = kern_param.B_type;
  633. auto c_type = kern_param.C_type;
  634. x86::matmul::sgemm_nopack_8x8_avx2 strategy(a_type, b_type, c_type);
  635. return megdnn::matmul::GemmInterleaved<
  636. x86::matmul::sgemm_nopack_8x8_avx2, false>(
  637. m, n, k, trans_a, trans_b, strategy)
  638. .get_workspace_size();
  639. }
  640. MIDOUT_END();
  641. }
  642. /*************************AlgoFloatAVX2M6N16********************/
  643. MatrixMulImpl::kern_t MatrixMulImpl::AlgoFloatAVX2M6N16::get_kern(
  644. const KernSizeParam&) const {
  645. return gemm_f32_avx2_6x16;
  646. }
  647. bool MatrixMulImpl::AlgoFloatAVX2M6N16::usable(
  648. const KernSizeParam& kern_size_param) const {
  649. bool is_param_ok =
  650. kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
  651. ((kern_size_param.A_type.enumv() == DTypeEnum::Float32 &&
  652. kern_size_param.C_type.enumv() == DTypeEnum::Float32)) &&
  653. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  654. kern_size_param.format == Param::Format::DEFAULT &&
  655. is_supported(SIMDType::AVX2);
  656. return is_param_ok;
  657. }
  658. size_t MatrixMulImpl::AlgoFloatAVX2M6N16::get_workspace(
  659. const KernSizeParam& kern_param) const {
  660. constexpr int cacheline = 64;
  661. const size_t m = kern_param.M;
  662. const size_t n = kern_param.N;
  663. const size_t k = kern_param.K;
  664. const bool trans_a = kern_param.trA;
  665. const bool trans_b = kern_param.trB;
  666. auto a_type = kern_param.A_type;
  667. auto b_type = kern_param.B_type;
  668. auto c_type = kern_param.C_type;
  669. x86::matmul::sgemm_pack_6x16_avx2 strategy(m, n, k, a_type, b_type,
  670. c_type);
  671. return megdnn::matmul::GemmInterleaved<
  672. x86::matmul::sgemm_pack_6x16_avx2>(
  673. m, n, k, trans_a, trans_b, strategy, cacheline)
  674. .get_workspace_size();
  675. }
  676. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
  677. AlgoFloatAVX2M6N16, megdnn_x86_matmul_kern,
  678. "AlgoFloatAVX2M6N16"_hash, x86::matmul::sgemm_pack_6x16_avx2,
  679. float, float, float, AlgoDataType::FLOAT32, DEFAULT);
  680. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台