move AlgoF32GemvMK4 from arm_common to fallback
GitOrigin-RevId: 6c065abf99
release-1.10
@@ -239,46 +239,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern(const KernSizeParam&) | |||||
return f32_gemv_kern; | return f32_gemv_kern; | ||||
} | } | ||||
/* ================== F32 Gemv MK4 algo ================== */ | |||||
namespace { | |||||
void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
MIDOUT_BEGIN(megdnn_arm_exec_fp32, midout_iv("f32_gemv_mk4_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_float32>(), Bptr = kern_param.B<dt_float32>(); | |||||
auto Cptr = kern_param.C<dt_float32>(); | |||||
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} // anonymous namespace | |||||
bool MatrixMulImpl::AlgoF32GemvMK4::usable(const KernSizeParam& kern_size_param) const { | |||||
// enumerate the M, N, K, only usable when preferred | |||||
auto M = kern_size_param.M; | |||||
auto N = kern_size_param.N; | |||||
auto K = kern_size_param.K; | |||||
auto LDB = kern_size_param.LDB; | |||||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||||
kern_size_param.B_type == kern_size_param.A_type && | |||||
kern_size_param.C_type == kern_size_param.A_type && | |||||
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | |||||
!kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; | |||||
} | |||||
bool MatrixMulImpl::AlgoF32GemvMK4::preferred( | |||||
const KernSizeParam& kern_size_param) const { | |||||
MEGDNN_MARK_USED_VAR(kern_size_param); | |||||
return true; | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GemvMK4::get_kern( | |||||
const KernSizeParam&) const { | |||||
return f32_gemv_mk4_kern; | |||||
} | |||||
/* ===================== F32 Gevm algo ===================== */ | /* ===================== F32 Gevm algo ===================== */ | ||||
namespace { | namespace { | ||||
template <typename stype, typename dtype> | template <typename stype, typename dtype> | ||||
@@ -95,22 +95,6 @@ public: | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | |||||
public: | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | |||||
const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) | |||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F32_GEMV_MK4) | |||||
}; | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
class MatrixMulImpl::AlgoF16Gemv : public AlgoBase { | class MatrixMulImpl::AlgoF16Gemv : public AlgoBase { | ||||
public: | public: | ||||
@@ -26,7 +26,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; | AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; | ||||
#endif | #endif | ||||
AlgoGevm gevm; | AlgoGevm gevm; | ||||
AlgoF32GemvMK4 f32_gemv_mk4; | |||||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | ||||
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; | fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; | ||||
@@ -42,7 +41,6 @@ public: | |||||
#endif | #endif | ||||
m_all_algos.emplace_back(&int8x8x32_gemv); | m_all_algos.emplace_back(&int8x8x32_gemv); | ||||
m_all_algos.emplace_back(&int8x8x32_gemv_mk4); | m_all_algos.emplace_back(&int8x8x32_gemv_mk4); | ||||
m_all_algos.emplace_back(&f32_gemv_mk4); | |||||
m_all_algos.emplace_back(&gevm); | m_all_algos.emplace_back(&gevm); | ||||
for (auto&& algo : m_all_algos) { | for (auto&& algo : m_all_algos) { | ||||
@@ -34,7 +34,6 @@ public: | |||||
protected: | protected: | ||||
class AlgoF32Gemv; // Arm_common F32 Gemv | class AlgoF32Gemv; // Arm_common F32 Gemv | ||||
class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44 | |||||
class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv | class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv | ||||
class AlgoInt8x8x32GemvMK4; // Arm_common Int8x8x32 Gemv NCHW44 | class AlgoInt8x8x32GemvMK4; // Arm_common Int8x8x32 Gemv NCHW44 | ||||
class AlgoGevm; // Arm_common Gevm(support int8 and fp32) | class AlgoGevm; // Arm_common Gevm(support int8 and fp32) | ||||
@@ -17,11 +17,15 @@ | |||||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | #include "src/naive/matrix_mul/matrix_mul_helper.h" | ||||
#include "src/fallback/matrix_mul/gi/fp32/exec_sgemv.h" | |||||
#include "midout.h" | #include "midout.h" | ||||
MIDOUT_DECL(megdnn_fb_matmul_f32_kern) | MIDOUT_DECL(megdnn_fb_matmul_f32_kern) | ||||
MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like) | MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like) | ||||
MIDOUT_DECL(megdnn_fb_matmul_naive) | MIDOUT_DECL(megdnn_fb_matmul_naive) | ||||
MIDOUT_DECL(megdnn_fb_gi_exec_fp32) | |||||
MIDOUT_DECL(megdnn_fb_gi_matmul_kern) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace fallback; | using namespace fallback; | ||||
@@ -205,4 +209,99 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoNaive::get_kern(const KernSizeParam&) c | |||||
return kern_naive; | return kern_naive; | ||||
} | } | ||||
/* ================== F32 Gemv MK4 gi algo ================== */ | |||||
namespace { | |||||
void gi_f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
MIDOUT_BEGIN(megdnn_fb_gi_exec_fp32, midout_iv("f32_gemv_mk4_gi_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_float32>(), Bptr = kern_param.B<dt_float32>(); | |||||
auto Cptr = kern_param.C<dt_float32>(); | |||||
gi_gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} // anonymous namespace | |||||
bool MatrixMulImpl::AlgoF32GiGemvMK4::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
// enumerate the M, N, K, only usable when preferred | |||||
auto M = kern_size_param.M; | |||||
auto N = kern_size_param.N; | |||||
auto K = kern_size_param.K; | |||||
auto LDB = kern_size_param.LDB; | |||||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||||
kern_size_param.B_type == kern_size_param.A_type && | |||||
kern_size_param.C_type == kern_size_param.A_type && | |||||
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | |||||
!kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; | |||||
} | |||||
bool MatrixMulImpl::AlgoF32GiGemvMK4::preferred( | |||||
const KernSizeParam& kern_size_param) const { | |||||
MEGDNN_MARK_USED_VAR(kern_size_param); | |||||
return true; | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiGemvMK4::get_kern( | |||||
const KernSizeParam&) const { | |||||
return gi_f32_gemv_mk4_kern; | |||||
} | |||||
/* ================== F32 Gemm MK4 gi algo ================== */ | |||||
namespace { | |||||
void gi_f32_mk4_4x8_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
MIDOUT_BEGIN(megdnn_fb_gi_matmul_kern, midout_iv("gi_f32_mk4_4x8_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto trA = kern_param.trA, trB = kern_param.trB; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
auto A_type = kern_param.A_type, B_type = kern_param.B_type, | |||||
C_type = kern_param.C_type; | |||||
const auto Aptr = kern_param.A<float>(), Bptr = kern_param.B<float>(); | |||||
auto Cptr = kern_param.C<float>(); | |||||
matmul::fallback::gi_sgemm_nopack_4x8 strategy(A_type, B_type, C_type); | |||||
megdnn::matmul::GemmInterleaved<matmul::fallback::gi_sgemm_nopack_4x8, false>( | |||||
M, N, K, trA, trB, strategy) | |||||
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} // anonymous namespace | |||||
bool MatrixMulImpl::AlgoF32GiMK4_4x8::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||||
kern_size_param.B_type == kern_size_param.A_type && | |||||
kern_size_param.C_type == kern_size_param.A_type && | |||||
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | |||||
!kern_size_param.trB; | |||||
} | |||||
size_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_workspace( | |||||
const KernSizeParam& kern_size_param) const { | |||||
MIDOUT_BEGIN( | |||||
megdnn_fb_gi_matmul_kern, | |||||
midout_iv("AlgoF32GiMK4_4x8::get_workspace"_hash)) { | |||||
auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; | |||||
auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||||
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||||
C_type = kern_size_param.C_type; | |||||
matmul::fallback::gi_sgemm_nopack_4x8 strategy(A_type, B_type, C_type); | |||||
return megdnn::matmul::GemmInterleaved< | |||||
matmul::fallback::gi_sgemm_nopack_4x8, false>( | |||||
M, N, K, trA, trB, strategy) | |||||
.get_workspace_size(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern( | |||||
const KernSizeParam&) const { | |||||
return gi_f32_mk4_4x8_kern; | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -80,6 +80,34 @@ public: | |||||
DEFAULT) | DEFAULT) | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoF32GiGemvMK4 : public AlgoBase { | |||||
public: | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | |||||
const char* name() const override { return "FB_GI_F32_GEMV_MK4"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) | |||||
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_GEMV_MK4) | |||||
}; | |||||
class MatrixMulImpl::AlgoF32GiMK4_4x8 final : public AlgoBase { | |||||
public: | |||||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
const char* name() const override { return "FB_GI_F32_MK4_4x8"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) | |||||
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8) | |||||
}; | |||||
} // namespace fallback | } // namespace fallback | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -16,6 +16,8 @@ namespace matmul { | |||||
namespace fallback { | namespace fallback { | ||||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12); | MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12); | ||||
MEGDNN_REG_GEMM_STRATEGY_NOPACK( | |||||
float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8); | |||||
} // namespace fallback | } // namespace fallback | ||||
} // namespace matmul | } // namespace matmul | ||||
@@ -0,0 +1,101 @@ | |||||
/** | |||||
* \file dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2022 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/fallback/matrix_mul/gi/fp32/exec_sgemv.h" | |||||
#include "include/megdnn/oprs.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/general_intrinsic/gi_float.h" | |||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_fp32_gi_sgemv) | |||||
using namespace megdnn; | |||||
using namespace fallback; | |||||
namespace { | |||||
void sgemv_gi_naive_n_mk4( | |||||
const float* __restrict A, const float* __restrict B, float* __restrict C, | |||||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { | |||||
constexpr size_t PACK_SIZE = 4; | |||||
megdnn_assert( | |||||
N == 1 && Bstride == PACK_SIZE && M % PACK_SIZE == 0 && K % PACK_SIZE == 0); | |||||
auto Aptr = A; | |||||
auto Cptr = C; | |||||
size_t m = 0; | |||||
while (m < M) { | |||||
auto Aptr0 = Aptr; | |||||
auto Cptr0 = Cptr; | |||||
GI_FLOAT32_t c[4]; | |||||
#define INIT(step) c[step] = GiBroadcastFloat32(0.0f); | |||||
UNROLL_CALL_RAW(4, INIT) | |||||
#undef INIT | |||||
auto Bptr = B; | |||||
size_t k = 0; | |||||
while (k < K) { | |||||
GI_FLOAT32_t b = GiLoadFloat32(Bptr); | |||||
GI_FLOAT32_V2_t a[2]; | |||||
#if defined(GI_TEST_NAIVE) | |||||
#define LOAD_A(step) \ | |||||
a[step].val[0] = GiLoadFloat32(Aptr0 + step * 8); \ | |||||
a[step].val[1] = GiLoadFloat32(Aptr0 + step * 8 + 4); | |||||
#elif defined(__arm__) || defined(__aarch64__) | |||||
#define LOAD_A(step) a[step] = vld1q_f32_x2(Aptr0 + step * 8); | |||||
#else | |||||
#define LOAD_A(step) \ | |||||
a[step].val[0] = GiLoadFloat32(Aptr0 + step * 8); \ | |||||
a[step].val[1] = GiLoadFloat32(Aptr0 + step * 8 + 4); | |||||
#endif | |||||
UNROLL_CALL_RAW(2, LOAD_A) | |||||
#undef LOAD_A | |||||
#define COMPT(step) \ | |||||
c[step] = GiSimdFmaLane(c[step], a[step / 2].val[step % 2], b, step % 4); | |||||
UNROLL_CALL_RAW(4, COMPT) | |||||
#undef COMPT | |||||
Bptr += Bstride; | |||||
Aptr0 += PACK_SIZE * PACK_SIZE; | |||||
k += PACK_SIZE; | |||||
} | |||||
#define ADD_C(step, stride) c[step] = GiAddFloat32(c[step], c[step + stride]); | |||||
UNROLL_CALL_RAW(2, ADD_C, 2) | |||||
UNROLL_CALL_RAW(1, ADD_C, 1) | |||||
#undef ADD_C | |||||
GiStoreFloat32(Cptr0, c[0]); | |||||
Aptr += Astride; | |||||
Cptr += Cstride; | |||||
m += PACK_SIZE; | |||||
} | |||||
} | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace fallback { | |||||
void gi_gemv_like_mk4( | |||||
const float* __restrict A, const float* __restrict B, float* __restrict C, | |||||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { | |||||
megdnn_assert(N == 1 && Bstride == 4); | |||||
MIDOUT_BEGIN(megdnn_fp32_gi_sgemv, midout_iv("F32_GEMV_NCHW_GI_44_N"_hash)) { | |||||
return sgemv_gi_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} // namespace fallback | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,25 @@ | |||||
/** | |||||
* \file dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2022 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <cstddef> | |||||
namespace megdnn { | |||||
namespace fallback { | |||||
void gi_gemv_like_mk4( | |||||
const float* __restrict A, const float* __restrict B, float* __restrict C, | |||||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); | |||||
} // namespace fallback | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,349 @@ | |||||
/** | |||||
* \file dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2022 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/common/utils.h" | |||||
#include "src/fallback/general_intrinsic/gi_float.h" | |||||
#include "src/fallback/matrix_mul/generic_strategy.h" | |||||
using namespace megdnn; | |||||
using namespace matmul::fallback; | |||||
namespace { | |||||
void kern_4x1(const float* A, const float* B, size_t LDB, size_t K, float* C) { | |||||
LDB = LDB - 4; | |||||
K = K - 4; | |||||
GI_FLOAT32_t d8d9 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
GI_FLOAT32_t d10d11 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
GI_FLOAT32_t d12d13 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
GI_FLOAT32_t d14d15 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
GI_FLOAT32_t d16d17 = GiBroadcastFloat32(0.0f); | |||||
GI_FLOAT32_t d18d19 = GiBroadcastFloat32(0.0f); | |||||
GI_FLOAT32_t d20d21 = GiBroadcastFloat32(0.0f); | |||||
GI_FLOAT32_t d22d23 = GiBroadcastFloat32(0.0f); | |||||
GI_FLOAT32_t d0d1 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); | |||||
d18d19 = GiSimdFmaLane(d18d19, d10d11, d0d1, 1); | |||||
for (; K > 0; K -= 4) { | |||||
d8d9 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
d10d11 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
d20d21 = GiSimdFmaLane(d20d21, d12d13, d0d1, 2); | |||||
d22d23 = GiSimdFmaLane(d22d23, d14d15, d0d1, 3); | |||||
B = B + LDB; | |||||
d0d1 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
d12d13 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
d14d15 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); | |||||
d18d19 = GiSimdFmaLane(d18d19, d10d11, d0d1, 1); | |||||
} | |||||
d20d21 = GiSimdFmaLane(d20d21, d12d13, d0d1, 2); | |||||
d22d23 = GiSimdFmaLane(d22d23, d14d15, d0d1, 3); | |||||
d16d17 = GiAddFloat32(d16d17, d20d21); | |||||
d18d19 = GiAddFloat32(d18d19, d22d23); | |||||
d16d17 = GiAddFloat32(d16d17, d18d19); | |||||
GiStoreFloat32(C, d16d17); | |||||
C = C + 4; | |||||
} | |||||
void kern_4x4(const float* A, const float* B, size_t LDB, size_t K, float* C) { | |||||
LDB = (LDB - 16); | |||||
K = K - 4; | |||||
GI_FLOAT32_t d8d9 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
GI_FLOAT32_t d10d11 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
GI_FLOAT32_t d12d13 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
GI_FLOAT32_t d14d15 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
GI_FLOAT32_t d0d1 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
GI_FLOAT32_t d2d3 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
GI_FLOAT32_t d4d5 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
GI_FLOAT32_t d6d7 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); | |||||
GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); | |||||
GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); | |||||
GI_FLOAT32_t d22d23 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0); | |||||
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); | |||||
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); | |||||
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); | |||||
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); | |||||
for (; K > 0; K -= 4) { | |||||
d8d9 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
d10d11 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); | |||||
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); | |||||
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); | |||||
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); | |||||
B = B + LDB; | |||||
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); | |||||
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); | |||||
d0d1 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); | |||||
d2d3 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); | |||||
d4d5 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); | |||||
d6d7 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
d18d19 = GiSimdFmaLane(d18d19, d8d9, d2d3, 0); | |||||
d20d21 = GiSimdFmaLane(d20d21, d8d9, d4d5, 0); | |||||
d22d23 = GiSimdFmaLane(d22d23, d8d9, d6d7, 0); | |||||
d12d13 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
d14d15 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); | |||||
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); | |||||
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); | |||||
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); | |||||
} | |||||
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); | |||||
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); | |||||
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); | |||||
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); | |||||
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); | |||||
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); | |||||
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); | |||||
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); | |||||
GiStoreFloat32(C, d16d17); | |||||
C = C + 4; | |||||
GiStoreFloat32(C, d18d19); | |||||
C = C + 4; | |||||
GiStoreFloat32(C, d20d21); | |||||
C = C + 4; | |||||
GiStoreFloat32(C, d22d23); | |||||
C = C + 4; | |||||
} | |||||
void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) { | |||||
LDB -= 32; | |||||
GI_FLOAT32_t d8d9 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
GI_FLOAT32_t d10d11 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
GI_FLOAT32_t d12d13 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
GI_FLOAT32_t d14d15 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
GI_FLOAT32_t d0d1 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
GI_FLOAT32_t d2d3 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
GI_FLOAT32_t d4d5 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
GI_FLOAT32_t d6d7 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); | |||||
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); | |||||
GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); | |||||
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); | |||||
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); | |||||
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); | |||||
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); | |||||
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); | |||||
d0d1 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
d2d3 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); | |||||
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); | |||||
GI_FLOAT32_t d22d23 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0); | |||||
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); | |||||
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); | |||||
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); | |||||
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); | |||||
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); | |||||
d4d5 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
d6d7 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
GI_FLOAT32_t d24d25 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); | |||||
d24d25 = GiSimdFmaLane(d24d25, d10d11, d0d1, 1); | |||||
GI_FLOAT32_t d26d27 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); | |||||
d24d25 = GiSimdFmaLane(d24d25, d12d13, d0d1, 2); | |||||
d26d27 = GiSimdFmaLane(d26d27, d10d11, d2d3, 1); | |||||
d24d25 = GiSimdFmaLane(d24d25, d14d15, d0d1, 3); | |||||
d26d27 = GiSimdFmaLane(d26d27, d12d13, d2d3, 2); | |||||
d26d27 = GiSimdFmaLane(d26d27, d14d15, d2d3, 3); | |||||
GI_FLOAT32_t d28d29 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); | |||||
d28d29 = GiSimdFmaLane(d28d29, d10d11, d4d5, 1); | |||||
GI_FLOAT32_t d30d31 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0); | |||||
d28d29 = GiSimdFmaLane(d28d29, d12d13, d4d5, 2); | |||||
d30d31 = GiSimdFmaLane(d30d31, d10d11, d6d7, 1); | |||||
d28d29 = GiSimdFmaLane(d28d29, d14d15, d4d5, 3); | |||||
d30d31 = GiSimdFmaLane(d30d31, d12d13, d6d7, 2); | |||||
d30d31 = GiSimdFmaLane(d30d31, d14d15, d6d7, 3); | |||||
B = B + LDB; | |||||
K = K - 4; | |||||
for (; K > 0; K -= 4) { | |||||
d8d9 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
d10d11 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
d12d13 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
d14d15 = GiLoadFloat32(A); | |||||
A = A + 4; | |||||
d0d1 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
d2d3 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
d4d5 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
d6d7 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); | |||||
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); | |||||
d18d19 = GiSimdFmaLane(d18d19, d8d9, d2d3, 0); | |||||
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); | |||||
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); | |||||
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); | |||||
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); | |||||
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); | |||||
d0d1 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
d2d3 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
d20d21 = GiSimdFmaLane(d20d21, d8d9, d4d5, 0); | |||||
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); | |||||
d22d23 = GiSimdFmaLane(d22d23, d8d9, d6d7, 0); | |||||
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); | |||||
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); | |||||
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); | |||||
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); | |||||
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); | |||||
d4d5 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
d6d7 = GiLoadFloat32(B); | |||||
B = B + 4; | |||||
d24d25 = GiSimdFmaLane(d24d25, d8d9, d0d1, 0); | |||||
d24d25 = GiSimdFmaLane(d24d25, d10d11, d0d1, 1); | |||||
d26d27 = GiSimdFmaLane(d26d27, d8d9, d2d3, 0); | |||||
d24d25 = GiSimdFmaLane(d24d25, d12d13, d0d1, 2); | |||||
d26d27 = GiSimdFmaLane(d26d27, d10d11, d2d3, 1); | |||||
d24d25 = GiSimdFmaLane(d24d25, d14d15, d0d1, 3); | |||||
d26d27 = GiSimdFmaLane(d26d27, d12d13, d2d3, 2); | |||||
d26d27 = GiSimdFmaLane(d26d27, d14d15, d2d3, 3); | |||||
d28d29 = GiSimdFmaLane(d28d29, d8d9, d4d5, 0); | |||||
d28d29 = GiSimdFmaLane(d28d29, d10d11, d4d5, 1); | |||||
d30d31 = GiSimdFmaLane(d30d31, d8d9, d6d7, 0); | |||||
d28d29 = GiSimdFmaLane(d28d29, d12d13, d4d5, 2); | |||||
d30d31 = GiSimdFmaLane(d30d31, d10d11, d6d7, 1); | |||||
d28d29 = GiSimdFmaLane(d28d29, d14d15, d4d5, 3); | |||||
d30d31 = GiSimdFmaLane(d30d31, d12d13, d6d7, 2); | |||||
d30d31 = GiSimdFmaLane(d30d31, d14d15, d6d7, 3); | |||||
B = B + LDB; | |||||
} | |||||
GiStoreFloat32(C, d16d17); | |||||
C = C + 4; | |||||
GiStoreFloat32(C, d18d19); | |||||
C = C + 4; | |||||
GiStoreFloat32(C, d20d21); | |||||
C = C + 4; | |||||
GiStoreFloat32(C, d22d23); | |||||
C = C + 4; | |||||
GiStoreFloat32(C, d24d25); | |||||
C = C + 4; | |||||
GiStoreFloat32(C, d26d27); | |||||
C = C + 4; | |||||
GiStoreFloat32(C, d28d29); | |||||
C = C + 4; | |||||
GiStoreFloat32(C, d30d31); | |||||
C = C + 4; | |||||
} | |||||
} // namespace | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gi_sgemm_nopack_4x8); | |||||
void gi_sgemm_nopack_4x8::kern( | |||||
const float* A, size_t LDA, const float* B, size_t LDB, float* C, size_t LDC, | |||||
size_t M, size_t K, size_t N, const float*, void*, bool trA, bool trB) const { | |||||
constexpr size_t MB = 4; | |||||
constexpr size_t KB = 4; | |||||
constexpr size_t NB = 8; | |||||
constexpr size_t NB_HALF = 4; | |||||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||||
for (size_t m = 0; m < M; m += MB) { | |||||
float* output = C + (m / MB) * LDC; | |||||
const float* cur_B = B; | |||||
size_t n = 0; | |||||
for (; n + NB - 1 < N; n += NB) { | |||||
kern_4x8(A, cur_B, LDB, K, output); | |||||
cur_B += KB * NB; | |||||
output += MB * NB; | |||||
} | |||||
if (N - n >= 4) { | |||||
kern_4x4(A, cur_B, LDB, K, output); | |||||
cur_B += KB * NB_HALF; | |||||
output += MB * NB_HALF; | |||||
n += 4; | |||||
} | |||||
while (n < N) { | |||||
kern_4x1(A, cur_B, LDB, K, output); | |||||
cur_B += KB; | |||||
output += MB; | |||||
n++; | |||||
} | |||||
A += LDA; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -36,6 +36,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
AlgoF32K8x12x1 f32_k8x12x1; | AlgoF32K8x12x1 f32_k8x12x1; | ||||
AlgoGemv gemv; | AlgoGemv gemv; | ||||
AlgoNaive naive; | AlgoNaive naive; | ||||
AlgoF32GiGemvMK4 f32_gemv_mk4; | |||||
AlgoF32GiMK4_4x8 f32_mk4_4x8; | |||||
SmallVector<AlgoBase*> m_all_algos; | SmallVector<AlgoBase*> m_all_algos; | ||||
AlgoBase::Mapper m_all_algos_map; | AlgoBase::Mapper m_all_algos_map; | ||||
@@ -44,6 +46,8 @@ public: | |||||
m_all_algos.emplace_back(&gemv); | m_all_algos.emplace_back(&gemv); | ||||
m_all_algos.emplace_back(&f32_k8x12x1); | m_all_algos.emplace_back(&f32_k8x12x1); | ||||
m_all_algos.emplace_back(&naive); | m_all_algos.emplace_back(&naive); | ||||
m_all_algos.emplace_back(&f32_gemv_mk4); | |||||
m_all_algos.emplace_back(&f32_mk4_4x8); | |||||
for (auto&& algo : m_all_algos) { | for (auto&& algo : m_all_algos) { | ||||
m_all_algos_map.emplace(algo->info().desc, algo); | m_all_algos_map.emplace(algo->info().desc, algo); | ||||
} | } | ||||
@@ -112,6 +112,8 @@ public: | |||||
FB_F32K8x12x1 = 1 << 0, | FB_F32K8x12x1 = 1 << 0, | ||||
FB_GEMV, | FB_GEMV, | ||||
FB_NAIVE, | FB_NAIVE, | ||||
FB_GI_F32_GEMV_MK4, | |||||
FB_GI_F32_MK4_4x8, | |||||
#if MEGDNN_X86 | #if MEGDNN_X86 | ||||
//! x86 | //! x86 | ||||
@@ -131,7 +133,6 @@ public: | |||||
ARM_COMMON_INT8X8X32_GEMV, | ARM_COMMON_INT8X8X32_GEMV, | ||||
ARM_COMMON_INT8X8X32_GEMV_MK4, | ARM_COMMON_INT8X8X32_GEMV_MK4, | ||||
ARM_COMMON_INT8X8X32_GEMV_MK4_DOT, | ARM_COMMON_INT8X8X32_GEMV_MK4_DOT, | ||||
ARM_COMMON_F32_GEMV_MK4, | |||||
ARM_COMMON_F16_GEMV, | ARM_COMMON_F16_GEMV, | ||||
ARM_COMMON_GEVM, | ARM_COMMON_GEVM, | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
@@ -236,7 +237,9 @@ public: | |||||
}; | }; | ||||
private: | private: | ||||
class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 | |||||
class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 | |||||
class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44 | |||||
class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44 | |||||
class AlgoGemv; | class AlgoGemv; | ||||
class AlgoNaive; | class AlgoNaive; | ||||
class AlgoPack; | class AlgoPack; | ||||
@@ -45,6 +45,13 @@ TEST_F(FALLBACK, MATRIX_MUL) { | |||||
checker.execl({AL, BL, CL}); | checker.execl({AL, BL, CL}); | ||||
} | } | ||||
} | } | ||||
TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) { | |||||
matrix_mul::check_matrix_mul( | |||||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||||
"FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1); | |||||
} | |||||
TEST_F(FALLBACK, MATRIX_MUL_RECORD) { | TEST_F(FALLBACK, MATRIX_MUL_RECORD) { | ||||
TaskRecordChecker<MatrixMul> checker(1); | TaskRecordChecker<MatrixMul> checker(1); | ||||
using Param = MatrixMul::Param; | using Param = MatrixMul::Param; | ||||