You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.cpp 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. /**
  2. * \file dnn/src/arm_common/matrix_mul/algos.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/arm_common/matrix_mul/algos.h"
  12. #include "src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.h"
  13. #include "src/arm_common/matrix_mul/fp16/hgemv.h"
  14. #include "src/arm_common/matrix_mul/fp32/exec_sgemv.h"
  15. #include "src/arm_common/matrix_mul/int8/gemv.h"
  16. #include "midout.h"
  17. MIDOUT_DECL(megdnn_arm_hgemv)
  18. MIDOUT_DECL(megdnn_arm_exec_int8816)
  19. using namespace megdnn;
  20. using namespace arm_common;
  21. /* ===================== Int8x8x16 algo ===================== */
  22. namespace {
  23. WorkspaceBundle get_workspace_bundle_int_8x8x16(
  24. const MatrixMulImpl::KernSizeParam& kern_size_param) {
  25. auto M = kern_size_param.M, K = kern_size_param.K, N = kern_size_param.N;
  26. // Use 8x8 tile
  27. return WorkspaceBundle(nullptr, {(M + 8) * K * sizeof(int8_t),
  28. K * (N + 8) * sizeof(int8_t)});
  29. }
  30. void exec_int_8x8x16(const MatrixMulImpl::KernParam& kern_param) {
  31. MIDOUT_BEGIN(megdnn_arm_exec_int8816, void) {
  32. auto bundle = get_workspace_bundle_int_8x8x16(kern_param);
  33. bundle.set(kern_param.workspace_ptr);
  34. auto w0 = static_cast<int8_t*>(bundle.get(0));
  35. auto w1 = static_cast<int8_t*>(bundle.get(1));
  36. size_t M = kern_param.M;
  37. size_t N = kern_param.N;
  38. size_t K = kern_param.K;
  39. size_t LDB = kern_param.LDB;
  40. exec_gemm_int8_int8_int16(
  41. kern_param.A<dt_int8>(), kern_param.B<dt_int8>(),
  42. kern_param.C<dt_int16>(), M, K, N, LDB, w0, w1);
  43. }
  44. MIDOUT_END();
  45. }
  46. } // anonymous namespace
  47. bool MatrixMulImpl::AlgoInt8x8x16::usable(
  48. const KernSizeParam& kern_size_param) const {
  49. return kern_size_param.A_type == dtype::Int8() &&
  50. kern_size_param.B_type == dtype::Int8() &&
  51. kern_size_param.C_type == dtype::Int16() &&
  52. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  53. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  54. !kern_size_param.trA && !kern_size_param.trB;
  55. }
  56. size_t MatrixMulImpl::AlgoInt8x8x16::get_workspace(
  57. const KernSizeParam& kern_size_param) const {
  58. auto wbundle = get_workspace_bundle_int_8x8x16(kern_size_param);
  59. return wbundle.total_size_in_bytes();
  60. }
  61. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern(
  62. const KernSizeParam&) const {
  63. return exec_int_8x8x16;
  64. }
  65. /* ===================== Int8x8x32 Gemv algo ===================== */
  66. namespace {
  67. void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) {
  68. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  69. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  70. const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
  71. auto Cptr = kern_param.C<dt_int32>();
  72. gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
  73. }
  74. } // anonymous namespace
  75. bool MatrixMulImpl::AlgoInt8x8x32Gemv::usable(
  76. const KernSizeParam& kern_size_param) const {
  77. auto N = kern_size_param.N, LDB = kern_size_param.LDB;
  78. return can_be_treated_as_int8x8x32(kern_size_param) &&
  79. !kern_size_param.trA && !kern_size_param.trB && (N == 1 && LDB == 1);
  80. }
  81. bool MatrixMulImpl::AlgoInt8x8x32Gemv::preferred(
  82. const KernSizeParam& kern_size_param) const {
  83. auto N = kern_size_param.N, LDB = kern_size_param.LDB;
  84. return N == 1 && LDB == 1;
  85. }
  86. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern(
  87. const KernSizeParam&) const {
  88. return int8x8x32_gemv_kern;
  89. }
  90. /* ===================== Int8x8x32 Gemv MK4 algo ===================== */
  91. namespace {
  92. void int8x8x32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) {
  93. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  94. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  95. const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
  96. auto Cptr = kern_param.C<dt_int32>();
  97. gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
  98. }
  99. } // anonymous namespace
  100. bool MatrixMulImpl::AlgoInt8x8x32GemvMK4::usable(
  101. const KernSizeParam& kern_size_param) const {
  102. auto M = kern_size_param.M;
  103. auto N = kern_size_param.N;
  104. auto K = kern_size_param.K;
  105. auto LDB = kern_size_param.LDB;
  106. bool is_dtype_ok =
  107. kern_size_param.A_type == kern_size_param.B_type &&
  108. (kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
  109. kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
  110. (kern_size_param.C_type.enumv() == DTypeEnum::Int32 ||
  111. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32);
  112. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  113. kern_size_param.format == param::MatrixMul::Format::MK4 &&
  114. is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB &&
  115. M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4;
  116. }
  117. bool MatrixMulImpl::AlgoInt8x8x32GemvMK4::preferred(
  118. const KernSizeParam& kern_size_param) const {
  119. MEGDNN_MARK_USED_VAR(kern_size_param);
  120. return true;
  121. }
  122. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern(
  123. const KernSizeParam&) const {
  124. return int8x8x32_gemv_mk4_kern;
  125. }
  126. #if __ARM_FEATURE_DOTPROD
  127. /* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */
  128. namespace {
  129. void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) {
  130. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  131. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  132. const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
  133. auto Cptr = kern_param.C<dt_int32>();
  134. gemv_like_mk4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
  135. }
  136. } // anonymous namespace
  137. bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::usable(
  138. const KernSizeParam& kern_size_param) const {
  139. auto M = kern_size_param.M;
  140. auto N = kern_size_param.N;
  141. auto K = kern_size_param.K;
  142. auto LDB = kern_size_param.LDB;
  143. bool is_dtype_ok =
  144. kern_size_param.A_type == kern_size_param.B_type &&
  145. (kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
  146. kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
  147. (kern_size_param.C_type.enumv() == DTypeEnum::Int32 ||
  148. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32);
  149. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  150. kern_size_param.format == param::MatrixMul::Format::MK4_DOT &&
  151. is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB &&
  152. M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4;
  153. }
  154. bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::preferred(
  155. const KernSizeParam& kern_size_param) const {
  156. return true;
  157. }
  158. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::get_kern(
  159. const KernSizeParam&) const {
  160. return int8x8x32_gemv_mk4_dot_kern;
  161. }
  162. #endif
  163. /* ===================== F32 Gemv algo ===================== */
  164. namespace {
  165. void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) {
  166. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  167. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  168. const auto Aptr = kern_param.A<dt_float32>(),
  169. Bptr = kern_param.B<dt_float32>();
  170. auto Cptr = kern_param.C<dt_float32>();
  171. gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
  172. }
  173. } // anonymous namespace
  174. bool MatrixMulImpl::AlgoF32Gemv::usable(
  175. const KernSizeParam& kern_size_param) const {
  176. // enumerate the M, N, K, only usable when preferred
  177. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  178. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  179. kern_size_param.B_type == kern_size_param.A_type &&
  180. kern_size_param.C_type == kern_size_param.A_type &&
  181. kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA &&
  182. !kern_size_param.trB && preferred(kern_size_param);
  183. }
  184. bool MatrixMulImpl::AlgoF32Gemv::preferred(
  185. const KernSizeParam& kern_size_param) const {
  186. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K,
  187. LDB = kern_size_param.LDB;
  188. return M < 8 || (M == 8 && K <= 2) || (N == 1 && LDB == 1);
  189. }
  190. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern(
  191. const KernSizeParam&) const {
  192. return f32_gemv_kern;
  193. }
  194. /* ================== F32 Gemv MK4 algo ================== */
  195. namespace {
  196. void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) {
  197. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  198. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  199. const auto Aptr = kern_param.A<dt_float32>(),
  200. Bptr = kern_param.B<dt_float32>();
  201. auto Cptr = kern_param.C<dt_float32>();
  202. gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
  203. }
  204. } // anonymous namespace
  205. bool MatrixMulImpl::AlgoF32GemvMK4::usable(
  206. const KernSizeParam& kern_size_param) const {
  207. // enumerate the M, N, K, only usable when preferred
  208. auto M = kern_size_param.M;
  209. auto N = kern_size_param.N;
  210. auto K = kern_size_param.K;
  211. auto LDB = kern_size_param.LDB;
  212. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  213. kern_size_param.format == param::MatrixMul::Format::MK4 &&
  214. kern_size_param.B_type == kern_size_param.A_type &&
  215. kern_size_param.C_type == kern_size_param.A_type &&
  216. kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA &&
  217. !kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 &&
  218. LDB == 4;
  219. }
  220. bool MatrixMulImpl::AlgoF32GemvMK4::preferred(
  221. const KernSizeParam& kern_size_param) const {
  222. MEGDNN_MARK_USED_VAR(kern_size_param);
  223. return true;
  224. }
  225. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GemvMK4::get_kern(
  226. const KernSizeParam&) const {
  227. return f32_gemv_mk4_kern;
  228. }
  229. /* ===================== F32 Gevm algo ===================== */
  230. namespace {
  231. template <typename stype, typename dtype>
  232. void gevm_like_kern(const MatrixMulImpl::KernParam& kern_param) {
  233. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  234. auto LDB = kern_param.LDB;
  235. const auto Aptr = kern_param.A<stype>(), Bptr = kern_param.B<stype>();
  236. auto Cptr = kern_param.C<dtype>();
  237. megdnn::arm_common::gemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1);
  238. }
  239. } // anonymous namespace
  240. bool MatrixMulImpl::AlgoGevm::usable(
  241. const KernSizeParam& kern_size_param) const {
  242. // enumerate the M, N, K, only usable when preferred
  243. bool fp32_ok =
  244. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  245. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  246. kern_size_param.B_type == kern_size_param.A_type &&
  247. kern_size_param.C_type == kern_size_param.A_type &&
  248. kern_size_param.A_type == dtype::Float32();
  249. bool fp16_ok = false;
  250. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  251. fp16_ok = kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  252. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  253. kern_size_param.B_type == kern_size_param.A_type &&
  254. kern_size_param.C_type == kern_size_param.A_type &&
  255. kern_size_param.A_type == dtype::Float16();
  256. #endif
  257. bool int8_ok = can_be_treated_as_int8x8x32(kern_size_param);
  258. return (fp32_ok || fp16_ok || int8_ok) && preferred(kern_size_param);
  259. }
  260. bool MatrixMulImpl::AlgoGevm::preferred(
  261. const KernSizeParam& kern_size_param) const {
  262. auto M = kern_size_param.M;
  263. return kern_size_param.trB && M == 1;
  264. }
  265. MatrixMulImpl::kern_t MatrixMulImpl::AlgoGevm::get_kern(
  266. const KernSizeParam& kern_size_param) const {
  267. if (kern_size_param.A_type == dtype::Float32()) {
  268. return gevm_like_kern<dt_float32, dt_float32>;
  269. } else if (kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
  270. kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) {
  271. return gevm_like_kern<dt_int8, dt_int32>;
  272. }
  273. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  274. else if (kern_size_param.A_type == dtype::Float16()) {
  275. return gevm_like_kern<__fp16, __fp16>;
  276. }
  277. #endif
  278. else {
  279. megdnn_assert(
  280. false, "no avaliable kern got A_type: %s B_type: %s C_type: %s",
  281. kern_size_param.A_type.name(), kern_size_param.B_type.name(),
  282. kern_size_param.C_type.name());
  283. }
  284. }
  285. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  286. /* ===================== F16 Gemv algo ===================== */
  287. namespace {
  288. void f16_gemv_kern(const MatrixMulImpl::KernParam& kern_param) {
  289. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  290. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  291. const auto Aptr = kern_param.A<dt_float16>(),
  292. Bptr = kern_param.B<dt_float16>();
  293. auto Cptr = kern_param.C<dt_float16>();
  294. MIDOUT_BEGIN(megdnn_arm_hgemv, void) {
  295. arm_common::gemv_like(reinterpret_cast<const __fp16*>(Aptr),
  296. reinterpret_cast<const __fp16*>(Bptr),
  297. reinterpret_cast<__fp16*>(Cptr), M, N, K, LDA,
  298. LDB, LDC);
  299. }
  300. MIDOUT_END();
  301. }
  302. } // anonymous namespace
  303. bool MatrixMulImpl::AlgoF16Gemv::usable(
  304. const KernSizeParam& kern_size_param) const {
  305. // enumerate the M, N, K, only usable when preferred
  306. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  307. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  308. kern_size_param.B_type == kern_size_param.A_type &&
  309. kern_size_param.C_type == kern_size_param.A_type &&
  310. kern_size_param.A_type == dtype::Float16() && !kern_size_param.trA &&
  311. !kern_size_param.trB && preferred(kern_size_param);
  312. }
  313. bool MatrixMulImpl::AlgoF16Gemv::preferred(
  314. const KernSizeParam& kern_size_param) const {
  315. auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K,
  316. LDB = kern_size_param.LDB;
  317. return M <= 4 || (M == 8 && K <= 2) || (N == 1 && LDB == 1);
  318. }
  319. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16Gemv::get_kern(
  320. const KernSizeParam&) const {
  321. return f16_gemv_kern;
  322. }
  323. #endif
  324. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台