Browse Source

feat(fallback): imp gi matmul AlgoF32GiMK4_4x8 algo,

move AlgoF32GemvMK4 from arm_common to fallback

GitOrigin-RevId: 6c065abf99
release-1.10
Megvii Engine Team 3 years ago
parent
commit
0f1afb0935
13 changed files with 620 additions and 61 deletions
  1. +0
    -40
      dnn/src/arm_common/matrix_mul/algos.cpp
  2. +0
    -16
      dnn/src/arm_common/matrix_mul/algos.h
  3. +0
    -2
      dnn/src/arm_common/matrix_mul/opr_impl.cpp
  4. +0
    -1
      dnn/src/arm_common/matrix_mul/opr_impl.h
  5. +99
    -0
      dnn/src/fallback/matrix_mul/algos.cpp
  6. +28
    -0
      dnn/src/fallback/matrix_mul/algos.h
  7. +2
    -0
      dnn/src/fallback/matrix_mul/generic_strategy.h
  8. +101
    -0
      dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp
  9. +25
    -0
      dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.h
  10. +349
    -0
      dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp
  11. +4
    -0
      dnn/src/fallback/matrix_mul/opr_impl.cpp
  12. +5
    -2
      dnn/src/fallback/matrix_mul/opr_impl.h
  13. +7
    -0
      dnn/test/fallback/matrix_mul.cpp

+ 0
- 40
dnn/src/arm_common/matrix_mul/algos.cpp View File

@@ -239,46 +239,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern(const KernSizeParam&)
return f32_gemv_kern;
}

/* ================== F32 Gemv MK4 algo ================== */
namespace {
void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_arm_exec_fp32, midout_iv("f32_gemv_mk4_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_float32>(), Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
MIDOUT_END();
}
} // anonymous namespace

bool MatrixMulImpl::AlgoF32GemvMK4::usable(const KernSizeParam& kern_size_param) const {
// enumerate the M, N, K, only usable when preferred
auto M = kern_size_param.M;
auto N = kern_size_param.N;
auto K = kern_size_param.K;
auto LDB = kern_size_param.LDB;

return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA &&
!kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4;
}

bool MatrixMulImpl::AlgoF32GemvMK4::preferred(
const KernSizeParam& kern_size_param) const {
MEGDNN_MARK_USED_VAR(kern_size_param);
return true;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GemvMK4::get_kern(
const KernSizeParam&) const {
return f32_gemv_mk4_kern;
}

/* ===================== F32 Gevm algo ===================== */
namespace {
template <typename stype, typename dtype>


+ 0
- 16
dnn/src/arm_common/matrix_mul/algos.h View File

@@ -95,22 +95,6 @@ public:
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT)
};

class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override;
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4)
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F32_GEMV_MK4)
};

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class MatrixMulImpl::AlgoF16Gemv : public AlgoBase {
public:


+ 0
- 2
dnn/src/arm_common/matrix_mul/opr_impl.cpp View File

@@ -26,7 +26,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot;
#endif
AlgoGevm gevm;
AlgoF32GemvMK4 f32_gemv_mk4;

SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos;
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map;
@@ -42,7 +41,6 @@ public:
#endif
m_all_algos.emplace_back(&int8x8x32_gemv);
m_all_algos.emplace_back(&int8x8x32_gemv_mk4);
m_all_algos.emplace_back(&f32_gemv_mk4);
m_all_algos.emplace_back(&gevm);

for (auto&& algo : m_all_algos) {


+ 0
- 1
dnn/src/arm_common/matrix_mul/opr_impl.h View File

@@ -34,7 +34,6 @@ public:

protected:
class AlgoF32Gemv; // Arm_common F32 Gemv
class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44
class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv
class AlgoInt8x8x32GemvMK4; // Arm_common Int8x8x32 Gemv NCHW44
class AlgoGevm; // Arm_common Gevm(support int8 and fp32)


+ 99
- 0
dnn/src/fallback/matrix_mul/algos.cpp View File

@@ -17,11 +17,15 @@

#include "src/naive/matrix_mul/matrix_mul_helper.h"

#include "src/fallback/matrix_mul/gi/fp32/exec_sgemv.h"

#include "midout.h"

MIDOUT_DECL(megdnn_fb_matmul_f32_kern)
MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like)
MIDOUT_DECL(megdnn_fb_matmul_naive)
MIDOUT_DECL(megdnn_fb_gi_exec_fp32)
MIDOUT_DECL(megdnn_fb_gi_matmul_kern)

using namespace megdnn;
using namespace fallback;
@@ -205,4 +209,99 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoNaive::get_kern(const KernSizeParam&) c
return kern_naive;
}

