You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.cpp 53 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169
  1. /**
  2. * \file dnn/src/aarch64/matrix_mul/algos.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/aarch64/matrix_mul/algos.h"
  12. #include "src/aarch64/matrix_mul/fp16/strategy.h"
  13. #include "src/aarch64/matrix_mul/fp32/strategy.h"
  14. #include "src/aarch64/matrix_mul/int16/strategy.h"
  15. #include "src/aarch64/matrix_mul/int8/strategy.h"
  16. #include "src/aarch64/matrix_mul/int8_dot/gemv.h"
  17. #include "src/aarch64/matrix_mul/int8_dot/strategy.h"
  18. #include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
  19. #include "src/aarch64/matrix_mul/quint8/strategy.h"
  20. #include "src/aarch64/matrix_mul/quint8_dot/gemv.h"
  21. #include "src/aarch64/matrix_mul/quint8_dot/strategy.h"
  22. #include "src/common/utils.h"
  23. #include "src/fallback/matrix_mul/gemm_impl.h"
  24. #include "midout.h"
  25. MIDOUT_DECL(megdnn_aarch64_matmul_kern)
  26. using namespace megdnn;
  27. using namespace aarch64;
  28. /* ===================== F32K8X12X1 algo ===================== */
  29. bool MatrixMulImpl::AlgoF32K8x12x1::usable(
  30. const KernSizeParam& kern_size_param) const {
  31. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  32. kern_size_param.B_type == kern_size_param.A_type &&
  33. kern_size_param.C_type == kern_size_param.A_type &&
  34. kern_size_param.A_type == dtype::Float32() &&
  35. kern_size_param.format == param::MatrixMul::Format::DEFAULT;
  36. }
  37. size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace(
  38. const KernSizeParam& kern_size_param) const {
  39. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  40. midout_iv("AlgoF32K8x12x1::get_workspace"_hash)) {
  41. auto M = kern_size_param.M, N = kern_size_param.N,
  42. K = kern_size_param.K;
  43. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  44. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  45. C_type = kern_size_param.C_type;
  46. aarch64::matmul::sgemm_8x12 strategy(M, N, K, A_type, B_type, C_type);
  47. return megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_8x12>(
  48. M, N, K, trA, trB, strategy)
  49. .get_workspace_size();
  50. }
  51. MIDOUT_END();
  52. }
  53. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(
  54. const KernSizeParam&) const {
  55. auto f32_kern_8x12 = [](const MatrixMulImpl::KernParam& kern_param) {
  56. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  57. midout_iv("AlgoF32K8x12x1::get_kern"_hash)) {
  58. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  59. auto trA = kern_param.trA, trB = kern_param.trB;
  60. auto LDA = kern_param.LDA, LDB = kern_param.LDB,
  61. LDC = kern_param.LDC;
  62. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  63. C_type = kern_param.C_type;
  64. const auto Aptr = kern_param.A<float>(),
  65. Bptr = kern_param.B<float>();
  66. auto Cptr = kern_param.C<float>();
  67. aarch64::matmul::sgemm_8x12 strategy(M, N, K, A_type, B_type,
  68. C_type);
  69. megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_8x12>(
  70. M, N, K, trA, trB, strategy)
  71. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  72. kern_param.workspace_ptr);
  73. }
  74. MIDOUT_END();
  75. };
  76. return f32_kern_8x12;
  77. }
  78. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern,
  79. "AlgoF32K8x12x1Impl"_hash,
  80. aarch64::matmul::sgemm_8x12, float, float);
  81. /* ===================== F32_MK4_8X12X1 algo ===================== */
  82. bool MatrixMulImpl::AlgoF32MK4_8x12x1::usable(
  83. const KernSizeParam& kern_size_param) const {
  84. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  85. kern_size_param.B_type == kern_size_param.A_type &&
  86. kern_size_param.C_type == kern_size_param.A_type &&
  87. kern_size_param.A_type == dtype::Float32() &&
  88. kern_size_param.format == param::MatrixMul::Format::MK4 &&
  89. !kern_size_param.trA && !kern_size_param.trB &&
  90. kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0;
  91. }
  92. size_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_workspace(
  93. const KernSizeParam& kern_size_param) const {
  94. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  95. midout_iv("AlgoF32MK4_8x12x1::get_workspace"_hash)) {
  96. auto M = kern_size_param.M, N = kern_size_param.N,
  97. K = kern_size_param.K;
  98. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  99. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  100. C_type = kern_size_param.C_type;
  101. aarch64::matmul::sgemm_mk4_8x12 strategy(M, N, K, A_type, B_type,
  102. C_type);
  103. return megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_mk4_8x12>(
  104. M, N, K, trA, trB, strategy)
  105. .get_workspace_size();
  106. }
  107. MIDOUT_END();
  108. }
  109. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_kern(
  110. const KernSizeParam&) const {
  111. auto f32_kern_mk4_8x12 = [](const MatrixMulImpl::KernParam& kern_param) {
  112. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  113. midout_iv("AlgoF32MK4_8x12x1::get_kern"_hash)) {
  114. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  115. auto trA = kern_param.trA, trB = kern_param.trB;
  116. auto LDA = kern_param.LDA, LDB = kern_param.LDB,
  117. LDC = kern_param.LDC;
  118. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  119. C_type = kern_param.C_type;
  120. const auto Aptr = kern_param.A<float>(),
  121. Bptr = kern_param.B<float>();
  122. auto Cptr = kern_param.C<float>();
  123. aarch64::matmul::sgemm_mk4_8x12 strategy(M, N, K, A_type, B_type,
  124. C_type);
  125. megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_mk4_8x12>(
  126. M, N, K, trA, trB, strategy)
  127. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  128. kern_param.workspace_ptr);
  129. }
  130. MIDOUT_END();
  131. };
  132. return f32_kern_mk4_8x12;
  133. }
  134. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4_8x12x1,
  135. megdnn_aarch64_matmul_kern,
  136. "AlgoF32MK4_8x12x1Impl"_hash,
  137. aarch64::matmul::sgemm_mk4_8x12, float,
  138. float);
  139. /* ===================== F32K4X16X1 algo ===================== */
  140. bool MatrixMulImpl::AlgoF32K4x16x1::usable(
  141. const KernSizeParam& kern_size_param) const {
  142. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  143. kern_size_param.B_type == kern_size_param.A_type &&
  144. kern_size_param.C_type == kern_size_param.A_type &&
  145. kern_size_param.A_type == dtype::Float32() &&
  146. kern_size_param.format == param::MatrixMul::Format::DEFAULT;
  147. }
  148. size_t MatrixMulImpl::AlgoF32K4x16x1::get_workspace(
  149. const KernSizeParam& kern_size_param) const {
  150. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  151. midout_iv("AlgoF32K4x16x1::get_workspace"_hash)) {
  152. auto M = kern_size_param.M, N = kern_size_param.N,
  153. K = kern_size_param.K;
  154. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  155. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  156. C_type = kern_size_param.C_type;
  157. aarch64::matmul::sgemm_4x16 strategy(M, N, K, A_type, B_type, C_type);
  158. return megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_4x16>(
  159. M, N, K, trA, trB, strategy)
  160. .get_workspace_size();
  161. }
  162. MIDOUT_END();
  163. }
  164. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K4x16x1::get_kern(
  165. const KernSizeParam&) const {
  166. auto f32_kern_4x16 = [](const MatrixMulImpl::KernParam& kern_param) {
  167. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  168. midout_iv("AlgoF32K4x16x1::get_kern"_hash)) {
  169. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  170. auto trA = kern_param.trA, trB = kern_param.trB;
  171. auto LDA = kern_param.LDA, LDB = kern_param.LDB,
  172. LDC = kern_param.LDC;
  173. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  174. C_type = kern_param.C_type;
  175. const auto Aptr = kern_param.A<float>(),
  176. Bptr = kern_param.B<float>();
  177. auto Cptr = kern_param.C<float>();
  178. aarch64::matmul::sgemm_4x16 strategy(M, N, K, A_type, B_type,
  179. C_type);
  180. megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_4x16>(
  181. M, N, K, trA, trB, strategy)
  182. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  183. kern_param.workspace_ptr);
  184. }
  185. MIDOUT_END();
  186. };
  187. return f32_kern_4x16;
  188. }
  189. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K4x16x1, megdnn_aarch64_matmul_kern,
  190. "AlgoF32K4x16x1Impl"_hash,
  191. aarch64::matmul::sgemm_4x16, float, float);
  192. /* ===================== F32MK4_4x16 algo ===================== */
  193. bool MatrixMulImpl::AlgoF32MK4_4x16::usable(
  194. const KernSizeParam& kern_size_param) const {
  195. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  196. kern_size_param.C_type == dtype::Float32() &&
  197. kern_size_param.B_type == dtype::Float32() &&
  198. kern_size_param.A_type == dtype::Float32() &&
  199. kern_size_param.format == param::MatrixMul::Format::MK4 &&
  200. !kern_size_param.trA && !kern_size_param.trB &&
  201. kern_size_param.N % 4 == 0;
  202. }
  203. size_t MatrixMulImpl::AlgoF32MK4_4x16::get_workspace(
  204. const KernSizeParam& kern_size_param) const {
  205. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  206. midout_iv("AlgoF32MK4_4x16::get_workspace"_hash)) {
  207. auto M = kern_size_param.M, N = kern_size_param.N,
  208. K = kern_size_param.K;
  209. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  210. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  211. C_type = kern_size_param.C_type;
  212. aarch64::matmul::sgemm_nopack_4x16 strategy(A_type, B_type, C_type);
  213. return megdnn::matmul::GemmInterleaved<
  214. aarch64::matmul::sgemm_nopack_4x16, false>(M, N, K, trA,
  215. trB, strategy)
  216. .get_workspace_size();
  217. }
  218. MIDOUT_END();
  219. }
  220. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_4x16::get_kern(
  221. const KernSizeParam&) const {
  222. auto f32_kern_mk4_4x16 = [](const MatrixMulImpl::KernParam& kern_param) {
  223. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  224. midout_iv("AlgoF32MK4_4x16::get_kern"_hash)) {
  225. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  226. auto trA = kern_param.trA, trB = kern_param.trB;
  227. auto LDA = kern_param.LDA, LDB = kern_param.LDB,
  228. LDC = kern_param.LDC;
  229. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  230. C_type = kern_param.C_type;
  231. const auto Aptr = kern_param.A<float>(),
  232. Bptr = kern_param.B<float>();
  233. auto Cptr = kern_param.C<float>();
  234. aarch64::matmul::sgemm_nopack_4x16 strategy(A_type, B_type, C_type);
  235. megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_nopack_4x16,
  236. false>(M, N, K, trA, trB, strategy)
  237. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  238. kern_param.workspace_ptr);
  239. }
  240. MIDOUT_END();
  241. };
  242. return f32_kern_mk4_4x16;
  243. }
  244. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  245. /* ===================== F16 K8x24x1 algo ===================== */
  246. namespace {
  247. void f16_kern(const MatrixMulImpl::KernParam& kern_param) {
  248. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, midout_iv("f16_kern"_hash)) {
  249. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  250. auto trA = kern_param.trA, trB = kern_param.trB;
  251. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  252. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  253. C_type = kern_param.C_type;
  254. const auto Aptr = kern_param.A<dt_float16>(),
  255. Bptr = kern_param.B<dt_float16>();
  256. auto Cptr = kern_param.C<dt_float16>();
  257. aarch64::matmul::hgemm_8x24 strategy(M, N, K, A_type, B_type, C_type);
  258. megdnn::matmul::GemmInterleaved<aarch64::matmul::hgemm_8x24>(
  259. M, N, K, trA, trB, strategy)
  260. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  261. kern_param.workspace_ptr);
  262. }
  263. MIDOUT_END();
  264. }
  265. } // anonymous namespace
  266. bool MatrixMulImpl::AlgoF16K8x24x1::usable(
  267. const KernSizeParam& kern_size_param) const {
  268. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  269. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  270. kern_size_param.C_type == kern_size_param.A_type &&
  271. kern_size_param.B_type == kern_size_param.A_type &&
  272. kern_size_param.A_type == dtype::Float16();
  273. }
  274. size_t MatrixMulImpl::AlgoF16K8x24x1::get_workspace(
  275. const KernSizeParam& kern_size_param) const {
  276. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  277. midout_iv("AlgoF16K8x24x1::get_workspace"_hash)) {
  278. auto M = kern_size_param.M, N = kern_size_param.N,
  279. K = kern_size_param.K;
  280. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  281. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  282. C_type = kern_size_param.C_type;
  283. aarch64::matmul::hgemm_8x24 strategy(M, N, K, A_type, B_type, C_type);
  284. return megdnn::matmul::GemmInterleaved<aarch64::matmul::hgemm_8x24>(
  285. M, N, K, trA, trB, strategy)
  286. .get_workspace_size();
  287. }
  288. MIDOUT_END();
  289. }
  290. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K8x24x1::get_kern(
  291. const KernSizeParam&) const {
  292. return f16_kern;
  293. }
  294. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K8x24x1, megdnn_aarch64_matmul_kern,
  295. "AlogF16K8x24x1Impl"_hash,
  296. aarch64::matmul::hgemm_8x24, dt_float16,
  297. dt_float16);
  298. /* ===================== F16_MK8_8x8 algo ===================== */
  299. bool MatrixMulImpl::AlgoF16MK8_8x8::usable(
  300. const KernSizeParam& kern_size_param) const {
  301. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  302. kern_size_param.C_type == kern_size_param.A_type &&
  303. kern_size_param.B_type == kern_size_param.A_type &&
  304. kern_size_param.A_type == dtype::Float16() &&
  305. kern_size_param.format == param::MatrixMul::Format::MK8 &&
  306. !kern_size_param.trA && !kern_size_param.trB &&
  307. kern_size_param.N % 4 == 0;
  308. }
  309. size_t MatrixMulImpl::AlgoF16MK8_8x8::get_workspace(
  310. const KernSizeParam& kern_size_param) const {
  311. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  312. midout_iv("AlgoF16MK8_8x8::get_workspace"_hash)) {
  313. auto M = kern_size_param.M, N = kern_size_param.N,
  314. K = kern_size_param.K;
  315. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  316. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  317. C_type = kern_size_param.C_type;
  318. aarch64::matmul::gemm_nopack_f16_8x8 strategy(A_type, B_type, C_type);
  319. return megdnn::matmul::GemmInterleaved<
  320. aarch64::matmul::gemm_nopack_f16_8x8, false>(
  321. M, N, K, trA, trB, strategy)
  322. .get_workspace_size();
  323. }
  324. MIDOUT_END();
  325. }
  326. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern(
  327. const KernSizeParam&) const {
  328. auto kern_mk8_8x8 = [](const MatrixMulImpl::KernParam& kern_param) {
  329. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  330. midout_iv("AlgoF16MK8_8x8::get_kern"_hash)) {
  331. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  332. auto trA = kern_param.trA, trB = kern_param.trB;
  333. auto LDA = kern_param.LDA, LDB = kern_param.LDB,
  334. LDC = kern_param.LDC;
  335. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  336. C_type = kern_param.C_type;
  337. const auto Aptr = kern_param.A<dt_float16>(),
  338. Bptr = kern_param.B<dt_float16>();
  339. auto Cptr = kern_param.C<dt_float16>();
  340. aarch64::matmul::gemm_nopack_f16_8x8 strategy(A_type, B_type,
  341. C_type);
  342. megdnn::matmul::GemmInterleaved<
  343. aarch64::matmul::gemm_nopack_f16_8x8, false>(M, N, K, trA,
  344. trB, strategy)
  345. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  346. kern_param.workspace_ptr);
  347. }
  348. MIDOUT_END();
  349. };
  350. return kern_mk8_8x8;
  351. }
  352. #endif
  353. #if __ARM_FEATURE_DOTPROD
  354. /* ==================== Int8x8x32 K8x12x4 Dotprod algo ==================== */
  355. namespace {
  356. void int8x8x32_k8x12x4_dotprod_kern(
  357. const MatrixMulImpl::KernParam& kern_param) {
  358. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  359. midout_iv("int8x8x32_k8x12x4_dotprod_kern"_hash)) {
  360. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  361. auto trA = kern_param.trA, trB = kern_param.trB;
  362. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  363. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  364. C_type = kern_param.C_type;
  365. const auto Aptr = kern_param.A<dt_int8>(),
  366. Bptr = kern_param.B<dt_int8>();
  367. auto Cptr = kern_param.C<dt_int32>();
  368. aarch64::matmul::gemm_s8_8x12 strategy(M, N, K, A_type, B_type, C_type);
  369. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8_8x12>(
  370. M, N, K, trA, trB, strategy)
  371. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  372. kern_param.workspace_ptr);
  373. }
  374. MIDOUT_END();
  375. }
  376. } // anonymous namespace
  377. bool MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::usable(
  378. const KernSizeParam& kern_size_param) const {
  379. return can_be_treated_as_int8x8x32(kern_size_param);
  380. }
  381. size_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_workspace(
  382. const KernSizeParam& kern_size_param) const {
  383. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  384. midout_iv("AlgoInt8x8x32K8x12x4DotProd::get_workspace"_hash)) {
  385. auto M = kern_size_param.M, N = kern_size_param.N,
  386. K = kern_size_param.K;
  387. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  388. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  389. C_type = kern_size_param.C_type;
  390. aarch64::matmul::gemm_s8_8x12 strategy(M, N, K, A_type, B_type, C_type);
  391. return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8_8x12>(
  392. M, N, K, trA, trB, strategy)
  393. .get_workspace_size();
  394. }
  395. MIDOUT_END();
  396. }
  397. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_kern(
  398. const KernSizeParam&) const {
  399. return int8x8x32_k8x12x4_dotprod_kern;
  400. }
  401. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd,
  402. megdnn_aarch64_matmul_kern,
  403. "AlgoInt8x8x32K8x12x4DotProdImpl"_hash,
  404. aarch64::matmul::gemm_s8_8x12, int8_t,
  405. int32_t);
  406. /* ===================== Int8x8x32 Gemv DotProd algo ===================== */
  407. namespace {
  408. void int8x8x32_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
  409. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  410. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  411. const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
  412. auto Cptr = kern_param.C<dt_int32>();
  413. aarch64::matmul::gemv_like_int8(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
  414. }
  415. } // anonymous namespace
  416. bool MatrixMulImpl::AlgoInt8x8x32GemvDotProd::usable(
  417. const KernSizeParam& kern_size_param) const {
  418. return can_be_treated_as_int8x8x32(kern_size_param) &&
  419. !kern_size_param.trA && !kern_size_param.trB &&
  420. kern_size_param.N == 1 && kern_size_param.LDB == 1;
  421. }
  422. bool MatrixMulImpl::AlgoInt8x8x32GemvDotProd::preferred(
  423. const KernSizeParam& kern_size_param) const {
  424. auto N = kern_size_param.N, LDB = kern_size_param.LDB;
  425. return (N == 1 && LDB == 1);
  426. }
  427. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvDotProd::get_kern(
  428. const KernSizeParam&) const {
  429. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  430. midout_iv("AlgoInt8x8x32GemvDotProd::get_kern"_hash)) {
  431. return int8x8x32_gemv_dotprod_kern;
  432. }
  433. MIDOUT_END();
  434. return nullptr;
  435. }
  436. /* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */
  437. namespace {
  438. void int8x8x32_mk4_8x12x4_dotprod_kern(
  439. const MatrixMulImpl::KernParam& kern_param) {
  440. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  441. midout_iv("int8x8x32_mk4_8x12x4_dotprod_kern"_hash)) {
  442. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  443. auto trA = kern_param.trA, trB = kern_param.trB;
  444. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  445. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  446. C_type = kern_param.C_type;
  447. const auto Aptr = kern_param.A<dt_int8>(),
  448. Bptr = kern_param.B<dt_int8>();
  449. auto Cptr = kern_param.C<dt_int32>();
  450. aarch64::matmul::gemm_mk4_s8_8x12 strategy(M, N, K, A_type, B_type,
  451. C_type);
  452. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_mk4_s8_8x12>(
  453. M, N, K, trA, trB, strategy)
  454. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  455. kern_param.workspace_ptr);
  456. }
  457. MIDOUT_END();
  458. }
  459. } // anonymous namespace
  460. bool MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::usable(
  461. const KernSizeParam& kern_size_param) const {
  462. return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
  463. (kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
  464. kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
  465. (kern_size_param.C_type.enumv() == DTypeEnum::Int32 ||
  466. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32) &&
  467. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  468. kern_size_param.format == param::MatrixMul::Format::MK4_DOT &&
  469. !kern_size_param.trA && !kern_size_param.trB;
  470. }
  471. size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace(
  472. const KernSizeParam& kern_size_param) const {
  473. MIDOUT_BEGIN(
  474. megdnn_aarch64_matmul_kern,
  475. midout_iv("AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace"_hash)) {
  476. auto M = kern_size_param.M, N = kern_size_param.N,
  477. K = kern_size_param.K;
  478. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  479. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  480. C_type = kern_size_param.C_type;
  481. aarch64::matmul::gemm_mk4_s8_8x12 strategy(M, N, K, A_type, B_type,
  482. C_type);
  483. return megdnn::matmul::GemmInterleaved<
  484. aarch64::matmul::gemm_mk4_s8_8x12>(M, N, K, trA, trB,
  485. strategy)
  486. .get_workspace_size();
  487. }
  488. MIDOUT_END();
  489. }
  490. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_kern(
  491. const KernSizeParam&) const {
  492. return int8x8x32_mk4_8x12x4_dotprod_kern;
  493. }
  494. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x12x4DotProd,
  495. megdnn_aarch64_matmul_kern,
  496. "AlgoInt8x8x32MK4_8x12x4DotProdImpl"_hash,
  497. aarch64::matmul::gemm_mk4_s8_8x12, int8_t,
  498. int32_t);
  499. #else
  500. /* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */
  501. namespace {
  502. void int8x8x32_mk4_4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) {
  503. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  504. midout_iv("int8x8x32_mk4_4x4x16_kern"_hash)) {
  505. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  506. auto trA = kern_param.trA, trB = kern_param.trB;
  507. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  508. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  509. C_type = kern_param.C_type;
  510. const auto Aptr = kern_param.A<dt_int8>(),
  511. Bptr = kern_param.B<dt_int8>();
  512. auto Cptr = kern_param.C<dt_int32>();
  513. aarch64::matmul::gemm_mk4_s8_4x4 strategy(M, N, K, A_type, B_type,
  514. C_type);
  515. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_mk4_s8_4x4>(
  516. M, N, K, trA, trB, strategy)
  517. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  518. kern_param.workspace_ptr);
  519. }
  520. MIDOUT_END();
  521. }
  522. } // anonymous namespace
  523. bool MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::usable(
  524. const KernSizeParam& param) const {
  525. return param.A_type.enumv() == param.B_type.enumv() &&
  526. (param.A_type.enumv() == DTypeEnum::Int8 ||
  527. param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
  528. (param.C_type.enumv() == DTypeEnum::Int32 ||
  529. param.C_type.enumv() == DTypeEnum::QuantizedS32) &&
  530. param.compute_mode == Param::ComputeMode::DEFAULT &&
  531. param.format == param::MatrixMul::Format::MK4 && !param.trA &&
  532. !param.trB && param.M % 4 == 0 && param.K % 4 == 0;
  533. }
  534. bool MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::preferred(
  535. const KernSizeParam& kern_size_param) const {
  536. return kern_size_param.K > 16;
  537. }
  538. size_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_workspace(
  539. const KernSizeParam& kern_size_param) const {
  540. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  541. midout_iv("AlgoInt8x8x32MK4_4x4x16::get_workspace"_hash)) {
  542. auto M = kern_size_param.M, N = kern_size_param.N,
  543. K = kern_size_param.K;
  544. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  545. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  546. C_type = kern_size_param.C_type;
  547. aarch64::matmul::gemm_mk4_s8_4x4 strategy(M, N, K, A_type, B_type,
  548. C_type);
  549. return megdnn::matmul::GemmInterleaved<
  550. aarch64::matmul::gemm_mk4_s8_4x4>(M, N, K, trA, trB,
  551. strategy)
  552. .get_workspace_size();
  553. }
  554. MIDOUT_END();
  555. }
  556. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_kern(
  557. const KernSizeParam&) const {
  558. return int8x8x32_mk4_4x4x16_kern;
  559. }
  560. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x4x16,
  561. megdnn_aarch64_matmul_kern,
  562. "AlgoInt8x8x32MK4_4x4x16Impl"_hash,
  563. aarch64::matmul::gemm_mk4_s8_4x4, int8_t,
  564. int32_t);
  565. /* ===================== Int8x8x32 K4x4x16 algo ===================== */
  566. namespace {
  567. void int8x8x32_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) {
  568. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  569. midout_iv("int8x8x32_k4x4x16_kern"_hash)) {
  570. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  571. auto trA = kern_param.trA, trB = kern_param.trB;
  572. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  573. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  574. C_type = kern_param.C_type;
  575. const auto Aptr = kern_param.A<dt_int8>(),
  576. Bptr = kern_param.B<dt_int8>();
  577. auto Cptr = kern_param.C<dt_int32>();
  578. aarch64::matmul::gemm_s8_4x4 strategy(M, N, K, A_type, B_type, C_type);
  579. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8_4x4>(
  580. M, N, K, trA, trB, strategy)
  581. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  582. kern_param.workspace_ptr);
  583. }
  584. MIDOUT_END();
  585. }
  586. } // anonymous namespace
  587. bool MatrixMulImpl::AlgoInt8x8x32K4x4x16::usable(
  588. const KernSizeParam& kern_size_param) const {
  589. return can_be_treated_as_int8x8x32(kern_size_param);
  590. }
  591. bool MatrixMulImpl::AlgoInt8x8x32K4x4x16::preferred(
  592. const KernSizeParam& kern_size_param) const {
  593. return kern_size_param.K > 16;
  594. }
  595. size_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_workspace(
  596. const KernSizeParam& kern_size_param) const {
  597. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  598. midout_iv("AlgoInt8x8x32K4x4x16::get_workspace"_hash)) {
  599. auto M = kern_size_param.M, N = kern_size_param.N,
  600. K = kern_size_param.K;
  601. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  602. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  603. C_type = kern_size_param.C_type;
  604. aarch64::matmul::gemm_s8_4x4 strategy(M, N, K, A_type, B_type, C_type);
  605. return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8_4x4>(
  606. M, N, K, trA, trB, strategy)
  607. .get_workspace_size();
  608. }
  609. MIDOUT_END();
  610. }
  611. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_kern(
  612. const KernSizeParam&) const {
  613. return int8x8x32_k4x4x16_kern;
  614. }
  615. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x4x16,
  616. megdnn_aarch64_matmul_kern,
  617. "AlgoInt8x8x32K4x4x16Impl"_hash,
  618. aarch64::matmul::gemm_s8_4x4, int8_t,
  619. int32_t);
  620. /* ===================== Int8x8x32 K8x8x8 algo ===================== */
  621. namespace {
  622. void int8x8x32_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
  623. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  624. midout_iv("int8x8x32_k8x8x8_kern"_hash)) {
  625. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  626. auto trA = kern_param.trA, trB = kern_param.trB;
  627. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  628. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  629. C_type = kern_param.C_type;
  630. const auto Aptr = kern_param.A<dt_int8>(),
  631. Bptr = kern_param.B<dt_int8>();
  632. auto Cptr = kern_param.C<dt_int32>();
  633. aarch64::matmul::gemm_s8_8x8 strategy(M, N, K, A_type, B_type, C_type);
  634. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8_8x8>(
  635. M, N, K, trA, trB, strategy)
  636. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  637. kern_param.workspace_ptr);
  638. }
  639. MIDOUT_END();
  640. }
  641. } // anonymous namespace
  642. bool MatrixMulImpl::AlgoInt8x8x32K8x8x8::usable(
  643. const KernSizeParam& kern_size_param) const {
  644. return can_be_treated_as_int8x8x32(kern_size_param);
  645. }
  646. bool MatrixMulImpl::AlgoInt8x8x32K8x8x8::preferred(
  647. const KernSizeParam& kern_size_param) const {
  648. return kern_size_param.K <= 16;
  649. }
  650. size_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_workspace(
  651. const KernSizeParam& kern_size_param) const {
  652. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  653. midout_iv("AlgoInt8x8x32K8x8x8::get_workspace"_hash)) {
  654. auto M = kern_size_param.M, N = kern_size_param.N,
  655. K = kern_size_param.K;
  656. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  657. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  658. C_type = kern_size_param.C_type;
  659. aarch64::matmul::gemm_s8_8x8 strategy(M, N, K, A_type, B_type, C_type);
  660. return megdnn::matmul::GemmInterleaved<matmul::gemm_s8_8x8>(
  661. M, N, K, trA, trB, strategy)
  662. .get_workspace_size();
  663. }
  664. MIDOUT_END();
  665. }
  666. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_kern(
  667. const KernSizeParam&) const {
  668. return int8x8x32_k8x8x8_kern;
  669. }
  670. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x8x8,
  671. megdnn_aarch64_matmul_kern,
  672. "AlgoInt8x8x32K8x8x8Impl"_hash,
  673. aarch64::matmul::gemm_s8_8x8, int8_t,
  674. int32_t);
  675. #endif
  676. /* ===================== Int8x8x16 K8x8x8 algo ===================== */
  677. namespace {
  678. void int8x8x16_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
  679. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  680. midout_iv("int8x8x16_k8x8x8_kern"_hash)) {
  681. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  682. auto trA = kern_param.trA, trB = kern_param.trB;
  683. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  684. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  685. C_type = kern_param.C_type;
  686. const auto Aptr = kern_param.A<dt_int8>(),
  687. Bptr = kern_param.B<dt_int8>();
  688. auto Cptr = kern_param.C<dt_int16>();
  689. aarch64::matmul::gemm_s8x8x16_8x8 strategy(M, N, K, A_type, B_type,
  690. C_type);
  691. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8x8x16_8x8>(
  692. M, N, K, trA, trB, strategy)
  693. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  694. kern_param.workspace_ptr);
  695. }
  696. MIDOUT_END();
  697. }
  698. } // anonymous namespace
  699. bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::usable(
  700. const KernSizeParam& kern_size_param) const {
  701. return can_be_treated_as_int8x8x16(kern_size_param);
  702. }
  703. bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::preferred(
  704. const KernSizeParam& kern_size_param) const {
  705. return kern_size_param.K <= 16;
  706. }
  707. size_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_workspace(
  708. const KernSizeParam& kern_size_param) const {
  709. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  710. midout_iv("AlgoInt8x8x16K8x8x8::get_workspace"_hash)) {
  711. auto M = kern_size_param.M, N = kern_size_param.N,
  712. K = kern_size_param.K;
  713. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  714. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  715. C_type = kern_size_param.C_type;
  716. aarch64::matmul::gemm_s8x8x16_8x8 strategy(M, N, K, A_type, B_type,
  717. C_type);
  718. return megdnn::matmul::GemmInterleaved<matmul::gemm_s8x8x16_8x8>(
  719. M, N, K, trA, trB, strategy)
  720. .get_workspace_size();
  721. }
  722. MIDOUT_END();
  723. }
  724. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_kern(
  725. const KernSizeParam&) const {
  726. return int8x8x16_k8x8x8_kern;
  727. }
  728. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K8x8x8,
  729. megdnn_aarch64_matmul_kern,
  730. "AlgoInt8x8x16K8x8x8Impl"_hash,
  731. aarch64::matmul::gemm_s8x8x16_8x8, int8_t,
  732. int16_t);
  733. /* ===================== Int8x8x16 K4x4x16 algo ===================== */
  734. namespace {
  735. void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) {
  736. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  737. midout_iv("int8x8x16_k4x4x16_kern"_hash)) {
  738. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  739. auto trA = kern_param.trA, trB = kern_param.trB;
  740. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  741. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  742. C_type = kern_param.C_type;
  743. const auto Aptr = kern_param.A<dt_int8>(),
  744. Bptr = kern_param.B<dt_int8>();
  745. auto Cptr = kern_param.C<dt_int16>();
  746. aarch64::matmul::gemm_s8x8x16_4x4 strategy(M, N, K, A_type, B_type,
  747. C_type);
  748. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8x8x16_4x4>(
  749. M, N, K, trA, trB, strategy)
  750. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  751. kern_param.workspace_ptr);
  752. }
  753. MIDOUT_END();
  754. }
  755. } // anonymous namespace
  756. bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::usable(
  757. const KernSizeParam& kern_size_param) const {
  758. return can_be_treated_as_int8x8x16(kern_size_param);
  759. }
  760. bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::preferred(
  761. const KernSizeParam& kern_size_param) const {
  762. MEGDNN_MARK_USED_VAR(kern_size_param);
  763. return true;
  764. }
  765. size_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_workspace(
  766. const KernSizeParam& kern_size_param) const {
  767. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  768. midout_iv("AlgoInt8x8x16K4x4x16::get_workspace"_hash)) {
  769. auto M = kern_size_param.M, N = kern_size_param.N,
  770. K = kern_size_param.K;
  771. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  772. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  773. C_type = kern_size_param.C_type;
  774. aarch64::matmul::gemm_s8x8x16_4x4 strategy(M, N, K, A_type, B_type,
  775. C_type);
  776. return megdnn::matmul::GemmInterleaved<matmul::gemm_s8x8x16_4x4>(
  777. M, N, K, trA, trB, strategy)
  778. .get_workspace_size();
  779. }
  780. MIDOUT_END();
  781. }
  782. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_kern(
  783. const KernSizeParam&) const {
  784. return int8x8x16_k4x4x16_kern;
  785. }
  786. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x4x16,
  787. megdnn_aarch64_matmul_kern,
  788. "AlgoInt8x8x16K4x4x16Impl"_hash,
  789. aarch64::matmul::gemm_s8x8x16_4x4, int8_t,
  790. int16_t);
  791. /* ===================== Int16x16x32 K12x8x1 algo ===================== */
  792. namespace {
  793. void int16x16x32_k12x8x1_kern(const MatrixMulImpl::KernParam& kern_param) {
  794. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  795. midout_iv("int16x16x32_k12x8x1_kern"_hash)) {
  796. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  797. auto trA = kern_param.trA, trB = kern_param.trB;
  798. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  799. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  800. C_type = kern_param.C_type;
  801. const auto Aptr = kern_param.A<dt_int16>(),
  802. Bptr = kern_param.B<dt_int16>();
  803. auto Cptr = kern_param.C<dt_int32>();
  804. aarch64::matmul::gemm_s16_12x8x1 strategy(M, N, K, A_type, B_type,
  805. C_type);
  806. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s16_12x8x1>(
  807. M, N, K, trA, trB, strategy)
  808. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  809. kern_param.workspace_ptr);
  810. }
  811. MIDOUT_END();
  812. }
  813. } // anonymous namespace
  814. bool MatrixMulImpl::AlgoInt16x16x32K12x8x1::usable(
  815. const KernSizeParam& kern_size_param) const {
  816. return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
  817. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  818. kern_size_param.compute_mode ==
  819. param::MatrixMul::ComputeMode::DEFAULT &&
  820. kern_size_param.A_type.enumv() == DTypeEnum::Int16 &&
  821. kern_size_param.C_type.enumv() == DTypeEnum::Int32;
  822. }
  823. bool MatrixMulImpl::AlgoInt16x16x32K12x8x1::preferred(
  824. const KernSizeParam& kern_size_param) const {
  825. MEGDNN_MARK_USED_VAR(kern_size_param);
  826. return true;
  827. }
  828. size_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_workspace(
  829. const KernSizeParam& kern_size_param) const {
  830. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  831. midout_iv("AlgoInt16x16x32K12x8x1::get_workspace"_hash)) {
  832. auto M = kern_size_param.M, N = kern_size_param.N,
  833. K = kern_size_param.K;
  834. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  835. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  836. C_type = kern_size_param.C_type;
  837. aarch64::matmul::gemm_s16_12x8x1 strategy(M, N, K, A_type, B_type,
  838. C_type);
  839. return megdnn::matmul::GemmInterleaved<matmul::gemm_s16_12x8x1>(
  840. M, N, K, trA, trB, strategy)
  841. .get_workspace_size();
  842. }
  843. MIDOUT_END();
  844. }
  845. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_kern(
  846. const KernSizeParam&) const {
  847. return int16x16x32_k12x8x1_kern;
  848. }
  849. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x8x1,
  850. megdnn_aarch64_matmul_kern,
  851. "AlgoInt16x16x32K12x8x1Impl"_hash,
  852. aarch64::matmul::gemm_s16_12x8x1, int16_t,
  853. int32_t);
  854. /* ===================== Int16x16x32MK8_8x8 algo ===================== */
  855. bool MatrixMulImpl::AlgoInt16x16x32MK8_8x8::usable(
  856. const KernSizeParam& kern_size_param) const {
  857. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  858. kern_size_param.C_type == dtype::Int32() &&
  859. kern_size_param.B_type == dtype::Int16() &&
  860. kern_size_param.A_type == dtype::Int16() &&
  861. kern_size_param.format == param::MatrixMul::Format::MK8 &&
  862. !kern_size_param.trA && !kern_size_param.trB &&
  863. kern_size_param.N % 4 == 0;
  864. }
  865. size_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_workspace(
  866. const KernSizeParam& kern_size_param) const {
  867. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  868. midout_iv("AlgoInt16x16x32MK8_8x8::get_workspace"_hash)) {
  869. auto M = kern_size_param.M, N = kern_size_param.N,
  870. K = kern_size_param.K;
  871. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  872. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  873. C_type = kern_size_param.C_type;
  874. aarch64::matmul::gemm_nopack_s16_8x8 strategy(A_type, B_type, C_type);
  875. return megdnn::matmul::GemmInterleaved<
  876. aarch64::matmul::gemm_nopack_s16_8x8, false>(
  877. M, N, K, trA, trB, strategy)
  878. .get_workspace_size();
  879. }
  880. MIDOUT_END();
  881. }
  882. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_kern(
  883. const KernSizeParam&) const {
  884. auto kern_mk8_8x8 = [](const MatrixMulImpl::KernParam& kern_param) {
  885. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  886. midout_iv("AlgoInt16x16x32MK8_8x8::get_kern"_hash)) {
  887. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  888. auto trA = kern_param.trA, trB = kern_param.trB;
  889. auto LDA = kern_param.LDA, LDB = kern_param.LDB,
  890. LDC = kern_param.LDC;
  891. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  892. C_type = kern_param.C_type;
  893. const auto Aptr = kern_param.A<dt_int16>(),
  894. Bptr = kern_param.B<dt_int16>();
  895. auto Cptr = kern_param.C<dt_int32>();
  896. aarch64::matmul::gemm_nopack_s16_8x8 strategy(A_type, B_type,
  897. C_type);
  898. megdnn::matmul::GemmInterleaved<
  899. aarch64::matmul::gemm_nopack_s16_8x8, false>(M, N, K, trA,
  900. trB, strategy)
  901. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  902. kern_param.workspace_ptr);
  903. }
  904. MIDOUT_END();
  905. };
  906. return kern_mk8_8x8;
  907. }
  908. #if __ARM_FEATURE_DOTPROD
  909. /* ==================== Quint8 K8x8x4 Dotprod algo ==================== */
  910. namespace {
  911. void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
  912. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  913. midout_iv("quint8_k8x8x4_dotprod_kern"_hash)) {
  914. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  915. auto trA = kern_param.trA, trB = kern_param.trB;
  916. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  917. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  918. C_type = kern_param.C_type;
  919. const auto Aptr = kern_param.A<dt_uint8>(),
  920. Bptr = kern_param.B<dt_uint8>();
  921. auto Cptr = kern_param.C<dt_int32>();
  922. aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type);
  923. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8>(
  924. M, N, K, trA, trB, strategy)
  925. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  926. kern_param.workspace_ptr);
  927. }
  928. MIDOUT_END();
  929. }
  930. } // anonymous namespace
  931. bool MatrixMulImpl::AlgoQuint8K8x8x4DotProd::usable(
  932. const KernSizeParam& kern_size_param) const {
  933. return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm &&
  934. kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm &&
  935. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 &&
  936. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  937. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT;
  938. }
  939. size_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_workspace(
  940. const KernSizeParam& kern_size_param) const {
  941. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  942. midout_iv("AlgoQuint8K8x8x4DotProd::get_workspace"_hash)) {
  943. auto M = kern_size_param.M, N = kern_size_param.N,
  944. K = kern_size_param.K;
  945. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  946. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  947. C_type = kern_size_param.C_type;
  948. aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type);
  949. return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8>(
  950. M, N, K, trA, trB, strategy)
  951. .get_workspace_size();
  952. }
  953. MIDOUT_END();
  954. }
  955. MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_kern(
  956. const KernSizeParam&) const {
  957. return quint8_k8x8x4_dotprod_kern;
  958. }
  959. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x4DotProd,
  960. megdnn_aarch64_matmul_kern,
  961. "AlgoQuint8K8x8x4DotProdImpl"_hash,
  962. aarch64::matmul::gemm_u8_8x8, uint8_t,
  963. int32_t);
  964. /* ===================== Quint8 Gemv DotProd algo ===================== */
  965. namespace {
  966. void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
  967. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  968. midout_iv("quint8_gemv_dotprod_kern"_hash)) {
  969. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  970. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  971. const auto Aptr = kern_param.A<dt_uint8>(),
  972. Bptr = kern_param.B<dt_uint8>();
  973. auto Cptr = kern_param.C<dt_int32>();
  974. auto A_type = kern_param.A_type, B_type = kern_param.B_type;
  975. aarch64::matmul::gemv_like_quint8(
  976. Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC,
  977. A_type.param<dtype::Quantized8Asymm>().zero_point,
  978. B_type.param<dtype::Quantized8Asymm>().zero_point);
  979. }
  980. MIDOUT_END();
  981. }
  982. } // anonymous namespace
  983. bool MatrixMulImpl::AlgoQuint8GemvDotProd::usable(
  984. const KernSizeParam& kern_size_param) const {
  985. return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm &&
  986. kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm &&
  987. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 &&
  988. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  989. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  990. !kern_size_param.trA && !kern_size_param.trB &&
  991. kern_size_param.N == 1 && kern_size_param.LDB == 1;
  992. }
  993. bool MatrixMulImpl::AlgoQuint8GemvDotProd::preferred(
  994. const KernSizeParam& kern_size_param) const {
  995. auto N = kern_size_param.N, LDB = kern_size_param.LDB;
  996. return (N == 1 && LDB == 1);
  997. }
  998. MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8GemvDotProd::get_kern(
  999. const KernSizeParam&) const {
  1000. return quint8_gemv_dotprod_kern;
  1001. }
  1002. #else
  1003. /* ===================== Quint8 K8x8x8 algo ===================== */
  1004. namespace {
  1005. void quint8_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
  1006. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  1007. midout_iv("quint8_gemv_dotprod_kern"_hash)) {
  1008. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  1009. auto trA = kern_param.trA, trB = kern_param.trB;
  1010. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  1011. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  1012. C_type = kern_param.C_type;
  1013. const auto Aptr = kern_param.A<dt_uint8>(),
  1014. Bptr = kern_param.B<dt_uint8>();
  1015. auto Cptr = kern_param.C<dt_int32>();
  1016. aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type);
  1017. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8>(
  1018. M, N, K, trA, trB, strategy)
  1019. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  1020. kern_param.workspace_ptr);
  1021. }
  1022. MIDOUT_END();
  1023. }
  1024. } // anonymous namespace
  1025. bool MatrixMulImpl::AlgoQuint8K8x8x8::usable(
  1026. const KernSizeParam& kern_size_param) const {
  1027. return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm &&
  1028. kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm &&
  1029. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 &&
  1030. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  1031. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT;
  1032. }
  1033. size_t MatrixMulImpl::AlgoQuint8K8x8x8::get_workspace(
  1034. const KernSizeParam& kern_size_param) const {
  1035. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  1036. midout_iv("AlgoQuint8K8x8x8::get_workspace"_hash)) {
  1037. auto M = kern_size_param.M, N = kern_size_param.N,
  1038. K = kern_size_param.K;
  1039. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  1040. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  1041. C_type = kern_size_param.C_type;
  1042. aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type);
  1043. return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8>(
  1044. M, N, K, trA, trB, strategy)
  1045. .get_workspace_size();
  1046. }
  1047. MIDOUT_END();
  1048. }
  1049. MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x8::get_kern(
  1050. const KernSizeParam&) const {
  1051. return quint8_k8x8x8_kern;
  1052. }
  1053. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8,
  1054. megdnn_aarch64_matmul_kern,
  1055. "AlgoQuint8K8x8x8Impl"_hash,
  1056. aarch64::matmul::gemm_u8_8x8, uint8_t,
  1057. int32_t);
  1058. #endif
  1059. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台