GitOrigin-RevId: a049c33f2b
tags/v1.0.0-rc1
@@ -23,6 +23,9 @@ | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/matrix_mul/gemm_impl.h" | #include "src/fallback/matrix_mul/gemm_impl.h" | ||||
#if MGB_ENABLE_CPUINFO | |||||
#include "cpuinfo.h" | |||||
#endif | |||||
#include "midout.h" | #include "midout.h" | ||||
MIDOUT_DECL(megdnn_aarch64_matmul_kern) | MIDOUT_DECL(megdnn_aarch64_matmul_kern) | ||||
@@ -80,6 +83,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
}; | }; | ||||
return f32_kern_8x12; | return f32_kern_8x12; | ||||
} | } | ||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern, | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern, | ||||
@@ -837,6 +841,159 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x4x16, | |||||
aarch64::matmul::gemm_s8x8x16_4x4, int8_t, | aarch64::matmul::gemm_s8x8x16_4x4, int8_t, | ||||
int16_t); | int16_t); | ||||
/* ===================== Int8x8x16 K16x12x4 algo ===================== */ | |||||
namespace { | |||||
void int8x8x16_mk4_16x12x4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||||
midout_iv("int8x8x16_mk4_16x12x4_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto trA = kern_param.trA, trB = kern_param.trB; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
auto A_type = kern_param.A_type, B_type = kern_param.B_type, | |||||
C_type = kern_param.C_type; | |||||
const auto Aptr = kern_param.A<dt_int8>(), | |||||
Bptr = kern_param.B<dt_int8>(); | |||||
auto Cptr = kern_param.C<dt_int16>(); | |||||
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy(M, N, K, A_type, | |||||
B_type, C_type); | |||||
megdnn::matmul::GemmInterleaved< | |||||
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53>(M, N, K, trA, trB, | |||||
strategy) | |||||
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | |||||
kern_param.workspace_ptr); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} // anonymous namespace | |||||
bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
return can_be_treated_as_int8x8x16(kern_size_param) && | |||||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||||
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
!kern_size_param.trA && !kern_size_param.trB && | |||||
kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; | |||||
} | |||||
bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::preferred( | |||||
const KernSizeParam&) const { | |||||
#if !MGB_ENABLE_CPUINFO | |||||
return false; | |||||
#else | |||||
auto arch = cpuinfo_get_current_core()->uarch; | |||||
bool little_core = arch == cpuinfo_uarch_cortex_a53 || | |||||
arch == cpuinfo_uarch_cortex_a55; | |||||
return little_core; | |||||
#endif | |||||
} | |||||
size_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_workspace( | |||||
const KernSizeParam& kern_size_param) const { | |||||
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||||
midout_iv("AlgoInt8x8x16MK4_16x12x4::get_workspace"_hash)) { | |||||
auto M = kern_size_param.M, N = kern_size_param.N, | |||||
K = kern_size_param.K; | |||||
auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||||
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||||
C_type = kern_size_param.C_type; | |||||
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy(M, N, K, A_type, | |||||
B_type, C_type); | |||||
return megdnn::matmul::GemmInterleaved< | |||||
matmul::gemm_s8x8x16_mk4_16x12_a53>(M, N, K, trA, trB, | |||||
strategy) | |||||
.get_workspace_size(); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern( | |||||
const KernSizeParam&) const { | |||||
return int8x8x16_mk4_16x12x4_kern; | |||||
} | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( | |||||
AlgoInt8x8x16MK4_16x12x4, megdnn_aarch64_matmul_kern, | |||||
"AlgoInt8x8x16MK4_16x12x4Impl"_hash, | |||||
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t); | |||||
/* ===================== Int8x8x16 MK4 4x4x8 algo ===================== */ | |||||
namespace { | |||||
void int8x8x16_mk4_4x4x8_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||||
midout_iv("int8x8x16_mk4_4x4x8_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto trA = kern_param.trA, trB = kern_param.trB; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
auto A_type = kern_param.A_type, B_type = kern_param.B_type, | |||||
C_type = kern_param.C_type; | |||||
const auto Aptr = kern_param.A<dt_int8>(), | |||||
Bptr = kern_param.B<dt_int8>(); | |||||
auto Cptr = kern_param.C<dt_int16>(); | |||||
aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy(M, N, K, A_type, | |||||
B_type, C_type); | |||||
megdnn::matmul::GemmInterleaved< | |||||
aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72>(M, N, K, trA, trB, | |||||
strategy) | |||||
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | |||||
kern_param.workspace_ptr); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} // anonymous namespace | |||||
bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
return can_be_treated_as_int8x8x16(kern_size_param) && | |||||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||||
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
!kern_size_param.trA && !kern_size_param.trB && | |||||
kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; | |||||
} | |||||
bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::preferred( | |||||
const KernSizeParam&) const { | |||||
#if !MGB_ENABLE_CPUINFO | |||||
return false; | |||||
#else | |||||
auto arch = cpuinfo_get_current_core()->uarch; | |||||
bool little_core = arch == cpuinfo_uarch_cortex_a53 || | |||||
arch == cpuinfo_uarch_cortex_a55; | |||||
return !little_core; | |||||
#endif | |||||
} | |||||
size_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_workspace( | |||||
const KernSizeParam& kern_size_param) const { | |||||
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||||
midout_iv("AlgoInt8x8x16MK4_4x4x8::get_workspace"_hash)) { | |||||
auto M = kern_size_param.M, N = kern_size_param.N, | |||||
K = kern_size_param.K; | |||||
auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||||
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||||
C_type = kern_size_param.C_type; | |||||
aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy(M, N, K, A_type, | |||||
B_type, C_type); | |||||
return megdnn::matmul::GemmInterleaved< | |||||
matmul::gemm_s8x8x16_mk4_4x4_a72>(M, N, K, trA, trB, | |||||
strategy) | |||||
.get_workspace_size(); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_kern( | |||||
const KernSizeParam&) const { | |||||
return int8x8x16_mk4_4x4x8_kern; | |||||
} | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_4x4x8, | |||||
megdnn_aarch64_matmul_kern, | |||||
"AlgoInt8x8x16MK4_4x4x8_Impl"_hash, | |||||
aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72, | |||||
int8_t, int16_t); | |||||
/* ===================== Int16x16x32 K12x8x1 algo ===================== */ | /* ===================== Int16x16x32 K12x8x1 algo ===================== */ | ||||
namespace { | namespace { | ||||
void int16x16x32_k12x8x1_kern(const MatrixMulImpl::KernParam& kern_param) { | void int16x16x32_k12x8x1_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
@@ -121,12 +122,9 @@ public: | |||||
#else | #else | ||||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | |||||
return "AARCH64_INT8X8X32_MK4_4X4X16"; | |||||
} | |||||
const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -188,6 +186,36 @@ public: | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
return "AARCH64_INT8X8X16_MK4_16X12X4"; | |||||
} | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::DEFAULT; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::DEFAULT; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include <cmath> | #include <cmath> | ||||
@@ -993,8 +994,8 @@ static inline void interleave_4x1_4_s(const int32_t*& inptr0, | |||||
template <typename T> | template <typename T> | ||||
static inline void interleave_4x8_1_s(const T*& inptr0, const T*& inptr1, | static inline void interleave_4x8_1_s(const T*& inptr0, const T*& inptr1, | ||||
const T*& inptr2, const T*& inptr3, | |||||
T*& outptr) { | |||||
const T*& inptr2, const T*& inptr3, | |||||
T*& outptr) { | |||||
static_assert(sizeof(T) == 4, "only support size == 4"); | static_assert(sizeof(T) == 4, "only support size == 4"); | ||||
asm volatile( | asm volatile( | ||||
"ld1 {v0.4s, v1.4s}, [%[inptr0]], #32\n" | "ld1 {v0.4s, v1.4s}, [%[inptr0]], #32\n" | ||||
@@ -1140,8 +1141,8 @@ static inline void interleave_2x4_4_s(const T*& inptr0, const T*& inptr1, | |||||
"stp q2, q6, [%[outptr], #64]\n" | "stp q2, q6, [%[outptr], #64]\n" | ||||
"stp q3, q7, [%[outptr], #96]\n" | "stp q3, q7, [%[outptr], #96]\n" | ||||
: [ inptr0 ] "+r"(inptr0), [ inptr1 ] "+r"(inptr1), | |||||
[ outptr ] "+r"(outptr) | |||||
: | |||||
[inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory"); | : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory"); | ||||
} | } | ||||
@@ -1153,7 +1154,7 @@ static inline void interleave_1x4_4_s(const T*& inptr0, T* outptr) { | |||||
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" | "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" | ||||
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n" | "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n" | ||||
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) | |||||
: [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "memory"); | : "v0", "v1", "v2", "v3", "memory"); | ||||
} | } | ||||
@@ -1550,7 +1551,7 @@ static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { | |||||
"stp q2, q6, [%[outptr], #96] \n" | "stp q2, q6, [%[outptr], #96] \n" | ||||
"stp q10, q3, [%[outptr], #128] \n" | "stp q10, q3, [%[outptr], #128] \n" | ||||
"stp q7, q11, [%[outptr], #160] \n" | "stp q7, q11, [%[outptr], #160] \n" | ||||
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) | |||||
: [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | ||||
"v11", "memory"); | "v11", "memory"); | ||||
@@ -1564,7 +1565,7 @@ static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { | |||||
asm volatile( | asm volatile( | ||||
"ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" | "ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" | ||||
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n" | "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n" | ||||
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) | |||||
: [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "memory"); | : "v0", "v1", "v2", "v3", "memory"); | ||||
} | } | ||||
@@ -1681,13 +1682,12 @@ static inline void transpose_12x4_1_s(const T*& inptr0, const T*& inptr1, | |||||
"st1 {v3.4s,v4.4s,v5.4s}, [%[outptr]], #48\n" | "st1 {v3.4s,v4.4s,v5.4s}, [%[outptr]], #48\n" | ||||
"st1 {v6.4s,v7.4s,v8.4s}, [%[outptr]], #48\n" | "st1 {v6.4s,v7.4s,v8.4s}, [%[outptr]], #48\n" | ||||
"st1 {v24.4s,v25.4s,v26.4s}, [%[outptr]], #48\n" | "st1 {v24.4s,v25.4s,v26.4s}, [%[outptr]], #48\n" | ||||
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), | |||||
[inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), | |||||
[inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), | |||||
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), | |||||
[inptr8] "+r"(inptr8), [inptr9] "+r"(inptr9), | |||||
[inptr10] "+r"(inptr10), [inptr11] "+r"(inptr11), | |||||
[outptr] "+r"(outptr) | |||||
: | |||||
[inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), | |||||
[inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), | |||||
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), | |||||
[inptr9] "+r"(inptr9), [inptr10] "+r"(inptr10), | |||||
[inptr11] "+r"(inptr11), [outptr] "+r"(outptr) | |||||
: | : | ||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | ||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | ||||
@@ -1972,6 +1972,135 @@ static inline void transpose_interleave_1x4_4_b(const T*& inptr0, T* outptr, | |||||
: "v0", "v1", "v2", "v3", "v4", "memory"); | : "v0", "v1", "v2", "v3", "v4", "memory"); | ||||
} | } | ||||
static inline void interleave_4x4_16x4_s8_s16(const int8_t* inptr0, | |||||
const int8_t* inptr1, | |||||
const int8_t* inptr2, | |||||
const int8_t* inptr3, | |||||
int16_t* outptr) { | |||||
int8x16_t row0 = vld1q_s8(inptr0); | |||||
int16x8_t row0_01 = vmovl_low_s8(row0); | |||||
int16x8_t row0_23 = vmovl_high_s8(row0); | |||||
int16x4_t row0_0 = vget_low_s16(row0_01); | |||||
int16x4_t row0_1 = vget_high_s16(row0_01); | |||||
int16x4_t row0_2 = vget_low_s16(row0_23); | |||||
int16x4_t row0_3 = vget_high_s16(row0_23); | |||||
int8x16_t row1 = vld1q_s8(inptr1); | |||||
int16x8_t row1_01 = vmovl_low_s8(row1); | |||||
int16x8_t row1_23 = vmovl_high_s8(row1); | |||||
int16x4_t row1_0 = vget_low_s16(row1_01); | |||||
int16x4_t row1_1 = vget_high_s16(row1_01); | |||||
int16x4_t row1_2 = vget_low_s16(row1_23); | |||||
int16x4_t row1_3 = vget_high_s16(row1_23); | |||||
int8x16_t row2 = vld1q_s8(inptr2); | |||||
int16x8_t row2_01 = vmovl_low_s8(row2); | |||||
int16x8_t row2_23 = vmovl_high_s8(row2); | |||||
int16x4_t row2_0 = vget_low_s16(row2_01); | |||||
int16x4_t row2_1 = vget_high_s16(row2_01); | |||||
int16x4_t row2_2 = vget_low_s16(row2_23); | |||||
int16x4_t row2_3 = vget_high_s16(row2_23); | |||||
int8x16_t row3 = vld1q_s8(inptr3); | |||||
int16x8_t row3_01 = vmovl_low_s8(row3); | |||||
int16x8_t row3_23 = vmovl_high_s8(row3); | |||||
int16x4_t row3_0 = vget_low_s16(row3_01); | |||||
int16x4_t row3_1 = vget_high_s16(row3_01); | |||||
int16x4_t row3_2 = vget_low_s16(row3_23); | |||||
int16x4_t row3_3 = vget_high_s16(row3_23); | |||||
vst1_s16(outptr, row0_0); | |||||
vst1_s16(outptr + 1 * 4, row1_0); | |||||
vst1_s16(outptr + 2 * 4, row2_0); | |||||
vst1_s16(outptr + 3 * 4, row3_0); | |||||
vst1_s16(outptr + 4 * 4, row0_1); | |||||
vst1_s16(outptr + 5 * 4, row1_1); | |||||
vst1_s16(outptr + 6 * 4, row2_1); | |||||
vst1_s16(outptr + 7 * 4, row3_1); | |||||
vst1_s16(outptr + 8 * 4, row0_2); | |||||
vst1_s16(outptr + 9 * 4, row1_2); | |||||
vst1_s16(outptr + 10 * 4, row2_2); | |||||
vst1_s16(outptr + 11 * 4, row3_2); | |||||
vst1_s16(outptr + 12 * 4, row0_3); | |||||
vst1_s16(outptr + 13 * 4, row1_3); | |||||
vst1_s16(outptr + 14 * 4, row2_3); | |||||
vst1_s16(outptr + 15 * 4, row3_3); | |||||
}; | |||||
static inline void interleave_4x4_8x4_s8_s16(const int8_t* inptr0, | |||||
const int8_t* inptr1, | |||||
int16_t* outptr) { | |||||
int8x16_t row0 = vld1q_s8(inptr0); | |||||
int16x8_t row0_01 = vmovl_low_s8(row0); | |||||
int16x8_t row0_23 = vmovl_high_s8(row0); | |||||
int16x4_t row0_0 = vget_low_s16(row0_01); | |||||
int16x4_t row0_1 = vget_high_s16(row0_01); | |||||
int16x4_t row0_2 = vget_low_s16(row0_23); | |||||
int16x4_t row0_3 = vget_high_s16(row0_23); | |||||
int8x16_t row1 = vld1q_s8(inptr1); | |||||
int16x8_t row1_01 = vmovl_low_s8(row1); | |||||
int16x8_t row1_23 = vmovl_high_s8(row1); | |||||
int16x4_t row1_0 = vget_low_s16(row1_01); | |||||
int16x4_t row1_1 = vget_high_s16(row1_01); | |||||
int16x4_t row1_2 = vget_low_s16(row1_23); | |||||
int16x4_t row1_3 = vget_high_s16(row1_23); | |||||
vst1_s16(outptr, row0_0); | |||||
vst1_s16(outptr + 1 * 4, row1_0); | |||||
vst1_s16(outptr + 2 * 4, row0_1); | |||||
vst1_s16(outptr + 3 * 4, row1_1); | |||||
vst1_s16(outptr + 4 * 4, row0_2); | |||||
vst1_s16(outptr + 5 * 4, row1_2); | |||||
vst1_s16(outptr + 6 * 4, row0_3); | |||||
vst1_s16(outptr + 7 * 4, row1_3); | |||||
}; | |||||
static inline void memcpy_s8_s16(const int8_t* inptr, int16_t* outptr, | |||||
int count) { | |||||
for (; count >= 32; count -= 32) { | |||||
int8x8_t in0 = vld1_s8(inptr); | |||||
int8x8_t in1 = vld1_s8(inptr + 1 * 8); | |||||
int8x8_t in2 = vld1_s8(inptr + 2 * 8); | |||||
int8x8_t in3 = vld1_s8(inptr + 3 * 8); | |||||
vst1q_s16(outptr, vmovl_s8(in0)); | |||||
vst1q_s16(outptr + 1 * 8, vmovl_s8(in1)); | |||||
vst1q_s16(outptr + 2 * 8, vmovl_s8(in2)); | |||||
vst1q_s16(outptr + 3 * 8, vmovl_s8(in3)); | |||||
inptr += 32; | |||||
outptr += 32; | |||||
} | |||||
for (; count >= 8; count -= 8) { | |||||
int8x8_t in0 = vld1_s8(inptr); | |||||
vst1q_s16(outptr, vmovl_s8(in0)); | |||||
inptr += 8; | |||||
outptr += 8; | |||||
} | |||||
for (; count > 0; --count) { | |||||
*outptr++ = (int16_t)(*inptr++); | |||||
} | |||||
} | |||||
static inline void transpos_12x4_s8(const int8_t* inptr0, int8_t* outptr) { | |||||
static const uint8_t src_idx_buffer[16] = {0, 4, 8, 12, 1, 5, 9, 13, | |||||
2, 6, 10, 14, 3, 7, 11, 15}; | |||||
static const uint8x16_t vtbl = vld1q_u8(&src_idx_buffer[0]); | |||||
int8x8x4_t input = vld4_s8(inptr0); | |||||
int8x16_t input2 = vqtbl1q_s8(vld1q_s8(inptr0 + 4 * 8), vtbl); | |||||
vst1_s8(outptr, input.val[0]); | |||||
vst1q_lane_s32(reinterpret_cast<int32_t*>(outptr + 8), | |||||
vreinterpretq_s32_s8(input2), 0); | |||||
vst1_s8(outptr + 1 * 12, input.val[1]); | |||||
vst1q_lane_s32(reinterpret_cast<int32_t*>(outptr + 1 * 12 + 8), | |||||
vreinterpretq_s32_s8(input2), 1); | |||||
vst1_s8(outptr + 2 * 12, input.val[2]); | |||||
vst1q_lane_s32(reinterpret_cast<int32_t*>(outptr + 2 * 12 + 8), | |||||
vreinterpretq_s32_s8(input2), 2); | |||||
vst1_s8(outptr + 3 * 12, input.val[3]); | |||||
vst1q_lane_s32(reinterpret_cast<int32_t*>(outptr + 3 * 12 + 8), | |||||
vreinterpretq_s32_s8(input2), 3); | |||||
} | |||||
} // namespace aarch64 | } // namespace aarch64 | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -6,42 +6,55 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/aarch64/matrix_mul/fp32/strategy.h" | |||||
#include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" | #include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" | ||||
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" | #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" | ||||
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h" | |||||
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h" | |||||
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h" | #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h" | ||||
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h" | |||||
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h" | |||||
#include "src/aarch64/matrix_mul/fp32/strategy.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#if MGB_ENABLE_CPUINFO | |||||
#include "cpuinfo.h" | |||||
#endif | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace aarch64; | using namespace aarch64; | ||||
using namespace aarch64::matmul; | using namespace aarch64::matmul; | ||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16); | MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16); | ||||
void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, | |||||
int ymax, int k0, int kmax, bool transpose_A) const { | |||||
void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, int ymax, | |||||
int k0, int kmax, bool transpose_A) const { | |||||
if (transpose_A) { | if (transpose_A) { | ||||
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); | |||||
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} else { | } else { | ||||
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||||
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} | } | ||||
} | } | ||||
void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | ||||
int k0, int kmax, bool transpose_B) const { | int k0, int kmax, bool transpose_B) const { | ||||
if (transpose_B) { | if (transpose_B) { | ||||
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | |||||
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
} else { | } else { | ||||
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
} | } | ||||
} | } | ||||
void sgemm_4x16::kern(const float* packA, const float* packB, | |||||
size_t M, size_t N, size_t K, float* C, size_t LDC, | |||||
bool is_first_k, const float*, float*) const { | |||||
void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, | |||||
size_t N, size_t K, float* C, size_t LDC, bool is_first_k, | |||||
const float*, float*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | ||||
A_dtype.enumv() == C_dtype.enumv() && | A_dtype.enumv() == C_dtype.enumv() && | ||||
A_dtype.enumv() == DTypeEnum::Float32); | A_dtype.enumv() == DTypeEnum::Float32); | ||||
@@ -61,15 +74,17 @@ void sgemm_4x16::kern(const float* packA, const float* packB, | |||||
size_t n = 0; | size_t n = 0; | ||||
const float* cur_packB = packB; | const float* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, | |||||
is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K16; | cur_packB += K16; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_general_4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
matmul_general_4x16::kern_4x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
output += 4; | output += 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -80,8 +95,8 @@ void sgemm_4x16::kern(const float* packA, const float* packB, | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12); | MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12); | ||||
void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, | |||||
int ymax, int k0, int kmax, bool transpose_A) const { | |||||
void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, int ymax, | |||||
int k0, int kmax, bool transpose_A) const { | |||||
if (transpose_A) { | if (transpose_A) { | ||||
matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, | matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, | ||||
kmax); | kmax); | ||||
@@ -102,16 +117,10 @@ void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | |||||
} | } | ||||
} | } | ||||
void sgemm_8x12::kern(const float* packA, const float* packB, | |||||
size_t M, size_t N, size_t K, float* C, size_t LDC, | |||||
bool is_first_k, const float*, float*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
A_dtype.enumv() == C_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Float32); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
template <typename gemm_class> | |||||
static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||||
size_t M, size_t N, size_t K, float* C, | |||||
size_t LDC, bool is_first_k) { | |||||
constexpr size_t A_INTERLEAVE = 8; | constexpr size_t A_INTERLEAVE = 8; | ||||
constexpr size_t A_INTERLEAVE4 = 4; | constexpr size_t A_INTERLEAVE4 = 4; | ||||
constexpr size_t B_INTERLEAVE = 12; | constexpr size_t B_INTERLEAVE = 12; | ||||
@@ -126,16 +135,14 @@ void sgemm_8x12::kern(const float* packA, const float* packB, | |||||
size_t n = 0; | size_t n = 0; | ||||
const float* cur_packB = packB; | const float* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) { | ||||
matmul_general_8x12::kern_8x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
gemm_class::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K12; | cur_packB += K12; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_general_8x12::kern_8x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
gemm_class::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4; | output += 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -146,17 +153,16 @@ void sgemm_8x12::kern(const float* packA, const float* packB, | |||||
size_t n = 0; | size_t n = 0; | ||||
const float* cur_packB = packB; | const float* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_general_8x12::kern_4x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
gemm_class::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4)); | |||||
output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
cur_packB += K12; | cur_packB += K12; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_general_8x12::kern_4x4( | |||||
packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
gemm_class::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(M - m, 4), | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4; | output += 4; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -164,6 +170,33 @@ void sgemm_8x12::kern(const float* packA, const float* packB, | |||||
} | } | ||||
} | } | ||||
void sgemm_8x12::kern(const float* packA, const float* packB, size_t M, | |||||
size_t N, size_t K, float* C, size_t LDC, bool is_first_k, | |||||
const float*, float*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
A_dtype.enumv() == C_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Float32); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
#if !MGB_ENABLE_CPUINFO | |||||
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, | |||||
is_first_k); | |||||
#else | |||||
auto arch = cpuinfo_get_current_core()->uarch; | |||||
if (arch == cpuinfo_uarch_cortex_a53) { | |||||
sgemm_8x12_helper<matmul_general_8x12_a53>(packA, packB, M, N, K, C, | |||||
LDC, is_first_k); | |||||
} else if (arch == cpuinfo_uarch_cortex_a55) { | |||||
sgemm_8x12_helper<matmul_general_8x12_a55>(packA, packB, M, N, K, C, | |||||
LDC, is_first_k); | |||||
} else { | |||||
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, | |||||
is_first_k); | |||||
} | |||||
#endif | |||||
} | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12); | MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12); | ||||
void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0, | void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0, | ||||
@@ -180,25 +213,17 @@ void sgemm_mk4_8x12::pack_B(float* out, const float* in, int ldin, int x0, | |||||
matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); | matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); | ||||
} | } | ||||
void sgemm_mk4_8x12::kern(const float* packA, const float* packB, | |||||
size_t M, size_t N, size_t K, float* C, size_t LDC, | |||||
bool is_first_k, const float*, float*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
A_dtype.enumv() == C_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Float32); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | |||||
template <typename gemm_name> | |||||
static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, | |||||
size_t M, size_t N, size_t K, float* C, | |||||
size_t LDC, bool is_first_k) { | |||||
const int K12 = K * 12; | |||||
const int K8 = K * 8; | |||||
const int K4 = K * 4; | |||||
constexpr size_t PACK_C_SIZE = 4; | constexpr size_t PACK_C_SIZE = 4; | ||||
constexpr size_t A_INTERLEAVE = 8; | constexpr size_t A_INTERLEAVE = 8; | ||||
constexpr size_t A_INTERLEAVE4 = 4; | constexpr size_t A_INTERLEAVE4 = 4; | ||||
constexpr size_t B_INTERLEAVE = 12; | constexpr size_t B_INTERLEAVE = 12; | ||||
const int K12 = K * 12; | |||||
const int K8 = K * 8; | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | size_t m = 0; | ||||
for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) { | for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) { | ||||
float* output = C + (m / PACK_C_SIZE * LDC); | float* output = C + (m / PACK_C_SIZE * LDC); | ||||
@@ -206,15 +231,14 @@ void sgemm_mk4_8x12::kern(const float* packA, const float* packB, | |||||
size_t n = 0; | size_t n = 0; | ||||
const float* cur_packB = packB; | const float* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) { | ||||
matmul_mk4_8x12::kern_8x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
gemm_name::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k); | |||||
output += B_INTERLEAVE * PACK_C_SIZE; | output += B_INTERLEAVE * PACK_C_SIZE; | ||||
cur_packB += K12; | cur_packB += K12; | ||||
} | } | ||||
for (; n < N; n += 4) { | |||||
matmul_mk4_8x12::kern_8x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(N - n, 4)); | |||||
for (; n < N; n += 4) { | |||||
gemm_name::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4 * PACK_C_SIZE; | output += 4 * PACK_C_SIZE; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
@@ -225,19 +249,45 @@ void sgemm_mk4_8x12::kern(const float* packA, const float* packB, | |||||
size_t n = 0; | size_t n = 0; | ||||
const float* cur_packB = packB; | const float* cur_packB = packB; | ||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
matmul_mk4_8x12::kern_4x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
gemm_name::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k); | |||||
output += B_INTERLEAVE * PACK_C_SIZE; | output += B_INTERLEAVE * PACK_C_SIZE; | ||||
cur_packB += K12; | cur_packB += K12; | ||||
} | } | ||||
for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
matmul_mk4_8x12::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, std::min<size_t>(N - n, 4)); | |||||
gemm_name::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += 4 * PACK_C_SIZE; | output += 4 * PACK_C_SIZE; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
} | } | ||||
packA += K4; | packA += K4; | ||||
} | } | ||||
} | } | ||||
void sgemm_mk4_8x12::kern(const float* packA, const float* packB, size_t M, | |||||
size_t N, size_t K, float* C, size_t LDC, | |||||
bool is_first_k, const float*, float*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
A_dtype.enumv() == C_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Float32); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | |||||
#if !MGB_ENABLE_CPUINFO | |||||
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, | |||||
is_first_k); | |||||
#else | |||||
auto arch = cpuinfo_get_current_core()->uarch; | |||||
if (arch == cpuinfo_uarch_cortex_a53) { | |||||
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>(packA, packB, M, N, K, C, | |||||
LDC, is_first_k); | |||||
} else if (arch == cpuinfo_uarch_cortex_a55) { | |||||
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>(packA, packB, M, N, K, C, | |||||
LDC, is_first_k); | |||||
} else { | |||||
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, | |||||
is_first_k); | |||||
} | |||||
#endif | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "src/fallback/matrix_mul/gemm_common.h" | #include "src/fallback/matrix_mul/gemm_common.h" | ||||
@@ -0,0 +1,387 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include <inttypes.h> | |||||
#include "src/aarch64/matrix_mul/asm/common.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul_mk4_4x4x8_a72 { | |||||
//! optimize for A72 | |||||
// clang-format off | |||||
/** | |||||
* Overview of register layout: | |||||
* | |||||
* A 4x4x8 cell of Lhs is stored in 8bit in q0-q3, q4-q7 | |||||
* A 4x4x8 cell of Rhs is stored in 8bit in q8-q11, q12-q15 | |||||
* A 4x4 block of accumulators is stored in 16bit in q16-q31 | |||||
* | |||||
* +------------------------+ | |||||
* | q8 | q9 | q10 | q11 | | |||||
* Rhs +------------------------+ | |||||
* Lhs | | | | | | |||||
* +--------+ - - - - +------------------------+ | |||||
* | q0 | | q16 | q20 | q24 | q28 | | |||||
* | q1 | | q17 | q21 | q25 | q29 | | |||||
* | q2 | | q18 | q22 | q26 | q30 | | |||||
* | q3 | | q19 | q23 | q27 | q31 | | |||||
* +--------+ - - - - +------------------------+ | |||||
* | |||||
* Accumulator | |||||
*/ | |||||
// clang-format on | |||||
static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool, int remain_n) { | |||||
K = div_ceil(K, 8); | |||||
int oddk = (K & 1); | |||||
K = ((K + 1) / 2) - 1; | |||||
const int8_t* a_ptr = packA; | |||||
const int8_t* b_ptr = packB; | |||||
LDC = LDC * sizeof(int8_t); | |||||
// clang-format off | |||||
#define STORE_LINE(reg0) \ | |||||
"cmp w10, #0 \n" \ | |||||
"beq 101f\n" \ | |||||
"st1 {v" reg0 ".4h}, [x0], #8\n" \ | |||||
"subs w10, w10, #1\n" | |||||
#define STORE_C \ | |||||
"mov w10, %w[remain_n]\n" \ | |||||
STORE_LINE("16") \ | |||||
STORE_LINE("20") \ | |||||
STORE_LINE("24") \ | |||||
STORE_LINE("28") | |||||
// clang-format on | |||||
register int16_t* outptr asm("x0") = output; | |||||
asm volatile( | |||||
// load accumulator C | |||||
"1:\n" | |||||
"eor v16.16b, v16.16b, v16.16b\n" | |||||
"eor v17.16b, v17.16b, v17.16b\n" | |||||
"eor v18.16b, v18.16b, v18.16b\n" | |||||
"eor v19.16b, v19.16b, v19.16b\n" | |||||
"eor v20.16b, v20.16b, v20.16b\n" | |||||
"eor v21.16b, v21.16b, v21.16b\n" | |||||
"eor v22.16b, v22.16b, v22.16b\n" | |||||
"eor v23.16b, v23.16b, v23.16b\n" | |||||
"eor v24.16b, v24.16b, v24.16b\n" | |||||
"eor v25.16b, v25.16b, v25.16b\n" | |||||
"eor v26.16b, v26.16b, v26.16b\n" | |||||
"eor v27.16b, v27.16b, v27.16b\n" | |||||
"eor v28.16b, v28.16b, v28.16b\n" | |||||
"eor v29.16b, v29.16b, v29.16b\n" | |||||
"eor v30.16b, v30.16b, v30.16b\n" | |||||
"eor v31.16b, v31.16b, v31.16b\n" | |||||
"2: \n" | |||||
"ld1 {v0.8b, v1.8b}, [%[a_ptr]], #16\n" | |||||
"ld1 {v2.8b, v3.8b}, [%[a_ptr]], #16\n" | |||||
"ld1 {v8.8b, v9.8b}, [%[b_ptr]], #16\n" | |||||
"ld1 {v10.8b, v11.8b}, [%[b_ptr]], #16\n" | |||||
"cmp %w[K], #0\n" | |||||
"beq 4f\n" | |||||
"3: \n" | |||||
//! k = 0 | |||||
"smlal v16.8h, v0.8b, v8.8b\n" | |||||
"ld1 {v4.8b}, [%[a_ptr]], #8\n" | |||||
"smlal v17.8h, v1.8b, v8.8b\n" | |||||
"smlal v18.8h, v2.8b, v8.8b\n" | |||||
"ld1 {v5.8b}, [%[a_ptr]], #8\n" | |||||
"smlal v19.8h, v3.8b, v8.8b\n" | |||||
"smlal v20.8h, v0.8b, v9.8b\n" | |||||
"ld1 {v6.8b}, [%[a_ptr]], #8\n" | |||||
"smlal v21.8h, v1.8b, v9.8b\n" | |||||
"smlal v22.8h, v2.8b, v9.8b\n" | |||||
"ld1 {v7.8b}, [%[a_ptr]], #8\n" | |||||
"smlal v23.8h, v3.8b, v9.8b\n" | |||||
"smlal v24.8h, v0.8b, v10.8b\n" | |||||
"ld1 {v12.8b}, [%[b_ptr]], #8\n" | |||||
"smlal v25.8h, v1.8b, v10.8b\n" | |||||
"smlal v26.8h, v2.8b, v10.8b\n" | |||||
"ld1 {v13.8b}, [%[b_ptr]], #8\n" | |||||
"smlal v27.8h, v3.8b, v10.8b\n" | |||||
"smlal v28.8h, v0.8b, v11.8b\n" | |||||
"ld1 {v14.8b}, [%[b_ptr]], #8\n" | |||||
"smlal v29.8h, v1.8b, v11.8b\n" | |||||
"smlal v30.8h, v2.8b, v11.8b\n" | |||||
"ld1 {v15.8b}, [%[b_ptr]], #8\n" | |||||
"smlal v31.8h, v3.8b, v11.8b\n" | |||||
//! k = 8 | |||||
"smlal v16.8h, v4.8b, v12.8b\n" | |||||
"ld1 {v0.8b}, [%[a_ptr]], #8\n" | |||||
"smlal v17.8h, v5.8b, v12.8b\n" | |||||
"smlal v18.8h, v6.8b, v12.8b\n" | |||||
"ld1 {v1.8b}, [%[a_ptr]], #8\n" | |||||
"smlal v19.8h, v7.8b, v12.8b\n" | |||||
"smlal v20.8h, v4.8b, v13.8b\n" | |||||
"ld1 {v2.8b}, [%[a_ptr]], #8\n" | |||||
"smlal v21.8h, v5.8b, v13.8b\n" | |||||
"smlal v22.8h, v6.8b, v13.8b\n" | |||||
"ld1 {v3.8b}, [%[a_ptr]], #8\n" | |||||
"smlal v23.8h, v7.8b, v13.8b\n" | |||||
"smlal v24.8h, v4.8b, v14.8b\n" | |||||
"ld1 {v8.8b}, [%[b_ptr]], #8\n" | |||||
"smlal v25.8h, v5.8b, v14.8b\n" | |||||
"smlal v26.8h, v6.8b, v14.8b\n" | |||||
"ld1 {v9.8b}, [%[b_ptr]], #8\n" | |||||
"smlal v27.8h, v7.8b, v14.8b\n" | |||||
"smlal v28.8h, v4.8b, v15.8b\n" | |||||
"ld1 {v10.8b}, [%[b_ptr]], #8\n" | |||||
"smlal v29.8h, v5.8b, v15.8b\n" | |||||
"smlal v30.8h, v6.8b, v15.8b\n" | |||||
"ld1 {v11.8b}, [%[b_ptr]], #8\n" | |||||
"smlal v31.8h, v7.8b, v15.8b\n" | |||||
"subs %w[K], %w[K], #1\n" | |||||
"bne 3b\n" | |||||
"4:\n" | |||||
"cmp %w[oddk], #1\n" | |||||
"beq 5f\n" | |||||
//! even tail | |||||
//! k = 0 | |||||
"smlal v16.8h, v0.8b, v8.8b\n" | |||||
"ld1 {v4.8b}, [%[a_ptr]], #8\n" | |||||
"smlal v17.8h, v1.8b, v8.8b\n" | |||||
"smlal v18.8h, v2.8b, v8.8b\n" | |||||
"ld1 {v5.8b}, [%[a_ptr]], #8\n" | |||||
"smlal v19.8h, v3.8b, v8.8b\n" | |||||
"smlal v20.8h, v0.8b, v9.8b\n" | |||||
"ld1 {v6.8b}, [%[a_ptr]], #8\n" | |||||
"smlal v21.8h, v1.8b, v9.8b\n" | |||||
"smlal v22.8h, v2.8b, v9.8b\n" | |||||
"ld1 {v7.8b}, [%[a_ptr]], #8\n" | |||||
"smlal v23.8h, v3.8b, v9.8b\n" | |||||
"smlal v24.8h, v0.8b, v10.8b\n" | |||||
"ld1 {v12.8b}, [%[b_ptr]], #8\n" | |||||
"smlal v25.8h, v1.8b, v10.8b\n" | |||||
"smlal v26.8h, v2.8b, v10.8b\n" | |||||
"ld1 {v13.8b}, [%[b_ptr]], #8\n" | |||||
"smlal v27.8h, v3.8b, v10.8b\n" | |||||
"smlal v28.8h, v0.8b, v11.8b\n" | |||||
"ld1 {v14.8b}, [%[b_ptr]], #8\n" | |||||
"smlal v29.8h, v1.8b, v11.8b\n" | |||||
"smlal v30.8h, v2.8b, v11.8b\n" | |||||
"ld1 {v15.8b}, [%[b_ptr]], #8\n" | |||||
"smlal v31.8h, v3.8b, v11.8b\n" | |||||
//! k = 8 | |||||
"smlal v16.8h, v4.8b, v12.8b\n" | |||||
"smlal v17.8h, v5.8b, v12.8b\n" | |||||
"smlal v18.8h, v6.8b, v12.8b\n" | |||||
"smlal v19.8h, v7.8b, v12.8b\n" | |||||
"smlal v20.8h, v4.8b, v13.8b\n" | |||||
"smlal v21.8h, v5.8b, v13.8b\n" | |||||
"smlal v22.8h, v6.8b, v13.8b\n" | |||||
"smlal v23.8h, v7.8b, v13.8b\n" | |||||
"smlal v24.8h, v4.8b, v14.8b\n" | |||||
"smlal v25.8h, v5.8b, v14.8b\n" | |||||
"smlal v26.8h, v6.8b, v14.8b\n" | |||||
"smlal v27.8h, v7.8b, v14.8b\n" | |||||
"smlal v28.8h, v4.8b, v15.8b\n" | |||||
"smlal v29.8h, v5.8b, v15.8b\n" | |||||
"smlal v30.8h, v6.8b, v15.8b\n" | |||||
"smlal v31.8h, v7.8b, v15.8b\n" | |||||
"b 6f\n" | |||||
"5:\n" | |||||
//! odd tail | |||||
"smlal v16.8h, v0.8b, v8.8b\n" | |||||
"smlal v17.8h, v1.8b, v8.8b\n" | |||||
"smlal v18.8h, v2.8b, v8.8b\n" | |||||
"smlal v19.8h, v3.8b, v8.8b\n" | |||||
"smlal v20.8h, v0.8b, v9.8b\n" | |||||
"smlal v21.8h, v1.8b, v9.8b\n" | |||||
"smlal v22.8h, v2.8b, v9.8b\n" | |||||
"smlal v23.8h, v3.8b, v9.8b\n" | |||||
"smlal v24.8h, v0.8b, v10.8b\n" | |||||
"smlal v25.8h, v1.8b, v10.8b\n" | |||||
"smlal v26.8h, v2.8b, v10.8b\n" | |||||
"smlal v27.8h, v3.8b, v10.8b\n" | |||||
"smlal v28.8h, v0.8b, v11.8b\n" | |||||
"smlal v29.8h, v1.8b, v11.8b\n" | |||||
"smlal v30.8h, v2.8b, v11.8b\n" | |||||
"smlal v31.8h, v3.8b, v11.8b\n" | |||||
"6:\n" | |||||
//! reduece | |||||
"addp v16.8h, v16.8h, v17.8h\n" | |||||
"addp v18.8h, v18.8h, v19.8h\n" | |||||
"addp v20.8h, v20.8h, v21.8h\n" | |||||
"addp v22.8h, v22.8h, v23.8h\n" | |||||
"addp v24.8h, v24.8h, v25.8h\n" | |||||
"addp v26.8h, v26.8h, v27.8h\n" | |||||
"addp v16.8h, v16.8h, v18.8h\n" | |||||
"addp v28.8h, v28.8h, v29.8h\n" | |||||
"addp v30.8h, v30.8h, v31.8h\n" | |||||
"addp v20.8h, v20.8h, v22.8h\n" | |||||
"addp v16.8h, v16.8h, v16.8h\n" | |||||
"addp v20.8h, v20.8h, v20.8h\n" | |||||
"addp v24.8h, v24.8h, v26.8h\n" | |||||
"addp v24.8h, v24.8h, v24.8h\n" | |||||
"addp v28.8h, v28.8h, v30.8h\n" | |||||
"addp v28.8h, v28.8h, v28.8h\n" | |||||
"cmp %w[remain_n], #4\n" | |||||
"bne 7f\n" | |||||
"st1 {v16.4h}, [x0], #8\n" | |||||
"st1 {v20.4h}, [x0], #8\n" | |||||
"st1 {v24.4h}, [x0], #8\n" | |||||
"st1 {v28.4h}, [x0], #8\n" | |||||
"b 101f\n" | |||||
"7:\n" STORE_C | |||||
"101:\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[oddk] "+r"(oddk), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||||
[remain_n] "+r"(remain_n) | |||||
: | |||||
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
"v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||||
"x8", "x9", "x10", "cc", "memory"); | |||||
#undef STORE_C | |||||
#undef STORE_LINE | |||||
} | |||||
static inline void transpose_8x4_b(const dt_int8* inptr, dt_int8* outptr) { | |||||
int8x8x4_t in0 = vld4_s8(inptr); | |||||
vst1_s8(outptr + 0 * 8, in0.val[0]); | |||||
vst1_s8(outptr + 1 * 8, in0.val[1]); | |||||
vst1_s8(outptr + 2 * 8, in0.val[2]); | |||||
vst1_s8(outptr + 3 * 8, in0.val[3]); | |||||
} | |||||
static inline void interleve_8x4_b(const dt_int8* inptr, const dt_int8* inptr2, | |||||
dt_int8* outptr) { | |||||
int8x16_t in0 = vld1q_s8(inptr); | |||||
int8x16_t in1 = vld1q_s8(inptr2); | |||||
int32x4x2_t in_x2 = { | |||||
{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||||
vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); | |||||
} | |||||
static inline void interleve_8x4_b_pad(const dt_int8* inptr, dt_int8* outptr) { | |||||
int8x16_t in0 = vld1q_s8(inptr); | |||||
int8x16_t in1 = vdupq_n_s8(0); | |||||
int32x4x2_t in_x2 = { | |||||
{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||||
vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); | |||||
} | |||||
static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in, | |||||
int ldin, int m0, int mmax, int k0, | |||||
int kmax) { | |||||
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | |||||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||||
constexpr int pack_m = 4; | |||||
constexpr int pack_k = 8; | |||||
constexpr int pack_size = 4; | |||||
const int ksize = kmax - k0; | |||||
const int remain_k = ksize % pack_k; | |||||
const int kend = kmax - remain_k; | |||||
int8_t tmpbuff[pack_m * pack_k]{0}; | |||||
for (int m_idx = m0; m_idx < mmax; m_idx += pack_m) { | |||||
const int8_t* inptr0 = in + m_idx / pack_size * ldin + k0; | |||||
for (int k_idx = k0; k_idx < kend; k_idx += pack_k) { | |||||
transpose_8x4_b(inptr0, out); | |||||
inptr0 += pack_m * pack_k; | |||||
out += pack_m * pack_k; | |||||
} | |||||
if (remain_k > 0) { | |||||
int8x16_t tmp = vld1q_s8(inptr0); | |||||
vst1q_s8(&tmpbuff[0], tmp); | |||||
transpose_8x4_b(&tmpbuff[0], out); | |||||
inptr0 += pack_m * pack_size; | |||||
out += pack_m * pack_k; | |||||
} | |||||
} | |||||
} | |||||
static void gemm_s8x8x16_mk4_4x4x8_pack_B(dt_int8* out, const dt_int8* in, | |||||
int ldin, int n0, int nmax, int k0, | |||||
int kmax) { | |||||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||||
constexpr int pack_n = 4; | |||||
constexpr int pack_k = 8; | |||||
constexpr int pack_size = 4; | |||||
const int ksize = kmax - k0; | |||||
const int packed_ksize = round_up(ksize, pack_k); | |||||
const int remain_k = ksize % pack_k; | |||||
const int kend = kmax - remain_k; | |||||
const int nsize = nmax - n0; | |||||
const int remain_n = nsize % pack_n; | |||||
const int nend = nmax - remain_n; | |||||
const int stride_input = pack_size * nsize; | |||||
int8_t tmpbuff[pack_n * pack_k]{0}; | |||||
int8_t tmpbuff2[pack_n * pack_k]{0}; | |||||
for (int k_idx = k0; k_idx < kend; k_idx += pack_k) { | |||||
const int8_t* inptr = in + k_idx / pack_size * ldin + n0 * pack_size; | |||||
const int8_t* inptr2 = inptr + stride_input; | |||||
int8_t* outptr = out + k_idx * pack_n; | |||||
for (int n_idx = n0; n_idx < nend; n_idx += pack_n) { | |||||
interleve_8x4_b(inptr, inptr2, outptr); | |||||
inptr += pack_n * pack_size; | |||||
inptr2 += pack_n * pack_size; | |||||
outptr += pack_n * packed_ksize; | |||||
} | |||||
if (remain_n > 0) { | |||||
memcpy(&tmpbuff[0], inptr, remain_n * pack_size * sizeof(int8_t)); | |||||
memcpy(&tmpbuff2[0], inptr2, remain_n * pack_size * sizeof(int8_t)); | |||||
interleve_8x4_b(&tmpbuff[0], &tmpbuff2[0], outptr); | |||||
outptr += pack_n * packed_ksize; | |||||
} | |||||
} | |||||
if (remain_k > 0) { | |||||
const int8_t* inptr = in + kend / pack_size * ldin + n0 * pack_size; | |||||
int8_t* outptr = out + kend * pack_n; | |||||
for (int n_idx = n0; n_idx < nend; n_idx += pack_n) { | |||||
interleve_8x4_b_pad(inptr, outptr); | |||||
inptr += pack_n * pack_size; | |||||
outptr += pack_n * packed_ksize; | |||||
} | |||||
if (remain_n > 0) { | |||||
memcpy(&tmpbuff[0], inptr, remain_n * pack_size * sizeof(int8_t)); | |||||
interleve_8x4_b_pad(&tmpbuff[0], outptr); | |||||
outptr += pack_n * packed_ksize; | |||||
} | |||||
} | |||||
} | |||||
} // namespace matmul_mk4_4x4x8_a72 | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -6,12 +6,15 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" | #include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" | ||||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" | #include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" | ||||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h" | |||||
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h" | |||||
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | #include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | ||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -197,4 +200,161 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||||
packA += K4; | packA += K4; | ||||
} | } | ||||
} | } | ||||
// ===========================gemm_s8x8x16_mk4_16x12================================== | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_16x12_a53); | |||||
void gemm_s8x8x16_mk4_16x12_a53::pack_A(dt_int16* out, const dt_int8* in, | |||||
int ldin, int y0, int ymax, int k0, | |||||
int kmax, bool) const { | |||||
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A(out, in, ldin, y0, | |||||
ymax, k0, kmax); | |||||
} | |||||
void gemm_s8x8x16_mk4_16x12_a53::pack_B(dt_int8* out, const dt_int8* in, | |||||
int ldin, int x0, int xmax, int k0, | |||||
int kmax, bool) const { | |||||
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B(out, in, ldin, x0, | |||||
xmax, k0, kmax); | |||||
} | |||||
void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||||
const dt_int8* packB, size_t M, size_t N, | |||||
size_t K, dt_int16* C, size_t LDC, | |||||
bool is_first_k, const dt_int16*, | |||||
dt_int16*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
C_dtype.enumv() == DTypeEnum::Int16 && | |||||
A_dtype.enumv() == DTypeEnum::Int8); | |||||
megdnn_assert(is_first_k == true, "only impl is_first_k"); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | |||||
constexpr size_t pack_size = 4; | |||||
constexpr size_t pack_m = 16; | |||||
constexpr size_t pack_n = 12; | |||||
const size_t remain_n = N % pack_n; | |||||
size_t remain_m = M % pack_m; | |||||
size_t m_idx = 0; | |||||
for (; m_idx + pack_m <= M; m_idx += pack_m) { | |||||
int16_t* output = C + (m_idx / pack_size * LDC); | |||||
size_t n_idx = 0; | |||||
const int8_t* cur_packB = packB; | |||||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||||
matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k, pack_n); | |||||
output += pack_n * pack_size; | |||||
cur_packB += pack_n * K; | |||||
} | |||||
if (remain_n > 0) { | |||||
matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k, remain_n); | |||||
output += remain_n * pack_size; | |||||
cur_packB += pack_n * K; | |||||
} | |||||
packA += pack_m * K; | |||||
} | |||||
if (remain_m >= 8) { | |||||
int16_t* output = C + (m_idx / pack_size * LDC); | |||||
size_t n_idx = 0; | |||||
const int8_t* cur_packB = packB; | |||||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||||
matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k, pack_n); | |||||
output += pack_n * pack_size; | |||||
cur_packB += pack_n * K; | |||||
} | |||||
if (remain_n > 0) { | |||||
matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k, remain_n); | |||||
output += remain_n * pack_size; | |||||
cur_packB += pack_n * K; | |||||
} | |||||
packA += 8 * K; | |||||
m_idx += 8; | |||||
remain_m -= 8; | |||||
} | |||||
if (remain_m == 4) { | |||||
int16_t* output = C + (m_idx / pack_size * LDC); | |||||
size_t n_idx = 0; | |||||
const int8_t* cur_packB = packB; | |||||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||||
matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k, pack_n); | |||||
output += pack_n * pack_size; | |||||
cur_packB += pack_n * K; | |||||
} | |||||
if (remain_n > 0) { | |||||
matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC, | |||||
is_first_k, remain_n); | |||||
output += remain_n * pack_size; | |||||
cur_packB += pack_n * K; | |||||
} | |||||
} | |||||
} | |||||
// ===========================gemm_s8x8x16_mk4_4x4_a72================================== | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_4x4_a72); | |||||
void gemm_s8x8x16_mk4_4x4_a72::pack_A(dt_int8* out, const dt_int8* in, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool) const { | |||||
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A(out, in, ldin, y0, ymax, | |||||
k0, kmax); | |||||
} | |||||
void gemm_s8x8x16_mk4_4x4_a72::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax, | |||||
bool) const { | |||||
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B(out, in, ldin, x0, xmax, | |||||
k0, kmax); | |||||
} | |||||
void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t M, size_t N, size_t K, dt_int16* C, | |||||
size_t LDC, bool is_first_k, | |||||
const dt_int16*, dt_int16*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
C_dtype.enumv() == DTypeEnum::Int16 && | |||||
A_dtype.enumv() == DTypeEnum::Int8); | |||||
megdnn_assert(is_first_k == true, "only impl is_first_k"); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | |||||
constexpr size_t pack_size = 4; | |||||
constexpr size_t pack_m = 4; | |||||
constexpr size_t pack_n = 4; | |||||
constexpr size_t pack_k = 8; | |||||
const size_t remain_n = N % pack_n; | |||||
const size_t nend = N - remain_n; | |||||
const size_t packed_k = round_up(K, pack_k); | |||||
for (size_t m_idx = 0; m_idx < M; m_idx += pack_m) { | |||||
int16_t* output = C + (m_idx / pack_size * LDC); | |||||
const int8_t* cur_packB = packB; | |||||
for (size_t n_idx = 0; n_idx < nend; n_idx += pack_n) { | |||||
matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, pack_n); | |||||
output += pack_n * pack_size; | |||||
cur_packB += pack_n * packed_k; | |||||
} | |||||
if (remain_n > 0) { | |||||
matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, remain_n); | |||||
output += remain_n * pack_size; | |||||
cur_packB += pack_n * packed_k; | |||||
} | |||||
packA += pack_m * packed_k; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
@@ -20,6 +21,11 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, | |||||
gemm_s8x8x16_8x8); | gemm_s8x8x16_8x8); | ||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 16, false, true, | MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 16, false, true, | ||||
gemm_s8x8x16_4x4); | gemm_s8x8x16_4x4); | ||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 8, false, false, | |||||
gemm_s8x8x16_mk4_4x4_a72); | |||||
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int16, | |||||
16, 12, 4, false, false, | |||||
gemm_s8x8x16_mk4_16x12_a53); | |||||
} // namespace matmul | } // namespace matmul | ||||
} // namespace aarch64 | } // namespace aarch64 | ||||
@@ -6,10 +6,11 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/aarch64/matrix_mul/opr_impl.h" | |||||
#include "src/aarch64/matrix_mul/algos.h" | #include "src/aarch64/matrix_mul/algos.h" | ||||
#include "src/aarch64/matrix_mul/opr_impl.h" | |||||
#include "src/common/metahelper.h" | #include "src/common/metahelper.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
@@ -36,6 +37,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
#endif | #endif | ||||
AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; | AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; | ||||
AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; | AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; | ||||
AlgoInt8x8x16MK4_16x12x4 int8x8x16_mk4_16x12x4; | |||||
AlgoInt8x8x16MK4_4x4x8 int8x8x16_mk4_4x4x8; | |||||
AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1; | AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1; | ||||
AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8; | AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8; | ||||
@@ -70,6 +73,8 @@ public: | |||||
#endif | #endif | ||||
all_algos.emplace_back(&int8x8x16_k4x4x16); | all_algos.emplace_back(&int8x8x16_k4x4x16); | ||||
all_algos.emplace_back(&int8x8x16_k8x8x8); | all_algos.emplace_back(&int8x8x16_k8x8x8); | ||||
all_algos.emplace_back(&int8x8x16_mk4_4x4x8); | |||||
all_algos.emplace_back(&int8x8x16_mk4_16x12x4); | |||||
all_algos.emplace_back(&int16x16x32_k12x8x1); | all_algos.emplace_back(&int16x16x32_k12x8x1); | ||||
all_algos.emplace_back(&int16x16x32_mk8_8x8); | all_algos.emplace_back(&int16x16x32_mk8_8x8); | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "src/arm_common/matrix_mul/opr_impl.h" | #include "src/arm_common/matrix_mul/opr_impl.h" | ||||
@@ -21,28 +22,30 @@ public: | |||||
SmallVector<AlgoBase*> algo_pack() override; | SmallVector<AlgoBase*> algo_pack() override; | ||||
private: | private: | ||||
class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 | |||||
class AlgoF32MK4_8x12x1; // Aarch64 F32 Kernel MK4 8x12x1 | |||||
class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1 | |||||
class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4 | |||||
class AlgoF32Gemv; // Aarch64 F32 Gemv | |||||
class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 | |||||
class AlgoF32MK4_8x12x1; // Aarch64 F32 Kernel MK4 8x12x1 | |||||
class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1 | |||||
class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4 | |||||
class AlgoF32Gemv; // Aarch64 F32 Gemv | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
class AlgoF16K8x24x1; // Aarch64 F16 Kernel 8x24x1 | class AlgoF16K8x24x1; // Aarch64 F16 Kernel 8x24x1 | ||||
class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8 | class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8 | ||||
#endif | #endif | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel | |||||
// 8x12x4 DotProduct | |||||
class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel | |||||
// 8x12x4 DotProduct | |||||
class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel | |||||
// 8x12x4 DotProduct | |||||
class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel | |||||
// 8x12x4 DotProduct | |||||
#else | #else | ||||
class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 | class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 | ||||
class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 | |||||
class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 | |||||
class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 | |||||
class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 | |||||
#endif | #endif | ||||
class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 | |||||
class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 | |||||
class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 | |||||
class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 | |||||
class AlgoInt8x8x16MK4_16x12x4; // Aarch64 Int8x8x16 Kernel 16x12x16 | |||||
class AlgoInt8x8x16MK4_4x4x8; // Aarch64 Int8x8x16 Kernel 4x4x8 | |||||
class AlgoInt16x16x32K12x8x1; // Aarch64 Int16x16x32 Kernel 12x8x1 | class AlgoInt16x16x32K12x8x1; // Aarch64 Int16x16x32 Kernel 12x8x1 | ||||
class AlgoInt16x16x32MK8_8x8; // Aarch64 Int16x16x32 Format MK8 block 8x8 | class AlgoInt16x16x32MK8_8x8; // Aarch64 Int16x16x32 Format MK8 block 8x8 | ||||
@@ -52,7 +55,7 @@ private: | |||||
// 8x8x4 DotProduct | // 8x8x4 DotProduct | ||||
class AlgoQuint8GemvDotProd; // Aarch64 Quint8 Gemv DotProduct | class AlgoQuint8GemvDotProd; // Aarch64 Quint8 Gemv DotProduct | ||||
#else | #else | ||||
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 | |||||
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 | |||||
#endif | #endif | ||||
class AlgoPack; | class AlgoPack; | ||||
@@ -214,7 +214,6 @@ void* const ConvBiasImpl::sm_arm_common_algo_type = | |||||
bool ConvBiasImpl::is_matmul_quantized_prefer( | bool ConvBiasImpl::is_matmul_quantized_prefer( | ||||
const ConvBiasImpl::NCBKernSizeParam& param) const { | const ConvBiasImpl::NCBKernSizeParam& param) const { | ||||
// fallback::ConvBiasImpl::NCBKernParam conv_ncb_param; | |||||
fallback::ConvBiasImpl::NCBKernSizeParam conv_ncb_param( | fallback::ConvBiasImpl::NCBKernSizeParam conv_ncb_param( | ||||
param, 0, param::MatrixMul::Format::DEFAULT, {}, 0, | param, 0, param::MatrixMul::Format::DEFAULT, {}, 0, | ||||
BiasMode::NO_BIAS, param::ConvBias::NonlineMode::IDENTITY); | BiasMode::NO_BIAS, param::ConvBias::NonlineMode::IDENTITY); | ||||
@@ -9,8 +9,8 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#ifdef MGB_ENABLE_CPUINFO_CHECK | |||||
#include "src/common/utils.h" | |||||
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO | |||||
#include "cpuinfo_arch_vendor.h" | #include "cpuinfo_arch_vendor.h" | ||||
@@ -11,8 +11,8 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#ifdef MGB_ENABLE_CPUINFO_CHECK | |||||
#include "src/common/utils.h" | |||||
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO | |||||
#include <cpuinfo.h> | #include <cpuinfo.h> | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "test/aarch64/fixture.h" | #include "test/aarch64/fixture.h" | ||||
@@ -16,6 +17,7 @@ | |||||
#include "test/common/matrix_mul.h" | #include "test/common/matrix_mul.h" | ||||
#include "test/common/rng.h" | #include "test/common/rng.h" | ||||
#include "test/arm_common/cpuinfo_help.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace test; | using namespace test; | ||||
@@ -24,6 +26,20 @@ TEST_F(AARCH64, MATRIX_MUL_FP32K8X12) { | |||||
dtype::Float32{}, handle(), | dtype::Float32{}, handle(), | ||||
"AARCH64_F32K8X12X1"); | "AARCH64_F32K8X12X1"); | ||||
} | } | ||||
#if MGB_ENABLE_CPUINFO | |||||
TEST_F(AARCH64, MATRIX_MUL_FP32K8X12_A53) { | |||||
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53); | |||||
matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, | |||||
dtype::Float32{}, handle(), | |||||
"AARCH64_F32K8X12X1"); | |||||
} | |||||
TEST_F(AARCH64, MATRIX_MUL_FP32K8X12_A55) { | |||||
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55); | |||||
matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, | |||||
dtype::Float32{}, handle(), | |||||
"AARCH64_F32K8X12X1"); | |||||
} | |||||
#endif | |||||
TEST_F(AARCH64, MATRIX_MUL_FP32K4X16) { | TEST_F(AARCH64, MATRIX_MUL_FP32K4X16) { | ||||
matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, | matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, | ||||
@@ -36,6 +52,20 @@ TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4) { | |||||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | ||||
"AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1); | "AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1); | ||||
} | } | ||||
#if MGB_ENABLE_CPUINFO | |||||
TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4_A53) { | |||||
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53); | |||||
matrix_mul::check_matrix_mul( | |||||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||||
"AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1); | |||||
} | |||||
TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4_A55) { | |||||
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55); | |||||
matrix_mul::check_matrix_mul( | |||||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||||
"AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1); | |||||
} | |||||
#endif | |||||
TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) { | TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) { | ||||
matrix_mul::check_matrix_mul( | matrix_mul::check_matrix_mul( | ||||
@@ -92,6 +122,18 @@ TEST_F(AARCH64, MATRIX_MUL_INT8_MK4) { | |||||
std::move(args)); | std::move(args)); | ||||
} | } | ||||
TEST_F(AARCH64, MATRIX_MUL_MK4_8x8x16_4x4) { | |||||
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | |||||
handle(), "AARCH64_INT8X8X16_MK4_4X4X8", | |||||
param::MatrixMul::Format::MK4, 1); | |||||
} | |||||
TEST_F(AARCH64, MATRIX_MUL_MK4_8x8x16) { | |||||
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | |||||
handle(), "AARCH64_INT8X8X16_MK4_16X12X4", | |||||
param::MatrixMul::Format::MK4, 1); | |||||
} | |||||
TEST_F(AARCH64, MATRIX_MUL_INT8x8x32_K8x8x8) { | TEST_F(AARCH64, MATRIX_MUL_INT8x8x32_K8x8x8) { | ||||
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | ||||
handle(), "AARCH64_INT8X8X32_K8X8X8"); | handle(), "AARCH64_INT8X8X32_K8X8X8"); | ||||
@@ -172,6 +214,7 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_K4X16) { | |||||
}; | }; | ||||
run(256, 256, 128); | run(256, 256, 128); | ||||
run(384, 384, 384); | |||||
for (size_t k = 4; k <= 256; k *= 8) { | for (size_t k = 4; k <= 256; k *= 8) { | ||||
for (size_t m = 4; m <= 256; m *= 4) { | for (size_t m = 4; m <= 256; m *= 4) { | ||||
@@ -235,7 +278,7 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_8X8X8) { | |||||
int32_used / int_used); | int32_used / int_used); | ||||
}; | }; | ||||
run(256, 256, 128); | |||||
run(256, 256, 256); | |||||
for (size_t k = 4; k <= 256; k *= 8) { | for (size_t k = 4; k <= 256; k *= 8) { | ||||
for (size_t m = 4; m <= 256; m *= 4) { | for (size_t m = 4; m <= 256; m *= 4) { | ||||
@@ -297,6 +340,62 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT32_MK_4X4X16) { | |||||
} | } | ||||
} | } | ||||
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x16) { | |||||
constexpr size_t RUNS = 50; | |||||
param::MatrixMul param; | |||||
param.transposeA = false; | |||||
param.transposeB = false; | |||||
Benchmarker<MatrixMul> benchmarker(handle()); | |||||
Benchmarker<MatrixMul> benchmarker_mk4(handle()); | |||||
Benchmarker<MatrixMul> benchmarker_mk4_16x12(handle()); | |||||
benchmarker.set_times(RUNS) | |||||
.set_dtype(0, dtype::Int8{}) | |||||
.set_dtype(1, dtype::Int8{}) | |||||
.set_dtype(2, dtype::Int16{}) | |||||
.set_param(param) | |||||
.set_display(false); | |||||
benchmarker.set_before_exec_callback( | |||||
AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_K4X4X16")); | |||||
param.format = MatrixMul::Param::Format::MK4; | |||||
benchmarker_mk4.set_before_exec_callback( | |||||
AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_MK4_4X4X8")); | |||||
benchmarker_mk4.set_times(RUNS) | |||||
.set_dtype(0, dtype::Int8{}) | |||||
.set_dtype(1, dtype::Int8{}) | |||||
.set_dtype(2, dtype::Int16{}) | |||||
.set_param(param) | |||||
.set_display(false); | |||||
benchmarker_mk4_16x12.set_before_exec_callback( | |||||
AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_MK4_16X12X4")); | |||||
benchmarker_mk4_16x12.set_times(RUNS) | |||||
.set_dtype(0, dtype::Int8{}) | |||||
.set_dtype(1, dtype::Int8{}) | |||||
.set_dtype(2, dtype::Int16{}) | |||||
.set_param(param) | |||||
.set_display(false); | |||||
auto run = [&](size_t M, size_t N, size_t K) { | |||||
auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS; | |||||
auto mk_used = benchmarker_mk4.exec( | |||||
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / | |||||
RUNS; | |||||
auto mk4_16x12_used = | |||||
benchmarker_mk4_16x12.exec( | |||||
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / | |||||
RUNS; | |||||
float computations = 2.f * M * K * N * 1e-6; | |||||
printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms " | |||||
"%f Gflops speedup: %f, mk4_16x12 %f Gflops speedup: %f\n", | |||||
M, K, N, default_used, computations / default_used, mk_used, | |||||
computations / mk_used, default_used / mk_used, | |||||
computations / mk4_16x12_used, default_used / mk4_16x12_used); | |||||
}; | |||||
run(384, 384, 384); | |||||
} | |||||
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_4X4X16) { | TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_4X4X16) { | ||||
constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
param::MatrixMul param; | param::MatrixMul param; | ||||
@@ -350,9 +449,11 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_4X4X16) { | |||||
run(256, 256, 128); | run(256, 256, 128); | ||||
for (size_t k = 4; k <= 16; k *= 2) { | |||||
for (size_t m = 4; m <= 64; m *= 2) { | |||||
for (size_t n = 4; n <= 64; n *= 2) { | |||||
run(256, 256, 256); | |||||
for (size_t k = 4; k <= 256; k *= 4) { | |||||
for (size_t m = 4; m <= 256; m *= 4) { | |||||
for (size_t n = 4; n <= 256; n *= 4) { | |||||
run(m, n, k); | run(m, n, k); | ||||
} | } | ||||
} | } | ||||
@@ -736,15 +736,21 @@ TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) { | |||||
} | } | ||||
#endif | #endif | ||||
#if MEGDNN_ARMV7 | |||||
TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x16) { | TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x16) { | ||||
#if MEGDNN_ARMV7 | |||||
const char* default_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"; | const char* default_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"; | ||||
const char* mk4_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4"; | const char* mk4_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4"; | ||||
printf("compare %s vs %s \n", default_algo, mk4_algo); | printf("compare %s vs %s \n", default_algo, mk4_algo); | ||||
BENCHMARK_IM2COL_NCHW44_VS_NCHW(default_algo, mk4_algo, handle(), 3, | BENCHMARK_IM2COL_NCHW44_VS_NCHW(default_algo, mk4_algo, handle(), 3, | ||||
dtype::Int8(), dtype::Int16()); | dtype::Int8(), dtype::Int16()); | ||||
} | |||||
#else | |||||
const char* default_algo = "IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16"; | |||||
const char* mk4_algo = "IM2COLMATMUL:AARCH64_INT8X8X16_MK4_4X4X8"; | |||||
printf("compare %s vs %s \n", default_algo, mk4_algo); | |||||
BENCHMARK_IM2COL_NCHW44_VS_NCHW(default_algo, mk4_algo, handle(), 3, | |||||
dtype::Int8(), dtype::Int16()); | |||||
#endif | #endif | ||||
} | |||||
TEST_F(ARM_COMMON, BENCHMARK_GROUP_CONV_NCHW44_INT8x8x32_VS_INT8x8x16_STRIDE1) { | TEST_F(ARM_COMMON, BENCHMARK_GROUP_CONV_NCHW44_INT8x8x32_VS_INT8x8x16_STRIDE1) { | ||||
BENCHMARK_GROUPCONV_NCHW44_int8x8x16VS_int8x8x32("S8_CHAN_WISE_STRD1_NCHW44", | BENCHMARK_GROUPCONV_NCHW44_int8x8x16VS_int8x8x32("S8_CHAN_WISE_STRD1_NCHW44", | ||||
@@ -14,6 +14,8 @@ | |||||
#include "test/common/benchmarker.h" | #include "test/common/benchmarker.h" | ||||
#include "test/common/conv_bias.h" | #include "test/common/conv_bias.h" | ||||
#include "test/arm_common/cpuinfo_help.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace test; | using namespace test; | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
@@ -487,11 +489,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
handle(), "S8_CHAN_WISE_STRD2_NCHW44"); | handle(), "S8_CHAN_WISE_STRD2_NCHW44"); | ||||
} | } | ||||
TEST_F(ARM_COMMON, | |||||
CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) { | |||||
TEST_F(ARM_COMMON, CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) { | |||||
Checker<ConvBias> checker(handle()); | Checker<ConvBias> checker(handle()); | ||||
checker.set_before_exec_callback( | |||||
conv_bias::ConvBiasAlgoChecker<ConvBias>("S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44")); | |||||
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | |||||
"S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44")); | |||||
checker.set_dtype(0, dtype::Int8()); | checker.set_dtype(0, dtype::Int8()); | ||||
checker.set_dtype(1, dtype::Int8()); | checker.set_dtype(1, dtype::Int8()); | ||||
checker.set_dtype(2, dtype::Int16()); | checker.set_dtype(2, dtype::Int16()); | ||||
@@ -505,8 +506,8 @@ TEST_F(ARM_COMMON, | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | TEST_F(ARM_COMMON_MULTI_THREADS, | ||||
CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT2_NCHW44) { | CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT2_NCHW44) { | ||||
Checker<ConvBias> checker(handle()); | Checker<ConvBias> checker(handle()); | ||||
checker.set_before_exec_callback( | |||||
conv_bias::ConvBiasAlgoChecker<ConvBias>("S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44")); | |||||
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>( | |||||
"S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44")); | |||||
checker.set_dtype(0, dtype::Int8()); | checker.set_dtype(0, dtype::Int8()); | ||||
checker.set_dtype(1, dtype::Int8()); | checker.set_dtype(1, dtype::Int8()); | ||||
checker.set_dtype(2, dtype::Int16()); | checker.set_dtype(2, dtype::Int16()); | ||||
@@ -1803,8 +1804,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE2_PREPROCESS) { | |||||
handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \ | handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \ | ||||
dtype::Float32(), dtype::Float32(), name); | dtype::Float32(), dtype::Float32(), name); | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
cb("IM2COLMATMUL:AARCH64_F32K8X12X1") | |||||
cb("IM2COLMATMUL:AARCH64_F32K4X16X1") | |||||
cb("IM2COLMATMUL:AARCH64_F32K8X12X1") cb("IM2COLMATMUL:AARCH64_F32K4X16X1") | |||||
#elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
cb("IM2COLMATMUL:ARMV7_F32") | cb("IM2COLMATMUL:ARMV7_F32") | ||||
#endif | #endif | ||||
@@ -1858,6 +1858,94 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE1) { | |||||
#undef cb | #undef cb | ||||
} | } | ||||
//! CPUINFO ralated test | |||||
#if MEGDNN_AARCH64 | |||||
#if MGB_ENABLE_CPUINFO | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_A55) { | |||||
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55); | |||||
#define cb(name,stride) \ | |||||
check_conv_bias( \ | |||||
get_conv_bias_args({2, 3, 4, 5, 6, 7}, stride, false, false, false), \ | |||||
handle(), name); | |||||
cb("IM2COLMATMUL:AARCH64_F32K8X12X1", 1) | |||||
cb("IM2COLMATMUL:AARCH64_F32K8X12X1", 2) | |||||
#undef cb | |||||
} | |||||
#endif | |||||
#endif | |||||
#if MEGDNN_AARCH64 | |||||
#if MGB_ENABLE_CPUINFO | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_A53) { | |||||
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53); | |||||
#define cb(name,stride) \ | |||||
check_conv_bias( \ | |||||
get_conv_bias_args({2, 3, 4, 5, 6, 7}, stride, false, false, false), \ | |||||
handle(), name); | |||||
cb("IM2COLMATMUL:AARCH64_F32K8X12X1", 1) | |||||
cb("IM2COLMATMUL:AARCH64_F32K8X12X1", 2) | |||||
#undef cb | |||||
} | |||||
#endif | |||||
#endif | |||||
#if MEGDNN_AARCH64 | |||||
#if MGB_ENABLE_CPUINFO | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_MK4_PACK_F32_A55) { | |||||
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55); | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | |||||
{2, 3, 7}, 1, false, false, false, false, false, true, true); | |||||
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | |||||
args = get_nchw44_conv_bias_args( | |||||
{2, 3, 7}, 2, false, false, false, false, false, true, true); | |||||
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | |||||
} | |||||
#endif | |||||
#endif | |||||
#if MEGDNN_AARCH64 | |||||
#if MGB_ENABLE_CPUINFO | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_MK4_PACK_F32_A53) { | |||||
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53); | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | |||||
{2, 3, 7}, 1, false, false, false, false, false, true, true); | |||||
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | |||||
args = get_nchw44_conv_bias_args( | |||||
{2, 3, 7}, 2, false, false, false, false, false, true, true); | |||||
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | |||||
} | |||||
#endif | |||||
#endif | |||||
#if MEGDNN_AARCH64 | |||||
#if MGB_ENABLE_CPUINFO | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_MK4_PACK_F32_A55) { | |||||
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55); | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({1}, 1, true, false, false); | |||||
check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24"); | |||||
} | |||||
#endif | |||||
#endif | |||||
#if MEGDNN_AARCH64 | |||||
#if MGB_ENABLE_CPUINFO | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_MK4_PACK_F32_A53) { | |||||
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53); | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({1}, 1, true, false, false); | |||||
check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24"); | |||||
} | |||||
#endif | |||||
#endif | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) { | ||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
@@ -2216,7 +2304,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) { | |||||
#undef cb | #undef cb | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32_FILTERPREPROCESS) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_IM2COLMATMUL_QUINT8x8x32_FILTERPREPROCESS) { | |||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
float epsilon = 0.001; | float epsilon = 0.001; | ||||
#define cb(name) \ | #define cb(name) \ | ||||
@@ -2247,7 +2336,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32_FILTERPREPRO | |||||
#undef cb | #undef cb | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { | ||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
float epsilon = 0.001; | float epsilon = 0.001; | ||||
@@ -2276,19 +2364,21 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { | |||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8"); | cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8"); | ||||
cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16"); | cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16"); | ||||
cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16"); | |||||
cb_nchw44("IM2COLMATMUL:AARCH64_INT8X8X16_MK4_4X4X8"); | |||||
cb_nchw44("IM2COLMATMUL:AARCH64_INT8X8X16_MK4_16X12X4"); | |||||
#elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16"); | |||||
cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"); | cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"); | ||||
cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16"); | cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16"); | ||||
cb_nchw44("IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4"); | cb_nchw44("IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4"); | ||||
#endif | #endif | ||||
cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16"); | |||||
#undef cb | #undef cb | ||||
#undef cb_nchw44 | #undef cb_nchw44 | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCESS) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCESS) { | |||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
float epsilon = 0.001; | float epsilon = 0.001; | ||||
#define cb(name) \ | #define cb(name) \ | ||||
@@ -2311,7 +2401,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCES | |||||
#undef cb | #undef cb | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPREPROCESS) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPREPROCESS) { | |||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
float epsilon = 0.001; | float epsilon = 0.001; | ||||
#define cb(name) \ | #define cb(name) \ | ||||
@@ -2415,8 +2506,9 @@ void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args, | |||||
} | } | ||||
} | } | ||||
void checker_conv_bias_int8x8x32_preprocess(std::vector<conv_bias::TestArg> args, | |||||
Handle* handle, const char* algo_name) { | |||||
void checker_conv_bias_int8x8x32_preprocess( | |||||
std::vector<conv_bias::TestArg> args, Handle* handle, | |||||
const char* algo_name) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | ||||
@@ -2461,7 +2553,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { | |||||
#undef cb | #undef cb | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<conv_bias::TestArg> args = | std::vector<conv_bias::TestArg> args = | ||||
get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true); | get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true); | ||||
@@ -2490,7 +2583,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) { | |||||
#undef cb | #undef cb | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1_PREPROCESS) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1_PREPROCESS) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<conv_bias::TestArg> args = | std::vector<conv_bias::TestArg> args = | ||||
get_nchw44_conv_bias_args({3, 4, 6}, 1, false, true, true); | get_nchw44_conv_bias_args({3, 4, 6}, 1, false, true, true); | ||||
@@ -2541,7 +2635,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
#undef cb | #undef cb | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, | TEST_F(ARM_COMMON_MULTI_THREADS, | ||||
CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_S1) { | CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_S1) { | ||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
@@ -2678,7 +2771,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { | |||||
#undef cb | #undef cb | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X32_FILTER_PREPROCESS) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_IM2COLMATMUL_INT8X8X32_FILTER_PREPROCESS) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<conv_bias::TestArg> args = | std::vector<conv_bias::TestArg> args = | ||||
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true); | get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true); | ||||
@@ -2722,7 +2816,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32_PREPROCESS) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32_PREPROCESS) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | ||||
{2, 4, 7}, 1, false, false, false, false, false, true,true); | |||||
{2, 4, 7}, 1, false, false, false, false, false, true, true); | |||||
#define cb(name) \ | #define cb(name) \ | ||||
check_conv_bias_preprocess(args, handle(), nullptr, 0.001, \ | check_conv_bias_preprocess(args, handle(), nullptr, 0.001, \ | ||||
dtype::Float32(), dtype::Float32(), \ | dtype::Float32(), dtype::Float32(), \ | ||||
@@ -2748,7 +2842,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) { | |||||
#undef cb | #undef cb | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32_FUSE_PREPROCESS) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_IM2COL_S2_MK4_PACK_F32_FUSE_PREPROCESS) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( | ||||
{3}, 2, false, false, false, false, false, true, true, false); | {3}, 2, false, false, false, false, false, true, true, false); | ||||
@@ -2884,12 +2979,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16_PREPROCESS) { | |||||
NormalRNG rng(1); | NormalRNG rng(1); | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
check_conv_bias_preprocess(args, handle(), &rng, 0.03, dtype::Float16{}, | check_conv_bias_preprocess(args, handle(), &rng, 0.03, dtype::Float16{}, | ||||
dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, | |||||
"CONV1x1:AARCH64_F16_K8X24X1:48"); | |||||
dtype::Float16{}, dtype::Float16{}, | |||||
dtype::Float16{}, | |||||
"CONV1x1:AARCH64_F16_K8X24X1:48"); | |||||
#elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
check_conv_bias_preprocess(args, handle(), &rng, 0.03, dtype::Float16{}, | check_conv_bias_preprocess(args, handle(), &rng, 0.03, dtype::Float16{}, | ||||
dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, | |||||
"CONV1x1:AARCH32_F16_K4X16X1:24"); | |||||
dtype::Float16{}, dtype::Float16{}, | |||||
dtype::Float16{}, | |||||
"CONV1x1:AARCH32_F16_K4X16X1:24"); | |||||
#endif | #endif | ||||
} | } | ||||
@@ -2951,7 +3048,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM_PREPROCESS) { | |||||
#undef cb | #undef cb | ||||
} | } | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) { | ||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
@@ -3074,7 +3170,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32_PREPROCESS) { | |||||
cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24"); | cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24"); | ||||
#endif | #endif | ||||
#undef cb | #undef cb | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { | ||||
@@ -3095,6 +3190,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { | |||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24"); | cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24"); | ||||
cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24"); | cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24"); | ||||
cb_nchw44("CONV1x1:AARCH64_INT8X8X16_MK4_4X4X8:48"); | |||||
cb_nchw44("CONV1x1:AARCH64_INT8X8X16_MK4_16X12X4:48"); | |||||
#elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24"); | cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24"); | ||||
cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48"); | cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48"); | ||||
@@ -3128,11 +3225,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16_PREPROCESS) { | |||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24"); | cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24"); | ||||
cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24"); | cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24"); | ||||
cb("CONV1x1:ARM_COMMON_INT8X8X16:24");//!add nopack test | |||||
cb("CONV1x1:ARM_COMMON_INT8X8X16:24"); //! add nopack test | |||||
#elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24"); | cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24"); | ||||
cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48"); | cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48"); | ||||
cb("CONV1x1:ARM_COMMON_INT8X8X16:24");//!add nopack test | |||||
cb("CONV1x1:ARM_COMMON_INT8X8X16:24"); //! add nopack test | |||||
#endif | #endif | ||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -3245,11 +3342,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4_PREPROCESS) { | |||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
float epsilon = 0.001; | float epsilon = 0.001; | ||||
#define cb(name) \ | |||||
check_conv_bias_preprocess(get_nchw44_conv_bias_args({1}, 1, true, false, false), \ | |||||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||||
dtype::QuantizedS8(60.25f), name); | |||||
#define cb(name) \ | |||||
check_conv_bias_preprocess( \ | |||||
get_nchw44_conv_bias_args({1}, 1, true, false, false), handle(), \ | |||||
&rng, epsilon, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), \ | |||||
dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), name); | |||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24"); | cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24"); | ||||
#elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
@@ -9,7 +9,8 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#ifdef MGB_ENABLE_CPUINFO_CHECK | |||||
#include "src/common/utils.h" | |||||
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO | |||||
#include <cpuinfo.h> | #include <cpuinfo.h> | ||||
#include <inttypes.h> | #include <inttypes.h> | ||||
#include "gtest/gtest.h" | #include "gtest/gtest.h" | ||||
@@ -18,7 +19,6 @@ namespace megdnn { | |||||
namespace test { | namespace test { | ||||
TEST(ARM_RUNTIME, CPUINFO_KIRIN980) { | TEST(ARM_RUNTIME, CPUINFO_KIRIN980) { | ||||
ASSERT_TRUE(cpuinfo_initialize()); | ASSERT_TRUE(cpuinfo_initialize()); | ||||
int right_soc = strcmp(cpuinfo_get_package(0)->name, "HiSilicon Kirin 980"); | int right_soc = strcmp(cpuinfo_get_package(0)->name, "HiSilicon Kirin 980"); | ||||
@@ -68,7 +68,6 @@ TEST(ARM_RUNTIME, CPUINFO_KIRIN980) { | |||||
} | } | ||||
TEST(ARM_RUNTIME, CPUINFO_SDM8150) { | TEST(ARM_RUNTIME, CPUINFO_SDM8150) { | ||||
ASSERT_TRUE(cpuinfo_initialize()); | ASSERT_TRUE(cpuinfo_initialize()); | ||||
int right_soc = | int right_soc = | ||||
@@ -119,7 +118,6 @@ TEST(ARM_RUNTIME, CPUINFO_SDM8150) { | |||||
} | } | ||||
TEST(ARM_RUNTIME, CPUINFO_SDM660) { | TEST(ARM_RUNTIME, CPUINFO_SDM660) { | ||||
ASSERT_TRUE(cpuinfo_initialize()); | ASSERT_TRUE(cpuinfo_initialize()); | ||||
int right_soc = | int right_soc = | ||||
@@ -173,4 +171,3 @@ TEST(ARM_RUNTIME, CPUINFO_SDM660) { | |||||
} // namespace megdnn | } // namespace megdnn | ||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
@@ -0,0 +1,17 @@ | |||||
/** | |||||
* \file dnn/test/arm_common/cpuinfo_help.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/common/utils.h" | |||||
#include "test/arm_common/cpuinfo_help.h" | |||||
#if MGB_ENABLE_CPUINFO | |||||
std::mutex CpuInfoTmpReplace::m_cpuinfo_lock; | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,47 @@ | |||||
/** | |||||
* \file dnn/test/arm_common/cpuinfo_help.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <mutex> | |||||
#include <vector> | |||||
#include "src/common/utils.h" | |||||
#if MGB_ENABLE_CPUINFO | |||||
#include "cpuinfo.h" | |||||
extern const struct cpuinfo_core** cpuinfo_linux_cpu_to_core_map; | |||||
class CpuInfoTmpReplace { | |||||
public: | |||||
CpuInfoTmpReplace(enum cpuinfo_uarch arch) { | |||||
m_cpuinfo_lock.lock(); | |||||
for (uint32_t i = 0; i < cpuinfo_get_cores_count(); ++i) { | |||||
m_arch_bak_vec.push_back(cpuinfo_linux_cpu_to_core_map[i]->uarch); | |||||
((struct cpuinfo_core**)cpuinfo_linux_cpu_to_core_map)[i]->uarch = | |||||
arch; | |||||
} | |||||
} | |||||
~CpuInfoTmpReplace() { | |||||
if (m_arch_bak_vec.size() > 0) { | |||||
for (uint32_t i = 0; i < cpuinfo_get_cores_count(); ++i) { | |||||
((struct cpuinfo_core**)cpuinfo_linux_cpu_to_core_map)[i] | |||||
->uarch = m_arch_bak_vec[i]; | |||||
} | |||||
} | |||||
m_cpuinfo_lock.unlock(); | |||||
} | |||||
private: | |||||
static std::mutex m_cpuinfo_lock; | |||||
std::vector<cpuinfo_uarch> m_arch_bak_vec; | |||||
}; | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -9,7 +9,8 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
* implied. | * implied. | ||||
*/ | */ | ||||
#ifdef MGB_ENABLE_CPUINFO_CHECK | |||||
#include "src/common/utils.h" | |||||
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO | |||||
#include <cpuinfo.h> | #include <cpuinfo.h> | ||||
#include <inttypes.h> | #include <inttypes.h> | ||||
#include "gtest/gtest.h" | #include "gtest/gtest.h" | ||||
@@ -18,14 +19,12 @@ namespace megdnn { | |||||
namespace test { | namespace test { | ||||
TEST(X86_RUNTIME, CPUINFO_XEON6130) { | TEST(X86_RUNTIME, CPUINFO_XEON6130) { | ||||
ASSERT_TRUE(cpuinfo_initialize()); | ASSERT_TRUE(cpuinfo_initialize()); | ||||
int right_cpu = | int right_cpu = | ||||
strcmp(cpuinfo_get_package(0)->name, "Intel Xeon Gold 6130"); | strcmp(cpuinfo_get_package(0)->name, "Intel Xeon Gold 6130"); | ||||
if (!right_cpu) { | if (!right_cpu) { | ||||
ASSERT_TRUE(cpuinfo_get_processors()); | ASSERT_TRUE(cpuinfo_get_processors()); | ||||
ASSERT_TRUE(cpuinfo_has_x86_avx2()); | ASSERT_TRUE(cpuinfo_has_x86_avx2()); | ||||
@@ -44,4 +43,3 @@ TEST(X86_RUNTIME, CPUINFO_XEON6130) { | |||||
} // namespace megdnn | } // namespace megdnn | ||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||