You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.cpp 66 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469
  1. /**
  2. * \file dnn/src/aarch64/matrix_mul/algos.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/aarch64/matrix_mul/algos.h"
  13. #include "src/aarch64/matrix_mul/fp16/strategy.h"
  14. #include "src/aarch64/matrix_mul/fp32/strategy.h"
  15. #include "src/aarch64/matrix_mul/int16/strategy.h"
  16. #include "src/aarch64/matrix_mul/int8/strategy.h"
  17. #include "src/aarch64/matrix_mul/int8_dot/strategy.h"
  18. #include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
  19. #include "src/aarch64/matrix_mul/int4x4x16/strategy.h"
  20. #include "src/aarch64/matrix_mul/quint8/strategy.h"
  21. #include "src/aarch64/matrix_mul/quint8_dot/gemv.h"
  22. #include "src/aarch64/matrix_mul/quint8_dot/strategy.h"
  23. #include "src/common/utils.h"
  24. #include "src/fallback/matrix_mul/gemm_impl.h"
  25. #if MGB_ENABLE_CPUINFO
  26. #include "cpuinfo.h"
  27. #endif
  28. #include "midout.h"
  29. MIDOUT_DECL(megdnn_aarch64_matmul_kern)
  30. using namespace megdnn;
  31. using namespace aarch64;
  32. /* ===================== F32K8X12X1 algo ===================== */
  33. bool MatrixMulImpl::AlgoF32K8x12x1::usable(
  34. const KernSizeParam& kern_size_param) const {
  35. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  36. kern_size_param.B_type == kern_size_param.A_type &&
  37. kern_size_param.C_type == kern_size_param.A_type &&
  38. kern_size_param.A_type == dtype::Float32() &&
  39. kern_size_param.format == param::MatrixMul::Format::DEFAULT;
  40. }
  41. size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace(
  42. const KernSizeParam& kern_size_param) const {
  43. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  44. midout_iv("AlgoF32K8x12x1::get_workspace"_hash)) {
  45. auto M = kern_size_param.M, N = kern_size_param.N,
  46. K = kern_size_param.K;
  47. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  48. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  49. C_type = kern_size_param.C_type;
  50. aarch64::matmul::sgemm_8x12 strategy(M, N, K, A_type, B_type, C_type);
  51. return megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_8x12>(
  52. M, N, K, trA, trB, strategy)
  53. .get_workspace_size();
  54. }
  55. MIDOUT_END();
  56. return 0;
  57. }
  58. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(
  59. const KernSizeParam&) const {
  60. auto f32_kern_8x12 = [](const MatrixMulImpl::KernParam& kern_param) {
  61. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  62. midout_iv("AlgoF32K8x12x1::get_kern"_hash)) {
  63. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  64. auto trA = kern_param.trA, trB = kern_param.trB;
  65. auto LDA = kern_param.LDA, LDB = kern_param.LDB,
  66. LDC = kern_param.LDC;
  67. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  68. C_type = kern_param.C_type;
  69. const auto Aptr = kern_param.A<float>(),
  70. Bptr = kern_param.B<float>();
  71. auto Cptr = kern_param.C<float>();
  72. aarch64::matmul::sgemm_8x12 strategy(M, N, K, A_type, B_type,
  73. C_type);
  74. megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_8x12>(
  75. M, N, K, trA, trB, strategy)
  76. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  77. kern_param.workspace_ptr);
  78. }
  79. MIDOUT_END();
  80. };
  81. return f32_kern_8x12;
  82. }
  83. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern,
  84. "AlgoF32K8x12x1Impl"_hash,
  85. aarch64::matmul::sgemm_8x12, float, float,
  86. AlgoDataType::FLOAT32, DEFAULT);
  87. /* ===================== F32_MK4_8X12X1 algo ===================== */
  88. bool MatrixMulImpl::AlgoF32MK4_8x12x1::usable(
  89. const KernSizeParam& kern_size_param) const {
  90. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  91. kern_size_param.B_type == kern_size_param.A_type &&
  92. kern_size_param.C_type == kern_size_param.A_type &&
  93. kern_size_param.A_type == dtype::Float32() &&
  94. kern_size_param.format == param::MatrixMul::Format::MK4 &&
  95. !kern_size_param.trA && !kern_size_param.trB &&
  96. kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0;
  97. }
  98. size_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_workspace(
  99. const KernSizeParam& kern_size_param) const {
  100. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  101. midout_iv("AlgoF32MK4_8x12x1::get_workspace"_hash)) {
  102. auto M = kern_size_param.M, N = kern_size_param.N,
  103. K = kern_size_param.K;
  104. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  105. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  106. C_type = kern_size_param.C_type;
  107. aarch64::matmul::sgemm_mk4_8x12 strategy(M, N, K, A_type, B_type,
  108. C_type);
  109. return megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_mk4_8x12>(
  110. M, N, K, trA, trB, strategy)
  111. .get_workspace_size();
  112. }
  113. MIDOUT_END();
  114. return 0;
  115. }
  116. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_kern(
  117. const KernSizeParam&) const {
  118. auto f32_kern_mk4_8x12 = [](const MatrixMulImpl::KernParam& kern_param) {
  119. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  120. midout_iv("AlgoF32MK4_8x12x1::get_kern"_hash)) {
  121. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  122. auto trA = kern_param.trA, trB = kern_param.trB;
  123. auto LDA = kern_param.LDA, LDB = kern_param.LDB,
  124. LDC = kern_param.LDC;
  125. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  126. C_type = kern_param.C_type;
  127. const auto Aptr = kern_param.A<float>(),
  128. Bptr = kern_param.B<float>();
  129. auto Cptr = kern_param.C<float>();
  130. aarch64::matmul::sgemm_mk4_8x12 strategy(M, N, K, A_type, B_type,
  131. C_type);
  132. megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_mk4_8x12>(
  133. M, N, K, trA, trB, strategy)
  134. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  135. kern_param.workspace_ptr);
  136. }
  137. MIDOUT_END();
  138. };
  139. return f32_kern_mk4_8x12;
  140. }
  141. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4_8x12x1,
  142. megdnn_aarch64_matmul_kern,
  143. "AlgoF32MK4_8x12x1Impl"_hash,
  144. aarch64::matmul::sgemm_mk4_8x12, float,
  145. float, AlgoDataType::FLOAT32, MK4);
  146. /* ===================== F32K4X16X1 algo ===================== */
  147. bool MatrixMulImpl::AlgoF32K4x16x1::usable(
  148. const KernSizeParam& kern_size_param) const {
  149. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  150. kern_size_param.B_type == kern_size_param.A_type &&
  151. kern_size_param.C_type == kern_size_param.A_type &&
  152. kern_size_param.A_type == dtype::Float32() &&
  153. kern_size_param.format == param::MatrixMul::Format::DEFAULT;
  154. }
  155. size_t MatrixMulImpl::AlgoF32K4x16x1::get_workspace(
  156. const KernSizeParam& kern_size_param) const {
  157. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  158. midout_iv("AlgoF32K4x16x1::get_workspace"_hash)) {
  159. auto M = kern_size_param.M, N = kern_size_param.N,
  160. K = kern_size_param.K;
  161. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  162. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  163. C_type = kern_size_param.C_type;
  164. aarch64::matmul::sgemm_4x16 strategy(M, N, K, A_type, B_type, C_type);
  165. return megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_4x16>(
  166. M, N, K, trA, trB, strategy)
  167. .get_workspace_size();
  168. }
  169. MIDOUT_END();
  170. return 0;
  171. }
  172. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K4x16x1::get_kern(
  173. const KernSizeParam&) const {
  174. auto f32_kern_4x16 = [](const MatrixMulImpl::KernParam& kern_param) {
  175. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  176. midout_iv("AlgoF32K4x16x1::get_kern"_hash)) {
  177. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  178. auto trA = kern_param.trA, trB = kern_param.trB;
  179. auto LDA = kern_param.LDA, LDB = kern_param.LDB,
  180. LDC = kern_param.LDC;
  181. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  182. C_type = kern_param.C_type;
  183. const auto Aptr = kern_param.A<float>(),
  184. Bptr = kern_param.B<float>();
  185. auto Cptr = kern_param.C<float>();
  186. aarch64::matmul::sgemm_4x16 strategy(M, N, K, A_type, B_type,
  187. C_type);
  188. megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_4x16>(
  189. M, N, K, trA, trB, strategy)
  190. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  191. kern_param.workspace_ptr);
  192. }
  193. MIDOUT_END();
  194. };
  195. return f32_kern_4x16;
  196. }
  197. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K4x16x1, megdnn_aarch64_matmul_kern,
  198. "AlgoF32K4x16x1Impl"_hash,
  199. aarch64::matmul::sgemm_4x16, float, float,
  200. AlgoDataType::FLOAT32, MK4);
  201. /* ===================== F32MK4_4x16 algo ===================== */
  202. bool MatrixMulImpl::AlgoF32MK4_4x16::usable(
  203. const KernSizeParam& kern_size_param) const {
  204. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  205. kern_size_param.C_type == dtype::Float32() &&
  206. kern_size_param.B_type == dtype::Float32() &&
  207. kern_size_param.A_type == dtype::Float32() &&
  208. kern_size_param.format == param::MatrixMul::Format::MK4 &&
  209. !kern_size_param.trA && !kern_size_param.trB;
  210. }
  211. size_t MatrixMulImpl::AlgoF32MK4_4x16::get_workspace(
  212. const KernSizeParam& kern_size_param) const {
  213. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  214. midout_iv("AlgoF32MK4_4x16::get_workspace"_hash)) {
  215. auto M = kern_size_param.M, N = kern_size_param.N,
  216. K = kern_size_param.K;
  217. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  218. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  219. C_type = kern_size_param.C_type;
  220. aarch64::matmul::sgemm_nopack_4x16 strategy(A_type, B_type, C_type);
  221. return megdnn::matmul::GemmInterleaved<
  222. aarch64::matmul::sgemm_nopack_4x16, false>(M, N, K, trA,
  223. trB, strategy)
  224. .get_workspace_size();
  225. }
  226. MIDOUT_END();
  227. return 0;
  228. }
  229. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_4x16::get_kern(
  230. const KernSizeParam&) const {
  231. auto f32_kern_mk4_4x16 = [](const MatrixMulImpl::KernParam& kern_param) {
  232. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  233. midout_iv("AlgoF32MK4_4x16::get_kern"_hash)) {
  234. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  235. auto trA = kern_param.trA, trB = kern_param.trB;
  236. auto LDA = kern_param.LDA, LDB = kern_param.LDB,
  237. LDC = kern_param.LDC;
  238. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  239. C_type = kern_param.C_type;
  240. const auto Aptr = kern_param.A<float>(),
  241. Bptr = kern_param.B<float>();
  242. auto Cptr = kern_param.C<float>();
  243. aarch64::matmul::sgemm_nopack_4x16 strategy(A_type, B_type, C_type);
  244. megdnn::matmul::GemmInterleaved<aarch64::matmul::sgemm_nopack_4x16,
  245. false>(M, N, K, trA, trB, strategy)
  246. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  247. kern_param.workspace_ptr);
  248. }
  249. MIDOUT_END();
  250. };
  251. return f32_kern_mk4_4x16;
  252. }
  253. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  254. /* ===================== F16 K8x24x1 algo ===================== */
  255. namespace {
  256. void f16_kern(const MatrixMulImpl::KernParam& kern_param) {
  257. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, midout_iv("f16_kern"_hash)) {
  258. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  259. auto trA = kern_param.trA, trB = kern_param.trB;
  260. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  261. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  262. C_type = kern_param.C_type;
  263. const auto Aptr = kern_param.A<dt_float16>(),
  264. Bptr = kern_param.B<dt_float16>();
  265. auto Cptr = kern_param.C<dt_float16>();
  266. aarch64::matmul::hgemm_8x24 strategy(M, N, K, A_type, B_type, C_type);
  267. megdnn::matmul::GemmInterleaved<aarch64::matmul::hgemm_8x24>(
  268. M, N, K, trA, trB, strategy)
  269. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  270. kern_param.workspace_ptr);
  271. }
  272. MIDOUT_END();
  273. }
  274. } // anonymous namespace
  275. bool MatrixMulImpl::AlgoF16K8x24x1::usable(
  276. const KernSizeParam& kern_size_param) const {
  277. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  278. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  279. kern_size_param.C_type == kern_size_param.A_type &&
  280. kern_size_param.B_type == kern_size_param.A_type &&
  281. kern_size_param.A_type == dtype::Float16();
  282. }
  283. size_t MatrixMulImpl::AlgoF16K8x24x1::get_workspace(
  284. const KernSizeParam& kern_size_param) const {
  285. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  286. midout_iv("AlgoF16K8x24x1::get_workspace"_hash)) {
  287. auto M = kern_size_param.M, N = kern_size_param.N,
  288. K = kern_size_param.K;
  289. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  290. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  291. C_type = kern_size_param.C_type;
  292. aarch64::matmul::hgemm_8x24 strategy(M, N, K, A_type, B_type, C_type);
  293. return megdnn::matmul::GemmInterleaved<aarch64::matmul::hgemm_8x24>(
  294. M, N, K, trA, trB, strategy)
  295. .get_workspace_size();
  296. }
  297. MIDOUT_END();
  298. return 0;
  299. }
  300. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K8x24x1::get_kern(
  301. const KernSizeParam&) const {
  302. return f16_kern;
  303. }
  304. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K8x24x1, megdnn_aarch64_matmul_kern,
  305. "AlogF16K8x24x1Impl"_hash,
  306. aarch64::matmul::hgemm_8x24, dt_float16,
  307. dt_float16, AlgoDataType::FLOAT16,
  308. DEFAULT);
  309. /* ===================== F16_MK8_8x8 algo ===================== */
  310. bool MatrixMulImpl::AlgoF16MK8_8x8::usable(
  311. const KernSizeParam& kern_size_param) const {
  312. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  313. kern_size_param.C_type == kern_size_param.A_type &&
  314. kern_size_param.B_type == kern_size_param.A_type &&
  315. kern_size_param.A_type == dtype::Float16() &&
  316. kern_size_param.format == param::MatrixMul::Format::MK8 &&
  317. !kern_size_param.trA && !kern_size_param.trB;
  318. }
  319. size_t MatrixMulImpl::AlgoF16MK8_8x8::get_workspace(
  320. const KernSizeParam& kern_size_param) const {
  321. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  322. midout_iv("AlgoF16MK8_8x8::get_workspace"_hash)) {
  323. auto M = kern_size_param.M, N = kern_size_param.N,
  324. K = kern_size_param.K;
  325. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  326. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  327. C_type = kern_size_param.C_type;
  328. aarch64::matmul::gemm_nopack_f16_8x8 strategy(A_type, B_type, C_type);
  329. return megdnn::matmul::GemmInterleaved<
  330. aarch64::matmul::gemm_nopack_f16_8x8, false>(
  331. M, N, K, trA, trB, strategy)
  332. .get_workspace_size();
  333. }
  334. MIDOUT_END();
  335. return 0;
  336. }
  337. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern(
  338. const KernSizeParam&) const {
  339. auto kern_mk8_8x8 = [](const MatrixMulImpl::KernParam& kern_param) {
  340. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  341. midout_iv("AlgoF16MK8_8x8::get_kern"_hash)) {
  342. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  343. auto trA = kern_param.trA, trB = kern_param.trB;
  344. auto LDA = kern_param.LDA, LDB = kern_param.LDB,
  345. LDC = kern_param.LDC;
  346. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  347. C_type = kern_param.C_type;
  348. const auto Aptr = kern_param.A<dt_float16>(),
  349. Bptr = kern_param.B<dt_float16>();
  350. auto Cptr = kern_param.C<dt_float16>();
  351. aarch64::matmul::gemm_nopack_f16_8x8 strategy(A_type, B_type,
  352. C_type);
  353. megdnn::matmul::GemmInterleaved<
  354. aarch64::matmul::gemm_nopack_f16_8x8, false>(M, N, K, trA,
  355. trB, strategy)
  356. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  357. kern_param.workspace_ptr);
  358. }
  359. MIDOUT_END();
  360. };
  361. return kern_mk8_8x8;
  362. }
  363. #endif
  364. #if __ARM_FEATURE_DOTPROD
  365. /* ==================== Int8x8x32 K8x12x4 Dotprod algo ==================== */
  366. namespace {
  367. void int8x8x32_k8x12x4_dotprod_kern(
  368. const MatrixMulImpl::KernParam& kern_param) {
  369. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  370. midout_iv("int8x8x32_k8x12x4_dotprod_kern"_hash)) {
  371. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  372. auto trA = kern_param.trA, trB = kern_param.trB;
  373. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  374. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  375. C_type = kern_param.C_type;
  376. const auto Aptr = kern_param.A<dt_int8>(),
  377. Bptr = kern_param.B<dt_int8>();
  378. auto Cptr = kern_param.C<dt_int32>();
  379. aarch64::matmul::gemm_s8_8x12 strategy(M, N, K, A_type, B_type, C_type);
  380. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8_8x12>(
  381. M, N, K, trA, trB, strategy)
  382. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  383. kern_param.workspace_ptr);
  384. }
  385. MIDOUT_END();
  386. }
  387. } // anonymous namespace
  388. bool MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::usable(
  389. const KernSizeParam& kern_size_param) const {
  390. return can_be_treated_as_int8x8x32(kern_size_param);
  391. }
  392. size_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_workspace(
  393. const KernSizeParam& kern_size_param) const {
  394. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  395. midout_iv("AlgoInt8x8x32K8x12x4DotProd::get_workspace"_hash)) {
  396. auto M = kern_size_param.M, N = kern_size_param.N,
  397. K = kern_size_param.K;
  398. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  399. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  400. C_type = kern_size_param.C_type;
  401. aarch64::matmul::gemm_s8_8x12 strategy(M, N, K, A_type, B_type, C_type);
  402. return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8_8x12>(
  403. M, N, K, trA, trB, strategy)
  404. .get_workspace_size();
  405. }
  406. MIDOUT_END();
  407. return 0;
  408. }
  409. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_kern(
  410. const KernSizeParam&) const {
  411. return int8x8x32_k8x12x4_dotprod_kern;
  412. }
  413. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd,
  414. megdnn_aarch64_matmul_kern,
  415. "AlgoInt8x8x32K8x12x4DotProdImpl"_hash,
  416. aarch64::matmul::gemm_s8_8x12, int8_t,
  417. int32_t, AlgoDataType::QINT8X8X32,
  418. DEFAULT);
  419. /* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */
  420. namespace {
  421. void int8x8x32_mk4_8x12x4_dotprod_kern(
  422. const MatrixMulImpl::KernParam& kern_param) {
  423. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  424. midout_iv("int8x8x32_mk4_8x12x4_dotprod_kern"_hash)) {
  425. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  426. auto trA = kern_param.trA, trB = kern_param.trB;
  427. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  428. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  429. C_type = kern_param.C_type;
  430. const auto Aptr = kern_param.A<dt_int8>(),
  431. Bptr = kern_param.B<dt_int8>();
  432. auto Cptr = kern_param.C<dt_int32>();
  433. aarch64::matmul::gemm_mk4_s8_8x12 strategy(M, N, K, A_type, B_type,
  434. C_type);
  435. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_mk4_s8_8x12>(
  436. M, N, K, trA, trB, strategy)
  437. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  438. kern_param.workspace_ptr);
  439. }
  440. MIDOUT_END();
  441. }
  442. } // anonymous namespace
  443. bool MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::usable(
  444. const KernSizeParam& kern_size_param) const {
  445. return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
  446. (kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
  447. kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
  448. (kern_size_param.C_type.enumv() == DTypeEnum::Int32 ||
  449. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32) &&
  450. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  451. kern_size_param.format == param::MatrixMul::Format::MK4_DOT &&
  452. !kern_size_param.trA && !kern_size_param.trB;
  453. }
  454. size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace(
  455. const KernSizeParam& kern_size_param) const {
  456. MIDOUT_BEGIN(
  457. megdnn_aarch64_matmul_kern,
  458. midout_iv("AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace"_hash)) {
  459. auto M = kern_size_param.M, N = kern_size_param.N,
  460. K = kern_size_param.K;
  461. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  462. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  463. C_type = kern_size_param.C_type;
  464. aarch64::matmul::gemm_mk4_s8_8x12 strategy(M, N, K, A_type, B_type,
  465. C_type);
  466. return megdnn::matmul::GemmInterleaved<
  467. aarch64::matmul::gemm_mk4_s8_8x12>(M, N, K, trA, trB,
  468. strategy)
  469. .get_workspace_size();
  470. }
  471. MIDOUT_END();
  472. return 0;
  473. }
  474. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_kern(
  475. const KernSizeParam&) const {
  476. return int8x8x32_mk4_8x12x4_dotprod_kern;
  477. }
  478. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x12x4DotProd,
  479. megdnn_aarch64_matmul_kern,
  480. "AlgoInt8x8x32MK4_8x12x4DotProdImpl"_hash,
  481. aarch64::matmul::gemm_mk4_s8_8x12, int8_t,
  482. int32_t, AlgoDataType::QINT8X8X32,
  483. MK4_DOT);
  484. #else
  485. /* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */
  486. namespace {
  487. void int8x8x32_mk4_4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) {
  488. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  489. midout_iv("int8x8x32_mk4_4x4x16_kern"_hash)) {
  490. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  491. auto trA = kern_param.trA, trB = kern_param.trB;
  492. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  493. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  494. C_type = kern_param.C_type;
  495. const auto Aptr = kern_param.A<dt_int8>(),
  496. Bptr = kern_param.B<dt_int8>();
  497. auto Cptr = kern_param.C<dt_int32>();
  498. aarch64::matmul::gemm_mk4_s8_4x4 strategy(M, N, K, A_type, B_type,
  499. C_type);
  500. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_mk4_s8_4x4>(
  501. M, N, K, trA, trB, strategy)
  502. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  503. kern_param.workspace_ptr);
  504. }
  505. MIDOUT_END();
  506. }
  507. } // anonymous namespace
  508. bool MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::usable(
  509. const KernSizeParam& param) const {
  510. return param.A_type.enumv() == param.B_type.enumv() &&
  511. (param.A_type.enumv() == DTypeEnum::Int8 ||
  512. param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
  513. (param.C_type.enumv() == DTypeEnum::Int32 ||
  514. param.C_type.enumv() == DTypeEnum::QuantizedS32) &&
  515. param.compute_mode == Param::ComputeMode::DEFAULT &&
  516. param.format == param::MatrixMul::Format::MK4 && !param.trA &&
  517. !param.trB && param.M % 4 == 0 && param.K % 4 == 0;
  518. }
  519. bool MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::preferred(
  520. const KernSizeParam& kern_size_param) const {
  521. return kern_size_param.K > 16;
  522. }
  523. size_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_workspace(
  524. const KernSizeParam& kern_size_param) const {
  525. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  526. midout_iv("AlgoInt8x8x32MK4_4x4x16::get_workspace"_hash)) {
  527. auto M = kern_size_param.M, N = kern_size_param.N,
  528. K = kern_size_param.K;
  529. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  530. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  531. C_type = kern_size_param.C_type;
  532. aarch64::matmul::gemm_mk4_s8_4x4 strategy(M, N, K, A_type, B_type,
  533. C_type);
  534. return megdnn::matmul::GemmInterleaved<
  535. aarch64::matmul::gemm_mk4_s8_4x4>(M, N, K, trA, trB,
  536. strategy)
  537. .get_workspace_size();
  538. }
  539. MIDOUT_END();
  540. return 0;
  541. }
  542. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_kern(
  543. const KernSizeParam&) const {
  544. return int8x8x32_mk4_4x4x16_kern;
  545. }
  546. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x4x16,
  547. megdnn_aarch64_matmul_kern,
  548. "AlgoInt8x8x32MK4_4x4x16Impl"_hash,
  549. aarch64::matmul::gemm_mk4_s8_4x4, int8_t,
  550. int32_t, AlgoDataType::QINT8X8X32,
  551. MK4);
  552. /* ===================== Int8x8x32 K4x4x16 algo ===================== */
  553. namespace {
  554. void int8x8x32_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) {
  555. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  556. midout_iv("int8x8x32_k4x4x16_kern"_hash)) {
  557. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  558. auto trA = kern_param.trA, trB = kern_param.trB;
  559. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  560. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  561. C_type = kern_param.C_type;
  562. const auto Aptr = kern_param.A<dt_int8>(),
  563. Bptr = kern_param.B<dt_int8>();
  564. auto Cptr = kern_param.C<dt_int32>();
  565. aarch64::matmul::gemm_s8_4x4 strategy(M, N, K, A_type, B_type, C_type);
  566. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8_4x4>(
  567. M, N, K, trA, trB, strategy)
  568. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  569. kern_param.workspace_ptr);
  570. }
  571. MIDOUT_END();
  572. }
  573. } // anonymous namespace
  574. bool MatrixMulImpl::AlgoInt8x8x32K4x4x16::usable(
  575. const KernSizeParam& kern_size_param) const {
  576. return can_be_treated_as_int8x8x32(kern_size_param);
  577. }
  578. bool MatrixMulImpl::AlgoInt8x8x32K4x4x16::preferred(
  579. const KernSizeParam& kern_size_param) const {
  580. return kern_size_param.K > 16;
  581. }
  582. size_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_workspace(
  583. const KernSizeParam& kern_size_param) const {
  584. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  585. midout_iv("AlgoInt8x8x32K4x4x16::get_workspace"_hash)) {
  586. auto M = kern_size_param.M, N = kern_size_param.N,
  587. K = kern_size_param.K;
  588. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  589. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  590. C_type = kern_size_param.C_type;
  591. aarch64::matmul::gemm_s8_4x4 strategy(M, N, K, A_type, B_type, C_type);
  592. return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8_4x4>(
  593. M, N, K, trA, trB, strategy)
  594. .get_workspace_size();
  595. }
  596. MIDOUT_END();
  597. return 0;
  598. }
  599. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_kern(
  600. const KernSizeParam&) const {
  601. return int8x8x32_k4x4x16_kern;
  602. }
  603. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x4x16,
  604. megdnn_aarch64_matmul_kern,
  605. "AlgoInt8x8x32K4x4x16Impl"_hash,
  606. aarch64::matmul::gemm_s8_4x4, int8_t,
  607. int32_t, AlgoDataType::QINT8X8X32,
  608. DEFAULT);
  609. /* ===================== Int8x8x32 K8x8x8 algo ===================== */
  610. namespace {
  611. void int8x8x32_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
  612. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  613. midout_iv("int8x8x32_k8x8x8_kern"_hash)) {
  614. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  615. auto trA = kern_param.trA, trB = kern_param.trB;
  616. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  617. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  618. C_type = kern_param.C_type;
  619. const auto Aptr = kern_param.A<dt_int8>(),
  620. Bptr = kern_param.B<dt_int8>();
  621. auto Cptr = kern_param.C<dt_int32>();
  622. aarch64::matmul::gemm_s8_8x8 strategy(M, N, K, A_type, B_type, C_type);
  623. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8_8x8>(
  624. M, N, K, trA, trB, strategy)
  625. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  626. kern_param.workspace_ptr);
  627. }
  628. MIDOUT_END();
  629. }
  630. } // anonymous namespace
  631. bool MatrixMulImpl::AlgoInt8x8x32K8x8x8::usable(
  632. const KernSizeParam& kern_size_param) const {
  633. return can_be_treated_as_int8x8x32(kern_size_param);
  634. }
  635. bool MatrixMulImpl::AlgoInt8x8x32K8x8x8::preferred(
  636. const KernSizeParam& kern_size_param) const {
  637. return kern_size_param.K <= 16;
  638. }
  639. size_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_workspace(
  640. const KernSizeParam& kern_size_param) const {
  641. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  642. midout_iv("AlgoInt8x8x32K8x8x8::get_workspace"_hash)) {
  643. auto M = kern_size_param.M, N = kern_size_param.N,
  644. K = kern_size_param.K;
  645. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  646. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  647. C_type = kern_size_param.C_type;
  648. aarch64::matmul::gemm_s8_8x8 strategy(M, N, K, A_type, B_type, C_type);
  649. return megdnn::matmul::GemmInterleaved<matmul::gemm_s8_8x8>(
  650. M, N, K, trA, trB, strategy)
  651. .get_workspace_size();
  652. }
  653. MIDOUT_END();
  654. return 0;
  655. }
  656. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_kern(
  657. const KernSizeParam&) const {
  658. return int8x8x32_k8x8x8_kern;
  659. }
  660. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x8x8,
  661. megdnn_aarch64_matmul_kern,
  662. "AlgoInt8x8x32K8x8x8Impl"_hash,
  663. aarch64::matmul::gemm_s8_8x8, int8_t,
  664. int32_t, AlgoDataType::QINT8X8X32,
  665. DEFAULT);
  666. #endif
  667. /* ===================== Int8x8x16 K8x8x8 algo ===================== */
  668. namespace {
  669. void int8x8x16_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
  670. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  671. midout_iv("int8x8x16_k8x8x8_kern"_hash)) {
  672. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  673. auto trA = kern_param.trA, trB = kern_param.trB;
  674. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  675. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  676. C_type = kern_param.C_type;
  677. const auto Aptr = kern_param.A<dt_int8>(),
  678. Bptr = kern_param.B<dt_int8>();
  679. auto Cptr = kern_param.C<dt_int16>();
  680. aarch64::matmul::gemm_s8x8x16_8x8 strategy(M, N, K, A_type, B_type,
  681. C_type);
  682. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8x8x16_8x8>(
  683. M, N, K, trA, trB, strategy)
  684. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  685. kern_param.workspace_ptr);
  686. }
  687. MIDOUT_END();
  688. }
  689. } // anonymous namespace
  690. bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::usable(
  691. const KernSizeParam& kern_size_param) const {
  692. return can_be_treated_as_int8x8x16(kern_size_param) &&
  693. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  694. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT;
  695. }
  696. bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::preferred(
  697. const KernSizeParam& kern_size_param) const {
  698. return kern_size_param.K <= 16;
  699. }
  700. size_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_workspace(
  701. const KernSizeParam& kern_size_param) const {
  702. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  703. midout_iv("AlgoInt8x8x16K8x8x8::get_workspace"_hash)) {
  704. auto M = kern_size_param.M, N = kern_size_param.N,
  705. K = kern_size_param.K;
  706. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  707. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  708. C_type = kern_size_param.C_type;
  709. aarch64::matmul::gemm_s8x8x16_8x8 strategy(M, N, K, A_type, B_type,
  710. C_type);
  711. return megdnn::matmul::GemmInterleaved<matmul::gemm_s8x8x16_8x8>(
  712. M, N, K, trA, trB, strategy)
  713. .get_workspace_size();
  714. }
  715. MIDOUT_END();
  716. return 0;
  717. }
  718. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_kern(
  719. const KernSizeParam&) const {
  720. return int8x8x16_k8x8x8_kern;
  721. }
  722. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K8x8x8,
  723. megdnn_aarch64_matmul_kern,
  724. "AlgoInt8x8x16K8x8x8Impl"_hash,
  725. aarch64::matmul::gemm_s8x8x16_8x8, int8_t,
  726. int16_t, AlgoDataType::INT8X8X16, DEFAULT);
  727. /* ===================== Int8x8x16 K4x4x16 algo ===================== */
  728. namespace {
  729. void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) {
  730. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  731. midout_iv("int8x8x16_k4x4x16_kern"_hash)) {
  732. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  733. auto trA = kern_param.trA, trB = kern_param.trB;
  734. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  735. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  736. C_type = kern_param.C_type;
  737. const auto Aptr = kern_param.A<dt_int8>(),
  738. Bptr = kern_param.B<dt_int8>();
  739. auto Cptr = kern_param.C<dt_int16>();
  740. aarch64::matmul::gemm_s8x8x16_4x4 strategy(M, N, K, A_type, B_type,
  741. C_type);
  742. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s8x8x16_4x4>(
  743. M, N, K, trA, trB, strategy)
  744. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  745. kern_param.workspace_ptr);
  746. }
  747. MIDOUT_END();
  748. }
  749. } // anonymous namespace
  750. bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::usable(
  751. const KernSizeParam& kern_size_param) const {
  752. return can_be_treated_as_int8x8x16(kern_size_param) &&
  753. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  754. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT;
  755. }
  756. bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::preferred(
  757. const KernSizeParam& kern_size_param) const {
  758. MEGDNN_MARK_USED_VAR(kern_size_param);
  759. return true;
  760. }
  761. size_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_workspace(
  762. const KernSizeParam& kern_size_param) const {
  763. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  764. midout_iv("AlgoInt8x8x16K4x4x16::get_workspace"_hash)) {
  765. auto M = kern_size_param.M, N = kern_size_param.N,
  766. K = kern_size_param.K;
  767. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  768. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  769. C_type = kern_size_param.C_type;
  770. aarch64::matmul::gemm_s8x8x16_4x4 strategy(M, N, K, A_type, B_type,
  771. C_type);
  772. return megdnn::matmul::GemmInterleaved<matmul::gemm_s8x8x16_4x4>(
  773. M, N, K, trA, trB, strategy)
  774. .get_workspace_size();
  775. }
  776. MIDOUT_END();
  777. return 0;
  778. }
  779. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_kern(
  780. const KernSizeParam&) const {
  781. return int8x8x16_k4x4x16_kern;
  782. }
  783. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x4x16,
  784. megdnn_aarch64_matmul_kern,
  785. "AlgoInt8x8x16K4x4x16Impl"_hash,
  786. aarch64::matmul::gemm_s8x8x16_4x4, int8_t,
  787. int16_t, AlgoDataType::INT8X8X16, DEFAULT);
  788. /* ===================== Int8x8x16 K16x12x4 algo ===================== */
  789. namespace {
  790. void int8x8x16_mk4_16x12x4_kern(const MatrixMulImpl::KernParam& kern_param) {
  791. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  792. midout_iv("int8x8x16_mk4_16x12x4_kern"_hash)) {
  793. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  794. auto trA = kern_param.trA, trB = kern_param.trB;
  795. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  796. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  797. C_type = kern_param.C_type;
  798. const auto Aptr = kern_param.A<dt_int8>(),
  799. Bptr = kern_param.B<dt_int8>();
  800. auto Cptr = kern_param.C<dt_int16>();
  801. aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy(M, N, K, A_type,
  802. B_type, C_type);
  803. megdnn::matmul::GemmInterleaved<
  804. aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53>(M, N, K, trA, trB,
  805. strategy)
  806. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  807. kern_param.workspace_ptr);
  808. }
  809. MIDOUT_END();
  810. }
  811. } // anonymous namespace
  812. bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::usable(
  813. const KernSizeParam& kern_size_param) const {
  814. return can_be_treated_as_int8x8x16(kern_size_param) &&
  815. kern_size_param.format == param::MatrixMul::Format::MK4 &&
  816. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  817. !kern_size_param.trA && !kern_size_param.trB &&
  818. kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0;
  819. }
  820. bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::preferred(
  821. const KernSizeParam&) const {
  822. #if !MGB_ENABLE_CPUINFO
  823. return false;
  824. #else
  825. auto arch = cpuinfo_get_current_core()->uarch;
  826. bool little_core = arch == cpuinfo_uarch_cortex_a53 ||
  827. arch == cpuinfo_uarch_cortex_a55;
  828. return little_core;
  829. #endif
  830. }
  831. size_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_workspace(
  832. const KernSizeParam& kern_size_param) const {
  833. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  834. midout_iv("AlgoInt8x8x16MK4_16x12x4::get_workspace"_hash)) {
  835. auto M = kern_size_param.M, N = kern_size_param.N,
  836. K = kern_size_param.K;
  837. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  838. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  839. C_type = kern_size_param.C_type;
  840. aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy(M, N, K, A_type,
  841. B_type, C_type);
  842. return megdnn::matmul::GemmInterleaved<
  843. matmul::gemm_s8x8x16_mk4_16x12_a53>(M, N, K, trA, trB,
  844. strategy)
  845. .get_workspace_size();
  846. }
  847. MIDOUT_END();
  848. return 0;
  849. }
  850. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern(
  851. const KernSizeParam&) const {
  852. return int8x8x16_mk4_16x12x4_kern;
  853. }
  854. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
  855. AlgoInt8x8x16MK4_16x12x4, megdnn_aarch64_matmul_kern,
  856. "AlgoInt8x8x16MK4_16x12x4Impl"_hash,
  857. aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t,
  858. AlgoDataType::INT8X8X16, MK4);
  859. /* ===================== Int8x8x16 MK4 4x4x8 algo ===================== */
  860. namespace {
  861. void int8x8x16_mk4_4x4x8_kern(const MatrixMulImpl::KernParam& kern_param) {
  862. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  863. midout_iv("int8x8x16_mk4_4x4x8_kern"_hash)) {
  864. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  865. auto trA = kern_param.trA, trB = kern_param.trB;
  866. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  867. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  868. C_type = kern_param.C_type;
  869. const auto Aptr = kern_param.A<dt_int8>(),
  870. Bptr = kern_param.B<dt_int8>();
  871. auto Cptr = kern_param.C<dt_int16>();
  872. aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy(M, N, K, A_type,
  873. B_type, C_type);
  874. megdnn::matmul::GemmInterleaved<
  875. aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72>(M, N, K, trA, trB,
  876. strategy)
  877. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  878. kern_param.workspace_ptr);
  879. }
  880. MIDOUT_END();
  881. }
  882. } // anonymous namespace
  883. bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::usable(
  884. const KernSizeParam& kern_size_param) const {
  885. return can_be_treated_as_int8x8x16(kern_size_param) &&
  886. kern_size_param.format == param::MatrixMul::Format::MK4 &&
  887. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  888. !kern_size_param.trA && !kern_size_param.trB &&
  889. kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0;
  890. }
  891. bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::preferred(
  892. const KernSizeParam&) const {
  893. #if !MGB_ENABLE_CPUINFO
  894. return false;
  895. #else
  896. auto arch = cpuinfo_get_current_core()->uarch;
  897. bool little_core = arch == cpuinfo_uarch_cortex_a53 ||
  898. arch == cpuinfo_uarch_cortex_a55;
  899. return !little_core;
  900. #endif
  901. }
  902. size_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_workspace(
  903. const KernSizeParam& kern_size_param) const {
  904. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  905. midout_iv("AlgoInt8x8x16MK4_4x4x8::get_workspace"_hash)) {
  906. auto M = kern_size_param.M, N = kern_size_param.N,
  907. K = kern_size_param.K;
  908. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  909. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  910. C_type = kern_size_param.C_type;
  911. aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy(M, N, K, A_type,
  912. B_type, C_type);
  913. return megdnn::matmul::GemmInterleaved<
  914. matmul::gemm_s8x8x16_mk4_4x4_a72>(M, N, K, trA, trB,
  915. strategy)
  916. .get_workspace_size();
  917. }
  918. MIDOUT_END();
  919. return 0;
  920. }
  921. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_kern(
  922. const KernSizeParam&) const {
  923. return int8x8x16_mk4_4x4x8_kern;
  924. }
  925. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_4x4x8,
  926. megdnn_aarch64_matmul_kern,
  927. "AlgoInt8x8x16MK4_4x4x8_Impl"_hash,
  928. aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72,
  929. int8_t, int16_t, AlgoDataType::INT8X8X16,
  930. MK4);
  931. /* ===================== Int16x16x32 K12x8x1 algo ===================== */
  932. namespace {
  933. void int16x16x32_k12x8x1_kern(const MatrixMulImpl::KernParam& kern_param) {
  934. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  935. midout_iv("int16x16x32_k12x8x1_kern"_hash)) {
  936. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  937. auto trA = kern_param.trA, trB = kern_param.trB;
  938. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  939. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  940. C_type = kern_param.C_type;
  941. const auto Aptr = kern_param.A<dt_int16>(),
  942. Bptr = kern_param.B<dt_int16>();
  943. auto Cptr = kern_param.C<dt_int32>();
  944. aarch64::matmul::gemm_s16_12x8x1 strategy(M, N, K, A_type, B_type,
  945. C_type);
  946. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s16_12x8x1>(
  947. M, N, K, trA, trB, strategy)
  948. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  949. kern_param.workspace_ptr);
  950. }
  951. MIDOUT_END();
  952. }
  953. } // anonymous namespace
  954. bool MatrixMulImpl::AlgoInt16x16x32K12x8x1::usable(
  955. const KernSizeParam& kern_size_param) const {
  956. return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
  957. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  958. kern_size_param.compute_mode ==
  959. param::MatrixMul::ComputeMode::DEFAULT &&
  960. kern_size_param.A_type.enumv() == DTypeEnum::Int16 &&
  961. kern_size_param.C_type.enumv() == DTypeEnum::Int32;
  962. }
  963. bool MatrixMulImpl::AlgoInt16x16x32K12x8x1::preferred(
  964. const KernSizeParam& kern_size_param) const {
  965. MEGDNN_MARK_USED_VAR(kern_size_param);
  966. return true;
  967. }
  968. size_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_workspace(
  969. const KernSizeParam& kern_size_param) const {
  970. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  971. midout_iv("AlgoInt16x16x32K12x8x1::get_workspace"_hash)) {
  972. auto M = kern_size_param.M, N = kern_size_param.N,
  973. K = kern_size_param.K;
  974. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  975. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  976. C_type = kern_size_param.C_type;
  977. aarch64::matmul::gemm_s16_12x8x1 strategy(M, N, K, A_type, B_type,
  978. C_type);
  979. return megdnn::matmul::GemmInterleaved<matmul::gemm_s16_12x8x1>(
  980. M, N, K, trA, trB, strategy)
  981. .get_workspace_size();
  982. }
  983. MIDOUT_END();
  984. return 0;
  985. }
  986. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_kern(
  987. const KernSizeParam&) const {
  988. return int16x16x32_k12x8x1_kern;
  989. }
  990. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x8x1,
  991. megdnn_aarch64_matmul_kern,
  992. "AlgoInt16x16x32K12x8x1Impl"_hash,
  993. aarch64::matmul::gemm_s16_12x8x1, int16_t,
  994. int32_t, AlgoDataType::INT16X16X32,
  995. DEFAULT);
  996. /* ===================== Int16x16x32MK8_8x8 algo ===================== */
  997. bool MatrixMulImpl::AlgoInt16x16x32MK8_8x8::usable(
  998. const KernSizeParam& kern_size_param) const {
  999. return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  1000. kern_size_param.C_type == dtype::Int32() &&
  1001. kern_size_param.B_type == dtype::Int16() &&
  1002. kern_size_param.A_type == dtype::Int16() &&
  1003. kern_size_param.format == param::MatrixMul::Format::MK8 &&
  1004. !kern_size_param.trA && !kern_size_param.trB;
  1005. }
  1006. size_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_workspace(
  1007. const KernSizeParam& kern_size_param) const {
  1008. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  1009. midout_iv("AlgoInt16x16x32MK8_8x8::get_workspace"_hash)) {
  1010. auto M = kern_size_param.M, N = kern_size_param.N,
  1011. K = kern_size_param.K;
  1012. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  1013. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  1014. C_type = kern_size_param.C_type;
  1015. aarch64::matmul::gemm_nopack_s16_8x8 strategy(A_type, B_type, C_type);
  1016. return megdnn::matmul::GemmInterleaved<
  1017. aarch64::matmul::gemm_nopack_s16_8x8, false>(
  1018. M, N, K, trA, trB, strategy)
  1019. .get_workspace_size();
  1020. }
  1021. MIDOUT_END();
  1022. return 0;
  1023. }
  1024. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_kern(
  1025. const KernSizeParam&) const {
  1026. auto kern_mk8_8x8 = [](const MatrixMulImpl::KernParam& kern_param) {
  1027. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  1028. midout_iv("AlgoInt16x16x32MK8_8x8::get_kern"_hash)) {
  1029. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  1030. auto trA = kern_param.trA, trB = kern_param.trB;
  1031. auto LDA = kern_param.LDA, LDB = kern_param.LDB,
  1032. LDC = kern_param.LDC;
  1033. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  1034. C_type = kern_param.C_type;
  1035. const auto Aptr = kern_param.A<dt_int16>(),
  1036. Bptr = kern_param.B<dt_int16>();
  1037. auto Cptr = kern_param.C<dt_int32>();
  1038. aarch64::matmul::gemm_nopack_s16_8x8 strategy(A_type, B_type,
  1039. C_type);
  1040. megdnn::matmul::GemmInterleaved<
  1041. aarch64::matmul::gemm_nopack_s16_8x8, false>(M, N, K, trA,
  1042. trB, strategy)
  1043. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  1044. kern_param.workspace_ptr);
  1045. }
  1046. MIDOUT_END();
  1047. };
  1048. return kern_mk8_8x8;
  1049. }
  1050. #if __ARM_FEATURE_DOTPROD
  1051. /* ==================== Quint8 K8x8x4 Dotprod algo ==================== */
  1052. namespace {
  1053. void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
  1054. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  1055. midout_iv("quint8_k8x8x4_dotprod_kern"_hash)) {
  1056. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  1057. auto trA = kern_param.trA, trB = kern_param.trB;
  1058. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  1059. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  1060. C_type = kern_param.C_type;
  1061. const auto Aptr = kern_param.A<dt_uint8>(),
  1062. Bptr = kern_param.B<dt_uint8>();
  1063. auto Cptr = kern_param.C<dt_int32>();
  1064. aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type);
  1065. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8>(
  1066. M, N, K, trA, trB, strategy)
  1067. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  1068. kern_param.workspace_ptr);
  1069. }
  1070. MIDOUT_END();
  1071. }
  1072. } // anonymous namespace
  1073. bool MatrixMulImpl::AlgoQuint8K8x8x4DotProd::usable(
  1074. const KernSizeParam& kern_size_param) const {
  1075. return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm &&
  1076. kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm &&
  1077. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 &&
  1078. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  1079. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT;
  1080. }
  1081. size_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_workspace(
  1082. const KernSizeParam& kern_size_param) const {
  1083. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  1084. midout_iv("AlgoQuint8K8x8x4DotProd::get_workspace"_hash)) {
  1085. auto M = kern_size_param.M, N = kern_size_param.N,
  1086. K = kern_size_param.K;
  1087. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  1088. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  1089. C_type = kern_size_param.C_type;
  1090. aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type);
  1091. return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8>(
  1092. M, N, K, trA, trB, strategy)
  1093. .get_workspace_size();
  1094. }
  1095. MIDOUT_END();
  1096. return 0;
  1097. }
  1098. MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_kern(
  1099. const KernSizeParam&) const {
  1100. return quint8_k8x8x4_dotprod_kern;
  1101. }
  1102. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x4DotProd,
  1103. megdnn_aarch64_matmul_kern,
  1104. "AlgoQuint8K8x8x4DotProdImpl"_hash,
  1105. aarch64::matmul::gemm_u8_8x8, uint8_t,
  1106. int32_t, AlgoDataType::QUINT8X8X32,
  1107. DEFAULT);
  1108. /* ===================== Quint8 Gemv DotProd algo ===================== */
  1109. namespace {
  1110. void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
  1111. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  1112. midout_iv("quint8_gemv_dotprod_kern"_hash)) {
  1113. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  1114. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  1115. const auto Aptr = kern_param.A<dt_uint8>(),
  1116. Bptr = kern_param.B<dt_uint8>();
  1117. auto Cptr = kern_param.C<dt_int32>();
  1118. auto A_type = kern_param.A_type, B_type = kern_param.B_type;
  1119. aarch64::matmul::gemv_like_quint8(
  1120. Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC,
  1121. A_type.param<dtype::Quantized8Asymm>().zero_point,
  1122. B_type.param<dtype::Quantized8Asymm>().zero_point);
  1123. }
  1124. MIDOUT_END();
  1125. }
  1126. } // anonymous namespace
  1127. bool MatrixMulImpl::AlgoQuint8GemvDotProd::usable(
  1128. const KernSizeParam& kern_size_param) const {
  1129. return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm &&
  1130. kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm &&
  1131. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 &&
  1132. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  1133. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  1134. !kern_size_param.trA && !kern_size_param.trB &&
  1135. kern_size_param.N == 1 && kern_size_param.LDB == 1;
  1136. }
  1137. bool MatrixMulImpl::AlgoQuint8GemvDotProd::preferred(
  1138. const KernSizeParam& kern_size_param) const {
  1139. auto N = kern_size_param.N, LDB = kern_size_param.LDB;
  1140. return (N == 1 && LDB == 1);
  1141. }
  1142. MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8GemvDotProd::get_kern(
  1143. const KernSizeParam&) const {
  1144. return quint8_gemv_dotprod_kern;
  1145. }
  1146. #else
  1147. /* ===================== Quint8 K8x8x8 algo ===================== */
  1148. namespace {
  1149. void quint8_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
  1150. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  1151. midout_iv("quint8_gemv_dotprod_kern"_hash)) {
  1152. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  1153. auto trA = kern_param.trA, trB = kern_param.trB;
  1154. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  1155. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  1156. C_type = kern_param.C_type;
  1157. const auto Aptr = kern_param.A<dt_uint8>(),
  1158. Bptr = kern_param.B<dt_uint8>();
  1159. auto Cptr = kern_param.C<dt_int32>();
  1160. aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type);
  1161. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8>(
  1162. M, N, K, trA, trB, strategy)
  1163. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  1164. kern_param.workspace_ptr);
  1165. }
  1166. MIDOUT_END();
  1167. }
  1168. } // anonymous namespace
  1169. bool MatrixMulImpl::AlgoQuint8K8x8x8::usable(
  1170. const KernSizeParam& kern_size_param) const {
  1171. return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm &&
  1172. kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm &&
  1173. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 &&
  1174. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  1175. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT;
  1176. }
  1177. size_t MatrixMulImpl::AlgoQuint8K8x8x8::get_workspace(
  1178. const KernSizeParam& kern_size_param) const {
  1179. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  1180. midout_iv("AlgoQuint8K8x8x8::get_workspace"_hash)) {
  1181. auto M = kern_size_param.M, N = kern_size_param.N,
  1182. K = kern_size_param.K;
  1183. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  1184. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  1185. C_type = kern_size_param.C_type;
  1186. aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type);
  1187. return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8>(
  1188. M, N, K, trA, trB, strategy)
  1189. .get_workspace_size();
  1190. }
  1191. MIDOUT_END();
  1192. return 0;
  1193. }
  1194. MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x8::get_kern(
  1195. const KernSizeParam&) const {
  1196. return quint8_k8x8x8_kern;
  1197. }
  1198. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8,
  1199. megdnn_aarch64_matmul_kern,
  1200. "AlgoQuint8K8x8x8Impl"_hash,
  1201. aarch64::matmul::gemm_u8_8x8, uint8_t,
  1202. int32_t, AlgoDataType::QUINT8X8X32,
  1203. DEFAULT);
  1204. #endif
  1205. /* ===================== Int8x8x16 K8x8x8 algo ===================== */
  1206. namespace {
  1207. void int8x8x16_mk4_8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
  1208. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  1209. midout_iv("int8x8x16_mk4_8x8x8_kern"_hash)) {
  1210. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  1211. auto trA = kern_param.trA, trB = kern_param.trB;
  1212. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  1213. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  1214. C_type = kern_param.C_type;
  1215. const auto Aptr = kern_param.A<dt_int8>(),
  1216. Bptr = kern_param.B<dt_int8>();
  1217. auto Cptr = kern_param.C<dt_int16>();
  1218. aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy(M, N, K, A_type,
  1219. B_type, C_type);
  1220. megdnn::matmul::GemmInterleaved<
  1221. aarch64::matmul::gemm_s8x8x16_mk4_8x8x8>(M, N, K, trA, trB,
  1222. strategy)
  1223. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  1224. kern_param.workspace_ptr);
  1225. }
  1226. MIDOUT_END();
  1227. }
  1228. } // anonymous namespace
  1229. bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::usable(
  1230. const KernSizeParam& kern_size_param) const {
  1231. return can_be_treated_as_int8x8x16(kern_size_param) &&
  1232. kern_size_param.format == param::MatrixMul::Format::MK4 &&
  1233. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  1234. !kern_size_param.trA && !kern_size_param.trB &&
  1235. kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0;
  1236. }
  1237. bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::preferred(
  1238. const KernSizeParam&) const {
  1239. return true;
  1240. }
  1241. size_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_workspace(
  1242. const KernSizeParam& kern_size_param) const {
  1243. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  1244. midout_iv("AlgoInt8x8x16_MK4_8x8x8::get_workspace"_hash)) {
  1245. auto M = kern_size_param.M, N = kern_size_param.N,
  1246. K = kern_size_param.K;
  1247. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  1248. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  1249. C_type = kern_size_param.C_type;
  1250. aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy(M, N, K, A_type,
  1251. B_type, C_type);
  1252. return megdnn::matmul::GemmInterleaved<
  1253. matmul::gemm_s8x8x16_mk4_8x8x8>(M, N, K, trA, trB,
  1254. strategy)
  1255. .get_workspace_size();
  1256. }
  1257. MIDOUT_END();
  1258. return 0;
  1259. }
  1260. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_kern(
  1261. const KernSizeParam&) const {
  1262. return int8x8x16_mk4_8x8x8_kern;
  1263. }
  1264. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8,
  1265. megdnn_aarch64_matmul_kern,
  1266. "AlgoInt8x8x16MK4_K8x8x8Impl"_hash,
  1267. aarch64::matmul::gemm_s8x8x16_mk4_8x8x8,
  1268. int8_t, int16_t, AlgoDataType::INT8X8X16,
  1269. MK4);
  1270. /* ===================== Int4x4x16 K8x8x8 algo ===================== */
  1271. namespace {
  1272. void int4x4x16_k8x8x16_kern(const MatrixMulImpl::KernParam& kern_param) {
  1273. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  1274. midout_iv("int4x4x16_k8x8x8_kern"_hash)) {
  1275. auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
  1276. auto trA = kern_param.trA, trB = kern_param.trB;
  1277. auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
  1278. auto A_type = kern_param.A_type, B_type = kern_param.B_type,
  1279. C_type = kern_param.C_type;
  1280. const auto Aptr = kern_param.A<dt_int8>(),
  1281. Bptr = kern_param.B<dt_int8>();
  1282. auto Cptr = kern_param.C<dt_int16>();
  1283. aarch64::matmul::gemm_s4x4x16_s4_8x8x8 strategy(M, N, K, A_type, B_type,
  1284. C_type);
  1285. megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_s4x4x16_s4_8x8x8>(
  1286. M, N, K, trA, trB, strategy)
  1287. .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
  1288. kern_param.workspace_ptr);
  1289. }
  1290. MIDOUT_END();
  1291. }
  1292. } // anonymous namespace
  1293. bool MatrixMulImpl::AlgoInt4x4x16K8x8x8::usable(
  1294. const KernSizeParam& kern_size_param) const {
  1295. return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
  1296. kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS4 &&
  1297. kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16 &&
  1298. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  1299. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
  1300. (kern_size_param.K & 1) == 0 && (kern_size_param.N & 1) == 0;
  1301. }
  1302. bool MatrixMulImpl::AlgoInt4x4x16K8x8x8::preferred(
  1303. const KernSizeParam& kern_size_param) const {
  1304. MEGDNN_MARK_USED_VAR(kern_size_param);
  1305. return true;
  1306. }
  1307. size_t MatrixMulImpl::AlgoInt4x4x16K8x8x8::get_workspace(
  1308. const KernSizeParam& kern_size_param) const {
  1309. MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
  1310. midout_iv("AlgoInt4x4x16K8x8x8::get_workspace"_hash)) {
  1311. auto M = kern_size_param.M, N = kern_size_param.N,
  1312. K = kern_size_param.K;
  1313. auto trA = kern_size_param.trA, trB = kern_size_param.trB;
  1314. auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
  1315. C_type = kern_size_param.C_type;
  1316. aarch64::matmul::gemm_s4x4x16_s4_8x8x8 strategy(M, N, K, A_type, B_type,
  1317. C_type);
  1318. return megdnn::matmul::GemmInterleaved<matmul::gemm_s4x4x16_s4_8x8x8>(
  1319. M, N, K, trA, trB, strategy)
  1320. .get_workspace_size();
  1321. }
  1322. MIDOUT_END();
  1323. }
  1324. MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt4x4x16K8x8x8::get_kern(
  1325. const KernSizeParam&) const {
  1326. return int4x4x16_k8x8x16_kern;
  1327. }
  1328. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt4x4x16K8x8x8,
  1329. megdnn_aarch64_matmul_kern,
  1330. "AlgoInt4x4x16K8x8x8Impl"_hash,
  1331. aarch64::matmul::gemm_s4x4x16_s4_8x8x8,
  1332. int8_t, int16_t, AlgoDataType::INT4X4X16,
  1333. DEFAULT);
  1334. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台