Browse Source

feat(fallback): imp gi matmul FB_GI_F32_4x12 algo

GitOrigin-RevId: 16255e7a72
release-1.10
Megvii Engine Team 3 years ago
parent
commit
f249d387de
9 changed files with 1253 additions and 4 deletions
  1. +0
    -2
      dnn/src/fallback/conv_bias/opr_impl.cpp
  2. +58
    -0
      dnn/src/fallback/matrix_mul/algos.cpp
  3. +11
    -0
      dnn/src/fallback/matrix_mul/algos.h
  4. +1
    -0
      dnn/src/fallback/matrix_mul/generic_strategy.h
  5. +221
    -0
      dnn/src/fallback/matrix_mul/gi/fp32/common.h
  6. +950
    -0
      dnn/src/fallback/matrix_mul/gi/fp32/strategy_4x12.cpp
  7. +4
    -2
      dnn/src/fallback/matrix_mul/opr_impl.cpp
  8. +2
    -0
      dnn/src/fallback/matrix_mul/opr_impl.h
  9. +6
    -0
      dnn/test/fallback/matrix_mul.cpp

+ 0
- 2
dnn/src/fallback/conv_bias/opr_impl.cpp View File

@@ -138,8 +138,6 @@ public:
}
}

//! TODO: move arm_v7 MatrixMulImpl::AlgoF32 matmul to gi fallback, for nchw
//! prefetch algo, also need update dnn/test/common/conv_bias.cpp:check_winograd
matmul_algos = static_cast<fallback::MatrixMulImpl*>(matmul_opr)
->select_algo_type(
{AlgoDataType::FLOAT32, MatmulFormat::DEFAULT});


+ 58
- 0
dnn/src/fallback/matrix_mul/algos.cpp View File

@@ -15,6 +15,7 @@ MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like)
MIDOUT_DECL(megdnn_fb_matmul_naive)
MIDOUT_DECL(megdnn_fb_gi_exec_fp32)
MIDOUT_DECL(megdnn_fb_gi_matmul_kern)
MIDOUT_DECL(megdnn_fb_gi_f32_4x12)

using namespace megdnn;
using namespace fallback;
@@ -293,4 +294,61 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern(
const KernSizeParam&) const {
return gi_f32_mk4_4x8_kern;
}

/* ===================== F32 algo ===================== */
namespace {
void f32_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_fb_gi_f32_4x12, midout_iv("f32_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<float>(), Bptr = kern_param.B<float>();
auto Cptr = kern_param.C<float>();

matmul::fallback::gi_sgemm_4x12 strategy(M, N, K, A_type, B_type, C_type);
megdnn::matmul::GemmInterleaved<matmul::fallback::gi_sgemm_4x12>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
}
MIDOUT_END();
}

} // anonymous namespace

bool MatrixMulImpl::AlgoF32Gi4x12::usable(const KernSizeParam& kern_size_param) const {
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32();
}

size_t MatrixMulImpl::AlgoF32Gi4x12::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(
megdnn_fb_gi_f32_4x12, midout_iv("AlgoF32Gi4x12::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
matmul::fallback::gi_sgemm_4x12 strategy(M, N, K, A_type, B_type, C_type);
return megdnn::matmul::GemmInterleaved<matmul::fallback::gi_sgemm_4x12>(
M, N, K, trA, trB, strategy)
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gi4x12::get_kern(
const KernSizeParam&) const {
return f32_kern;
}

MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
AlgoF32Gi4x12, megdnn_fb_gi_f32_4x12, "AlgoF32Gi4x12Impl"_hash,
matmul::fallback::gi_sgemm_4x12, float, float, AlgoDataType::FLOAT32, DEFAULT);

// vim: syntax=cpp.doxygen

+ 11
- 0
dnn/src/fallback/matrix_mul/algos.h View File

@@ -97,6 +97,17 @@ public:
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8)
};

class MatrixMulImpl::AlgoF32Gi4x12 final : public AlgoBase {
public:
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "FB_GI_F32_4x12"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_4x12)
};

} // namespace fallback
} // namespace megdnn



+ 1
- 0
dnn/src/fallback/matrix_mul/generic_strategy.h View File

@@ -8,6 +8,7 @@ namespace fallback {
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12);
MEGDNN_REG_GEMM_STRATEGY_NOPACK(
float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8);
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, gi_sgemm_4x12);

} // namespace fallback
} // namespace matmul


+ 221
- 0
dnn/src/fallback/matrix_mul/gi/fp32/common.h View File

@@ -0,0 +1,221 @@
#pragma once

#include "src/fallback/general_intrinsic/gi_float.h"

