From 7ca3d579dbef8c809773a70ca1cedeb0011e4fb9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 16 Jul 2020 13:27:13 +0800 Subject: [PATCH] feat(dnn): make mk4 and mk8 matmul for winograd both on aarch64 and armv7 supports n=1 GitOrigin-RevId: 0f64b9f70f5f010696cec06ba37113c0d7dd178f --- dnn/src/aarch64/matrix_mul/algos.cpp | 9 +- .../aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp | 83 ++++++++++++++++- .../aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp | 85 +++++++++++++---- .../aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp | 89 +++++++++++++++++- .../arm_common/conv_bias/f16/strategy_2x3_8x8.cpp | 2 +- dnn/src/armv7/matrix_mul/algos.cpp | 9 +- dnn/src/armv7/matrix_mul/fp16/strategy_mk8_4x8.cpp | 76 +++++++++++++++- dnn/src/armv7/matrix_mul/fp32/strategy_mk4_4x8.cpp | 67 +++++++++++++- .../matrix_mul/int16x16x32/strategy_mk8_4x8.cpp | 101 ++++++++++++++++++++- dnn/src/fallback/conv_bias/winograd/winograd.h | 3 - dnn/test/aarch64/matrix_mul.cpp | 9 +- dnn/test/armv7/matrix_mul.cpp | 6 +- 12 files changed, 480 insertions(+), 59 deletions(-) diff --git a/dnn/src/aarch64/matrix_mul/algos.cpp b/dnn/src/aarch64/matrix_mul/algos.cpp index a97acc2f..7d88e3e9 100644 --- a/dnn/src/aarch64/matrix_mul/algos.cpp +++ b/dnn/src/aarch64/matrix_mul/algos.cpp @@ -214,8 +214,7 @@ bool MatrixMulImpl::AlgoF32MK4_4x16::usable( kern_size_param.B_type == dtype::Float32() && kern_size_param.A_type == dtype::Float32() && kern_size_param.format == param::MatrixMul::Format::MK4 && - !kern_size_param.trA && !kern_size_param.trB && - kern_size_param.N % 4 == 0; + !kern_size_param.trA && !kern_size_param.trB; } size_t MatrixMulImpl::AlgoF32MK4_4x16::get_workspace( @@ -330,8 +329,7 @@ bool MatrixMulImpl::AlgoF16MK8_8x8::usable( kern_size_param.B_type == kern_size_param.A_type && kern_size_param.A_type == dtype::Float16() && kern_size_param.format == param::MatrixMul::Format::MK8 && - !kern_size_param.trA && !kern_size_param.trB && - kern_size_param.N % 4 == 0; + !kern_size_param.trA && !kern_size_param.trB; } size_t MatrixMulImpl::AlgoF16MK8_8x8::get_workspace( @@ -918,8 +916,7 @@ bool MatrixMulImpl::AlgoInt16x16x32MK8_8x8::usable( kern_size_param.B_type == dtype::Int16() && kern_size_param.A_type == dtype::Int16() && kern_size_param.format == param::MatrixMul::Format::MK8 && - !kern_size_param.trA && !kern_size_param.trB && - kern_size_param.N % 4 == 0; + !kern_size_param.trA && !kern_size_param.trB; } size_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_workspace( diff --git a/dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp b/dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp index fd3d7e30..9de2d3e9 100644 --- a/dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp +++ b/dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp @@ -21,6 +21,76 @@ using namespace aarch64::matmul; namespace { +void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, + dt_float16* output) { + LDB *= sizeof(dt_float16); + asm volatile( + ".arch armv8.2-a+fp16\n" + + "subs %w[K], %w[K], #8\n" + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[a_ptr]], 64\n" + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[a_ptr]], 64\n" + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + "eor v28.16b, v28.16b, v28.16b\n" + "eor v29.16b, v29.16b, v29.16b\n" + "eor v30.16b, v30.16b, v30.16b\n" + "eor v31.16b, v31.16b, v31.16b\n" + "ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" + + "fmla v24.8h, v16.8h, v0.h[0]\n" + "fmla v25.8h, v17.8h, v0.h[1]\n" + "fmla v26.8h, v18.8h, v0.h[2]\n" + "fmla v27.8h, v19.8h, v0.h[3]\n" + + "beq 2f\n" + + "1:\n" + + "ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [%[a_ptr]], 64\n" + "fmla v28.8h, v20.8h, v0.h[4]\n" + "fmla v29.8h, v21.8h, v0.h[5]\n" + "fmla v30.8h, v22.8h, v0.h[6]\n" + "fmla v31.8h, v23.8h, v0.h[7]\n" + + "ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" + + "ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [%[a_ptr]], 64\n" + "fmla v24.8h, v16.8h, v0.h[0]\n" + "fmla v25.8h, v17.8h, v0.h[1]\n" + "fmla v26.8h, v18.8h, v0.h[2]\n" + "fmla v27.8h, v19.8h, v0.h[3]\n" + + "subs %w[K], %w[K], #8\n" + "bne 1b\n" + + "2:\n" + + "fmla v28.8h, v20.8h, v0.h[4]\n" + "fmla v29.8h, v21.8h, v0.h[5]\n" + "fmla v30.8h, v22.8h, v0.h[6]\n" + "fmla v31.8h, v23.8h, v0.h[7]\n" + + "fadd v24.8h, v24.8h, v25.8h\n" + "fadd v26.8h, v26.8h, v27.8h\n" + "fadd v28.8h, v28.8h, v29.8h\n" + "fadd v30.8h, v30.8h, v31.8h\n" + "fadd v24.8h, v24.8h, v26.8h\n" + "fadd v28.8h, v28.8h, v30.8h\n" + "fadd v24.8h, v24.8h, v28.8h\n" + + "st1 {v24.4s}, [%[output]], 16\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [output] "+r"(output), [LDB] "+r"(LDB) + : + : "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", + "memory"); +} + // Overview of register layout: // // A 8x1 cell of Rhs is stored in 16bit in v0-v3 @@ -416,7 +486,7 @@ void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA, constexpr static size_t NB = 8; constexpr static size_t CALCBLK = 4; - megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); + megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) for (size_t m = 0; m < M; m += MB) { @@ -428,8 +498,17 @@ void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA, cur_B += KB * NB; output += MB * NB; } - if (n < N) { + if (N - n >= 4) { kern_8x4(A, cur_B, LDB, K, output); + cur_B += KB * CALCBLK; + output += MB * CALCBLK; + n += 4; + } + while (n < N) { + kern_8x1(A, cur_B, LDB, K, output); + cur_B += KB; + output += MB; + n++; } A += LDA; } diff --git a/dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp b/dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp index bfb51f6e..fa88550e 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp +++ b/dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp @@ -20,6 +20,54 @@ using namespace aarch64::matmul; namespace { +void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, + float* output) { + LDB *= sizeof(float); + asm volatile( + "subs %w[K], %w[K], #4\n" + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[a_ptr]], 64\n" + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + "ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" + "prfm pstl1keep, [%[b_ptr]]\n" + + "fmla v16.4s, v4.4s, v0.s[0]\n" + "fmla v17.4s, v5.4s, v0.s[1]\n" + + "beq 2f\n" + + "1:\n" + "ld1 {v4.4s, v5.4s}, [%[a_ptr]], 32\n" + "fmla v18.4s, v6.4s, v0.s[2]\n" + "fmla v19.4s, v7.4s, v0.s[3]\n" + "ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" + "prfm pstl1keep, [%[b_ptr]]\n" + "ld1 {v6.4s, v7.4s}, [%[a_ptr]], 32\n" + "fmla v16.4s, v4.4s, v0.s[0]\n" + "fmla v17.4s, v5.4s, v0.s[1]\n" + + "subs %w[K], %w[K], #4\n" + "bne 1b\n" + + "2:\n" + + "fmla v18.4s, v6.4s, v0.s[2]\n" + "fmla v19.4s, v7.4s, v0.s[3]\n" + "fadd v16.4s, v16.4s, v18.4s\n" + "fadd v17.4s, v17.4s, v19.4s\n" + "fadd v16.4s, v16.4s, v17.4s\n" + + "st1 {v16.4s}, [%[output]], 16\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [output] "+r"(output), [LDB] "+r"(LDB) + : + : "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc", + "memory"); +} + // Overview of register layout: // // A 4x4 block of A is stored in register v4-v7 @@ -117,7 +165,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [output] "+r"(output), [LDB] "+r"(LDB) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", + "v18", "v19", "cc", "memory"); } // Overview of register layout: @@ -535,7 +584,7 @@ void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B, constexpr static size_t NB = 16; constexpr static size_t CALCBLK = 4; - megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); + megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); //! (m/4, k/4, 4, 4) * (k/4, n, 4) = (m/4, n, 4) for (size_t m = 0; m < M; m += MB) { @@ -547,21 +596,23 @@ void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B, cur_B += KB * NB; output += MB * NB; } - switch (N - n) { - case 4: - kern_4x4(A, cur_B, LDB, K, output); - break; - case 8: - kern_4x8(A, cur_B, LDB, K, output); - break; - case 12: - kern_4x8(A, cur_B, LDB, K, output); - cur_B += KB * CALCBLK * 2; - output += MB * CALCBLK * 2; - kern_4x4(A, cur_B, LDB, K, output); - break; - default: - break; + if (N - n >= 8) { + kern_4x8(A, cur_B, LDB, K, output); + cur_B += KB * CALCBLK * 2; + output += MB * CALCBLK * 2; + n += 8; + } + if (N - n >= 4) { + kern_4x4(A, cur_B, LDB, K, output); + cur_B += KB * CALCBLK; + output += MB * CALCBLK; + n += 4; + } + while (n < N) { + kern_4x1(A, cur_B, LDB, K, output); + cur_B += KB; + output += MB; + n++; } A += LDA; } diff --git a/dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp b/dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp index 7a4d7d90..e358d411 100644 --- a/dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp +++ b/dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp @@ -20,6 +20,82 @@ using namespace aarch64::matmul; namespace { +void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, + dt_int32* output) { + //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 + //! here. + LDB *= sizeof(dt_int16); + + asm volatile( + "subs %w[K], %w[K], #8\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[a_ptr]], 64\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[a_ptr]], 64\n" + "ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" + + "smull v16.4s, v24.4h, v0.h[0]\n" + "smull2 v17.4s, v24.8h, v0.h[0]\n" + "smull v18.4s, v25.4h, v0.h[1]\n" + "smull2 v19.4s, v25.8h, v0.h[1]\n" + "smull v20.4s, v26.4h, v0.h[2]\n" + "smull2 v21.4s, v26.8h, v0.h[2]\n" + "smull v22.4s, v27.4h, v0.h[3]\n" + "smull2 v23.4s, v27.8h, v0.h[3]\n" + + "beq 2f\n" + + "1:\n" + "ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [%[a_ptr]], 64\n" + "smlal v16.4s, v28.4h, v0.h[4]\n" + "smlal2 v17.4s, v28.8h, v0.h[4]\n" + "smlal v18.4s, v29.4h, v0.h[5]\n" + "smlal2 v19.4s, v29.8h, v0.h[5]\n" + "smlal v20.4s, v30.4h, v0.h[6]\n" + "smlal2 v21.4s, v30.8h, v0.h[6]\n" + "smlal v22.4s, v31.4h, v0.h[7]\n" + "smlal2 v23.4s, v31.8h, v0.h[7]\n" + + "ld1 {v0.4s}, [%[b_ptr]], %x[LDB]\n" + "ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [%[a_ptr]], 64\n" + + "smlal v16.4s, v24.4h, v0.h[0]\n" + "smlal2 v17.4s, v24.8h, v0.h[0]\n" + "smlal v18.4s, v25.4h, v0.h[1]\n" + "smlal2 v19.4s, v25.8h, v0.h[1]\n" + "smlal v20.4s, v26.4h, v0.h[2]\n" + "smlal2 v21.4s, v26.8h, v0.h[2]\n" + "smlal v22.4s, v27.4h, v0.h[3]\n" + "smlal2 v23.4s, v27.8h, v0.h[3]\n" + + "subs %w[K], %w[K], #8\n" + "bne 1b\n" + + "2:\n" + "smlal v16.4s, v28.4h, v0.h[4]\n" + "smlal2 v17.4s, v28.8h, v0.h[4]\n" + "smlal v18.4s, v29.4h, v0.h[5]\n" + "smlal2 v19.4s, v29.8h, v0.h[5]\n" + "smlal v20.4s, v30.4h, v0.h[6]\n" + "smlal2 v21.4s, v30.8h, v0.h[6]\n" + "smlal v22.4s, v31.4h, v0.h[7]\n" + "smlal2 v23.4s, v31.8h, v0.h[7]\n" + + "add v16.4s, v16.4s, v18.4s\n" + "add v20.4s, v20.4s, v22.4s\n" + "add v17.4s, v17.4s, v19.4s\n" + "add v21.4s, v21.4s, v23.4s\n" + "add v16.4s, v16.4s, v20.4s\n" + "add v17.4s, v17.4s, v21.4s\n" + + "st1 {v16.4s, v17.4s}, [%[output]], 32\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [output] "+r"(output), [LDB] "+r"(LDB) + : + : "v0", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "cc", "memory"); +} + // Overview of register layout: // // A 8x1 cell of Lhs is stored in 16bit in v24-v27 @@ -636,7 +712,7 @@ void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, constexpr static size_t NB = 8; constexpr static size_t CALCBLK = 4; - megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); + megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) for (size_t m = 0; m < M; m += MB) { @@ -648,8 +724,17 @@ void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, cur_B += KB * NB; output += MB * NB; } - if (n < N) { + if (N - n >= 4) { kern_8x4(A, cur_B, LDB, K, output); + cur_B += KB * CALCBLK; + output += MB * CALCBLK; + n += 4; + } + while (n < N) { + kern_8x1(A, cur_B, LDB, K, output); + cur_B += KB; + output += MB; + n++; } A += LDA; } diff --git a/dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp b/dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp index e25e8f11..23adba86 100644 --- a/dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp +++ b/dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp @@ -390,7 +390,7 @@ void winograd_2x3_8x8_f16::output(const dt_float16* output_transform_buf, size_t oh_start = nh * OUTPUT_BLOCK_SIZE; size_t ow_start = nw * OUTPUT_BLOCK_SIZE; DISPATCH_CONV_WINOGRAD_BIAS( - megdnn_arm_common_winograd_fp16_F23_8x8, cb, __fp16, __fp16, + megdnn_arm_common_winograd_f16_F23_8x8, cb, __fp16, __fp16, bmode, nonline_mode, output_transform_buf, bias, output, transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype); diff --git a/dnn/src/armv7/matrix_mul/algos.cpp b/dnn/src/armv7/matrix_mul/algos.cpp index 8d8c2f93..712753a0 100644 --- a/dnn/src/armv7/matrix_mul/algos.cpp +++ b/dnn/src/armv7/matrix_mul/algos.cpp @@ -875,8 +875,7 @@ bool MatrixMulImpl::AlgoF32MK4_4x8::usable( kern_size_param.B_type == kern_size_param.A_type && kern_size_param.C_type == kern_size_param.A_type && kern_size_param.A_type == dtype::Float32() && - kern_size_param.N % 4 == 0 && !kern_size_param.trA && - !kern_size_param.trB; + !kern_size_param.trA && !kern_size_param.trB; } size_t MatrixMulImpl::AlgoF32MK4_4x8::get_workspace( @@ -911,8 +910,7 @@ bool MatrixMulImpl::AlgoInt16x16x32MK8_4x8::usable( kern_size_param.A_type == dtype::Int16() && kern_size_param.B_type == dtype::Int16() && kern_size_param.C_type == dtype::Int32() && - kern_size_param.N % 4 == 0 && !kern_size_param.trA && - !kern_size_param.trB; + !kern_size_param.trA && !kern_size_param.trB; } size_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_workspace( @@ -969,8 +967,7 @@ bool MatrixMulImpl::AlgoF16MK8_4x8::usable( kern_size_param.B_type == kern_size_param.A_type && kern_size_param.A_type == dtype::Float16() && kern_size_param.format == param::MatrixMul::Format::MK8 && - !kern_size_param.trA && !kern_size_param.trB && - kern_size_param.N % 4 == 0; + !kern_size_param.trA && !kern_size_param.trB; } size_t MatrixMulImpl::AlgoF16MK8_4x8::get_workspace( diff --git a/dnn/src/armv7/matrix_mul/fp16/strategy_mk8_4x8.cpp b/dnn/src/armv7/matrix_mul/fp16/strategy_mk8_4x8.cpp index 90903f63..a8cb2b56 100644 --- a/dnn/src/armv7/matrix_mul/fp16/strategy_mk8_4x8.cpp +++ b/dnn/src/armv7/matrix_mul/fp16/strategy_mk8_4x8.cpp @@ -21,6 +21,66 @@ using namespace armv7::matmul; namespace { +void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, + dt_float16* output) { + LDB = (LDB - 4) * sizeof(dt_float16); + asm volatile( + "subs %[K], #8\n" + + "vld1.32 {d0}, [%[b_ptr]]!\n" + "vld1.32 {d1}, [%[b_ptr]], %[LDB]\n" + "vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" + "vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" + "vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" + "vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" + + "vmul.f16 q12, q4, d0[0]\n" + "vmul.f16 q13, q5, d0[1]\n" + "vmul.f16 q14, q6, d0[2]\n" + "vmul.f16 q15, q7, d0[3]\n" + + "beq 2f\n" + + "1:\n" + "vmla.f16 q12, q8, d1[0]\n" + "vld1.32 {d0}, [%[b_ptr]]!\n" + "vmla.f16 q13, q9, d1[1]\n" + "vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" + "vmla.f16 q14, q10, d1[2]\n" + "vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" + "vmla.f16 q15, q11, d1[3]\n" + + "vmla.f16 q12, q4, d0[0]\n" + "vld1.32 {d1}, [%[b_ptr]], %[LDB]\n" + "vmla.f16 q13, q5, d0[1]\n" + "vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" + "vmla.f16 q14, q6, d0[2]\n" + "vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" + "vmla.f16 q15, q7, d0[3]\n" + + "subs %[K], #8\n" + "bne 1b\n" + + "2:\n" + "vmla.f16 q12, q8, d1[0]\n" + "vmla.f16 q13, q9, d1[1]\n" + "vmla.f16 q14, q10, d1[2]\n" + "vmla.f16 q15, q11, d1[3]\n" + + "vadd.f16 q12, q12, q14\n" + "vadd.f16 q13, q13, q15\n" + "vadd.f16 q12, q12, q13\n" + + "vst1.32 {d24, d25}, [%[output]]!\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [output] "+r"(output), [LDB] "+r"(LDB) + : + : "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", + "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", + "d25", "d26", "d27", "d28", "d29", "d30", "d31", "cc", "memory"); +} + // Overview of register layout: // // A 8x1 cell of Rhs is stored in 16bit in v4-v11 @@ -45,7 +105,7 @@ namespace { // | v3[0-7]| |v15[0-7]| // +--------+ +--------+--------+ // Accumulator -void kern_4x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, +void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, dt_float16* output) { //! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 //! here. @@ -179,19 +239,25 @@ void gemm_nopack_f16_4x8::kern(const dt_float16* A, size_t LDA, constexpr static size_t MB = 8; constexpr static size_t KB = 8; constexpr static size_t NB = 4; - constexpr static size_t CALCBLK = 4; - megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); + megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) for (size_t m = 0; m < M; m += MB) { dt_float16* output = C + (m / MB) * LDC; const dt_float16* cur_B = B; - for (size_t n = 0; n < N; n += NB) { - kern_4x8(A, cur_B, LDB, K, output); + size_t n = 0; + for (; n + NB - 1 < N; n += NB) { + kern_8x4(A, cur_B, LDB, K, output); cur_B += KB * NB; output += MB * NB; } + while (n < N) { + kern_8x1(A, cur_B, LDB, K, output); + cur_B += KB; + output += MB; + n++; + } A += LDA; } } diff --git a/dnn/src/armv7/matrix_mul/fp32/strategy_mk4_4x8.cpp b/dnn/src/armv7/matrix_mul/fp32/strategy_mk4_4x8.cpp index 096e53c9..de897ae7 100644 --- a/dnn/src/armv7/matrix_mul/fp32/strategy_mk4_4x8.cpp +++ b/dnn/src/armv7/matrix_mul/fp32/strategy_mk4_4x8.cpp @@ -20,6 +20,58 @@ using namespace armv7::matmul; namespace { +void kern_4x1(const float* A, const float* B, size_t LDB, size_t K, float* C) { + LDB = (LDB - 4) * sizeof(float); + asm volatile( + "subs %[K], %[K], #4\n" + + "vld1.32 {d8-d11}, [%[A]]!\n" + "vld1.32 {d12-d15}, [%[A]]!\n" + "veor q8, q8 \n" + "veor q9, q9 \n" + "veor q10, q10 \n" + "veor q11, q11 \n" + + "vld1.32 {d0-d1}, [%[B]]!\n" + + "vmla.f32 q8, q4, d0[0]\n" + "vmla.f32 q9, q5, d0[1]\n" + + "beq 2f\n" + + "1:\n" + + "vld1.32 {d8-d11}, [%[A]]!\n" + "vmla.f32 q10, q6, d1[0]\n" + "vmla.f32 q11, q7, d1[1]\n" + + "add %[B], %[B], %[LDB]\n" + "vld1.32 {d0-d1}, [%[B]]!\n" + "vld1.32 {d12-d15}, [%[A]]!\n" + + "vmla.f32 q8, q4, d0[0]\n" + "vmla.f32 q9, q5, d0[1]\n" + + "subs %[K], %[K], #4\n" + "bne 1b\n" + + "2:\n" + + "vmla.f32 q10, q6, d1[0]\n" + "vmla.f32 q11, q7, d1[1]\n" + "vadd.f32 q8, q8, q10\n" + "vadd.f32 q9, q9, q11\n" + "vadd.f32 q8, q8, q9\n" + + "vst1.32 {d16, d17}, [%[C]]!\n" + + : [ A ] "+r"(A), [ B ] "+r"(B), [ K ] "+r"(K), [ C ] "+r"(C) + : [ LDB ] "r"(LDB) + : "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", + "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "cc", + "memory"); +} + // Overview of register layout: // // A 8x4 cell of Rhs is stored in 32bit in q0-q3, load 4 register each time @@ -268,9 +320,9 @@ void sgemm_nopack_4x8::kern(const float* A, size_t LDA, const float* B, constexpr size_t MB = 4; constexpr size_t KB = 4; constexpr size_t NB = 8; - constexpr size_t CALCBLK = 4; + constexpr size_t NB_HALF = 4; - megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); + megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) for (size_t m = 0; m < M; m += MB) { @@ -282,8 +334,17 @@ void sgemm_nopack_4x8::kern(const float* A, size_t LDA, const float* B, cur_B += KB * NB; output += MB * NB; } - if (n < N) { + if (N - n >= 4) { kern_4x4(A, cur_B, LDB, K, output); + cur_B += KB * NB_HALF; + output += MB * NB_HALF; + n += 4; + } + while (n < N) { + kern_4x1(A, cur_B, LDB, K, output); + cur_B += KB; + output += MB; + n++; } A += LDA; } diff --git a/dnn/src/armv7/matrix_mul/int16x16x32/strategy_mk8_4x8.cpp b/dnn/src/armv7/matrix_mul/int16x16x32/strategy_mk8_4x8.cpp index 335c9b67..62243f42 100644 --- a/dnn/src/armv7/matrix_mul/int16x16x32/strategy_mk8_4x8.cpp +++ b/dnn/src/armv7/matrix_mul/int16x16x32/strategy_mk8_4x8.cpp @@ -20,6 +20,91 @@ using namespace armv7::matmul; namespace { +void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, + dt_int32* output) { + //! As each load 16 number from B, but the pos add 16 * 2, so we minus 16 + //! here. + LDB = (LDB - 4) * sizeof(dt_int16); + + asm volatile( + "subs %[K], #8\n" + "vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" + "vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" + "vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" + "vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" + + "vld1.32 {d0}, [%[b_ptr]]!\n" + "vld1.32 {d1}, [%[b_ptr]], %[LDB]\n" + + "vmull.s16 q12, d8, d0[0]\n" + "vmull.s16 q13, d9, d0[0]\n" + "vmull.s16 q14, d10, d0[1]\n" + "vmull.s16 q15, d11, d0[1]\n" + + "vmlal.s16 q12, d12, d0[2]\n" + "vmlal.s16 q13, d13, d0[2]\n" + "vmlal.s16 q14, d14, d0[3]\n" + "vmlal.s16 q15, d15, d0[3]\n" + + "beq 2f\n" + + "1:\n" + + "vld1.32 {d8, d9, d10, d11}, [%[a_ptr]]!\n" + "vld1.32 {d12, d13, d14, d15}, [%[a_ptr]]!\n" + "vld1.32 {d0}, [%[b_ptr]]!\n" + + "vmlal.s16 q12, d16, d1[0]\n" + "vmlal.s16 q13, d17, d1[0]\n" + "vmlal.s16 q14, d18, d1[1]\n" + "vmlal.s16 q15, d19, d1[1]\n" + + "vmlal.s16 q12, d20, d1[2]\n" + "vmlal.s16 q13, d21, d1[2]\n" + "vmlal.s16 q14, d22, d1[3]\n" + "vmlal.s16 q15, d23, d1[3]\n" + + "vld1.32 {d1}, [%[b_ptr]], %[LDB]\n" + "vld1.32 {d16, d17, d18, d19}, [%[a_ptr]]!\n" + "vld1.32 {d20, d21, d22, d23}, [%[a_ptr]]!\n" + + "vmlal.s16 q12, d8, d0[0]\n" + "vmlal.s16 q13, d9, d0[0]\n" + "vmlal.s16 q14, d10, d0[1]\n" + "vmlal.s16 q15, d11, d0[1]\n" + + "vmlal.s16 q12, d12, d0[2]\n" + "vmlal.s16 q13, d13, d0[2]\n" + "vmlal.s16 q14, d14, d0[3]\n" + "vmlal.s16 q15, d15, d0[3]\n" + + "subs %[K], %[K], #8\n" + "bne 1b\n" + + "2:\n" + "vmlal.s16 q12, d16, d1[0]\n" + "vmlal.s16 q13, d17, d1[0]\n" + "vmlal.s16 q14, d18, d1[1]\n" + "vmlal.s16 q15, d19, d1[1]\n" + + "vmlal.s16 q12, d20, d1[2]\n" + "vmlal.s16 q13, d21, d1[2]\n" + "vmlal.s16 q14, d22, d1[3]\n" + "vmlal.s16 q15, d23, d1[3]\n" + + "vadd.s32 q12, q12, q14\n" + "vadd.s32 q13, q13, q15\n" + + "vst1.32 {d24, d25, d26, d27}, [%[output]]!\n" + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [output] "+r"(output), [LDB] "+r"(LDB) + : + : "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", + "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", + "d25", "d26", "d27", "d28", "d29", "d30", "d31", "cc", "memory"); +} + // Overview of register layout: // // A 4x8 cell of Rhs is stored in 16bit in q0-q3 @@ -40,7 +125,7 @@ namespace { // | q3[0-7]| |q14[0-3]|v15[0-3]| // +--------+ +--------+--------+ // Accumulator -void kern_4x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, +void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, dt_int32* output) { //! As each load 16 number from B, but the pos add 16 * 2, so we minus 16 //! here. @@ -247,19 +332,25 @@ void gemm_nopack_s16_4x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, constexpr static size_t MB = 8; constexpr static size_t KB = 8; constexpr static size_t NB = 4; - constexpr static size_t CALCBLK = 4; - megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0 && N % CALCBLK == 0); + megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); //! (m/8, k/8, 8, 8) * (k/8, n, 8) = (m/8, n, 8) for (size_t m = 0; m < M; m += MB) { dt_int32* output = C + (m / MB) * LDC; const dt_int16* cur_B = B; - for (size_t n = 0; n < N; n += NB) { - kern_4x8(A, cur_B, LDB, K, output); + size_t n = 0; + for (; n + NB - 1 < N; n += NB) { + kern_8x4(A, cur_B, LDB, K, output); cur_B += KB * NB; output += MB * NB; } + while (n < N) { + kern_8x1(A, cur_B, LDB, K, output); + cur_B += KB; + output += MB; + n++; + } A += LDA; } } diff --git a/dnn/src/fallback/conv_bias/winograd/winograd.h b/dnn/src/fallback/conv_bias/winograd/winograd.h index 868605e9..eb1796ee 100644 --- a/dnn/src/fallback/conv_bias/winograd/winograd.h +++ b/dnn/src/fallback/conv_bias/winograd/winograd.h @@ -427,9 +427,6 @@ public: "The winograd remain oc is not times of OC_BLOCK_SIZE"); if (format == param::MatrixMul::Format::MK4 || format == param::MatrixMul::Format::MK8) { -#if !MEGDNN_X86 - nr_tiles_in_unit = round_up(nr_tiles_in_unit, 4); -#endif megdnn_assert(nr_tiles_in_unit <= unit_tile_size, "nr_tiles_in_unit: %zu TILE_SIZE:%zu", nr_tiles_in_unit, unit_tile_size); diff --git a/dnn/test/aarch64/matrix_mul.cpp b/dnn/test/aarch64/matrix_mul.cpp index d4dfae37..2cb5ec47 100644 --- a/dnn/test/aarch64/matrix_mul.cpp +++ b/dnn/test/aarch64/matrix_mul.cpp @@ -38,10 +38,9 @@ TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4) { } TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) { - //! nbase should be 4 in order to test the last rest 4 in N dim matrix_mul::check_matrix_mul( dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), - "AARCH64_F32_MK4_4x16", param::MatrixMul::Format::MK4, 4); + "AARCH64_F32_MK4_4x16", param::MatrixMul::Format::MK4, 1); } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -52,10 +51,9 @@ TEST_F(AARCH64, MATRIX_MUL_F16_K8X24X1) { } TEST_F(AARCH64, MATRIX_MUL_F16_MK8) { - //! nbase should be 4 in order to test the last rest 4 in N dim matrix_mul::check_matrix_mul( dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(), - "AARCH64_F16_MK8_8X8", param::MatrixMul::Format::MK8, 4); + "AARCH64_F16_MK8_8X8", param::MatrixMul::Format::MK8, 1); } #endif @@ -116,10 +114,9 @@ TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_K12X8X1) { } TEST_F(AARCH64, MATRIX_MUL_INT16x16x32_MK8) { - //! nbase should be 4 in order to test the last rest 4 in N dim matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, handle(), "AARCH64_INT16X16X32_MK8_8X8", - param::MatrixMul::Format::MK8, 4); + param::MatrixMul::Format::MK8, 1); } //! FIXME: need to add tests of GEMV and QUINT8 diff --git a/dnn/test/armv7/matrix_mul.cpp b/dnn/test/armv7/matrix_mul.cpp index 67fa6749..2e7baa28 100644 --- a/dnn/test/armv7/matrix_mul.cpp +++ b/dnn/test/armv7/matrix_mul.cpp @@ -26,7 +26,7 @@ TEST_F(ARMV7, MATRIX_MUL) { TEST_F(ARMV7, MATRIX_MUL_MK4) { matrix_mul::check_matrix_mul( dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), - "ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 4); + "ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1); } TEST_F(ARMV7, MATRIX_MUL_PACK_MK4) { @@ -66,7 +66,7 @@ TEST_F(ARMV7, MATRIX_MUL_INT16x16x32) { TEST_F(ARMV7, MATRIX_MUL_INT16x16x32_MK8) { matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, handle(), "ARMV7_INT16X16X32_MK8_4X8", - param::MatrixMul::Format::MK8, 4); + param::MatrixMul::Format::MK8, 1); } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -78,7 +78,7 @@ TEST_F(ARMV7, MATRIX_MUL_FP16) { TEST_F(ARMV7, MATRIX_MUL_F16_MK8) { matrix_mul::check_matrix_mul( dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, handle(), - "AARCH32_F16_MK8_4X8", param::MatrixMul::Format::MK8, 4); + "AARCH32_F16_MK8_4X8", param::MatrixMul::Format::MK8, 1); } #endif