You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_8x8x8.h 50 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375
  1. /**
  2. * \file dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #if !(__ARM_FEATURE_DOTPROD)
  12. #include "src/aarch64/matrix_mul/asm/common.h"
  13. #include "src/arm_common/simd_macro/marm_neon.h"
  14. namespace megdnn {
  15. namespace aarch64 {
  16. namespace matmul_8x8x8 {
  17. /**
  18. * Overview of register layout:
  19. *
  20. * A 8x8x8 cell of Rhs is stored in 8bit in q26-q27
  21. * A 8x8x8 cell of Lhs is stored in 8bit in q0-q7
  22. * A 8x8 block of accumulators is stored in 32bit in q8-q23
  23. *
  24. * +--------+--------+
  25. * |v26[0-8]|v27[0-8]|
  26. * Rhs +--------+--------+
  27. * Lhs | | |
  28. *
  29. * +--------+ - - - - +-----------------+
  30. * |v0[0-8]| | v8[0-4]| v9[0-4]|
  31. * |v1[0-8]| |v10[0-4]|v11[0-4]|
  32. * |v2[0-8]| |v12[0-4]|v13[0-4]|
  33. * |v3[0-8]| |v14[0-4]|v15[0-4]|
  34. * |v4[0-8]| |v16[0-4]|v17[0-4]|
  35. * |v5[0-8]| |v18[0-4]|v19[0-4]|
  36. * |v6[0-8]| |v20[0-4]|v21[0-4]|
  37. * |v7[0-8]| |v22[0-4]|v23[0-4]|
  38. * +--------+ - - - - +-----------------+
  39. *
  40. * Accumulator
  41. */
  42. static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
  43. int32_t* output, int LDC, bool is_first_k) {
  44. K /= 8;
  45. const int8_t* a_ptr = packA;
  46. const int8_t* b_ptr = packB;
  47. LDC = LDC * sizeof(int32_t);
  48. asm volatile(
  49. // load accumulator C
  50. "add x1, %[output], %x[LDC]\n"
  51. "add x2, x1, %x[LDC]\n"
  52. "add x3, x2, %x[LDC]\n"
  53. "add x4, x3, %x[LDC]\n"
  54. "add x5, x4, %x[LDC]\n"
  55. "add x6, x5, %x[LDC]\n"
  56. "add x7, x6, %x[LDC]\n"
  57. "cmp %w[is_first_k], #1\n"
  58. "beq 1f\n"
  59. "ldp q8, q9, [%[output]]\n"
  60. "ldp q10, q11, [x1]\n"
  61. "ldp q12, q13, [x2]\n"
  62. "ldp q14, q15, [x3]\n"
  63. "ldp q16, q17, [x4]\n"
  64. "ldp q18, q19, [x5]\n"
  65. "ldp q20, q21, [x6]\n"
  66. "ldp q22, q23, [x7]\n"
  67. "b 2f\n"
  68. "1:\n"
  69. "eor v8.16b, v8.16b, v8.16b\n"
  70. "eor v9.16b, v9.16b, v9.16b\n"
  71. "eor v10.16b, v10.16b, v10.16b\n"
  72. "eor v11.16b, v11.16b, v11.16b\n"
  73. "eor v12.16b, v12.16b, v12.16b\n"
  74. "eor v13.16b, v13.16b, v13.16b\n"
  75. "eor v14.16b, v14.16b, v14.16b\n"
  76. "eor v15.16b, v15.16b, v15.16b\n"
  77. "eor v16.16b, v16.16b, v16.16b\n"
  78. "eor v17.16b, v17.16b, v17.16b\n"
  79. "eor v18.16b, v18.16b, v18.16b\n"
  80. "eor v19.16b, v19.16b, v19.16b\n"
  81. "eor v20.16b, v20.16b, v20.16b\n"
  82. "eor v21.16b, v21.16b, v21.16b\n"
  83. "eor v22.16b, v22.16b, v22.16b\n"
  84. "eor v23.16b, v23.16b, v23.16b\n"
  85. "2: \n"
  86. "ld1 {v26.8b}, [%[b_ptr]], 8\n"
  87. "ld1 {v0.8b}, [%[a_ptr]], 8\n"
  88. "ld1 {v1.8b}, [%[a_ptr]], 8\n"
  89. "ld1 {v2.8b}, [%[a_ptr]], 8\n"
  90. "ld1 {v3.8b}, [%[a_ptr]], 8\n"
  91. "ld1 {v4.8b}, [%[a_ptr]], 8\n"
  92. "ld1 {v5.8b}, [%[a_ptr]], 8\n"
  93. "ld1 {v6.8b}, [%[a_ptr]], 8\n"
  94. "ld1 {v7.8b}, [%[a_ptr]], 8\n"
  95. "sshll v26.8h, v26.8b, #0\n"
  96. "sshll v0.8h, v0.8b, #0\n"
  97. "sshll v1.8h, v1.8b, #0\n"
  98. "sshll v2.8h, v2.8b, #0\n"
  99. "sshll v3.8h, v3.8b, #0\n"
  100. "sshll v4.8h, v4.8b, #0\n"
  101. "sshll v5.8h, v5.8b, #0\n"
  102. "sshll v6.8h, v6.8b, #0\n"
  103. "sshll v7.8h, v7.8b, #0\n"
  104. "ld1 {v27.8b}, [%[b_ptr]], 8\n"
  105. "smlal v8.4s, v26.4h, v0.h[0]\n"
  106. "smlal v10.4s, v26.4h, v1.h[0]\n"
  107. "smlal v12.4s, v26.4h, v2.h[0]\n"
  108. "smlal v14.4s, v26.4h, v3.h[0]\n"
  109. "smlal v16.4s, v26.4h, v4.h[0]\n"
  110. "smlal v18.4s, v26.4h, v5.h[0]\n"
  111. "smlal v20.4s, v26.4h, v6.h[0]\n"
  112. "smlal v22.4s, v26.4h, v7.h[0]\n"
  113. "sshll v27.8h, v27.8b, #0\n"
  114. "smlal2 v9.4s, v26.8h, v0.h[0]\n"
  115. "smlal2 v11.4s, v26.8h, v1.h[0]\n"
  116. "smlal2 v13.4s, v26.8h, v2.h[0]\n"
  117. "smlal2 v15.4s, v26.8h, v3.h[0]\n"
  118. "smlal2 v17.4s, v26.8h, v4.h[0]\n"
  119. "smlal2 v19.4s, v26.8h, v5.h[0]\n"
  120. "smlal2 v21.4s, v26.8h, v6.h[0]\n"
  121. "smlal2 v23.4s, v26.8h, v7.h[0]\n"
  122. "ld1 {v26.8b}, [%[b_ptr]], 8\n"
  123. "smlal v8.4s, v27.4h, v0.h[1]\n"
  124. "smlal v10.4s, v27.4h, v1.h[1]\n"
  125. "smlal v12.4s, v27.4h, v2.h[1]\n"
  126. "smlal v14.4s, v27.4h, v3.h[1]\n"
  127. "smlal v16.4s, v27.4h, v4.h[1]\n"
  128. "smlal v18.4s, v27.4h, v5.h[1]\n"
  129. "smlal v20.4s, v27.4h, v6.h[1]\n"
  130. "smlal v22.4s, v27.4h, v7.h[1]\n"
  131. "sshll v26.8h, v26.8b, #0\n"
  132. "smlal2 v9.4s, v27.8h, v0.h[1]\n"
  133. "smlal2 v11.4s, v27.8h, v1.h[1]\n"
  134. "smlal2 v13.4s, v27.8h, v2.h[1]\n"
  135. "smlal2 v15.4s, v27.8h, v3.h[1]\n"
  136. "smlal2 v17.4s, v27.8h, v4.h[1]\n"
  137. "smlal2 v19.4s, v27.8h, v5.h[1]\n"
  138. "smlal2 v21.4s, v27.8h, v6.h[1]\n"
  139. "smlal2 v23.4s, v27.8h, v7.h[1]\n"
  140. "ld1 {v27.8b}, [%[b_ptr]], 8\n"
  141. "smlal v8.4s, v26.4h, v0.h[2]\n"
  142. "smlal v10.4s, v26.4h, v1.h[2]\n"
  143. "smlal v12.4s, v26.4h, v2.h[2]\n"
  144. "smlal v14.4s, v26.4h, v3.h[2]\n"
  145. "smlal v16.4s, v26.4h, v4.h[2]\n"
  146. "smlal v18.4s, v26.4h, v5.h[2]\n"
  147. "smlal v20.4s, v26.4h, v6.h[2]\n"
  148. "smlal v22.4s, v26.4h, v7.h[2]\n"
  149. "sshll v27.8h, v27.8b, #0\n"
  150. "smlal2 v9.4s, v26.8h, v0.h[2]\n"
  151. "smlal2 v11.4s, v26.8h, v1.h[2]\n"
  152. "smlal2 v13.4s, v26.8h, v2.h[2]\n"
  153. "smlal2 v15.4s, v26.8h, v3.h[2]\n"
  154. "smlal2 v17.4s, v26.8h, v4.h[2]\n"
  155. "smlal2 v19.4s, v26.8h, v5.h[2]\n"
  156. "smlal2 v21.4s, v26.8h, v6.h[2]\n"
  157. "smlal2 v23.4s, v26.8h, v7.h[2]\n"
  158. "ld1 {v26.8b}, [%[b_ptr]], 8\n"
  159. "smlal v8.4s, v27.4h, v0.h[3]\n"
  160. "smlal v10.4s, v27.4h, v1.h[3]\n"
  161. "smlal v12.4s, v27.4h, v2.h[3]\n"
  162. "smlal v14.4s, v27.4h, v3.h[3]\n"
  163. "smlal v16.4s, v27.4h, v4.h[3]\n"
  164. "smlal v18.4s, v27.4h, v5.h[3]\n"
  165. "smlal v20.4s, v27.4h, v6.h[3]\n"
  166. "smlal v22.4s, v27.4h, v7.h[3]\n"
  167. "sshll v26.8h, v26.8b, #0\n"
  168. "smlal2 v9.4s, v27.8h, v0.h[3]\n"
  169. "smlal2 v11.4s, v27.8h, v1.h[3]\n"
  170. "smlal2 v13.4s, v27.8h, v2.h[3]\n"
  171. "smlal2 v15.4s, v27.8h, v3.h[3]\n"
  172. "smlal2 v17.4s, v27.8h, v4.h[3]\n"
  173. "smlal2 v19.4s, v27.8h, v5.h[3]\n"
  174. "smlal2 v21.4s, v27.8h, v6.h[3]\n"
  175. "smlal2 v23.4s, v27.8h, v7.h[3]\n"
  176. "ld1 {v27.8b}, [%[b_ptr]], 8\n"
  177. "smlal v8.4s, v26.4h, v0.h[4]\n"
  178. "smlal v10.4s, v26.4h, v1.h[4]\n"
  179. "smlal v12.4s, v26.4h, v2.h[4]\n"
  180. "smlal v14.4s, v26.4h, v3.h[4]\n"
  181. "smlal v16.4s, v26.4h, v4.h[4]\n"
  182. "smlal v18.4s, v26.4h, v5.h[4]\n"
  183. "smlal v20.4s, v26.4h, v6.h[4]\n"
  184. "smlal v22.4s, v26.4h, v7.h[4]\n"
  185. "sshll v27.8h, v27.8b, #0\n"
  186. "smlal2 v9.4s, v26.8h, v0.h[4]\n"
  187. "smlal2 v11.4s, v26.8h, v1.h[4]\n"
  188. "smlal2 v13.4s, v26.8h, v2.h[4]\n"
  189. "smlal2 v15.4s, v26.8h, v3.h[4]\n"
  190. "smlal2 v17.4s, v26.8h, v4.h[4]\n"
  191. "smlal2 v19.4s, v26.8h, v5.h[4]\n"
  192. "smlal2 v21.4s, v26.8h, v6.h[4]\n"
  193. "smlal2 v23.4s, v26.8h, v7.h[4]\n"
  194. "ld1 {v26.8b}, [%[b_ptr]], 8\n"
  195. "smlal v8.4s, v27.4h, v0.h[5]\n"
  196. "smlal v10.4s, v27.4h, v1.h[5]\n"
  197. "smlal v12.4s, v27.4h, v2.h[5]\n"
  198. "smlal v14.4s, v27.4h, v3.h[5]\n"
  199. "smlal v16.4s, v27.4h, v4.h[5]\n"
  200. "smlal v18.4s, v27.4h, v5.h[5]\n"
  201. "smlal v20.4s, v27.4h, v6.h[5]\n"
  202. "smlal v22.4s, v27.4h, v7.h[5]\n"
  203. "sshll v26.8h, v26.8b, #0\n"
  204. "smlal2 v9.4s, v27.8h, v0.h[5]\n"
  205. "smlal2 v11.4s, v27.8h, v1.h[5]\n"
  206. "smlal2 v13.4s, v27.8h, v2.h[5]\n"
  207. "smlal2 v15.4s, v27.8h, v3.h[5]\n"
  208. "smlal2 v17.4s, v27.8h, v4.h[5]\n"
  209. "smlal2 v19.4s, v27.8h, v5.h[5]\n"
  210. "smlal2 v21.4s, v27.8h, v6.h[5]\n"
  211. "smlal2 v23.4s, v27.8h, v7.h[5]\n"
  212. "ld1 {v27.8b}, [%[b_ptr]], 8\n"
  213. "smlal v8.4s, v26.4h, v0.h[6]\n"
  214. "smlal v10.4s, v26.4h, v1.h[6]\n"
  215. "smlal v12.4s, v26.4h, v2.h[6]\n"
  216. "smlal v14.4s, v26.4h, v3.h[6]\n"
  217. "smlal v16.4s, v26.4h, v4.h[6]\n"
  218. "smlal v18.4s, v26.4h, v5.h[6]\n"
  219. "smlal v20.4s, v26.4h, v6.h[6]\n"
  220. "smlal v22.4s, v26.4h, v7.h[6]\n"
  221. "sshll v27.8h, v27.8b, #0\n"
  222. "smlal2 v9.4s, v26.8h, v0.h[6]\n"
  223. "smlal2 v11.4s, v26.8h, v1.h[6]\n"
  224. "smlal2 v13.4s, v26.8h, v2.h[6]\n"
  225. "smlal2 v15.4s, v26.8h, v3.h[6]\n"
  226. "smlal2 v17.4s, v26.8h, v4.h[6]\n"
  227. "smlal2 v19.4s, v26.8h, v5.h[6]\n"
  228. "smlal2 v21.4s, v26.8h, v6.h[6]\n"
  229. "smlal2 v23.4s, v26.8h, v7.h[6]\n"
  230. "smlal v8.4s, v27.4h, v0.h[7]\n"
  231. "smlal v10.4s, v27.4h, v1.h[7]\n"
  232. "smlal v12.4s, v27.4h, v2.h[7]\n"
  233. "smlal v14.4s, v27.4h, v3.h[7]\n"
  234. "smlal v16.4s, v27.4h, v4.h[7]\n"
  235. "smlal v18.4s, v27.4h, v5.h[7]\n"
  236. "smlal v20.4s, v27.4h, v6.h[7]\n"
  237. "smlal v22.4s, v27.4h, v7.h[7]\n"
  238. "smlal2 v9.4s, v27.8h, v0.h[7]\n"
  239. "smlal2 v11.4s, v27.8h, v1.h[7]\n"
  240. "smlal2 v13.4s, v27.8h, v2.h[7]\n"
  241. "smlal2 v15.4s, v27.8h, v3.h[7]\n"
  242. "smlal2 v17.4s, v27.8h, v4.h[7]\n"
  243. "smlal2 v19.4s, v27.8h, v5.h[7]\n"
  244. "smlal2 v21.4s, v27.8h, v6.h[7]\n"
  245. "smlal2 v23.4s, v27.8h, v7.h[7]\n"
  246. "subs %w[K], %w[K], #1\n"
  247. "cbnz %w[K], 2b\n"
  248. "3:\n"
  249. "stp q8, q9, [%[output]]\n"
  250. "stp q10, q11, [x1]\n"
  251. "stp q12, q13, [x2]\n"
  252. "stp q14, q15, [x3]\n"
  253. "stp q16, q17, [x4]\n"
  254. "stp q18, q19, [x5]\n"
  255. "stp q20, q21, [x6]\n"
  256. "stp q22, q23, [x7]\n"
  257. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
  258. [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
  259. [output] "+r"(output)
  260. :
  261. : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
  262. "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
  263. "v20", "v21", "v22", "v23", "v26", "v27", "x1",
  264. "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory");
  265. }
  266. /**
  267. * Overview of register layout:
  268. *
  269. * A 8x4x8 cell of Rhs is stored in 8bit in q16-q17
  270. * A 8x8x8 cell of Lhs is stored in 8bit in q0-q7
  271. * A 8x4 block of accumulators is stored in 32bit in q8-q15
  272. *
  273. * +--------+
  274. * |v16[0-4]|
  275. * Rhs +--------+
  276. * |v17[0-4]|
  277. * Lhs +--------+
  278. *
  279. * +--------+ - - - - +--------+
  280. * |v0[0-8]| | v8[0-4]|
  281. * |v1[0-8]| | v9[0-4]|
  282. * |v2[0-8]| |v10[0-4]|
  283. * |v3[0-8]| |v11[0-4]|
  284. * |v4[0-8]| |v12[0-4]|
  285. * |v5[0-8]| |v13[0-4]|
  286. * |v6[0-8]| |v14[0-4]|
  287. * |v7[0-8]| |v15[0-4]|
  288. * +--------+ - - - - +--------+
  289. *
  290. * Accumulator
  291. */
  292. static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
  293. int32_t* output, int LDC, bool is_first_k,
  294. size_t n_remain) {
  295. K /= 8;
  296. const int8_t* a_ptr = packA;
  297. const int8_t* b_ptr = packB;
  298. LDC = LDC * sizeof(int32_t);
  299. int32_t* outptr0 = output;
  300. int32_t* outptr1;
  301. int32_t* outptr2;
  302. int32_t* outptr3;
  303. int32_t* outptr4;
  304. int32_t* outptr5;
  305. int32_t* outptr6;
  306. int32_t* outptr7;
  307. size_t x0 = 0;
  308. // clang-format off
  309. #define LOAD_LINE(reg_index, n) \
  310. "mov %[x0], %[outptr" n "]\n" \
  311. "cmp %w[n_remain], #4\n" \
  312. "blt 100" n "f\n" \
  313. "ldr q" reg_index ", [%[x0]] \n" \
  314. "b 101" n "f\n" \
  315. "100" n ":\n" \
  316. "cmp %w[n_remain], #0\n" \
  317. "beq 101" n "f\n" \
  318. "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \
  319. "cmp %w[n_remain], #1\n" \
  320. "beq 101" n "f\n" \
  321. "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \
  322. "cmp %w[n_remain], #2\n" \
  323. "beq 101" n "f\n" \
  324. "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \
  325. "101" n ":\n"
  326. #define LOAD_C \
  327. LOAD_LINE("8", "0") \
  328. LOAD_LINE("9", "1") \
  329. LOAD_LINE("10", "2") \
  330. LOAD_LINE("11", "3") \
  331. LOAD_LINE("12", "4") \
  332. LOAD_LINE("13", "5") \
  333. LOAD_LINE("14", "6") \
  334. LOAD_LINE("15", "7")
  335. #define STORE_LINE(reg_index, n) \
  336. "mov %[x0], %[outptr" n "]\n" \
  337. "cmp %w[n_remain], #4\n" \
  338. "blt 102" n "f\n" \
  339. "str q" reg_index ", [%[x0]]\n" \
  340. "b 103" n "f\n" \
  341. "102" n ":\n" \
  342. "cmp %w[n_remain], #0\n" \
  343. "beq 103" n "f\n" \
  344. "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \
  345. "cmp %w[n_remain], #1\n" \
  346. "beq 103" n "f\n" \
  347. "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \
  348. "cmp %w[n_remain], #2\n" \
  349. "beq 103" n "f\n" \
  350. "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \
  351. "103" n ":\n"
  352. #define STORE_C \
  353. STORE_LINE("8", "0") \
  354. STORE_LINE("9", "1") \
  355. STORE_LINE("10", "2") \
  356. STORE_LINE("11", "3") \
  357. STORE_LINE("12", "4") \
  358. STORE_LINE("13", "5") \
  359. STORE_LINE("14", "6") \
  360. STORE_LINE("15", "7")
  361. // clang-format on
  362. asm volatile(
  363. // load accumulator C
  364. "add %[outptr1], %[outptr0], %x[LDC]\n"
  365. "add %[outptr2], %[outptr1], %x[LDC]\n"
  366. "add %[outptr3], %[outptr2], %x[LDC]\n"
  367. "add %[outptr4], %[outptr3], %x[LDC]\n"
  368. "add %[outptr5], %[outptr4], %x[LDC]\n"
  369. "add %[outptr6], %[outptr5], %x[LDC]\n"
  370. "add %[outptr7], %[outptr6], %x[LDC]\n"
  371. "cmp %w[is_first_k], #1\n"
  372. "beq 1f\n" LOAD_C
  373. "b 2f\n"
  374. "1:\n"
  375. "eor v8.16b, v8.16b, v8.16b\n"
  376. "eor v9.16b, v9.16b, v9.16b\n"
  377. "eor v10.16b, v10.16b, v10.16b\n"
  378. "eor v11.16b, v11.16b, v11.16b\n"
  379. "eor v12.16b, v12.16b, v12.16b\n"
  380. "eor v13.16b, v13.16b, v13.16b\n"
  381. "eor v14.16b, v14.16b, v14.16b\n"
  382. "eor v15.16b, v15.16b, v15.16b\n"
  383. "2: \n"
  384. "ld1 {v16.s}[0], [%[b_ptr]], 4\n"
  385. "ld1 {v0.8b}, [%[a_ptr]], 8\n"
  386. "ld1 {v1.8b}, [%[a_ptr]], 8\n"
  387. "ld1 {v2.8b}, [%[a_ptr]], 8\n"
  388. "ld1 {v3.8b}, [%[a_ptr]], 8\n"
  389. "ld1 {v4.8b}, [%[a_ptr]], 8\n"
  390. "ld1 {v5.8b}, [%[a_ptr]], 8\n"
  391. "ld1 {v6.8b}, [%[a_ptr]], 8\n"
  392. "ld1 {v7.8b}, [%[a_ptr]], 8\n"
  393. "sshll v16.8h, v16.8b, #0\n"
  394. "sshll v0.8h, v0.8b, #0\n"
  395. "sshll v1.8h, v1.8b, #0\n"
  396. "sshll v2.8h, v2.8b, #0\n"
  397. "sshll v3.8h, v3.8b, #0\n"
  398. "sshll v4.8h, v4.8b, #0\n"
  399. "sshll v5.8h, v5.8b, #0\n"
  400. "sshll v6.8h, v6.8b, #0\n"
  401. "sshll v7.8h, v7.8b, #0\n"
  402. "ld1 {v17.s}[0], [%[b_ptr]], 4\n"
  403. "smlal v8.4s, v16.4h, v0.h[0]\n"
  404. "smlal v9.4s, v16.4h, v1.h[0]\n"
  405. "smlal v10.4s, v16.4h, v2.h[0]\n"
  406. "smlal v11.4s, v16.4h, v3.h[0]\n"
  407. "sshll v17.8h, v17.8b, #0\n"
  408. "smlal v12.4s, v16.4h, v4.h[0]\n"
  409. "smlal v13.4s, v16.4h, v5.h[0]\n"
  410. "smlal v14.4s, v16.4h, v6.h[0]\n"
  411. "smlal v15.4s, v16.4h, v7.h[0]\n"
  412. "ld1 {v16.s}[0], [%[b_ptr]], 4\n"
  413. "smlal v8.4s, v17.4h, v0.h[1]\n"
  414. "smlal v9.4s, v17.4h, v1.h[1]\n"
  415. "smlal v10.4s, v17.4h, v2.h[1]\n"
  416. "smlal v11.4s, v17.4h, v3.h[1]\n"
  417. "sshll v16.8h, v16.8b, #0\n"
  418. "smlal v12.4s, v17.4h, v4.h[1]\n"
  419. "smlal v13.4s, v17.4h, v5.h[1]\n"
  420. "smlal v14.4s, v17.4h, v6.h[1]\n"
  421. "smlal v15.4s, v17.4h, v7.h[1]\n"
  422. "ld1 {v17.s}[0], [%[b_ptr]], 4\n"
  423. "smlal v8.4s, v16.4h, v0.h[2]\n"
  424. "smlal v9.4s, v16.4h, v1.h[2]\n"
  425. "smlal v10.4s, v16.4h, v2.h[2]\n"
  426. "smlal v11.4s, v16.4h, v3.h[2]\n"
  427. "sshll v17.8h, v17.8b, #0\n"
  428. "smlal v12.4s, v16.4h, v4.h[2]\n"
  429. "smlal v13.4s, v16.4h, v5.h[2]\n"
  430. "smlal v14.4s, v16.4h, v6.h[2]\n"
  431. "smlal v15.4s, v16.4h, v7.h[2]\n"
  432. "ld1 {v16.s}[0], [%[b_ptr]], 4\n"
  433. "smlal v8.4s, v17.4h, v0.h[3]\n"
  434. "smlal v9.4s, v17.4h, v1.h[3]\n"
  435. "smlal v10.4s, v17.4h, v2.h[3]\n"
  436. "smlal v11.4s, v17.4h, v3.h[3]\n"
  437. "sshll v16.8h, v16.8b, #0\n"
  438. "smlal v12.4s, v17.4h, v4.h[3]\n"
  439. "smlal v13.4s, v17.4h, v5.h[3]\n"
  440. "smlal v14.4s, v17.4h, v6.h[3]\n"
  441. "smlal v15.4s, v17.4h, v7.h[3]\n"
  442. "ld1 {v17.s}[0], [%[b_ptr]], 4\n"
  443. "smlal v8.4s, v16.4h, v0.h[4]\n"
  444. "smlal v9.4s, v16.4h, v1.h[4]\n"
  445. "smlal v10.4s, v16.4h, v2.h[4]\n"
  446. "smlal v11.4s, v16.4h, v3.h[4]\n"
  447. "sshll v17.8h, v17.8b, #0\n"
  448. "smlal v12.4s, v16.4h, v4.h[4]\n"
  449. "smlal v13.4s, v16.4h, v5.h[4]\n"
  450. "smlal v14.4s, v16.4h, v6.h[4]\n"
  451. "smlal v15.4s, v16.4h, v7.h[4]\n"
  452. "ld1 {v16.s}[0], [%[b_ptr]], 4\n"
  453. "smlal v8.4s, v17.4h, v0.h[5]\n"
  454. "smlal v9.4s, v17.4h, v1.h[5]\n"
  455. "smlal v10.4s, v17.4h, v2.h[5]\n"
  456. "smlal v11.4s, v17.4h, v3.h[5]\n"
  457. "sshll v16.8h, v16.8b, #0\n"
  458. "smlal v12.4s, v17.4h, v4.h[5]\n"
  459. "smlal v13.4s, v17.4h, v5.h[5]\n"
  460. "smlal v14.4s, v17.4h, v6.h[5]\n"
  461. "smlal v15.4s, v17.4h, v7.h[5]\n"
  462. "ld1 {v17.s}[0], [%[b_ptr]], 4\n"
  463. "smlal v8.4s, v16.4h, v0.h[6]\n"
  464. "smlal v9.4s, v16.4h, v1.h[6]\n"
  465. "smlal v10.4s, v16.4h, v2.h[6]\n"
  466. "smlal v11.4s, v16.4h, v3.h[6]\n"
  467. "sshll v17.8h, v17.8b, #0\n"
  468. "smlal v12.4s, v16.4h, v4.h[6]\n"
  469. "smlal v13.4s, v16.4h, v5.h[6]\n"
  470. "smlal v14.4s, v16.4h, v6.h[6]\n"
  471. "smlal v15.4s, v16.4h, v7.h[6]\n"
  472. "smlal v8.4s, v17.4h, v0.h[7]\n"
  473. "smlal v9.4s, v17.4h, v1.h[7]\n"
  474. "smlal v10.4s, v17.4h, v2.h[7]\n"
  475. "smlal v11.4s, v17.4h, v3.h[7]\n"
  476. "smlal v12.4s, v17.4h, v4.h[7]\n"
  477. "smlal v13.4s, v17.4h, v5.h[7]\n"
  478. "smlal v14.4s, v17.4h, v6.h[7]\n"
  479. "smlal v15.4s, v17.4h, v7.h[7]\n"
  480. "subs %w[K], %w[K], #1\n"
  481. "cbnz %w[K], 2b\n"
  482. "3:\n" STORE_C
  483. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
  484. [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
  485. [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
  486. [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
  487. [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5),
  488. [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0),
  489. [n_remain] "+r"(n_remain)
  490. :
  491. : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
  492. "v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory");
  493. #undef LOAD_LINE
  494. #undef LOAD_C
  495. #undef STORE_LINE
  496. #undef STORE_C
  497. }
  498. /**
  499. * Overview of register layout:
  500. *
  501. * A 8x8x8 cell of Rhs is stored in 8bit in q12-q13
  502. * A 8x8x4 cell of Lhs is stored in 8bit in q0-q3
  503. * A 4x8 block of accumulators is stored in 32bit in q4-q11
  504. *
  505. * +--------+--------+
  506. * |v12[0-8]|v13[0-8]|
  507. * Rhs +--------+--------+
  508. * Lhs | | |
  509. *
  510. * +--------+ - - - - +-----------------+
  511. * |v0[0-8]| | v4[0-4]| v5[0-4]|
  512. * |v1[0-8]| | v6[0-4]| v7[0-4]|
  513. * |v2[0-8]| | v8[0-4]| v9[0-4]|
  514. * |v3[0-8]| |v10[0-4]|v11[0-4]|
  515. * +--------+ - - - - +-----------------+
  516. *
  517. * Accumulator
  518. */
  519. static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
  520. int32_t* output, int LDC, bool is_first_k,
  521. size_t m_remain) {
  522. K /= 8;
  523. const int8_t* a_ptr = packA;
  524. const int8_t* b_ptr = packB;
  525. LDC = LDC * sizeof(int32_t);
  526. int32_t* outptr0 = output;
  527. int32_t* outptr1;
  528. int32_t* outptr2;
  529. int32_t* outptr3;
  530. size_t x0 = 0;
  531. // clang-format off
  532. #define LOAD_LINE(v1, v2, m) \
  533. "cbz %[x0], 100f\n" \
  534. "ldp " v1 "," v2 ", [%[outptr" m "]]\n" \
  535. "subs %[x0], %[x0], #1\n"
  536. #define LOAD_C \
  537. "mov %[x0], %x[m_remain]\n" \
  538. LOAD_LINE("q4", "q5", "0") \
  539. LOAD_LINE("q6", "q7", "1") \
  540. LOAD_LINE("q8", "q9", "2") \
  541. LOAD_LINE("q10", "q11", "3") \
  542. "100:\n"
  543. #define STORE_LINE(v1, v2, m) \
  544. "cbz %[x0], 101f\n" \
  545. "stp " v1 "," v2", [%[outptr" m "]]\n" \
  546. "subs %[x0], %[x0], #1\n"
  547. #define STORE_C \
  548. "mov %[x0], %x[m_remain]\n" \
  549. STORE_LINE("q4", "q5", "0") \
  550. STORE_LINE("q6", "q7", "1") \
  551. STORE_LINE("q8", "q9", "2") \
  552. STORE_LINE("q10", "q11", "3") \
  553. "101:\n"
  554. // clang-format on
  555. asm volatile(
  556. // load accumulator C
  557. "add %[outptr1], %[outptr0], %x[LDC]\n"
  558. "add %[outptr2], %[outptr1], %x[LDC]\n"
  559. "add %[outptr3], %[outptr2], %x[LDC]\n"
  560. "cmp %w[is_first_k], #1\n"
  561. "beq 1f\n" LOAD_C
  562. "b 2f\n"
  563. "1:\n"
  564. "eor v4.16b, v4.16b, v4.16b\n"
  565. "eor v5.16b, v5.16b, v5.16b\n"
  566. "eor v6.16b, v6.16b, v6.16b\n"
  567. "eor v7.16b, v7.16b, v7.16b\n"
  568. "eor v8.16b, v8.16b, v8.16b\n"
  569. "eor v9.16b, v9.16b, v9.16b\n"
  570. "eor v10.16b, v10.16b, v10.16b\n"
  571. "eor v11.16b, v11.16b, v11.16b\n"
  572. "2: \n"
  573. "ld1 {v12.8b}, [%[b_ptr]], 8\n"
  574. "ld1 {v0.8b}, [%[a_ptr]], 8\n"
  575. "ld1 {v1.8b}, [%[a_ptr]], 8\n"
  576. "ld1 {v2.8b}, [%[a_ptr]], 8\n"
  577. "ld1 {v3.8b}, [%[a_ptr]], 8\n"
  578. "sshll v12.8h, v12.8b, #0\n"
  579. "sshll v0.8h, v0.8b, #0\n"
  580. "sshll v1.8h, v1.8b, #0\n"
  581. "sshll v2.8h, v2.8b, #0\n"
  582. "sshll v3.8h, v3.8b, #0\n"
  583. "ld1 {v13.8b}, [%[b_ptr]], 8\n"
  584. "smlal v4.4s, v12.4h, v0.h[0]\n"
  585. "smlal v6.4s, v12.4h, v1.h[0]\n"
  586. "smlal v8.4s, v12.4h, v2.h[0]\n"
  587. "smlal v10.4s, v12.4h, v3.h[0]\n"
  588. "sshll v13.8h, v13.8b, #0\n"
  589. "smlal2 v5.4s, v12.8h, v0.h[0]\n"
  590. "smlal2 v7.4s, v12.8h, v1.h[0]\n"
  591. "smlal2 v9.4s, v12.8h, v2.h[0]\n"
  592. "smlal2 v11.4s, v12.8h, v3.h[0]\n"
  593. "ld1 {v12.8b}, [%[b_ptr]], 8\n"
  594. "smlal v4.4s, v13.4h, v0.h[1]\n"
  595. "smlal v6.4s, v13.4h, v1.h[1]\n"
  596. "smlal v8.4s, v13.4h, v2.h[1]\n"
  597. "smlal v10.4s, v13.4h, v3.h[1]\n"
  598. "sshll v12.8h, v12.8b, #0\n"
  599. "smlal2 v5.4s, v13.8h, v0.h[1]\n"
  600. "smlal2 v7.4s, v13.8h, v1.h[1]\n"
  601. "smlal2 v9.4s, v13.8h, v2.h[1]\n"
  602. "smlal2 v11.4s, v13.8h, v3.h[1]\n"
  603. "ld1 {v13.8b}, [%[b_ptr]], 8\n"
  604. "smlal v4.4s, v12.4h, v0.h[2]\n"
  605. "smlal v6.4s, v12.4h, v1.h[2]\n"
  606. "smlal v8.4s, v12.4h, v2.h[2]\n"
  607. "smlal v10.4s, v12.4h, v3.h[2]\n"
  608. "sshll v13.8h, v13.8b, #0\n"
  609. "smlal2 v5.4s, v12.8h, v0.h[2]\n"
  610. "smlal2 v7.4s, v12.8h, v1.h[2]\n"
  611. "smlal2 v9.4s, v12.8h, v2.h[2]\n"
  612. "smlal2 v11.4s, v12.8h, v3.h[2]\n"
  613. "ld1 {v12.8b}, [%[b_ptr]], 8\n"
  614. "smlal v4.4s, v13.4h, v0.h[3]\n"
  615. "smlal v6.4s, v13.4h, v1.h[3]\n"
  616. "smlal v8.4s, v13.4h, v2.h[3]\n"
  617. "smlal v10.4s, v13.4h, v3.h[3]\n"
  618. "sshll v12.8h, v12.8b, #0\n"
  619. "smlal2 v5.4s, v13.8h, v0.h[3]\n"
  620. "smlal2 v7.4s, v13.8h, v1.h[3]\n"
  621. "smlal2 v9.4s, v13.8h, v2.h[3]\n"
  622. "smlal2 v11.4s, v13.8h, v3.h[3]\n"
  623. "ld1 {v13.8b}, [%[b_ptr]], 8\n"
  624. "smlal v4.4s, v12.4h, v0.h[4]\n"
  625. "smlal v6.4s, v12.4h, v1.h[4]\n"
  626. "smlal v8.4s, v12.4h, v2.h[4]\n"
  627. "smlal v10.4s, v12.4h, v3.h[4]\n"
  628. "sshll v13.8h, v13.8b, #0\n"
  629. "smlal2 v5.4s, v12.8h, v0.h[4]\n"
  630. "smlal2 v7.4s, v12.8h, v1.h[4]\n"
  631. "smlal2 v9.4s, v12.8h, v2.h[4]\n"
  632. "smlal2 v11.4s, v12.8h, v3.h[4]\n"
  633. "ld1 {v12.8b}, [%[b_ptr]], 8\n"
  634. "smlal v4.4s, v13.4h, v0.h[5]\n"
  635. "smlal v6.4s, v13.4h, v1.h[5]\n"
  636. "smlal v8.4s, v13.4h, v2.h[5]\n"
  637. "smlal v10.4s, v13.4h, v3.h[5]\n"
  638. "sshll v12.8h, v12.8b, #0\n"
  639. "smlal2 v5.4s, v13.8h, v0.h[5]\n"
  640. "smlal2 v7.4s, v13.8h, v1.h[5]\n"
  641. "smlal2 v9.4s, v13.8h, v2.h[5]\n"
  642. "smlal2 v11.4s, v13.8h, v3.h[5]\n"
  643. "ld1 {v13.8b}, [%[b_ptr]], 8\n"
  644. "smlal v4.4s, v12.4h, v0.h[6]\n"
  645. "smlal v6.4s, v12.4h, v1.h[6]\n"
  646. "smlal v8.4s, v12.4h, v2.h[6]\n"
  647. "smlal v10.4s, v12.4h, v3.h[6]\n"
  648. "sshll v13.8h, v13.8b, #0\n"
  649. "smlal2 v5.4s, v12.8h, v0.h[6]\n"
  650. "smlal2 v7.4s, v12.8h, v1.h[6]\n"
  651. "smlal2 v9.4s, v12.8h, v2.h[6]\n"
  652. "smlal2 v11.4s, v12.8h, v3.h[6]\n"
  653. "smlal v4.4s, v13.4h, v0.h[7]\n"
  654. "smlal v6.4s, v13.4h, v1.h[7]\n"
  655. "smlal v8.4s, v13.4h, v2.h[7]\n"
  656. "smlal v10.4s, v13.4h, v3.h[7]\n"
  657. "smlal2 v5.4s, v13.8h, v0.h[7]\n"
  658. "smlal2 v7.4s, v13.8h, v1.h[7]\n"
  659. "smlal2 v9.4s, v13.8h, v2.h[7]\n"
  660. "smlal2 v11.4s, v13.8h, v3.h[7]\n"
  661. "subs %w[K], %w[K], #1\n"
  662. "cbnz %w[K], 2b\n"
  663. "3:\n" STORE_C
  664. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
  665. [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
  666. [outptr0] "+r"(outptr0),
  667. [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2),
  668. [outptr3] "=r"(outptr3), [x0] "+r"(x0), [m_remain] "+r"(m_remain)
  669. :
  670. : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
  671. "v11", "v12", "v13", "cc", "memory");
  672. #undef LOAD_LINE
  673. #undef LOAD_C
  674. #undef STORE_LINE
  675. #undef STORE_C
  676. }
  677. /**
  678. * Overview of register layout:
  679. *
  680. * A 8x4x8 cell of Rhs is stored in 8bit in q8-q9
  681. * A 8x8x4 cell of Lhs is stored in 8bit in q0-q3
  682. * A 4x4 block of accumulators is stored in 32bit in q4-q7
  683. *
  684. * +--------+
  685. * | v8[0-4]|
  686. * Rhs +--------+
  687. * | v9[0-4]|
  688. * Lhs +--------+
  689. *
  690. * +--------+ - - - - +--------+
  691. * |v0[0-8]| | v4[0-4]|
  692. * |v1[0-8]| | v5[0-4]|
  693. * |v2[0-8]| | v6[0-4]|
  694. * |v3[0-8]| | v7[0-4]|
  695. * +--------+ - - - - +--------+
  696. *
  697. * Accumulator
  698. */
  699. static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
  700. int32_t* output, int LDC, bool is_first_k, size_t m_remain,
  701. size_t n_remain) {
  702. K /= 8;
  703. const int8_t* a_ptr = packA;
  704. const int8_t* b_ptr = packB;
  705. LDC = LDC * sizeof(int32_t);
  706. int32_t* outptr0 = output;
  707. int32_t* outptr1;
  708. int32_t* outptr2;
  709. int32_t* outptr3;
  710. size_t x0 = 0;
  711. size_t x1 = 0;
  712. // clang-format off
  713. #define LOAD_LINE(reg_index, n) \
  714. "cbz %[x1], 102f\n" \
  715. "mov %[x0], %[outptr" n "]\n" \
  716. "cmp %w[n_remain], #4\n" \
  717. "blt 100" n "f\n" \
  718. "ldr q" reg_index ", [%[x0]]\n" \
  719. "b 101" n "f\n" \
  720. "100" n ":\n" \
  721. "cmp %w[n_remain], #0\n" \
  722. "beq 101" n "f\n" \
  723. "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \
  724. "cmp %w[n_remain], #1\n" \
  725. "beq 101" n "f\n" \
  726. "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \
  727. "cmp %w[n_remain], #2\n" \
  728. "beq 101" n "f\n" \
  729. "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \
  730. "101" n ":\n" \
  731. "subs %[x1], %[x1], #1\n"
  732. #define LOAD_C \
  733. "mov %[x1], %x[m_remain]\n" \
  734. LOAD_LINE("4", "0") \
  735. LOAD_LINE("5", "1") \
  736. LOAD_LINE("6", "2") \
  737. LOAD_LINE("7", "3") \
  738. "102:\n"
  739. #define STORE_LINE(reg_index, n) \
  740. "cbz %[x1], 105f\n" \
  741. "mov %[x0], %[outptr" n "]\n" \
  742. "cmp %w[n_remain], #4\n" \
  743. "blt 103" n "f\n" \
  744. "str q" reg_index ", [%[x0]]\n" \
  745. "b 104" n "f\n" \
  746. "103" n ":\n" \
  747. "cmp %w[n_remain], #0\n" \
  748. "beq 104" n "f\n" \
  749. "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \
  750. "cmp %w[n_remain], #1\n" \
  751. "beq 104" n "f\n" \
  752. "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \
  753. "cmp %w[n_remain], #2\n" \
  754. "beq 104" n "f\n" \
  755. "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \
  756. "104" n ":\n" \
  757. "subs %[x1], %[x1], #1\n"
  758. #define STORE_C \
  759. "mov %[x1], %x[m_remain]\n" \
  760. STORE_LINE("4", "0") \
  761. STORE_LINE("5", "1") \
  762. STORE_LINE("6", "2") \
  763. STORE_LINE("7", "3") \
  764. "105:\n"
  765. // clang-format on
  766. asm volatile(
  767. // load accumulator C
  768. "add %[outptr1], %[outptr0], %x[LDC]\n"
  769. "add %[outptr2], %[outptr1], %x[LDC]\n"
  770. "add %[outptr3], %[outptr2], %x[LDC]\n"
  771. "cmp %w[is_first_k], #1\n"
  772. "beq 1f\n" LOAD_C
  773. "b 2f\n"
  774. "1:\n"
  775. "eor v4.16b, v4.16b, v4.16b\n"
  776. "eor v5.16b, v5.16b, v5.16b\n"
  777. "eor v6.16b, v6.16b, v6.16b\n"
  778. "eor v7.16b, v7.16b, v7.16b\n"
  779. "2: \n"
  780. "ld1 {v8.s}[0], [%[b_ptr]], 4\n"
  781. "ld1 {v0.8b}, [%[a_ptr]], 8\n"
  782. "ld1 {v1.8b}, [%[a_ptr]], 8\n"
  783. "ld1 {v2.8b}, [%[a_ptr]], 8\n"
  784. "ld1 {v3.8b}, [%[a_ptr]], 8\n"
  785. "sshll v8.8h, v8.8b, #0\n"
  786. "sshll v0.8h, v0.8b, #0\n"
  787. "sshll v1.8h, v1.8b, #0\n"
  788. "sshll v2.8h, v2.8b, #0\n"
  789. "sshll v3.8h, v3.8b, #0\n"
  790. "ld1 {v9.s}[0], [%[b_ptr]], 4\n"
  791. "smlal v4.4s, v8.4h, v0.h[0]\n"
  792. "smlal v5.4s, v8.4h, v1.h[0]\n"
  793. "sshll v9.8h, v9.8b, #0\n"
  794. "smlal v6.4s, v8.4h, v2.h[0]\n"
  795. "smlal v7.4s, v8.4h, v3.h[0]\n"
  796. "ld1 {v8.s}[0], [%[b_ptr]], 4\n"
  797. "smlal v4.4s, v9.4h, v0.h[1]\n"
  798. "smlal v5.4s, v9.4h, v1.h[1]\n"
  799. "sshll v8.8h, v8.8b, #0\n"
  800. "smlal v6.4s, v9.4h, v2.h[1]\n"
  801. "smlal v7.4s, v9.4h, v3.h[1]\n"
  802. "ld1 {v9.s}[0], [%[b_ptr]], 4\n"
  803. "smlal v4.4s, v8.4h, v0.h[2]\n"
  804. "smlal v5.4s, v8.4h, v1.h[2]\n"
  805. "sshll v9.8h, v9.8b, #0\n"
  806. "smlal v6.4s, v8.4h, v2.h[2]\n"
  807. "smlal v7.4s, v8.4h, v3.h[2]\n"
  808. "ld1 {v8.s}[0], [%[b_ptr]], 4\n"
  809. "smlal v4.4s, v9.4h, v0.h[3]\n"
  810. "smlal v5.4s, v9.4h, v1.h[3]\n"
  811. "sshll v8.8h, v8.8b, #0\n"
  812. "smlal v6.4s, v9.4h, v2.h[3]\n"
  813. "smlal v7.4s, v9.4h, v3.h[3]\n"
  814. "ld1 {v9.s}[0], [%[b_ptr]], 4\n"
  815. "smlal v4.4s, v8.4h, v0.h[4]\n"
  816. "smlal v5.4s, v8.4h, v1.h[4]\n"
  817. "sshll v9.8h, v9.8b, #0\n"
  818. "smlal v6.4s, v8.4h, v2.h[4]\n"
  819. "smlal v7.4s, v8.4h, v3.h[4]\n"
  820. "ld1 {v8.s}[0], [%[b_ptr]], 4\n"
  821. "smlal v4.4s, v9.4h, v0.h[5]\n"
  822. "smlal v5.4s, v9.4h, v1.h[5]\n"
  823. "sshll v8.8h, v8.8b, #0\n"
  824. "smlal v6.4s, v9.4h, v2.h[5]\n"
  825. "smlal v7.4s, v9.4h, v3.h[5]\n"
  826. "ld1 {v9.s}[0], [%[b_ptr]], 4\n"
  827. "smlal v4.4s, v8.4h, v0.h[6]\n"
  828. "smlal v5.4s, v8.4h, v1.h[6]\n"
  829. "sshll v9.8h, v9.8b, #0\n"
  830. "smlal v6.4s, v8.4h, v2.h[6]\n"
  831. "smlal v7.4s, v8.4h, v3.h[6]\n"
  832. "smlal v4.4s, v9.4h, v0.h[7]\n"
  833. "smlal v5.4s, v9.4h, v1.h[7]\n"
  834. "smlal v6.4s, v9.4h, v2.h[7]\n"
  835. "smlal v7.4s, v9.4h, v3.h[7]\n"
  836. "subs %w[K], %w[K], #1\n"
  837. "cbnz %w[K], 2b\n"
  838. "3:\n" STORE_C
  839. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
  840. [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
  841. [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
  842. [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0),
  843. [x1] "+r"(x1), [m_remain] "+r"(m_remain),
  844. [n_remain] "+r"(n_remain)
  845. :
  846. : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v11", "cc",
  847. "memory");
  848. #undef LOAD_LINE
  849. #undef LOAD_C
  850. #undef STORE_LINE
  851. #undef STORE_C
  852. }
  853. static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin,
  854. int y0, int ymax, int k0, int kmax) {
  855. int8_t zerobuff[16];
  856. std::memset(zerobuff, 0, sizeof(int8_t) * 16);
  857. int y = y0;
  858. for (; y + 7 < ymax; y += 8) {
  859. const int8_t* inptr0 = inptr + y * ldin + k0;
  860. const int8_t* inptr1 = inptr0 + ldin;
  861. const int8_t* inptr2 = inptr1 + ldin;
  862. const int8_t* inptr3 = inptr2 + ldin;
  863. const int8_t* inptr4 = inptr3 + ldin;
  864. const int8_t* inptr5 = inptr4 + ldin;
  865. const int8_t* inptr6 = inptr5 + ldin;
  866. const int8_t* inptr7 = inptr6 + ldin;
  867. prefetch_2x(inptr0);
  868. prefetch_2x(inptr1);
  869. prefetch_2x(inptr2);
  870. prefetch_2x(inptr3);
  871. prefetch_2x(inptr4);
  872. prefetch_2x(inptr5);
  873. prefetch_2x(inptr6);
  874. prefetch_2x(inptr7);
  875. int K = kmax - k0;
  876. for (; K > 15; K -= 16) {
  877. interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
  878. inptr6, inptr7, outptr);
  879. }
  880. if (K > 0) {
  881. interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
  882. inptr7, outptr, 8, K);
  883. }
  884. }
  885. for (; y < ymax; y += 4) {
  886. const int8_t* inptr0 = inptr + y * ldin + k0;
  887. const int8_t* inptr1 = inptr0 + ldin;
  888. const int8_t* inptr2 = inptr1 + ldin;
  889. const int8_t* inptr3 = inptr2 + ldin;
  890. prefetch_2x(inptr0);
  891. prefetch_2x(inptr1);
  892. prefetch_2x(inptr2);
  893. prefetch_2x(inptr3);
  894. int K = kmax - k0;
  895. for (; K > 15; K -= 16) {
  896. if (y + 3 >= ymax) {
  897. switch (y + 3 - ymax) {
  898. case 2:
  899. inptr1 = zerobuff; MEGDNN_FALLTHRU
  900. case 1:
  901. inptr2 = zerobuff; MEGDNN_FALLTHRU
  902. case 0:
  903. inptr3 = zerobuff;
  904. break;
  905. default:
  906. megdnn_assert(0);
  907. }
  908. }
  909. interleave_4x8_2_b(inptr0, inptr1, inptr2, inptr3, outptr);
  910. }
  911. if (K > 0) {
  912. if (y + 3 >= ymax) {
  913. switch (y + 3 - ymax) {
  914. case 2:
  915. inptr1 = zerobuff; MEGDNN_FALLTHRU
  916. case 1:
  917. inptr2 = zerobuff; MEGDNN_FALLTHRU
  918. case 0:
  919. inptr3 = zerobuff;
  920. break;
  921. default:
  922. megdnn_assert(0);
  923. }
  924. }
  925. interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K);
  926. }
  927. }
  928. }
  929. static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
  930. int ldin, int x0, int xmax, int k0,
  931. int kmax) {
  932. int8_t zerobuff[16];
  933. std::memset(zerobuff, 0, sizeof(int8_t) * 16);
  934. const int ksize = kmax - k0;
  935. const int ksize4 = round_up(ksize, 8) * 4;
  936. const int ksize8 = ksize4 * 2;
  937. int8_t* outptr = out;
  938. int8_t* outptr_base = out;
  939. //! 4x4 block output start pos
  940. int8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8;
  941. int k = k0;
  942. for (; k < kmax; k += 8) {
  943. const int8_t* inptr0 = in + k * ldin + x0;
  944. const int8_t* inptr1 = inptr0 + ldin;
  945. const int8_t* inptr2 = inptr1 + ldin;
  946. const int8_t* inptr3 = inptr2 + ldin;
  947. const int8_t* inptr4 = inptr3 + ldin;
  948. const int8_t* inptr5 = inptr4 + ldin;
  949. const int8_t* inptr6 = inptr5 + ldin;
  950. const int8_t* inptr7 = inptr6 + ldin;
  951. prefetch_2x(inptr0);
  952. prefetch_2x(inptr1);
  953. prefetch_2x(inptr2);
  954. prefetch_2x(inptr3);
  955. prefetch_2x(inptr4);
  956. prefetch_2x(inptr5);
  957. prefetch_2x(inptr6);
  958. prefetch_2x(inptr7);
  959. int x = x0;
  960. outptr = outptr_base;
  961. for (; x + 7 < xmax; x += 8) {
  962. if (k + 7 >= kmax) {
  963. switch (k + 7 - kmax) {
  964. case 6:
  965. inptr1 = zerobuff; MEGDNN_FALLTHRU
  966. case 5:
  967. inptr2 = zerobuff; MEGDNN_FALLTHRU
  968. case 4:
  969. inptr3 = zerobuff; MEGDNN_FALLTHRU
  970. case 3:
  971. inptr4 = zerobuff; MEGDNN_FALLTHRU
  972. case 2:
  973. inptr5 = zerobuff; MEGDNN_FALLTHRU
  974. case 1:
  975. inptr6 = zerobuff; MEGDNN_FALLTHRU
  976. case 0:
  977. inptr7 = zerobuff;
  978. break;
  979. default:
  980. megdnn_assert(0);
  981. }
  982. }
  983. transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
  984. inptr6, inptr7, outptr);
  985. outptr += ksize8;
  986. }
  987. outptr = outptr_base4;
  988. for (; x + 3 < xmax; x += 4) {
  989. if (k + 7 >= kmax) {
  990. switch (k + 7 - kmax) {
  991. case 6:
  992. inptr1 = zerobuff; MEGDNN_FALLTHRU
  993. case 5:
  994. inptr2 = zerobuff; MEGDNN_FALLTHRU
  995. case 4:
  996. inptr3 = zerobuff; MEGDNN_FALLTHRU
  997. case 3:
  998. inptr4 = zerobuff; MEGDNN_FALLTHRU
  999. case 2:
  1000. inptr5 = zerobuff; MEGDNN_FALLTHRU
  1001. case 1:
  1002. inptr6 = zerobuff; MEGDNN_FALLTHRU
  1003. case 0:
  1004. inptr7 = zerobuff;
  1005. break;
  1006. default:
  1007. megdnn_assert(0);
  1008. }
  1009. }
  1010. transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
  1011. inptr7, outptr, 4, 4);
  1012. outptr += ksize4;
  1013. }
  1014. if (x < xmax) {
  1015. if (k + 7 >= kmax) {
  1016. switch (k + 7 - kmax) {
  1017. case 6:
  1018. inptr1 = zerobuff; MEGDNN_FALLTHRU
  1019. case 5:
  1020. inptr2 = zerobuff; MEGDNN_FALLTHRU
  1021. case 4:
  1022. inptr3 = zerobuff; MEGDNN_FALLTHRU
  1023. case 3:
  1024. inptr4 = zerobuff; MEGDNN_FALLTHRU
  1025. case 2:
  1026. inptr5 = zerobuff; MEGDNN_FALLTHRU
  1027. case 1:
  1028. inptr6 = zerobuff; MEGDNN_FALLTHRU
  1029. case 0:
  1030. inptr7 = zerobuff;
  1031. break;
  1032. default:
  1033. megdnn_assert(0);
  1034. }
  1035. }
  1036. transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
  1037. inptr7, outptr, 4, xmax - x);
  1038. }
  1039. outptr_base += 8 * 8;
  1040. outptr_base4 += 4 * 8;
  1041. }
  1042. }
  1043. static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
  1044. int x0, int xmax, int k0, int kmax) {
  1045. int8_t zerobuff[16];
  1046. std::memset(zerobuff, 0, sizeof(int8_t) * 16);
  1047. const int ksize = kmax - k0;
  1048. const int ksize4 = round_up(ksize, 8) * 4;
  1049. const int ksize8 = ksize4 * 2;
  1050. int8_t* outptr = out;
  1051. int8_t* outptr_base = out;
  1052. int8_t* outptr_interleave = nullptr;
  1053. //! 4x4 block output start pos
  1054. int8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8;
  1055. int k = k0;
  1056. for (; k < kmax; k += 8) {
  1057. const int8_t* inptr0 = in + k * ldin + x0;
  1058. const int8_t* inptr1 = inptr0 + ldin;
  1059. const int8_t* inptr2 = inptr1 + ldin;
  1060. const int8_t* inptr3 = inptr2 + ldin;
  1061. const int8_t* inptr4 = inptr3 + ldin;
  1062. const int8_t* inptr5 = inptr4 + ldin;
  1063. const int8_t* inptr6 = inptr5 + ldin;
  1064. const int8_t* inptr7 = inptr6 + ldin;
  1065. prefetch_2x(inptr0);
  1066. prefetch_2x(inptr1);
  1067. prefetch_2x(inptr2);
  1068. prefetch_2x(inptr3);
  1069. prefetch_2x(inptr4);
  1070. prefetch_2x(inptr5);
  1071. prefetch_2x(inptr6);
  1072. prefetch_2x(inptr7);
  1073. int x = x0;
  1074. outptr = outptr_base;
  1075. for (; x + 7 < xmax; x += 8) {
  1076. if (k + 7 >= kmax) {
  1077. switch (k + 7 - kmax) {
  1078. case 6:
  1079. inptr1 = zerobuff; MEGDNN_FALLTHRU
  1080. case 5:
  1081. inptr2 = zerobuff; MEGDNN_FALLTHRU
  1082. case 4:
  1083. inptr3 = zerobuff; MEGDNN_FALLTHRU
  1084. case 3:
  1085. inptr4 = zerobuff; MEGDNN_FALLTHRU
  1086. case 2:
  1087. inptr5 = zerobuff; MEGDNN_FALLTHRU
  1088. case 1:
  1089. inptr6 = zerobuff; MEGDNN_FALLTHRU
  1090. case 0:
  1091. inptr7 = zerobuff;
  1092. break;
  1093. default:
  1094. megdnn_assert(0);
  1095. }
  1096. }
  1097. outptr_interleave = outptr;
  1098. interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
  1099. inptr6, inptr7, outptr_interleave);
  1100. outptr += ksize8;
  1101. }
  1102. outptr = outptr_base4;
  1103. for (; x + 3 < xmax; x += 4) {
  1104. if (k + 7 >= kmax) {
  1105. switch (k + 7 - kmax) {
  1106. case 6:
  1107. inptr1 = zerobuff; MEGDNN_FALLTHRU
  1108. case 5:
  1109. inptr2 = zerobuff; MEGDNN_FALLTHRU
  1110. case 4:
  1111. inptr3 = zerobuff; MEGDNN_FALLTHRU
  1112. case 3:
  1113. inptr4 = zerobuff; MEGDNN_FALLTHRU
  1114. case 2:
  1115. inptr5 = zerobuff; MEGDNN_FALLTHRU
  1116. case 1:
  1117. inptr6 = zerobuff; MEGDNN_FALLTHRU
  1118. case 0:
  1119. inptr7 = zerobuff;
  1120. break;
  1121. default:
  1122. megdnn_assert(0);
  1123. }
  1124. }
  1125. outptr_interleave = outptr;
  1126. interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
  1127. inptr7, outptr_interleave, 4, 4);
  1128. outptr += ksize4;
  1129. }
  1130. if (x < xmax) {
  1131. if (k + 7 >= kmax) {
  1132. switch (k + 7 - kmax) {
  1133. case 6:
  1134. inptr1 = zerobuff; MEGDNN_FALLTHRU
  1135. case 5:
  1136. inptr2 = zerobuff; MEGDNN_FALLTHRU
  1137. case 4:
  1138. inptr3 = zerobuff; MEGDNN_FALLTHRU
  1139. case 3:
  1140. inptr4 = zerobuff; MEGDNN_FALLTHRU
  1141. case 2:
  1142. inptr5 = zerobuff; MEGDNN_FALLTHRU
  1143. case 1:
  1144. inptr6 = zerobuff; MEGDNN_FALLTHRU
  1145. case 0:
  1146. inptr7 = zerobuff;
  1147. break;
  1148. default:
  1149. megdnn_assert(0);
  1150. }
  1151. }
  1152. outptr_interleave = outptr;
  1153. interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
  1154. inptr7, outptr_interleave, 4, xmax - x);
  1155. }
  1156. outptr_base += 8 * 8;
  1157. outptr_base4 += 4 * 8;
  1158. }
  1159. }
  1160. static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr,
  1161. int ldin, int y0, int ymax, int k0,
  1162. int kmax) {
  1163. int8_t zerobuff[16];
  1164. std::memset(zerobuff, 0, sizeof(int8_t) * 16);
  1165. constexpr int interleave4 = 32;
  1166. constexpr int interleave8 = 64;
  1167. int y = y0;
  1168. for (; y + 7 < ymax; y += 8) {
  1169. const int8_t* inptr0 = inptr + y * ldin + k0;
  1170. const int8_t* inptr1 = inptr0 + ldin;
  1171. const int8_t* inptr2 = inptr1 + ldin;
  1172. const int8_t* inptr3 = inptr2 + ldin;
  1173. const int8_t* inptr4 = inptr3 + ldin;
  1174. const int8_t* inptr5 = inptr4 + ldin;
  1175. const int8_t* inptr6 = inptr5 + ldin;
  1176. const int8_t* inptr7 = inptr6 + ldin;
  1177. prefetch_2x(inptr0);
  1178. prefetch_2x(inptr1);
  1179. prefetch_2x(inptr2);
  1180. prefetch_2x(inptr3);
  1181. prefetch_2x(inptr4);
  1182. prefetch_2x(inptr5);
  1183. prefetch_2x(inptr6);
  1184. prefetch_2x(inptr7);
  1185. int K = kmax - k0;
  1186. for (; K > 7; K -= 8) {
  1187. transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
  1188. inptr6, inptr7, outptr);
  1189. outptr += interleave8;
  1190. }
  1191. if (K > 0) {
  1192. transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
  1193. inptr7, outptr, 8, K);
  1194. outptr += interleave8;
  1195. }
  1196. }
  1197. for (; y < ymax; y += 4) {
  1198. const int8_t* inptr0 = inptr + y * ldin + k0;
  1199. const int8_t* inptr1 = inptr0 + ldin;
  1200. const int8_t* inptr2 = inptr1 + ldin;
  1201. const int8_t* inptr3 = inptr2 + ldin;
  1202. prefetch_2x(inptr0);
  1203. prefetch_2x(inptr1);
  1204. prefetch_2x(inptr2);
  1205. prefetch_2x(inptr3);
  1206. int K = kmax - k0;
  1207. for (; K > 7; K -= 8) {
  1208. if (y + 3 >= ymax) {
  1209. switch (y + 3 - ymax) {
  1210. case 2:
  1211. inptr1 = zerobuff; MEGDNN_FALLTHRU
  1212. case 1:
  1213. inptr2 = zerobuff; MEGDNN_FALLTHRU
  1214. case 0:
  1215. inptr3 = zerobuff;
  1216. break;
  1217. default:
  1218. megdnn_assert(0);
  1219. }
  1220. }
  1221. transpose_8x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr);
  1222. outptr += interleave4;
  1223. }
  1224. if (K > 0) {
  1225. if (y + 3 >= ymax) {
  1226. switch (y + 3 - ymax) {
  1227. case 2:
  1228. inptr1 = zerobuff; MEGDNN_FALLTHRU
  1229. case 1:
  1230. inptr2 = zerobuff; MEGDNN_FALLTHRU
  1231. case 0:
  1232. inptr3 = zerobuff;
  1233. break;
  1234. default:
  1235. megdnn_assert(0);
  1236. }
  1237. }
  1238. transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K);
  1239. outptr += interleave4;
  1240. }
  1241. }
  1242. }
  1243. } // namespace matmul_8x8x8
  1244. } // namespace aarch64
  1245. } // namespace megdnn
  1246. // vim: syntax=cpp.doxygen
  1247. #endif

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台