You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_8x12x4.h 58 kB


  1. /**
  2. * \file dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #if MGB_ENABLE_DOT
  12. #include "src/aarch64/matrix_mul/asm/common.h"
  13. #include "src/arm_common/simd_macro/marm_neon.h"
  14. namespace megdnn {
  15. namespace aarch64 {
  16. namespace matmul_8x12x4 {
  17. // Overview of register layout:
  18. //
  19. // A 12x4 cell of Rhs is stored in 8bit in q2-q4.
  20. // A 8x4x2 cell of Lhs is stored in 8bit in q0-q1,q5-q6
  21. // A 8x12 block of accumulators is stored in 8bit in q8--q31.
  22. //
  23. // +--------+--------+--------+
  24. // |v2[0-16]|v3[0-16]|v4[0-16]|
  25. // Rhs +--------+--------+--------+
  26. //
  27. // | | | |
  28. //
  29. // Lhs | | | |
  30. //
  31. // +-------+-------+ - - - - +--------+--------+--------+
  32. // |v0[0-4]|v5[0-4]| | v8[0-4]|v16[0-4]|v24[0-4]|
  33. // |v0[0-4]|v5[0-4]| | v9[0-4]|v17[0-4]|v25[0-4]|
  34. // |v0[0-4]|v5[0-4]| |v10[0-4]|v18[0-4]|v26[0-4]|
  35. // |v0[0-4]|v5[0-4]| |v11[0-4]|v19[0-4]|v27[0-4]|
  36. // |v1[0-4]|v6[0-4]| |v12[0-4]|v20[0-4]|v28[0-4]|
  37. // |v1[0-4]|v6[0-4]| |v13[0-4]|v21[0-4]|v29[0-4]|
  38. // |v1[0-4]|v6[0-4]| |v14[0-4]|v22[0-4]|v30[0-4]|
  39. // |v1[0-4]|v6[0-4]| |v15[0-4]|v23[0-4]|v31[0-4]|
  40. // +-------+-------+ - - - - +--------+--------+--------+
  41. //
  42. // Accumulator
  43. /**
  44. * \note The performance of reorder instruction and use prefetch is almost the
  45. * same, I test in kirin980 with small and big core, here i just keep both the
  46. * implementation.
  47. */
  48. #if 1
  49. MEGDNN_ATTRIBUTE_TARGET("dotprod")
  50. static void kern_8x12(
  51. const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
  52. bool is_first_k) {
  53. K /= 4;
  54. const int8_t* a_ptr = packA;
  55. const int8_t* b_ptr = packB;
  56. // Fix up for odd lengths - set a flag if K is odd, but make
  57. // sure we round up the iteration count.
  58. int oddk = (K & 1);
  59. int k = ((K + 1) / 2) - 1;
  60. int32x4_t a0;
  61. int32x4_t a1;
  62. int32x4_t b0;
  63. int32x4_t b1;
  64. int32x4_t b2;
  65. int32x4_t a0a;
  66. int32x4_t a1a;
  67. LDC = LDC * sizeof(int32_t);
  68. int32_t* outptr0 = output;
  69. int32_t* outptr1;
  70. int32_t* outptr2;
  71. int32_t* outptr3;
  72. int32_t* outptr4;
  73. int32_t* outptr5;
  74. int32_t* outptr6;
  75. int32_t* outptr7;
  76. asm volatile (
  77. // load accumulator C
  78. "add %[outptr1], %[outptr0], %x[LDC]\n"
  79. "add %[outptr2], %[outptr1], %x[LDC]\n"
  80. "add %[outptr3], %[outptr2], %x[LDC]\n"
  81. "add %[outptr4], %[outptr3], %x[LDC]\n"
  82. "add %[outptr5], %[outptr4], %x[LDC]\n"
  83. "add %[outptr6], %[outptr5], %x[LDC]\n"
  84. "add %[outptr7], %[outptr6], %x[LDC]\n"
  85. "cmp %w[is_first_k], #1\n"
  86. "beq 5f\n"
  87. // we can not use ld1, as it can not encode {v8, v16, v24}
  88. "ldp q8, q16, [%[outptr0]]\n"
  89. "ldr q24, [%[outptr0], #32]\n"
  90. "ldp q9, q17, [%[outptr1]]\n"
  91. "ldr q25, [%[outptr1], #32]\n"
  92. "ldp q10, q18, [%[outptr2]]\n"
  93. "ldr q26, [%[outptr2], #32]\n"
  94. "ldp q11, q19, [%[outptr3]]\n"
  95. "ldr q27, [%[outptr3], #32]\n"
  96. "ldp q12, q20, [%[outptr4]]\n"
  97. "ldr q28, [%[outptr4], #32]\n"
  98. "ldp q13, q21, [%[outptr5]]\n"
  99. "ldr q29, [%[outptr5], #32]\n"
  100. "ldp q14, q22, [%[outptr6]]\n"
  101. "ldr q30, [%[outptr6], #32]\n"
  102. "ldp q15, q23, [%[outptr7]]\n"
  103. "ldr q31, [%[outptr7], #32]\n"
  104. "b 6f\n"
  105. "5:\n"
  106. "eor v8.16b, v8.16b, v8.16b\n"
  107. "eor v9.16b, v9.16b, v9.16b\n"
  108. "eor v10.16b, v10.16b, v10.16b\n"
  109. "eor v11.16b, v11.16b, v11.16b\n"
  110. "eor v12.16b, v12.16b, v12.16b\n"
  111. "eor v13.16b, v13.16b, v13.16b\n"
  112. "eor v14.16b, v14.16b, v14.16b\n"
  113. "eor v15.16b, v15.16b, v15.16b\n"
  114. "eor v16.16b, v16.16b, v16.16b\n"
  115. "eor v17.16b, v17.16b, v17.16b\n"
  116. "eor v18.16b, v18.16b, v18.16b\n"
  117. "eor v19.16b, v19.16b, v19.16b\n"
  118. "eor v20.16b, v20.16b, v20.16b\n"
  119. "eor v21.16b, v21.16b, v21.16b\n"
  120. "eor v22.16b, v22.16b, v22.16b\n"
  121. "eor v23.16b, v23.16b, v23.16b\n"
  122. "eor v24.16b, v24.16b, v24.16b\n"
  123. "eor v25.16b, v25.16b, v25.16b\n"
  124. "eor v26.16b, v26.16b, v26.16b\n"
  125. "eor v27.16b, v27.16b, v27.16b\n"
  126. "eor v28.16b, v28.16b, v28.16b\n"
  127. "eor v29.16b, v29.16b, v29.16b\n"
  128. "eor v30.16b, v30.16b, v30.16b\n"
  129. "eor v31.16b, v31.16b, v31.16b\n"
  130. "6: \n"
  131. // Initialize result registers, load initial operands, prime prefetches.
  132. "ldr %q[a0], [%[a_ptr]]\n"
  133. "ldr %q[b0], [%[b_ptr]]\n"
  134. "ldr %q[a1], [%[a_ptr], #16]\n"
  135. "ldr %q[b1], [%[b_ptr], #16]\n"
  136. ASM_PREFETCH("[%[b_ptr], #64]")
  137. ASM_PREFETCH("[%[a_ptr], #64]")
  138. ASM_PREFETCH("[%[b_ptr], #128]")
  139. ASM_PREFETCH("[%[a_ptr], #128]")
  140. ASM_PREFETCH("[%[b_ptr], #192]")
  141. ASM_PREFETCH("[%[b_ptr], #256]")
  142. ASM_PREFETCH("[%[a_ptr], #192]")
  143. ASM_PREFETCH("[%[b_ptr], #320]")
  144. ASM_PREFETCH("[%[a_ptr], #256]")
  145. ASM_PREFETCH("[%[b_ptr], #384]")
  146. // Skip loop if we are doing zero iterations of it.
  147. "cbz %w[k], 4f\n"
  148. // Loop proper
  149. "1:\n"
  150. "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n"
  151. "sdot v9.4s , %[b0].16b, %[a0].4b[1]\n"
  152. "ldr %q[b2], [%[b_ptr], #32]\n"
  153. "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n"
  154. "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
  155. "ldr %q[a0a], [%[a_ptr], #32]\n"
  156. "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n"
  157. "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n"
  158. "ldr %q[a1a], [%[a_ptr], #48]\n"
  159. "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n"
  160. "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n"
  161. "ldr %q[b0], [%[b_ptr], #48]\n"
  162. "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n"
  163. "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n"
  164. ASM_PREFETCH("[%[a_ptr], #320]")
  165. "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n"
  166. "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n"
  167. "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n"
  168. "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n"
  169. "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n"
  170. "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n"
  171. "ldr %q[b1], [%[b_ptr], #64]\n"
  172. "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n"
  173. "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n"
  174. ASM_PREFETCH("[%[b_ptr], #448]")
  175. "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n"
  176. "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n"
  177. "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n"
  178. "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n"
  179. "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n"
  180. "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n"
  181. "ldr %q[b2], [%[b_ptr], #80]\n"
  182. "sdot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
  183. "sdot v9.4s , %[b0].16b, %[a0a].4b[1]\n"
  184. "ldr %q[a0], [%[a_ptr], #64]\n"
  185. "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
  186. "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
  187. "sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
  188. "ldr %q[a1], [%[a_ptr], #80]\n"
  189. "sdot v13.4s, %[b0].16b, %[a1a].4b[1]\n"
  190. "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
  191. "sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
  192. "ldr %q[b0], [%[b_ptr], #96]\n"
  193. "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
  194. "sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n"
  195. ASM_PREFETCH("[%[b_ptr], #512]")
  196. "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
  197. "sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
  198. "sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
  199. "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
  200. "sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
  201. "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
  202. "ldr %q[b1], [%[b_ptr], #112]\n"
  203. "sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
  204. "sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
  205. "add %[a_ptr], %[a_ptr], #64\n"
  206. "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
  207. "sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
  208. "add %[b_ptr], %[b_ptr], #96\n"
  209. "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n"
  210. "sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
  211. "subs %w[k], %w[k], #1\n"
  212. "sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
  213. "sdot v31.4s, %[b2].16b, %[a1a].4b[3]\n"
  214. "bne 1b\n"
  215. // Target to use when K is 1 or 2 (i.e. zero iterations of main loop)
  216. "4:\n"
  217. // Branch to alternative tail for odd K
  218. "cbnz %w[oddk], 2f\n"
  219. // Detached final iteration (even K)
  220. "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n"
  221. "sdot v9.4s , %[b0].16b, %[a0].4b[1]\n"
  222. "ldr %q[b2], [%[b_ptr], #32]\n"
  223. "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n"
  224. "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
  225. "ldr %q[a0a], [%[a_ptr], #32]\n"
  226. "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n"
  227. "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n"
  228. "ldr %q[a1a], [%[a_ptr], #48]\n"
  229. "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n"
  230. "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n"
  231. "ldr %q[b0], [%[b_ptr], #48]\n"
  232. "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n"
  233. "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n"
  234. "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n"
  235. "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n"
  236. "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n"
  237. "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n"
  238. "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n"
  239. "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n"
  240. "ldr %q[b1], [%[b_ptr], #64]\n"
  241. "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n"
  242. "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n"
  243. "add %[a_ptr], %[a_ptr], #64\n"
  244. "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n"
  245. "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n"
  246. "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n"
  247. "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n"
  248. "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n"
  249. "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n"
  250. "ldr %q[b2], [%[b_ptr], #80]\n"
  251. "sdot v8.4s , %[b0].16b, %[a0a].4b[0]\n"
  252. "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
  253. "add %[b_ptr], %[b_ptr], #96\n"
  254. "sdot v9.4s , %[b0].16b, %[a0a].4b[1]\n"
  255. "str q8, [%[outptr0], #0]\n"
  256. "sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n"
  257. "str q16, [%[outptr0], #16]\n"
  258. "sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
  259. "str q24, [%[outptr0], #32]\n"
  260. "sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
  261. "str q9, [%[outptr1], #0]\n"
  262. "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
  263. "str q17, [%[outptr1], #16]\n"
  264. "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
  265. "str q25, [%[outptr1], #32]\n"
  266. "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
  267. "str q10, [%[outptr2], #0]\n"
  268. "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
  269. "str q18, [%[outptr2], #16]\n"
  270. "sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
  271. "str q26, [%[outptr2], #32]\n"
  272. "sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
  273. "str q11, [%[outptr3], #0]\n"
  274. "sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
  275. "str q19, [%[outptr3], #16]\n"
  276. "sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
  277. "str q27, [%[outptr3], #32]\n"
  278. "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n"
  279. "str q12, [%[outptr4], #0]\n"
  280. "sdot v13.4s, %[b0].16b, %[a1a].4b[1]\n"
  281. "str q20, [%[outptr4], #16]\n"
  282. "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
  283. "str q28, [%[outptr4], #32]\n"
  284. "sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
  285. "str q13, [%[outptr5], #0]\n"
  286. "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
  287. "str q21, [%[outptr5], #16]\n"
  288. "sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
  289. "str q29, [%[outptr5], #32]\n"
  290. "sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
  291. "str q14, [%[outptr6], #0]\n"
  292. "sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
  293. "str q22, [%[outptr6], #16]\n"
  294. "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
  295. "str q30, [%[outptr6], #32]\n"
  296. "sdot v31.4s, %[b2].16b, %[a1a].4b[3]\n"
  297. "str q15, [%[outptr7], #0]\n"
  298. "b 3f\n"
  299. // Detached final iteration (odd K)
  300. "2:\n"
  301. "sdot v8.4s , %[b0].16b, %[a0].4b[0]\n"
  302. "ldr %q[b2], [%[b_ptr], #32]\n"
  303. "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n"
  304. "sdot v9.4s , %[b0].16b, %[a0].4b[1]\n"
  305. "str q8, [%[outptr0], #0]\n"
  306. "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n"
  307. "str q16, [%[outptr0], #16]\n"
  308. "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n"
  309. "add %[b_ptr], %[b_ptr], #48\n"
  310. "add %[a_ptr], %[a_ptr], #32\n"
  311. "str q24, [%[outptr0], #32]\n"
  312. "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n"
  313. "str q9, [%[outptr1], #0]\n"
  314. "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n"
  315. "str q17, [%[outptr1], #16]\n"
  316. "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n"
  317. "str q25, [%[outptr1], #32]\n"
  318. "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n"
  319. "str q10, [%[outptr2], #0]\n"
  320. "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
  321. "str q18, [%[outptr2], #16]\n"
  322. "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n"
  323. "str q26, [%[outptr2], #32]\n"
  324. "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n"
  325. "str q11, [%[outptr3], #0]\n"
  326. "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n"
  327. "str q19, [%[outptr3], #16]\n"
  328. "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n"
  329. "str q27, [%[outptr3], #32]\n"
  330. "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n"
  331. "str q12, [%[outptr4], #0]\n"
  332. "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n"
  333. "str q20, [%[outptr4], #16]\n"
  334. "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n"
  335. "str q28, [%[outptr4], #32]\n"
  336. "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n"
  337. "str q13, [%[outptr5], #0]\n"
  338. "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n"
  339. "str q21, [%[outptr5], #16]\n"
  340. "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n"
  341. "str q29, [%[outptr5], #32]\n"
  342. "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n"
  343. "str q14, [%[outptr6], #0]\n"
  344. "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n"
  345. "str q22, [%[outptr6], #16]\n"
  346. "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n"
  347. "str q30, [%[outptr6], #32]\n"
  348. "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n"
  349. "str q15, [%[outptr7], #0]\n"
  350. // Common tail
  351. "3:\n"
  352. "str q23, [%[outptr7], #16]\n"
  353. "str q31, [%[outptr7], #32]\n"
  354. :
  355. [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr),[oddk] "+r" (oddk),
  356. [is_first_k] "+r" (is_first_k), [k] "+r" (k), [LDC] "+r" (LDC),
  357. [a0] "=w" (a0), [a1] "=w" (a1), [a0a] "=w" (a0a), [a1a] "=w" (a1a),
  358. [b0] "=w" (b0), [b1] "=w" (b1), [b2] "=w" (b2),
  359. [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
  360. [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
  361. [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5),
  362. [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7)
  363. :
  364. : "v8", "v9", "v10", "v11", "v12", "v13", "v14",
  365. "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23",
  366. "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc",
  367. "memory"
  368. );
  369. }
  370. #else
  371. MEGDNN_ATTRIBUTE_TARGET("dotprod")
  372. static void kern_8x12(
  373. const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
  374. bool is_first_k) {
  375. K /= 4;
  376. const int8_t* a_ptr = packA;
  377. const int8_t* b_ptr = packB;
  378. // Fix up for odd lengths - set a flag if K is odd, but make
  379. // sure we round up the iteration count.
  380. int oddk = (K & 1);
  381. int k = K / 2;
  382. int32x4_t a0;
  383. int32x4_t a1;
  384. int32x4_t b0;
  385. int32x4_t b1;
  386. int32x4_t b2;
  387. int32x4_t a0a;
  388. int32x4_t a1a;
  389. LDC = LDC * sizeof(int32_t);
  390. int32_t* outptr0 = output;
  391. int32_t* outptr1;
  392. int32_t* outptr2;
  393. int32_t* outptr3;
  394. int32_t* outptr4;
  395. int32_t* outptr5;
  396. int32_t* outptr6;
  397. int32_t* outptr7;
  398. asm volatile(
  399. // load accumulator C
  400. "add %[outptr1], %[outptr0], %x[LDC]\n"
  401. "add %[outptr2], %[outptr1], %x[LDC]\n"
  402. "add %[outptr3], %[outptr2], %x[LDC]\n"
  403. "add %[outptr4], %[outptr3], %x[LDC]\n"
  404. "add %[outptr5], %[outptr4], %x[LDC]\n"
  405. "add %[outptr6], %[outptr5], %x[LDC]\n"
  406. "add %[outptr7], %[outptr6], %x[LDC]\n"
  407. "cmp %w[is_first_k], #1\n"
  408. "beq 1f\n"
  409. // we can not use ld1, as it can not encode {v8, v16, v24}
  410. "ldp q8, q16, [%[outptr0]]\n"
  411. "ldr q24, [%[outptr0], #32]\n"
  412. "ldp q9, q17, [%[outptr1]]\n"
  413. "ldr q25, [%[outptr1], #32]\n"
  414. "ldp q10, q18, [%[outptr2]]\n"
  415. "ldr q26, [%[outptr2], #32]\n"
  416. "ldp q11, q19, [%[outptr3]]\n"
  417. "ldr q27, [%[outptr3], #32]\n"
  418. "ldp q12, q20, [%[outptr4]]\n"
  419. "ldr q28, [%[outptr4], #32]\n"
  420. "ldp q13, q21, [%[outptr5]]\n"
  421. "ldr q29, [%[outptr5], #32]\n"
  422. "ldp q14, q22, [%[outptr6]]\n"
  423. "ldr q30, [%[outptr6], #32]\n"
  424. "ldp q15, q23, [%[outptr7]]\n"
  425. "ldr q31, [%[outptr7], #32]\n"
  426. "b 2f\n"
  427. "1:\n"
  428. "eor v8.16b, v8.16b, v8.16b\n"
  429. "eor v9.16b, v9.16b, v9.16b\n"
  430. "eor v10.16b, v10.16b, v10.16b\n"
  431. "eor v11.16b, v11.16b, v11.16b\n"
  432. "eor v12.16b, v12.16b, v12.16b\n"
  433. "eor v13.16b, v13.16b, v13.16b\n"
  434. "eor v14.16b, v14.16b, v14.16b\n"
  435. "eor v15.16b, v15.16b, v15.16b\n"
  436. "eor v16.16b, v16.16b, v16.16b\n"
  437. "eor v17.16b, v17.16b, v17.16b\n"
  438. "eor v18.16b, v18.16b, v18.16b\n"
  439. "eor v19.16b, v19.16b, v19.16b\n"
  440. "eor v20.16b, v20.16b, v20.16b\n"
  441. "eor v21.16b, v21.16b, v21.16b\n"
  442. "eor v22.16b, v22.16b, v22.16b\n"
  443. "eor v23.16b, v23.16b, v23.16b\n"
  444. "eor v24.16b, v24.16b, v24.16b\n"
  445. "eor v25.16b, v25.16b, v25.16b\n"
  446. "eor v26.16b, v26.16b, v26.16b\n"
  447. "eor v27.16b, v27.16b, v27.16b\n"
  448. "eor v28.16b, v28.16b, v28.16b\n"
  449. "eor v29.16b, v29.16b, v29.16b\n"
  450. "eor v30.16b, v30.16b, v30.16b\n"
  451. "eor v31.16b, v31.16b, v31.16b\n"
  452. "2: \n"
  453. "cbz %w[oddk], 3f\n"
  454. // parse the oddk
  455. "ldr %q[a0], [%[a_ptr]], #16\n"
  456. "ldr %q[a1], [%[a_ptr]], #16\n"
  457. "ldr %q[b0], [%[b_ptr]], #16\n"
  458. "ldr %q[b1], [%[b_ptr]], #16\n"
  459. "ldr %q[b2], [%[b_ptr]], #16\n"
  460. "sdot v8.4s, %[b0].16b, %[a0].4b[0]\n"
  461. "sdot v9.4s, %[b0].16b, %[a0].4b[1]\n"
  462. "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n"
  463. "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
  464. "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n"
  465. "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n"
  466. "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n"
  467. "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n"
  468. "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n"
  469. "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n"
  470. "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n"
  471. "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n"
  472. "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n"
  473. "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n"
  474. "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n"
  475. "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n"
  476. "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n"
  477. "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n"
  478. "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n"
  479. "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n"
  480. "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n"
  481. "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n"
  482. "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n"
  483. "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n"
  484. "cbz %w[k], 4f\n"
  485. // Loop proper
  486. "3:\n"
  487. "ldr %q[a0], [%[a_ptr]], #16\n"
  488. "ldr %q[a1], [%[a_ptr]], #16\n"
  489. "ldr %q[a0a], [%[a_ptr]], #16\n"
  490. "ldr %q[a1a], [%[a_ptr]], #16\n"
  491. "ldr %q[b0], [%[b_ptr]], #16\n"
  492. "ldr %q[b1], [%[b_ptr]], #16\n"
  493. "ldr %q[b2], [%[b_ptr]], #16\n"
  494. "sdot v8.4s, %[b0].16b, %[a0].4b[0]\n"
  495. "sdot v9.4s, %[b0].16b, %[a0].4b[1]\n"
  496. "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n"
  497. "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
  498. "sdot v12.4s, %[b0].16b, %[a1].4b[0]\n"
  499. "sdot v13.4s, %[b0].16b, %[a1].4b[1]\n"
  500. "sdot v14.4s, %[b0].16b, %[a1].4b[2]\n"
  501. "sdot v15.4s, %[b0].16b, %[a1].4b[3]\n"
  502. "sdot v16.4s, %[b1].16b, %[a0].4b[0]\n"
  503. "sdot v17.4s, %[b1].16b, %[a0].4b[1]\n"
  504. "sdot v18.4s, %[b1].16b, %[a0].4b[2]\n"
  505. "sdot v19.4s, %[b1].16b, %[a0].4b[3]\n"
  506. "sdot v20.4s, %[b1].16b, %[a1].4b[0]\n"
  507. "sdot v21.4s, %[b1].16b, %[a1].4b[1]\n"
  508. "sdot v22.4s, %[b1].16b, %[a1].4b[2]\n"
  509. "sdot v23.4s, %[b1].16b, %[a1].4b[3]\n"
  510. "sdot v24.4s, %[b2].16b, %[a0].4b[0]\n"
  511. "sdot v25.4s, %[b2].16b, %[a0].4b[1]\n"
  512. "sdot v26.4s, %[b2].16b, %[a0].4b[2]\n"
  513. "sdot v27.4s, %[b2].16b, %[a0].4b[3]\n"
  514. "sdot v28.4s, %[b2].16b, %[a1].4b[0]\n"
  515. "sdot v29.4s, %[b2].16b, %[a1].4b[1]\n"
  516. "sdot v30.4s, %[b2].16b, %[a1].4b[2]\n"
  517. "sdot v31.4s, %[b2].16b, %[a1].4b[3]\n"
  518. "ldr %q[b0], [%[b_ptr]], #16\n"
  519. "ldr %q[b1], [%[b_ptr]], #16\n"
  520. "ldr %q[b2], [%[b_ptr]], #16\n"
  521. "sdot v8.4s, %[b0].16b, %[a0a].4b[0]\n"
  522. "sdot v9.4s, %[b0].16b, %[a0a].4b[1]\n"
  523. "sdot v10.4s, %[b0].16b, %[a0a].4b[2]\n"
  524. "sdot v11.4s, %[b0].16b, %[a0a].4b[3]\n"
  525. "sdot v12.4s, %[b0].16b, %[a1a].4b[0]\n"
  526. "sdot v13.4s, %[b0].16b, %[a1a].4b[1]\n"
  527. "sdot v14.4s, %[b0].16b, %[a1a].4b[2]\n"
  528. "sdot v15.4s, %[b0].16b, %[a1a].4b[3]\n"
  529. "sdot v16.4s, %[b1].16b, %[a0a].4b[0]\n"
  530. "sdot v17.4s, %[b1].16b, %[a0a].4b[1]\n"
  531. "sdot v18.4s, %[b1].16b, %[a0a].4b[2]\n"
  532. "sdot v19.4s, %[b1].16b, %[a0a].4b[3]\n"
  533. "sdot v20.4s, %[b1].16b, %[a1a].4b[0]\n"
  534. "sdot v21.4s, %[b1].16b, %[a1a].4b[1]\n"
  535. "sdot v22.4s, %[b1].16b, %[a1a].4b[2]\n"
  536. "sdot v23.4s, %[b1].16b, %[a1a].4b[3]\n"
  537. "sdot v24.4s, %[b2].16b, %[a0a].4b[0]\n"
  538. "sdot v25.4s, %[b2].16b, %[a0a].4b[1]\n"
  539. "sdot v26.4s, %[b2].16b, %[a0a].4b[2]\n"
  540. "sdot v27.4s, %[b2].16b, %[a0a].4b[3]\n"
  541. "sdot v28.4s, %[b2].16b, %[a1a].4b[0]\n"
  542. "sdot v29.4s, %[b2].16b, %[a1a].4b[1]\n"
  543. "sdot v30.4s, %[b2].16b, %[a1a].4b[2]\n"
  544. "sdot v31.4s, %[b2].16b, %[a1a].4b[3]\n"
  545. "subs %w[k], %w[k], #1\n"
  546. "bne 3b\n"
  547. "4:\n"
  548. "stp q8, q16, [%[outptr0]]\n"
  549. "str q24, [%[outptr0], #32]\n"
  550. "stp q9, q17, [%[outptr1]]\n"
  551. "str q25, [%[outptr1], #32]\n"
  552. "stp q10, q18, [%[outptr2]]\n"
  553. "str q26, [%[outptr2], #32]\n"
  554. "stp q11, q19, [%[outptr3]]\n"
  555. "str q27, [%[outptr3], #32]\n"
  556. "stp q12, q20, [%[outptr4]]\n"
  557. "str q28, [%[outptr4], #32]\n"
  558. "stp q13, q21, [%[outptr5]]\n"
  559. "str q29, [%[outptr5], #32]\n"
  560. "stp q14, q22, [%[outptr6]]\n"
  561. "str q30, [%[outptr6], #32]\n"
  562. "stp q15, q23, [%[outptr7]]\n"
  563. "str q31, [%[outptr7], #32]\n"
  564. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), [a1] "+w"(a1),
  565. [a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), [b1] "+w"(b1),
  566. [b2] "+w"(b2), [k] "+r"(k), [LDC] "+r"(LDC), [oddk] "+r"(oddk),
  567. [is_first_k] "+r"(is_first_k), [outptr0] "+r"(outptr0),
  568. [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
  569. [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6),
  570. [outptr7] "=r"(outptr7)
  571. :
  572. : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
  573. "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
  574. "v29", "v30", "v31", "cc", "memory");
  575. }
  576. #endif
  577. // Overview of register layout:
  578. //
  579. // A 12x4 cell of Rhs is stored in 8bit in q2-q4.
  580. // A 8x4x2 cell of Lhs is stored in 8bit in q0-q1,q5-q6
  581. // A 8x12 block of accumulators is stored in 8bit in q8--q31.
  582. //
  583. // +--------+--------+--------+
  584. // |v1[0-16]|v2[0-16]|v3[0-16]|
  585. // Rhs +--------+--------+--------+
  586. // |v5[0-16]|v6[0-16]|v7[0-16]|
  587. // +--------+--------+--------+
  588. //
  589. // | | | |
  590. //
  591. // Lhs | | | |
  592. //
  593. // +-------+-------+ - - - - +--------+--------+--------+
  594. // |v0[0-4]|v4[0-4]| | v8[0-4]|v12[0-4]|v16[0-4]|
  595. // |v0[0-4]|v4[0-4]| | v9[0-4]|v13[0-4]|v17[0-4]|
  596. // |v0[0-4]|v4[0-4]| |v10[0-4]|v14[0-4]|v18[0-4]|
  597. // |v0[0-4]|v4[0-4]| |v11[0-4]|v15[0-4]|v19[0-4]|
  598. // +-------+-------+ - - - - +--------+--------+--------+
  599. //
  600. // Accumulator
  601. MEGDNN_ATTRIBUTE_TARGET("dotprod")
  602. static void kern_4x12(
  603. const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
  604. bool is_first_k, int m_remain) {
  605. K /= 4;
  606. const int8_t* a_ptr = packA;
  607. const int8_t* b_ptr = packB;
  608. // Fix up for odd lengths - set a flag if K is odd, but make
  609. // sure we round up the iteration count.
  610. int oddk = (K & 1);
  611. int k = K / 2;
  612. int32x4_t a0;
  613. int32x4_t b0;
  614. int32x4_t b1;
  615. int32x4_t b2;
  616. int32x4_t a0a;
  617. int32x4_t b0a;
  618. int32x4_t b1a;
  619. int32x4_t b2a;
  620. LDC = LDC * sizeof(int32_t);
  621. int32_t* outptr0 = output;
  622. int32_t* outptr1;
  623. int32_t* outptr2;
  624. int32_t* outptr3;
  625. size_t x0;
  626. // clang-format off
  627. #define LOAD_LINE(v1, v2, v3, m) \
  628. "cbz %[x0], 100f\n" \
  629. "ldp " v1 "," v2 ", [%[outptr" m "]]\n" \
  630. "ldr " v3 ", [%[outptr" m "], #32]\n" \
  631. "subs %[x0], %[x0], #1\n"
  632. #define LOAD_C \
  633. "mov %[x0], %x[m_remain]\n" \
  634. LOAD_LINE("q8", "q12", "q16", "0") \
  635. LOAD_LINE("q9", "q13", "q17", "1") \
  636. LOAD_LINE("q10", "q14", "q18", "2") \
  637. LOAD_LINE("q11", "q15", "q19", "3") \
  638. "100:\n"
  639. #define STORE_LINE(v1, v2, v3, m) \
  640. "cbz %[x0], 101f\n" \
  641. "stp " v1 "," v2", [%[outptr" m "]]\n" \
  642. "str " v3 ", [%[outptr" m "], #32]\n" \
  643. "subs %[x0], %[x0], #1\n"
  644. #define STORE_C \
  645. "mov %[x0], %x[m_remain]\n" \
  646. STORE_LINE("q8", "q12", "q16", "0") \
  647. STORE_LINE("q9", "q13", "q17", "1") \
  648. STORE_LINE("q10", "q14", "q18", "2") \
  649. STORE_LINE("q11", "q15", "q19", "3") \
  650. "101:\n"
  651. // clang-format on
  652. asm volatile(
  653. // load accumulator C
  654. "add %[outptr1], %[outptr0], %x[LDC]\n"
  655. "add %[outptr2], %[outptr1], %x[LDC]\n"
  656. "add %[outptr3], %[outptr2], %x[LDC]\n"
  657. "cmp %w[is_first_k], #1\n"
  658. "beq 1f\n" LOAD_C
  659. "b 2f\n"
  660. "1:\n"
  661. "eor v8.16b, v8.16b, v8.16b\n"
  662. "eor v9.16b, v9.16b, v9.16b\n"
  663. "eor v10.16b, v10.16b, v10.16b\n"
  664. "eor v11.16b, v11.16b, v11.16b\n"
  665. "eor v12.16b, v12.16b, v12.16b\n"
  666. "eor v13.16b, v13.16b, v13.16b\n"
  667. "eor v14.16b, v14.16b, v14.16b\n"
  668. "eor v15.16b, v15.16b, v15.16b\n"
  669. "eor v16.16b, v16.16b, v16.16b\n"
  670. "eor v17.16b, v17.16b, v17.16b\n"
  671. "eor v18.16b, v18.16b, v18.16b\n"
  672. "eor v19.16b, v19.16b, v19.16b\n"
  673. "2: \n"
  674. "cbz %w[oddk], 3f\n"
  675. // parse the oddk
  676. "ldr %q[a0], [%[a_ptr]], #16\n"
  677. "ldr %q[b0], [%[b_ptr]], #16\n"
  678. "ldr %q[b1], [%[b_ptr]], #16\n"
  679. "ldr %q[b2], [%[b_ptr]], #16\n"
  680. "sdot v8.4s, %[b0].16b, %[a0].4b[0]\n"
  681. "sdot v9.4s, %[b0].16b, %[a0].4b[1]\n"
  682. "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n"
  683. "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
  684. "sdot v12.4s, %[b1].16b, %[a0].4b[0]\n"
  685. "sdot v13.4s, %[b1].16b, %[a0].4b[1]\n"
  686. "sdot v14.4s, %[b1].16b, %[a0].4b[2]\n"
  687. "sdot v15.4s, %[b1].16b, %[a0].4b[3]\n"
  688. "sdot v16.4s, %[b2].16b, %[a0].4b[0]\n"
  689. "sdot v17.4s, %[b2].16b, %[a0].4b[1]\n"
  690. "sdot v18.4s, %[b2].16b, %[a0].4b[2]\n"
  691. "sdot v19.4s, %[b2].16b, %[a0].4b[3]\n"
  692. "cbz %w[k], 4f\n"
  693. // Loop proper
  694. "3:\n"
  695. "ldr %q[a0], [%[a_ptr]], #16\n"
  696. "ldr %q[b0], [%[b_ptr]], #16\n"
  697. "ldr %q[b1], [%[b_ptr]], #16\n"
  698. "ldr %q[b2], [%[b_ptr]], #16\n"
  699. "ldr %q[a0a], [%[a_ptr]], #16\n"
  700. "ldr %q[b0a], [%[b_ptr]], #16\n"
  701. "ldr %q[b1a], [%[b_ptr]], #16\n"
  702. "ldr %q[b2a], [%[b_ptr]], #16\n"
  703. "sdot v8.4s, %[b0].16b, %[a0].4b[0]\n"
  704. "sdot v9.4s, %[b0].16b, %[a0].4b[1]\n"
  705. "sdot v10.4s, %[b0].16b, %[a0].4b[2]\n"
  706. "sdot v11.4s, %[b0].16b, %[a0].4b[3]\n"
  707. "sdot v12.4s, %[b1].16b, %[a0].4b[0]\n"
  708. "sdot v13.4s, %[b1].16b, %[a0].4b[1]\n"
  709. "sdot v14.4s, %[b1].16b, %[a0].4b[2]\n"
  710. "sdot v15.4s, %[b1].16b, %[a0].4b[3]\n"
  711. "sdot v16.4s, %[b2].16b, %[a0].4b[0]\n"
  712. "sdot v17.4s, %[b2].16b, %[a0].4b[1]\n"
  713. "sdot v18.4s, %[b2].16b, %[a0].4b[2]\n"
  714. "sdot v19.4s, %[b2].16b, %[a0].4b[3]\n"
  715. "sdot v8.4s , %[b0a].16b, %[a0a].4b[0]\n"
  716. "sdot v9.4s , %[b0a].16b, %[a0a].4b[1]\n"
  717. "sdot v10.4s, %[b0a].16b, %[a0a].4b[2]\n"
  718. "sdot v11.4s, %[b0a].16b, %[a0a].4b[3]\n"
  719. "sdot v12.4s, %[b1a].16b, %[a0a].4b[0]\n"
  720. "sdot v13.4s, %[b1a].16b, %[a0a].4b[1]\n"
  721. "sdot v14.4s, %[b1a].16b, %[a0a].4b[2]\n"
  722. "sdot v15.4s, %[b1a].16b, %[a0a].4b[3]\n"
  723. "sdot v16.4s, %[b2a].16b, %[a0a].4b[0]\n"
  724. "sdot v17.4s, %[b2a].16b, %[a0a].4b[1]\n"
  725. "sdot v18.4s, %[b2a].16b, %[a0a].4b[2]\n"
  726. "sdot v19.4s, %[b2a].16b, %[a0a].4b[3]\n"
  727. "subs %w[k], %w[k], #1\n"
  728. "bne 3b\n"
  729. "4:\n" STORE_C
  730. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k),
  731. [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k),
  732. [m_remain] "+r"(m_remain), [LDC] "+r"(LDC), [a0] "=w"(a0),
  733. [a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2),
  734. [b0a] "=w"(b0a), [b1a] "=w"(b1a), [b2a] "=w"(b2a),
  735. [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
  736. [x0] "=r"(x0)
  737. :
  738. : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
  739. "v19", "memory", "cc");
  740. #undef LOAD_LINE
  741. #undef LOAD_C
  742. #undef STORE_LINE
  743. #undef STORE_C
  744. }
  745. // Overview of register layout:
  746. //
  747. // A (4x4)x2 cell of Rhs is stored in 8bit in q2-q3.
  748. // A 4x4x2 cell of Lhs is stored in 8bit in q0-q1, q4-a5
  749. // A 8x4 block of accumulators is stored in 8bit in q4--q7.
  750. //
  751. // +--------+
  752. // |v2[0-16]|
  753. // Rhs +--------+
  754. // |v3[0-16]|
  755. // +--------+
  756. // | |
  757. //
  758. // Lhs | |
  759. //
  760. // +-------+-------+ - - - - +--------+
  761. // |v0[0-4]|v4[0-4]| | v6[0-4]|
  762. // |v0[0-4]|v4[0-4]| | v7[0-4]|
  763. // |v0[0-4]|v4[0-4]| | v8[0-4]|
  764. // |v0[0-4]|v4[0-4]| | v9[0-4]|
  765. // |v1[0-4]|v5[0-4]| |v10[0-4]|
  766. // |v1[0-4]|v5[0-4]| |v11[0-4]|
  767. // |v1[0-4]|v5[0-4]| |v12[0-4]|
  768. // |v1[0-4]|v5[0-4]| |v13[0-4]|
  769. // +-------+-------+ - - - - +---------+
  770. //
  771. // Accumulator
  772. MEGDNN_ATTRIBUTE_TARGET("dotprod")
  773. static void kern_8x4(
  774. const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
  775. bool is_first_k, int n_remain) {
  776. K /= 4;
  777. const int8_t* a_ptr = packA;
  778. const int8_t* b_ptr = packB;
  779. // Fix up for odd lengths - set a flag if K is odd, but make
  780. // sure we round up the iteration count.
  781. int oddk = (K & 1);
  782. int k = K / 2;
  783. int32x4_t a0;
  784. int32x4_t a1;
  785. int32x4_t b0;
  786. int32x4_t b0a;
  787. int32x4_t a0a;
  788. int32x4_t a1a;
  789. LDC = LDC * sizeof(int32_t);
  790. int32_t* outptr0 = output;
  791. int32_t* outptr1;
  792. int32_t* outptr2;
  793. int32_t* outptr3;
  794. int32_t* outptr4;
  795. int32_t* outptr5;
  796. int32_t* outptr6;
  797. int32_t* outptr7;
  798. size_t x0;
  799. // clang-format off
  800. #define LOAD_LINE(reg_index, n) \
  801. "mov %[x0], %[outptr" n "]\n" \
  802. "cmp %w[n_remain], #4\n" \
  803. "blt 100" n "f\n" \
  804. "ldr q" reg_index ", [%[x0]] \n" \
  805. "b 101" n "f\n" \
  806. "100" n ":\n" \
  807. "cmp %w[n_remain], #0\n" \
  808. "beq 101" n "f\n" \
  809. "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \
  810. "cmp %w[n_remain], #1\n" \
  811. "beq 101" n "f\n" \
  812. "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \
  813. "cmp %w[n_remain], #2\n" \
  814. "beq 101" n "f\n" \
  815. "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \
  816. "101" n ":\n"
  817. #define LOAD_C \
  818. LOAD_LINE("6", "0") \
  819. LOAD_LINE("7", "1") \
  820. LOAD_LINE("8", "2") \
  821. LOAD_LINE("9", "3") \
  822. LOAD_LINE("10", "4") \
  823. LOAD_LINE("11", "5") \
  824. LOAD_LINE("12", "6") \
  825. LOAD_LINE("13", "7")
  826. #define STORE_LINE(reg_index, n) \
  827. "mov %[x0], %[outptr" n "]\n" \
  828. "cmp %w[n_remain], #4\n" \
  829. "blt 102" n "f\n" \
  830. "str q" reg_index ", [%[x0]]\n" \
  831. "b 103" n "f\n" \
  832. "102" n ":\n" \
  833. "cmp %w[n_remain], #0\n" \
  834. "beq 103" n "f\n" \
  835. "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \
  836. "cmp %w[n_remain], #1\n" \
  837. "beq 103" n "f\n" \
  838. "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \
  839. "cmp %w[n_remain], #2\n" \
  840. "beq 103" n "f\n" \
  841. "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \
  842. "103" n ":\n"
  843. #define STORE_C \
  844. STORE_LINE("6", "0") \
  845. STORE_LINE("7", "1") \
  846. STORE_LINE("8", "2") \
  847. STORE_LINE("9", "3") \
  848. STORE_LINE("10", "4") \
  849. STORE_LINE("11", "5") \
  850. STORE_LINE("12", "6") \
  851. STORE_LINE("13", "7")
  852. // clang-format on
  853. asm volatile(
  854. // load accumulator C
  855. "add %[outptr1], %[outptr0], %x[LDC]\n"
  856. "add %[outptr2], %[outptr1], %x[LDC]\n"
  857. "add %[outptr3], %[outptr2], %x[LDC]\n"
  858. "add %[outptr4], %[outptr3], %x[LDC]\n"
  859. "add %[outptr5], %[outptr4], %x[LDC]\n"
  860. "add %[outptr6], %[outptr5], %x[LDC]\n"
  861. "add %[outptr7], %[outptr6], %x[LDC]\n"
  862. "cmp %w[is_first_k], #1\n"
  863. "beq 1f\n" LOAD_C
  864. "b 2f\n"
  865. "1:\n"
  866. "eor v6.16b, v6.16b, v6.16b\n"
  867. "eor v7.16b, v7.16b, v7.16b\n"
  868. "eor v8.16b, v8.16b, v8.16b\n"
  869. "eor v9.16b, v9.16b, v9.16b\n"
  870. "eor v10.16b, v10.16b, v10.16b\n"
  871. "eor v11.16b, v11.16b, v11.16b\n"
  872. "eor v12.16b, v12.16b, v12.16b\n"
  873. "eor v13.16b, v13.16b, v13.16b\n"
  874. "2: \n"
  875. "cbz %w[oddk], 3f\n"
  876. // parse the oddk
  877. "ldr %q[a0], [%[a_ptr]], #16\n"
  878. "ldr %q[b0], [%[b_ptr]], #16\n"
  879. "ldr %q[a1], [%[a_ptr]], #16\n"
  880. "sdot v6.4s , %[b0].16b, %[a0].4b[0]\n"
  881. "sdot v7.4s , %[b0].16b, %[a0].4b[1]\n"
  882. "sdot v8.4s, %[b0].16b, %[a0].4b[2]\n"
  883. "sdot v9.4s, %[b0].16b, %[a0].4b[3]\n"
  884. "sdot v10.4s, %[b0].16b, %[a1].4b[0]\n"
  885. "sdot v11.4s, %[b0].16b, %[a1].4b[1]\n"
  886. "sdot v12.4s, %[b0].16b, %[a1].4b[2]\n"
  887. "sdot v13.4s, %[b0].16b, %[a1].4b[3]\n"
  888. "cbz %w[k], 4f\n"
  889. // Loop proper
  890. "3:\n"
  891. "ldr %q[a0], [%[a_ptr]], #16\n"
  892. "ldr %q[b0], [%[b_ptr]], #16\n"
  893. "ldr %q[a1], [%[a_ptr]], #16\n"
  894. "ldr %q[a0a], [%[a_ptr]], #16\n"
  895. "ldr %q[a1a], [%[a_ptr]], #16\n"
  896. "ldr %q[b0a], [%[b_ptr]], #16\n"
  897. "sdot v6.4s , %[b0].16b, %[a0].4b[0]\n"
  898. "sdot v7.4s , %[b0].16b, %[a0].4b[1]\n"
  899. "sdot v8.4s, %[b0].16b, %[a0].4b[2]\n"
  900. "sdot v9.4s, %[b0].16b, %[a0].4b[3]\n"
  901. "sdot v10.4s, %[b0].16b, %[a1].4b[0]\n"
  902. "sdot v11.4s, %[b0].16b, %[a1].4b[1]\n"
  903. "sdot v12.4s, %[b0].16b, %[a1].4b[2]\n"
  904. "sdot v13.4s, %[b0].16b, %[a1].4b[3]\n"
  905. "sdot v6.4s , %[b0a].16b, %[a0a].4b[0]\n"
  906. "sdot v7.4s , %[b0a].16b, %[a0a].4b[1]\n"
  907. "sdot v8.4s, %[b0a].16b, %[a0a].4b[2]\n"
  908. "sdot v9.4s, %[b0a].16b, %[a0a].4b[3]\n"
  909. "sdot v10.4s, %[b0a].16b, %[a1a].4b[0]\n"
  910. "sdot v11.4s, %[b0a].16b, %[a1a].4b[1]\n"
  911. "sdot v12.4s, %[b0a].16b, %[a1a].4b[2]\n"
  912. "sdot v13.4s, %[b0a].16b, %[a1a].4b[3]\n"
  913. "subs %w[k], %w[k], #1\n"
  914. "bne 3b\n"
  915. "4:\n" STORE_C
  916. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC),
  917. [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k),
  918. [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0),
  919. [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a),
  920. [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1),
  921. [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4),
  922. [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7),
  923. [x0] "=r"(x0)
  924. :
  925. : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", "cc");
  926. #undef LOAD_LINE
  927. #undef LOAD_C
  928. #undef STORE_LINE
  929. #undef STORE_C
  930. }
  931. // Overview of register layout:
  932. //
  933. // A 4x4x2 cell of Rhs is stored in 8bit in q2-q3.
  934. // A 4x4x2 cell of Lhs is stored in 8bit in q0-q1
  935. // A 4x4x2 block of accumulators is stored in 8bit in q4--q7.
  936. //
  937. // +--------+
  938. // | v2[0-7]|
  939. // Rhs +--------+
  940. // | v3[0-7]|
  941. // +--------+
  942. // | |
  943. //
  944. // Lhs | |
  945. //
  946. // +-------+-------+ - - - - +--------+
  947. // |v0[0-4]|v1[0-4]| | v4[0-7]|
  948. // |v0[0-4]|v1[0-4]| | v5[0-7]|
  949. // |v0[0-4]|v1[0-4]| | v6[0-7]|
  950. // |v0[0-4]|v1[0-4]| | v7[0-7]|
  951. // +-------+-------+ - - - - +--------+
  952. //
  953. // Accumulator
  954. MEGDNN_ATTRIBUTE_TARGET("dotprod")
  955. static void kern_4x4(
  956. const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC,
  957. bool is_first_k, int m_remain, int n_remain) {
  958. K /= 4;
  959. const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA);
  960. const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB);
  961. // Fix up for odd lengths - set a flag if K is odd, but make
  962. // sure we round up the iteration count.
  963. int oddk = (K & 1);
  964. int k = K / 2;
  965. int32x4_t a0;
  966. int32x4_t a0a;
  967. int32x4_t b0;
  968. int32x4_t b0a;
  969. LDC = LDC * sizeof(int32_t);
  970. int32_t* outptr0 = output;
  971. int32_t* outptr1;
  972. int32_t* outptr2;
  973. int32_t* outptr3;
  974. size_t x0, x1;
  975. // clang-format off
  976. #define LOAD_LINE(reg_index, n) \
  977. "cbz %[x1], 102f\n" \
  978. "mov %[x0], %[outptr" n "]\n" \
  979. "cmp %w[n_remain], #4\n" \
  980. "blt 100" n "f\n" \
  981. "ldr q" reg_index ", [%[x0]]\n" \
  982. "b 101" n "f\n" \
  983. "100" n ":\n" \
  984. "cmp %w[n_remain], #0\n" \
  985. "beq 101" n "f\n" \
  986. "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \
  987. "cmp %w[n_remain], #1\n" \
  988. "beq 101" n "f\n" \
  989. "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \
  990. "cmp %w[n_remain], #2\n" \
  991. "beq 101" n "f\n" \
  992. "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \
  993. "101" n ":\n" \
  994. "subs %[x1], %[x1], #1\n"
  995. #define LOAD_C \
  996. "mov %[x1], %x[m_remain]\n" \
  997. LOAD_LINE("4", "0") \
  998. LOAD_LINE("5", "1") \
  999. LOAD_LINE("6", "2") \
  1000. LOAD_LINE("7", "3") \
  1001. "102:\n"
  1002. #define STORE_LINE(reg_index, n) \
  1003. "cbz %[x1], 105f\n" \
  1004. "mov %[x0], %[outptr" n "]\n" \
  1005. "cmp %w[n_remain], #4\n" \
  1006. "blt 103" n "f\n" \
  1007. "str q" reg_index ", [%[x0]]\n" \
  1008. "b 104" n "f\n" \
  1009. "103" n ":\n" \
  1010. "cmp %w[n_remain], #0\n" \
  1011. "beq 104" n "f\n" \
  1012. "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \
  1013. "cmp %w[n_remain], #1\n" \
  1014. "beq 104" n "f\n" \
  1015. "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \
  1016. "cmp %w[n_remain], #2\n" \
  1017. "beq 104" n "f\n" \
  1018. "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \
  1019. "104" n ":\n" \
  1020. "subs %[x1], %[x1], #1\n"
  1021. #define STORE_C \
  1022. "mov %[x1], %x[m_remain]\n" \
  1023. STORE_LINE("4", "0") \
  1024. STORE_LINE("5", "1") \
  1025. STORE_LINE("6", "2") \
  1026. STORE_LINE("7", "3") \
  1027. "105:\n"
  1028. // clang-format on
  1029. asm volatile(
  1030. // load accumulator C
  1031. "add %[outptr1], %[outptr0], %x[LDC]\n"
  1032. "add %[outptr2], %[outptr1], %x[LDC]\n"
  1033. "add %[outptr3], %[outptr2], %x[LDC]\n"
  1034. "cmp %w[is_first_k], #1\n"
  1035. "beq 1f\n" //
  1036. LOAD_C //
  1037. "b 2f\n"
  1038. "1:\n"
  1039. "eor v4.16b, v4.16b, v4.16b\n"
  1040. "eor v5.16b, v5.16b, v5.16b\n"
  1041. "eor v6.16b, v6.16b, v6.16b\n"
  1042. "eor v7.16b, v7.16b, v7.16b\n"
  1043. "2: \n"
  1044. "cbz %w[oddk], 3f\n"
  1045. // parse the oddk
  1046. "ldr %q[a0], [%[a_ptr]], #16\n"
  1047. "ldr %q[b0], [%[b_ptr]], #16\n"
  1048. "sdot v4.4s , %[b0].16b, %[a0].4b[0]\n"
  1049. "sdot v5.4s , %[b0].16b, %[a0].4b[1]\n"
  1050. "sdot v6.4s, %[b0].16b, %[a0].4b[2]\n"
  1051. "sdot v7.4s, %[b0].16b, %[a0].4b[3]\n"
  1052. "cbz %w[k], 4f\n"
  1053. // Loop proper
  1054. "3:\n"
  1055. "ldr %q[a0], [%[a_ptr]], #16\n"
  1056. "ldr %q[b0], [%[b_ptr]], #16\n"
  1057. "ldr %q[a0a], [%[a_ptr]], #16\n"
  1058. "ldr %q[b0a], [%[b_ptr]], #16\n"
  1059. "sdot v4.4s , %[b0].16b, %[a0].4b[0]\n"
  1060. "sdot v5.4s , %[b0].16b, %[a0].4b[1]\n"
  1061. "sdot v6.4s, %[b0].16b, %[a0].4b[2]\n"
  1062. "sdot v7.4s, %[b0].16b, %[a0].4b[3]\n"
  1063. "sdot v4.4s , %[b0a].16b, %[a0a].4b[0]\n"
  1064. "sdot v5.4s , %[b0a].16b, %[a0a].4b[1]\n"
  1065. "sdot v6.4s, %[b0a].16b, %[a0a].4b[2]\n"
  1066. "sdot v7.4s, %[b0a].16b, %[a0a].4b[3]\n"
  1067. "subs %w[k], %w[k], #1\n"
  1068. "bne 3b\n"
  1069. "4:\n" STORE_C
  1070. : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk),
  1071. [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain),
  1072. [m_remain] "+r"(m_remain), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0),
  1073. [k] "+r"(k), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0),
  1074. [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2),
  1075. [outptr3] "=r"(outptr3), [x0] "=r"(x0), [x1] "=r"(x1)
  1076. :
  1077. : "v4", "v5", "v6", "v7", "memory", "cc");
  1078. #undef LOAD_LINE
  1079. #undef LOAD_C
  1080. #undef STORE_LINE
  1081. #undef STORE_C
  1082. }
  1083. static void gemm_s8_8x12_pack_A_n(
  1084. dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
  1085. int kmax) {
  1086. int8_t zerobuff[16];
  1087. std::memset(zerobuff, 0, sizeof(int8_t) * 16);
  1088. int y = y0;
  1089. for (; y + 7 < ymax; y += 8) {
  1090. const int8_t* inptr0 = inptr + y * ldin + k0;
  1091. const int8_t* inptr1 = inptr0 + ldin;
  1092. const int8_t* inptr2 = inptr1 + ldin;
  1093. const int8_t* inptr3 = inptr2 + ldin;
  1094. const int8_t* inptr4 = inptr3 + ldin;
  1095. const int8_t* inptr5 = inptr4 + ldin;
  1096. const int8_t* inptr6 = inptr5 + ldin;
  1097. const int8_t* inptr7 = inptr6 + ldin;
  1098. prefetch_2x(inptr0);
  1099. prefetch_2x(inptr1);
  1100. prefetch_2x(inptr2);
  1101. prefetch_2x(inptr3);
  1102. prefetch_2x(inptr4);
  1103. prefetch_2x(inptr5);
  1104. prefetch_2x(inptr6);
  1105. prefetch_2x(inptr7);
  1106. int K = kmax - k0;
  1107. //! read 8 * 4 in each row
  1108. for (; K > 15; K -= 16) {
  1109. interleave_8x4_4_b(
  1110. inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
  1111. outptr);
  1112. }
  1113. if (K > 0) {
  1114. interleave_8(
  1115. inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
  1116. outptr, 4, K);
  1117. }
  1118. }
  1119. for (; y < ymax; y += 4) {
  1120. const int8_t* inptr0 = inptr + y * ldin + k0;
  1121. const int8_t* inptr1 = inptr0 + ldin;
  1122. const int8_t* inptr2 = inptr1 + ldin;
  1123. const int8_t* inptr3 = inptr2 + ldin;
  1124. prefetch_2x(inptr0);
  1125. prefetch_2x(inptr1);
  1126. prefetch_2x(inptr2);
  1127. prefetch_2x(inptr3);
  1128. int K = kmax - k0;
  1129. //! read 4 * 4 in each row
  1130. for (; K > 15; K -= 16) {
  1131. if (y + 3 >= ymax) {
  1132. switch (y + 3 - ymax) {
  1133. case 2:
  1134. inptr1 = zerobuff;
  1135. case 1:
  1136. inptr2 = zerobuff;
  1137. case 0:
  1138. inptr3 = zerobuff;
  1139. break;
  1140. default:
  1141. megdnn_assert(0);
  1142. }
  1143. }
  1144. interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, outptr);
  1145. }
  1146. if (K > 0) {
  1147. if (y + 3 >= ymax) {
  1148. switch (y + 3 - ymax) {
  1149. case 2:
  1150. inptr1 = zerobuff;
  1151. case 1:
  1152. inptr2 = zerobuff;
  1153. case 0:
  1154. inptr3 = zerobuff;
  1155. break;
  1156. default:
  1157. megdnn_assert(0);
  1158. }
  1159. }
  1160. interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, K);
  1161. }
  1162. }
  1163. }
  1164. static void gemm_s8_8x12_pack_A_t(
  1165. dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
  1166. int8_t zerobuff[16];
  1167. std::memset(zerobuff, 0, sizeof(int8_t) * 16);
  1168. const int ksize = kmax - k0;
  1169. const int ksize8 = round_up<int>(ksize, 4) * 8;
  1170. const int ksize4 = round_up(ksize, 4) * 4;
  1171. int8_t* outptr = out;
  1172. int8_t* outptr_base = out;
  1173. //! 4x4 block output start pos
  1174. int8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8;
  1175. int k = k0;
  1176. for (; k < kmax; k += 4) {
  1177. const int8_t* inptr0 = in + k * ldin + x0;
  1178. const int8_t* inptr1 = inptr0 + ldin;
  1179. const int8_t* inptr2 = inptr1 + ldin;
  1180. const int8_t* inptr3 = inptr2 + ldin;
  1181. prefetch_2x(inptr0);
  1182. prefetch_2x(inptr1);
  1183. prefetch_2x(inptr2);
  1184. prefetch_2x(inptr3);
  1185. int x = x0;
  1186. outptr = outptr_base;
  1187. for (; x + 7 < xmax; x += 8) {
  1188. if (k + 3 >= kmax) {
  1189. switch (k + 3 - kmax) {
  1190. case 2:
  1191. inptr1 = zerobuff;
  1192. case 1:
  1193. inptr2 = zerobuff;
  1194. case 0:
  1195. inptr3 = zerobuff;
  1196. break;
  1197. default:
  1198. megdnn_assert(0);
  1199. }
  1200. }
  1201. transpose_8x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr);
  1202. outptr += ksize8;
  1203. }
  1204. outptr = outptr_base4;
  1205. for (; x + 3 < xmax; x += 4) {
  1206. if (k + 3 >= kmax) {
  1207. switch (k + 3 - kmax) {
  1208. case 2:
  1209. inptr1 = zerobuff;
  1210. case 1:
  1211. inptr2 = zerobuff;
  1212. case 0:
  1213. inptr3 = zerobuff;
  1214. break;
  1215. default:
  1216. megdnn_assert(0);
  1217. }
  1218. }
  1219. transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, 4);
  1220. outptr += ksize4;
  1221. }
  1222. if (x < xmax) {
  1223. if (k + 3 >= kmax) {
  1224. switch (k + 3 - kmax) {
  1225. case 2:
  1226. inptr1 = zerobuff;
  1227. case 1:
  1228. inptr2 = zerobuff;
  1229. case 0:
  1230. inptr3 = zerobuff;
  1231. break;
  1232. default:
  1233. megdnn_assert(0);
  1234. }
  1235. }
  1236. transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, xmax - x);
  1237. }
  1238. outptr_base += 8 * 4;
  1239. outptr_base4 += 4 * 4;
  1240. }
  1241. }
  1242. static void gemm_s8_8x12_pack_B_n(
  1243. dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) {
  1244. int8_t zerobuff[16];
  1245. std::memset(zerobuff, 0, sizeof(int8_t) * 16);
  1246. const int ksize = kmax - k0;
  1247. const int ksize12 = round_up<int>(ksize, 4) * 12;
  1248. const int ksize4 = round_up(ksize, 4) * 4;
  1249. int8_t* outptr = out;
  1250. int8_t* outptr_base = out;
  1251. //! 4x4 block output start pos
  1252. int8_t* outptr_base4 = out + ((xmax - x0) / 12) * ksize12;
  1253. int k = k0;
  1254. for (; k < kmax; k += 4) {
  1255. const int8_t* inptr0 = in + k * ldin + x0;
  1256. const int8_t* inptr1 = inptr0 + ldin;
  1257. const int8_t* inptr2 = inptr1 + ldin;
  1258. const int8_t* inptr3 = inptr2 + ldin;
  1259. prefetch_2x(inptr0);
  1260. prefetch_2x(inptr1);
  1261. prefetch_2x(inptr2);
  1262. prefetch_2x(inptr3);
  1263. int x = x0;
  1264. outptr = outptr_base;
  1265. for (; x + 11 < xmax; x += 12) {
  1266. if (k + 3 >= kmax) {
  1267. switch (k + 3 - kmax) {
  1268. case 2:
  1269. inptr1 = zerobuff;
  1270. case 1:
  1271. inptr2 = zerobuff;
  1272. case 0:
  1273. inptr3 = zerobuff;
  1274. break;
  1275. default:
  1276. megdnn_assert(0);
  1277. }
  1278. }
  1279. transpose_12x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr);
  1280. outptr += ksize12;
  1281. }
  1282. outptr = outptr_base4;
  1283. for (; x + 3 < xmax; x += 4) {
  1284. if (k + 3 >= kmax) {
  1285. switch (k + 3 - kmax) {
  1286. case 2:
  1287. inptr1 = zerobuff;
  1288. case 1:
  1289. inptr2 = zerobuff;
  1290. case 0:
  1291. inptr3 = zerobuff;
  1292. break;
  1293. default:
  1294. megdnn_assert(0);
  1295. }
  1296. }
  1297. transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, 4);
  1298. outptr += ksize4;
  1299. }
  1300. if (x < xmax) {
  1301. if (k + 3 >= kmax) {
  1302. switch (k + 3 - kmax) {
  1303. case 2:
  1304. inptr1 = zerobuff;
  1305. case 1:
  1306. inptr2 = zerobuff;
  1307. case 0:
  1308. inptr3 = zerobuff;
  1309. break;
  1310. default:
  1311. megdnn_assert(0);
  1312. }
  1313. }
  1314. transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, xmax - x);
  1315. }
  1316. outptr_base += 12 * 4;
  1317. outptr_base4 += 4 * 4;
  1318. }
  1319. }
  1320. static void gemm_s8_8x12_pack_B_t(
  1321. dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0,
  1322. int kmax) {
  1323. int8_t zerobuff[16];
  1324. std::memset(zerobuff, 0, sizeof(int8_t) * 16);
  1325. int y = y0;
  1326. for (; y + 11 < ymax; y += 12) {
  1327. const int8_t* inptr0 = inptr + y * ldin + k0;
  1328. const int8_t* inptr1 = inptr0 + ldin;
  1329. const int8_t* inptr2 = inptr1 + ldin;
  1330. const int8_t* inptr3 = inptr2 + ldin;
  1331. const int8_t* inptr4 = inptr3 + ldin;
  1332. const int8_t* inptr5 = inptr4 + ldin;
  1333. const int8_t* inptr6 = inptr5 + ldin;
  1334. const int8_t* inptr7 = inptr6 + ldin;
  1335. const int8_t* inptr8 = inptr7 + ldin;
  1336. const int8_t* inptr9 = inptr8 + ldin;
  1337. const int8_t* inptr10 = inptr9 + ldin;
  1338. const int8_t* inptr11 = inptr10 + ldin;
  1339. prefetch_2x(inptr0);
  1340. prefetch_2x(inptr1);
  1341. prefetch_2x(inptr2);
  1342. prefetch_2x(inptr3);
  1343. prefetch_2x(inptr4);
  1344. prefetch_2x(inptr5);
  1345. prefetch_2x(inptr6);
  1346. prefetch_2x(inptr7);
  1347. prefetch_2x(inptr8);
  1348. prefetch_2x(inptr9);
  1349. prefetch_2x(inptr10);
  1350. prefetch_2x(inptr11);
  1351. int K = kmax - k0;
  1352. //! read 12 * 4 in each row
  1353. for (; K > 15; K -= 16) {
  1354. interleave_12x4_4_b(
  1355. inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
  1356. inptr8, inptr9, inptr10, inptr11, outptr);
  1357. }
  1358. if (K > 0) {
  1359. interleave_12(
  1360. inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7,
  1361. inptr8, inptr9, inptr10, inptr11, outptr, 4, K);
  1362. }
  1363. }
  1364. for (; y < ymax; y += 4) {
  1365. const int8_t* inptr0 = inptr + y * ldin + k0;
  1366. const int8_t* inptr1 = inptr0 + ldin;
  1367. const int8_t* inptr2 = inptr1 + ldin;
  1368. const int8_t* inptr3 = inptr2 + ldin;
  1369. prefetch_2x(inptr0);
  1370. prefetch_2x(inptr1);
  1371. prefetch_2x(inptr2);
  1372. prefetch_2x(inptr3);
  1373. int K = kmax - k0;
  1374. //! read 4 * 4 in each row
  1375. for (; K > 15; K -= 16) {
  1376. if (y + 3 >= ymax) {
  1377. switch (y + 3 - ymax) {
  1378. case 2:
  1379. inptr1 = zerobuff;
  1380. case 1:
  1381. inptr2 = zerobuff;
  1382. case 0:
  1383. inptr3 = zerobuff;
  1384. break;
  1385. default:
  1386. megdnn_assert(0);
  1387. }
  1388. }
  1389. interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, outptr);
  1390. }
  1391. if (K > 0) {
  1392. if (y + 3 >= ymax) {
  1393. switch (y + 3 - ymax) {
  1394. case 2:
  1395. inptr1 = zerobuff;
  1396. case 1:
  1397. inptr2 = zerobuff;
  1398. case 0:
  1399. inptr3 = zerobuff;
  1400. break;
  1401. default:
  1402. megdnn_assert(0);
  1403. }
  1404. }
  1405. interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, K);
  1406. }
  1407. }
  1408. }
  1409. } // namespace matmul_8x12x4
  1410. } // namespace aarch64
  1411. } // namespace megdnn
  1412. #endif
  1413. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台