namespace megdnn {
namespace matmul {
namespace fallback {

/* ======================== transform ======================== */
/**
* interleave_INTERLEAVE_UNROLLK_BATCH_type
*
* BATCH means process BATCH * UNROLL_K cols once, BATCH * sizeof(TYPE) *
* UNROLL_K = 16bytes(128bits, a vector size).
*
* the elements traverse order:
* rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[j, i]
*/

template <typename T>
static GI_FORCEINLINE void interleave_4x4_1_s(
const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3,
T*& outptr) {
static_assert(sizeof(T) == 4, "interleave_4x4_1_s only support sizeof(T) == 4");
GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0);
GI_FLOAT32_t d2d3 = GiLoadFloat32(inptr1);
GI_FLOAT32_t d4d5 = GiLoadFloat32(inptr2);
GI_FLOAT32_t d6d7 = GiLoadFloat32(inptr3);
inptr0 += 4;
inptr1 += 4;
inptr2 += 4;
inptr3 += 4;

GiStoreFloat32(outptr, d0d1);
outptr += 4;
GiStoreFloat32(outptr, d2d3);
outptr += 4;
GiStoreFloat32(outptr, d4d5);
outptr += 4;
GiStoreFloat32(outptr, d6d7);
outptr += 4;
}

template <typename T>
static GI_FORCEINLINE void interleave_4x12_1_s(
const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3,
T*& outptr) {
static_assert(sizeof(T) == 4, "interleave_4x12_1_s only support sizeof(T) == 4");
GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0);
inptr0 += 4;
GI_FLOAT32_t d2d3 = GiLoadFloat32(inptr0);
inptr0 += 4;
GI_FLOAT32_t d4d5 = GiLoadFloat32(inptr0);
inptr0 += 4;

GI_FLOAT32_t d6d7 = GiLoadFloat32(inptr1);
inptr1 += 4;
GI_FLOAT32_t d8d9 = GiLoadFloat32(inptr1);
inptr1 += 4;
GI_FLOAT32_t d10d11 = GiLoadFloat32(inptr1);
inptr1 += 4;

GI_FLOAT32_t d12d13 = GiLoadFloat32(inptr2);
inptr2 += 4;
GI_FLOAT32_t d14d15 = GiLoadFloat32(inptr2);
inptr2 += 4;
GI_FLOAT32_t d16d17 = GiLoadFloat32(inptr2);
inptr2 += 4;

GI_FLOAT32_t d18d19 = GiLoadFloat32(inptr3);
inptr3 += 4;
GI_FLOAT32_t d20d21 = GiLoadFloat32(inptr3);
inptr3 += 4;
GI_FLOAT32_t d22d23 = GiLoadFloat32(inptr3);
inptr3 += 4;

GiStoreFloat32(outptr, d0d1);
outptr += 4;
GiStoreFloat32(outptr, d2d3);
outptr += 4;
GiStoreFloat32(outptr, d4d5);
outptr += 4;
GiStoreFloat32(outptr, d6d7);
outptr += 4;
GiStoreFloat32(outptr, d8d9);
outptr += 4;
GiStoreFloat32(outptr, d10d11);
outptr += 4;
GiStoreFloat32(outptr, d12d13);
outptr += 4;
GiStoreFloat32(outptr, d14d15);
outptr += 4;
GiStoreFloat32(outptr, d16d17);
outptr += 4;
GiStoreFloat32(outptr, d18d19);
outptr += 4;
GiStoreFloat32(outptr, d20d21);
outptr += 4;
GiStoreFloat32(outptr, d22d23);
outptr += 4;
}

template <typename T>
static GI_FORCEINLINE void interleave_1x12_1_s(const T*& inptr0, T*& outptr) {
static_assert(sizeof(T) == 4, "interleave_1x12_1_s only support sizeof(T) == 4");
GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0);
inptr0 += 4;
GI_FLOAT32_t d2d3 = GiLoadFloat32(inptr0);
inptr0 += 4;
GI_FLOAT32_t d4d5 = GiLoadFloat32(inptr0);
inptr0 += 4;

GiStoreFloat32(outptr, d0d1);
outptr += 4;
GiStoreFloat32(outptr, d2d3);
outptr += 4;
GiStoreFloat32(outptr, d4d5);
outptr += 4;
}

template <typename T>
static GI_FORCEINLINE void interleave_1x4_1_s(const T*& inptr0, T*& outptr) {
static_assert(sizeof(T) == 4, "interleave_1x4_1_s only support sizeof(T) == 4");
GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0);
inptr0 += 4;

GiStoreFloat32(outptr, d0d1);
outptr += 4;
}

template <typename T>
static GI_FORCEINLINE void interleave_helper(
const T*& inptr, T*& outptr, int unroll_k, int ksize, T val = 0) {
int k = 0;
for (; k < ksize; k++) {
*outptr++ = *inptr++;
}
for (; k < unroll_k; k++) {
*outptr++ = val;
}
}

template <typename T>
static GI_FORCEINLINE void interleave_1(
const T*& inptr0, T*& outptr, int unroll_k, int ksize, T val = 0) {
for (int k = 0; k < ksize; k += unroll_k) {
int size = std::min(unroll_k, ksize - k);
interleave_helper(inptr0, outptr, unroll_k, size, val);
}
}

