|
- /**
- * \file dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h
- * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- *
- * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- */
-
- #if !(__ARM_FEATURE_DOTPROD)
- #include "src/aarch64/matrix_mul/asm/common.h"
- #include "src/arm_common/simd_macro/marm_neon.h"
-
- namespace megdnn {
- namespace aarch64 {
- namespace matmul_8x8x8 {
-
- /**
- * Overview of register layout:
- *
- * A 8x8x8 cell of Rhs is stored in 8bit in q26-q27
- * A 8x8x8 cell of Lhs is stored in 8bit in q0-q7
- * A 8x8 block of accumulators is stored in 32bit in q8-q23
- *
- * +--------+--------+
- * |v26[0-8]|v27[0-8]|
- * Rhs +--------+--------+
- * Lhs | | |
- *
- * +--------+ - - - - +-----------------+
- * |v0[0-8]| | v8[0-4]| v9[0-4]|
- * |v1[0-8]| |v10[0-4]|v11[0-4]|
- * |v2[0-8]| |v12[0-4]|v13[0-4]|
- * |v3[0-8]| |v14[0-4]|v15[0-4]|
- * |v4[0-8]| |v16[0-4]|v17[0-4]|
- * |v5[0-8]| |v18[0-4]|v19[0-4]|
- * |v6[0-8]| |v20[0-4]|v21[0-4]|
- * |v7[0-8]| |v22[0-4]|v23[0-4]|
- * +--------+ - - - - +-----------------+
- *
- * Accumulator
- */
-
- static void kern_8x8(const int8_t* packA, const int8_t* packB, int K,
- int32_t* output, int LDC, bool is_first_k) {
- K /= 8;
- const int8_t* a_ptr = packA;
- const int8_t* b_ptr = packB;
-
- LDC = LDC * sizeof(int32_t);
-
- asm volatile(
- // load accumulator C
- "add x1, %[output], %x[LDC]\n"
- "add x2, x1, %x[LDC]\n"
- "add x3, x2, %x[LDC]\n"
- "add x4, x3, %x[LDC]\n"
- "add x5, x4, %x[LDC]\n"
- "add x6, x5, %x[LDC]\n"
- "add x7, x6, %x[LDC]\n"
- "cmp %w[is_first_k], #1\n"
- "beq 1f\n"
-
- "ldp q8, q9, [%[output]]\n"
- "ldp q10, q11, [x1]\n"
- "ldp q12, q13, [x2]\n"
- "ldp q14, q15, [x3]\n"
- "ldp q16, q17, [x4]\n"
- "ldp q18, q19, [x5]\n"
- "ldp q20, q21, [x6]\n"
- "ldp q22, q23, [x7]\n"
- "b 2f\n"
-
- "1:\n"
- "eor v8.16b, v8.16b, v8.16b\n"
- "eor v9.16b, v9.16b, v9.16b\n"
- "eor v10.16b, v10.16b, v10.16b\n"
- "eor v11.16b, v11.16b, v11.16b\n"
- "eor v12.16b, v12.16b, v12.16b\n"
- "eor v13.16b, v13.16b, v13.16b\n"
- "eor v14.16b, v14.16b, v14.16b\n"
- "eor v15.16b, v15.16b, v15.16b\n"
- "eor v16.16b, v16.16b, v16.16b\n"
- "eor v17.16b, v17.16b, v17.16b\n"
- "eor v18.16b, v18.16b, v18.16b\n"
- "eor v19.16b, v19.16b, v19.16b\n"
- "eor v20.16b, v20.16b, v20.16b\n"
- "eor v21.16b, v21.16b, v21.16b\n"
- "eor v22.16b, v22.16b, v22.16b\n"
- "eor v23.16b, v23.16b, v23.16b\n"
-
- "2: \n"
- "ld1 {v26.8b}, [%[b_ptr]], 8\n"
- "ld1 {v0.8b}, [%[a_ptr]], 8\n"
- "ld1 {v1.8b}, [%[a_ptr]], 8\n"
- "ld1 {v2.8b}, [%[a_ptr]], 8\n"
- "ld1 {v3.8b}, [%[a_ptr]], 8\n"
- "ld1 {v4.8b}, [%[a_ptr]], 8\n"
- "ld1 {v5.8b}, [%[a_ptr]], 8\n"
- "ld1 {v6.8b}, [%[a_ptr]], 8\n"
- "ld1 {v7.8b}, [%[a_ptr]], 8\n"
- "sshll v26.8h, v26.8b, #0\n"
- "sshll v0.8h, v0.8b, #0\n"
- "sshll v1.8h, v1.8b, #0\n"
- "sshll v2.8h, v2.8b, #0\n"
- "sshll v3.8h, v3.8b, #0\n"
- "sshll v4.8h, v4.8b, #0\n"
- "sshll v5.8h, v5.8b, #0\n"
- "sshll v6.8h, v6.8b, #0\n"
- "sshll v7.8h, v7.8b, #0\n"
-
- "ld1 {v27.8b}, [%[b_ptr]], 8\n"
- "smlal v8.4s, v26.4h, v0.h[0]\n"
- "smlal v10.4s, v26.4h, v1.h[0]\n"
- "smlal v12.4s, v26.4h, v2.h[0]\n"
- "smlal v14.4s, v26.4h, v3.h[0]\n"
- "smlal v16.4s, v26.4h, v4.h[0]\n"
- "smlal v18.4s, v26.4h, v5.h[0]\n"
- "smlal v20.4s, v26.4h, v6.h[0]\n"
- "smlal v22.4s, v26.4h, v7.h[0]\n"
- "sshll v27.8h, v27.8b, #0\n"
- "smlal2 v9.4s, v26.8h, v0.h[0]\n"
- "smlal2 v11.4s, v26.8h, v1.h[0]\n"
- "smlal2 v13.4s, v26.8h, v2.h[0]\n"
- "smlal2 v15.4s, v26.8h, v3.h[0]\n"
- "smlal2 v17.4s, v26.8h, v4.h[0]\n"
- "smlal2 v19.4s, v26.8h, v5.h[0]\n"
- "smlal2 v21.4s, v26.8h, v6.h[0]\n"
- "smlal2 v23.4s, v26.8h, v7.h[0]\n"
-
- "ld1 {v26.8b}, [%[b_ptr]], 8\n"
- "smlal v8.4s, v27.4h, v0.h[1]\n"
- "smlal v10.4s, v27.4h, v1.h[1]\n"
- "smlal v12.4s, v27.4h, v2.h[1]\n"
- "smlal v14.4s, v27.4h, v3.h[1]\n"
- "smlal v16.4s, v27.4h, v4.h[1]\n"
- "smlal v18.4s, v27.4h, v5.h[1]\n"
- "smlal v20.4s, v27.4h, v6.h[1]\n"
- "smlal v22.4s, v27.4h, v7.h[1]\n"
- "sshll v26.8h, v26.8b, #0\n"
- "smlal2 v9.4s, v27.8h, v0.h[1]\n"
- "smlal2 v11.4s, v27.8h, v1.h[1]\n"
- "smlal2 v13.4s, v27.8h, v2.h[1]\n"
- "smlal2 v15.4s, v27.8h, v3.h[1]\n"
- "smlal2 v17.4s, v27.8h, v4.h[1]\n"
- "smlal2 v19.4s, v27.8h, v5.h[1]\n"
- "smlal2 v21.4s, v27.8h, v6.h[1]\n"
- "smlal2 v23.4s, v27.8h, v7.h[1]\n"
-
- "ld1 {v27.8b}, [%[b_ptr]], 8\n"
- "smlal v8.4s, v26.4h, v0.h[2]\n"
- "smlal v10.4s, v26.4h, v1.h[2]\n"
- "smlal v12.4s, v26.4h, v2.h[2]\n"
- "smlal v14.4s, v26.4h, v3.h[2]\n"
- "smlal v16.4s, v26.4h, v4.h[2]\n"
- "smlal v18.4s, v26.4h, v5.h[2]\n"
- "smlal v20.4s, v26.4h, v6.h[2]\n"
- "smlal v22.4s, v26.4h, v7.h[2]\n"
- "sshll v27.8h, v27.8b, #0\n"
- "smlal2 v9.4s, v26.8h, v0.h[2]\n"
- "smlal2 v11.4s, v26.8h, v1.h[2]\n"
- "smlal2 v13.4s, v26.8h, v2.h[2]\n"
- "smlal2 v15.4s, v26.8h, v3.h[2]\n"
- "smlal2 v17.4s, v26.8h, v4.h[2]\n"
- "smlal2 v19.4s, v26.8h, v5.h[2]\n"
- "smlal2 v21.4s, v26.8h, v6.h[2]\n"
- "smlal2 v23.4s, v26.8h, v7.h[2]\n"
-
- "ld1 {v26.8b}, [%[b_ptr]], 8\n"
- "smlal v8.4s, v27.4h, v0.h[3]\n"
- "smlal v10.4s, v27.4h, v1.h[3]\n"
- "smlal v12.4s, v27.4h, v2.h[3]\n"
- "smlal v14.4s, v27.4h, v3.h[3]\n"
- "smlal v16.4s, v27.4h, v4.h[3]\n"
- "smlal v18.4s, v27.4h, v5.h[3]\n"
- "smlal v20.4s, v27.4h, v6.h[3]\n"
- "smlal v22.4s, v27.4h, v7.h[3]\n"
- "sshll v26.8h, v26.8b, #0\n"
- "smlal2 v9.4s, v27.8h, v0.h[3]\n"
- "smlal2 v11.4s, v27.8h, v1.h[3]\n"
- "smlal2 v13.4s, v27.8h, v2.h[3]\n"
- "smlal2 v15.4s, v27.8h, v3.h[3]\n"
- "smlal2 v17.4s, v27.8h, v4.h[3]\n"
- "smlal2 v19.4s, v27.8h, v5.h[3]\n"
- "smlal2 v21.4s, v27.8h, v6.h[3]\n"
- "smlal2 v23.4s, v27.8h, v7.h[3]\n"
-
- "ld1 {v27.8b}, [%[b_ptr]], 8\n"
- "smlal v8.4s, v26.4h, v0.h[4]\n"
- "smlal v10.4s, v26.4h, v1.h[4]\n"
- "smlal v12.4s, v26.4h, v2.h[4]\n"
- "smlal v14.4s, v26.4h, v3.h[4]\n"
- "smlal v16.4s, v26.4h, v4.h[4]\n"
- "smlal v18.4s, v26.4h, v5.h[4]\n"
- "smlal v20.4s, v26.4h, v6.h[4]\n"
- "smlal v22.4s, v26.4h, v7.h[4]\n"
- "sshll v27.8h, v27.8b, #0\n"
- "smlal2 v9.4s, v26.8h, v0.h[4]\n"
- "smlal2 v11.4s, v26.8h, v1.h[4]\n"
- "smlal2 v13.4s, v26.8h, v2.h[4]\n"
- "smlal2 v15.4s, v26.8h, v3.h[4]\n"
- "smlal2 v17.4s, v26.8h, v4.h[4]\n"
- "smlal2 v19.4s, v26.8h, v5.h[4]\n"
- "smlal2 v21.4s, v26.8h, v6.h[4]\n"
- "smlal2 v23.4s, v26.8h, v7.h[4]\n"
-
- "ld1 {v26.8b}, [%[b_ptr]], 8\n"
- "smlal v8.4s, v27.4h, v0.h[5]\n"
- "smlal v10.4s, v27.4h, v1.h[5]\n"
- "smlal v12.4s, v27.4h, v2.h[5]\n"
- "smlal v14.4s, v27.4h, v3.h[5]\n"
- "smlal v16.4s, v27.4h, v4.h[5]\n"
- "smlal v18.4s, v27.4h, v5.h[5]\n"
- "smlal v20.4s, v27.4h, v6.h[5]\n"
- "smlal v22.4s, v27.4h, v7.h[5]\n"
- "sshll v26.8h, v26.8b, #0\n"
- "smlal2 v9.4s, v27.8h, v0.h[5]\n"
- "smlal2 v11.4s, v27.8h, v1.h[5]\n"
- "smlal2 v13.4s, v27.8h, v2.h[5]\n"
- "smlal2 v15.4s, v27.8h, v3.h[5]\n"
- "smlal2 v17.4s, v27.8h, v4.h[5]\n"
- "smlal2 v19.4s, v27.8h, v5.h[5]\n"
- "smlal2 v21.4s, v27.8h, v6.h[5]\n"
- "smlal2 v23.4s, v27.8h, v7.h[5]\n"
-
- "ld1 {v27.8b}, [%[b_ptr]], 8\n"
- "smlal v8.4s, v26.4h, v0.h[6]\n"
- "smlal v10.4s, v26.4h, v1.h[6]\n"
- "smlal v12.4s, v26.4h, v2.h[6]\n"
- "smlal v14.4s, v26.4h, v3.h[6]\n"
- "smlal v16.4s, v26.4h, v4.h[6]\n"
- "smlal v18.4s, v26.4h, v5.h[6]\n"
- "smlal v20.4s, v26.4h, v6.h[6]\n"
- "smlal v22.4s, v26.4h, v7.h[6]\n"
- "sshll v27.8h, v27.8b, #0\n"
- "smlal2 v9.4s, v26.8h, v0.h[6]\n"
- "smlal2 v11.4s, v26.8h, v1.h[6]\n"
- "smlal2 v13.4s, v26.8h, v2.h[6]\n"
- "smlal2 v15.4s, v26.8h, v3.h[6]\n"
- "smlal2 v17.4s, v26.8h, v4.h[6]\n"
- "smlal2 v19.4s, v26.8h, v5.h[6]\n"
- "smlal2 v21.4s, v26.8h, v6.h[6]\n"
- "smlal2 v23.4s, v26.8h, v7.h[6]\n"
-
- "smlal v8.4s, v27.4h, v0.h[7]\n"
- "smlal v10.4s, v27.4h, v1.h[7]\n"
- "smlal v12.4s, v27.4h, v2.h[7]\n"
- "smlal v14.4s, v27.4h, v3.h[7]\n"
- "smlal v16.4s, v27.4h, v4.h[7]\n"
- "smlal v18.4s, v27.4h, v5.h[7]\n"
- "smlal v20.4s, v27.4h, v6.h[7]\n"
- "smlal v22.4s, v27.4h, v7.h[7]\n"
- "smlal2 v9.4s, v27.8h, v0.h[7]\n"
- "smlal2 v11.4s, v27.8h, v1.h[7]\n"
- "smlal2 v13.4s, v27.8h, v2.h[7]\n"
- "smlal2 v15.4s, v27.8h, v3.h[7]\n"
- "smlal2 v17.4s, v27.8h, v4.h[7]\n"
- "smlal2 v19.4s, v27.8h, v5.h[7]\n"
- "smlal2 v21.4s, v27.8h, v6.h[7]\n"
- "smlal2 v23.4s, v27.8h, v7.h[7]\n"
-
- "subs %w[K], %w[K], #1\n"
- "cbnz %w[K], 2b\n"
-
- "3:\n"
- "stp q8, q9, [%[output]]\n"
- "stp q10, q11, [x1]\n"
- "stp q12, q13, [x2]\n"
- "stp q14, q15, [x3]\n"
- "stp q16, q17, [x4]\n"
- "stp q18, q19, [x5]\n"
- "stp q20, q21, [x6]\n"
- "stp q22, q23, [x7]\n"
- : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
- [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
- [output] "+r"(output)
- :
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
- "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
- "v20", "v21", "v22", "v23", "v26", "v27", "x1",
- "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory");
- }
-
- /**
- * Overview of register layout:
- *
- * A 8x4x8 cell of Rhs is stored in 8bit in q16-q17
- * A 8x8x8 cell of Lhs is stored in 8bit in q0-q7
- * A 8x4 block of accumulators is stored in 32bit in q8-q15
- *
- * +--------+
- * |v16[0-4]|
- * Rhs +--------+
- * |v17[0-4]|
- * Lhs +--------+
- *
- * +--------+ - - - - +--------+
- * |v0[0-8]| | v8[0-4]|
- * |v1[0-8]| | v9[0-4]|
- * |v2[0-8]| |v10[0-4]|
- * |v3[0-8]| |v11[0-4]|
- * |v4[0-8]| |v12[0-4]|
- * |v5[0-8]| |v13[0-4]|
- * |v6[0-8]| |v14[0-4]|
- * |v7[0-8]| |v15[0-4]|
- * +--------+ - - - - +--------+
- *
- * Accumulator
- */
-
- static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
- int32_t* output, int LDC, bool is_first_k,
- size_t n_remain) {
- K /= 8;
- const int8_t* a_ptr = packA;
- const int8_t* b_ptr = packB;
-
- LDC = LDC * sizeof(int32_t);
- int32_t* outptr0 = output;
- int32_t* outptr1;
- int32_t* outptr2;
- int32_t* outptr3;
- int32_t* outptr4;
- int32_t* outptr5;
- int32_t* outptr6;
- int32_t* outptr7;
- size_t x0 = 0;
-
- // clang-format off
- #define LOAD_LINE(reg_index, n) \
- "mov %[x0], %[outptr" n "]\n" \
- "cmp %w[n_remain], #4\n" \
- "blt 100" n "f\n" \
- "ldr q" reg_index ", [%[x0]] \n" \
- "b 101" n "f\n" \
- "100" n ":\n" \
- "cmp %w[n_remain], #0\n" \
- "beq 101" n "f\n" \
- "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \
- "cmp %w[n_remain], #1\n" \
- "beq 101" n "f\n" \
- "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \
- "cmp %w[n_remain], #2\n" \
- "beq 101" n "f\n" \
- "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \
- "101" n ":\n"
-
- #define LOAD_C \
- LOAD_LINE("8", "0") \
- LOAD_LINE("9", "1") \
- LOAD_LINE("10", "2") \
- LOAD_LINE("11", "3") \
- LOAD_LINE("12", "4") \
- LOAD_LINE("13", "5") \
- LOAD_LINE("14", "6") \
- LOAD_LINE("15", "7")
-
- #define STORE_LINE(reg_index, n) \
- "mov %[x0], %[outptr" n "]\n" \
- "cmp %w[n_remain], #4\n" \
- "blt 102" n "f\n" \
- "str q" reg_index ", [%[x0]]\n" \
- "b 103" n "f\n" \
- "102" n ":\n" \
- "cmp %w[n_remain], #0\n" \
- "beq 103" n "f\n" \
- "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \
- "cmp %w[n_remain], #1\n" \
- "beq 103" n "f\n" \
- "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \
- "cmp %w[n_remain], #2\n" \
- "beq 103" n "f\n" \
- "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \
- "103" n ":\n"
-
- #define STORE_C \
- STORE_LINE("8", "0") \
- STORE_LINE("9", "1") \
- STORE_LINE("10", "2") \
- STORE_LINE("11", "3") \
- STORE_LINE("12", "4") \
- STORE_LINE("13", "5") \
- STORE_LINE("14", "6") \
- STORE_LINE("15", "7")
-
- // clang-format on
-
- asm volatile(
- // load accumulator C
- "add %[outptr1], %[outptr0], %x[LDC]\n"
- "add %[outptr2], %[outptr1], %x[LDC]\n"
- "add %[outptr3], %[outptr2], %x[LDC]\n"
- "add %[outptr4], %[outptr3], %x[LDC]\n"
- "add %[outptr5], %[outptr4], %x[LDC]\n"
- "add %[outptr6], %[outptr5], %x[LDC]\n"
- "add %[outptr7], %[outptr6], %x[LDC]\n"
- "cmp %w[is_first_k], #1\n"
- "beq 1f\n" LOAD_C
-
- "b 2f\n"
-
- "1:\n"
- "eor v8.16b, v8.16b, v8.16b\n"
- "eor v9.16b, v9.16b, v9.16b\n"
- "eor v10.16b, v10.16b, v10.16b\n"
- "eor v11.16b, v11.16b, v11.16b\n"
- "eor v12.16b, v12.16b, v12.16b\n"
- "eor v13.16b, v13.16b, v13.16b\n"
- "eor v14.16b, v14.16b, v14.16b\n"
- "eor v15.16b, v15.16b, v15.16b\n"
-
- "2: \n"
- "ld1 {v16.s}[0], [%[b_ptr]], 4\n"
- "ld1 {v0.8b}, [%[a_ptr]], 8\n"
- "ld1 {v1.8b}, [%[a_ptr]], 8\n"
- "ld1 {v2.8b}, [%[a_ptr]], 8\n"
- "ld1 {v3.8b}, [%[a_ptr]], 8\n"
- "ld1 {v4.8b}, [%[a_ptr]], 8\n"
- "ld1 {v5.8b}, [%[a_ptr]], 8\n"
- "ld1 {v6.8b}, [%[a_ptr]], 8\n"
- "ld1 {v7.8b}, [%[a_ptr]], 8\n"
- "sshll v16.8h, v16.8b, #0\n"
- "sshll v0.8h, v0.8b, #0\n"
- "sshll v1.8h, v1.8b, #0\n"
- "sshll v2.8h, v2.8b, #0\n"
- "sshll v3.8h, v3.8b, #0\n"
- "sshll v4.8h, v4.8b, #0\n"
- "sshll v5.8h, v5.8b, #0\n"
- "sshll v6.8h, v6.8b, #0\n"
- "sshll v7.8h, v7.8b, #0\n"
-
- "ld1 {v17.s}[0], [%[b_ptr]], 4\n"
- "smlal v8.4s, v16.4h, v0.h[0]\n"
- "smlal v9.4s, v16.4h, v1.h[0]\n"
- "smlal v10.4s, v16.4h, v2.h[0]\n"
- "smlal v11.4s, v16.4h, v3.h[0]\n"
- "sshll v17.8h, v17.8b, #0\n"
- "smlal v12.4s, v16.4h, v4.h[0]\n"
- "smlal v13.4s, v16.4h, v5.h[0]\n"
- "smlal v14.4s, v16.4h, v6.h[0]\n"
- "smlal v15.4s, v16.4h, v7.h[0]\n"
-
- "ld1 {v16.s}[0], [%[b_ptr]], 4\n"
- "smlal v8.4s, v17.4h, v0.h[1]\n"
- "smlal v9.4s, v17.4h, v1.h[1]\n"
- "smlal v10.4s, v17.4h, v2.h[1]\n"
- "smlal v11.4s, v17.4h, v3.h[1]\n"
- "sshll v16.8h, v16.8b, #0\n"
- "smlal v12.4s, v17.4h, v4.h[1]\n"
- "smlal v13.4s, v17.4h, v5.h[1]\n"
- "smlal v14.4s, v17.4h, v6.h[1]\n"
- "smlal v15.4s, v17.4h, v7.h[1]\n"
-
- "ld1 {v17.s}[0], [%[b_ptr]], 4\n"
- "smlal v8.4s, v16.4h, v0.h[2]\n"
- "smlal v9.4s, v16.4h, v1.h[2]\n"
- "smlal v10.4s, v16.4h, v2.h[2]\n"
- "smlal v11.4s, v16.4h, v3.h[2]\n"
- "sshll v17.8h, v17.8b, #0\n"
- "smlal v12.4s, v16.4h, v4.h[2]\n"
- "smlal v13.4s, v16.4h, v5.h[2]\n"
- "smlal v14.4s, v16.4h, v6.h[2]\n"
- "smlal v15.4s, v16.4h, v7.h[2]\n"
-
- "ld1 {v16.s}[0], [%[b_ptr]], 4\n"
- "smlal v8.4s, v17.4h, v0.h[3]\n"
- "smlal v9.4s, v17.4h, v1.h[3]\n"
- "smlal v10.4s, v17.4h, v2.h[3]\n"
- "smlal v11.4s, v17.4h, v3.h[3]\n"
- "sshll v16.8h, v16.8b, #0\n"
- "smlal v12.4s, v17.4h, v4.h[3]\n"
- "smlal v13.4s, v17.4h, v5.h[3]\n"
- "smlal v14.4s, v17.4h, v6.h[3]\n"
- "smlal v15.4s, v17.4h, v7.h[3]\n"
-
- "ld1 {v17.s}[0], [%[b_ptr]], 4\n"
- "smlal v8.4s, v16.4h, v0.h[4]\n"
- "smlal v9.4s, v16.4h, v1.h[4]\n"
- "smlal v10.4s, v16.4h, v2.h[4]\n"
- "smlal v11.4s, v16.4h, v3.h[4]\n"
- "sshll v17.8h, v17.8b, #0\n"
- "smlal v12.4s, v16.4h, v4.h[4]\n"
- "smlal v13.4s, v16.4h, v5.h[4]\n"
- "smlal v14.4s, v16.4h, v6.h[4]\n"
- "smlal v15.4s, v16.4h, v7.h[4]\n"
-
- "ld1 {v16.s}[0], [%[b_ptr]], 4\n"
- "smlal v8.4s, v17.4h, v0.h[5]\n"
- "smlal v9.4s, v17.4h, v1.h[5]\n"
- "smlal v10.4s, v17.4h, v2.h[5]\n"
- "smlal v11.4s, v17.4h, v3.h[5]\n"
- "sshll v16.8h, v16.8b, #0\n"
- "smlal v12.4s, v17.4h, v4.h[5]\n"
- "smlal v13.4s, v17.4h, v5.h[5]\n"
- "smlal v14.4s, v17.4h, v6.h[5]\n"
- "smlal v15.4s, v17.4h, v7.h[5]\n"
-
- "ld1 {v17.s}[0], [%[b_ptr]], 4\n"
- "smlal v8.4s, v16.4h, v0.h[6]\n"
- "smlal v9.4s, v16.4h, v1.h[6]\n"
- "smlal v10.4s, v16.4h, v2.h[6]\n"
- "smlal v11.4s, v16.4h, v3.h[6]\n"
- "sshll v17.8h, v17.8b, #0\n"
- "smlal v12.4s, v16.4h, v4.h[6]\n"
- "smlal v13.4s, v16.4h, v5.h[6]\n"
- "smlal v14.4s, v16.4h, v6.h[6]\n"
- "smlal v15.4s, v16.4h, v7.h[6]\n"
-
- "smlal v8.4s, v17.4h, v0.h[7]\n"
- "smlal v9.4s, v17.4h, v1.h[7]\n"
- "smlal v10.4s, v17.4h, v2.h[7]\n"
- "smlal v11.4s, v17.4h, v3.h[7]\n"
- "smlal v12.4s, v17.4h, v4.h[7]\n"
- "smlal v13.4s, v17.4h, v5.h[7]\n"
- "smlal v14.4s, v17.4h, v6.h[7]\n"
- "smlal v15.4s, v17.4h, v7.h[7]\n"
-
- "subs %w[K], %w[K], #1\n"
- "cbnz %w[K], 2b\n"
-
- "3:\n" STORE_C
- : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
- [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
- [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
- [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3),
- [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5),
- [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0),
- [n_remain] "+r"(n_remain)
- :
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
- "v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory");
-
- #undef LOAD_LINE
- #undef LOAD_C
- #undef STORE_LINE
- #undef STORE_C
- }
-
- /**
- * Overview of register layout:
- *
- * A 8x8x8 cell of Rhs is stored in 8bit in q12-q13
- * A 8x8x4 cell of Lhs is stored in 8bit in q0-q3
- * A 4x8 block of accumulators is stored in 32bit in q4-q11
- *
- * +--------+--------+
- * |v12[0-8]|v13[0-8]|
- * Rhs +--------+--------+
- * Lhs | | |
- *
- * +--------+ - - - - +-----------------+
- * |v0[0-8]| | v4[0-4]| v5[0-4]|
- * |v1[0-8]| | v6[0-4]| v7[0-4]|
- * |v2[0-8]| | v8[0-4]| v9[0-4]|
- * |v3[0-8]| |v10[0-4]|v11[0-4]|
- * +--------+ - - - - +-----------------+
- *
- * Accumulator
- */
-
- static void kern_4x8(const int8_t* packA, const int8_t* packB, int K,
- int32_t* output, int LDC, bool is_first_k,
- size_t m_remain) {
- K /= 8;
- const int8_t* a_ptr = packA;
- const int8_t* b_ptr = packB;
-
- LDC = LDC * sizeof(int32_t);
- int32_t* outptr0 = output;
- int32_t* outptr1;
- int32_t* outptr2;
- int32_t* outptr3;
- size_t x0 = 0;
-
- // clang-format off
- #define LOAD_LINE(v1, v2, m) \
- "cbz %[x0], 100f\n" \
- "ldp " v1 "," v2 ", [%[outptr" m "]]\n" \
- "subs %[x0], %[x0], #1\n"
-
- #define LOAD_C \
- "mov %[x0], %x[m_remain]\n" \
- LOAD_LINE("q4", "q5", "0") \
- LOAD_LINE("q6", "q7", "1") \
- LOAD_LINE("q8", "q9", "2") \
- LOAD_LINE("q10", "q11", "3") \
- "100:\n"
-
- #define STORE_LINE(v1, v2, m) \
- "cbz %[x0], 101f\n" \
- "stp " v1 "," v2", [%[outptr" m "]]\n" \
- "subs %[x0], %[x0], #1\n"
-
- #define STORE_C \
- "mov %[x0], %x[m_remain]\n" \
- STORE_LINE("q4", "q5", "0") \
- STORE_LINE("q6", "q7", "1") \
- STORE_LINE("q8", "q9", "2") \
- STORE_LINE("q10", "q11", "3") \
- "101:\n"
-
- // clang-format on
-
- asm volatile(
- // load accumulator C
- "add %[outptr1], %[outptr0], %x[LDC]\n"
- "add %[outptr2], %[outptr1], %x[LDC]\n"
- "add %[outptr3], %[outptr2], %x[LDC]\n"
- "cmp %w[is_first_k], #1\n"
- "beq 1f\n" LOAD_C
-
- "b 2f\n"
-
- "1:\n"
- "eor v4.16b, v4.16b, v4.16b\n"
- "eor v5.16b, v5.16b, v5.16b\n"
- "eor v6.16b, v6.16b, v6.16b\n"
- "eor v7.16b, v7.16b, v7.16b\n"
- "eor v8.16b, v8.16b, v8.16b\n"
- "eor v9.16b, v9.16b, v9.16b\n"
- "eor v10.16b, v10.16b, v10.16b\n"
- "eor v11.16b, v11.16b, v11.16b\n"
-
- "2: \n"
- "ld1 {v12.8b}, [%[b_ptr]], 8\n"
- "ld1 {v0.8b}, [%[a_ptr]], 8\n"
- "ld1 {v1.8b}, [%[a_ptr]], 8\n"
- "ld1 {v2.8b}, [%[a_ptr]], 8\n"
- "ld1 {v3.8b}, [%[a_ptr]], 8\n"
- "sshll v12.8h, v12.8b, #0\n"
- "sshll v0.8h, v0.8b, #0\n"
- "sshll v1.8h, v1.8b, #0\n"
- "sshll v2.8h, v2.8b, #0\n"
- "sshll v3.8h, v3.8b, #0\n"
-
- "ld1 {v13.8b}, [%[b_ptr]], 8\n"
- "smlal v4.4s, v12.4h, v0.h[0]\n"
- "smlal v6.4s, v12.4h, v1.h[0]\n"
- "smlal v8.4s, v12.4h, v2.h[0]\n"
- "smlal v10.4s, v12.4h, v3.h[0]\n"
- "sshll v13.8h, v13.8b, #0\n"
- "smlal2 v5.4s, v12.8h, v0.h[0]\n"
- "smlal2 v7.4s, v12.8h, v1.h[0]\n"
- "smlal2 v9.4s, v12.8h, v2.h[0]\n"
- "smlal2 v11.4s, v12.8h, v3.h[0]\n"
-
- "ld1 {v12.8b}, [%[b_ptr]], 8\n"
- "smlal v4.4s, v13.4h, v0.h[1]\n"
- "smlal v6.4s, v13.4h, v1.h[1]\n"
- "smlal v8.4s, v13.4h, v2.h[1]\n"
- "smlal v10.4s, v13.4h, v3.h[1]\n"
- "sshll v12.8h, v12.8b, #0\n"
- "smlal2 v5.4s, v13.8h, v0.h[1]\n"
- "smlal2 v7.4s, v13.8h, v1.h[1]\n"
- "smlal2 v9.4s, v13.8h, v2.h[1]\n"
- "smlal2 v11.4s, v13.8h, v3.h[1]\n"
-
- "ld1 {v13.8b}, [%[b_ptr]], 8\n"
- "smlal v4.4s, v12.4h, v0.h[2]\n"
- "smlal v6.4s, v12.4h, v1.h[2]\n"
- "smlal v8.4s, v12.4h, v2.h[2]\n"
- "smlal v10.4s, v12.4h, v3.h[2]\n"
- "sshll v13.8h, v13.8b, #0\n"
- "smlal2 v5.4s, v12.8h, v0.h[2]\n"
- "smlal2 v7.4s, v12.8h, v1.h[2]\n"
- "smlal2 v9.4s, v12.8h, v2.h[2]\n"
- "smlal2 v11.4s, v12.8h, v3.h[2]\n"
-
- "ld1 {v12.8b}, [%[b_ptr]], 8\n"
- "smlal v4.4s, v13.4h, v0.h[3]\n"
- "smlal v6.4s, v13.4h, v1.h[3]\n"
- "smlal v8.4s, v13.4h, v2.h[3]\n"
- "smlal v10.4s, v13.4h, v3.h[3]\n"
- "sshll v12.8h, v12.8b, #0\n"
- "smlal2 v5.4s, v13.8h, v0.h[3]\n"
- "smlal2 v7.4s, v13.8h, v1.h[3]\n"
- "smlal2 v9.4s, v13.8h, v2.h[3]\n"
- "smlal2 v11.4s, v13.8h, v3.h[3]\n"
-
- "ld1 {v13.8b}, [%[b_ptr]], 8\n"
- "smlal v4.4s, v12.4h, v0.h[4]\n"
- "smlal v6.4s, v12.4h, v1.h[4]\n"
- "smlal v8.4s, v12.4h, v2.h[4]\n"
- "smlal v10.4s, v12.4h, v3.h[4]\n"
- "sshll v13.8h, v13.8b, #0\n"
- "smlal2 v5.4s, v12.8h, v0.h[4]\n"
- "smlal2 v7.4s, v12.8h, v1.h[4]\n"
- "smlal2 v9.4s, v12.8h, v2.h[4]\n"
- "smlal2 v11.4s, v12.8h, v3.h[4]\n"
-
- "ld1 {v12.8b}, [%[b_ptr]], 8\n"
- "smlal v4.4s, v13.4h, v0.h[5]\n"
- "smlal v6.4s, v13.4h, v1.h[5]\n"
- "smlal v8.4s, v13.4h, v2.h[5]\n"
- "smlal v10.4s, v13.4h, v3.h[5]\n"
- "sshll v12.8h, v12.8b, #0\n"
- "smlal2 v5.4s, v13.8h, v0.h[5]\n"
- "smlal2 v7.4s, v13.8h, v1.h[5]\n"
- "smlal2 v9.4s, v13.8h, v2.h[5]\n"
- "smlal2 v11.4s, v13.8h, v3.h[5]\n"
-
- "ld1 {v13.8b}, [%[b_ptr]], 8\n"
- "smlal v4.4s, v12.4h, v0.h[6]\n"
- "smlal v6.4s, v12.4h, v1.h[6]\n"
- "smlal v8.4s, v12.4h, v2.h[6]\n"
- "smlal v10.4s, v12.4h, v3.h[6]\n"
- "sshll v13.8h, v13.8b, #0\n"
- "smlal2 v5.4s, v12.8h, v0.h[6]\n"
- "smlal2 v7.4s, v12.8h, v1.h[6]\n"
- "smlal2 v9.4s, v12.8h, v2.h[6]\n"
- "smlal2 v11.4s, v12.8h, v3.h[6]\n"
-
- "smlal v4.4s, v13.4h, v0.h[7]\n"
- "smlal v6.4s, v13.4h, v1.h[7]\n"
- "smlal v8.4s, v13.4h, v2.h[7]\n"
- "smlal v10.4s, v13.4h, v3.h[7]\n"
- "smlal2 v5.4s, v13.8h, v0.h[7]\n"
- "smlal2 v7.4s, v13.8h, v1.h[7]\n"
- "smlal2 v9.4s, v13.8h, v2.h[7]\n"
- "smlal2 v11.4s, v13.8h, v3.h[7]\n"
-
- "subs %w[K], %w[K], #1\n"
- "cbnz %w[K], 2b\n"
-
- "3:\n" STORE_C
- : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
- [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
- [outptr0] "+r"(outptr0),
- [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2),
- [outptr3] "=r"(outptr3), [x0] "+r"(x0), [m_remain] "+r"(m_remain)
- :
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
- "v11", "v12", "v13", "cc", "memory");
-
- #undef LOAD_LINE
- #undef LOAD_C
- #undef STORE_LINE
- #undef STORE_C
- }
-
- /**
- * Overview of register layout:
- *
- * A 8x4x8 cell of Rhs is stored in 8bit in q8-q9
- * A 8x8x4 cell of Lhs is stored in 8bit in q0-q3
- * A 4x4 block of accumulators is stored in 32bit in q4-q7
- *
- * +--------+
- * | v8[0-4]|
- * Rhs +--------+
- * | v9[0-4]|
- * Lhs +--------+
- *
- * +--------+ - - - - +--------+
- * |v0[0-8]| | v4[0-4]|
- * |v1[0-8]| | v5[0-4]|
- * |v2[0-8]| | v6[0-4]|
- * |v3[0-8]| | v7[0-4]|
- * +--------+ - - - - +--------+
- *
- * Accumulator
- */
-
- static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
- int32_t* output, int LDC, bool is_first_k, size_t m_remain,
- size_t n_remain) {
- K /= 8;
- const int8_t* a_ptr = packA;
- const int8_t* b_ptr = packB;
-
- LDC = LDC * sizeof(int32_t);
- int32_t* outptr0 = output;
- int32_t* outptr1;
- int32_t* outptr2;
- int32_t* outptr3;
- size_t x0 = 0;
- size_t x1 = 0;
-
- // clang-format off
- #define LOAD_LINE(reg_index, n) \
- "cbz %[x1], 102f\n" \
- "mov %[x0], %[outptr" n "]\n" \
- "cmp %w[n_remain], #4\n" \
- "blt 100" n "f\n" \
- "ldr q" reg_index ", [%[x0]]\n" \
- "b 101" n "f\n" \
- "100" n ":\n" \
- "cmp %w[n_remain], #0\n" \
- "beq 101" n "f\n" \
- "ld1 {v" reg_index ".s}[0], [%[x0]], #4\n" \
- "cmp %w[n_remain], #1\n" \
- "beq 101" n "f\n" \
- "ld1 {v" reg_index ".s}[1], [%[x0]], #4\n" \
- "cmp %w[n_remain], #2\n" \
- "beq 101" n "f\n" \
- "ld1 {v" reg_index ".s}[2], [%[x0]], #4\n" \
- "101" n ":\n" \
- "subs %[x1], %[x1], #1\n"
-
- #define LOAD_C \
- "mov %[x1], %x[m_remain]\n" \
- LOAD_LINE("4", "0") \
- LOAD_LINE("5", "1") \
- LOAD_LINE("6", "2") \
- LOAD_LINE("7", "3") \
- "102:\n"
-
- #define STORE_LINE(reg_index, n) \
- "cbz %[x1], 105f\n" \
- "mov %[x0], %[outptr" n "]\n" \
- "cmp %w[n_remain], #4\n" \
- "blt 103" n "f\n" \
- "str q" reg_index ", [%[x0]]\n" \
- "b 104" n "f\n" \
- "103" n ":\n" \
- "cmp %w[n_remain], #0\n" \
- "beq 104" n "f\n" \
- "st1 {v" reg_index ".s}[0], [%[x0]], #4\n" \
- "cmp %w[n_remain], #1\n" \
- "beq 104" n "f\n" \
- "st1 {v" reg_index ".s}[1], [%[x0]], #4\n" \
- "cmp %w[n_remain], #2\n" \
- "beq 104" n "f\n" \
- "st1 {v" reg_index ".s}[2], [%[x0]], #4\n" \
- "104" n ":\n" \
- "subs %[x1], %[x1], #1\n"
-
- #define STORE_C \
- "mov %[x1], %x[m_remain]\n" \
- STORE_LINE("4", "0") \
- STORE_LINE("5", "1") \
- STORE_LINE("6", "2") \
- STORE_LINE("7", "3") \
- "105:\n"
-
- // clang-format on
-
- asm volatile(
- // load accumulator C
- "add %[outptr1], %[outptr0], %x[LDC]\n"
- "add %[outptr2], %[outptr1], %x[LDC]\n"
- "add %[outptr3], %[outptr2], %x[LDC]\n"
- "cmp %w[is_first_k], #1\n"
- "beq 1f\n" LOAD_C
-
- "b 2f\n"
-
- "1:\n"
- "eor v4.16b, v4.16b, v4.16b\n"
- "eor v5.16b, v5.16b, v5.16b\n"
- "eor v6.16b, v6.16b, v6.16b\n"
- "eor v7.16b, v7.16b, v7.16b\n"
-
- "2: \n"
- "ld1 {v8.s}[0], [%[b_ptr]], 4\n"
- "ld1 {v0.8b}, [%[a_ptr]], 8\n"
- "ld1 {v1.8b}, [%[a_ptr]], 8\n"
- "ld1 {v2.8b}, [%[a_ptr]], 8\n"
- "ld1 {v3.8b}, [%[a_ptr]], 8\n"
- "sshll v8.8h, v8.8b, #0\n"
- "sshll v0.8h, v0.8b, #0\n"
- "sshll v1.8h, v1.8b, #0\n"
- "sshll v2.8h, v2.8b, #0\n"
- "sshll v3.8h, v3.8b, #0\n"
-
- "ld1 {v9.s}[0], [%[b_ptr]], 4\n"
- "smlal v4.4s, v8.4h, v0.h[0]\n"
- "smlal v5.4s, v8.4h, v1.h[0]\n"
- "sshll v9.8h, v9.8b, #0\n"
- "smlal v6.4s, v8.4h, v2.h[0]\n"
- "smlal v7.4s, v8.4h, v3.h[0]\n"
-
- "ld1 {v8.s}[0], [%[b_ptr]], 4\n"
- "smlal v4.4s, v9.4h, v0.h[1]\n"
- "smlal v5.4s, v9.4h, v1.h[1]\n"
- "sshll v8.8h, v8.8b, #0\n"
- "smlal v6.4s, v9.4h, v2.h[1]\n"
- "smlal v7.4s, v9.4h, v3.h[1]\n"
-
- "ld1 {v9.s}[0], [%[b_ptr]], 4\n"
- "smlal v4.4s, v8.4h, v0.h[2]\n"
- "smlal v5.4s, v8.4h, v1.h[2]\n"
- "sshll v9.8h, v9.8b, #0\n"
- "smlal v6.4s, v8.4h, v2.h[2]\n"
- "smlal v7.4s, v8.4h, v3.h[2]\n"
-
- "ld1 {v8.s}[0], [%[b_ptr]], 4\n"
- "smlal v4.4s, v9.4h, v0.h[3]\n"
- "smlal v5.4s, v9.4h, v1.h[3]\n"
- "sshll v8.8h, v8.8b, #0\n"
- "smlal v6.4s, v9.4h, v2.h[3]\n"
- "smlal v7.4s, v9.4h, v3.h[3]\n"
-
- "ld1 {v9.s}[0], [%[b_ptr]], 4\n"
- "smlal v4.4s, v8.4h, v0.h[4]\n"
- "smlal v5.4s, v8.4h, v1.h[4]\n"
- "sshll v9.8h, v9.8b, #0\n"
- "smlal v6.4s, v8.4h, v2.h[4]\n"
- "smlal v7.4s, v8.4h, v3.h[4]\n"
-
- "ld1 {v8.s}[0], [%[b_ptr]], 4\n"
- "smlal v4.4s, v9.4h, v0.h[5]\n"
- "smlal v5.4s, v9.4h, v1.h[5]\n"
- "sshll v8.8h, v8.8b, #0\n"
- "smlal v6.4s, v9.4h, v2.h[5]\n"
- "smlal v7.4s, v9.4h, v3.h[5]\n"
-
- "ld1 {v9.s}[0], [%[b_ptr]], 4\n"
- "smlal v4.4s, v8.4h, v0.h[6]\n"
- "smlal v5.4s, v8.4h, v1.h[6]\n"
- "sshll v9.8h, v9.8b, #0\n"
- "smlal v6.4s, v8.4h, v2.h[6]\n"
- "smlal v7.4s, v8.4h, v3.h[6]\n"
-
- "smlal v4.4s, v9.4h, v0.h[7]\n"
- "smlal v5.4s, v9.4h, v1.h[7]\n"
- "smlal v6.4s, v9.4h, v2.h[7]\n"
- "smlal v7.4s, v9.4h, v3.h[7]\n"
-
- "subs %w[K], %w[K], #1\n"
- "cbnz %w[K], 2b\n"
-
- "3:\n" STORE_C
- : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr),
- [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC),
- [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1),
- [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0),
- [x1] "+r"(x1), [m_remain] "+r"(m_remain),
- [n_remain] "+r"(n_remain)
- :
- : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v11", "cc",
- "memory");
-
- #undef LOAD_LINE
- #undef LOAD_C
- #undef STORE_LINE
- #undef STORE_C
- }
-
- static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin,
- int y0, int ymax, int k0, int kmax) {
- int8_t zerobuff[16];
- std::memset(zerobuff, 0, sizeof(int8_t) * 16);
-
- int y = y0;
- for (; y + 7 < ymax; y += 8) {
- const int8_t* inptr0 = inptr + y * ldin + k0;
- const int8_t* inptr1 = inptr0 + ldin;
- const int8_t* inptr2 = inptr1 + ldin;
- const int8_t* inptr3 = inptr2 + ldin;
- const int8_t* inptr4 = inptr3 + ldin;
- const int8_t* inptr5 = inptr4 + ldin;
- const int8_t* inptr6 = inptr5 + ldin;
- const int8_t* inptr7 = inptr6 + ldin;
-
- prefetch_2x(inptr0);
- prefetch_2x(inptr1);
- prefetch_2x(inptr2);
- prefetch_2x(inptr3);
- prefetch_2x(inptr4);
- prefetch_2x(inptr5);
- prefetch_2x(inptr6);
- prefetch_2x(inptr7);
-
- int K = kmax - k0;
- for (; K > 15; K -= 16) {
- interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
- inptr6, inptr7, outptr);
- }
-
- if (K > 0) {
- interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
- inptr7, outptr, 8, K);
- }
- }
-
- for (; y < ymax; y += 4) {
- const int8_t* inptr0 = inptr + y * ldin + k0;
- const int8_t* inptr1 = inptr0 + ldin;
- const int8_t* inptr2 = inptr1 + ldin;
- const int8_t* inptr3 = inptr2 + ldin;
-
- prefetch_2x(inptr0);
- prefetch_2x(inptr1);
- prefetch_2x(inptr2);
- prefetch_2x(inptr3);
-
- int K = kmax - k0;
- for (; K > 15; K -= 16) {
- if (y + 3 >= ymax) {
- switch (y + 3 - ymax) {
- case 2:
- inptr1 = zerobuff; MEGDNN_FALLTHRU
- case 1:
- inptr2 = zerobuff; MEGDNN_FALLTHRU
- case 0:
- inptr3 = zerobuff;
- break;
- default:
- megdnn_assert(0);
- }
- }
-
- interleave_4x8_2_b(inptr0, inptr1, inptr2, inptr3, outptr);
- }
-
- if (K > 0) {
- if (y + 3 >= ymax) {
- switch (y + 3 - ymax) {
- case 2:
- inptr1 = zerobuff; MEGDNN_FALLTHRU
- case 1:
- inptr2 = zerobuff; MEGDNN_FALLTHRU
- case 0:
- inptr3 = zerobuff;
- break;
- default:
- megdnn_assert(0);
- }
- }
- interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K);
- }
- }
- }
-
- static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in,
- int ldin, int x0, int xmax, int k0,
- int kmax) {
- int8_t zerobuff[16];
- std::memset(zerobuff, 0, sizeof(int8_t) * 16);
- const int ksize = kmax - k0;
- const int ksize4 = round_up(ksize, 8) * 4;
- const int ksize8 = ksize4 * 2;
- int8_t* outptr = out;
- int8_t* outptr_base = out;
- //! 4x4 block output start pos
- int8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8;
-
- int k = k0;
- for (; k < kmax; k += 8) {
- const int8_t* inptr0 = in + k * ldin + x0;
- const int8_t* inptr1 = inptr0 + ldin;
- const int8_t* inptr2 = inptr1 + ldin;
- const int8_t* inptr3 = inptr2 + ldin;
- const int8_t* inptr4 = inptr3 + ldin;
- const int8_t* inptr5 = inptr4 + ldin;
- const int8_t* inptr6 = inptr5 + ldin;
- const int8_t* inptr7 = inptr6 + ldin;
- prefetch_2x(inptr0);
- prefetch_2x(inptr1);
- prefetch_2x(inptr2);
- prefetch_2x(inptr3);
- prefetch_2x(inptr4);
- prefetch_2x(inptr5);
- prefetch_2x(inptr6);
- prefetch_2x(inptr7);
-
- int x = x0;
- outptr = outptr_base;
-
- for (; x + 7 < xmax; x += 8) {
- if (k + 7 >= kmax) {
- switch (k + 7 - kmax) {
- case 6:
- inptr1 = zerobuff; MEGDNN_FALLTHRU
- case 5:
- inptr2 = zerobuff; MEGDNN_FALLTHRU
- case 4:
- inptr3 = zerobuff; MEGDNN_FALLTHRU
- case 3:
- inptr4 = zerobuff; MEGDNN_FALLTHRU
- case 2:
- inptr5 = zerobuff; MEGDNN_FALLTHRU
- case 1:
- inptr6 = zerobuff; MEGDNN_FALLTHRU
- case 0:
- inptr7 = zerobuff;
- break;
- default:
- megdnn_assert(0);
- }
- }
- transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
- inptr6, inptr7, outptr);
- outptr += ksize8;
- }
-
- outptr = outptr_base4;
- for (; x + 3 < xmax; x += 4) {
- if (k + 7 >= kmax) {
- switch (k + 7 - kmax) {
- case 6:
- inptr1 = zerobuff; MEGDNN_FALLTHRU
- case 5:
- inptr2 = zerobuff; MEGDNN_FALLTHRU
- case 4:
- inptr3 = zerobuff; MEGDNN_FALLTHRU
- case 3:
- inptr4 = zerobuff; MEGDNN_FALLTHRU
- case 2:
- inptr5 = zerobuff; MEGDNN_FALLTHRU
- case 1:
- inptr6 = zerobuff; MEGDNN_FALLTHRU
- case 0:
- inptr7 = zerobuff;
- break;
- default:
- megdnn_assert(0);
- }
- }
-
- transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
- inptr7, outptr, 4, 4);
- outptr += ksize4;
- }
-
- if (x < xmax) {
- if (k + 7 >= kmax) {
- switch (k + 7 - kmax) {
- case 6:
- inptr1 = zerobuff; MEGDNN_FALLTHRU
- case 5:
- inptr2 = zerobuff; MEGDNN_FALLTHRU
- case 4:
- inptr3 = zerobuff; MEGDNN_FALLTHRU
- case 3:
- inptr4 = zerobuff; MEGDNN_FALLTHRU
- case 2:
- inptr5 = zerobuff; MEGDNN_FALLTHRU
- case 1:
- inptr6 = zerobuff; MEGDNN_FALLTHRU
- case 0:
- inptr7 = zerobuff;
- break;
- default:
- megdnn_assert(0);
- }
- }
-
- transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
- inptr7, outptr, 4, xmax - x);
- }
-
- outptr_base += 8 * 8;
- outptr_base4 += 4 * 8;
- }
- }
-
- static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin,
- int x0, int xmax, int k0, int kmax) {
- int8_t zerobuff[16];
- std::memset(zerobuff, 0, sizeof(int8_t) * 16);
- const int ksize = kmax - k0;
- const int ksize4 = round_up(ksize, 8) * 4;
- const int ksize8 = ksize4 * 2;
- int8_t* outptr = out;
- int8_t* outptr_base = out;
- int8_t* outptr_interleave = nullptr;
- //! 4x4 block output start pos
- int8_t* outptr_base4 = out + ((xmax - x0) / 8) * ksize8;
-
- int k = k0;
- for (; k < kmax; k += 8) {
- const int8_t* inptr0 = in + k * ldin + x0;
- const int8_t* inptr1 = inptr0 + ldin;
- const int8_t* inptr2 = inptr1 + ldin;
- const int8_t* inptr3 = inptr2 + ldin;
- const int8_t* inptr4 = inptr3 + ldin;
- const int8_t* inptr5 = inptr4 + ldin;
- const int8_t* inptr6 = inptr5 + ldin;
- const int8_t* inptr7 = inptr6 + ldin;
- prefetch_2x(inptr0);
- prefetch_2x(inptr1);
- prefetch_2x(inptr2);
- prefetch_2x(inptr3);
- prefetch_2x(inptr4);
- prefetch_2x(inptr5);
- prefetch_2x(inptr6);
- prefetch_2x(inptr7);
-
- int x = x0;
- outptr = outptr_base;
-
- for (; x + 7 < xmax; x += 8) {
- if (k + 7 >= kmax) {
- switch (k + 7 - kmax) {
- case 6:
- inptr1 = zerobuff; MEGDNN_FALLTHRU
- case 5:
- inptr2 = zerobuff; MEGDNN_FALLTHRU
- case 4:
- inptr3 = zerobuff; MEGDNN_FALLTHRU
- case 3:
- inptr4 = zerobuff; MEGDNN_FALLTHRU
- case 2:
- inptr5 = zerobuff; MEGDNN_FALLTHRU
- case 1:
- inptr6 = zerobuff; MEGDNN_FALLTHRU
- case 0:
- inptr7 = zerobuff;
- break;
- default:
- megdnn_assert(0);
- }
- }
- outptr_interleave = outptr;
- interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
- inptr6, inptr7, outptr_interleave);
- outptr += ksize8;
- }
-
- outptr = outptr_base4;
- for (; x + 3 < xmax; x += 4) {
- if (k + 7 >= kmax) {
- switch (k + 7 - kmax) {
- case 6:
- inptr1 = zerobuff; MEGDNN_FALLTHRU
- case 5:
- inptr2 = zerobuff; MEGDNN_FALLTHRU
- case 4:
- inptr3 = zerobuff; MEGDNN_FALLTHRU
- case 3:
- inptr4 = zerobuff; MEGDNN_FALLTHRU
- case 2:
- inptr5 = zerobuff; MEGDNN_FALLTHRU
- case 1:
- inptr6 = zerobuff; MEGDNN_FALLTHRU
- case 0:
- inptr7 = zerobuff;
- break;
- default:
- megdnn_assert(0);
- }
- }
-
- outptr_interleave = outptr;
- interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
- inptr7, outptr_interleave, 4, 4);
- outptr += ksize4;
- }
-
- if (x < xmax) {
- if (k + 7 >= kmax) {
- switch (k + 7 - kmax) {
- case 6:
- inptr1 = zerobuff; MEGDNN_FALLTHRU
- case 5:
- inptr2 = zerobuff; MEGDNN_FALLTHRU
- case 4:
- inptr3 = zerobuff; MEGDNN_FALLTHRU
- case 3:
- inptr4 = zerobuff; MEGDNN_FALLTHRU
- case 2:
- inptr5 = zerobuff; MEGDNN_FALLTHRU
- case 1:
- inptr6 = zerobuff; MEGDNN_FALLTHRU
- case 0:
- inptr7 = zerobuff;
- break;
- default:
- megdnn_assert(0);
- }
- }
-
- outptr_interleave = outptr;
- interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
- inptr7, outptr_interleave, 4, xmax - x);
- }
-
- outptr_base += 8 * 8;
- outptr_base4 += 4 * 8;
- }
- }
-
- static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr,
- int ldin, int y0, int ymax, int k0,
- int kmax) {
- int8_t zerobuff[16];
- std::memset(zerobuff, 0, sizeof(int8_t) * 16);
- constexpr int interleave4 = 32;
- constexpr int interleave8 = 64;
-
- int y = y0;
- for (; y + 7 < ymax; y += 8) {
- const int8_t* inptr0 = inptr + y * ldin + k0;
- const int8_t* inptr1 = inptr0 + ldin;
- const int8_t* inptr2 = inptr1 + ldin;
- const int8_t* inptr3 = inptr2 + ldin;
- const int8_t* inptr4 = inptr3 + ldin;
- const int8_t* inptr5 = inptr4 + ldin;
- const int8_t* inptr6 = inptr5 + ldin;
- const int8_t* inptr7 = inptr6 + ldin;
-
- prefetch_2x(inptr0);
- prefetch_2x(inptr1);
- prefetch_2x(inptr2);
- prefetch_2x(inptr3);
- prefetch_2x(inptr4);
- prefetch_2x(inptr5);
- prefetch_2x(inptr6);
- prefetch_2x(inptr7);
-
- int K = kmax - k0;
- for (; K > 7; K -= 8) {
- transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5,
- inptr6, inptr7, outptr);
- outptr += interleave8;
- }
-
- if (K > 0) {
- transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6,
- inptr7, outptr, 8, K);
- outptr += interleave8;
- }
- }
-
- for (; y < ymax; y += 4) {
- const int8_t* inptr0 = inptr + y * ldin + k0;
- const int8_t* inptr1 = inptr0 + ldin;
- const int8_t* inptr2 = inptr1 + ldin;
- const int8_t* inptr3 = inptr2 + ldin;
-
- prefetch_2x(inptr0);
- prefetch_2x(inptr1);
- prefetch_2x(inptr2);
- prefetch_2x(inptr3);
-
- int K = kmax - k0;
- for (; K > 7; K -= 8) {
- if (y + 3 >= ymax) {
- switch (y + 3 - ymax) {
- case 2:
- inptr1 = zerobuff; MEGDNN_FALLTHRU
- case 1:
- inptr2 = zerobuff; MEGDNN_FALLTHRU
- case 0:
- inptr3 = zerobuff;
- break;
- default:
- megdnn_assert(0);
- }
- }
-
- transpose_8x4_1_b(inptr0, inptr1, inptr2, inptr3, outptr);
- outptr += interleave4;
- }
-
- if (K > 0) {
- if (y + 3 >= ymax) {
- switch (y + 3 - ymax) {
- case 2:
- inptr1 = zerobuff; MEGDNN_FALLTHRU
- case 1:
- inptr2 = zerobuff; MEGDNN_FALLTHRU
- case 0:
- inptr3 = zerobuff;
- break;
- default:
- megdnn_assert(0);
- }
- }
- transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K);
- outptr += interleave4;
- }
- }
- }
-
- } // namespace matmul_8x8x8
- } // namespace aarch64
- } // namespace megdnn
-
- // vim: syntax=cpp.doxygen
- #endif
|