From eed54081aba819d36bc4698d3603f91b351c0353 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 23 Jul 2020 13:28:15 +0800 Subject: [PATCH] feat(dnn/arm): add armv7 mk4 i8i8i16 gemm, optimized for A7 GitOrigin-RevId: d2f8290a8d6577b99adad16e42d57a6ca55a119e --- dnn/src/aarch64/matrix_mul/algos.cpp | 11 +- dnn/src/armv7/matrix_mul/algos.cpp | 79 +++- dnn/src/armv7/matrix_mul/algos.h | 15 +- dnn/src/armv7/matrix_mul/asm/common.h | 133 +++++-- dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h | 28 +- .../armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h | 406 +++++++++++++++++++++ dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp | 83 ++++- dnn/src/armv7/matrix_mul/int8x8x16/strategy.h | 7 +- dnn/src/armv7/matrix_mul/opr_impl.cpp | 10 +- dnn/src/armv7/matrix_mul/opr_impl.h | 31 +- dnn/src/fallback/conv_bias/conv1x1/algos.cpp | 5 +- dnn/src/fallback/conv_bias/im2col/factory.h | 34 +- dnn/src/fallback/matrix_mul/opr_impl.h | 21 +- dnn/test/arm_common/conv_bias.cpp | 71 ++-- dnn/test/arm_common/conv_bias_multi_thread.cpp | 31 +- dnn/test/armv7/matrix_mul.cpp | 68 +++- dnn/test/common/matrix_mul.cpp | 9 +- 17 files changed, 890 insertions(+), 152 deletions(-) create mode 100644 dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h diff --git a/dnn/src/aarch64/matrix_mul/algos.cpp b/dnn/src/aarch64/matrix_mul/algos.cpp index 7ef15b45..a97acc2f 100644 --- a/dnn/src/aarch64/matrix_mul/algos.cpp +++ b/dnn/src/aarch64/matrix_mul/algos.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/aarch64/matrix_mul/algos.h" @@ -733,7 +734,9 @@ void int8x8x16_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::usable( const KernSizeParam& kern_size_param) const { - return can_be_treated_as_int8x8x16(kern_size_param); + return can_be_treated_as_int8x8x16(kern_size_param) && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT; } bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::preferred( @@ -796,7 +799,9 @@ void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) { bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::usable( const KernSizeParam& kern_size_param) const { - return can_be_treated_as_int8x8x16(kern_size_param); + return can_be_treated_as_int8x8x16(kern_size_param) && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT; } bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::preferred( diff --git a/dnn/src/armv7/matrix_mul/algos.cpp b/dnn/src/armv7/matrix_mul/algos.cpp index de2e0efe..8d8c2f93 100644 --- a/dnn/src/armv7/matrix_mul/algos.cpp +++ b/dnn/src/armv7/matrix_mul/algos.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/armv7/matrix_mul/algos.h" @@ -526,6 +527,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x8x8, "AlgoInt8x8x16K4x8x8"_hash, armv7::matmul::gemm_s8x8x16_4x8, int8_t, int16_t); + +/* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/ + +namespace { +void kern_int8x8x16_mk4_k8x8x4(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("kern_int8x8x16_mk4_k8x8x4"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto trA = kern_param.trA, trB = kern_param.trB; + + armv7::matmul::gemm_s8x8x16_mk4_8x8 strategy(M, N, K, kern_param.A_type, + kern_param.B_type, + kern_param.C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, + kern_param.workspace_ptr); + } + MIDOUT_END(); +} +} // anonymous namespace + +bool MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::usable( + const KernSizeParam& kern_size_param) const { + bool type_ok = can_be_treated_as_int8x8x16(kern_size_param); + + return type_ok && kern_size_param.format == param::MatrixMul::Format::MK4 && + kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + !kern_size_param.trA && !kern_size_param.trB && + kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; +} + +size_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x16K8x8x4::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, + K = kern_size_param.K; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + matmul::gemm_s8x8x16_mk4_8x8 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_kern( + const KernSizeParam&) const { + return kern_int8x8x16_mk4_k8x8x4; +} + +bool MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::preferred( + const KernSizeParam& kern_size_param) const { + return kern_size_param.K >= 4; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16MK4_8x8x4, + megdnn_armv7_matmul_kern, + "AlgoInt8x8x16MK4_8x8x4"_hash, + armv7::matmul::gemm_s8x8x16_mk4_8x8, + int8_t, int16_t, int16_t); + /* ===================== Int16x16x32 Kernel 12x4x1 algo ===================== */ namespace { @@ -937,11 +1006,9 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_4x8::get_kern( Bptr = kern_param.B(); auto Cptr = kern_param.C(); - armv7::matmul::gemm_nopack_f16_4x8 strategy(A_type, B_type, - C_type); - megdnn::matmul::GemmInterleaved< - armv7::matmul::gemm_nopack_f16_4x8, false>(M, N, K, trA, - trB, strategy) + armv7::matmul::gemm_nopack_f16_4x8 strategy(A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved(M, N, K, trA, trB, strategy) .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } diff --git a/dnn/src/armv7/matrix_mul/algos.h b/dnn/src/armv7/matrix_mul/algos.h index 9a509b46..60e35f1b 100644 --- a/dnn/src/armv7/matrix_mul/algos.h +++ b/dnn/src/armv7/matrix_mul/algos.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -171,6 +172,18 @@ public: MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); }; +class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; } + bool usable(const KernSizeParam&) const override; + bool preferred(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + void* type() const override { return sm_arm_common_algo_type; } + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); +}; + class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase { public: bool is_reproducible() const override { return true; } diff --git a/dnn/src/armv7/matrix_mul/asm/common.h b/dnn/src/armv7/matrix_mul/asm/common.h index 820322e5..442dc67f 100644 --- a/dnn/src/armv7/matrix_mul/asm/common.h +++ b/dnn/src/armv7/matrix_mul/asm/common.h @@ -6,13 +6,15 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include #include #include #include +#include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" @@ -172,7 +174,6 @@ static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1, [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) : : "q0", "q1", "q2", "q3", "memory"); - } template @@ -183,12 +184,12 @@ static inline void interleave_4x4_4_b(const T*& inptr0, const T*& inptr1, std::is_same::value || std::is_same::value, "interleave_4x4_4_b only support uint8_t and int8_t"); asm volatile( - "vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3 - "vld1.32 {d2, d3}, [%[inptr1]]!\n" // B0B1B2B3 - "vld1.32 {d4, d5}, [%[inptr2]]!\n" // C0C1C2C3 - "vld1.32 {d6, d7}, [%[inptr3]]!\n" // D0D1D2D3 - "vtrn.32 q0, q1\n" // A0B0A2B2 A1B1A3B3 - "vtrn.32 q2, q3\n" // C0D0C2D2 C1D1C3D3 + "vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3 + "vld1.32 {d2, d3}, [%[inptr1]]!\n" // B0B1B2B3 + "vld1.32 {d4, d5}, [%[inptr2]]!\n" // C0C1C2C3 + "vld1.32 {d6, d7}, [%[inptr3]]!\n" // D0D1D2D3 + "vtrn.32 q0, q1\n" // A0B0A2B2 A1B1A3B3 + "vtrn.32 q2, q3\n" // C0D0C2D2 C1D1C3D3 "vswp d1, d4 \n" // q0=A0,B0,C0,D0 q2=A2,B2,C2,D2 "vswp d3, d6 \n" // q1=A1,B1,C1,D1 q3=A3,B3,C3,D3 "vst1.32 {d0-d1},[%[outptr]]!\n" @@ -323,10 +324,10 @@ static inline void interleave_6x4_8_b(const T*& inptr0, const T*& inptr1, "vtrn.32 q1, q3 \n" // q1=r02,r12,r03,r13 q3=r06,r16,r07,r17 "vtrn.32 q5, q7 \n" // q5=r22,r32,r23,r33 q7=r26,r36,r27,r37 "vtrn.32 q9, q11 \n" // q9=r42,r52,r43,r53 q11=r46,r56,r47,r57 - "vst1.32 {d0-d1}, [%[outptr]]! \n" - "vst1.32 {d16}, [%[outptr]]! \n" + "vst1.32 {d0-d1}, [%[outptr]]! \n" + "vst1.32 {d16}, [%[outptr]]! \n" "vswp d3, d10 \n" // q1=r02,r12,r22,r32 q5=r03,r13,r23,r33 - "vst1.32 {d8-d9}, [%[outptr]]! \n" + "vst1.32 {d8-d9}, [%[outptr]]! \n" "vst1.32 {d17}, [%[outptr]]! \n" "vst1.32 {d2-d3}, [%[outptr]]!\n" "vst1.32 {d18}, [%[outptr]]!\n" @@ -810,15 +811,15 @@ static inline void transpose_12x4_1_h(const T*& inptr0, const T*& inptr1, "interleave_12x4_1_h only support uint16_t and int16_t"); auto ldin_asm = ldin << 1; asm volatile( - "vld1.16 {d0}, [%[inptr0]]!\n" // A0A1A2A3 - "vld1.16 {d1}, [%[inptr1]]!\n" // B0B1B2B3 - "vld1.16 {d2}, [%[inptr2]]!\n" // C0C1C2C3 - "vld1.16 {d3}, [%[inptr3]]!\n" // D0D1D2D3 - "vld1.16 {d4}, [%[inptr4]]!\n" // E0E1E2E3 - "vld1.16 {d5}, [%[inptr5]]!\n" // F0F1F2F3 - "vld1.16 {d6}, [%[inptr6]]!\n" // G0G1G2G3 - "vld1.16 {d7}, [%[inptr7]]!\n" // H0H1H2H3 - "vld1.16 {d8}, [%[inptr8]]!\n" // I0I1I2I3 + "vld1.16 {d0}, [%[inptr0]]!\n" // A0A1A2A3 + "vld1.16 {d1}, [%[inptr1]]!\n" // B0B1B2B3 + "vld1.16 {d2}, [%[inptr2]]!\n" // C0C1C2C3 + "vld1.16 {d3}, [%[inptr3]]!\n" // D0D1D2D3 + "vld1.16 {d4}, [%[inptr4]]!\n" // E0E1E2E3 + "vld1.16 {d5}, [%[inptr5]]!\n" // F0F1F2F3 + "vld1.16 {d6}, [%[inptr6]]!\n" // G0G1G2G3 + "vld1.16 {d7}, [%[inptr7]]!\n" // H0H1H2H3 + "vld1.16 {d8}, [%[inptr8]]!\n" // I0I1I2I3 "vld1.16 {d9}, [%[inptr9]]\n" // J0J1J2J3 "add %[inptr9], %[inptr9], %[ldin_asm]\n" "vld1.16 {d10}, [%[inptr9]]\n" // K0K1K2K3 @@ -854,17 +855,15 @@ static inline void transpose_12x4_1_h(const T*& inptr0, const T*& inptr1, [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), [inptr9] "+r"(inptr9), [outptr] "+r"(outptr) - :[ldin_asm] "r"(ldin_asm) + : [ldin_asm] "r"(ldin_asm) : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", "memory"); - inptr9 -= ldin_asm; - inptr9 += 4; + inptr9 -= ldin_asm; + inptr9 += 4; inptr10 += 4; inptr11 += 4; } - - template static inline void transpose_2x16_1_b_helper(const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, @@ -1038,7 +1037,7 @@ static inline void transpose_4x4_1_s(const T*& inptr0, const T*& inptr1, "vst1.32 {d7}, [%[outptr]], %[stride]\n" : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr), [stride] "+r" (stride) + [outptr] "+r"(outptr), [stride] "+r"(stride) : : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "memory"); } @@ -1069,7 +1068,6 @@ static inline void transpose_4x2_1_s(const T*& inptr0, const T*& inptr1, : "d0", "d1", "d2", "d3", "memory"); } - template static inline void transpose_6x4_1_b(const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, @@ -1082,9 +1080,9 @@ static inline void transpose_6x4_1_b(const T*& inptr0, const T*& inptr1, "vld1.8 {d1}, [%[inptr1]]\n" // B0B1B2B3B4B5 B6B7 "vld1.8 {d2}, [%[inptr2]]\n" // C0C1C2C3C4C5 C6C7 "vld1.8 {d3}, [%[inptr3]]\n" // D0D1D2D3D4D5 D6D7 - "vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7 - "vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7 - + "vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7 + "vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7 + "add %[inptr0],%[inptr0],#6 \n" "add %[inptr1],%[inptr1],#6 \n" "add %[inptr2],%[inptr2],#6 \n" @@ -1121,9 +1119,9 @@ static inline void transpose_4x4_1_b(const T*& inptr0, const T*& inptr1, "vld1.8 {d1}, [%[inptr1]]\n" // B0B1B2B3B4B5 B6B7 "vld1.8 {d2}, [%[inptr2]]\n" // C0C1C2C3C4C5 C6C7 "vld1.8 {d3}, [%[inptr3]]\n" // D0D1D2D3D4D5 D6D7 - "vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7 - "vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7 - + "vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7 + "vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7 + "add %[inptr0],%[inptr0],#4 \n" "add %[inptr1],%[inptr1],#4 \n" "add %[inptr2],%[inptr2],#4 \n" @@ -1176,7 +1174,7 @@ static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { "vst1.32 {d6-d7}, [%[outptr]]! \n" "vst1.32 {d14-d15}, [%[outptr]]! \n" "vst1.32 {d22-d23}, [%[outptr]]! \n" - : [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) + : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) : : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "memory"); @@ -1195,12 +1193,11 @@ static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { "vst1.32 {d4-d5}, [%[outptr]]! \n" "vst1.32 {d2-d3}, [%[outptr]]! \n" "vst1.32 {d6-d7}, [%[outptr]]! \n" - : [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) + : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) : : "q0", "q1", "q2", "q3", "memory"); } - template static inline void transpose_4(const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, T* outptr, @@ -1251,7 +1248,6 @@ static inline void transpose_8(const T*& inptr0, const T*& inptr1, } } - template static inline void transpose_4x1(const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, @@ -1375,7 +1371,68 @@ static inline void transpose_interleave_1x4_4_b(const T*& inptr0, T* outptr, : "q0", "q1", "q2", "q3", "memory"); } -} // armv7 +static inline void interleave_4x4_8x4_s8_s16(const int8_t* inptr0, + const int8_t* inptr1, + int16_t* outptr) { + int8x16_t row0 = vld1q_s8(inptr0); + int16x8_t row0_01 = vmovl_low_s8(row0); + int16x8_t row0_23 = vmovl_high_s8(row0); + int16x4_t row0_0 = vget_low_s16(row0_01); + int16x4_t row0_1 = vget_high_s16(row0_01); + int16x4_t row0_2 = vget_low_s16(row0_23); + int16x4_t row0_3 = vget_high_s16(row0_23); + + int8x16_t row1 = vld1q_s8(inptr1); + int16x8_t row1_01 = vmovl_low_s8(row1); + int16x8_t row1_23 = vmovl_high_s8(row1); + int16x4_t row1_0 = vget_low_s16(row1_01); + int16x4_t row1_1 = vget_high_s16(row1_01); + int16x4_t row1_2 = vget_low_s16(row1_23); + int16x4_t row1_3 = vget_high_s16(row1_23); + + vst1_s16(outptr, row0_0); + vst1_s16(outptr + 1 * 4, row1_0); + vst1_s16(outptr + 2 * 4, row0_1); + vst1_s16(outptr + 3 * 4, row1_1); + vst1_s16(outptr + 4 * 4, row0_2); + vst1_s16(outptr + 5 * 4, row1_2); + vst1_s16(outptr + 6 * 4, row0_3); + vst1_s16(outptr + 7 * 4, row1_3); +}; + +static inline void transpos_8x4_int8(const int8_t* inptr0, int8_t* outptr) { + int8x8x4_t input = vld4_s8(inptr0); + vst1_s8(outptr, input.val[0]); + vst1_s8(outptr + 1 * 8, input.val[1]); + vst1_s8(outptr + 2 * 8, input.val[2]); + vst1_s8(outptr + 3 * 8, input.val[3]); +} +static inline void memcpy_s8_s16(const int8_t* inptr, int16_t* outptr, + int count) { + for (; count >= 32; count -= 32) { + int8x8_t in0 = vld1_s8(inptr); + int8x8_t in1 = vld1_s8(inptr + 1 * 8); + int8x8_t in2 = vld1_s8(inptr + 2 * 8); + int8x8_t in3 = vld1_s8(inptr + 3 * 8); + vst1q_s16(outptr, vmovl_s8(in0)); + vst1q_s16(outptr + 1 * 8, vmovl_s8(in1)); + vst1q_s16(outptr + 2 * 8, vmovl_s8(in2)); + vst1q_s16(outptr + 3 * 8, vmovl_s8(in3)); + inptr += 32; + outptr += 32; + } + for (; count >= 8; count -= 8) { + int8x8_t in0 = vld1_s8(inptr); + vst1q_s16(outptr, vmovl_s8(in0)); + inptr += 8; + outptr += 8; + } + for (; count > 0; --count) { + *outptr++ = (int16_t)(*inptr++); + } +} + +} // namespace armv7 } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h index cd3b9e22..0af6a9f6 100644 --- a/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h +++ b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h @@ -102,60 +102,60 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, "vld1.8 {d2}, [%[a_ptr]]!\n" "vld1.8 {d4}, [%[a_ptr]]!\n" "vld1.8 {d6}, [%[a_ptr]]!\n" + "vld1.8 {d18}, [%[b_ptr]]!\n" "vmovl.s8 q8, d16\n" "vmovl.s8 q0, d0\n" "vmovl.s8 q1, d2\n" "vmovl.s8 q2, d4\n" "vmovl.s8 q3, d6\n" - "vld1.8 {d18}, [%[b_ptr]]!\n" + "vmovl.s8 q9, d18\n" + "vld1.8 {d20}, [%[b_ptr]]!\n" "vmla.s16 q4, q8, d0[0]\n" "vmla.s16 q5, q8, d2[0]\n" "vmla.s16 q6, q8, d4[0]\n" "vmla.s16 q7, q8, d6[0]\n" - "vmovl.s8 q9, d18\n" - "vld1.8 {d20}, [%[b_ptr]]!\n" + "vmovl.s8 q10, d20\n" + "vld1.8 {d22}, [%[b_ptr]]!\n" "vmla.s16 q4, q9, d0[1]\n" "vmla.s16 q5, q9, d2[1]\n" "vmla.s16 q6, q9, d4[1]\n" "vmla.s16 q7, q9, d6[1]\n" - "vmovl.s8 q10, d20\n" - "vld1.8 {d22}, [%[b_ptr]]!\n" + "vmovl.s8 q11, d22\n" + "vld1.8 {d24}, [%[b_ptr]]!\n" "vmla.s16 q4, q10, d0[2]\n" "vmla.s16 q5, q10, d2[2]\n" "vmla.s16 q6, q10, d4[2]\n" "vmla.s16 q7, q10, d6[2]\n" - "vmovl.s8 q11, d22\n" - "vld1.8 {d24}, [%[b_ptr]]!\n" + "vmovl.s8 q12, d24\n" + "vld1.8 {d26}, [%[b_ptr]]!\n" "vmla.s16 q4, q11, d0[3]\n" "vmla.s16 q5, q11, d2[3]\n" "vmla.s16 q6, q11, d4[3]\n" "vmla.s16 q7, q11, d6[3]\n" - "vmovl.s8 q12, d24\n" - "vld1.8 {d26}, [%[b_ptr]]!\n" + "vmovl.s8 q13, d26\n" + "vld1.8 {d28}, [%[b_ptr]]!\n" "vmla.s16 q4, q12, d1[0]\n" "vmla.s16 q5, q12, d3[0]\n" "vmla.s16 q6, q12, d5[0]\n" "vmla.s16 q7, q12, d7[0]\n" - "vmovl.s8 q13, d26\n" - "vld1.8 {d28}, [%[b_ptr]]!\n" + "vmovl.s8 q14, d28\n" + "vld1.8 {d30}, [%[b_ptr]]!\n" "vmla.s16 q4, q13, d1[1]\n" "vmla.s16 q5, q13, d3[1]\n" "vmla.s16 q6, q13, d5[1]\n" "vmla.s16 q7, q13, d7[1]\n" - "vmovl.s8 q14, d28\n" - "vld1.8 {d30}, [%[b_ptr]]!\n" + "vmovl.s8 q15, d30\n" "vmla.s16 q4, q14, d1[2]\n" "vmla.s16 q5, q14, d3[2]\n" "vmla.s16 q6, q14, d5[2]\n" "vmla.s16 q7, q14, d7[2]\n" - "vmovl.s8 q15, d30\n" "vmla.s16 q4, q15, d1[3]\n" "vmla.s16 q5, q15, d3[3]\n" diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h new file mode 100644 index 00000000..59a2c2be --- /dev/null +++ b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h @@ -0,0 +1,406 @@ +/** + * \file dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" + +namespace megdnn { +namespace armv7 { +namespace matmul_mk4_8x8x4 { + +//! optimize for A7 + +/** + * Overview of register layout: + * + * A 8x8x8 cell of Lhs is stored in 16bit in q0, q1 + * A 8x8x8 cell of Rhs is stored in 8bit in q2, q3 + * A 8x8 block of accumulators is stored in 16bit in q8-q15 + * + * +--------+ + * | q4[0-8]| + * Rhs +--------+ + * Lhs | | + * + * +--------+ - - - - +--------- + * |q0[0]| | q8 [0-8]| + * |q0[1]| | q9 [0-8]| + * |q0[2]| | q10[0-8]| + * |q0[3]| | q11[0-8]| + * |q0[4]| | q12[0-8]| + * |q0[5]| | q13[0-8]| + * |q0[6]| | q14[0-8]| + * |q0[7]| | q15[0-8]| + * +--------+ - - - - +--------- + * + * Accumulator + */ +static void kern_8x8(const int16_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, int remain_n) { + K /= 4; + const int16_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int16_t); + int x0 = 0; + +// clang-format off +#define STORE_LINE(reg_index1, reg_index2) \ + "cmp %[x0], #0 \n" \ + "beq 101f\n" \ + "vst1.16 {d" reg_index1 "}, [r0]!\n" \ + "vst1.16 {d" reg_index2 "}, [r1]!\n" \ + "subs %[x0], %[x0], #1\n" + +#define STORE_C \ + "mov %[x0], %[remain_n]\n" \ + STORE_LINE("16", "17") \ + STORE_LINE("18", "19") \ + STORE_LINE("20", "21") \ + STORE_LINE("22", "23") \ + STORE_LINE("24", "25") \ + STORE_LINE("26", "27") \ + STORE_LINE("28", "29") \ + STORE_LINE("30", "31") \ + "101:\n" + + // clang-format on + + register int16_t* outptr asm("r0") = output; + asm volatile( + // load accumulator C + "add r1, r0, %[LDC]\n" + "cmp %[is_first_k], #1\n" + "beq 1f\n" + + "b 2f\n" + + "1:\n" + "veor.s32 q8 , q8 , q8 \n" + "veor.s32 q9 , q9 , q9 \n" + "veor.s32 q10, q10, q10\n" + "veor.s32 q11, q11, q11\n" + "veor.s32 q12, q12, q12\n" + "veor.s32 q13, q13, q13\n" + "veor.s32 q14, q14, q14\n" + "veor.s32 q15, q15, q15\n" + + "2: \n" + "vld1.8 {d4}, [%[b_ptr]]!\n" + "vld1.16 {d0, d1}, [%[a_ptr]]!\n" + "vmovl.s8 q2, d4\n" + "vld1.16 {d2, d3}, [%[a_ptr]]!\n" + "vld1.8 {d6}, [%[b_ptr]]!\n" + //! k0 + "vmla.s16 q8, q0, d4[0]\n" + "vmla.s16 q9, q0, d4[1]\n" + "vmla.s16 q10, q0, d4[2]\n" + "vmla.s16 q11, q0, d4[3]\n" + "vmovl.s8 q3, d6\n" + "vmla.s16 q12, q0, d5[0]\n" + "vmla.s16 q13, q0, d5[1]\n" + "vmla.s16 q14, q0, d5[2]\n" + "vmla.s16 q15, q0, d5[3]\n" + //! k1 + "vld1.16 {d0, d1}, [%[a_ptr]]!\n" + "vld1.8 {d4}, [%[b_ptr]]!\n" + "vmla.s16 q8, q1, d6[0]\n" + "vmla.s16 q9, q1, d6[1]\n" + "vmla.s16 q10, q1, d6[2]\n" + "vmla.s16 q11, q1, d6[3]\n" + "vmovl.s8 q2, d4\n" + "vmla.s16 q12, q1, d7[0]\n" + "vmla.s16 q13, q1, d7[1]\n" + "vmla.s16 q14, q1, d7[2]\n" + "vmla.s16 q15, q1, d7[3]\n" + //! k2 + "vld1.16 {d2, d3}, [%[a_ptr]]!\n" + "vld1.8 {d6}, [%[b_ptr]]!\n" + "vmla.s16 q8, q0, d4[0]\n" + "vmla.s16 q9, q0, d4[1]\n" + "vmla.s16 q10, q0, d4[2]\n" + "vmla.s16 q11, q0, d4[3]\n" + "vmovl.s8 q3, d6\n" + "vmla.s16 q12, q0, d5[0]\n" + "vmla.s16 q13, q0, d5[1]\n" + "vmla.s16 q14, q0, d5[2]\n" + "vmla.s16 q15, q0, d5[3]\n" + //! k3 + "vmla.s16 q8, q1, d6[0]\n" + "vmla.s16 q9, q1, d6[1]\n" + "vmla.s16 q10, q1, d6[2]\n" + "vmla.s16 q11, q1, d6[3]\n" + "vmla.s16 q12, q1, d7[0]\n" + "vmla.s16 q13, q1, d7[1]\n" + "vmla.s16 q14, q1, d7[2]\n" + "vmla.s16 q15, q1, d7[3]\n" + + "subs %[K], %[K], #1\n" + "bne 2b\n" + + "3:\n" + "cmp %[remain_n], #8\n" + "bne 4f\n" + "vstr d16, [r0]\n" + "vstr d18, [r0, #8]\n" + "vstr d20, [r0, #16]\n" + "vstr d22, [r0, #24]\n" + "vstr d24, [r0, #32]\n" + "vstr d26, [r0, #40]\n" + "vstr d28, [r0, #48]\n" + "vstr d30, [r0, #56]\n" + + "vstr d17, [r1]\n" + "vstr d19, [r1, #8]\n" + "vstr d21, [r1, #16]\n" + "vstr d23, [r1, #24]\n" + "vstr d25, [r1, #32]\n" + "vstr d27, [r1, #40]\n" + "vstr d29, [r1, #48]\n" + "vstr d31, [r1, #56]\n" + + "b 101f\n" + + "4:\n " STORE_C + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [x0] "+r"(x0), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", + "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", + "d29", "d30", "d31", "r1", "r2", "r3", "cc", "memory"); +#undef STORE_C +#undef STORE_LINE +} + +/** + * Overview of register layout: + * + * A 8x8x8 cell of Lhs is stored in 16bit in d0, d2 + * A 8x8x8 cell of Rhs is stored in 8bit in q2, q3 + * A 8x8 block of accumulators is stored in 16bit in q8-11 + * + * +--------+ + * | q4[0-8]| + * Rhs +--------+ + * Lhs | | + * + * +--------+ - - - - +--------- + * |d0[0]| | q8 [0-8]| + * |d0[1]| | q9 [0-8]| + * |d0[2]| | q10[0-8]| + * |d0[3]| | q11[0-8]| + * +--------+ - - - - +--------- + * + * Accumulator + */ +static void kern_4x8(const int16_t* packA, const int8_t* packB, int K, + int16_t* output, int LDC, bool is_first_k, int remain_n) { + K /= 4; + const int16_t* a_ptr = packA; + const int8_t* b_ptr = packB; + + LDC = LDC * sizeof(int16_t); + int x0 = 0; + +// clang-format off +#define STORE_LINE(reg_index1) \ + "cmp %[x0], #0 \n" \ + "beq 101f\n" \ + "vst1.16 {d" reg_index1 "}, [r0]!\n" \ + "subs %[x0], %[x0], #1\n" + +#define STORE_C \ + "mov %[x0], %[remain_n]\n" \ + STORE_LINE("16") \ + STORE_LINE("18") \ + STORE_LINE("20") \ + STORE_LINE("22") \ + STORE_LINE("24") \ + STORE_LINE("26") \ + STORE_LINE("28") \ + STORE_LINE("30") \ + "101:\n" + + // clang-format on + + register int16_t* outptr asm("r0") = output; + asm volatile( + //! load accumulator C + "add r1, r0, %[LDC]\n" + "cmp %[is_first_k], #1\n" + "beq 1f\n" + + "b 2f\n" + + "1:\n" + "veor.s32 q8 , q8 , q8 \n" + "veor.s32 q9 , q9 , q9 \n" + "veor.s32 q10, q10, q10\n" + "veor.s32 q11, q11, q11\n" + "veor.s32 q12, q12, q12\n" + "veor.s32 q13, q13, q13\n" + "veor.s32 q14, q14, q14\n" + "veor.s32 q15, q15, q15\n" + + "2: \n" + "vld1.8 {d4}, [%[b_ptr]]!\n" + "vld1.16 {d0}, [%[a_ptr]]!\n" + "vmovl.s8 q2, d4\n" + "vld1.16 {d2}, [%[a_ptr]]!\n" + "vld1.8 {d6}, [%[b_ptr]]!\n" + //! k0 + "vmla.s16 d16, d0, d4[0]\n" + "vmla.s16 d18, d0, d4[1]\n" + "vmla.s16 d20, d0, d4[2]\n" + "vmla.s16 d22, d0, d4[3]\n" + "vmovl.s8 q3, d6\n" + "vmla.s16 d24, d0, d5[0]\n" + "vmla.s16 d26, d0, d5[1]\n" + "vmla.s16 d28, d0, d5[2]\n" + "vmla.s16 d30, d0, d5[3]\n" + //! k1 + "vld1.16 {d0}, [%[a_ptr]]!\n" + "vld1.8 {d4}, [%[b_ptr]]!\n" + "vmla.s16 d16, d2, d6[0]\n" + "vmla.s16 d18, d2, d6[1]\n" + "vmla.s16 d20, d2, d6[2]\n" + "vmla.s16 d22, d2, d6[3]\n" + "vmovl.s8 q2, d4\n" + "vmla.s16 d24, d2, d7[0]\n" + "vmla.s16 d26, d2, d7[1]\n" + "vmla.s16 d28, d2, d7[2]\n" + "vmla.s16 d30, d2, d7[3]\n" + //! k2 + "vld1.16 {d2}, [%[a_ptr]]!\n" + "vld1.8 {d6}, [%[b_ptr]]!\n" + "vmla.s16 d16, d0, d4[0]\n" + "vmla.s16 d18, d0, d4[1]\n" + "vmla.s16 d20, d0, d4[2]\n" + "vmla.s16 d22, d0, d4[3]\n" + "vmovl.s8 q3, d6\n" + "vmla.s16 d24, d0, d5[0]\n" + "vmla.s16 d26, d0, d5[1]\n" + "vmla.s16 d28, d0, d5[2]\n" + "vmla.s16 d30, d0, d5[3]\n" + //! k3 + "vmla.s16 d16, d2, d6[0]\n" + "vmla.s16 d18, d2, d6[1]\n" + "vmla.s16 d20, d2, d6[2]\n" + "vmla.s16 d22, d2, d6[3]\n" + "vmla.s16 d24, d2, d7[0]\n" + "vmla.s16 d26, d2, d7[1]\n" + "vmla.s16 d28, d2, d7[2]\n" + "vmla.s16 d30, d2, d7[3]\n" + + "subs %[K], %[K], #1\n" + "bne 2b\n" + + "3:\n" + "cmp %[remain_n], #8\n" + "bne 4f\n" + "vstr d16, [r0]\n" + "vstr d18, [r0, #8]\n" + "vstr d20, [r0, #16]\n" + "vstr d22, [r0, #24]\n" + "vstr d24, [r0, #32]\n" + "vstr d26, [r0, #40]\n" + "vstr d28, [r0, #48]\n" + "vstr d30, [r0, #56]\n" + "b 101f\n" + + "4:\n " STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [x0] "+r"(x0), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), + [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", + "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", + "d29", "d30", "d31", "r1", "r2", "r3", "cc", "memory"); +#undef STORE_C +#undef STORE_LINE +} + +static void gemm_s8x8x16_mk4_8x8_pack_A_n(dt_int16* outptr, + const dt_int8* inptr, int ldin, + int m0, int mmax, int k0, int kmax) { + megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); + megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); + constexpr int pack_m = 8; + constexpr int pack_k = 4; + constexpr int pack_size = 4; + const int m_size = mmax - m0; + const int m_end = m_size / pack_m * pack_m + m0; + const int remain_m = mmax - m_end; + + for (int m_idx = m0; m_idx < m_end; m_idx += pack_m) { + const int8_t* inptr0 = inptr + m_idx / pack_size * ldin + k0; + const int8_t* inptr1 = inptr0 + ldin; + prefetch_2x(inptr0); + prefetch_2x(inptr1); + + for (int k_idx = k0; k_idx < kmax; k_idx += pack_size) { + interleave_4x4_8x4_s8_s16(inptr0, inptr1, outptr); + inptr0 += pack_size * pack_size; + inptr1 += pack_size * pack_size; + outptr += pack_m * pack_k; + } + } + if (remain_m > 0) { + const int8_t* inptr0 = inptr + m_end / pack_size * ldin + k0; + const int k_size = kmax - k0; + memcpy_s8_s16(inptr0, outptr, k_size * pack_size); + } +} + +static void gemm_s8x8x16_mk4_8x8_pack_B_n(dt_int8* out, const dt_int8* in, + int ldin, int n0, int nmax, int k0, + int kmax) { + megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); + int8_t tmpbuff[32] = {0}; + + constexpr int pack_n = 8; + constexpr int pack_size = 4; + const int ksize = kmax - k0; + const int nsize = nmax - n0; + const int n_end = nsize / pack_n * pack_n + n0; + const int remain_n = nsize % pack_n; + int output_stride = ksize * pack_n; + int8_t* outptr_base = out; + + for (int k_idx = k0; k_idx < kmax; k_idx += pack_size) { + const int8_t* inptr = in + k_idx / pack_size * ldin + n0 * pack_size; + prefetch_3x(inptr); + + auto outptr = outptr_base; + for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) { + transpos_8x4_int8(inptr, outptr); + inptr += pack_n * pack_size; + outptr += output_stride; + } + if (remain_n > 0) { + memcpy(tmpbuff, inptr, sizeof(int8_t) * remain_n * pack_size); + transpos_8x4_int8(tmpbuff, outptr); + outptr += output_stride; + } + outptr_base += pack_n * pack_size; + } +} + +} // namespace matmul_mk4_8x8x4 +} // namespace armv7 +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp b/dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp index 93494ef4..0a1bac5f 100644 --- a/dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp +++ b/dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp @@ -6,14 +6,16 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ -#include "src/armv7/matrix_mul/int8x8x16/strategy.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/armv7/matrix_mul/asm/common.h" #include "src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h" #include "src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h" +#include "src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h" +#include "src/armv7/matrix_mul/int8x8x16/strategy.h" #include "src/common/utils.h" #include "src/fallback/matrix_mul/gemm_common.h" @@ -108,7 +110,7 @@ void gemm_s8x8x16_4x2::kern(const dt_int8* packA, const dt_int8* packB, } } -// ===========================gemm_s8x8x16_4x4================================== +// ===========================gemm_s8x8x16_4x8================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x8); void gemm_s8x8x16_4x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, @@ -179,4 +181,79 @@ void gemm_s8x8x16_4x8::kern(const dt_int8* packA, const dt_int8* packB, } } +// ===========================gemm_s8x8x16_mk4_8x8================================== +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8); + +void gemm_s8x8x16_mk4_8x8::pack_A(dt_int16* out, const dt_int8* in, int ldin, + int y0, int ymax, int k0, int kmax, + bool) const { + matmul_mk4_8x8x4::gemm_s8x8x16_mk4_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, + kmax); +} + +void gemm_s8x8x16_mk4_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, + int x0, int xmax, int k0, int kmax, + bool) const { + matmul_mk4_8x8x4::gemm_s8x8x16_mk4_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, + kmax); +} + +void gemm_s8x8x16_mk4_8x8::kern(const dt_int16* packA, const dt_int8* packB, + size_t M, size_t N, size_t K, dt_int16* C, + size_t LDC, bool is_first_k, const dt_int16*, + dt_int16*) const { + megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && + C_dtype.enumv() == DTypeEnum::Int16 && + A_dtype.enumv() == DTypeEnum::Int8); + megdnn_assert(is_first_k == true, "only impl is_first_k"); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); + + constexpr size_t pack_size = 4; + constexpr size_t pack_m = 8; + constexpr size_t pack_n = 8; + const size_t remain_n = N % pack_n; + const size_t remain_m = M % pack_m; + + size_t m_idx = 0; + for (; m_idx + pack_m <= M; m_idx += pack_m) { + int16_t* output = C + (m_idx / pack_size * LDC); + + size_t n_idx = 0; + const int8_t* cur_packB = packB; + for (; n_idx + pack_n <= N; n_idx += pack_n) { + matmul_mk4_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, + is_first_k, pack_n); + output += pack_n * pack_size; + cur_packB += pack_n * K; + } + if (remain_n > 0) { + matmul_mk4_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, + is_first_k, remain_n); + output += remain_n * pack_size; + cur_packB += pack_n * K; + } + packA += pack_m * K; + } + if (remain_m > 0) { + int16_t* output = C + (m_idx / pack_size * LDC); + size_t n_idx = 0; + const int8_t* cur_packB = packB; + for (; n_idx + pack_n <= N; n_idx += pack_n) { + matmul_mk4_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, + is_first_k, pack_n); + output += pack_n * pack_size; + cur_packB += pack_n * K; + } + if (remain_n > 0) { + matmul_mk4_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, + is_first_k, remain_n); + output += remain_n * pack_size; + cur_packB += pack_n * K; + } + } +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/strategy.h b/dnn/src/armv7/matrix_mul/int8x8x16/strategy.h index 98d24bcd..7307b0c1 100644 --- a/dnn/src/armv7/matrix_mul/int8x8x16/strategy.h +++ b/dnn/src/armv7/matrix_mul/int8x8x16/strategy.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "src/fallback/matrix_mul/gemm_common.h" @@ -21,6 +22,10 @@ MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 2, 16, false, true, MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 8, 8, false, true, gemm_s8x8x16_4x8); +MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(int8_t, int16_t, int16_t, int16_t, 8, + 8, 4, false, false, + gemm_s8x8x16_mk4_8x8); + } // namespace matmul } // namespace armv7 } // namespace megdnn diff --git a/dnn/src/armv7/matrix_mul/opr_impl.cpp b/dnn/src/armv7/matrix_mul/opr_impl.cpp index 3f3a037b..904bba8b 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.cpp +++ b/dnn/src/armv7/matrix_mul/opr_impl.cpp @@ -6,10 +6,11 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ -#include "src/armv7/matrix_mul/opr_impl.h" #include "src/armv7/matrix_mul/algos.h" +#include "src/armv7/matrix_mul/opr_impl.h" #include "src/common/metahelper.h" #include "src/common/utils.h" #include "src/fallback/matrix_mul/gemm_impl.h" @@ -21,7 +22,7 @@ using namespace armv7; class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoF32 f32; AlgoF32MK4Pack4x12 f32_mk4_pack_4x12; - AlgoF32MK4_4x8 f32_mk4_4x8; + AlgoF32MK4_4x8 f32_mk4_4x8; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC AlgoF16K4x16x1 f16_k4x16x1; AlgoF16MK8_4x8 f16_mk8_4x8; @@ -38,6 +39,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoQuint8K4x8x8 quint8_k4x8x8; AlgoInt8x8x16K4x2x16 int8x8x16_k4x2x16; AlgoInt8x8x16K4x8x8 int8x8x16_k4x8x8; + AlgoInt8x8x16MK4_8x8x4 int8x8x16_mk4_8x8x4; AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1; AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8; @@ -62,8 +64,10 @@ public: all_algos.emplace_back(&int8x8x32_k4x2x16); all_algos.emplace_back(&int8x8x32_k4x8x8); all_algos.emplace_back(&quint8_k4x8x8); + all_algos.emplace_back(&int8x8x16_mk4_8x8x4); all_algos.emplace_back(&int8x8x16_k4x2x16); all_algos.emplace_back(&int8x8x16_k4x8x8); + all_algos.emplace_back(&int16x16x32_k12x4x1); all_algos.emplace_back(&int16x16x32_mk8_4x8); } diff --git a/dnn/src/armv7/matrix_mul/opr_impl.h b/dnn/src/armv7/matrix_mul/opr_impl.h index d502b63c..7e573d07 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.h +++ b/dnn/src/armv7/matrix_mul/opr_impl.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "src/arm_common/matrix_mul/opr_impl.h" @@ -19,26 +20,28 @@ public: using arm_common::MatrixMulImpl::MatrixMulImpl; SmallVector algo_pack() override; + private: - class AlgoF32; // Armv7 F32 - class AlgoF32MK4Pack4x12; // Armv7 F32 Kernel 4x12 with pack - class AlgoF32MK4_4x8; // Armv7 F32 Kernel 4x8 nopack - class AlgoF32Gemv; // Armv7 F32 Gemv - class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 - class AlgoInt8x8x32K4x2x16; // Armv7 Int8x8x32 Kernel 4x2x16 + class AlgoF32; // Armv7 F32 + class AlgoF32MK4Pack4x12; // Armv7 F32 Kernel 4x12 with pack + class AlgoF32MK4_4x8; // Armv7 F32 Kernel 4x8 nopack + class AlgoF32Gemv; // Armv7 F32 Gemv + class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 + class AlgoInt8x8x32K4x2x16; // Armv7 Int8x8x32 Kernel 4x2x16 class AlgoInt8x8x32MK4_4x2x16; // Armv7 Int8x8x32 Kernel MK4 4x2x16 - class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8 - class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16 - class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8 - class AlgoInt16x16x32K12x4x1; // Armv7 Int16x16x32 Kernel 12x4x1 - class AlgoInt16x16x32MK8_4x8; // Armv7 Int16x16x32 MK8 Format block 4x8 + class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8 + class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16 + class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8 + class AlgoInt8x8x16MK4_8x8x4; // Armv7 Int8x8x16 Kernel 8x8x8 + class AlgoInt16x16x32K12x4x1; // Armv7 Int16x16x32 Kernel 12x4x1 + class AlgoInt16x16x32MK8_4x8; // Armv7 Int16x16x32 MK8 Format block 4x8 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC class AlgoF16K4x16x1; // Armv7 F16 Kernel 4x16x1 class AlgoF16MK8_4x8; // Armv7 F16 MK8 Format block 4x8 #endif #if __ARM_FEATURE_DOTPROD - class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4 - class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4 + class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4 + class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4 class AlgoInt8x8x32MK4_8x4x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x4x4 // DotProduct #endif diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp index af9fcb60..6d8c2384 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp @@ -10,9 +10,9 @@ * implied. */ -#include "src/fallback/conv_bias/conv1x1/algos.h" #include "src/common/opr_delegate.h" #include "src/fallback/conv_bias/common.h" +#include "src/fallback/conv_bias/conv1x1/algos.h" #include "src/fallback/conv_bias/conv1x1/conv1x1_dispatcher.h" #include "src/fallback/conv_bias/conv1x1/conv1x1_strategy.h" #include "src/fallback/conv_bias/opr_impl.h" @@ -194,10 +194,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param, PW = param.filter_meta.padding[1]; size_t SH = param.filter_meta.stride[0], SW = param.filter_meta.stride[1]; - if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) return false; - if (param.src_type.enumv() != param.filter_type.enumv()) { return false; } @@ -216,6 +214,7 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param, //! is identity otherwise return false mean that 8x8x32 and 8x8x16 //! not support PostProcess if (param.dst_type.enumv() == DTypeEnum::Int16 || + param.dst_type.enumv() == DTypeEnum::QuantizedS16 || param.dst_type.enumv() == DTypeEnum::Int32 || param.dst_type.enumv() == DTypeEnum::QuantizedS32) { if (param.bias_mode != megdnn::BiasMode::NO_BIAS || diff --git a/dnn/src/fallback/conv_bias/im2col/factory.h b/dnn/src/fallback/conv_bias/im2col/factory.h index ad3e0f9b..f4fbf529 100644 --- a/dnn/src/fallback/conv_bias/im2col/factory.h +++ b/dnn/src/fallback/conv_bias/im2col/factory.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include @@ -226,10 +227,10 @@ public: PostprocessMode::FLOAT, "DefaultStrategyType::FLOAT"_hash); } else if (format == param::ConvBias::Format::NCHW44) { - #if MEGDNN_AARCH64 || MEGDNN_ARMV7 auto matmul_block = matmul_algo->get_inner_block_size(); - //! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1 im2col+pack fuse + //! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1 + //! im2col+pack fuse if ((matmul_block.m == 8 || matmul_block.m == 4) && matmul_block.n == 12 && matmul_block.k == 1 && param.filter_meta.spatial[0] == 3 && @@ -297,9 +298,21 @@ public: break; case StrategyType::INT8x8x16: - cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, - dt_int16, dt_int16, PostprocessMode::NO_PROCESS, - "DefaultStrategyType::INT8x8x16"_hash); + if (format == param::ConvBias::Format::NCHW) { + cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, + dt_int16, dt_int16, PostprocessMode::NO_PROCESS, + "DefaultStrategyType::INT8x8x16"_hash); + } else if (format == param::ConvBias::Format::NCHW44) { + cb2(NCHW44, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, + dt_int16, dt_int16, PostprocessMode::NO_PROCESS, + "DefaultStrategyType::INT8x8x16"_hash); + } else { + megdnn_throw( + ssprintf("Current only support layout " + "NCHW44/NCHW for im2col " + "algo, but got %d\n", + uint32_t(format))); + } break; #if MEGDNN_AARCH64 || MEGDNN_ARMV7 case StrategyType::QUINT8x8x32: @@ -421,10 +434,11 @@ public: dt_int32, dt_int8, PostprocessMode::QUANTIZED, "DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash); } else { - megdnn_throw(ssprintf("Current only support layout " - "NCHW44/NCHW/NCHW_DOT for im2col " - "algo, but got %d\n", - uint32_t(format))); + megdnn_throw( + ssprintf("Current only support layout " + "NCHW44/NCHW/NCHW_DOT for im2col " + "algo, but got %d\n", + uint32_t(format))); } break; } diff --git a/dnn/src/fallback/matrix_mul/opr_impl.h b/dnn/src/fallback/matrix_mul/opr_impl.h index 06d608ff..3184f2c6 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.h +++ b/dnn/src/fallback/matrix_mul/opr_impl.h @@ -6,11 +6,12 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once -#include "src/naive/matrix_mul/opr_impl.h" #include "src/common/utils.h" +#include "src/naive/matrix_mul/opr_impl.h" namespace megdnn { namespace fallback { @@ -66,7 +67,8 @@ public: }; typedef void (*kern_t)(const KernParam&); - typedef void (*kern_naked_t)(const KernParam& , const void* a_panel, const void *b_panel); + typedef void (*kern_naked_t)(const KernParam&, const void* a_panel, + const void* b_panel); class AlgoBase : public Algorithm { protected: virtual ~AlgoBase() = default; @@ -83,18 +85,19 @@ public: bool can_be_treated_as_int8x8x16(const KernSizeParam& param) const { return param.A_type.enumv() == param.B_type.enumv() && - param.A_type.enumv() == DTypeEnum::Int8 && - param.C_type.enumv() == DTypeEnum::Int16 && - param.format == param::MatrixMul::Format::DEFAULT && - param.compute_mode == Param::ComputeMode::DEFAULT; + (param.A_type.enumv() == DTypeEnum::Int8 || + param.A_type.enumv() == DTypeEnum::QuantizedS8) && + (param.C_type.enumv() == DTypeEnum::Int16 || + param.C_type.enumv() == DTypeEnum::QuantizedS16); } + public: - enum class AlgoSet:uint32_t { + enum class AlgoSet : uint32_t { ALGO_TYPE_GEMM = 0, ALGO_TYPE_GEMV = 1, }; - enum class PackMode:uint32_t { + enum class PackMode : uint32_t { DEFAULT = 0, NO_PACK = 1, ONLY_PACKA = 2, diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index eb91933b..2e52a54f 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -489,25 +489,26 @@ void benchmark_im2col_single_algo(const char* im2col_name, Handle* handle, void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name, const char* im2col_name, Handle* handle, - size_t kernel, size_t pack_size = 1) { - auto&& args = get_winograd_benchmark_args(kernel, pack_size); + size_t kernel, DType src_type, + DType dst_type) { + auto&& args = get_winograd_benchmark_args(kernel, 4); using namespace conv_bias; constexpr size_t RUN = 10; Benchmarker benchmark(handle); benchmark.set_display(false); benchmark.set_times(RUN); - benchmark.set_dtype(0, dtype::Int8()); - benchmark.set_dtype(1, dtype::Int8()); - benchmark.set_dtype(2, dtype::Int32()); - benchmark.set_dtype(4, dtype::Int32()); + benchmark.set_dtype(0, src_type); + benchmark.set_dtype(1, src_type); + benchmark.set_dtype(2, dst_type); + benchmark.set_dtype(4, dst_type); Benchmarker benchmark_im2col(handle); benchmark_im2col.set_display(false); benchmark_im2col.set_times(RUN); - benchmark_im2col.set_dtype(0, dtype::Int8()); - benchmark_im2col.set_dtype(1, dtype::Int8()); - benchmark_im2col.set_dtype(2, dtype::Int32()); - benchmark_im2col.set_dtype(4, dtype::Int32()); + benchmark_im2col.set_dtype(0, src_type); + benchmark_im2col.set_dtype(1, src_type); + benchmark_im2col.set_dtype(2, dst_type); + benchmark_im2col.set_dtype(4, dst_type); for (auto&& arg : args) { TensorLayout dst_layout; @@ -556,6 +557,7 @@ void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name, computations / used_im2col, used / used_im2col); } } + #if MEGDNN_AARCH64 TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) { printf("=========================compare " @@ -563,7 +565,17 @@ TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) { "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16 \n"); BENCHMARK_IM2COL_NCHW44_VS_NCHW("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16", "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16", - handle(), 3, 4); + handle(), 3, dtype::Int8(), dtype::Int32()); +} +#endif + +#if MEGDNN_ARMV7 +TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x16) { + const char* default_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"; + const char* mk4_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4"; + printf("compare %s vs %s \n", default_algo, mk4_algo); + BENCHMARK_IM2COL_NCHW44_VS_NCHW(default_algo, mk4_algo, handle(), 3, + dtype::Int8(), dtype::Int16()); } #endif @@ -1860,15 +1872,16 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) { param.format = param::ConvBias::Format::NCHW44_DOT; //! channel bias - args.emplace_back(param, TensorShape{1, ic/4, h, w, 4}, - TensorShape{oc/4, ic/4, kernel, kernel, 4, 4}, - TensorShape{1, oc/4, 1, 1, 4}); + args.emplace_back(param, TensorShape{1, ic / 4, h, w, 4}, + TensorShape{oc / 4, ic / 4, kernel, kernel, 4, 4}, + TensorShape{1, oc / 4, 1, 1, 4}); }; for (size_t stride : {1, 2}) for (size_t kernel : {2, 3, 5, 7}) - for(size_t oc : {64}) + for (size_t oc : {64}) for (NonlineMode nonline_mode : {NonlineMode::IDENTITY}) { - run(oc, oc, 56, 56, kernel, kernel / 2, stride, nonline_mode); + run(oc, oc, 56, 56, kernel, kernel / 2, stride, + nonline_mode); } constexpr size_t RUN = 50; @@ -1880,7 +1893,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) { benchmark0.set_display(false); benchmark0.set_times(RUN); benchmark0.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker("ARMDOTS8DIRECT_NCHW44")); + conv_bias::ConvBiasAlgoChecker( + "ARMDOTS8DIRECT_NCHW44")); Benchmarker benchmark1(handle()); benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f)) @@ -2002,15 +2016,20 @@ std::vector get_conv_bias_1x1_benchmark_args( void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle, DType stype, DType matmul_dtype, DType bias_type, - DType conv_dtype) { + DType conv_dtype, bool is_mk4 = false) { using namespace conv_bias; + int pack_size = is_mk4 ? 4 : 1; std::vector conv_bias_1x1_args = - get_conv_bias_1x1_benchmark_args(); + get_conv_bias_1x1_benchmark_args(pack_size); + constexpr size_t RUNS = 50; param::MatrixMul param; param.transposeA = false; param.transposeB = false; + if (is_mk4) { + param.format = MatrixMul::Param::Format::MK4; + } Benchmarker benchmark_matmul(handle); benchmark_matmul.set_before_exec_callback( AlgoChecker(matmul_algo_name)); @@ -2038,8 +2057,8 @@ void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle, size_t OH = arg.src[2]; size_t OW = arg.src[3]; size_t OC = arg.filter[0]; - size_t M = OC; - size_t K = IC; + size_t M = OC * pack_size; + size_t K = IC * pack_size; size_t N = OH * OW; float computations = M * N * K * 2.f / (1024 * 1024 * 1024) * 1e3; @@ -2047,6 +2066,10 @@ void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle, TensorShape A, B; A = TensorShape{M, K}; B = TensorShape{K, N}; + if (is_mk4) { + A = TensorShape{M / 4, K / 4, 4, 4}; + B = TensorShape{K / 4, N, 4}; + } auto conv1x1_used = benchmark_conv1x1.set_param(arg.param).exec( {arg.src, arg.filter, arg.bias, {}, {}}) / @@ -2133,6 +2156,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_INT8x8x16) { dtype::Int16{}, dtype::Int16{}, dtype::Int16{}); benchmark_conv1x1("ARMV7_INT8X8X16_K4X2X16", handle(), dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, dtype::Int16{}); + benchmark_conv1x1("ARMV7_INT8X8X16_MK4_K8X8X4", handle(), dtype::Int8{}, + dtype::Int16{}, dtype::Int16{}, dtype::Int16{}, true); #endif } @@ -2145,13 +2170,13 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_GEMV_FP32) { conv_param.pad_h = 0; conv_param.pad_w = 0; conv_param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; - auto run = [&](size_t M, size_t K){ + auto run = [&](size_t M, size_t K) { args.emplace_back(conv_param, TensorShape{1, K, 1, 1}, TensorShape{M, K, 1, 1}, TensorShape{}); }; for (size_t M : {4, 64, 1024, 4096}) for (size_t K : {128, 256, 1024, 4096}) - run(M, K); + run(M, K); constexpr size_t RUNS = 50; param::MatrixMul param; diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 1d1476df..c410c241 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -850,7 +850,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) { param::ConvBias::Format::NCHW44); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44_WEIGHT_PREPROCESS) { +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_WINOGRAD_F63_4_NCHW44_WEIGHT_PREPROCESS) { using namespace conv_bias; std::vector args = get_nchw44_conv_bias_args({3}, 1); Checker> checker( @@ -1131,7 +1132,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) { 1e-3f); } -TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2_WEIGHT_PREPROCESS) { +TEST_F(ARM_COMMON_MULTI_THREADS, + CONV_BIAS_WINOGRAD_MK_PACKED_F32_2_WEIGHT_PREPROCESS) { using namespace conv_bias; Checker> checker( @@ -2089,6 +2091,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) { TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { UniformIntRNG rng{-50, 50}; float epsilon = 0.001; + std::vector args_nchw44 = + get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, true, true, true, + false, false, false, false, true); + std::vector args_nchw44_1x1s2 = + get_nchw44_conv_bias_args({1}, 2, true, true, true, false, false, + false, false, true); #define cb(name) \ checker_conv_bias( \ get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ @@ -2098,6 +2106,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ dtype::Int16{}, dtype::Int16{}, name); +#define cb_nchw44(name) \ + checker_conv_bias(args_nchw44, handle(), &rng, epsilon, dtype::Int8{}, \ + dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); \ + checker_conv_bias(args_nchw44_1x1s2, handle(), &rng, epsilon, \ + dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, \ + dtype::Int16{}, name); + #if MEGDNN_AARCH64 cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8"); cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16"); @@ -2106,8 +2121,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16"); cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"); cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16"); + cb_nchw44("IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4"); #endif + #undef cb +#undef cb_nchw44 } #endif @@ -2516,19 +2534,28 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { UniformIntRNG rng{-50, 50}; float epsilon = 0.001; std::vector args = get_conv_bias_1x1_args(true, true); + std::vector args_nchw44 = get_nchw44_conv_bias_args( + {1}, 1, true, true, true, false, false, false, false, true); #define cb(name) \ checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \ dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); +#define cb_nchw44(name) \ + checker_conv_bias(args_nchw44, handle(), &rng, epsilon, dtype::Int8{}, \ + dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); + #if MEGDNN_AARCH64 cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24"); cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24"); #elif MEGDNN_ARMV7 cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24"); cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48"); + cb_nchw44("CONV1x1:ARMV7_INT8X8X16_MK4_K8X8X4:48"); #endif cb("CONV1x1:ARM_COMMON_INT8X8X16:48"); + #undef cb +#undef cb_nchw44 std::vector gemv_args; for (auto&& arg : args) diff --git a/dnn/test/armv7/matrix_mul.cpp b/dnn/test/armv7/matrix_mul.cpp index 0d17be3a..67fa6749 100644 --- a/dnn/test/armv7/matrix_mul.cpp +++ b/dnn/test/armv7/matrix_mul.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "test/armv7/fixture.h" #include "test/common/benchmarker.h" @@ -51,9 +52,15 @@ TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_K4x8x8) { handle(), "ARMV7_INT8X8X16_K4X8X8"); } +TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_MK4_K8x8x4) { + matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, + handle(), "ARMV7_INT8X8X16_MK4_K8X8X4", + param::MatrixMul::Format::MK4, 1); +} + TEST_F(ARMV7, MATRIX_MUL_INT16x16x32) { - matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, - handle(),"ARMV7_INT16X16X32_K12X4X1"); + matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, + handle(), "ARMV7_INT16X16X32_K12X4X1"); } TEST_F(ARMV7, MATRIX_MUL_INT16x16x32_MK8) { @@ -83,7 +90,8 @@ TEST_F(ARMV7, MATRIX_MUL_SDOT) { TEST_F(ARMV7, MATRIX_MUL_UDOT) { matrix_mul::check_matrix_mul( - dtype::Quantized8Asymm(4.0f, static_cast(10)), dtype::Quantized8Asymm(3.0f, static_cast(54)), + dtype::Quantized8Asymm(4.0f, static_cast(10)), + dtype::Quantized8Asymm(3.0f, static_cast(54)), dtype::QuantizedS32(12.0f), handle(), "AARCH32_QUINT8_K4X8X4"); } @@ -103,7 +111,9 @@ TEST_F(ARMV7, MATRIX_MUL_MK4_DOT_INT8) { #if MEGDNN_WITH_BENCHMARK namespace { -void run_8x8x16_benchmark(const char* algo, Handle* handle) { +void run_8x8x16_benchmark( + const char* algo, Handle* handle, + MatrixMul::Param::Format format = MatrixMul::Param::Format::DEFAULT) { constexpr size_t RUNS = 50; param::MatrixMul param; Benchmarker benchmarker_int(handle); @@ -116,21 +126,31 @@ void run_8x8x16_benchmark(const char* algo, Handle* handle) { .set_dtype(2, dtype::Int16{}) .set_param(param) .set_display(false); + param::MatrixMul target_param; + target_param.format = format; benchmarker_int_kern_4x2x16.set_before_exec_callback( AlgoChecker(algo)); benchmarker_int_kern_4x2x16.set_times(RUNS) .set_dtype(0, dtype::Int8{}) .set_dtype(1, dtype::Int8{}) .set_dtype(2, dtype::Int16{}) - .set_param(param) + .set_param(target_param) .set_display(false); Benchmarker benchmarker_float(handle); benchmarker_float.set_display(false).set_times(RUNS); auto run = [&](size_t M, size_t N, size_t K) { auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS; - auto int_kern_used = - benchmarker_int_kern_4x2x16.exec({{M, K}, {K, N}, {}}) / RUNS; + auto int_kern_used = 1e10; + if (format == MatrixMul::Param::Format::MK4) { + int_kern_used = benchmarker_int_kern_4x2x16.exec( + {{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / + RUNS; + } else { + int_kern_used = + benchmarker_int_kern_4x2x16.exec({{M, K}, {K, N}, {}}) / + RUNS; + } auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS; float computations = 2.f * M * K * N * 1e-6; printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f " @@ -145,6 +165,7 @@ void run_8x8x16_benchmark(const char* algo, Handle* handle) { }; run(256, 12 * 24, 256); + run(256, 256, 256); //////////////////////// gemv ////////////////////////// for (size_t M : {8, 64, 112, 256}) { @@ -185,7 +206,8 @@ void run_16x16x32_benchmark(const char* algo, Handle* handle) { "int: %f ms %f Gflops %s: \n" "speedup(%s/arm_common, %s/float): %f\n", M, K, N, float_used, computations / float_used, int_used, - computations / int_used,algo,algo,algo,float_used / int_used); + computations / int_used, algo, algo, algo, + float_used / int_used); }; run(256, 12 * 24, 256); @@ -231,7 +253,8 @@ void run_8x8x32_benchmark(const char* algo, Handle* handle) { "int: %f ms %f Gflops %s: \n" "speedup(%s/arm_common, %s/float): %f\n", M, K, N, float_used, computations / float_used, int_used, - computations / int_used,algo,algo,algo,float_used / int_used); + computations / int_used, algo, algo, algo, + float_used / int_used); }; run(256, 12 * 24, 256); @@ -252,9 +275,11 @@ void run_8x8x32_quint_benchmark(Handle* handle) { benchmarker_quint8_dot.set_before_exec_callback( AlgoChecker("AARCH32_QUINT8_K4X8X4")); benchmarker_quint8_dot.set_times(RUNS) - .set_dtype(0, dtype::Quantized8Asymm(2.3f, static_cast(20))) - .set_dtype(1, dtype::Quantized8Asymm(3.1f, static_cast(30))) - .set_dtype(2, dtype::QuantizedS32(2.3f*3.1f)) + .set_dtype(0, + dtype::Quantized8Asymm(2.3f, static_cast(20))) + .set_dtype(1, + dtype::Quantized8Asymm(3.1f, static_cast(30))) + .set_dtype(2, dtype::QuantizedS32(2.3f * 3.1f)) .set_param(param) .set_display(false); @@ -262,14 +287,17 @@ void run_8x8x32_quint_benchmark(Handle* handle) { benchmarker_quint8.set_before_exec_callback( AlgoChecker("ARMV7_QUINT8_K4X8X8")); benchmarker_quint8.set_times(RUNS) - .set_dtype(0, dtype::Quantized8Asymm(2.3f, static_cast(20))) - .set_dtype(1, dtype::Quantized8Asymm(3.1f, static_cast(30))) - .set_dtype(2, dtype::QuantizedS32(2.3f*3.1f)) + .set_dtype(0, + dtype::Quantized8Asymm(2.3f, static_cast(20))) + .set_dtype(1, + dtype::Quantized8Asymm(3.1f, static_cast(30))) + .set_dtype(2, dtype::QuantizedS32(2.3f * 3.1f)) .set_param(param) .set_display(false); auto run = [&](size_t M, size_t N, size_t K) { - auto dot_used = benchmarker_quint8_dot.exec({{M, K}, {K, N}, {}}) / RUNS; + auto dot_used = + benchmarker_quint8_dot.exec({{M, K}, {K, N}, {}}) / RUNS; auto normal_used = benchmarker_quint8.exec({{M, K}, {K, N}, {}}) / RUNS; float computations = 2.f * M * K * N * 1e-6; printf("run: {%zu{M} %zu{K} %zu{N}} dot: %f ms %f Gflops \n" @@ -351,11 +379,15 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x2x16) { run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X2X16", handle()); } - TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8) { run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X8X8", handle()); } +TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_MK4_K4x8x8) { + run_8x8x16_benchmark("ARMV7_INT8X8X16_MK4_K8X8X4", handle(), + MatrixMul::Param::Format::MK4); +} + TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_K12x4x1) { run_16x16x32_benchmark("ARMV7_INT16X16X32_K12X4X1", handle()); } diff --git a/dnn/test/common/matrix_mul.cpp b/dnn/test/common/matrix_mul.cpp index 74cc7c83..4b42b551 100644 --- a/dnn/test/common/matrix_mul.cpp +++ b/dnn/test/common/matrix_mul.cpp @@ -6,12 +6,13 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ -#include "test/common/matrix_mul.h" #include "src/common/utils.h" #include "test/common/benchmarker.h" #include "test/common/checker.h" +#include "test/common/matrix_mul.h" using namespace megdnn; using namespace test; @@ -39,9 +40,9 @@ std::vector matrix_mul::get_matmul_args_no_mask() { std::vector matrix_mul::get_matmul_mk_packed_args( size_t nbase) { std::vector args; - for (size_t m : {1, 2, 3, 4, 5}) + for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 11}) for (size_t n : {1, 2, 3, 4, 5, 8, 12, 16, 24}) - for (size_t k : {1, 2, 3, 4, 5, 9, 10}) + for (size_t k : {1, 2, 3, 4, 5, 9, 10, 11}) args.emplace_back(m, n * nbase, k, 0); return args; }