template <typename T>
static GI_FORCEINLINE void interleave_4(
const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3,
T*& outptr, int unroll_k, int ksize, T val = 0) {
for (int k = 0; k < ksize; k += unroll_k) {
int size = std::min(unroll_k, ksize - k);
interleave_helper(inptr0, outptr, unroll_k, size, val);
interleave_helper(inptr1, outptr, unroll_k, size, val);
interleave_helper(inptr2, outptr, unroll_k, size, val);
interleave_helper(inptr3, outptr, unroll_k, size, val);
}
}

/* ======================== transpose pack B ======================== */
/**
* transpose_INTERLEAVE_UNROLLK_BATCH_type
*
* BATCH means process BATCH * INTERLEAVE cols once, BATCH * sizeof(TYPE) *
* INTERLEAVE = 16bytes(128bits, a vector size).
*
* the elements traverse order:
* rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[i, j]
*/

template <typename T>
static GI_FORCEINLINE void transpose_4x4_1_s(
const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3,
T*& outptr, int stride = 16) {
static_assert(sizeof(T) == 4, "transpose_4x4_1_s only support sizeof(T) == 4");

stride = stride / sizeof(float);
stride -= 2;
GI_FLOAT32_t d0d1 = GiLoadFloat32(inptr0);
GI_FLOAT32_t d2d3 = GiLoadFloat32(inptr1);
GI_FLOAT32_t d4d5 = GiLoadFloat32(inptr2);
GI_FLOAT32_t d6d7 = GiLoadFloat32(inptr3);
inptr0 += 4;
inptr1 += 4;
inptr2 += 4;
inptr3 += 4;

GI_FLOAT32_V2_t q0q1 = GiZipqFloat32(d0d1, d2d3);
GI_FLOAT32_V2_t q2q3 = GiZipqFloat32(d4d5, d6d7);

GiSt1Float32(outptr, GiGetLowFloat32(q0q1.val[0]));
outptr += 2;
GiSt1Float32(outptr, GiGetLowFloat32(q2q3.val[0]));
outptr += stride;

GiSt1Float32(outptr, GiGetHighFloat32(q0q1.val[0]));
outptr += 2;
GiSt1Float32(outptr, GiGetHighFloat32(q2q3.val[0]));
outptr += stride;

GiSt1Float32(outptr, GiGetLowFloat32(q0q1.val[1]));
outptr += 2;
GiSt1Float32(outptr, GiGetLowFloat32(q2q3.val[1]));
outptr += stride;

GiSt1Float32(outptr, GiGetHighFloat32(q0q1.val[1]));
outptr += 2;
GiSt1Float32(outptr, GiGetHighFloat32(q2q3.val[1]));
outptr += stride;
}

} // namespace fallback
} // namespace matmul
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 950
- 0
dnn/src/fallback/matrix_mul/gi/fp32/strategy_4x12.cpp View File

@@ -0,0 +1,950 @@
#include "src/fallback/matrix_mul/generic_strategy.h"
#include "src/fallback/matrix_mul/gi/fp32/common.h"

using namespace megdnn;
using namespace matmul::fallback;

namespace {

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wuninitialized"

#ifdef __GNUC__
#ifndef __has_warning
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#else
#if __has_warning("-Wmaybe-uninitialized")
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif
#endif
#endif
void kern_4x12(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;

float* r0 = output;
float* r1 = r0 + LDC;
float* r2 = r1 + LDC;
float* r3 = r2 + LDC;

GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15, d16d17, d18d19,
d20d21, d22d23, d24d25, d26d27, d28d29, d30d31;

if (is_first_k) {
d8d9 = GiBroadcastFloat32(0.0f);
d10d11 = GiBroadcastFloat32(0.0f);
d12d13 = GiBroadcastFloat32(0.0f);
d14d15 = GiBroadcastFloat32(0.0f);
d16d17 = GiBroadcastFloat32(0.0f);
d18d19 = GiBroadcastFloat32(0.0f);
d20d21 = GiBroadcastFloat32(0.0f);
d22d23 = GiBroadcastFloat32(0.0f);
d24d25 = GiBroadcastFloat32(0.0f);
d26d27 = GiBroadcastFloat32(0.0f);
d28d29 = GiBroadcastFloat32(0.0f);
d30d31 = GiBroadcastFloat32(0.0f);
} else {
if (m_remain == 4) {
d8d9 = GiLoadFloat32(r0);
d10d11 = GiLoadFloat32(r0 + 4);
d12d13 = GiLoadFloat32(r0 + 8);

d14d15 = GiLoadFloat32(r1);
d16d17 = GiLoadFloat32(r1 + 4);
d18d19 = GiLoadFloat32(r1 + 8);

d20d21 = GiLoadFloat32(r2);
d22d23 = GiLoadFloat32(r2 + 4);
d24d25 = GiLoadFloat32(r2 + 8);

d26d27 = GiLoadFloat32(r3);
d28d29 = GiLoadFloat32(r3 + 4);
d30d31 = GiLoadFloat32(r3 + 8);
} else if (m_remain == 3) {
d8d9 = GiLoadFloat32(r0);
d10d11 = GiLoadFloat32(r0 + 4);
d12d13 = GiLoadFloat32(r0 + 8);

d14d15 = GiLoadFloat32(r1);
d16d17 = GiLoadFloat32(r1 + 4);
d18d19 = GiLoadFloat32(r1 + 8);

d20d21 = GiLoadFloat32(r2);
d22d23 = GiLoadFloat32(r2 + 4);
d24d25 = GiLoadFloat32(r2 + 8);
} else if (m_remain == 2) {
d8d9 = GiLoadFloat32(r0);
d10d11 = GiLoadFloat32(r0 + 4);
d12d13 = GiLoadFloat32(r0 + 8);

d14d15 = GiLoadFloat32(r1);
d16d17 = GiLoadFloat32(r1 + 4);
d18d19 = GiLoadFloat32(r1 + 8);
} else if (m_remain == 1) {
d8d9 = GiLoadFloat32(r0);
d10d11 = GiLoadFloat32(r0 + 4);
d12d13 = GiLoadFloat32(r0 + 8);
}
}
d2d3 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;

for (; K > 0; K--) {
d0d1 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;

d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0);
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0);
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0);
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1);
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1);
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2);
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2);
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2);
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3);
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3);
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3);

