From 92b12685dbf693d275fcdbcc769539ffdcfa0372 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 24 Sep 2020 18:37:49 +0800 Subject: [PATCH] feat(dnn/aarch64): add aarch64 int8X8X16_mk4_k8x8x8 matmul, performance is better GitOrigin-RevId: b6af21e8e314b4edd62f0fddcf8578d2eaa0fc2a --- dnn/src/aarch64/matrix_mul/algos.cpp | 70 + dnn/src/aarch64/matrix_mul/algos.h | 16 + dnn/src/aarch64/matrix_mul/asm/common.h | 56 + .../matrix_mul/int8x8x16/kernel_mk4_8x8x8.h | 1451 ++++++++++++++++++++ dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp | 78 ++ dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h | 2 + dnn/src/aarch64/matrix_mul/opr_impl.cpp | 2 + dnn/src/aarch64/matrix_mul/opr_impl.h | 1 + dnn/test/aarch64/matrix_mul.cpp | 79 ++ 9 files changed, 1755 insertions(+) create mode 100644 dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h diff --git a/dnn/src/aarch64/matrix_mul/algos.cpp b/dnn/src/aarch64/matrix_mul/algos.cpp index 3de6ef3e..89aea82c 100644 --- a/dnn/src/aarch64/matrix_mul/algos.cpp +++ b/dnn/src/aarch64/matrix_mul/algos.cpp @@ -1310,4 +1310,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8, int32_t); #endif +/* ===================== Int8x8x16 K8x8x8 algo ===================== */ +namespace { +void int8x8x16_mk4_8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("int8x8x16_mk4_8x8x8_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), + Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy(M, N, K, A_type, + B_type, C_type); + megdnn::matmul::GemmInterleaved< + aarch64::matmul::gemm_s8x8x16_mk4_8x8x8>(M, N, K, trA, trB, + strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::usable( + const KernSizeParam& kern_size_param) const { + return can_be_treated_as_int8x8x16(kern_size_param) && + kern_size_param.format == param::MatrixMul::Format::MK4 && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + !kern_size_param.trA && !kern_size_param.trB && + kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; +} + +bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::preferred( + const KernSizeParam&) const { + return true; +} + +size_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x16_MK4_8x8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy(M, N, K, A_type, + B_type, C_type); + return megdnn::matmul::GemmInterleaved< + matmul::gemm_s8x8x16_mk4_8x8x8>(M, N, K, trA, trB, + strategy) + .get_workspace_size(); + } + MIDOUT_END(); + return 0; +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_kern( + const KernSizeParam&) const { + return int8x8x16_mk4_8x8x8_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8, + megdnn_aarch64_matmul_kern, + "AlgoInt8x8x16MK4_K8x8x8Impl"_hash, + aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, int8_t, + int16_t); // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/algos.h b/dnn/src/aarch64/matrix_mul/algos.h index 54b6734d..46b7df25 100644 --- a/dnn/src/aarch64/matrix_mul/algos.h +++ b/dnn/src/aarch64/matrix_mul/algos.h @@ -202,6 +202,22 @@ public: MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; +class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { + return "AARCH64_INT8X8X16_MK4_K8X8X8"; + } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + PackMode packmode() const override { return PackMode::DEFAULT; } + + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { public: bool is_reproducible() const override { return true; } diff --git a/dnn/src/aarch64/matrix_mul/asm/common.h b/dnn/src/aarch64/matrix_mul/asm/common.h index 771129a9..369810fb 100644 --- a/dnn/src/aarch64/matrix_mul/asm/common.h +++ b/dnn/src/aarch64/matrix_mul/asm/common.h @@ -2101,6 +2101,62 @@ static inline void transpos_12x4_s8(const int8_t* inptr0, int8_t* outptr) { vreinterpretq_s32_s8(input2), 3); } + +template +static inline void interleave_8x8_mk4_b(const T*& inptr0, const T*& inptr1, + T*& outptr) { + + static_assert( + std::is_same::value || std::is_same::value, + "transpose_8x4_1_b only support uint8_t and int8_t"); + asm volatile( + "ld1 {v0.4s}, [%[inptr0]], #16\n" + "ld1 {v1.4s}, [%[inptr1]], #16\n" + "ld1 {v2.4s}, [%[inptr0]], #16\n" + "ld1 {v3.4s}, [%[inptr1]], #16\n" + + "zip1 v4.4s, v0.4s, v1.4s \n" + "zip2 v5.4s, v0.4s, v1.4s \n" + + "zip1 v6.4s, v2.4s, v3.4s\n" + "zip2 v7.4s, v2.4s, v3.4s\n" + + "st1 {v4.4s},[%[outptr]],#16\n" + "st1 {v5.4s},[%[outptr]],#16\n" + "st1 {v6.4s},[%[outptr]],#16\n" + "st1 {v7.4s},[%[outptr]],#16\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","memory"); +} + +template +static inline void transpose_8x8_mk4_b(const T*& inptr0, const T*& inptr1, + T* outptr) { + + static_assert( + std::is_same::value || std::is_same::value, + "transpose_8x4_1_b only support uint8_t and int8_t"); + asm volatile( + "ld4 {v0.8b-v3.8b}, [%[inptr0]], #32\n" + "ld4 {v4.8b-v7.8b}, [%[inptr1]], #32\n" + "st1 {v0.2s},[%[outptr]],#8\n" + "st1 {v1.2s},[%[outptr]],#8\n" + "st1 {v2.2s},[%[outptr]],#8\n" + "st1 {v3.2s},[%[outptr]],#8\n" + "st1 {v4.2s},[%[outptr]],#8\n" + "st1 {v5.2s},[%[outptr]],#8\n" + "st1 {v6.2s},[%[outptr]],#8\n" + "st1 {v7.2s},[%[outptr]],#8\n" + + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","memory"); +} + } // namespace aarch64 } // namespace megdnn diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h new file mode 100644 index 00000000..80db73b7 --- /dev/null +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h @@ -0,0 +1,1451 @@ +/** + * \file dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include +#include "src/aarch64/matrix_mul/asm/common.h" +#include "src/arm_common/simd_macro/marm_neon.h" + +namespace megdnn { +namespace aarch64 { +namespace matmul_mk4_8x8x8 { + + +/** + * Overview of register layout: + * + * A 8x8 cell of Lhs is stored in 8bit in v16-v17 + * B 8x8 cell of Rhs is stored in 8bit in v0-v15, v20-v23 + * C 8x8 block of accumulators is stored in 16bit in v24-v31 + * + * +---------------------------------+ + * | v0 ------------------------ v7 | + * | v8 ------------------------ v15| + * Rhs +---------------------------------+ + * Lhs | | + * +--------+ - - - - +---------------------------------+ + * | v16 | | v24 | + * | v17 | | v25 | + * | v16 | | v26 | + * | v17 | | v27 | + * | v16 | | v28 | + * | v17 | | v29 | + * | v16 | | v30 | + * | v17 | | v31 | + * +--------+ - - - - +---------------------------------+ + * + * Accumulator + */ +static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, int m_remain, + int n_remain) { + K /= 8; + LDC = LDC * sizeof(int16_t); + const int8_t* a_ptr = packB;//packA; + const int8_t* b_ptr = packA;//packB; +// clang-format off +#define LOAD_C_8 \ + "ld1 {v0.8h}, [x0], #16\n" \ + "ld1 {v1.8h}, [x0], #16\n" \ + "ld1 {v2.8h}, [x0], #16\n" \ + "ld1 {v3.8h}, [x0], #16\n" \ + "ld1 {v4.8h}, [x1], #16\n" \ + "ld1 {v5.8h}, [x1], #16\n" \ + "ld1 {v6.8h}, [x1], #16\n" \ + "ld1 {v7.8h}, [x1], #16\n" \ + + +#define STORE_C_8 \ + "st1 {v0.8h}, [x0], #16\n" \ + "st1 {v1.8h}, [x0], #16\n" \ + "st1 {v2.8h}, [x0], #16\n" \ + "st1 {v3.8h}, [x0], #16\n" \ + "st1 {v4.8h}, [x1], #16\n" \ + "st1 {v5.8h}, [x1], #16\n" \ + "st1 {v6.8h}, [x1], #16\n" \ + "st1 {v7.8h}, [x1], #16\n" \ + + register int16_t* outptr asm("x0") = output; + asm volatile( + "add x1, x0, %x[LDC]\n" + + "eor v24.16b, v24.16b, v24.16b\n" + "PRFM PLDL1KEEP, [%[a_ptr], #512]\n" + "eor v25.16b, v25.16b, v25.16b\n" + "PRFM PLDL1KEEP, [%[b_ptr], #512]\n" + "eor v26.16b, v26.16b, v26.16b\n" + "ld1 {v20.16b}, [%[a_ptr]],#16\n" + "eor v27.16b, v27.16b, v27.16b\n" + "ld1 {v21.16b}, [%[a_ptr]],#16\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + // General loop. + "1:\n" + "dup v0.8b,v20.b[0]\n" + "ld1 {v22.16b}, [%[a_ptr]],#16\n" + "dup v1.8b,v20.b[1]\n" + "ld1 {v23.16b}, [%[a_ptr]],#16\n" + "dup v2.8b,v20.b[2]\n" + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v3.8b,v20.b[3]\n" + "dup v4.8b,v20.b[4]\n" + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + "dup v5.8b,v20.b[5]\n" + "dup v6.8b,v20.b[6]\n" + "dup v7.8b,v20.b[7]\n" + + + "dup v8.8b,v20.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v20.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v20.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v20.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v20.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v20.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v20.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v20.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + + "dup v0.8b,v21.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v21.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v21.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v21.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v21.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v21.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v21.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v21.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + + "dup v8.8b,v21.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v21.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v21.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v21.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v21.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v21.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v21.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v21.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v0.8b,v22.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v22.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v22.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v22.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v22.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v22.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v22.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v22.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + + "dup v8.8b,v22.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v22.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v22.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v22.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v22.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v22.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v22.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v22.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v0.8b,v23.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v23.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v23.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v23.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v23.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v23.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v23.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v23.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + "dup v8.8b,v23.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v23.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v23.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v23.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v23.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v23.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v23.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v23.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v20.16b}, [%[a_ptr]],#16\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "ld1 {v21.16b}, [%[a_ptr]],#16\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 1b\n" + + "cmp %w[is_first_k], #1\n" + "beq 2f\n" LOAD_C_8 + "b 3f \n" + "2: \n" + "eor v0.16b, v0.16b, v0.16b\n" + "eor v1.16b, v1.16b, v1.16b\n" + "eor v2.16b, v2.16b, v2.16b\n" + "eor v3.16b, v3.16b, v3.16b\n" + "eor v4.16b, v4.16b, v4.16b\n" + "eor v5.16b, v5.16b, v5.16b\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + "3:\n" + "zip1 v8.2d, v24.2d, v25.2d\n" + "zip2 v9.2d, v24.2d, v25.2d\n" + "zip1 v10.2d, v26.2d, v27.2d\n" + "zip2 v11.2d, v26.2d, v27.2d\n" + "zip1 v12.2d, v28.2d, v29.2d\n" + "zip2 v13.2d, v28.2d, v29.2d\n" + "zip1 v14.2d, v30.2d, v31.2d\n" + "zip2 v15.2d, v30.2d, v31.2d\n" + "add v0.8h, v0.8h, v8.8h\n" + "add v1.8h, v1.8h, v10.8h\n" + "add v2.8h, v2.8h, v12.8h\n" + "add v3.8h, v3.8h, v14.8h\n" + "add v4.8h, v4.8h, v9.8h\n" + "add v5.8h, v5.8h, v11.8h\n" + "add v6.8h, v6.8h, v13.8h\n" + "add v7.8h, v7.8h, v15.8h\n" + + // Store back into memory + STORE_C_8 + + : + [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), + [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), + [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), + [ n_remain ] "+r"(n_remain) + : + : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31"); +// clang-format on +} + +static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, int m_remain, + int n_remain) { + K /= 8; + LDC = LDC * sizeof(int16_t); + const int8_t* a_ptr = packB; + const int8_t* b_ptr = packA; +// clang-format off + register int16_t* outptr asm("x0") = output; + asm volatile( + "add x1, x0, %x[LDC]\n" + + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + // General loop. + "ld1 {v20.16b}, [%[a_ptr]],#16\n" + "ld1 {v21.16b}, [%[a_ptr]],#16\n" + "PRFM PLDL1KEEP, [%[a_ptr], #512]\n" + "PRFM PLDL1KEEP, [%[b_ptr], #512]\n" + "1:\n" + "dup v0.8b,v20.b[0]\n" + "ld1 {v22.16b}, [%[a_ptr]],#16\n" + "dup v1.8b,v20.b[1]\n" + "ld1 {v23.16b}, [%[a_ptr]],#16\n" + "dup v2.8b,v20.b[2]\n" + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v3.8b,v20.b[3]\n" + "dup v4.8b,v20.b[4]\n" + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + "dup v5.8b,v20.b[5]\n" + "dup v6.8b,v20.b[6]\n" + "dup v7.8b,v20.b[7]\n" + + "dup v8.8b,v20.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v20.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v20.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v20.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v20.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v20.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v20.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v20.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + + "dup v0.8b,v21.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v21.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v21.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v21.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v21.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v21.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v21.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v21.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + + "dup v8.8b,v21.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v21.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v21.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v21.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v21.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v21.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v21.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v21.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v0.8b,v22.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v22.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v22.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v22.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v22.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v22.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v22.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v22.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + + "dup v8.8b,v22.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v22.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v22.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v22.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v22.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v22.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v22.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v22.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v0.8b,v23.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v23.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v23.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v23.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v23.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v23.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v23.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v23.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + "dup v8.8b,v23.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v23.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v23.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v23.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v23.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v23.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v23.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v23.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v20.16b}, [%[a_ptr]],#16\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "ld1 {v21.16b}, [%[a_ptr]],#16\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 1b\n" + + "cmp %w[is_first_k], #1\n" + "beq 2f\n" + "cmp %x[m_remain], #8 \n" + "beq 8f \n" + "cmp %x[m_remain], #4 \n" + "beq 9f \n" + "8: \n" + "cmp %x[n_remain], #8\n" + "beq 200f \n" + "cmp %x[n_remain], #7\n" + "beq 201f \n" + "cmp %x[n_remain], #6\n" + "beq 202f \n" + "cmp %x[n_remain], #5\n" + "beq 203f \n" + "cmp %x[n_remain], #4\n" + "beq 204f \n" + "cmp %x[n_remain], #3\n" + "beq 205f \n" + "cmp %x[n_remain], #2\n" + "beq 206f \n" + "cmp %x[n_remain], #1\n" + "beq 207f \n" + "200: \n" + "ld1 {v0.8h}, [x0], #16\n" + "ld1 {v1.8h}, [x0], #16\n" + "ld1 {v2.8h}, [x0], #16\n" + "ld1 {v3.8h}, [x0], #16\n" + "ld1 {v4.8h}, [x1], #16\n" + "ld1 {v5.8h}, [x1], #16\n" + "ld1 {v6.8h}, [x1], #16\n" + "ld1 {v7.8h}, [x1], #16\n" + "b 3f \n" + "201: \n" + "ld1 {v0.8h}, [x0], #16\n" + "ld1 {v1.8h}, [x0], #16\n" + "ld1 {v2.8h}, [x0], #16\n" + "ld1 {v3.d}[0], [x0], #8\n" + "ld1 {v4.8h}, [x1], #16\n" + "ld1 {v5.8h}, [x1], #16\n" + "ld1 {v6.8h}, [x1], #16\n" + "ld1 {v7.d}[0], [x1], #8\n" + "b 3f \n" + "202: \n" + "ld1 {v0.8h}, [x0], #16\n" + "ld1 {v1.8h}, [x0], #16\n" + "ld1 {v2.8h}, [x0], #16\n" + "ld1 {v4.8h}, [x1], #16\n" + "ld1 {v5.8h}, [x1], #16\n" + "ld1 {v6.8h}, [x1], #16\n" + "b 3f \n" + "203: \n" + "ld1 {v0.8h}, [x0], #16\n" + "ld1 {v1.8h}, [x0], #16\n" + "ld1 {v2.d}[0], [x0], #8\n" + "ld1 {v4.8h}, [x1], #16\n" + "ld1 {v5.8h}, [x1], #16\n" + "ld1 {v6.d}[0], [x1], #8\n" + "b 3f \n" + "204: \n" + "ld1 {v0.8h}, [x0], #16\n" + "ld1 {v1.8h}, [x0], #16\n" + "ld1 {v4.8h}, [x1], #16\n" + "ld1 {v5.8h}, [x1], #16\n" + "b 3f \n" + "205: \n" + "ld1 {v0.8h}, [x0], #16\n" + "ld1 {v1.d}[0], [x0], #8\n" + "ld1 {v4.8h}, [x1], #16\n" + "ld1 {v5.d}[0], [x1], #8\n" + "b 3f \n" + "206: \n" + "ld1 {v0.8h}, [x0], #16\n" + "ld1 {v4.8h}, [x1], #16\n" + "b 3f \n" + "207: \n" + "ld1 {v0.d}[0], [x0], #8\n" + "ld1 {v4.d}[0], [x1], #8\n" + "b 3f \n" + "9: \n" + "cmp %x[n_remain], #8\n" + "beq 300f \n" + "cmp %x[n_remain], #7\n" + "beq 301f \n" + "cmp %x[n_remain], #6\n" + "beq 302f \n" + "cmp %x[n_remain], #5\n" + "beq 303f \n" + "cmp %x[n_remain], #4\n" + "beq 304f \n" + "cmp %x[n_remain], #3\n" + "beq 305f \n" + "cmp %x[n_remain], #2\n" + "beq 306f \n" + "cmp %x[n_remain], #1\n" + "beq 307f \n" + "300: \n" + "ld1 {v0.8h}, [x0], #16\n" + "ld1 {v1.8h}, [x0], #16\n" + "ld1 {v2.8h}, [x0], #16\n" + "ld1 {v3.8h}, [x0], #16\n" + "b 3f \n" + "301: \n" + "ld1 {v0.8h}, [x0], #16\n" + "ld1 {v1.8h}, [x0], #16\n" + "ld1 {v2.8h}, [x0], #16\n" + "ld1 {v3.d}[0], [x0], #8\n" + "b 3f \n" + "302: \n" + "ld1 {v0.8h}, [x0], #16\n" + "ld1 {v1.8h}, [x0], #16\n" + "ld1 {v2.8h}, [x0], #16\n" + "b 3f \n" + "303: \n" + "ld1 {v0.8h}, [x0], #16\n" + "ld1 {v1.8h}, [x0], #16\n" + "ld1 {v2.d}[0], [x0], #8\n" + "b 3f \n" + "304: \n" + "ld1 {v0.8h}, [x0], #16\n" + "ld1 {v1.8h}, [x0], #16\n" + "b 3f \n" + "305: \n" + "ld1 {v0.8h}, [x0], #16\n" + "ld1 {v1.d}[0], [x0], #8\n" + "b 3f \n" + "306: \n" + "ld1 {v0.8h}, [x0], #16\n" + "b 3f \n" + "307: \n" + "ld1 {v0.d}[0], [x0], #8\n" + "b 3f \n" + "2: \n" + "eor v0.16b, v0.16b, v0.16b\n" + "eor v1.16b, v1.16b, v1.16b\n" + "eor v2.16b, v2.16b, v2.16b\n" + "eor v3.16b, v3.16b, v3.16b\n" + "eor v4.16b, v4.16b, v4.16b\n" + "eor v5.16b, v5.16b, v5.16b\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + "3:\n" + "zip1 v8.2d, v24.2d, v25.2d\n" + "zip1 v10.2d, v26.2d, v27.2d\n" + "add v0.8h, v0.8h, v8.8h \n" + "zip1 v12.2d, v28.2d, v29.2d\n" + "add v1.8h, v1.8h, v10.8h \n" + "zip1 v14.2d, v30.2d, v31.2d\n" + "add v2.8h, v2.8h, v12.8h \n" + "add v3.8h, v3.8h, v14.8h \n" + "zip2 v9.2d, v24.2d, v25.2d\n" + "zip2 v11.2d, v26.2d, v27.2d \n" + "add v4.8h, v4.8h, v9.8h \n" + "zip2 v13.2d, v28.2d, v29.2d \n" + "add v5.8h, v5.8h, v11.8h \n" + "zip2 v15.2d, v30.2d, v31.2d \n" + "add v6.8h, v6.8h, v13.8h \n" + "add v7.8h, v7.8h, v15.8h \n" +//save to memory + "cmp %x[m_remain], #8 \n" + "beq 4f \n" + "cmp %x[m_remain], #4 \n" + "beq 5f \n" + "4: \n" + "cmp %x[n_remain], #8\n" + "beq 100f \n" + "cmp %x[n_remain], #7\n" + "beq 101f \n" + "cmp %x[n_remain], #6\n" + "beq 102f \n" + "cmp %x[n_remain], #5\n" + "beq 103f \n" + "cmp %x[n_remain], #4\n" + "beq 104f \n" + "cmp %x[n_remain], #3\n" + "beq 105f \n" + "cmp %x[n_remain], #2\n" + "beq 106f \n" + "cmp %x[n_remain], #1\n" + "beq 107f \n" + "100: \n" + "st1 {v0.8h}, [x0], #16\n" + "st1 {v1.8h}, [x0], #16\n" + "st1 {v2.8h}, [x0], #16\n" + "st1 {v3.8h}, [x0], #16\n" + "st1 {v4.8h}, [x1], #16\n" + "st1 {v5.8h}, [x1], #16\n" + "st1 {v6.8h}, [x1], #16\n" + "st1 {v7.8h}, [x1], #16\n" + "b 1000f \n" + "101: \n" + "st1 {v0.8h}, [x0], #16\n" + "st1 {v1.8h}, [x0], #16\n" + "st1 {v2.8h}, [x0], #16\n" + "st1 {v3.d}[0], [x0], #8\n" + "st1 {v4.8h}, [x1], #16\n" + "st1 {v5.8h}, [x1], #16\n" + "st1 {v6.8h}, [x1], #16\n" + "st1 {v7.d}[0], [x1], #8\n" + "b 1000f \n" + "102: \n" + "st1 {v0.8h}, [x0], #16\n" + "st1 {v1.8h}, [x0], #16\n" + "st1 {v2.8h}, [x0], #16\n" + "st1 {v4.8h}, [x1], #16\n" + "st1 {v5.8h}, [x1], #16\n" + "st1 {v6.8h}, [x1], #16\n" + "b 1000f \n" + "103: \n" + "st1 {v0.8h}, [x0], #16\n" + "st1 {v1.8h}, [x0], #16\n" + "st1 {v2.d}[0], [x0], #8\n" + "st1 {v4.8h}, [x1], #16\n" + "st1 {v5.8h}, [x1], #16\n" + "st1 {v6.d}[0], [x1], #8\n" + "b 1000f \n" + "104: \n" + "st1 {v0.8h}, [x0], #16\n" + "st1 {v1.8h}, [x0], #16\n" + "st1 {v4.8h}, [x1], #16\n" + "st1 {v5.8h}, [x1], #16\n" + "b 1000f \n" + "105: \n" + "st1 {v0.8h}, [x0], #16\n" + "st1 {v1.d}[0], [x0], #8\n" + "st1 {v4.8h}, [x1], #16\n" + "st1 {v5.d}[0], [x1], #8\n" + "b 1000f \n" + "106: \n" + "st1 {v0.8h}, [x0], #16\n" + "st1 {v4.8h}, [x1], #16\n" + "b 1000f \n" + "107: \n" + "st1 {v0.d}[0], [x0], #8\n" + "st1 {v4.d}[0], [x1], #8\n" + "b 1000f \n" + "5: \n" + "cmp %x[n_remain], #8\n" + "beq 200f \n" + "cmp %x[n_remain], #7\n" + "beq 201f \n" + "cmp %x[n_remain], #6\n" + "beq 202f \n" + "cmp %x[n_remain], #5\n" + "beq 203f \n" + "cmp %x[n_remain], #4\n" + "beq 204f \n" + "cmp %x[n_remain], #3\n" + "beq 205f \n" + "cmp %x[n_remain], #2\n" + "beq 206f \n" + "cmp %x[n_remain], #1\n" + "beq 207f \n" + "200: \n" + "st1 {v0.8h}, [x0], #16\n" + "st1 {v1.8h}, [x0], #16\n" + "st1 {v2.8h}, [x0], #16\n" + "st1 {v3.8h}, [x0], #16\n" + "b 1000f \n" + "201: \n" + "st1 {v0.8h}, [x0], #16\n" + "st1 {v1.8h}, [x0], #16\n" + "st1 {v2.8h}, [x0], #16\n" + "st1 {v3.d}[0], [x0], #8\n" + "b 1000f \n" + "202: \n" + "st1 {v0.8h}, [x0], #16\n" + "st1 {v1.8h}, [x0], #16\n" + "st1 {v2.8h}, [x0], #16\n" + "b 1000f \n" + "203: \n" + "st1 {v0.8h}, [x0], #16\n" + "st1 {v1.8h}, [x0], #16\n" + "st1 {v2.d}[0], [x0], #8\n" + "b 1000f \n" + "204: \n" + "st1 {v0.8h}, [x0], #16\n" + "st1 {v1.8h}, [x0], #16\n" + "b 1000f \n" + "205: \n" + "st1 {v0.8h}, [x0], #16\n" + "st1 {v1.d}[0], [x0], #8\n" + "b 1000f \n" + "206: \n" + "st1 {v0.8h}, [x0], #16\n" + "b 1000f \n" + "207: \n" + "st1 {v0.d}[0], [x0], #8\n" + "b 1000f \n" + + "1000: \n" + : + [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), + [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), + [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), + [ n_remain ] "+r"(n_remain) + : + : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31"); +// clang-format on + +#undef LOAD_C_8 +#undef STORE_C_8 +} + + +static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, int m_remain, + int n_remain) { + K /= 8; + LDC = LDC * sizeof(int16_t); + const int8_t* a_ptr = packB;//packA; + const int8_t* b_ptr = packA;//packB; +// clang-format off +#define LOAD_C_4 \ + "ld1 {v0.8h}, [x0], #16\n" \ + "ld1 {v1.8h}, [x0], #16\n" \ + "ld1 {v2.8h}, [x0], #16\n" \ + "ld1 {v3.8h}, [x0], #16\n" \ + + +#define STORE_C_4 \ + "st1 {v0.8h}, [x0], #16\n" \ + "st1 {v1.8h}, [x0], #16\n" \ + "st1 {v2.8h}, [x0], #16\n" \ + "st1 {v3.8h}, [x0], #16\n" \ + + register int16_t* outptr asm("x0") = output; + asm volatile( + + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + // General loop. + "ld1 {v20.16b}, [%[a_ptr]],#16\n" + "ld1 {v21.16b}, [%[a_ptr]],#16\n" + "PRFM PLDL1KEEP, [%[a_ptr], #512]\n" + "PRFM PLDL1KEEP, [%[b_ptr], #512]\n" + "1:\n" + "dup v0.8b,v20.b[0]\n" + "ld1 {v22.16b}, [%[a_ptr]],#16\n" + "dup v1.8b,v20.b[1]\n" + "ld1 {v23.16b}, [%[a_ptr]],#16\n" + "dup v2.8b,v20.b[2]\n" + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v3.8b,v20.b[3]\n" + "dup v4.8b,v20.b[4]\n" + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + "dup v5.8b,v20.b[5]\n" + "dup v6.8b,v20.b[6]\n" + "dup v7.8b,v20.b[7]\n" + + + "dup v8.8b,v20.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v20.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v20.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v20.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v20.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v20.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v20.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v20.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + + "dup v0.8b,v21.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v21.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v21.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v21.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v21.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v21.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v21.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v21.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + + "dup v8.8b,v21.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v21.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v21.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v21.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v21.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v21.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v21.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v21.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v0.8b,v22.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v22.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v22.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v22.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v22.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v22.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v22.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v22.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + + "dup v8.8b,v22.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v22.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v22.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v22.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v22.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v22.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v22.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v22.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v0.8b,v23.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v23.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v23.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v23.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v23.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v23.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v23.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v23.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + "dup v8.8b,v23.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v23.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v23.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v23.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v23.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v23.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v23.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v23.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v20.16b}, [%[a_ptr]],#16\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "ld1 {v21.16b}, [%[a_ptr]],#16\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "subs %w[K], %w[K], #1\n" + "cbnz %w[K], 1b\n" + + "cmp %w[is_first_k], #1\n" + "beq 2f\n" LOAD_C_4 + "b 3f \n" + "2: \n" + "eor v0.16b, v0.16b, v0.16b\n" + "eor v1.16b, v1.16b, v1.16b\n" + "eor v2.16b, v2.16b, v2.16b\n" + "eor v3.16b, v3.16b, v3.16b\n" + "eor v4.16b, v4.16b, v4.16b\n" + "eor v5.16b, v5.16b, v5.16b\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + "3:\n" + "zip1 v8.2d, v24.2d, v25.2d\n" + "zip1 v10.2d, v26.2d, v27.2d\n" + "add v0.8h, v0.8h, v8.8h\n" + "zip1 v12.2d, v28.2d, v29.2d\n" + "add v1.8h, v1.8h, v10.8h\n" + "zip1 v14.2d, v30.2d, v31.2d\n" + "add v2.8h, v2.8h, v12.8h\n" + "add v3.8h, v3.8h, v14.8h\n" + + // Store back into memory + STORE_C_4 + + : + [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), + [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), + [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), + [ n_remain ] "+r"(n_remain) + : + : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31"); +// clang-format on +#undef LOAD_C_4 +#undef STORE_C_4 +} +static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, int m_remain, + int n_remain) { + K /= 8; + LDC = LDC * sizeof(int16_t); + const int8_t* a_ptr = packB;//packA; + const int8_t* b_ptr = packA;//packB; +// clang-format off + register int16_t* outptr asm("x0") = output; + asm volatile( + + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + // General loop. + "ld1 {v20.16b}, [%[a_ptr]],#16\n" + "ld1 {v21.16b}, [%[a_ptr]],#16\n" + "PRFM PLDL1KEEP, [%[a_ptr], #512]\n" + "PRFM PLDL1KEEP, [%[b_ptr], #512]\n" + "1:\n" + "dup v0.8b,v20.b[0]\n" + "ld1 {v22.16b}, [%[a_ptr]],#16\n" + "dup v1.8b,v20.b[1]\n" + "ld1 {v23.16b}, [%[a_ptr]],#16\n" + "dup v2.8b,v20.b[2]\n" + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v3.8b,v20.b[3]\n" + "dup v4.8b,v20.b[4]\n" + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + "dup v5.8b,v20.b[5]\n" + "dup v6.8b,v20.b[6]\n" + "dup v7.8b,v20.b[7]\n" + + "dup v8.8b,v20.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v20.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v20.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v20.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v20.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v20.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v20.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v20.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + + "dup v0.8b,v21.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v21.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v21.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v21.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v21.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v21.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v21.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v21.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + + "dup v8.8b,v21.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v21.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v21.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v21.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v21.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v21.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v21.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v21.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v0.8b,v22.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v22.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v22.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v22.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v22.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v22.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v22.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v22.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + + "dup v8.8b,v22.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v22.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v22.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v22.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v22.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v22.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v22.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v22.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v16.8b}, [%[b_ptr]], 8\n" + "dup v0.8b,v23.b[0]\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "dup v1.8b,v23.b[1]\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "dup v2.8b,v23.b[2]\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "dup v3.8b,v23.b[3]\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "dup v4.8b,v23.b[4]\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "dup v5.8b,v23.b[5]\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "dup v6.8b,v23.b[6]\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "dup v7.8b,v23.b[7]\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "ld1 {v17.8b}, [%[b_ptr]], 8\n" + "dup v8.8b,v23.b[8]\n" + "smlal v24.8h, v0.8b, v16.8b\n" + "dup v9.8b,v23.b[9]\n" + "smlal v25.8h, v1.8b, v16.8b\n" + "dup v10.8b,v23.b[10]\n" + "smlal v26.8h, v2.8b, v16.8b\n" + "dup v11.8b,v23.b[11]\n" + "smlal v27.8h, v3.8b, v16.8b\n" + "dup v12.8b,v23.b[12]\n" + "smlal v28.8h, v4.8b, v16.8b\n" + "dup v13.8b,v23.b[13]\n" + "smlal v29.8h, v5.8b, v16.8b\n" + "dup v14.8b,v23.b[14]\n" + "smlal v30.8h, v6.8b, v16.8b\n" + "dup v15.8b,v23.b[15]\n" + "smlal v31.8h, v7.8b, v16.8b\n" + + "ld1 {v20.16b}, [%[a_ptr]],#16\n" + "smlal v24.8h, v8.8b, v17.8b\n" + "smlal v25.8h, v9.8b, v17.8b\n" + "smlal v26.8h, v10.8b, v17.8b\n" + "smlal v27.8h, v11.8b, v17.8b\n" + "ld1 {v21.16b}, [%[a_ptr]],#16\n" + "smlal v28.8h, v12.8b, v17.8b\n" + "smlal v29.8h, v13.8b, v17.8b\n" + "smlal v30.8h, v14.8b, v17.8b\n" + "smlal v31.8h, v15.8b, v17.8b\n" + + "subs %w[K], %w[K], #1 \n" + "cbnz %w[K], 1b \n" + "cmp %w[is_first_k], #1 \n" + "beq 2f \n" + "cmp %w[n_remain],#7 \n" + "beq 200f \n" + "cmp %w[n_remain],#6 \n" + "beq 201f \n" + "cmp %w[n_remain],#5 \n" + "beq 202f \n" + "cmp %w[n_remain],#4 \n" + "beq 203f \n" + "cmp %w[n_remain],#3 \n" + "beq 204f \n" + "cmp %w[n_remain],#2 \n" + "beq 205f \n" + "cmp %w[n_remain],#1 \n" + "beq 206f \n" + "200: \n" + "ld1 {v0.8h}, [x0],#16 \n" + "ld1 {v1.8h}, [x0],#16 \n" + "ld1 {v2.8h}, [x0],#16 \n" + "ld1 {v3.d}[0], [x0],#8 \n" + "b 3f \n" + "201: \n" + "ld1 {v0.8h}, [x0],#16 \n" + "ld1 {v1.8h}, [x0],#16 \n" + "ld1 {v2.8h}, [x0],#16 \n" + "b 3f \n" + "202: \n" + "ld1 {v0.8h}, [x0],#16 \n" + "ld1 {v1.8h}, [x0],#16 \n" + "ld1 {v2.d}[0], [x0],#8 \n" + "b 3f \n" + "203: \n" + "ld1 {v0.8h}, [x0],#16 \n" + "ld1 {v1.8h}, [x0],#16 \n" + "b 3f \n" + "204: \n" + "ld1 {v0.8h}, [x0],#16 \n" + "ld1 {v1.d}[0], [x0],#8 \n" + "b 3f \n" + "205: \n" + "ld1 {v0.8h}, [x0],#16 \n" + "b 3f \n" + "206: \n" + "ld1 {v0.d}[0], [x0],#8 \n" + "b 3f \n" + "2: \n" + "eor v0.16b, v0.16b, v0.16b\n" + "eor v1.16b, v1.16b, v1.16b\n" + "eor v2.16b, v2.16b, v2.16b\n" + "eor v3.16b, v3.16b, v3.16b\n" + "eor v4.16b, v4.16b, v4.16b\n" + "eor v5.16b, v5.16b, v5.16b\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + "3: \n" + "zip1 v8.2d, v24.2d, v25.2d\n" + "zip1 v10.2d, v26.2d, v27.2d\n" + "add v0.8h, v0.8h, v8.8h \n" + "zip1 v12.2d, v28.2d, v29.2d\n" + "add v1.8h, v1.8h, v10.8h\n" + "zip1 v14.2d, v30.2d, v31.2d\n" + "add v2.8h, v2.8h, v12.8h\n" + "add v3.8h, v3.8h, v14.8h\n" + + // Store back into memory + "cmp %w[n_remain],#7 \n" + "beq 100f \n" + "cmp %w[n_remain],#6 \n" + "beq 101f \n" + "cmp %w[n_remain],#5 \n" + "beq 102f \n" + "cmp %w[n_remain],#4 \n" + "beq 103f \n" + "cmp %w[n_remain],#3 \n" + "beq 104f \n" + "cmp %w[n_remain],#2 \n" + "beq 105f \n" + "cmp %w[n_remain],#1 \n" + "beq 106f \n" + "100: \n" + "st1 {v0.8h}, [x0],#16 \n" + "st1 {v1.8h}, [x0],#16 \n" + "st1 {v2.8h}, [x0],#16 \n" + "st1 {v3.d}[0], [x0],#8 \n" + "b 1000f \n" + "101: \n" + "st1 {v0.8h}, [x0],#16 \n" + "st1 {v1.8h}, [x0],#16 \n" + "st1 {v2.8h}, [x0],#16 \n" + "b 1000f \n" + "102: \n" + "st1 {v0.8h}, [x0],#16 \n" + "st1 {v1.8h}, [x0],#16 \n" + "st1 {v2.d}[0], [x0],#8 \n" + "b 1000f \n" + "103: \n" + "st1 {v0.8h}, [x0],#16 \n" + "st1 {v1.8h}, [x0],#16 \n" + "b 1000f \n" + "104: \n" + "st1 {v0.8h}, [x0],#16 \n" + "st1 {v1.d}[0], [x0],#8 \n" + "b 1000f \n" + "105: \n" + "st1 {v0.8h}, [x0],#16 \n" + "b 1000f \n" + "106: \n" + "st1 {v0.d}[0], [x0],#8 \n" + "b 1000f \n" + "1000: \n" + : + [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), + [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), + [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), + [ n_remain ] "+r"(n_remain) + : + : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31"); +// clang-format on +#undef LOAD_C_4 +#undef STORE_C_4 +} + + +//! pack to icxoc +//! (M/4,K/4,4(K),4(M)) pack to (M/8,k/8,8(K_ic_0~3_ic_4~7),8(M_oc0~3_OC_4~7)) +//! if M K is not times of 8,pack 0 instead +static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, + const dt_int8* inptr, int ldin, + int m0, int mmax, int k0, int kmax) { + megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); + megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); + constexpr int pack_m = 8; + constexpr int pack_k = 8; + constexpr int pack_size = 4; + int8_t tmpbuff0[pack_m * pack_size] = {0}; + int8_t tmpbuff1[pack_m * pack_size] = {0}; + int8_t zerobuff[pack_m * pack_size] = {0}; + const int m_size = mmax - m0; + const int m_end = m_size / pack_m * pack_m + m0; + int remain_m = mmax - m_end; + + for (int m_idx = m0; m_idx < m_end; m_idx += pack_m) { + const int8_t* inptr0 = inptr + m_idx / pack_size * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + int k_idx = k0; + for ( ; k_idx + 7 < kmax; k_idx += pack_k) { + interleave_8x8_mk4_b(inptr0,inptr1,outptr); + } + + if (k_idx < kmax) { + memcpy(tmpbuff0, inptr0, sizeof(int8_t) * (kmax - k_idx) * pack_size); + memcpy(tmpbuff1, inptr1, sizeof(int8_t) * (kmax - k_idx) * pack_size); + inptr0 = tmpbuff0; + inptr1 = tmpbuff1; + interleave_8x8_mk4_b(inptr0, inptr1, outptr); + } + } + int m_idx = m_end; + if (remain_m == 4) { + const int8_t* inptr0 = inptr + m_idx / pack_size * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + int k_idx = k0; + for ( ; k_idx + 7 < kmax; k_idx += pack_k) { + inptr1 = zerobuff; + interleave_8x8_mk4_b(inptr0,inptr1,outptr); + } + + if (k_idx < kmax) { + memcpy(tmpbuff0, inptr0, sizeof(int8_t) * (kmax - k_idx) * pack_size); + inptr0 = tmpbuff0; + inptr1 = zerobuff; + interleave_8x8_mk4_b(inptr0, inptr1, outptr); + } + } +} +//! pack to nxic +//! (K/4,N,4) pack to K/8,N,8(ic0~7) ,K is not times of 8 ,pack 0 instead. +static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, + int ldin, int n0, int nmax, int k0, + int kmax) { + megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); + + constexpr int pack_n = 8; + constexpr int pack_k = 8; + constexpr int pack_size = 4; + int8_t tmpbuff0[pack_n * pack_size] = {0}; + int8_t tmpbuff1[pack_n * pack_size] = {0}; + int8_t zerobuff[pack_n * pack_size] = {0}; + const int ksize = round_up((kmax - k0),8); + const int nsize = nmax - n0; + const int n_end = nsize / pack_n * pack_n + n0; + const int remain_n = nsize % pack_n; + int output_stride = ksize * pack_n; + int8_t* outptr_base = out; + int k_idx = k0; + for ( ; k_idx + 7 < kmax; k_idx += pack_k) { + const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; + const int8_t* inptr1 = inptr0 + ldin; + prefetch_3x(inptr0); + prefetch_3x(inptr1); + + auto outptr = outptr_base; + for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) { + transpose_8x8_mk4_b(inptr0, inptr1, outptr); + outptr += output_stride; + } + if (remain_n > 0) { + memcpy(tmpbuff0, inptr0, sizeof(int8_t) * remain_n * pack_size); + memcpy(tmpbuff1, inptr1, sizeof(int8_t) * remain_n * pack_size); + inptr0 = tmpbuff0; + inptr1 = tmpbuff1; + transpose_8x8_mk4_b(inptr0, inptr1, outptr); + outptr += output_stride; + } + outptr_base += pack_n * pack_k; + } + + if(k_idx < kmax){ + const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; + const int8_t* inptr1 = nullptr; + prefetch_3x(inptr0); + auto outptr = outptr_base; + for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) { + inptr1 = zerobuff; + transpose_8x8_mk4_b(inptr0, inptr1, outptr); + outptr += output_stride; + } + if (remain_n > 0) { + memcpy(tmpbuff0, inptr0, sizeof(int8_t) * remain_n * pack_size); + inptr1 = zerobuff; + inptr0 = tmpbuff0; + transpose_8x8_mk4_b(inptr0, inptr1, outptr); + outptr += output_stride; + } + outptr_base += pack_n * pack_size; + } +} + +} // namespace matmul_mk4_16x12x4_a53 +} // namespace aarch64 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp b/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp index d96ce964..cd2f2e13 100644 --- a/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp @@ -13,6 +13,7 @@ #include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" #include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" +#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h" #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h" #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h" #include "src/aarch64/matrix_mul/int8x8x16/strategy.h" @@ -357,4 +358,81 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, } } +// ===========================gemm_s8x8x16_mk4_8x8x8================================== +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8x8); + +void gemm_s8x8x16_mk4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, + int ldin, int y0, int ymax, int k0, + int kmax, bool) const { + matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0, + ymax, k0, kmax); +} + +void gemm_s8x8x16_mk4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, + int ldin, int x0, int xmax, int k0, + int kmax, bool) const { + matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0, + xmax, k0, kmax); +} + +void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, + size_t M, size_t N, size_t K, dt_int16* C, + size_t LDC, bool is_first_k, const dt_int16*, + dt_int16*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + C_dtype.enumv() == DTypeEnum::Int16 && + A_dtype.enumv() == DTypeEnum::Int8); + megdnn_assert(is_first_k == true, "only impl is_first_k"); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); + + constexpr size_t pack_size = 4; + constexpr size_t pack_m = 8; + constexpr size_t pack_n = 8; + const size_t remain_n = N % pack_n; + size_t remain_m = M % pack_m; + K = round_up(K, 8); + size_t KSIZE8 = K * pack_n; + size_t m_idx = 0; + for (; m_idx + pack_m <= M; m_idx += pack_m) { + int16_t* output = C + (m_idx / pack_size * LDC); + + size_t n_idx = 0; + const int8_t* cur_packB = packB; + for (; n_idx + pack_n <= N; n_idx += pack_n) { + matmul_mk4_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, + is_first_k, pack_m, pack_n); + output += pack_n * pack_size; + cur_packB += KSIZE8; + } + if (remain_n > 0) { + matmul_mk4_8x8x8::kern_8x8_remain(packA, cur_packB, K, output, LDC, + is_first_k, pack_m, remain_n); + output += remain_n * pack_size; + cur_packB += KSIZE8; + } + packA += KSIZE8; + } + + if (remain_m == 4) { + int16_t* output = C + (m_idx / pack_size * LDC); + size_t n_idx = 0; + const int8_t* cur_packB = packB; + for (; n_idx + pack_n <= N; n_idx += pack_n) { + matmul_mk4_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, + is_first_k, 4, pack_n); + output += pack_n * pack_size; + cur_packB += pack_n * K; + } + if (remain_n > 0) { + matmul_mk4_8x8x8::kern_4x8_remain(packA, cur_packB, K, output, LDC, + is_first_k, 4, remain_n); + output += remain_n * pack_size; + cur_packB += pack_n * K; + } + } +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h b/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h index 61c2dddf..0b320712 100644 --- a/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h @@ -26,6 +26,8 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 8, false, false, MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int16, 16, 12, 4, false, false, gemm_s8x8x16_mk4_16x12_a53); +MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, false, + gemm_s8x8x16_mk4_8x8x8); } // namespace matmul } // namespace aarch64 diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.cpp b/dnn/src/aarch64/matrix_mul/opr_impl.cpp index 9c19594c..2b7614eb 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.cpp +++ b/dnn/src/aarch64/matrix_mul/opr_impl.cpp @@ -39,6 +39,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; AlgoInt8x8x16MK4_16x12x4 int8x8x16_mk4_16x12x4; AlgoInt8x8x16MK4_4x4x8 int8x8x16_mk4_4x4x8; + AlgoInt8x8x16MK4_K8x8x8 int8x8x16_mk4_k8x8x8; AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1; AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8; @@ -73,6 +74,7 @@ public: #endif all_algos.emplace_back(&int8x8x16_k4x4x16); all_algos.emplace_back(&int8x8x16_k8x8x8); + all_algos.emplace_back(&int8x8x16_mk4_k8x8x8); all_algos.emplace_back(&int8x8x16_mk4_4x4x8); all_algos.emplace_back(&int8x8x16_mk4_16x12x4); diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.h b/dnn/src/aarch64/matrix_mul/opr_impl.h index 1a504982..906cbc85 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.h +++ b/dnn/src/aarch64/matrix_mul/opr_impl.h @@ -57,6 +57,7 @@ private: #else class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 #endif + class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 class AlgoPack; }; diff --git a/dnn/test/aarch64/matrix_mul.cpp b/dnn/test/aarch64/matrix_mul.cpp index 1b09be71..bc9169bd 100644 --- a/dnn/test/aarch64/matrix_mul.cpp +++ b/dnn/test/aarch64/matrix_mul.cpp @@ -122,6 +122,20 @@ TEST_F(AARCH64, MATRIX_MUL_INT8_MK4) { std::move(args)); } +TEST_F(AARCH64, MATRIX_MUL_INT8x8x16_MK4) { + std::vector args; + for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}) + for (size_t n : + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 24}) + for (size_t k : + {2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29}) + args.emplace_back(m, n, k, 0); + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, + handle(), "AARCH64_INT8X8X16_MK4_K8X8X8", + param::MatrixMul::Format::MK4, 1, 1e-3, + std::move(args)); +} TEST_F(AARCH64, MATRIX_MUL_MK4_8x8x16_4x4) { matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, handle(), "AARCH64_INT8X8X16_MK4_4X4X8", @@ -396,6 +410,71 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x16) { run(384, 384, 384); } +TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x8_8x8x16_vs_4x4x16_8x8x16) { + constexpr size_t RUNS = 50; + param::MatrixMul param; + param.transposeA = false; + param.transposeB = false; + Benchmarker benchmarker(handle()); + Benchmarker benchmarker_mk4(handle()); + Benchmarker benchmarker_mk4_4x4x8(handle()); + benchmarker.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int16{}) + .set_param(param) + .set_display(false); + benchmarker.set_before_exec_callback( + AlgoChecker("AARCH64_INT8X8X16_K4X4X16")); + + param.format = MatrixMul::Param::Format::MK4; + benchmarker_mk4.set_before_exec_callback( + AlgoChecker( + "AARCH64_INT8X8X16_MK4_K8X8X8" + )); + benchmarker_mk4.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int16{}) + .set_param(param) + .set_display(false); + + benchmarker_mk4_4x4x8.set_before_exec_callback( + AlgoChecker("AARCH64_INT8X8X16_MK4_4X4X8")); + benchmarker_mk4_4x4x8.set_times(RUNS) + .set_dtype(0, dtype::Int8{}) + .set_dtype(1, dtype::Int8{}) + .set_dtype(2, dtype::Int16{}) + .set_param(param) + .set_display(false); + + auto run = [&](size_t M, size_t N, size_t K) { + auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS; + auto mk_used = benchmarker_mk4.exec( + {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / + RUNS; + auto mk4_4x4x8_used = + benchmarker_mk4_4x4x8.exec( + {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / + RUNS; + float computations = 2.f * M * K * N * 1e-6; + printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms " + "%f Gflops speedup: %f, mk4_4x4x8 %f Gflops %f ms speedup: %f\n", + M, K, N, default_used, computations / default_used, mk_used, + computations / mk_used, default_used / mk_used, + computations / mk4_4x4x8_used, mk4_4x4x8_used , mk4_4x4x8_used/mk_used); + }; + + run(384, 384, 384); + run(512, 512, 512); + run(1024, 1024, 384); + run(256, 256, 384); + for(int m = 32; m <= 512;m*=2) + for(int n = 32; n <= 512;n*=2) + for(int k = 32; k < 512;k*=2){ + run(m,n,k); + } +} TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_4X4X16) { constexpr size_t RUNS = 50; param::MatrixMul param;