/* ================== F32 Gemv MK4 gi algo ================== */
namespace {
void gi_f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_fb_gi_exec_fp32, midout_iv("f32_gemv_mk4_gi_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_float32>(), Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
gi_gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
MIDOUT_END();
}
} // anonymous namespace

bool MatrixMulImpl::AlgoF32GiGemvMK4::usable(
const KernSizeParam& kern_size_param) const {
// enumerate the M, N, K, only usable when preferred
auto M = kern_size_param.M;
auto N = kern_size_param.N;
auto K = kern_size_param.K;
auto LDB = kern_size_param.LDB;

return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA &&
!kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4;
}

bool MatrixMulImpl::AlgoF32GiGemvMK4::preferred(
const KernSizeParam& kern_size_param) const {
MEGDNN_MARK_USED_VAR(kern_size_param);
return true;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiGemvMK4::get_kern(
const KernSizeParam&) const {
return gi_f32_gemv_mk4_kern;
}

/* ================== F32 Gemm MK4 gi algo ================== */
namespace {
void gi_f32_mk4_4x8_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_fb_gi_matmul_kern, midout_iv("gi_f32_mk4_4x8_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<float>(), Bptr = kern_param.B<float>();
auto Cptr = kern_param.C<float>();

matmul::fallback::gi_sgemm_nopack_4x8 strategy(A_type, B_type, C_type);
megdnn::matmul::GemmInterleaved<matmul::fallback::gi_sgemm_nopack_4x8, false>(
M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr);
}
MIDOUT_END();
}

} // anonymous namespace
bool MatrixMulImpl::AlgoF32GiMK4_4x8::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.B_type == kern_size_param.A_type &&
kern_size_param.C_type == kern_size_param.A_type &&
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA &&
!kern_size_param.trB;
}

size_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(
megdnn_fb_gi_matmul_kern,
midout_iv("AlgoF32GiMK4_4x8::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
matmul::fallback::gi_sgemm_nopack_4x8 strategy(A_type, B_type, C_type);
return megdnn::matmul::GemmInterleaved<
matmul::fallback::gi_sgemm_nopack_4x8, false>(
M, N, K, trA, trB, strategy)
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern(
const KernSizeParam&) const {
return gi_f32_mk4_4x8_kern;
}
// vim: syntax=cpp.doxygen

+ 28
- 0
dnn/src/fallback/matrix_mul/algos.h View File

@@ -80,6 +80,34 @@ public:
DEFAULT)
};

class MatrixMulImpl::AlgoF32GiGemvMK4 : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "FB_GI_F32_GEMV_MK4"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override;
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4)
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_GEMV_MK4)
};

class MatrixMulImpl::AlgoF32GiMK4_4x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override { return "FB_GI_F32_MK4_4x8"; }
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4)
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8)
};

} // namespace fallback
} // namespace megdnn



+ 2
- 0
dnn/src/fallback/matrix_mul/generic_strategy.h View File

@@ -16,6 +16,8 @@ namespace matmul {
namespace fallback {

MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12);
MEGDNN_REG_GEMM_STRATEGY_NOPACK(
float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8);

} // namespace fallback
} // namespace matmul


+ 101
- 0
dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp View File