d0d1 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;
d2d3 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;

d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0);
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0);
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0);
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1);
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1);
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2);
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2);
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2);
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3);
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3);
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3);

d2d3 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
}

if (1 == oddk) {
d0d1 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;

d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0);
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0);
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0);
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1);
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1);
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2);
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2);
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2);
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3);
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3);
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3);

} else {
d0d1 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;

d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0);
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0);
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0);
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1);
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1);
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2);
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2);
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2);
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3);
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3);
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3);

d0d1 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;
d2d3 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;
d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;

d8d9 = GiSimdFmaLane(d8d9, d2d3, d0d1, 0);
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 0);
d12d13 = GiSimdFmaLane(d12d13, d6d7, d0d1, 0);
d14d15 = GiSimdFmaLane(d14d15, d2d3, d0d1, 1);
d16d17 = GiSimdFmaLane(d16d17, d4d5, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d6d7, d0d1, 1);
d20d21 = GiSimdFmaLane(d20d21, d2d3, d0d1, 2);
d22d23 = GiSimdFmaLane(d22d23, d4d5, d0d1, 2);
d24d25 = GiSimdFmaLane(d24d25, d6d7, d0d1, 2);
d26d27 = GiSimdFmaLane(d26d27, d2d3, d0d1, 3);
d28d29 = GiSimdFmaLane(d28d29, d4d5, d0d1, 3);
d30d31 = GiSimdFmaLane(d30d31, d6d7, d0d1, 3);
}

if (m_remain == 4) {
GiStoreFloat32(r0, d8d9);
GiStoreFloat32(r0 + 4, d10d11);
GiStoreFloat32(r0 + 8, d12d13);

GiStoreFloat32(r1, d14d15);
GiStoreFloat32(r1 + 4, d16d17);
GiStoreFloat32(r1 + 8, d18d19);

GiStoreFloat32(r2, d20d21);
GiStoreFloat32(r2 + 4, d22d23);
GiStoreFloat32(r2 + 8, d24d25);

GiStoreFloat32(r3, d26d27);
GiStoreFloat32(r3 + 4, d28d29);
GiStoreFloat32(r3 + 8, d30d31);
} else if (m_remain == 3) {
GiStoreFloat32(r0, d8d9);
GiStoreFloat32(r0 + 4, d10d11);
GiStoreFloat32(r0 + 8, d12d13);

GiStoreFloat32(r1, d14d15);
GiStoreFloat32(r1 + 4, d16d17);
GiStoreFloat32(r1 + 8, d18d19);

GiStoreFloat32(r2, d20d21);
GiStoreFloat32(r2 + 4, d22d23);
GiStoreFloat32(r2 + 8, d24d25);
} else if (m_remain == 2) {
GiStoreFloat32(r0, d8d9);
GiStoreFloat32(r0 + 4, d10d11);
GiStoreFloat32(r0 + 8, d12d13);

GiStoreFloat32(r1, d14d15);
GiStoreFloat32(r1 + 4, d16d17);
GiStoreFloat32(r1 + 8, d18d19);
} else if (m_remain == 1) {
GiStoreFloat32(r0, d8d9);
GiStoreFloat32(r0 + 4, d10d11);
GiStoreFloat32(r0 + 8, d12d13);
}
}

