move AlgoF32GemvMK4 from arm_common to fallback
GitOrigin-RevId: 6c065abf99
release-1.10
@@ -239,46 +239,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern(const KernSizeParam&) | |||
return f32_gemv_kern; | |||
} | |||
/* ================== F32 Gemv MK4 algo ================== */ | |||
namespace { | |||
void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
MIDOUT_BEGIN(megdnn_arm_exec_fp32, midout_iv("f32_gemv_mk4_kern"_hash)) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_float32>(), Bptr = kern_param.B<dt_float32>(); | |||
auto Cptr = kern_param.C<dt_float32>(); | |||
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoF32GemvMK4::usable(const KernSizeParam& kern_size_param) const { | |||
// enumerate the M, N, K, only usable when preferred | |||
auto M = kern_size_param.M; | |||
auto N = kern_size_param.N; | |||
auto K = kern_size_param.K; | |||
auto LDB = kern_size_param.LDB; | |||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
kern_size_param.B_type == kern_size_param.A_type && | |||
kern_size_param.C_type == kern_size_param.A_type && | |||
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | |||
!kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; | |||
} | |||
bool MatrixMulImpl::AlgoF32GemvMK4::preferred( | |||
const KernSizeParam& kern_size_param) const { | |||
MEGDNN_MARK_USED_VAR(kern_size_param); | |||
return true; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GemvMK4::get_kern( | |||
const KernSizeParam&) const { | |||
return f32_gemv_mk4_kern; | |||
} | |||
/* ===================== F32 Gevm algo ===================== */ | |||
namespace { | |||
template <typename stype, typename dtype> | |||
@@ -95,22 +95,6 @@ public: | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | |||
}; | |||
class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) | |||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F32_GEMV_MK4) | |||
}; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
class MatrixMulImpl::AlgoF16Gemv : public AlgoBase { | |||
public: | |||
@@ -26,7 +26,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; | |||
#endif | |||
AlgoGevm gevm; | |||
AlgoF32GemvMK4 f32_gemv_mk4; | |||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; | |||
fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; | |||
@@ -42,7 +41,6 @@ public: | |||
#endif | |||
m_all_algos.emplace_back(&int8x8x32_gemv); | |||
m_all_algos.emplace_back(&int8x8x32_gemv_mk4); | |||
m_all_algos.emplace_back(&f32_gemv_mk4); | |||
m_all_algos.emplace_back(&gevm); | |||
for (auto&& algo : m_all_algos) { | |||
@@ -34,7 +34,6 @@ public: | |||
protected: | |||
class AlgoF32Gemv; // Arm_common F32 Gemv | |||
class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44 | |||
class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv | |||
class AlgoInt8x8x32GemvMK4; // Arm_common Int8x8x32 Gemv NCHW44 | |||
class AlgoGevm; // Arm_common Gevm(support int8 and fp32) | |||
@@ -17,11 +17,15 @@ | |||
#include "src/naive/matrix_mul/matrix_mul_helper.h" | |||
#include "src/fallback/matrix_mul/gi/fp32/exec_sgemv.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_fb_matmul_f32_kern) | |||
MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like) | |||
MIDOUT_DECL(megdnn_fb_matmul_naive) | |||
MIDOUT_DECL(megdnn_fb_gi_exec_fp32) | |||
MIDOUT_DECL(megdnn_fb_gi_matmul_kern) | |||
using namespace megdnn; | |||
using namespace fallback; | |||
@@ -205,4 +209,99 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoNaive::get_kern(const KernSizeParam&) c | |||
return kern_naive; | |||
} | |||
/* ================== F32 Gemv MK4 gi algo ================== */ | |||
namespace { | |||
void gi_f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
MIDOUT_BEGIN(megdnn_fb_gi_exec_fp32, midout_iv("f32_gemv_mk4_gi_kern"_hash)) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_float32>(), Bptr = kern_param.B<dt_float32>(); | |||
auto Cptr = kern_param.C<dt_float32>(); | |||
gi_gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoF32GiGemvMK4::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
// enumerate the M, N, K, only usable when preferred | |||
auto M = kern_size_param.M; | |||
auto N = kern_size_param.N; | |||
auto K = kern_size_param.K; | |||
auto LDB = kern_size_param.LDB; | |||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
kern_size_param.B_type == kern_size_param.A_type && | |||
kern_size_param.C_type == kern_size_param.A_type && | |||
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | |||
!kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; | |||
} | |||
bool MatrixMulImpl::AlgoF32GiGemvMK4::preferred( | |||
const KernSizeParam& kern_size_param) const { | |||
MEGDNN_MARK_USED_VAR(kern_size_param); | |||
return true; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiGemvMK4::get_kern( | |||
const KernSizeParam&) const { | |||
return gi_f32_gemv_mk4_kern; | |||
} | |||
/* ================== F32 Gemm MK4 gi algo ================== */ | |||
namespace { | |||
void gi_f32_mk4_4x8_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
MIDOUT_BEGIN(megdnn_fb_gi_matmul_kern, midout_iv("gi_f32_mk4_4x8_kern"_hash)) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto trA = kern_param.trA, trB = kern_param.trB; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
auto A_type = kern_param.A_type, B_type = kern_param.B_type, | |||
C_type = kern_param.C_type; | |||
const auto Aptr = kern_param.A<float>(), Bptr = kern_param.B<float>(); | |||
auto Cptr = kern_param.C<float>(); | |||
matmul::fallback::gi_sgemm_nopack_4x8 strategy(A_type, B_type, C_type); | |||
megdnn::matmul::GemmInterleaved<matmul::fallback::gi_sgemm_nopack_4x8, false>( | |||
M, N, K, trA, trB, strategy) | |||
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoF32GiMK4_4x8::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
kern_size_param.B_type == kern_size_param.A_type && | |||
kern_size_param.C_type == kern_size_param.A_type && | |||
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | |||
!kern_size_param.trB; | |||
} | |||
size_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_workspace( | |||
const KernSizeParam& kern_size_param) const { | |||
MIDOUT_BEGIN( | |||
megdnn_fb_gi_matmul_kern, | |||
midout_iv("AlgoF32GiMK4_4x8::get_workspace"_hash)) { | |||
auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; | |||
auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||
C_type = kern_size_param.C_type; | |||
matmul::fallback::gi_sgemm_nopack_4x8 strategy(A_type, B_type, C_type); | |||
return megdnn::matmul::GemmInterleaved< | |||
matmul::fallback::gi_sgemm_nopack_4x8, false>( | |||
M, N, K, trA, trB, strategy) | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern( | |||
const KernSizeParam&) const { | |||
return gi_f32_mk4_4x8_kern; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -80,6 +80,34 @@ public: | |||
DEFAULT) | |||
}; | |||
class MatrixMulImpl::AlgoF32GiGemvMK4 : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "FB_GI_F32_GEMV_MK4"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) | |||
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_GEMV_MK4) | |||
}; | |||
class MatrixMulImpl::AlgoF32GiMK4_4x8 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
const char* name() const override { return "FB_GI_F32_MK4_4x8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) | |||
MEGDNN_DECL_ALGO_TYPE(FB_GI_F32_MK4_4x8) | |||
}; | |||
} // namespace fallback | |||
} // namespace megdnn | |||
@@ -16,6 +16,8 @@ namespace matmul { | |||
namespace fallback { | |||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12); | |||
MEGDNN_REG_GEMM_STRATEGY_NOPACK( | |||
float, float, float, 4, 8, 1, false, true, gi_sgemm_nopack_4x8); | |||
} // namespace fallback | |||
} // namespace matmul | |||
@@ -0,0 +1,101 @@ | |||
/** | |||
* \file dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2022 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/fallback/matrix_mul/gi/fp32/exec_sgemv.h" | |||
#include "include/megdnn/oprs.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/general_intrinsic/gi_float.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_fp32_gi_sgemv) | |||
using namespace megdnn; | |||
using namespace fallback; | |||
namespace { | |||
void sgemv_gi_naive_n_mk4( | |||
const float* __restrict A, const float* __restrict B, float* __restrict C, | |||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { | |||
constexpr size_t PACK_SIZE = 4; | |||
megdnn_assert( | |||
N == 1 && Bstride == PACK_SIZE && M % PACK_SIZE == 0 && K % PACK_SIZE == 0); | |||
auto Aptr = A; | |||
auto Cptr = C; | |||
size_t m = 0; | |||
while (m < M) { | |||
auto Aptr0 = Aptr; | |||
auto Cptr0 = Cptr; | |||
GI_FLOAT32_t c[4]; | |||
#define INIT(step) c[step] = GiBroadcastFloat32(0.0f); | |||
UNROLL_CALL_RAW(4, INIT) | |||
#undef INIT | |||
auto Bptr = B; | |||
size_t k = 0; | |||
while (k < K) { | |||
GI_FLOAT32_t b = GiLoadFloat32(Bptr); | |||
GI_FLOAT32_V2_t a[2]; | |||
#if defined(GI_TEST_NAIVE) | |||
#define LOAD_A(step) \ | |||
a[step].val[0] = GiLoadFloat32(Aptr0 + step * 8); \ | |||
a[step].val[1] = GiLoadFloat32(Aptr0 + step * 8 + 4); | |||
#elif defined(__arm__) || defined(__aarch64__) | |||
#define LOAD_A(step) a[step] = vld1q_f32_x2(Aptr0 + step * 8); | |||
#else | |||
#define LOAD_A(step) \ | |||
a[step].val[0] = GiLoadFloat32(Aptr0 + step * 8); \ | |||
a[step].val[1] = GiLoadFloat32(Aptr0 + step * 8 + 4); | |||
#endif | |||
UNROLL_CALL_RAW(2, LOAD_A) | |||
#undef LOAD_A | |||
#define COMPT(step) \ | |||
c[step] = GiSimdFmaLane(c[step], a[step / 2].val[step % 2], b, step % 4); | |||
UNROLL_CALL_RAW(4, COMPT) | |||
#undef COMPT | |||
Bptr += Bstride; | |||
Aptr0 += PACK_SIZE * PACK_SIZE; | |||
k += PACK_SIZE; | |||
} | |||
#define ADD_C(step, stride) c[step] = GiAddFloat32(c[step], c[step + stride]); | |||
UNROLL_CALL_RAW(2, ADD_C, 2) | |||
UNROLL_CALL_RAW(1, ADD_C, 1) | |||
#undef ADD_C | |||
GiStoreFloat32(Cptr0, c[0]); | |||
Aptr += Astride; | |||
Cptr += Cstride; | |||
m += PACK_SIZE; | |||
} | |||
} | |||
} // namespace | |||
namespace megdnn { | |||
namespace fallback { | |||
void gi_gemv_like_mk4( | |||
const float* __restrict A, const float* __restrict B, float* __restrict C, | |||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { | |||
megdnn_assert(N == 1 && Bstride == 4); | |||
MIDOUT_BEGIN(megdnn_fp32_gi_sgemv, midout_iv("F32_GEMV_NCHW_GI_44_N"_hash)) { | |||
return sgemv_gi_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,25 @@ | |||
/** | |||
* \file dnn/src/fallback/matrix_mul/gi/fp32/exec_sgemv.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2022 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <cstddef> | |||
namespace megdnn { | |||
namespace fallback { | |||
void gi_gemv_like_mk4( | |||
const float* __restrict A, const float* __restrict B, float* __restrict C, | |||
size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,349 @@ | |||
/** | |||
* \file dnn/src/fallback/matrix_mul/gi/fp32/strategy_mk4_4x8.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2022 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/common/utils.h" | |||
#include "src/fallback/general_intrinsic/gi_float.h" | |||
#include "src/fallback/matrix_mul/generic_strategy.h" | |||
using namespace megdnn; | |||
using namespace matmul::fallback; | |||
namespace { | |||
void kern_4x1(const float* A, const float* B, size_t LDB, size_t K, float* C) { | |||
LDB = LDB - 4; | |||
K = K - 4; | |||
GI_FLOAT32_t d8d9 = GiLoadFloat32(A); | |||
A = A + 4; | |||
GI_FLOAT32_t d10d11 = GiLoadFloat32(A); | |||
A = A + 4; | |||
GI_FLOAT32_t d12d13 = GiLoadFloat32(A); | |||
A = A + 4; | |||
GI_FLOAT32_t d14d15 = GiLoadFloat32(A); | |||
A = A + 4; | |||
GI_FLOAT32_t d16d17 = GiBroadcastFloat32(0.0f); | |||
GI_FLOAT32_t d18d19 = GiBroadcastFloat32(0.0f); | |||
GI_FLOAT32_t d20d21 = GiBroadcastFloat32(0.0f); | |||
GI_FLOAT32_t d22d23 = GiBroadcastFloat32(0.0f); | |||
GI_FLOAT32_t d0d1 = GiLoadFloat32(B); | |||
B = B + 4; | |||
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); | |||
d18d19 = GiSimdFmaLane(d18d19, d10d11, d0d1, 1); | |||
for (; K > 0; K -= 4) { | |||
d8d9 = GiLoadFloat32(A); | |||
A = A + 4; | |||
d10d11 = GiLoadFloat32(A); | |||
A = A + 4; | |||
d20d21 = GiSimdFmaLane(d20d21, d12d13, d0d1, 2); | |||
d22d23 = GiSimdFmaLane(d22d23, d14d15, d0d1, 3); | |||
B = B + LDB; | |||
d0d1 = GiLoadFloat32(B); | |||
B = B + 4; | |||
d12d13 = GiLoadFloat32(A); | |||
A = A + 4; | |||
d14d15 = GiLoadFloat32(A); | |||
A = A + 4; | |||
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); | |||
d18d19 = GiSimdFmaLane(d18d19, d10d11, d0d1, 1); | |||
} | |||
d20d21 = GiSimdFmaLane(d20d21, d12d13, d0d1, 2); | |||
d22d23 = GiSimdFmaLane(d22d23, d14d15, d0d1, 3); | |||
d16d17 = GiAddFloat32(d16d17, d20d21); | |||
d18d19 = GiAddFloat32(d18d19, d22d23); | |||
d16d17 = GiAddFloat32(d16d17, d18d19); | |||
GiStoreFloat32(C, d16d17); | |||
C = C + 4; | |||
} | |||
void kern_4x4(const float* A, const float* B, size_t LDB, size_t K, float* C) { | |||
LDB = (LDB - 16); | |||
K = K - 4; | |||
GI_FLOAT32_t d8d9 = GiLoadFloat32(A); | |||
A = A + 4; | |||
GI_FLOAT32_t d10d11 = GiLoadFloat32(A); | |||
A = A + 4; | |||
GI_FLOAT32_t d12d13 = GiLoadFloat32(A); | |||
A = A + 4; | |||
GI_FLOAT32_t d14d15 = GiLoadFloat32(A); | |||
A = A + 4; | |||
GI_FLOAT32_t d0d1 = GiLoadFloat32(B); | |||
B = B + 4; | |||
GI_FLOAT32_t d2d3 = GiLoadFloat32(B); | |||
B = B + 4; | |||
GI_FLOAT32_t d4d5 = GiLoadFloat32(B); | |||
B = B + 4; | |||
GI_FLOAT32_t d6d7 = GiLoadFloat32(B); | |||
B = B + 4; | |||
GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); | |||
GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); | |||
GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); | |||
GI_FLOAT32_t d22d23 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0); | |||
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); | |||
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); | |||
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); | |||
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); | |||
for (; K > 0; K -= 4) { | |||
d8d9 = GiLoadFloat32(A); | |||
A = A + 4; | |||
d10d11 = GiLoadFloat32(A); | |||
A = A + 4; | |||
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); | |||
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); | |||
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); | |||
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); | |||
B = B + LDB; | |||
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); | |||
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); | |||
d0d1 = GiLoadFloat32(B); | |||
B = B + 4; | |||
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); | |||
d2d3 = GiLoadFloat32(B); | |||
B = B + 4; | |||
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); | |||
d4d5 = GiLoadFloat32(B); | |||
B = B + 4; | |||
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); | |||
d6d7 = GiLoadFloat32(B); | |||
B = B + 4; | |||
d18d19 = GiSimdFmaLane(d18d19, d8d9, d2d3, 0); | |||
d20d21 = GiSimdFmaLane(d20d21, d8d9, d4d5, 0); | |||
d22d23 = GiSimdFmaLane(d22d23, d8d9, d6d7, 0); | |||
d12d13 = GiLoadFloat32(A); | |||
A = A + 4; | |||
d14d15 = GiLoadFloat32(A); | |||
A = A + 4; | |||
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); | |||
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); | |||
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); | |||
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); | |||
} | |||
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); | |||
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); | |||
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); | |||
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); | |||
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); | |||
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); | |||
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); | |||
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); | |||
GiStoreFloat32(C, d16d17); | |||
C = C + 4; | |||
GiStoreFloat32(C, d18d19); | |||
C = C + 4; | |||
GiStoreFloat32(C, d20d21); | |||
C = C + 4; | |||
GiStoreFloat32(C, d22d23); | |||
C = C + 4; | |||
} | |||
void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) { | |||
LDB -= 32; | |||
GI_FLOAT32_t d8d9 = GiLoadFloat32(A); | |||
A = A + 4; | |||
GI_FLOAT32_t d10d11 = GiLoadFloat32(A); | |||
A = A + 4; | |||
GI_FLOAT32_t d12d13 = GiLoadFloat32(A); | |||
A = A + 4; | |||
GI_FLOAT32_t d14d15 = GiLoadFloat32(A); | |||
A = A + 4; | |||
GI_FLOAT32_t d0d1 = GiLoadFloat32(B); | |||
B = B + 4; | |||
GI_FLOAT32_t d2d3 = GiLoadFloat32(B); | |||
B = B + 4; | |||
GI_FLOAT32_t d4d5 = GiLoadFloat32(B); | |||
B = B + 4; | |||
GI_FLOAT32_t d6d7 = GiLoadFloat32(B); | |||
B = B + 4; | |||
GI_FLOAT32_t d16d17 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); | |||
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); | |||
GI_FLOAT32_t d18d19 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); | |||
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); | |||
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); | |||
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); | |||
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); | |||
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); | |||
d0d1 = GiLoadFloat32(B); | |||
B = B + 4; | |||
d2d3 = GiLoadFloat32(B); | |||
B = B + 4; | |||
GI_FLOAT32_t d20d21 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); | |||
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); | |||
GI_FLOAT32_t d22d23 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0); | |||
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); | |||
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); | |||
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); | |||
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); | |||
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); | |||
d4d5 = GiLoadFloat32(B); | |||
B = B + 4; | |||
d6d7 = GiLoadFloat32(B); | |||
B = B + 4; | |||
GI_FLOAT32_t d24d25 = GiSimdFmaLane(vfzero, d8d9, d0d1, 0); | |||
d24d25 = GiSimdFmaLane(d24d25, d10d11, d0d1, 1); | |||
GI_FLOAT32_t d26d27 = GiSimdFmaLane(vfzero, d8d9, d2d3, 0); | |||
d24d25 = GiSimdFmaLane(d24d25, d12d13, d0d1, 2); | |||
d26d27 = GiSimdFmaLane(d26d27, d10d11, d2d3, 1); | |||
d24d25 = GiSimdFmaLane(d24d25, d14d15, d0d1, 3); | |||
d26d27 = GiSimdFmaLane(d26d27, d12d13, d2d3, 2); | |||
d26d27 = GiSimdFmaLane(d26d27, d14d15, d2d3, 3); | |||
GI_FLOAT32_t d28d29 = GiSimdFmaLane(vfzero, d8d9, d4d5, 0); | |||
d28d29 = GiSimdFmaLane(d28d29, d10d11, d4d5, 1); | |||
GI_FLOAT32_t d30d31 = GiSimdFmaLane(vfzero, d8d9, d6d7, 0); | |||
d28d29 = GiSimdFmaLane(d28d29, d12d13, d4d5, 2); | |||
d30d31 = GiSimdFmaLane(d30d31, d10d11, d6d7, 1); | |||
d28d29 = GiSimdFmaLane(d28d29, d14d15, d4d5, 3); | |||
d30d31 = GiSimdFmaLane(d30d31, d12d13, d6d7, 2); | |||
d30d31 = GiSimdFmaLane(d30d31, d14d15, d6d7, 3); | |||
B = B + LDB; | |||
K = K - 4; | |||
for (; K > 0; K -= 4) { | |||
d8d9 = GiLoadFloat32(A); | |||
A = A + 4; | |||
d10d11 = GiLoadFloat32(A); | |||
A = A + 4; | |||
d12d13 = GiLoadFloat32(A); | |||
A = A + 4; | |||
d14d15 = GiLoadFloat32(A); | |||
A = A + 4; | |||
d0d1 = GiLoadFloat32(B); | |||
B = B + 4; | |||
d2d3 = GiLoadFloat32(B); | |||
B = B + 4; | |||
d4d5 = GiLoadFloat32(B); | |||
B = B + 4; | |||
d6d7 = GiLoadFloat32(B); | |||
B = B + 4; | |||
d16d17 = GiSimdFmaLane(d16d17, d8d9, d0d1, 0); | |||
d16d17 = GiSimdFmaLane(d16d17, d10d11, d0d1, 1); | |||
d18d19 = GiSimdFmaLane(d18d19, d8d9, d2d3, 0); | |||
d16d17 = GiSimdFmaLane(d16d17, d12d13, d0d1, 2); | |||
d18d19 = GiSimdFmaLane(d18d19, d10d11, d2d3, 1); | |||
d16d17 = GiSimdFmaLane(d16d17, d14d15, d0d1, 3); | |||
d18d19 = GiSimdFmaLane(d18d19, d12d13, d2d3, 2); | |||
d18d19 = GiSimdFmaLane(d18d19, d14d15, d2d3, 3); | |||
d0d1 = GiLoadFloat32(B); | |||
B = B + 4; | |||
d2d3 = GiLoadFloat32(B); | |||
B = B + 4; | |||
d20d21 = GiSimdFmaLane(d20d21, d8d9, d4d5, 0); | |||
d20d21 = GiSimdFmaLane(d20d21, d10d11, d4d5, 1); | |||
d22d23 = GiSimdFmaLane(d22d23, d8d9, d6d7, 0); | |||
d20d21 = GiSimdFmaLane(d20d21, d12d13, d4d5, 2); | |||
d22d23 = GiSimdFmaLane(d22d23, d10d11, d6d7, 1); | |||
d20d21 = GiSimdFmaLane(d20d21, d14d15, d4d5, 3); | |||
d22d23 = GiSimdFmaLane(d22d23, d12d13, d6d7, 2); | |||
d22d23 = GiSimdFmaLane(d22d23, d14d15, d6d7, 3); | |||
d4d5 = GiLoadFloat32(B); | |||
B = B + 4; | |||
d6d7 = GiLoadFloat32(B); | |||
B = B + 4; | |||
d24d25 = GiSimdFmaLane(d24d25, d8d9, d0d1, 0); | |||
d24d25 = GiSimdFmaLane(d24d25, d10d11, d0d1, 1); | |||
d26d27 = GiSimdFmaLane(d26d27, d8d9, d2d3, 0); | |||
d24d25 = GiSimdFmaLane(d24d25, d12d13, d0d1, 2); | |||
d26d27 = GiSimdFmaLane(d26d27, d10d11, d2d3, 1); | |||
d24d25 = GiSimdFmaLane(d24d25, d14d15, d0d1, 3); | |||
d26d27 = GiSimdFmaLane(d26d27, d12d13, d2d3, 2); | |||
d26d27 = GiSimdFmaLane(d26d27, d14d15, d2d3, 3); | |||
d28d29 = GiSimdFmaLane(d28d29, d8d9, d4d5, 0); | |||
d28d29 = GiSimdFmaLane(d28d29, d10d11, d4d5, 1); | |||
d30d31 = GiSimdFmaLane(d30d31, d8d9, d6d7, 0); | |||
d28d29 = GiSimdFmaLane(d28d29, d12d13, d4d5, 2); | |||
d30d31 = GiSimdFmaLane(d30d31, d10d11, d6d7, 1); | |||
d28d29 = GiSimdFmaLane(d28d29, d14d15, d4d5, 3); | |||
d30d31 = GiSimdFmaLane(d30d31, d12d13, d6d7, 2); | |||
d30d31 = GiSimdFmaLane(d30d31, d14d15, d6d7, 3); | |||
B = B + LDB; | |||
} | |||
GiStoreFloat32(C, d16d17); | |||
C = C + 4; | |||
GiStoreFloat32(C, d18d19); | |||
C = C + 4; | |||
GiStoreFloat32(C, d20d21); | |||
C = C + 4; | |||
GiStoreFloat32(C, d22d23); | |||
C = C + 4; | |||
GiStoreFloat32(C, d24d25); | |||
C = C + 4; | |||
GiStoreFloat32(C, d26d27); | |||
C = C + 4; | |||
GiStoreFloat32(C, d28d29); | |||
C = C + 4; | |||
GiStoreFloat32(C, d30d31); | |||
C = C + 4; | |||
} | |||
} // namespace | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gi_sgemm_nopack_4x8); | |||
void gi_sgemm_nopack_4x8::kern( | |||
const float* A, size_t LDA, const float* B, size_t LDB, float* C, size_t LDC, | |||
size_t M, size_t K, size_t N, const float*, void*, bool trA, bool trB) const { | |||
constexpr size_t MB = 4; | |||
constexpr size_t KB = 4; | |||
constexpr size_t NB = 8; | |||
constexpr size_t NB_HALF = 4; | |||
megdnn_assert(!trA && !trB && M % MB == 0 && K % KB == 0); | |||
for (size_t m = 0; m < M; m += MB) { | |||
float* output = C + (m / MB) * LDC; | |||
const float* cur_B = B; | |||
size_t n = 0; | |||
for (; n + NB - 1 < N; n += NB) { | |||
kern_4x8(A, cur_B, LDB, K, output); | |||
cur_B += KB * NB; | |||
output += MB * NB; | |||
} | |||
if (N - n >= 4) { | |||
kern_4x4(A, cur_B, LDB, K, output); | |||
cur_B += KB * NB_HALF; | |||
output += MB * NB_HALF; | |||
n += 4; | |||
} | |||
while (n < N) { | |||
kern_4x1(A, cur_B, LDB, K, output); | |||
cur_B += KB; | |||
output += MB; | |||
n++; | |||
} | |||
A += LDA; | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -36,6 +36,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
AlgoF32K8x12x1 f32_k8x12x1; | |||
AlgoGemv gemv; | |||
AlgoNaive naive; | |||
AlgoF32GiGemvMK4 f32_gemv_mk4; | |||
AlgoF32GiMK4_4x8 f32_mk4_4x8; | |||
SmallVector<AlgoBase*> m_all_algos; | |||
AlgoBase::Mapper m_all_algos_map; | |||
@@ -44,6 +46,8 @@ public: | |||
m_all_algos.emplace_back(&gemv); | |||
m_all_algos.emplace_back(&f32_k8x12x1); | |||
m_all_algos.emplace_back(&naive); | |||
m_all_algos.emplace_back(&f32_gemv_mk4); | |||
m_all_algos.emplace_back(&f32_mk4_4x8); | |||
for (auto&& algo : m_all_algos) { | |||
m_all_algos_map.emplace(algo->info().desc, algo); | |||
} | |||
@@ -112,6 +112,8 @@ public: | |||
FB_F32K8x12x1 = 1 << 0, | |||
FB_GEMV, | |||
FB_NAIVE, | |||
FB_GI_F32_GEMV_MK4, | |||
FB_GI_F32_MK4_4x8, | |||
#if MEGDNN_X86 | |||
//! x86 | |||
@@ -131,7 +133,6 @@ public: | |||
ARM_COMMON_INT8X8X32_GEMV, | |||
ARM_COMMON_INT8X8X32_GEMV_MK4, | |||
ARM_COMMON_INT8X8X32_GEMV_MK4_DOT, | |||
ARM_COMMON_F32_GEMV_MK4, | |||
ARM_COMMON_F16_GEMV, | |||
ARM_COMMON_GEVM, | |||
#if MEGDNN_AARCH64 | |||
@@ -236,7 +237,9 @@ public: | |||
}; | |||
private: | |||
class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 | |||
class AlgoF32K8x12x1; // Fallback F32 Kernel 8x12x1 | |||
class AlgoF32GiGemvMK4; // fallback F32 gi Gemv NCHW44 | |||
class AlgoF32GiMK4_4x8; // fallback F32 gi Gemm NCHW44 | |||
class AlgoGemv; | |||
class AlgoNaive; | |||
class AlgoPack; | |||
@@ -45,6 +45,13 @@ TEST_F(FALLBACK, MATRIX_MUL) { | |||
checker.execl({AL, BL, CL}); | |||
} | |||
} | |||
TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) { | |||
matrix_mul::check_matrix_mul( | |||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||
"FB_GI_F32_MK4_4x8", param::MatrixMul::Format::MK4, 1); | |||
} | |||
TEST_F(FALLBACK, MATRIX_MUL_RECORD) { | |||
TaskRecordChecker<MatrixMul> checker(1); | |||
using Param = MatrixMul::Param; | |||