@@ -0,0 +1,101 @@
/**
* \file dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2022 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/fallback/matrix_mul/gi/fp32/exec_sgemv.h"
#include "include/megdnn/oprs.h"
#include "src/common/unroll_macro.h"
#include "src/common/utils.h"
#include "src/fallback/general_intrinsic/gi_float.h"

#include "midout.h"
MIDOUT_DECL(megdnn_fp32_gi_sgemv)

using namespace megdnn;
using namespace fallback;

namespace {

void sgemv_gi_naive_n_mk4(
const float* __restrict A, const float* __restrict B, float* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) {
constexpr size_t PACK_SIZE = 4;
megdnn_assert(
N == 1 && Bstride == PACK_SIZE && M % PACK_SIZE == 0 && K % PACK_SIZE == 0);
auto Aptr = A;
auto Cptr = C;
size_t m = 0;
while (m < M) {
auto Aptr0 = Aptr;
auto Cptr0 = Cptr;
GI_FLOAT32_t c[4];
#define INIT(step) c[step] = GiBroadcastFloat32(0.0f);
UNROLL_CALL_RAW(4, INIT)
#undef INIT
auto Bptr = B;
size_t k = 0;
while (k < K) {
GI_FLOAT32_t b = GiLoadFloat32(Bptr);
GI_FLOAT32_V2_t a[2];
#if defined(GI_TEST_NAIVE)
#define LOAD_A(step) \
a[step].val[0] = GiLoadFloat32(Aptr0 + step * 8); \
a[step].val[1] = GiLoadFloat32(Aptr0 + step * 8 + 4);
#elif defined(__arm__) || defined(__aarch64__)
#define LOAD_A(step) a[step] = vld1q_f32_x2(Aptr0 + step * 8);
#else
#define LOAD_A(step) \
a[step].val[0] = GiLoadFloat32(Aptr0 + step * 8); \
a[step].val[1] = GiLoadFloat32(Aptr0 + step * 8 + 4);
#endif
UNROLL_CALL_RAW(2, LOAD_A)
#undef LOAD_A

#define COMPT(step) \
c[step] = GiSimdFmaLane(c[step], a[step / 2].val[step % 2], b, step % 4);
UNROLL_CALL_RAW(4, COMPT)
#undef COMPT
Bptr += Bstride;
Aptr0 += PACK_SIZE * PACK_SIZE;
k += PACK_SIZE;
}

#define ADD_C(step, stride) c[step] = GiAddFloat32(c[step], c[step + stride]);
UNROLL_CALL_RAW(2, ADD_C, 2)
UNROLL_CALL_RAW(1, ADD_C, 1)
#undef ADD_C
GiStoreFloat32(Cptr0, c[0]);

Aptr += Astride;
Cptr += Cstride;
m += PACK_SIZE;
}
}

} // namespace

namespace megdnn {
namespace fallback {

void gi_gemv_like_mk4(
const float* __restrict A, const float* __restrict B, float* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) {
megdnn_assert(N == 1 && Bstride == 4);
MIDOUT_BEGIN(megdnn_fp32_gi_sgemv, midout_iv("F32_GEMV_NCHW_GI_44_N"_hash)) {
return sgemv_gi_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride);
}
MIDOUT_END();
}

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 25
- 0
dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.h View File

@@ -0,0 +1,25 @@
/**
* \file dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2022 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once

#include <cstddef>

namespace megdnn {
namespace fallback {

void gi_gemv_like_mk4(
const float* __restrict A, const float* __restrict B, float* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride);

} // namespace fallback
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 349
- 0
dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp View File

@@ -0,0 +1,349 @@
/**
* \file dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2022 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/common/utils.h"
#include "src/fallback/general_intrinsic/gi_float.h"
#include "src/fallback/matrix_mul/generic_strategy.h"

using namespace megdnn;
using namespace matmul::fallback;

namespace {

void kern_4x1(const float* A, const float* B, size_t LDB, size_t K, float* C) {
LDB = LDB - 4;
K = K - 4;

GI_FLOAT32_t d8d9 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d10d11 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d12d13 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d14d15 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d16d17 = GiBroadcastFloat32(0.0f);
GI_FLOAT32_t d18d19 = GiBroadcastFloat32(0.0f);
GI_FLOAT32_t d20d21 = GiBroadcastFloat32(0.0f);
GI_FLOAT32_t d22d23 = GiBroadcastFloat32(0.0f);

GI_FLOAT32_t d0d1 = GiLoadFloat32(B);
B = B + 4;

d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d0d1, 1);

for (; K > 0; K -= 4) {
d8d9 = GiLoadFloat32(A);
A = A + 4;
d10d11 = GiLoadFloat32(A);
A = A + 4;
d20d21 = GiSimdFmaLane(d20d21, d12d13, d0d1, 2);
d22d23 = GiSimdFmaLane(d22d23, d14d15, d0d1, 3);

B = B + LDB;
d0d1 = GiLoadFloat32(B);
B = B + 4;
d12d13 = GiLoadFloat32(A);
A = A + 4;
d14d15 = GiLoadFloat32(A);
A = A + 4;

d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d0d1, 1);
}

d20d21 = GiSimdFmaLane(d20d21, d12d13, d0d1, 2);
d22d23 = GiSimdFmaLane(d22d23, d14d15, d0d1, 3);
d16d17 = GiAddFloat32(d16d17, d20d21);
d18d19 = GiAddFloat32(d18d19, d22d23);
d16d17 = GiAddFloat32(d16d17, d18d19);

GiStoreFloat32(C, d16d17);
C = C + 4;
}

void kern_4x4(const float* A, const float* B, size_t LDB, size_t K, float* C) {
LDB = (LDB - 16);
K = K - 4;

GI_FLOAT32_t d8d9 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d10d11 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d12d13 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d14d15 = GiLoadFloat32(A);
A = A + 4;

GI_FLOAT32_t d0d1 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d2d3 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d4d5 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d6d7 = GiLoadFloat32(B);
B = B + 4;

GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0);
GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0);
GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0);
GI_FLOAT32_t d22d23 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0);

d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1);
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1);
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1);

for (; K > 0; K -= 4) {
d8d9 = GiLoadFloat32(A);
A = A + 4;
d10d11 = GiLoadFloat32(A);
A = A + 4;

d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2);
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2);
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2);
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2);

B = B + LDB;

d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3);
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3);
d0d1 = GiLoadFloat32(B);
B = B + 4;
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3);
d2d3 = GiLoadFloat32(B);
B = B + 4;
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3);
d4d5 = GiLoadFloat32(B);
B = B + 4;

d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0);
d6d7 = GiLoadFloat32(B);
B = B + 4;
d18d19 = GiSimdFmaLane(d18d19, d8d9, d2d3, 0);
d20d21 = GiSimdFmaLane(d20d21, d8d9, d4d5, 0);
d22d23 = GiSimdFmaLane(d22d23, d8d9, d6d7, 0);

d12d13 = GiLoadFloat32(A);
A = A + 4;
d14d15 = GiLoadFloat32(A);
A = A + 4;

d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1);
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1);
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1);
}

d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2);
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2);
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2);
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2);

d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3);
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3);
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3);
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3);

GiStoreFloat32(C, d16d17);
C = C + 4;
GiStoreFloat32(C, d18d19);
C = C + 4;
GiStoreFloat32(C, d20d21);
C = C + 4;
GiStoreFloat32(C, d22d23);
C = C + 4;
}

void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) {
LDB -= 32;
GI_FLOAT32_t d8d9 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d10d11 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d12d13 = GiLoadFloat32(A);
A = A + 4;
GI_FLOAT32_t d14d15 = GiLoadFloat32(A);
A = A + 4;

GI_FLOAT32_t d0d1 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d2d3 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d4d5 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d6d7 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0);
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1);
GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0);
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1);
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3);
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2);
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3);
d0d1 = GiLoadFloat32(B);
B = B + 4;
d2d3 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0);
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1);
GI_FLOAT32_t d22d23 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0);
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2);
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1);
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3);
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2);
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3);

d4d5 = GiLoadFloat32(B);
B = B + 4;
d6d7 = GiLoadFloat32(B);
B = B + 4;
GI_FLOAT32_t d24d25 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0);
d24d25 = GiSimdFmaLane(d24d25, d10d11, d0d1, 1);
GI_FLOAT32_t d26d27 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0);
d24d25 = GiSimdFmaLane(d24d25, d12d13, d0d1, 2);
d26d27 = GiSimdFmaLane(d26d27, d10d11, d2d3, 1);
d24d25 = GiSimdFmaLane(d24d25, d14d15, d0d1, 3);
d26d27 = GiSimdFmaLane(d26d27, d12d13, d2d3, 2);
d26d27 = GiSimdFmaLane(d26d27, d14d15, d2d3, 3);
GI_FLOAT32_t d28d29 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0);
d28d29 = GiSimdFmaLane(d28d29, d10d11, d4d5, 1);
GI_FLOAT32_t d30d31 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0);
d28d29 = GiSimdFmaLane(d28d29, d12d13, d4d5, 2);
d30d31 = GiSimdFmaLane(d30d31, d10d11, d6d7, 1);
d28d29 = GiSimdFmaLane(d28d29, d14d15, d4d5, 3);
d30d31 = GiSimdFmaLane(d30d31, d12d13, d6d7, 2);
d30d31 = GiSimdFmaLane(d30d31, d14d15, d6d7, 3);

B = B + LDB;
K = K - 4;
for (; K > 0; K -= 4) {
d8d9 = GiLoadFloat32(A);
A = A + 4;
d10d11 = GiLoadFloat32(A);
A = A + 4;
d12d13 = GiLoadFloat32(A);
A = A + 4;
d14d15 = GiLoadFloat32(A);
A = A + 4;

d0d1 = GiLoadFloat32(B);
B = B + 4;
d2d3 = GiLoadFloat32(B);
B = B + 4;
d4d5 = GiLoadFloat32(B);
B = B + 4;
d6d7 = GiLoadFloat32(B);
B = B + 4;
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0);
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1);
d18d19 = GiSimdFmaLane(d18d19, d8d9, d2d3, 0);
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2);
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1);
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3);
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2);
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3);
d0d1 = GiLoadFloat32(B);
B = B + 4;
d2d3 = GiLoadFloat32(B);
B = B + 4;
d20d21 = GiSimdFmaLane(d20d21, d8d9, d4d5, 0);
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1);
d22d23 = GiSimdFmaLane(d22d23, d8d9, d6d7, 0);
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2);
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1);
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3);
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2);
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3);

d4d5 = GiLoadFloat32(B);
B = B + 4;
d6d7 = GiLoadFloat32(B);
B = B + 4;
d24d25 = GiSimdFmaLane(d24d25, d8d9, d0d1, 0);
d24d25 = GiSimdFmaLane(d24d25, d10d11, d0d1, 1);
d26d27 = GiSimdFmaLane(d26d27, d8d9, d2d3, 0);
d24d25 = GiSimdFmaLane(d24d25, d12d13, d0d1, 2);
d26d27 = GiSimdFmaLane(d26d27, d10d11, d2d3, 1);
d24d25 = GiSimdFmaLane(d24d25, d14d15, d0d1, 3);
d26d27 = GiSimdFmaLane(d26d27, d12d13, d2d3, 2);
d26d27 = GiSimdFmaLane(d26d27, d14d15, d2d3, 3);
d28d29 = GiSimdFmaLane(d28d29, d8d9, d4d5, 0);
d28d29 = GiSimdFmaLane(d28d29, d10d11, d4d5, 1);
d30d31 = GiSimdFmaLane(d30d31, d8d9, d6d7, 0);
d28d29 = GiSimdFmaLane(d28d29, d12d13, d4d5, 2);
d30d31 = GiSimdFmaLane(d30d31, d10d11, d6d7, 1);
d28d29 = GiSimdFmaLane(d28d29, d14d15, d4d5, 3);
d30d31 = GiSimdFmaLane(d30d31, d12d13, d6d7, 2);
d30d31 = GiSimdFmaLane(d30d31, d14d15, d6d7, 3);
B = B + LDB;
}
GiStoreFloat32(C, d16d17);
C = C + 4;
GiStoreFloat32(C, d18d19);
C = C + 4;
GiStoreFloat32(C, d20d21);
C = C + 4;
GiStoreFloat32(C, d22d23);
C = C + 4;
GiStoreFloat32(C, d24d25);
C = C + 4;
GiStoreFloat32(C, d26d27);
C = C + 4;
GiStoreFloat32(C, d28d29);
C = C + 4;
GiStoreFloat32(C, d30d31);
C = C + 4;
}

} // namespace

MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gi_sgemm_nopack_4x8);

void gi_sgemm_nopack_4x8::kern(
const float* A, size_t LDA, const float* B, size_t LDB, float* C, size_t LDC,
size_t M, size_t K, size_t N, const float*, void*, bool trA, bool trB) const {
constexpr size_t MB = 4;
constexpr size_t KB = 4;
constexpr size_t NB = 8;
constexpr size_t NB_HALF = 4;

megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0);

for (size_t m = 0; m < M; m += MB) {
float* output = C + (m / MB) * LDC;
const float* cur_B = B;
size_t n = 0;
for (; n + NB - 1 < N; n += NB) {
kern_4x8(A, cur_B, LDB, K, output);
cur_B += KB * NB;
output += MB * NB;
}
if (N - n >= 4) {
kern_4x4(A, cur_B, LDB, K, output);
cur_B += KB * NB_HALF;
output += MB * NB_HALF;
n += 4;
}
while (n < N) {
kern_4x1(A, cur_B, LDB, K, output);
cur_B += KB;
output += MB;
n++;
}
A += LDA;
}
}

// vim: syntax=cpp.doxygen

+ 4
- 0
dnn/src/fallback/matrix_mul/opr_impl.cpp View File

@@ -36,6 +36,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF32K8x12x1 f32_k8x12x1;
AlgoGemv gemv;
AlgoNaive naive;
AlgoF32GiGemvMK4 f32_gemv_mk4;
AlgoF32GiMK4_4x8 f32_mk4_4x8;
SmallVector<AlgoBase*> m_all_algos;
AlgoBase::Mapper m_all_algos_map;

@@ -44,6 +46,8 @@ public:
m_all_algos.emplace_back(&gemv);
m_all_algos.emplace_back(&f32_k8x12x1);
m_all_algos.emplace_back(&naive);
m_all_algos.emplace_back(&f32_gemv_mk4);
m_all_algos.emplace_back(&f32_mk4_4x8);
for (auto&& algo : m_all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}


+ 5
- 2
dnn/src/fallback/matrix_mul/opr_impl.h View File

@@ -112,6 +112,8 @@ public:
FB_F32K8x12x1 = 1 << 0,
FB_GEMV,
FB_NAIVE,
FB_GI_F32_GEMV_MK4,
FB_GI_F32_MK4_4x8,

#if MEGDNN_X86
//! x86
@@ -131,7 +133,6 @@ public:
ARM_COMMON_INT8X8X32_GEMV,
ARM_COMMON_INT8X8X32_GEMV_MK4,
ARM_COMMON_INT8X8X32_GEMV_MK4_DOT,
ARM_COMMON_F32_GEMV_MK4,
ARM_COMMON_F16_GEMV,
ARM_COMMON_GEVM,
#if MEGDNN_AARCH64
@@ -236,7 +237,9 @@ public:
};

private:
class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1
class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1
class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44
class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44
class AlgoGemv;
class AlgoNaive;
class AlgoPack;


+ 7
- 0
dnn/test/fallback/matrix_mul.cpp View File

@@ -45,6 +45,13 @@ TEST_F(FALLBACK, MATRIX_MUL) {
checker.execl({AL, BL, CL});
}
}

TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) {
matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
"FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1);
}

TEST_F(FALLBACK, MATRIX_MUL_RECORD) {
TaskRecordChecker<MatrixMul> checker(1);
using Param = MatrixMul::Param;


Loading…
Cancel
Save