void kern_4x4(
const float* packA, const float* packB, int K, float* output, int LDC,
bool is_first_k, int m_remain, int n_remain) {
const float* a_ptr = packA;
const float* b_ptr = packB;
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;

float* r0 = output;
float* r1 = r0 + LDC;
float* r2 = r1 + LDC;
float* r3 = r2 + LDC;
size_t d_size = sizeof(float);

GI_FLOAT32_t d0d1, d2d3, d4d5, d6d7, d8d9, d10d11, d12d13, d14d15;
float tmp[4];
if (is_first_k) {
d8d9 = GiBroadcastFloat32(0.0f);
d10d11 = GiBroadcastFloat32(0.0f);
d12d13 = GiBroadcastFloat32(0.0f);
d14d15 = GiBroadcastFloat32(0.0f);
} else {
if (m_remain == 4) {
if (n_remain == 4) {
d8d9 = GiLoadFloat32(r0);
d10d11 = GiLoadFloat32(r1);
d12d13 = GiLoadFloat32(r2);
d14d15 = GiLoadFloat32(r3);
} else if (n_remain == 3) {
memcpy(tmp, r0, d_size * 3);
r0 += 3;
d8d9 = GiLoadFloat32(tmp);

memcpy(tmp, r1, d_size * 3);
r1 += 3;
d10d11 = GiLoadFloat32(tmp);

memcpy(tmp, r2, d_size * 3);
r2 += 3;
d12d13 = GiLoadFloat32(tmp);

memcpy(tmp, r3, d_size * 3);
r3 += 3;
d14d15 = GiLoadFloat32(tmp);
} else if (n_remain == 2) {
memcpy(tmp, r0, d_size * 2);
r0 += 2;
d8d9 = GiLoadFloat32(tmp);

memcpy(tmp, r1, d_size * 2);
r1 += 2;
d10d11 = GiLoadFloat32(tmp);

memcpy(tmp, r2, d_size * 2);
r2 += 2;
d12d13 = GiLoadFloat32(tmp);

memcpy(tmp, r3, d_size * 2);
r3 += 2;
d14d15 = GiLoadFloat32(tmp);
} else if (n_remain == 1) {
tmp[0] = *r0;
r0++;
d8d9 = GiLoadFloat32(tmp);

tmp[0] = *r1;
r1++;
d10d11 = GiLoadFloat32(tmp);

tmp[0] = *r2;
r2++;
d12d13 = GiLoadFloat32(tmp);

tmp[0] = *r3;
r3++;
d14d15 = GiLoadFloat32(tmp);
}
} else if (m_remain == 3) {
if (n_remain == 4) {
d8d9 = GiLoadFloat32(r0);
d10d11 = GiLoadFloat32(r1);
d12d13 = GiLoadFloat32(r2);
} else if (n_remain == 3) {
memcpy(tmp, r0, d_size * 3);
r0 += 3;
d8d9 = GiLoadFloat32(tmp);

memcpy(tmp, r1, d_size * 3);
r1 += 3;
d10d11 = GiLoadFloat32(tmp);

memcpy(tmp, r2, d_size * 3);
r2 += 3;
d12d13 = GiLoadFloat32(tmp);
} else if (n_remain == 2) {
memcpy(tmp, r0, d_size * 2);
r0 += 2;
d8d9 = GiLoadFloat32(tmp);

memcpy(tmp, r1, d_size * 2);
r1 += 2;
d10d11 = GiLoadFloat32(tmp);

memcpy(tmp, r2, d_size * 2);
r2 += 2;
d12d13 = GiLoadFloat32(tmp);
} else if (n_remain == 1) {
tmp[0] = *r0;
r0++;
d8d9 = GiLoadFloat32(tmp);

tmp[0] = *r1;
r1++;
d10d11 = GiLoadFloat32(tmp);

tmp[0] = *r2;
r2++;
d12d13 = GiLoadFloat32(tmp);
}
} else if (m_remain == 2) {
if (n_remain == 4) {
d8d9 = GiLoadFloat32(r0);
d10d11 = GiLoadFloat32(r1);
} else if (n_remain == 3) {
memcpy(tmp, r0, d_size * 3);
r0 += 3;
d8d9 = GiLoadFloat32(tmp);

memcpy(tmp, r1, d_size * 3);
r1 += 3;
d10d11 = GiLoadFloat32(tmp);
} else if (n_remain == 2) {
memcpy(tmp, r0, d_size * 2);
r0 += 2;
d8d9 = GiLoadFloat32(tmp);

memcpy(tmp, r1, d_size * 2);
r1 += 2;
d10d11 = GiLoadFloat32(tmp);
} else if (n_remain == 1) {
tmp[0] = *r0;
r0++;
d8d9 = GiLoadFloat32(tmp);

tmp[0] = *r1;
r1++;
d10d11 = GiLoadFloat32(tmp);
}
} else if (m_remain == 1) {
if (n_remain == 4) {
d8d9 = GiLoadFloat32(r0);
} else if (n_remain == 3) {
memcpy(tmp, r0, d_size * 3);
r0 += 3;
d8d9 = GiLoadFloat32(tmp);
} else if (n_remain == 2) {
memcpy(tmp, r0, d_size * 2);
r0 += 2;
d8d9 = GiLoadFloat32(tmp);
} else if (n_remain == 1) {
tmp[0] = *r0;
r0++;
d8d9 = GiLoadFloat32(tmp);
}
}
}

d0d1 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;
d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;

for (; K > 0; K--) {
d2d3 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;
d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;

d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0);
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1);
d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2);
d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3);

