diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index 836dfc61..d2bcc0e1 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -138,8 +138,6 @@ public: } } - //! TODO: move arm_v7 MatrixMulImpl::AlgoF32 matmul to gi fallback, for nchw - //! prefetch algo, also need update dnn/test/common/conv_bias.cpp:check_winograd matmul_algos = static_cast(matmul_opr) ->select_algo_type( {AlgoDataType::FLOAT32, MatmulFormat::DEFAULT}); diff --git a/dnn/src/fallback/matrix_mul/algos.cpp b/dnn/src/fallback/matrix_mul/algos.cpp index f79b2a5f..3188b23d 100644 --- a/dnn/src/fallback/matrix_mul/algos.cpp +++ b/dnn/src/fallback/matrix_mul/algos.cpp @@ -15,6 +15,7 @@ MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like) MIDOUT_DECL(megdnn_fb_matmul_naive) MIDOUT_DECL(megdnn_fb_gi_exec_fp32) MIDOUT_DECL(megdnn_fb_gi_matmul_kern) +MIDOUT_DECL(megdnn_fb_gi_f32_4x12) using namespace megdnn; using namespace fallback; @@ -293,4 +294,61 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern( const KernSizeParam&) const { return gi_f32_mk4_4x8_kern; } + +/* ===================== F32 algo ===================== */ +namespace { +void f32_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN(megdnn_fb_gi_f32_4x12, midout_iv("f32_kern"_hash)) { + auto M = kern_param.M, N = kern_param.N, K = kern_param.K; + auto trA = kern_param.trA, trB = kern_param.trB; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; + auto A_type = kern_param.A_type, B_type = kern_param.B_type, + C_type = kern_param.C_type; + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); + auto Cptr = kern_param.C(); + + matmul::fallback::gi_sgemm_4x12 strategy(M, N, K, A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); + } + MIDOUT_END(); +} + +} // anonymous namespace + +bool MatrixMulImpl::AlgoF32Gi4x12::usable(const KernSizeParam& kern_size_param) const { + return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float32(); +} + +size_t MatrixMulImpl::AlgoF32Gi4x12::get_workspace( + const KernSizeParam& kern_size_param) const { + MIDOUT_BEGIN( + megdnn_fb_gi_f32_4x12, midout_iv("AlgoF32Gi4x12::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; + auto trA = kern_size_param.trA, trB = kern_size_param.trB; + auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, + C_type = kern_size_param.C_type; + matmul::fallback::gi_sgemm_4x12 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .get_workspace_size(); + } + MIDOUT_END(); + return 0; +} + +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gi4x12::get_kern( + const KernSizeParam&) const { + return f32_kern; +} + +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoF32Gi4x12, megdnn_fb_gi_f32_4x12, "AlgoF32Gi4x12Impl"_hash, + matmul::fallback::gi_sgemm_4x12, float, float, AlgoDataType::FLOAT32, DEFAULT); + // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/matrix_mul/algos.h b/dnn/src/fallback/matrix_mul/algos.h index 24635250..bef12025 100644 --- a/dnn/src/fallback/matrix_mul/algos.h +++ b/dnn/src/fallback/matrix_mul/algos.h @@ -97,6 +97,17 @@ public: MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8) }; +class MatrixMulImpl::AlgoF32Gi4x12 final : public AlgoBase { +public: + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + const char* name() const override { return "FB_GI_F32_4x12"; } + bool usable(const KernSizeParam&) const override; + size_t get_workspace(const KernSizeParam&) const override; + kern_t get_kern(const KernSizeParam&) const override; + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); + MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_4x12) +}; + } // namespace fallback } // namespace megdnn diff --git a/dnn/src/fallback/matrix_mul/generic_strategy.h b/dnn/src/fallback/matrix_mul/generic_strategy.h index c6e5760d..4dc1aecd 100644 --- a/dnn/src/fallback/matrix_mul/generic_strategy.h +++ b/dnn/src/fallback/matrix_mul/generic_strategy.h @@ -8,6 +8,7 @@ namespace fallback { MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12); MEGDNN_REG_GEMM_STRATEGY_NOPACK( float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8); +MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, gi_sgemm_4x12); } // namespace fallback } // namespace matmul diff --git a/dnn/src/fallback/matrix_mul/gi/fp32/common.h b/dnn/src/fallback/matrix_mul/gi/fp32/common.h new file mode 100644 index 00000000..b969f698 --- /dev/null +++ b/dnn/src/fallback/matrix_mul/gi/fp32/common.h @@ -0,0 +1,221 @@ +#pragma once + +#include "src/fallback/general_intrinsic/gi_float.h" + +namespace megdnn { +namespace matmul { +namespace fallback { + +/* ======================== transform ======================== */ +/** + * interleave_INTERLEAVE_UNROLLK_BATCH_type + * + * BATCH means process BATCH * UNROLL_K cols once, BATCH * sizeof(TYPE) * + * UNROLL_K = 16bytes(128bits, a vector size). + * + * the elements traverse order: + * rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[j, i] + */ + +template +static GI_FORCEINLINE void interleave_4x4_1_s( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 4, "interleave_4x4_1_s only support sizeof(T) == 4"); + GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0); + GI_FLOAT32_t d2d3 = GiLoadFloat32(inptr1); + GI_FLOAT32_t d4d5 = GiLoadFloat32(inptr2); + GI_FLOAT32_t d6d7 = GiLoadFloat32(inptr3); + inptr0 += 4; + inptr1 += 4; + inptr2 += 4; + inptr3 += 4; + + GiStoreFloat32(outptr, d0d1); + outptr += 4; + GiStoreFloat32(outptr, d2d3); + outptr += 4; + GiStoreFloat32(outptr, d4d5); + outptr += 4; + GiStoreFloat32(outptr, d6d7); + outptr += 4; +} + +template +static GI_FORCEINLINE void interleave_4x12_1_s( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 4, "interleave_4x12_1_s only support sizeof(T) == 4"); + GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0); + inptr0 += 4; + GI_FLOAT32_t d2d3 = GiLoadFloat32(inptr0); + inptr0 += 4; + GI_FLOAT32_t d4d5 = GiLoadFloat32(inptr0); + inptr0 += 4; + + GI_FLOAT32_t d6d7 = GiLoadFloat32(inptr1); + inptr1 += 4; + GI_FLOAT32_t d8d9 = GiLoadFloat32(inptr1); + inptr1 += 4; + GI_FLOAT32_t d10d11 = GiLoadFloat32(inptr1); + inptr1 += 4; + + GI_FLOAT32_t d12d13 = GiLoadFloat32(inptr2); + inptr2 += 4; + GI_FLOAT32_t d14d15 = GiLoadFloat32(inptr2); + inptr2 += 4; + GI_FLOAT32_t d16d17 = GiLoadFloat32(inptr2); + inptr2 += 4; + + GI_FLOAT32_t d18d19 = GiLoadFloat32(inptr3); + inptr3 += 4; + GI_FLOAT32_t d20d21 = GiLoadFloat32(inptr3); + inptr3 += 4; + GI_FLOAT32_t d22d23 = GiLoadFloat32(inptr3); + inptr3 += 4; + + GiStoreFloat32(outptr, d0d1); + outptr += 4; + GiStoreFloat32(outptr, d2d3); + outptr += 4; + GiStoreFloat32(outptr, d4d5); + outptr += 4; + GiStoreFloat32(outptr, d6d7); + outptr += 4; + GiStoreFloat32(outptr, d8d9); + outptr += 4; + GiStoreFloat32(outptr, d10d11); + outptr += 4; + GiStoreFloat32(outptr, d12d13); + outptr += 4; + GiStoreFloat32(outptr, d14d15); + outptr += 4; + GiStoreFloat32(outptr, d16d17); + outptr += 4; + GiStoreFloat32(outptr, d18d19); + outptr += 4; + GiStoreFloat32(outptr, d20d21); + outptr += 4; + GiStoreFloat32(outptr, d22d23); + outptr += 4; +} + +template +static GI_FORCEINLINE void interleave_1x12_1_s(const T*& inptr0, T*& outptr) { + static_assert(sizeof(T) == 4, "interleave_1x12_1_s only support sizeof(T) == 4"); + GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0); + inptr0 += 4; + GI_FLOAT32_t d2d3 = GiLoadFloat32(inptr0); + inptr0 += 4; + GI_FLOAT32_t d4d5 = GiLoadFloat32(inptr0); + inptr0 += 4; + + GiStoreFloat32(outptr, d0d1); + outptr += 4; + GiStoreFloat32(outptr, d2d3); + outptr += 4; + GiStoreFloat32(outptr, d4d5); + outptr += 4; +} + +template +static GI_FORCEINLINE void interleave_1x4_1_s(const T*& inptr0, T*& outptr) { + static_assert(sizeof(T) == 4, "interleave_1x4_1_s only support sizeof(T) == 4"); + GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0); + inptr0 += 4; + + GiStoreFloat32(outptr, d0d1); + outptr += 4; +} + +template +static GI_FORCEINLINE void interleave_helper( + const T*& inptr, T*& outptr, int unroll_k, int ksize, T val = 0) { + int k = 0; + for (; k < ksize; k++) { + *outptr++ = *inptr++; + } + for (; k < unroll_k; k++) { + *outptr++ = val; + } +} + +template +static GI_FORCEINLINE void interleave_1( + const T*& inptr0, T*& outptr, int unroll_k, int ksize, T val = 0) { + for (int k = 0; k < ksize; k += unroll_k) { + int size = std::min(unroll_k, ksize - k); + interleave_helper(inptr0, outptr, unroll_k, size, val); + } +} + +template +static GI_FORCEINLINE void interleave_4( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr, int unroll_k, int ksize, T val = 0) { + for (int k = 0; k < ksize; k += unroll_k) { + int size = std::min(unroll_k, ksize - k); + interleave_helper(inptr0, outptr, unroll_k, size, val); + interleave_helper(inptr1, outptr, unroll_k, size, val); + interleave_helper(inptr2, outptr, unroll_k, size, val); + interleave_helper(inptr3, outptr, unroll_k, size, val); + } +} + +/* ======================== transpose pack B ======================== */ +/** + * transpose_INTERLEAVE_UNROLLK_BATCH_type + * + * BATCH means process BATCH * INTERLEAVE cols once, BATCH * sizeof(TYPE) * + * INTERLEAVE = 16bytes(128bits, a vector size). + * + * the elements traverse order: + * rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[i, j] + */ + +template +static GI_FORCEINLINE void transpose_4x4_1_s( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr, int stride = 16) { + static_assert(sizeof(T) == 4, "transpose_4x4_1_s only support sizeof(T) == 4"); + + stride = stride / sizeof(float); + stride -= 2; + GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0); + GI_FLOAT32_t d2d3 = GiLoadFloat32(inptr1); + GI_FLOAT32_t d4d5 = GiLoadFloat32(inptr2); + GI_FLOAT32_t d6d7 = GiLoadFloat32(inptr3); + inptr0 += 4; + inptr1 += 4; + inptr2 += 4; + inptr3 += 4; + + GI_FLOAT32_V2_t q0q1 = GiZipqFloat32(d0d1, d2d3); + GI_FLOAT32_V2_t q2q3 = GiZipqFloat32(d4d5, d6d7); + + GiSt1Float32(outptr, GiGetLowFloat32(q0q1.val[0])); + outptr += 2; + GiSt1Float32(outptr, GiGetLowFloat32(q2q3.val[0])); + outptr += stride; + + GiSt1Float32(outptr, GiGetHighFloat32(q0q1.val[0])); + outptr += 2; + GiSt1Float32(outptr, GiGetHighFloat32(q2q3.val[0])); + outptr += stride; + + GiSt1Float32(outptr, GiGetLowFloat32(q0q1.val[1])); + outptr += 2; + GiSt1Float32(outptr, GiGetLowFloat32(q2q3.val[1])); + outptr += stride; + + GiSt1Float32(outptr, GiGetHighFloat32(q0q1.val[1])); + outptr += 2; + GiSt1Float32(outptr, GiGetHighFloat32(q2q3.val[1])); + outptr += stride; +} + +} // namespace fallback +} // namespace matmul +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/matrix_mul/gi/fp32/strategy_4x12.cpp b/dnn/src/fallback/matrix_mul/gi/fp32/strategy_4x12.cpp new file mode 100644 index 00000000..d413252d --- /dev/null +++ b/dnn/src/fallback/matrix_mul/gi/fp32/strategy_4x12.cpp @@ -0,0 +1,950 @@ +#include "src/fallback/matrix_mul/generic_strategy.h" +#include "src/fallback/matrix_mul/gi/fp32/common.h" + +using namespace megdnn; +using namespace matmul::fallback; + +namespace { + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wuninitialized" + +#ifdef __GNUC__ +#ifndef __has_warning +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#else +#if __has_warning("-Wmaybe-uninitialized") +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif +#endif +#endif +void kern_4x12( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int m_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + float* r0 = output; + float* r1 = r0 + LDC; + float* r2 = r1 + LDC; + float* r3 = r2 + LDC; + + GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15, d16d17, d18d19, + d20d21, d22d23, d24d25, d26d27, d28d29, d30d31; + + if (is_first_k) { + d8d9 = GiBroadcastFloat32(0.0f); + d10d11 = GiBroadcastFloat32(0.0f); + d12d13 = GiBroadcastFloat32(0.0f); + d14d15 = GiBroadcastFloat32(0.0f); + d16d17 = GiBroadcastFloat32(0.0f); + d18d19 = GiBroadcastFloat32(0.0f); + d20d21 = GiBroadcastFloat32(0.0f); + d22d23 = GiBroadcastFloat32(0.0f); + d24d25 = GiBroadcastFloat32(0.0f); + d26d27 = GiBroadcastFloat32(0.0f); + d28d29 = GiBroadcastFloat32(0.0f); + d30d31 = GiBroadcastFloat32(0.0f); + } else { + if (m_remain == 4) { + d8d9 = GiLoadFloat32(r0); + d10d11 = GiLoadFloat32(r0 + 4); + d12d13 = GiLoadFloat32(r0 + 8); + + d14d15 = GiLoadFloat32(r1); + d16d17 = GiLoadFloat32(r1 + 4); + d18d19 = GiLoadFloat32(r1 + 8); + + d20d21 = GiLoadFloat32(r2); + d22d23 = GiLoadFloat32(r2 + 4); + d24d25 = GiLoadFloat32(r2 + 8); + + d26d27 = GiLoadFloat32(r3); + d28d29 = GiLoadFloat32(r3 + 4); + d30d31 = GiLoadFloat32(r3 + 8); + } else if (m_remain == 3) { + d8d9 = GiLoadFloat32(r0); + d10d11 = GiLoadFloat32(r0 + 4); + d12d13 = GiLoadFloat32(r0 + 8); + + d14d15 = GiLoadFloat32(r1); + d16d17 = GiLoadFloat32(r1 + 4); + d18d19 = GiLoadFloat32(r1 + 8); + + d20d21 = GiLoadFloat32(r2); + d22d23 = GiLoadFloat32(r2 + 4); + d24d25 = GiLoadFloat32(r2 + 8); + } else if (m_remain == 2) { + d8d9 = GiLoadFloat32(r0); + d10d11 = GiLoadFloat32(r0 + 4); + d12d13 = GiLoadFloat32(r0 + 8); + + d14d15 = GiLoadFloat32(r1); + d16d17 = GiLoadFloat32(r1 + 4); + d18d19 = GiLoadFloat32(r1 + 8); + } else if (m_remain == 1) { + d8d9 = GiLoadFloat32(r0); + d10d11 = GiLoadFloat32(r0 + 4); + d12d13 = GiLoadFloat32(r0 + 8); + } + } + d2d3 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d4d5 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d6d7 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + + for (; K > 0; K--) { + d0d1 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + + d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); + d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); + d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); + d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); + d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); + d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); + d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); + d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); + d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); + d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); + d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); + d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); + + d0d1 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + d2d3 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d4d5 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d6d7 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + + d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); + d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); + d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); + d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); + d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); + d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); + d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); + d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); + d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); + d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); + d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); + d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); + + d2d3 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d4d5 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d6d7 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + } + + if (1 == oddk) { + d0d1 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + + d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); + d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); + d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); + d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); + d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); + d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); + d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); + d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); + d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); + d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); + d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); + d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); + + } else { + d0d1 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + + d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); + d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); + d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); + d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); + d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); + d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); + d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); + d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); + d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); + d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); + d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); + d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); + + d0d1 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + d2d3 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d4d5 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + d6d7 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + + d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0); + d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0); + d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0); + d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1); + d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1); + d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1); + d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2); + d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2); + d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2); + d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3); + d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3); + d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3); + } + + if (m_remain == 4) { + GiStoreFloat32(r0, d8d9); + GiStoreFloat32(r0 + 4, d10d11); + GiStoreFloat32(r0 + 8, d12d13); + + GiStoreFloat32(r1, d14d15); + GiStoreFloat32(r1 + 4, d16d17); + GiStoreFloat32(r1 + 8, d18d19); + + GiStoreFloat32(r2, d20d21); + GiStoreFloat32(r2 + 4, d22d23); + GiStoreFloat32(r2 + 8, d24d25); + + GiStoreFloat32(r3, d26d27); + GiStoreFloat32(r3 + 4, d28d29); + GiStoreFloat32(r3 + 8, d30d31); + } else if (m_remain == 3) { + GiStoreFloat32(r0, d8d9); + GiStoreFloat32(r0 + 4, d10d11); + GiStoreFloat32(r0 + 8, d12d13); + + GiStoreFloat32(r1, d14d15); + GiStoreFloat32(r1 + 4, d16d17); + GiStoreFloat32(r1 + 8, d18d19); + + GiStoreFloat32(r2, d20d21); + GiStoreFloat32(r2 + 4, d22d23); + GiStoreFloat32(r2 + 8, d24d25); + } else if (m_remain == 2) { + GiStoreFloat32(r0, d8d9); + GiStoreFloat32(r0 + 4, d10d11); + GiStoreFloat32(r0 + 8, d12d13); + + GiStoreFloat32(r1, d14d15); + GiStoreFloat32(r1 + 4, d16d17); + GiStoreFloat32(r1 + 8, d18d19); + } else if (m_remain == 1) { + GiStoreFloat32(r0, d8d9); + GiStoreFloat32(r0 + 4, d10d11); + GiStoreFloat32(r0 + 8, d12d13); + } +} + +void kern_4x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { + const float* a_ptr = packA; + const float* b_ptr = packB; + int oddk = (K & 1); + K = ((K + 1) / 2) - 1; + + float* r0 = output; + float* r1 = r0 + LDC; + float* r2 = r1 + LDC; + float* r3 = r2 + LDC; + size_t d_size = sizeof(float); + + GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15; + float tmp[4]; + if (is_first_k) { + d8d9 = GiBroadcastFloat32(0.0f); + d10d11 = GiBroadcastFloat32(0.0f); + d12d13 = GiBroadcastFloat32(0.0f); + d14d15 = GiBroadcastFloat32(0.0f); + } else { + if (m_remain == 4) { + if (n_remain == 4) { + d8d9 = GiLoadFloat32(r0); + d10d11 = GiLoadFloat32(r1); + d12d13 = GiLoadFloat32(r2); + d14d15 = GiLoadFloat32(r3); + } else if (n_remain == 3) { + memcpy(tmp, r0, d_size * 3); + r0 += 3; + d8d9 = GiLoadFloat32(tmp); + + memcpy(tmp, r1, d_size * 3); + r1 += 3; + d10d11 = GiLoadFloat32(tmp); + + memcpy(tmp, r2, d_size * 3); + r2 += 3; + d12d13 = GiLoadFloat32(tmp); + + memcpy(tmp, r3, d_size * 3); + r3 += 3; + d14d15 = GiLoadFloat32(tmp); + } else if (n_remain == 2) { + memcpy(tmp, r0, d_size * 2); + r0 += 2; + d8d9 = GiLoadFloat32(tmp); + + memcpy(tmp, r1, d_size * 2); + r1 += 2; + d10d11 = GiLoadFloat32(tmp); + + memcpy(tmp, r2, d_size * 2); + r2 += 2; + d12d13 = GiLoadFloat32(tmp); + + memcpy(tmp, r3, d_size * 2); + r3 += 2; + d14d15 = GiLoadFloat32(tmp); + } else if (n_remain == 1) { + tmp[0] = *r0; + r0++; + d8d9 = GiLoadFloat32(tmp); + + tmp[0] = *r1; + r1++; + d10d11 = GiLoadFloat32(tmp); + + tmp[0] = *r2; + r2++; + d12d13 = GiLoadFloat32(tmp); + + tmp[0] = *r3; + r3++; + d14d15 = GiLoadFloat32(tmp); + } + } else if (m_remain == 3) { + if (n_remain == 4) { + d8d9 = GiLoadFloat32(r0); + d10d11 = GiLoadFloat32(r1); + d12d13 = GiLoadFloat32(r2); + } else if (n_remain == 3) { + memcpy(tmp, r0, d_size * 3); + r0 += 3; + d8d9 = GiLoadFloat32(tmp); + + memcpy(tmp, r1, d_size * 3); + r1 += 3; + d10d11 = GiLoadFloat32(tmp); + + memcpy(tmp, r2, d_size * 3); + r2 += 3; + d12d13 = GiLoadFloat32(tmp); + } else if (n_remain == 2) { + memcpy(tmp, r0, d_size * 2); + r0 += 2; + d8d9 = GiLoadFloat32(tmp); + + memcpy(tmp, r1, d_size * 2); + r1 += 2; + d10d11 = GiLoadFloat32(tmp); + + memcpy(tmp, r2, d_size * 2); + r2 += 2; + d12d13 = GiLoadFloat32(tmp); + } else if (n_remain == 1) { + tmp[0] = *r0; + r0++; + d8d9 = GiLoadFloat32(tmp); + + tmp[0] = *r1; + r1++; + d10d11 = GiLoadFloat32(tmp); + + tmp[0] = *r2; + r2++; + d12d13 = GiLoadFloat32(tmp); + } + } else if (m_remain == 2) { + if (n_remain == 4) { + d8d9 = GiLoadFloat32(r0); + d10d11 = GiLoadFloat32(r1); + } else if (n_remain == 3) { + memcpy(tmp, r0, d_size * 3); + r0 += 3; + d8d9 = GiLoadFloat32(tmp); + + memcpy(tmp, r1, d_size * 3); + r1 += 3; + d10d11 = GiLoadFloat32(tmp); + } else if (n_remain == 2) { + memcpy(tmp, r0, d_size * 2); + r0 += 2; + d8d9 = GiLoadFloat32(tmp); + + memcpy(tmp, r1, d_size * 2); + r1 += 2; + d10d11 = GiLoadFloat32(tmp); + } else if (n_remain == 1) { + tmp[0] = *r0; + r0++; + d8d9 = GiLoadFloat32(tmp); + + tmp[0] = *r1; + r1++; + d10d11 = GiLoadFloat32(tmp); + } + } else if (m_remain == 1) { + if (n_remain == 4) { + d8d9 = GiLoadFloat32(r0); + } else if (n_remain == 3) { + memcpy(tmp, r0, d_size * 3); + r0 += 3; + d8d9 = GiLoadFloat32(tmp); + } else if (n_remain == 2) { + memcpy(tmp, r0, d_size * 2); + r0 += 2; + d8d9 = GiLoadFloat32(tmp); + } else if (n_remain == 1) { + tmp[0] = *r0; + r0++; + d8d9 = GiLoadFloat32(tmp); + } + } + } + + d0d1 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + d4d5 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + + for (; K > 0; K--) { + d2d3 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + d6d7 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + + d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0); + d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1); + d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2); + d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3); + + d0d1 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + d4d5 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + + d8d9 = GiSimdFmaLane(d8d9, d6d7, d2d3, 0); + d10d11 = GiSimdFmaLane(d10d11, d6d7, d2d3, 1); + d12d13 = GiSimdFmaLane(d12d13, d6d7, d2d3, 2); + d14d15 = GiSimdFmaLane(d14d15, d6d7, d2d3, 3); + } + + if (1 == oddk) { + d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0); + d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1); + d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2); + d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3); + + } else { + d2d3 = GiLoadFloat32(a_ptr); + a_ptr = a_ptr + 4; + d6d7 = GiLoadFloat32(b_ptr); + b_ptr = b_ptr + 4; + + d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0); + d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1); + d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2); + d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3); + + d8d9 = GiSimdFmaLane(d8d9, d6d7, d2d3, 0); + d10d11 = GiSimdFmaLane(d10d11, d6d7, d2d3, 1); + d12d13 = GiSimdFmaLane(d12d13, d6d7, d2d3, 2); + d14d15 = GiSimdFmaLane(d14d15, d6d7, d2d3, 3); + } + + if (m_remain == 4) { + if (n_remain == 4) { + GiStoreFloat32(r0, d8d9); + r0 = r0 + 4; + GiStoreFloat32(r1, d10d11); + r1 = r1 + 4; + GiStoreFloat32(r2, d12d13); + r2 = r2 + 4; + GiStoreFloat32(r3, d14d15); + r3 = r3 + 4; + } else if (n_remain == 3) { + GiStoreFloat32(tmp, d8d9); + memcpy(r0, tmp, d_size * 3); + r0 += 3; + + GiStoreFloat32(tmp, d10d11); + memcpy(r1, tmp, d_size * 3); + r1 += 3; + + GiStoreFloat32(tmp, d12d13); + memcpy(r2, tmp, d_size * 3); + r2 += 3; + + GiStoreFloat32(tmp, d14d15); + memcpy(r3, tmp, d_size * 3); + r3 += 3; + } else if (n_remain == 2) { + GiStoreFloat32(tmp, d8d9); + memcpy(r0, tmp, d_size * 2); + r0 += 2; + + GiStoreFloat32(tmp, d10d11); + memcpy(r1, tmp, d_size * 2); + r1 += 2; + + GiStoreFloat32(tmp, d12d13); + memcpy(r2, tmp, d_size * 2); + r2 += 2; + + GiStoreFloat32(tmp, d14d15); + memcpy(r3, tmp, d_size * 2); + r3 += 2; + } else if (n_remain == 1) { + GiStoreFloat32(tmp, d8d9); + *r0 = tmp[0]; + r0++; + + GiStoreFloat32(tmp, d10d11); + *r1 = tmp[0]; + r1++; + + GiStoreFloat32(tmp, d12d13); + *r2 = tmp[0]; + r2++; + + GiStoreFloat32(tmp, d14d15); + *r3 = tmp[0]; + r3++; + } + } else if (m_remain == 3) { + if (n_remain == 4) { + GiStoreFloat32(r0, d8d9); + r0 = r0 + 4; + GiStoreFloat32(r1, d10d11); + r1 = r1 + 4; + GiStoreFloat32(r2, d12d13); + r2 = r2 + 4; + } else if (n_remain == 3) { + GiStoreFloat32(tmp, d8d9); + memcpy(r0, tmp, d_size * 3); + r0 += 3; + + GiStoreFloat32(tmp, d10d11); + memcpy(r1, tmp, d_size * 3); + r1 += 3; + + GiStoreFloat32(tmp, d12d13); + memcpy(r2, tmp, d_size * 3); + r2 += 3; + } else if (n_remain == 2) { + GiStoreFloat32(tmp, d8d9); + memcpy(r0, tmp, d_size * 2); + r0 += 2; + + GiStoreFloat32(tmp, d10d11); + memcpy(r1, tmp, d_size * 2); + r1 += 2; + + GiStoreFloat32(tmp, d12d13); + memcpy(r2, tmp, d_size * 2); + r2 += 2; + } else if (n_remain == 1) { + GiStoreFloat32(tmp, d8d9); + *r0 = tmp[0]; + r0++; + + GiStoreFloat32(tmp, d10d11); + *r1 = tmp[0]; + r1++; + + GiStoreFloat32(tmp, d12d13); + *r2 = tmp[0]; + r2++; + } + } else if (m_remain == 2) { + if (n_remain == 4) { + GiStoreFloat32(r0, d8d9); + r0 = r0 + 4; + GiStoreFloat32(r1, d10d11); + r1 = r1 + 4; + } else if (n_remain == 3) { + GiStoreFloat32(tmp, d8d9); + memcpy(r0, tmp, d_size * 3); + r0 += 3; + + GiStoreFloat32(tmp, d10d11); + memcpy(r1, tmp, d_size * 3); + r1 += 3; + } else if (n_remain == 2) { + GiStoreFloat32(tmp, d8d9); + memcpy(r0, tmp, d_size * 2); + r0 += 2; + + GiStoreFloat32(tmp, d10d11); + memcpy(r1, tmp, d_size * 2); + r1 += 2; + } else if (n_remain == 1) { + GiStoreFloat32(tmp, d8d9); + *r0 = tmp[0]; + r0++; + + GiStoreFloat32(tmp, d10d11); + *r1 = tmp[0]; + r1++; + } + } else if (m_remain == 1) { + if (n_remain == 4) { + GiStoreFloat32(r0, d8d9); + r0 = r0 + 4; + } else if (n_remain == 3) { + GiStoreFloat32(tmp, d8d9); + memcpy(r0, tmp, d_size * 3); + r0 += 3; + } else if (n_remain == 2) { + GiStoreFloat32(tmp, d8d9); + memcpy(r0, tmp, d_size * 2); + r0 += 2; + } else if (n_remain == 1) { + GiStoreFloat32(tmp, d8d9); + *r0 = tmp[0]; + r0++; + } + } +} +#pragma GCC diagnostic pop + +void gi_sgemm_4x12_pack_A_n( + float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { + float zerobuff[4]; + std::memset(zerobuff, 0, sizeof(float) * 4); + + int y = y0; + for (; y + 3 < ymax; y += 4) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + int K = (kmax - k0); + for (; K > 3; K -= 4) { + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); + } + + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); + } + + for (; y < ymax; y += 4) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + int K = (kmax - k0); + for (; K > 3; K -= 4) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + MEGDNN_FALLTHRU + case 1: + inptr2 = zerobuff; + MEGDNN_FALLTHRU + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (K > 0) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + MEGDNN_FALLTHRU + case 1: + inptr2 = zerobuff; + MEGDNN_FALLTHRU + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K); + } + } +} + +void gi_sgemm_4x12_pack_A_t( + float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { + int ksize = kmax - k0; + int ksize4 = (ksize << 2); + float* outptr_base = out; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const float* inptr = in + k * ldin + x0; + const float* inptr1 = inptr + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + int x = x0; + auto outptr = outptr_base; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); + outptr += ksize4; + } + + if (x < xmax) { + interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + } + + outptr_base += 4 * 4; + } + + for (; k < kmax; k++) { + const float* inptr = in + k * ldin + x0; + int x = x0; + auto outptr = outptr_base; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_1x4_1_s(inptr, outptr_interleave); + outptr += ksize4; + } + + if (x < xmax) { + interleave_1(inptr, outptr, 4, xmax - x); + } + + outptr_base += 4; + } +} + +void gi_sgemm_4x12_pack_B_n( + float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { + int ksize = kmax - k0; + int ksize12 = ksize * 12; + int ksize4 = (ksize << 2); + float* outptr_base = out; + float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12; + + int k = k0; + for (; k + 3 < kmax; k += 4) { + const float* inptr = in + k * ldin + x0; + const float* inptr1 = inptr + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + int x = x0; + auto outptr = outptr_base; + for (; x + 12 <= xmax; x += 12) { + auto outptr_interleave = outptr; + interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); + outptr += ksize12; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); + outptr += ksize4; + } + + if (x < xmax) { + interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); + } + + outptr_base += 12 * 4; + outptr_base4 += 4 * 4; + } + + for (; k < kmax; k++) { + const float* inptr = in + k * ldin + x0; + int x = x0; + auto outptr = outptr_base; + for (; x + 12 <= xmax; x += 12) { + auto outptr_interleave = outptr; + interleave_1x12_1_s(inptr, outptr_interleave); + outptr += ksize12; + } + outptr = outptr_base4; + for (; x + 4 <= xmax; x += 4) { + auto outptr_interleave = outptr; + interleave_1x4_1_s(inptr, outptr_interleave); + outptr += ksize4; + } + + if (x < xmax) { + interleave_1(inptr, outptr, 4, xmax - x); + } + + outptr_base += 12; + outptr_base4 += 4; + } +} + +void gi_sgemm_4x12_pack_B_t( + float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) { + float* outptr = out; + const float* inptr = in; + float zerobuff[4]; + std::memset(zerobuff, 0, sizeof(float) * 4); + int K12 = 12 * (kmax - k0); + + int y = y0; + + for (; y + 12 <= ymax; y += 12) { + int yi = y; + for (; yi < y + 12; yi += 4) { + const float* inptr0 = inptr + yi * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + float* outptr_inner = outptr + yi - y; + + int x = (kmax - k0); + for (; x > 3; x -= 4) { + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, 48); + } + for (; x > 0; x--) { + *outptr_inner++ = *inptr0++; + *outptr_inner++ = *inptr1++; + *outptr_inner++ = *inptr2++; + *outptr_inner++ = *inptr3++; + outptr_inner += 8; + } + } + outptr += K12; + } + + for (; y < ymax; y += 4) { + const float* inptr0 = inptr + y * ldin + k0; + const float* inptr1 = inptr0 + ldin; + const float* inptr2 = inptr1 + ldin; + const float* inptr3 = inptr2 + ldin; + + /* Cope with ragged cases by copying from a buffer of zeroes instead + */ + int x = (kmax - k0); + for (; x > 3; x -= 4) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + MEGDNN_FALLTHRU + case 1: + inptr2 = zerobuff; + MEGDNN_FALLTHRU + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr); + } + + if (x > 0) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + /* Everything falls through in here */ + case 2: + inptr1 = zerobuff; + MEGDNN_FALLTHRU + case 1: + inptr2 = zerobuff; + MEGDNN_FALLTHRU + case 0: + inptr3 = zerobuff; + break; + default: + megdnn_assert(0); + } + } + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, x); + } + } +} + +} // namespace + +MEGDNN_REG_GEMM_STRATEGY_IMPL(gi_sgemm_4x12); + +void gi_sgemm_4x12::pack_A( + float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose_A) const { + if (transpose_A) { + gi_sgemm_4x12_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); + } else { + gi_sgemm_4x12_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); + } +} + +void gi_sgemm_4x12::pack_B( + float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose_B) const { + if (transpose_B) { + gi_sgemm_4x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); + } else { + gi_sgemm_4x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); + } +} + +void gi_sgemm_4x12::kern( + const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, + size_t LDC, bool is_first_k, const float*, float*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float32); + MEGDNN_MARK_USED_VAR(A_dtype); + MEGDNN_MARK_USED_VAR(B_dtype); + MEGDNN_MARK_USED_VAR(C_dtype); + + constexpr size_t A_INTERLEAVE = 4; + constexpr size_t B_INTERLEAVE = 12; + const int K12 = K * 12; + const int K4 = K * 4; + + size_t m = 0; + for (; m < M; m += A_INTERLEAVE) { + float* output = C + (m * LDC); + + size_t n = 0; + const float* cur_packB = packB; + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + kern_4x12( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); + output += B_INTERLEAVE; + cur_packB += K12; + } + + for (; n < N; n += 4) { + kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); + output += 4; + cur_packB += K4; + } + + packA += K4; + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/matrix_mul/opr_impl.cpp b/dnn/src/fallback/matrix_mul/opr_impl.cpp index 71385eca..bb51963c 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.cpp +++ b/dnn/src/fallback/matrix_mul/opr_impl.cpp @@ -28,16 +28,18 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoNaive naive; AlgoF32GiGemvMK4 f32_gemv_mk4; AlgoF32GiMK4_4x8 f32_mk4_4x8; + AlgoF32Gi4x12 f32_4x8; SmallVector m_all_algos; AlgoBase::Mapper m_all_algos_map; public: AlgoPack() { + m_all_algos.emplace_back(&f32_gemv_mk4); + m_all_algos.emplace_back(&f32_mk4_4x8); + m_all_algos.emplace_back(&f32_4x8); m_all_algos.emplace_back(&gemv); m_all_algos.emplace_back(&f32_k8x12x1); m_all_algos.emplace_back(&naive); - m_all_algos.emplace_back(&f32_gemv_mk4); - m_all_algos.emplace_back(&f32_mk4_4x8); for (auto&& algo : m_all_algos) { m_all_algos_map.emplace(algo->info().desc, algo); } diff --git a/dnn/src/fallback/matrix_mul/opr_impl.h b/dnn/src/fallback/matrix_mul/opr_impl.h index ea0b7064..1af625cf 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.h +++ b/dnn/src/fallback/matrix_mul/opr_impl.h @@ -103,6 +103,7 @@ public: FB_NAIVE, FB_GI_F32_GEMV_MK4, FB_GI_F32_MK4_4x8, + FB_GI_F32_4x12, #if MEGDNN_X86 //! x86 @@ -232,6 +233,7 @@ private: class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44 class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44 + class AlgoF32Gi4x12; // fallback F32 gi Gemm class AlgoGemv; class AlgoNaive; class AlgoPack; diff --git a/dnn/test/fallback/matrix_mul.cpp b/dnn/test/fallback/matrix_mul.cpp index c1a121bc..1d50cb04 100644 --- a/dnn/test/fallback/matrix_mul.cpp +++ b/dnn/test/fallback/matrix_mul.cpp @@ -42,6 +42,12 @@ TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) { "FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1); } +TEST_F(FALLBACK, MATRIX_MULF_GI_F32_4x12) { + matrix_mul::check_matrix_mul( + dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), + "FB_GI_F32_4x12"); +} + TEST_F(FALLBACK, MATRIX_MUL_RECORD) { TaskRecordChecker checker(1); using Param = MatrixMul::Param;