d0d1 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;
d4d5 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;

d8d9 = GiSimdFmaLane(d8d9, d6d7, d2d3, 0);
d10d11 = GiSimdFmaLane(d10d11, d6d7, d2d3, 1);
d12d13 = GiSimdFmaLane(d12d13, d6d7, d2d3, 2);
d14d15 = GiSimdFmaLane(d14d15, d6d7, d2d3, 3);
}

if (1 == oddk) {
d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0);
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1);
d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2);
d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3);

} else {
d2d3 = GiLoadFloat32(a_ptr);
a_ptr = a_ptr + 4;
d6d7 = GiLoadFloat32(b_ptr);
b_ptr = b_ptr + 4;

d8d9 = GiSimdFmaLane(d8d9, d4d5, d0d1, 0);
d10d11 = GiSimdFmaLane(d10d11, d4d5, d0d1, 1);
d12d13 = GiSimdFmaLane(d12d13, d4d5, d0d1, 2);
d14d15 = GiSimdFmaLane(d14d15, d4d5, d0d1, 3);

d8d9 = GiSimdFmaLane(d8d9, d6d7, d2d3, 0);
d10d11 = GiSimdFmaLane(d10d11, d6d7, d2d3, 1);
d12d13 = GiSimdFmaLane(d12d13, d6d7, d2d3, 2);
d14d15 = GiSimdFmaLane(d14d15, d6d7, d2d3, 3);
}

if (m_remain == 4) {
if (n_remain == 4) {
GiStoreFloat32(r0, d8d9);
r0 = r0 + 4;
GiStoreFloat32(r1, d10d11);
r1 = r1 + 4;
GiStoreFloat32(r2, d12d13);
r2 = r2 + 4;
GiStoreFloat32(r3, d14d15);
r3 = r3 + 4;
} else if (n_remain == 3) {
GiStoreFloat32(tmp, d8d9);
memcpy(r0, tmp, d_size * 3);
r0 += 3;

GiStoreFloat32(tmp, d10d11);
memcpy(r1, tmp, d_size * 3);
r1 += 3;

GiStoreFloat32(tmp, d12d13);
memcpy(r2, tmp, d_size * 3);
r2 += 3;

GiStoreFloat32(tmp, d14d15);
memcpy(r3, tmp, d_size * 3);
r3 += 3;
} else if (n_remain == 2) {
GiStoreFloat32(tmp, d8d9);
memcpy(r0, tmp, d_size * 2);
r0 += 2;

GiStoreFloat32(tmp, d10d11);
memcpy(r1, tmp, d_size * 2);
r1 += 2;

GiStoreFloat32(tmp, d12d13);
memcpy(r2, tmp, d_size * 2);
r2 += 2;

GiStoreFloat32(tmp, d14d15);
memcpy(r3, tmp, d_size * 2);
r3 += 2;
} else if (n_remain == 1) {
GiStoreFloat32(tmp, d8d9);
*r0 = tmp[0];
r0++;

GiStoreFloat32(tmp, d10d11);
*r1 = tmp[0];
r1++;

GiStoreFloat32(tmp, d12d13);
*r2 = tmp[0];
r2++;

GiStoreFloat32(tmp, d14d15);
*r3 = tmp[0];
r3++;
}
} else if (m_remain == 3) {
if (n_remain == 4) {
GiStoreFloat32(r0, d8d9);
r0 = r0 + 4;
GiStoreFloat32(r1, d10d11);
r1 = r1 + 4;
GiStoreFloat32(r2, d12d13);
r2 = r2 + 4;
} else if (n_remain == 3) {
GiStoreFloat32(tmp, d8d9);
memcpy(r0, tmp, d_size * 3);
r0 += 3;

GiStoreFloat32(tmp, d10d11);
memcpy(r1, tmp, d_size * 3);
r1 += 3;

GiStoreFloat32(tmp, d12d13);
memcpy(r2, tmp, d_size * 3);
r2 += 3;
} else if (n_remain == 2) {
GiStoreFloat32(tmp, d8d9);
memcpy(r0, tmp, d_size * 2);
r0 += 2;

GiStoreFloat32(tmp, d10d11);
memcpy(r1, tmp, d_size * 2);
r1 += 2;

GiStoreFloat32(tmp, d12d13);
memcpy(r2, tmp, d_size * 2);
r2 += 2;
} else if (n_remain == 1) {
GiStoreFloat32(tmp, d8d9);
*r0 = tmp[0];
r0++;

GiStoreFloat32(tmp, d10d11);
*r1 = tmp[0];
r1++;

GiStoreFloat32(tmp, d12d13);
*r2 = tmp[0];
r2++;
}
} else if (m_remain == 2) {
if (n_remain == 4) {
GiStoreFloat32(r0, d8d9);
r0 = r0 + 4;
GiStoreFloat32(r1, d10d11);
r1 = r1 + 4;
} else if (n_remain == 3) {
GiStoreFloat32(tmp, d8d9);
memcpy(r0, tmp, d_size * 3);
r0 += 3;

GiStoreFloat32(tmp, d10d11);
memcpy(r1, tmp, d_size * 3);
r1 += 3;
} else if (n_remain == 2) {
GiStoreFloat32(tmp, d8d9);
memcpy(r0, tmp, d_size * 2);
r0 += 2;

GiStoreFloat32(tmp, d10d11);
memcpy(r1, tmp, d_size * 2);
r1 += 2;
} else if (n_remain == 1) {
GiStoreFloat32(tmp, d8d9);
*r0 = tmp[0];
r0++;

GiStoreFloat32(tmp, d10d11);
*r1 = tmp[0];
r1++;
}
} else if (m_remain == 1) {
if (n_remain == 4) {
GiStoreFloat32(r0, d8d9);
r0 = r0 + 4;
} else if (n_remain == 3) {
GiStoreFloat32(tmp, d8d9);
memcpy(r0, tmp, d_size * 3);
r0 += 3;
} else if (n_remain == 2) {
GiStoreFloat32(tmp, d8d9);
memcpy(r0, tmp, d_size * 2);
r0 += 2;
} else if (n_remain == 1) {
GiStoreFloat32(tmp, d8d9);
*r0 = tmp[0];
r0++;
}
}
}
#pragma GCC diagnostic pop

void gi_sgemm_4x12_pack_A_n(
float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
float zerobuff[4];
std::memset(zerobuff, 0, sizeof(float) * 4);

int y = y0;
for (; y + 3 < ymax; y += 4) {
const float* inptr0 = inptr + y * ldin + k0;
const float* inptr1 = inptr0 + ldin;
const float* inptr2 = inptr1 + ldin;
const float* inptr3 = inptr2 + ldin;

int K = (kmax - k0);
for (; K > 3; K -= 4) {
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr);
}

interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K);
}

for (; y < ymax; y += 4) {
const float* inptr0 = inptr + y * ldin + k0;
const float* inptr1 = inptr0 + ldin;
const float* inptr2 = inptr1 + ldin;
const float* inptr3 = inptr2 + ldin;

int K = (kmax - k0);
for (; K > 3; K -= 4) {
if ((y + 3) >= ymax) {
switch ((y + 3) - ymax) {
/* Everything falls through in here */
case 2:
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}

transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr);
}

if (K > 0) {
if ((y + 3) >= ymax) {
switch ((y + 3) - ymax) {
/* Everything falls through in here */
case 2:
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, K);
}
}
}

void gi_sgemm_4x12_pack_A_t(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) {
int ksize = kmax - k0;
int ksize4 = (ksize << 2);
float* outptr_base = out;

int k = k0;
for (; k + 3 < kmax; k += 4) {
const float* inptr = in + k * ldin + x0;
const float* inptr1 = inptr + ldin;
const float* inptr2 = inptr1 + ldin;
const float* inptr3 = inptr2 + ldin;

int x = x0;
auto outptr = outptr_base;
for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr;
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize4;
}

if (x < xmax) {
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x);
}

outptr_base += 4 * 4;
}

for (; k < kmax; k++) {
const float* inptr = in + k * ldin + x0;
int x = x0;
auto outptr = outptr_base;
for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr;
interleave_1x4_1_s(inptr, outptr_interleave);
outptr += ksize4;
}

if (x < xmax) {
interleave_1(inptr, outptr, 4, xmax - x);
}

outptr_base += 4;
}
}

void gi_sgemm_4x12_pack_B_n(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) {
int ksize = kmax - k0;
int ksize12 = ksize * 12;
int ksize4 = (ksize << 2);
float* outptr_base = out;
float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12;

int k = k0;
for (; k + 3 < kmax; k += 4) {
const float* inptr = in + k * ldin + x0;
const float* inptr1 = inptr + ldin;
const float* inptr2 = inptr1 + ldin;
const float* inptr3 = inptr2 + ldin;

int x = x0;
auto outptr = outptr_base;
for (; x + 12 <= xmax; x += 12) {
auto outptr_interleave = outptr;
interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize12;
}
outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr;
interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave);
outptr += ksize4;
}

if (x < xmax) {
interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x);
}

outptr_base += 12 * 4;
outptr_base4 += 4 * 4;
}

for (; k < kmax; k++) {
const float* inptr = in + k * ldin + x0;
int x = x0;
auto outptr = outptr_base;
for (; x + 12 <= xmax; x += 12) {
auto outptr_interleave = outptr;
interleave_1x12_1_s(inptr, outptr_interleave);
outptr += ksize12;
}
outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) {
auto outptr_interleave = outptr;
interleave_1x4_1_s(inptr, outptr_interleave);
outptr += ksize4;
}

if (x < xmax) {
interleave_1(inptr, outptr, 4, xmax - x);
}

outptr_base += 12;
outptr_base4 += 4;
}
}

void gi_sgemm_4x12_pack_B_t(
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) {
float* outptr = out;
const float* inptr = in;
float zerobuff[4];
std::memset(zerobuff, 0, sizeof(float) * 4);
int K12 = 12 * (kmax - k0);

int y = y0;

for (; y + 12 <= ymax; y += 12) {
int yi = y;
for (; yi < y + 12; yi += 4) {
const float* inptr0 = inptr + yi * ldin + k0;
const float* inptr1 = inptr0 + ldin;
const float* inptr2 = inptr1 + ldin;
const float* inptr3 = inptr2 + ldin;
float* outptr_inner = outptr + yi - y;

int x = (kmax - k0);
for (; x > 3; x -= 4) {
transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, 48);
}
for (; x > 0; x--) {
*outptr_inner++ = *inptr0++;
*outptr_inner++ = *inptr1++;
*outptr_inner++ = *inptr2++;
*outptr_inner++ = *inptr3++;
outptr_inner += 8;
}
}
outptr += K12;
}

for (; y < ymax; y += 4) {
const float* inptr0 = inptr + y * ldin + k0;
const float* inptr1 = inptr0 + ldin;
const float* inptr2 = inptr1 + ldin;
const float* inptr3 = inptr2 + ldin;

/* Cope with ragged cases by copying from a buffer of zeroes instead
*/
int x = (kmax - k0);
for (; x > 3; x -= 4) {
if ((y + 3) >= ymax) {
switch ((y + 3) - ymax) {
/* Everything falls through in here */
case 2:
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}

transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr);
}

if (x > 0) {
if ((y + 3) >= ymax) {
switch ((y + 3) - ymax) {
/* Everything falls through in here */
case 2:
inptr1 = zerobuff;
MEGDNN_FALLTHRU
case 1:
inptr2 = zerobuff;
MEGDNN_FALLTHRU
case 0:
inptr3 = zerobuff;
break;
default:
megdnn_assert(0);
}
}
interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 1, x);
}
}
}

} // namespace

MEGDNN_REG_GEMM_STRATEGY_IMPL(gi_sgemm_4x12);

void gi_sgemm_4x12::pack_A(
float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax,
bool transpose_A) const {
if (transpose_A) {
gi_sgemm_4x12_pack_A_t(out, in, ldin, y0, ymax, k0, kmax);
} else {
gi_sgemm_4x12_pack_A_n(out, in, ldin, y0, ymax, k0, kmax);
}
}

void gi_sgemm_4x12::pack_B(
float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax,
bool transpose_B) const {
if (transpose_B) {
gi_sgemm_4x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax);
} else {
gi_sgemm_4x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
}
}

void gi_sgemm_4x12::kern(
const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k, const float*, float*) const {
megdnn_assert(
A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);

constexpr size_t A_INTERLEAVE = 4;
constexpr size_t B_INTERLEAVE = 12;
const int K12 = K * 12;
const int K4 = K * 4;

size_t m = 0;
for (; m < M; m += A_INTERLEAVE) {
float* output = C + (m * LDC);

size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
kern_4x12(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K12;
}

for (; n < N; n += 4) {
kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}

packA += K4;
}
}

// vim: syntax=cpp.doxygen

+ 4
- 2
dnn/src/fallback/matrix_mul/opr_impl.cpp View File

@@ -28,16 +28,18 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoNaive naive;
AlgoF32GiGemvMK4 f32_gemv_mk4;
AlgoF32GiMK4_4x8 f32_mk4_4x8;
AlgoF32Gi4x12 f32_4x8;
SmallVector<AlgoBase*> m_all_algos;
AlgoBase::Mapper m_all_algos_map;

public:
AlgoPack() {
m_all_algos.emplace_back(&f32_gemv_mk4);
m_all_algos.emplace_back(&f32_mk4_4x8);
m_all_algos.emplace_back(&f32_4x8);
m_all_algos.emplace_back(&gemv);
m_all_algos.emplace_back(&f32_k8x12x1);
m_all_algos.emplace_back(&naive);
m_all_algos.emplace_back(&f32_gemv_mk4);
m_all_algos.emplace_back(&f32_mk4_4x8);
for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}


+ 2
- 0
dnn/src/fallback/matrix_mul/opr_impl.h View File

@@ -103,6 +103,7 @@ public:
FB_NAIVE,
FB_GI_F32_GEMV_MK4,
FB_GI_F32_MK4_4x8,
FB_GI_F32_4x12,

#if MEGDNN_X86
//! x86
@@ -232,6 +233,7 @@ private:
class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1
class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44
class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44
class AlgoF32Gi4x12; // fallback F32 gi Gemm
class AlgoGemv;
class AlgoNaive;
class AlgoPack;


+ 6
- 0
dnn/test/fallback/matrix_mul.cpp View File

@@ -42,6 +42,12 @@ TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) {
"FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1);
}

TEST_F(FALLBACK, MATRIX_MULF_GI_F32_4x12) {
matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
"FB_GI_F32_4x12");
}

TEST_F(FALLBACK, MATRIX_MUL_RECORD) {
TaskRecordChecker<MatrixMul> checker(1);
using Param = MatrixMul::Param;


Loading…
Cancel
Save