GitOrigin-RevId: d2f8290a8d
tags/v1.0.0-rc1
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/aarch64/matrix_mul/algos.h" | #include "src/aarch64/matrix_mul/algos.h" | ||||
@@ -733,7 +734,9 @@ void int8x8x16_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::usable( | bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::usable( | ||||
const KernSizeParam& kern_size_param) const { | const KernSizeParam& kern_size_param) const { | ||||
return can_be_treated_as_int8x8x16(kern_size_param); | |||||
return can_be_treated_as_int8x8x16(kern_size_param) && | |||||
kern_size_param.format == param::MatrixMul::Format::DEFAULT && | |||||
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT; | |||||
} | } | ||||
bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::preferred( | bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::preferred( | ||||
@@ -796,7 +799,9 @@ void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::usable( | bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::usable( | ||||
const KernSizeParam& kern_size_param) const { | const KernSizeParam& kern_size_param) const { | ||||
return can_be_treated_as_int8x8x16(kern_size_param); | |||||
return can_be_treated_as_int8x8x16(kern_size_param) && | |||||
kern_size_param.format == param::MatrixMul::Format::DEFAULT && | |||||
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT; | |||||
} | } | ||||
bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::preferred( | bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::preferred( | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/armv7/matrix_mul/algos.h" | #include "src/armv7/matrix_mul/algos.h" | ||||
@@ -526,6 +527,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x8x8, | |||||
"AlgoInt8x8x16K4x8x8"_hash, | "AlgoInt8x8x16K4x8x8"_hash, | ||||
armv7::matmul::gemm_s8x8x16_4x8, int8_t, | armv7::matmul::gemm_s8x8x16_4x8, int8_t, | ||||
int16_t); | int16_t); | ||||
/* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/ | |||||
namespace { | |||||
void kern_int8x8x16_mk4_k8x8x4(const MatrixMulImpl::KernParam& kern_param) { | |||||
MIDOUT_BEGIN(megdnn_armv7_matmul_kern, | |||||
midout_iv("kern_int8x8x16_mk4_k8x8x4"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||||
auto Cptr = kern_param.C<dt_int16>(); | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
auto trA = kern_param.trA, trB = kern_param.trB; | |||||
armv7::matmul::gemm_s8x8x16_mk4_8x8 strategy(M, N, K, kern_param.A_type, | |||||
kern_param.B_type, | |||||
kern_param.C_type); | |||||
megdnn::matmul::GemmInterleaved<armv7::matmul::gemm_s8x8x16_mk4_8x8>( | |||||
M, N, K, trA, trB, strategy) | |||||
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | |||||
kern_param.workspace_ptr); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} // anonymous namespace | |||||
bool MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
bool type_ok = can_be_treated_as_int8x8x16(kern_size_param); | |||||
return type_ok && kern_size_param.format == param::MatrixMul::Format::MK4 && | |||||
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
!kern_size_param.trA && !kern_size_param.trB && | |||||
kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; | |||||
} | |||||
size_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_workspace( | |||||
const KernSizeParam& kern_size_param) const { | |||||
MIDOUT_BEGIN(megdnn_armv7_matmul_kern, | |||||
midout_iv("AlgoInt8x8x16K8x8x4::get_workspace"_hash)) { | |||||
auto M = kern_size_param.M, N = kern_size_param.N, | |||||
K = kern_size_param.K; | |||||
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||||
C_type = kern_size_param.C_type; | |||||
auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||||
matmul::gemm_s8x8x16_mk4_8x8 strategy(M, N, K, A_type, B_type, C_type); | |||||
return megdnn::matmul::GemmInterleaved<matmul::gemm_s8x8x16_mk4_8x8>( | |||||
M, N, K, trA, trB, strategy) | |||||
.get_workspace_size(); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_kern( | |||||
const KernSizeParam&) const { | |||||
return kern_int8x8x16_mk4_k8x8x4; | |||||
} | |||||
bool MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::preferred( | |||||
const KernSizeParam& kern_size_param) const { | |||||
return kern_size_param.K >= 4; | |||||
} | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16MK4_8x8x4, | |||||
megdnn_armv7_matmul_kern, | |||||
"AlgoInt8x8x16MK4_8x8x4"_hash, | |||||
armv7::matmul::gemm_s8x8x16_mk4_8x8, | |||||
int8_t, int16_t, int16_t); | |||||
/* ===================== Int16x16x32 Kernel 12x4x1 algo ===================== */ | /* ===================== Int16x16x32 Kernel 12x4x1 algo ===================== */ | ||||
namespace { | namespace { | ||||
@@ -937,11 +1006,9 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_4x8::get_kern( | |||||
Bptr = kern_param.B<dt_float16>(); | Bptr = kern_param.B<dt_float16>(); | ||||
auto Cptr = kern_param.C<dt_float16>(); | auto Cptr = kern_param.C<dt_float16>(); | ||||
armv7::matmul::gemm_nopack_f16_4x8 strategy(A_type, B_type, | |||||
C_type); | |||||
megdnn::matmul::GemmInterleaved< | |||||
armv7::matmul::gemm_nopack_f16_4x8, false>(M, N, K, trA, | |||||
trB, strategy) | |||||
armv7::matmul::gemm_nopack_f16_4x8 strategy(A_type, B_type, C_type); | |||||
megdnn::matmul::GemmInterleaved<armv7::matmul::gemm_nopack_f16_4x8, | |||||
false>(M, N, K, trA, trB, strategy) | |||||
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | ||||
kern_param.workspace_ptr); | kern_param.workspace_ptr); | ||||
} | } | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
@@ -171,6 +172,18 @@ public: | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase { | class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
@@ -6,13 +6,15 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include <arm_neon.h> | #include <arm_neon.h> | ||||
#include <cmath> | #include <cmath> | ||||
#include <cstdint> | #include <cstdint> | ||||
#include <type_traits> | #include <type_traits> | ||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
@@ -172,7 +174,6 @@ static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1, | |||||
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) | [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) | ||||
: | : | ||||
: "q0", "q1", "q2", "q3", "memory"); | : "q0", "q1", "q2", "q3", "memory"); | ||||
} | } | ||||
template <typename T> | template <typename T> | ||||
@@ -183,12 +184,12 @@ static inline void interleave_4x4_4_b(const T*& inptr0, const T*& inptr1, | |||||
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value, | std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value, | ||||
"interleave_4x4_4_b only support uint8_t and int8_t"); | "interleave_4x4_4_b only support uint8_t and int8_t"); | ||||
asm volatile( | asm volatile( | ||||
"vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3 | |||||
"vld1.32 {d2, d3}, [%[inptr1]]!\n" // B0B1B2B3 | |||||
"vld1.32 {d4, d5}, [%[inptr2]]!\n" // C0C1C2C3 | |||||
"vld1.32 {d6, d7}, [%[inptr3]]!\n" // D0D1D2D3 | |||||
"vtrn.32 q0, q1\n" // A0B0A2B2 A1B1A3B3 | |||||
"vtrn.32 q2, q3\n" // C0D0C2D2 C1D1C3D3 | |||||
"vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3 | |||||
"vld1.32 {d2, d3}, [%[inptr1]]!\n" // B0B1B2B3 | |||||
"vld1.32 {d4, d5}, [%[inptr2]]!\n" // C0C1C2C3 | |||||
"vld1.32 {d6, d7}, [%[inptr3]]!\n" // D0D1D2D3 | |||||
"vtrn.32 q0, q1\n" // A0B0A2B2 A1B1A3B3 | |||||
"vtrn.32 q2, q3\n" // C0D0C2D2 C1D1C3D3 | |||||
"vswp d1, d4 \n" // q0=A0,B0,C0,D0 q2=A2,B2,C2,D2 | "vswp d1, d4 \n" // q0=A0,B0,C0,D0 q2=A2,B2,C2,D2 | ||||
"vswp d3, d6 \n" // q1=A1,B1,C1,D1 q3=A3,B3,C3,D3 | "vswp d3, d6 \n" // q1=A1,B1,C1,D1 q3=A3,B3,C3,D3 | ||||
"vst1.32 {d0-d1},[%[outptr]]!\n" | "vst1.32 {d0-d1},[%[outptr]]!\n" | ||||
@@ -323,10 +324,10 @@ static inline void interleave_6x4_8_b(const T*& inptr0, const T*& inptr1, | |||||
"vtrn.32 q1, q3 \n" // q1=r02,r12,r03,r13 q3=r06,r16,r07,r17 | "vtrn.32 q1, q3 \n" // q1=r02,r12,r03,r13 q3=r06,r16,r07,r17 | ||||
"vtrn.32 q5, q7 \n" // q5=r22,r32,r23,r33 q7=r26,r36,r27,r37 | "vtrn.32 q5, q7 \n" // q5=r22,r32,r23,r33 q7=r26,r36,r27,r37 | ||||
"vtrn.32 q9, q11 \n" // q9=r42,r52,r43,r53 q11=r46,r56,r47,r57 | "vtrn.32 q9, q11 \n" // q9=r42,r52,r43,r53 q11=r46,r56,r47,r57 | ||||
"vst1.32 {d0-d1}, [%[outptr]]! \n" | |||||
"vst1.32 {d16}, [%[outptr]]! \n" | |||||
"vst1.32 {d0-d1}, [%[outptr]]! \n" | |||||
"vst1.32 {d16}, [%[outptr]]! \n" | |||||
"vswp d3, d10 \n" // q1=r02,r12,r22,r32 q5=r03,r13,r23,r33 | "vswp d3, d10 \n" // q1=r02,r12,r22,r32 q5=r03,r13,r23,r33 | ||||
"vst1.32 {d8-d9}, [%[outptr]]! \n" | |||||
"vst1.32 {d8-d9}, [%[outptr]]! \n" | |||||
"vst1.32 {d17}, [%[outptr]]! \n" | "vst1.32 {d17}, [%[outptr]]! \n" | ||||
"vst1.32 {d2-d3}, [%[outptr]]!\n" | "vst1.32 {d2-d3}, [%[outptr]]!\n" | ||||
"vst1.32 {d18}, [%[outptr]]!\n" | "vst1.32 {d18}, [%[outptr]]!\n" | ||||
@@ -810,15 +811,15 @@ static inline void transpose_12x4_1_h(const T*& inptr0, const T*& inptr1, | |||||
"interleave_12x4_1_h only support uint16_t and int16_t"); | "interleave_12x4_1_h only support uint16_t and int16_t"); | ||||
auto ldin_asm = ldin << 1; | auto ldin_asm = ldin << 1; | ||||
asm volatile( | asm volatile( | ||||
"vld1.16 {d0}, [%[inptr0]]!\n" // A0A1A2A3 | |||||
"vld1.16 {d1}, [%[inptr1]]!\n" // B0B1B2B3 | |||||
"vld1.16 {d2}, [%[inptr2]]!\n" // C0C1C2C3 | |||||
"vld1.16 {d3}, [%[inptr3]]!\n" // D0D1D2D3 | |||||
"vld1.16 {d4}, [%[inptr4]]!\n" // E0E1E2E3 | |||||
"vld1.16 {d5}, [%[inptr5]]!\n" // F0F1F2F3 | |||||
"vld1.16 {d6}, [%[inptr6]]!\n" // G0G1G2G3 | |||||
"vld1.16 {d7}, [%[inptr7]]!\n" // H0H1H2H3 | |||||
"vld1.16 {d8}, [%[inptr8]]!\n" // I0I1I2I3 | |||||
"vld1.16 {d0}, [%[inptr0]]!\n" // A0A1A2A3 | |||||
"vld1.16 {d1}, [%[inptr1]]!\n" // B0B1B2B3 | |||||
"vld1.16 {d2}, [%[inptr2]]!\n" // C0C1C2C3 | |||||
"vld1.16 {d3}, [%[inptr3]]!\n" // D0D1D2D3 | |||||
"vld1.16 {d4}, [%[inptr4]]!\n" // E0E1E2E3 | |||||
"vld1.16 {d5}, [%[inptr5]]!\n" // F0F1F2F3 | |||||
"vld1.16 {d6}, [%[inptr6]]!\n" // G0G1G2G3 | |||||
"vld1.16 {d7}, [%[inptr7]]!\n" // H0H1H2H3 | |||||
"vld1.16 {d8}, [%[inptr8]]!\n" // I0I1I2I3 | |||||
"vld1.16 {d9}, [%[inptr9]]\n" // J0J1J2J3 | "vld1.16 {d9}, [%[inptr9]]\n" // J0J1J2J3 | ||||
"add %[inptr9], %[inptr9], %[ldin_asm]\n" | "add %[inptr9], %[inptr9], %[ldin_asm]\n" | ||||
"vld1.16 {d10}, [%[inptr9]]\n" // K0K1K2K3 | "vld1.16 {d10}, [%[inptr9]]\n" // K0K1K2K3 | ||||
@@ -854,17 +855,15 @@ static inline void transpose_12x4_1_h(const T*& inptr0, const T*& inptr1, | |||||
[inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), | [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), | ||||
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), | [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), | ||||
[inptr9] "+r"(inptr9), [outptr] "+r"(outptr) | [inptr9] "+r"(inptr9), [outptr] "+r"(outptr) | ||||
:[ldin_asm] "r"(ldin_asm) | |||||
: [ldin_asm] "r"(ldin_asm) | |||||
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", | : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", | ||||
"d11", "memory"); | "d11", "memory"); | ||||
inptr9 -= ldin_asm; | |||||
inptr9 += 4; | |||||
inptr9 -= ldin_asm; | |||||
inptr9 += 4; | |||||
inptr10 += 4; | inptr10 += 4; | ||||
inptr11 += 4; | inptr11 += 4; | ||||
} | } | ||||
template <typename T> | template <typename T> | ||||
static inline void transpose_2x16_1_b_helper(const T*& inptr0, const T*& inptr1, | static inline void transpose_2x16_1_b_helper(const T*& inptr0, const T*& inptr1, | ||||
const T*& inptr2, const T*& inptr3, | const T*& inptr2, const T*& inptr3, | ||||
@@ -1038,7 +1037,7 @@ static inline void transpose_4x4_1_s(const T*& inptr0, const T*& inptr1, | |||||
"vst1.32 {d7}, [%[outptr]], %[stride]\n" | "vst1.32 {d7}, [%[outptr]], %[stride]\n" | ||||
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), | : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), | ||||
[inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), | [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), | ||||
[outptr] "+r"(outptr), [stride] "+r" (stride) | |||||
[outptr] "+r"(outptr), [stride] "+r"(stride) | |||||
: | : | ||||
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "memory"); | : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "memory"); | ||||
} | } | ||||
@@ -1069,7 +1068,6 @@ static inline void transpose_4x2_1_s(const T*& inptr0, const T*& inptr1, | |||||
: "d0", "d1", "d2", "d3", "memory"); | : "d0", "d1", "d2", "d3", "memory"); | ||||
} | } | ||||
template <typename T> | template <typename T> | ||||
static inline void transpose_6x4_1_b(const T*& inptr0, const T*& inptr1, | static inline void transpose_6x4_1_b(const T*& inptr0, const T*& inptr1, | ||||
const T*& inptr2, const T*& inptr3, | const T*& inptr2, const T*& inptr3, | ||||
@@ -1082,9 +1080,9 @@ static inline void transpose_6x4_1_b(const T*& inptr0, const T*& inptr1, | |||||
"vld1.8 {d1}, [%[inptr1]]\n" // B0B1B2B3B4B5 B6B7 | "vld1.8 {d1}, [%[inptr1]]\n" // B0B1B2B3B4B5 B6B7 | ||||
"vld1.8 {d2}, [%[inptr2]]\n" // C0C1C2C3C4C5 C6C7 | "vld1.8 {d2}, [%[inptr2]]\n" // C0C1C2C3C4C5 C6C7 | ||||
"vld1.8 {d3}, [%[inptr3]]\n" // D0D1D2D3D4D5 D6D7 | "vld1.8 {d3}, [%[inptr3]]\n" // D0D1D2D3D4D5 D6D7 | ||||
"vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7 | |||||
"vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7 | |||||
"vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7 | |||||
"vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7 | |||||
"add %[inptr0],%[inptr0],#6 \n" | "add %[inptr0],%[inptr0],#6 \n" | ||||
"add %[inptr1],%[inptr1],#6 \n" | "add %[inptr1],%[inptr1],#6 \n" | ||||
"add %[inptr2],%[inptr2],#6 \n" | "add %[inptr2],%[inptr2],#6 \n" | ||||
@@ -1121,9 +1119,9 @@ static inline void transpose_4x4_1_b(const T*& inptr0, const T*& inptr1, | |||||
"vld1.8 {d1}, [%[inptr1]]\n" // B0B1B2B3B4B5 B6B7 | "vld1.8 {d1}, [%[inptr1]]\n" // B0B1B2B3B4B5 B6B7 | ||||
"vld1.8 {d2}, [%[inptr2]]\n" // C0C1C2C3C4C5 C6C7 | "vld1.8 {d2}, [%[inptr2]]\n" // C0C1C2C3C4C5 C6C7 | ||||
"vld1.8 {d3}, [%[inptr3]]\n" // D0D1D2D3D4D5 D6D7 | "vld1.8 {d3}, [%[inptr3]]\n" // D0D1D2D3D4D5 D6D7 | ||||
"vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7 | |||||
"vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7 | |||||
"vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7 | |||||
"vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7 | |||||
"add %[inptr0],%[inptr0],#4 \n" | "add %[inptr0],%[inptr0],#4 \n" | ||||
"add %[inptr1],%[inptr1],#4 \n" | "add %[inptr1],%[inptr1],#4 \n" | ||||
"add %[inptr2],%[inptr2],#4 \n" | "add %[inptr2],%[inptr2],#4 \n" | ||||
@@ -1176,7 +1174,7 @@ static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { | |||||
"vst1.32 {d6-d7}, [%[outptr]]! \n" | "vst1.32 {d6-d7}, [%[outptr]]! \n" | ||||
"vst1.32 {d14-d15}, [%[outptr]]! \n" | "vst1.32 {d14-d15}, [%[outptr]]! \n" | ||||
"vst1.32 {d22-d23}, [%[outptr]]! \n" | "vst1.32 {d22-d23}, [%[outptr]]! \n" | ||||
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) | |||||
: [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) | |||||
: | : | ||||
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", | : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", | ||||
"q11", "memory"); | "q11", "memory"); | ||||
@@ -1195,12 +1193,11 @@ static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { | |||||
"vst1.32 {d4-d5}, [%[outptr]]! \n" | "vst1.32 {d4-d5}, [%[outptr]]! \n" | ||||
"vst1.32 {d2-d3}, [%[outptr]]! \n" | "vst1.32 {d2-d3}, [%[outptr]]! \n" | ||||
"vst1.32 {d6-d7}, [%[outptr]]! \n" | "vst1.32 {d6-d7}, [%[outptr]]! \n" | ||||
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) | |||||
: [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) | |||||
: | : | ||||
: "q0", "q1", "q2", "q3", "memory"); | : "q0", "q1", "q2", "q3", "memory"); | ||||
} | } | ||||
template <typename T> | template <typename T> | ||||
static inline void transpose_4(const T*& inptr0, const T*& inptr1, | static inline void transpose_4(const T*& inptr0, const T*& inptr1, | ||||
const T*& inptr2, const T*& inptr3, T* outptr, | const T*& inptr2, const T*& inptr3, T* outptr, | ||||
@@ -1251,7 +1248,6 @@ static inline void transpose_8(const T*& inptr0, const T*& inptr1, | |||||
} | } | ||||
} | } | ||||
template <typename T> | template <typename T> | ||||
static inline void transpose_4x1(const T*& inptr0, const T*& inptr1, | static inline void transpose_4x1(const T*& inptr0, const T*& inptr1, | ||||
const T*& inptr2, const T*& inptr3, | const T*& inptr2, const T*& inptr3, | ||||
@@ -1375,7 +1371,68 @@ static inline void transpose_interleave_1x4_4_b(const T*& inptr0, T* outptr, | |||||
: "q0", "q1", "q2", "q3", "memory"); | : "q0", "q1", "q2", "q3", "memory"); | ||||
} | } | ||||
} // armv7 | |||||
static inline void interleave_4x4_8x4_s8_s16(const int8_t* inptr0, | |||||
const int8_t* inptr1, | |||||
int16_t* outptr) { | |||||
int8x16_t row0 = vld1q_s8(inptr0); | |||||
int16x8_t row0_01 = vmovl_low_s8(row0); | |||||
int16x8_t row0_23 = vmovl_high_s8(row0); | |||||
int16x4_t row0_0 = vget_low_s16(row0_01); | |||||
int16x4_t row0_1 = vget_high_s16(row0_01); | |||||
int16x4_t row0_2 = vget_low_s16(row0_23); | |||||
int16x4_t row0_3 = vget_high_s16(row0_23); | |||||
int8x16_t row1 = vld1q_s8(inptr1); | |||||
int16x8_t row1_01 = vmovl_low_s8(row1); | |||||
int16x8_t row1_23 = vmovl_high_s8(row1); | |||||
int16x4_t row1_0 = vget_low_s16(row1_01); | |||||
int16x4_t row1_1 = vget_high_s16(row1_01); | |||||
int16x4_t row1_2 = vget_low_s16(row1_23); | |||||
int16x4_t row1_3 = vget_high_s16(row1_23); | |||||
vst1_s16(outptr, row0_0); | |||||
vst1_s16(outptr + 1 * 4, row1_0); | |||||
vst1_s16(outptr + 2 * 4, row0_1); | |||||
vst1_s16(outptr + 3 * 4, row1_1); | |||||
vst1_s16(outptr + 4 * 4, row0_2); | |||||
vst1_s16(outptr + 5 * 4, row1_2); | |||||
vst1_s16(outptr + 6 * 4, row0_3); | |||||
vst1_s16(outptr + 7 * 4, row1_3); | |||||
}; | |||||
static inline void transpos_8x4_int8(const int8_t* inptr0, int8_t* outptr) { | |||||
int8x8x4_t input = vld4_s8(inptr0); | |||||
vst1_s8(outptr, input.val[0]); | |||||
vst1_s8(outptr + 1 * 8, input.val[1]); | |||||
vst1_s8(outptr + 2 * 8, input.val[2]); | |||||
vst1_s8(outptr + 3 * 8, input.val[3]); | |||||
} | |||||
static inline void memcpy_s8_s16(const int8_t* inptr, int16_t* outptr, | |||||
int count) { | |||||
for (; count >= 32; count -= 32) { | |||||
int8x8_t in0 = vld1_s8(inptr); | |||||
int8x8_t in1 = vld1_s8(inptr + 1 * 8); | |||||
int8x8_t in2 = vld1_s8(inptr + 2 * 8); | |||||
int8x8_t in3 = vld1_s8(inptr + 3 * 8); | |||||
vst1q_s16(outptr, vmovl_s8(in0)); | |||||
vst1q_s16(outptr + 1 * 8, vmovl_s8(in1)); | |||||
vst1q_s16(outptr + 2 * 8, vmovl_s8(in2)); | |||||
vst1q_s16(outptr + 3 * 8, vmovl_s8(in3)); | |||||
inptr += 32; | |||||
outptr += 32; | |||||
} | |||||
for (; count >= 8; count -= 8) { | |||||
int8x8_t in0 = vld1_s8(inptr); | |||||
vst1q_s16(outptr, vmovl_s8(in0)); | |||||
inptr += 8; | |||||
outptr += 8; | |||||
} | |||||
for (; count > 0; --count) { | |||||
*outptr++ = (int16_t)(*inptr++); | |||||
} | |||||
} | |||||
} // namespace armv7 | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -102,60 +102,60 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
"vld1.8 {d2}, [%[a_ptr]]!\n" | "vld1.8 {d2}, [%[a_ptr]]!\n" | ||||
"vld1.8 {d4}, [%[a_ptr]]!\n" | "vld1.8 {d4}, [%[a_ptr]]!\n" | ||||
"vld1.8 {d6}, [%[a_ptr]]!\n" | "vld1.8 {d6}, [%[a_ptr]]!\n" | ||||
"vld1.8 {d18}, [%[b_ptr]]!\n" | |||||
"vmovl.s8 q8, d16\n" | "vmovl.s8 q8, d16\n" | ||||
"vmovl.s8 q0, d0\n" | "vmovl.s8 q0, d0\n" | ||||
"vmovl.s8 q1, d2\n" | "vmovl.s8 q1, d2\n" | ||||
"vmovl.s8 q2, d4\n" | "vmovl.s8 q2, d4\n" | ||||
"vmovl.s8 q3, d6\n" | "vmovl.s8 q3, d6\n" | ||||
"vld1.8 {d18}, [%[b_ptr]]!\n" | |||||
"vmovl.s8 q9, d18\n" | |||||
"vld1.8 {d20}, [%[b_ptr]]!\n" | |||||
"vmla.s16 q4, q8, d0[0]\n" | "vmla.s16 q4, q8, d0[0]\n" | ||||
"vmla.s16 q5, q8, d2[0]\n" | "vmla.s16 q5, q8, d2[0]\n" | ||||
"vmla.s16 q6, q8, d4[0]\n" | "vmla.s16 q6, q8, d4[0]\n" | ||||
"vmla.s16 q7, q8, d6[0]\n" | "vmla.s16 q7, q8, d6[0]\n" | ||||
"vmovl.s8 q9, d18\n" | |||||
"vld1.8 {d20}, [%[b_ptr]]!\n" | |||||
"vmovl.s8 q10, d20\n" | |||||
"vld1.8 {d22}, [%[b_ptr]]!\n" | |||||
"vmla.s16 q4, q9, d0[1]\n" | "vmla.s16 q4, q9, d0[1]\n" | ||||
"vmla.s16 q5, q9, d2[1]\n" | "vmla.s16 q5, q9, d2[1]\n" | ||||
"vmla.s16 q6, q9, d4[1]\n" | "vmla.s16 q6, q9, d4[1]\n" | ||||
"vmla.s16 q7, q9, d6[1]\n" | "vmla.s16 q7, q9, d6[1]\n" | ||||
"vmovl.s8 q10, d20\n" | |||||
"vld1.8 {d22}, [%[b_ptr]]!\n" | |||||
"vmovl.s8 q11, d22\n" | |||||
"vld1.8 {d24}, [%[b_ptr]]!\n" | |||||
"vmla.s16 q4, q10, d0[2]\n" | "vmla.s16 q4, q10, d0[2]\n" | ||||
"vmla.s16 q5, q10, d2[2]\n" | "vmla.s16 q5, q10, d2[2]\n" | ||||
"vmla.s16 q6, q10, d4[2]\n" | "vmla.s16 q6, q10, d4[2]\n" | ||||
"vmla.s16 q7, q10, d6[2]\n" | "vmla.s16 q7, q10, d6[2]\n" | ||||
"vmovl.s8 q11, d22\n" | |||||
"vld1.8 {d24}, [%[b_ptr]]!\n" | |||||
"vmovl.s8 q12, d24\n" | |||||
"vld1.8 {d26}, [%[b_ptr]]!\n" | |||||
"vmla.s16 q4, q11, d0[3]\n" | "vmla.s16 q4, q11, d0[3]\n" | ||||
"vmla.s16 q5, q11, d2[3]\n" | "vmla.s16 q5, q11, d2[3]\n" | ||||
"vmla.s16 q6, q11, d4[3]\n" | "vmla.s16 q6, q11, d4[3]\n" | ||||
"vmla.s16 q7, q11, d6[3]\n" | "vmla.s16 q7, q11, d6[3]\n" | ||||
"vmovl.s8 q12, d24\n" | |||||
"vld1.8 {d26}, [%[b_ptr]]!\n" | |||||
"vmovl.s8 q13, d26\n" | |||||
"vld1.8 {d28}, [%[b_ptr]]!\n" | |||||
"vmla.s16 q4, q12, d1[0]\n" | "vmla.s16 q4, q12, d1[0]\n" | ||||
"vmla.s16 q5, q12, d3[0]\n" | "vmla.s16 q5, q12, d3[0]\n" | ||||
"vmla.s16 q6, q12, d5[0]\n" | "vmla.s16 q6, q12, d5[0]\n" | ||||
"vmla.s16 q7, q12, d7[0]\n" | "vmla.s16 q7, q12, d7[0]\n" | ||||
"vmovl.s8 q13, d26\n" | |||||
"vld1.8 {d28}, [%[b_ptr]]!\n" | |||||
"vmovl.s8 q14, d28\n" | |||||
"vld1.8 {d30}, [%[b_ptr]]!\n" | |||||
"vmla.s16 q4, q13, d1[1]\n" | "vmla.s16 q4, q13, d1[1]\n" | ||||
"vmla.s16 q5, q13, d3[1]\n" | "vmla.s16 q5, q13, d3[1]\n" | ||||
"vmla.s16 q6, q13, d5[1]\n" | "vmla.s16 q6, q13, d5[1]\n" | ||||
"vmla.s16 q7, q13, d7[1]\n" | "vmla.s16 q7, q13, d7[1]\n" | ||||
"vmovl.s8 q14, d28\n" | |||||
"vld1.8 {d30}, [%[b_ptr]]!\n" | |||||
"vmovl.s8 q15, d30\n" | |||||
"vmla.s16 q4, q14, d1[2]\n" | "vmla.s16 q4, q14, d1[2]\n" | ||||
"vmla.s16 q5, q14, d3[2]\n" | "vmla.s16 q5, q14, d3[2]\n" | ||||
"vmla.s16 q6, q14, d5[2]\n" | "vmla.s16 q6, q14, d5[2]\n" | ||||
"vmla.s16 q7, q14, d7[2]\n" | "vmla.s16 q7, q14, d7[2]\n" | ||||
"vmovl.s8 q15, d30\n" | |||||
"vmla.s16 q4, q15, d1[3]\n" | "vmla.s16 q4, q15, d1[3]\n" | ||||
"vmla.s16 q5, q15, d3[3]\n" | "vmla.s16 q5, q15, d3[3]\n" | ||||
@@ -0,0 +1,406 @@ | |||||
/** | |||||
* \file dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/armv7/matrix_mul/asm/common.h" | |||||
namespace megdnn { | |||||
namespace armv7 { | |||||
namespace matmul_mk4_8x8x4 { | |||||
//! optimize for A7 | |||||
/** | |||||
* Overview of register layout: | |||||
* | |||||
* A 8x8x8 cell of Lhs is stored in 16bit in q0, q1 | |||||
* A 8x8x8 cell of Rhs is stored in 8bit in q2, q3 | |||||
* A 8x8 block of accumulators is stored in 16bit in q8-q15 | |||||
* | |||||
* +--------+ | |||||
* | q4[0-8]| | |||||
* Rhs +--------+ | |||||
* Lhs | | | |||||
* | |||||
* +--------+ - - - - +--------- | |||||
* |q0[0]| | q8 [0-8]| | |||||
* |q0[1]| | q9 [0-8]| | |||||
* |q0[2]| | q10[0-8]| | |||||
* |q0[3]| | q11[0-8]| | |||||
* |q0[4]| | q12[0-8]| | |||||
* |q0[5]| | q13[0-8]| | |||||
* |q0[6]| | q14[0-8]| | |||||
* |q0[7]| | q15[0-8]| | |||||
* +--------+ - - - - +--------- | |||||
* | |||||
* Accumulator | |||||
*/ | |||||
static void kern_8x8(const int16_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool is_first_k, int remain_n) { | |||||
K /= 4; | |||||
const int16_t* a_ptr = packA; | |||||
const int8_t* b_ptr = packB; | |||||
LDC = LDC * sizeof(int16_t); | |||||
int x0 = 0; | |||||
// clang-format off | |||||
#define STORE_LINE(reg_index1, reg_index2) \ | |||||
"cmp %[x0], #0 \n" \ | |||||
"beq 101f\n" \ | |||||
"vst1.16 {d" reg_index1 "}, [r0]!\n" \ | |||||
"vst1.16 {d" reg_index2 "}, [r1]!\n" \ | |||||
"subs %[x0], %[x0], #1\n" | |||||
#define STORE_C \ | |||||
"mov %[x0], %[remain_n]\n" \ | |||||
STORE_LINE("16", "17") \ | |||||
STORE_LINE("18", "19") \ | |||||
STORE_LINE("20", "21") \ | |||||
STORE_LINE("22", "23") \ | |||||
STORE_LINE("24", "25") \ | |||||
STORE_LINE("26", "27") \ | |||||
STORE_LINE("28", "29") \ | |||||
STORE_LINE("30", "31") \ | |||||
"101:\n" | |||||
// clang-format on | |||||
register int16_t* outptr asm("r0") = output; | |||||
asm volatile( | |||||
// load accumulator C | |||||
"add r1, r0, %[LDC]\n" | |||||
"cmp %[is_first_k], #1\n" | |||||
"beq 1f\n" | |||||
"b 2f\n" | |||||
"1:\n" | |||||
"veor.s32 q8 , q8 , q8 \n" | |||||
"veor.s32 q9 , q9 , q9 \n" | |||||
"veor.s32 q10, q10, q10\n" | |||||
"veor.s32 q11, q11, q11\n" | |||||
"veor.s32 q12, q12, q12\n" | |||||
"veor.s32 q13, q13, q13\n" | |||||
"veor.s32 q14, q14, q14\n" | |||||
"veor.s32 q15, q15, q15\n" | |||||
"2: \n" | |||||
"vld1.8 {d4}, [%[b_ptr]]!\n" | |||||
"vld1.16 {d0, d1}, [%[a_ptr]]!\n" | |||||
"vmovl.s8 q2, d4\n" | |||||
"vld1.16 {d2, d3}, [%[a_ptr]]!\n" | |||||
"vld1.8 {d6}, [%[b_ptr]]!\n" | |||||
//! k0 | |||||
"vmla.s16 q8, q0, d4[0]\n" | |||||
"vmla.s16 q9, q0, d4[1]\n" | |||||
"vmla.s16 q10, q0, d4[2]\n" | |||||
"vmla.s16 q11, q0, d4[3]\n" | |||||
"vmovl.s8 q3, d6\n" | |||||
"vmla.s16 q12, q0, d5[0]\n" | |||||
"vmla.s16 q13, q0, d5[1]\n" | |||||
"vmla.s16 q14, q0, d5[2]\n" | |||||
"vmla.s16 q15, q0, d5[3]\n" | |||||
//! k1 | |||||
"vld1.16 {d0, d1}, [%[a_ptr]]!\n" | |||||
"vld1.8 {d4}, [%[b_ptr]]!\n" | |||||
"vmla.s16 q8, q1, d6[0]\n" | |||||
"vmla.s16 q9, q1, d6[1]\n" | |||||
"vmla.s16 q10, q1, d6[2]\n" | |||||
"vmla.s16 q11, q1, d6[3]\n" | |||||
"vmovl.s8 q2, d4\n" | |||||
"vmla.s16 q12, q1, d7[0]\n" | |||||
"vmla.s16 q13, q1, d7[1]\n" | |||||
"vmla.s16 q14, q1, d7[2]\n" | |||||
"vmla.s16 q15, q1, d7[3]\n" | |||||
//! k2 | |||||
"vld1.16 {d2, d3}, [%[a_ptr]]!\n" | |||||
"vld1.8 {d6}, [%[b_ptr]]!\n" | |||||
"vmla.s16 q8, q0, d4[0]\n" | |||||
"vmla.s16 q9, q0, d4[1]\n" | |||||
"vmla.s16 q10, q0, d4[2]\n" | |||||
"vmla.s16 q11, q0, d4[3]\n" | |||||
"vmovl.s8 q3, d6\n" | |||||
"vmla.s16 q12, q0, d5[0]\n" | |||||
"vmla.s16 q13, q0, d5[1]\n" | |||||
"vmla.s16 q14, q0, d5[2]\n" | |||||
"vmla.s16 q15, q0, d5[3]\n" | |||||
//! k3 | |||||
"vmla.s16 q8, q1, d6[0]\n" | |||||
"vmla.s16 q9, q1, d6[1]\n" | |||||
"vmla.s16 q10, q1, d6[2]\n" | |||||
"vmla.s16 q11, q1, d6[3]\n" | |||||
"vmla.s16 q12, q1, d7[0]\n" | |||||
"vmla.s16 q13, q1, d7[1]\n" | |||||
"vmla.s16 q14, q1, d7[2]\n" | |||||
"vmla.s16 q15, q1, d7[3]\n" | |||||
"subs %[K], %[K], #1\n" | |||||
"bne 2b\n" | |||||
"3:\n" | |||||
"cmp %[remain_n], #8\n" | |||||
"bne 4f\n" | |||||
"vstr d16, [r0]\n" | |||||
"vstr d18, [r0, #8]\n" | |||||
"vstr d20, [r0, #16]\n" | |||||
"vstr d22, [r0, #24]\n" | |||||
"vstr d24, [r0, #32]\n" | |||||
"vstr d26, [r0, #40]\n" | |||||
"vstr d28, [r0, #48]\n" | |||||
"vstr d30, [r0, #56]\n" | |||||
"vstr d17, [r1]\n" | |||||
"vstr d19, [r1, #8]\n" | |||||
"vstr d21, [r1, #16]\n" | |||||
"vstr d23, [r1, #24]\n" | |||||
"vstr d25, [r1, #32]\n" | |||||
"vstr d27, [r1, #40]\n" | |||||
"vstr d29, [r1, #48]\n" | |||||
"vstr d31, [r1, #56]\n" | |||||
"b 101f\n" | |||||
"4:\n " STORE_C | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[x0] "+r"(x0), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||||
: | |||||
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", | |||||
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", | |||||
"d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", | |||||
"d29", "d30", "d31", "r1", "r2", "r3", "cc", "memory"); | |||||
#undef STORE_C | |||||
#undef STORE_LINE | |||||
} | |||||
/** | |||||
* Overview of register layout: | |||||
* | |||||
* A 8x8x8 cell of Lhs is stored in 16bit in d0, d2 | |||||
* A 8x8x8 cell of Rhs is stored in 8bit in q2, q3 | |||||
* A 8x8 block of accumulators is stored in 16bit in q8-11 | |||||
* | |||||
* +--------+ | |||||
* | q4[0-8]| | |||||
* Rhs +--------+ | |||||
* Lhs | | | |||||
* | |||||
* +--------+ - - - - +--------- | |||||
* |d0[0]| | q8 [0-8]| | |||||
* |d0[1]| | q9 [0-8]| | |||||
* |d0[2]| | q10[0-8]| | |||||
* |d0[3]| | q11[0-8]| | |||||
* +--------+ - - - - +--------- | |||||
* | |||||
* Accumulator | |||||
*/ | |||||
static void kern_4x8(const int16_t* packA, const int8_t* packB, int K, | |||||
int16_t* output, int LDC, bool is_first_k, int remain_n) { | |||||
K /= 4; | |||||
const int16_t* a_ptr = packA; | |||||
const int8_t* b_ptr = packB; | |||||
LDC = LDC * sizeof(int16_t); | |||||
int x0 = 0; | |||||
// clang-format off | |||||
#define STORE_LINE(reg_index1) \ | |||||
"cmp %[x0], #0 \n" \ | |||||
"beq 101f\n" \ | |||||
"vst1.16 {d" reg_index1 "}, [r0]!\n" \ | |||||
"subs %[x0], %[x0], #1\n" | |||||
#define STORE_C \ | |||||
"mov %[x0], %[remain_n]\n" \ | |||||
STORE_LINE("16") \ | |||||
STORE_LINE("18") \ | |||||
STORE_LINE("20") \ | |||||
STORE_LINE("22") \ | |||||
STORE_LINE("24") \ | |||||
STORE_LINE("26") \ | |||||
STORE_LINE("28") \ | |||||
STORE_LINE("30") \ | |||||
"101:\n" | |||||
// clang-format on | |||||
register int16_t* outptr asm("r0") = output; | |||||
asm volatile( | |||||
//! load accumulator C | |||||
"add r1, r0, %[LDC]\n" | |||||
"cmp %[is_first_k], #1\n" | |||||
"beq 1f\n" | |||||
"b 2f\n" | |||||
"1:\n" | |||||
"veor.s32 q8 , q8 , q8 \n" | |||||
"veor.s32 q9 , q9 , q9 \n" | |||||
"veor.s32 q10, q10, q10\n" | |||||
"veor.s32 q11, q11, q11\n" | |||||
"veor.s32 q12, q12, q12\n" | |||||
"veor.s32 q13, q13, q13\n" | |||||
"veor.s32 q14, q14, q14\n" | |||||
"veor.s32 q15, q15, q15\n" | |||||
"2: \n" | |||||
"vld1.8 {d4}, [%[b_ptr]]!\n" | |||||
"vld1.16 {d0}, [%[a_ptr]]!\n" | |||||
"vmovl.s8 q2, d4\n" | |||||
"vld1.16 {d2}, [%[a_ptr]]!\n" | |||||
"vld1.8 {d6}, [%[b_ptr]]!\n" | |||||
//! k0 | |||||
"vmla.s16 d16, d0, d4[0]\n" | |||||
"vmla.s16 d18, d0, d4[1]\n" | |||||
"vmla.s16 d20, d0, d4[2]\n" | |||||
"vmla.s16 d22, d0, d4[3]\n" | |||||
"vmovl.s8 q3, d6\n" | |||||
"vmla.s16 d24, d0, d5[0]\n" | |||||
"vmla.s16 d26, d0, d5[1]\n" | |||||
"vmla.s16 d28, d0, d5[2]\n" | |||||
"vmla.s16 d30, d0, d5[3]\n" | |||||
//! k1 | |||||
"vld1.16 {d0}, [%[a_ptr]]!\n" | |||||
"vld1.8 {d4}, [%[b_ptr]]!\n" | |||||
"vmla.s16 d16, d2, d6[0]\n" | |||||
"vmla.s16 d18, d2, d6[1]\n" | |||||
"vmla.s16 d20, d2, d6[2]\n" | |||||
"vmla.s16 d22, d2, d6[3]\n" | |||||
"vmovl.s8 q2, d4\n" | |||||
"vmla.s16 d24, d2, d7[0]\n" | |||||
"vmla.s16 d26, d2, d7[1]\n" | |||||
"vmla.s16 d28, d2, d7[2]\n" | |||||
"vmla.s16 d30, d2, d7[3]\n" | |||||
//! k2 | |||||
"vld1.16 {d2}, [%[a_ptr]]!\n" | |||||
"vld1.8 {d6}, [%[b_ptr]]!\n" | |||||
"vmla.s16 d16, d0, d4[0]\n" | |||||
"vmla.s16 d18, d0, d4[1]\n" | |||||
"vmla.s16 d20, d0, d4[2]\n" | |||||
"vmla.s16 d22, d0, d4[3]\n" | |||||
"vmovl.s8 q3, d6\n" | |||||
"vmla.s16 d24, d0, d5[0]\n" | |||||
"vmla.s16 d26, d0, d5[1]\n" | |||||
"vmla.s16 d28, d0, d5[2]\n" | |||||
"vmla.s16 d30, d0, d5[3]\n" | |||||
//! k3 | |||||
"vmla.s16 d16, d2, d6[0]\n" | |||||
"vmla.s16 d18, d2, d6[1]\n" | |||||
"vmla.s16 d20, d2, d6[2]\n" | |||||
"vmla.s16 d22, d2, d6[3]\n" | |||||
"vmla.s16 d24, d2, d7[0]\n" | |||||
"vmla.s16 d26, d2, d7[1]\n" | |||||
"vmla.s16 d28, d2, d7[2]\n" | |||||
"vmla.s16 d30, d2, d7[3]\n" | |||||
"subs %[K], %[K], #1\n" | |||||
"bne 2b\n" | |||||
"3:\n" | |||||
"cmp %[remain_n], #8\n" | |||||
"bne 4f\n" | |||||
"vstr d16, [r0]\n" | |||||
"vstr d18, [r0, #8]\n" | |||||
"vstr d20, [r0, #16]\n" | |||||
"vstr d22, [r0, #24]\n" | |||||
"vstr d24, [r0, #32]\n" | |||||
"vstr d26, [r0, #40]\n" | |||||
"vstr d28, [r0, #48]\n" | |||||
"vstr d30, [r0, #56]\n" | |||||
"b 101f\n" | |||||
"4:\n " STORE_C | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
[x0] "+r"(x0), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||||
: | |||||
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", | |||||
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", | |||||
"d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", | |||||
"d29", "d30", "d31", "r1", "r2", "r3", "cc", "memory"); | |||||
#undef STORE_C | |||||
#undef STORE_LINE | |||||
} | |||||
static void gemm_s8x8x16_mk4_8x8_pack_A_n(dt_int16* outptr, | |||||
const dt_int8* inptr, int ldin, | |||||
int m0, int mmax, int k0, int kmax) { | |||||
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | |||||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||||
constexpr int pack_m = 8; | |||||
constexpr int pack_k = 4; | |||||
constexpr int pack_size = 4; | |||||
const int m_size = mmax - m0; | |||||
const int m_end = m_size / pack_m * pack_m + m0; | |||||
const int remain_m = mmax - m_end; | |||||
for (int m_idx = m0; m_idx < m_end; m_idx += pack_m) { | |||||
const int8_t* inptr0 = inptr + m_idx / pack_size * ldin + k0; | |||||
const int8_t* inptr1 = inptr0 + ldin; | |||||
prefetch_2x(inptr0); | |||||
prefetch_2x(inptr1); | |||||
for (int k_idx = k0; k_idx < kmax; k_idx += pack_size) { | |||||
interleave_4x4_8x4_s8_s16(inptr0, inptr1, outptr); | |||||
inptr0 += pack_size * pack_size; | |||||
inptr1 += pack_size * pack_size; | |||||
outptr += pack_m * pack_k; | |||||
} | |||||
} | |||||
if (remain_m > 0) { | |||||
const int8_t* inptr0 = inptr + m_end / pack_size * ldin + k0; | |||||
const int k_size = kmax - k0; | |||||
memcpy_s8_s16(inptr0, outptr, k_size * pack_size); | |||||
} | |||||
} | |||||
static void gemm_s8x8x16_mk4_8x8_pack_B_n(dt_int8* out, const dt_int8* in, | |||||
int ldin, int n0, int nmax, int k0, | |||||
int kmax) { | |||||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||||
int8_t tmpbuff[32] = {0}; | |||||
constexpr int pack_n = 8; | |||||
constexpr int pack_size = 4; | |||||
const int ksize = kmax - k0; | |||||
const int nsize = nmax - n0; | |||||
const int n_end = nsize / pack_n * pack_n + n0; | |||||
const int remain_n = nsize % pack_n; | |||||
int output_stride = ksize * pack_n; | |||||
int8_t* outptr_base = out; | |||||
for (int k_idx = k0; k_idx < kmax; k_idx += pack_size) { | |||||
const int8_t* inptr = in + k_idx / pack_size * ldin + n0 * pack_size; | |||||
prefetch_3x(inptr); | |||||
auto outptr = outptr_base; | |||||
for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) { | |||||
transpos_8x4_int8(inptr, outptr); | |||||
inptr += pack_n * pack_size; | |||||
outptr += output_stride; | |||||
} | |||||
if (remain_n > 0) { | |||||
memcpy(tmpbuff, inptr, sizeof(int8_t) * remain_n * pack_size); | |||||
transpos_8x4_int8(tmpbuff, outptr); | |||||
outptr += output_stride; | |||||
} | |||||
outptr_base += pack_n * pack_size; | |||||
} | |||||
} | |||||
} // namespace matmul_mk4_8x8x4 | |||||
} // namespace armv7 | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -6,14 +6,16 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/armv7/matrix_mul/int8x8x16/strategy.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/armv7/matrix_mul/asm/common.h" | #include "src/armv7/matrix_mul/asm/common.h" | ||||
#include "src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h" | #include "src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h" | ||||
#include "src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h" | #include "src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h" | ||||
#include "src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h" | |||||
#include "src/armv7/matrix_mul/int8x8x16/strategy.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/matrix_mul/gemm_common.h" | #include "src/fallback/matrix_mul/gemm_common.h" | ||||
@@ -108,7 +110,7 @@ void gemm_s8x8x16_4x2::kern(const dt_int8* packA, const dt_int8* packB, | |||||
} | } | ||||
} | } | ||||
// ===========================gemm_s8x8x16_4x4================================== | |||||
// ===========================gemm_s8x8x16_4x8================================== | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x8); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x8); | ||||
void gemm_s8x8x16_4x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | void gemm_s8x8x16_4x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | ||||
@@ -179,4 +181,79 @@ void gemm_s8x8x16_4x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
} | } | ||||
} | } | ||||
// ===========================gemm_s8x8x16_mk4_8x8================================== | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8); | |||||
void gemm_s8x8x16_mk4_8x8::pack_A(dt_int16* out, const dt_int8* in, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | |||||
bool) const { | |||||
matmul_mk4_8x8x4::gemm_s8x8x16_mk4_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, | |||||
kmax); | |||||
} | |||||
void gemm_s8x8x16_mk4_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax, | |||||
bool) const { | |||||
matmul_mk4_8x8x4::gemm_s8x8x16_mk4_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, | |||||
kmax); | |||||
} | |||||
void gemm_s8x8x16_mk4_8x8::kern(const dt_int16* packA, const dt_int8* packB, | |||||
size_t M, size_t N, size_t K, dt_int16* C, | |||||
size_t LDC, bool is_first_k, const dt_int16*, | |||||
dt_int16*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
C_dtype.enumv() == DTypeEnum::Int16 && | |||||
A_dtype.enumv() == DTypeEnum::Int8); | |||||
megdnn_assert(is_first_k == true, "only impl is_first_k"); | |||||
MEGDNN_MARK_USED_VAR(A_dtype); | |||||
MEGDNN_MARK_USED_VAR(B_dtype); | |||||
MEGDNN_MARK_USED_VAR(C_dtype); | |||||
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | |||||
constexpr size_t pack_size = 4; | |||||
constexpr size_t pack_m = 8; | |||||
constexpr size_t pack_n = 8; | |||||
const size_t remain_n = N % pack_n; | |||||
const size_t remain_m = M % pack_m; | |||||
size_t m_idx = 0; | |||||
for (; m_idx + pack_m <= M; m_idx += pack_m) { | |||||
int16_t* output = C + (m_idx / pack_size * LDC); | |||||
size_t n_idx = 0; | |||||
const int8_t* cur_packB = packB; | |||||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||||
matmul_mk4_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k, pack_n); | |||||
output += pack_n * pack_size; | |||||
cur_packB += pack_n * K; | |||||
} | |||||
if (remain_n > 0) { | |||||
matmul_mk4_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k, remain_n); | |||||
output += remain_n * pack_size; | |||||
cur_packB += pack_n * K; | |||||
} | |||||
packA += pack_m * K; | |||||
} | |||||
if (remain_m > 0) { | |||||
int16_t* output = C + (m_idx / pack_size * LDC); | |||||
size_t n_idx = 0; | |||||
const int8_t* cur_packB = packB; | |||||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||||
matmul_mk4_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k, pack_n); | |||||
output += pack_n * pack_size; | |||||
cur_packB += pack_n * K; | |||||
} | |||||
if (remain_n > 0) { | |||||
matmul_mk4_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, | |||||
is_first_k, remain_n); | |||||
output += remain_n * pack_size; | |||||
cur_packB += pack_n * K; | |||||
} | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "src/fallback/matrix_mul/gemm_common.h" | #include "src/fallback/matrix_mul/gemm_common.h" | ||||
@@ -21,6 +22,10 @@ MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 2, 16, false, true, | |||||
MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 8, 8, false, true, | MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 8, 8, false, true, | ||||
gemm_s8x8x16_4x8); | gemm_s8x8x16_4x8); | ||||
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(int8_t, int16_t, int16_t, int16_t, 8, | |||||
8, 4, false, false, | |||||
gemm_s8x8x16_mk4_8x8); | |||||
} // namespace matmul | } // namespace matmul | ||||
} // namespace armv7 | } // namespace armv7 | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -6,10 +6,11 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "src/armv7/matrix_mul/opr_impl.h" | |||||
#include "src/armv7/matrix_mul/algos.h" | #include "src/armv7/matrix_mul/algos.h" | ||||
#include "src/armv7/matrix_mul/opr_impl.h" | |||||
#include "src/common/metahelper.h" | #include "src/common/metahelper.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/matrix_mul/gemm_impl.h" | #include "src/fallback/matrix_mul/gemm_impl.h" | ||||
@@ -21,7 +22,7 @@ using namespace armv7; | |||||
class MatrixMulImpl::AlgoPack : NonCopyableObj { | class MatrixMulImpl::AlgoPack : NonCopyableObj { | ||||
AlgoF32 f32; | AlgoF32 f32; | ||||
AlgoF32MK4Pack4x12 f32_mk4_pack_4x12; | AlgoF32MK4Pack4x12 f32_mk4_pack_4x12; | ||||
AlgoF32MK4_4x8 f32_mk4_4x8; | |||||
AlgoF32MK4_4x8 f32_mk4_4x8; | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
AlgoF16K4x16x1 f16_k4x16x1; | AlgoF16K4x16x1 f16_k4x16x1; | ||||
AlgoF16MK8_4x8 f16_mk8_4x8; | AlgoF16MK8_4x8 f16_mk8_4x8; | ||||
@@ -38,6 +39,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
AlgoQuint8K4x8x8 quint8_k4x8x8; | AlgoQuint8K4x8x8 quint8_k4x8x8; | ||||
AlgoInt8x8x16K4x2x16 int8x8x16_k4x2x16; | AlgoInt8x8x16K4x2x16 int8x8x16_k4x2x16; | ||||
AlgoInt8x8x16K4x8x8 int8x8x16_k4x8x8; | AlgoInt8x8x16K4x8x8 int8x8x16_k4x8x8; | ||||
AlgoInt8x8x16MK4_8x8x4 int8x8x16_mk4_8x8x4; | |||||
AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1; | AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1; | ||||
AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8; | AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8; | ||||
@@ -62,8 +64,10 @@ public: | |||||
all_algos.emplace_back(&int8x8x32_k4x2x16); | all_algos.emplace_back(&int8x8x32_k4x2x16); | ||||
all_algos.emplace_back(&int8x8x32_k4x8x8); | all_algos.emplace_back(&int8x8x32_k4x8x8); | ||||
all_algos.emplace_back(&quint8_k4x8x8); | all_algos.emplace_back(&quint8_k4x8x8); | ||||
all_algos.emplace_back(&int8x8x16_mk4_8x8x4); | |||||
all_algos.emplace_back(&int8x8x16_k4x2x16); | all_algos.emplace_back(&int8x8x16_k4x2x16); | ||||
all_algos.emplace_back(&int8x8x16_k4x8x8); | all_algos.emplace_back(&int8x8x16_k4x8x8); | ||||
all_algos.emplace_back(&int16x16x32_k12x4x1); | all_algos.emplace_back(&int16x16x32_k12x4x1); | ||||
all_algos.emplace_back(&int16x16x32_mk8_4x8); | all_algos.emplace_back(&int16x16x32_mk8_4x8); | ||||
} | } | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "src/arm_common/matrix_mul/opr_impl.h" | #include "src/arm_common/matrix_mul/opr_impl.h" | ||||
@@ -19,26 +20,28 @@ public: | |||||
using arm_common::MatrixMulImpl::MatrixMulImpl; | using arm_common::MatrixMulImpl::MatrixMulImpl; | ||||
SmallVector<AlgoBase*> algo_pack() override; | SmallVector<AlgoBase*> algo_pack() override; | ||||
private: | private: | ||||
class AlgoF32; // Armv7 F32 | |||||
class AlgoF32MK4Pack4x12; // Armv7 F32 Kernel 4x12 with pack | |||||
class AlgoF32MK4_4x8; // Armv7 F32 Kernel 4x8 nopack | |||||
class AlgoF32Gemv; // Armv7 F32 Gemv | |||||
class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 | |||||
class AlgoInt8x8x32K4x2x16; // Armv7 Int8x8x32 Kernel 4x2x16 | |||||
class AlgoF32; // Armv7 F32 | |||||
class AlgoF32MK4Pack4x12; // Armv7 F32 Kernel 4x12 with pack | |||||
class AlgoF32MK4_4x8; // Armv7 F32 Kernel 4x8 nopack | |||||
class AlgoF32Gemv; // Armv7 F32 Gemv | |||||
class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 | |||||
class AlgoInt8x8x32K4x2x16; // Armv7 Int8x8x32 Kernel 4x2x16 | |||||
class AlgoInt8x8x32MK4_4x2x16; // Armv7 Int8x8x32 Kernel MK4 4x2x16 | class AlgoInt8x8x32MK4_4x2x16; // Armv7 Int8x8x32 Kernel MK4 4x2x16 | ||||
class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8 | |||||
class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16 | |||||
class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8 | |||||
class AlgoInt16x16x32K12x4x1; // Armv7 Int16x16x32 Kernel 12x4x1 | |||||
class AlgoInt16x16x32MK8_4x8; // Armv7 Int16x16x32 MK8 Format block 4x8 | |||||
class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8 | |||||
class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16 | |||||
class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8 | |||||
class AlgoInt8x8x16MK4_8x8x4; // Armv7 Int8x8x16 Kernel 8x8x8 | |||||
class AlgoInt16x16x32K12x4x1; // Armv7 Int16x16x32 Kernel 12x4x1 | |||||
class AlgoInt16x16x32MK8_4x8; // Armv7 Int16x16x32 MK8 Format block 4x8 | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
class AlgoF16K4x16x1; // Armv7 F16 Kernel 4x16x1 | class AlgoF16K4x16x1; // Armv7 F16 Kernel 4x16x1 | ||||
class AlgoF16MK8_4x8; // Armv7 F16 MK8 Format block 4x8 | class AlgoF16MK8_4x8; // Armv7 F16 MK8 Format block 4x8 | ||||
#endif | #endif | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4 | |||||
class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4 | |||||
class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4 | |||||
class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4 | |||||
class AlgoInt8x8x32MK4_8x4x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x4x4 | class AlgoInt8x8x32MK4_8x4x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x4x4 | ||||
// DotProduct | // DotProduct | ||||
#endif | #endif | ||||
@@ -10,9 +10,9 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/fallback/conv_bias/conv1x1/algos.h" | |||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
#include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
#include "src/fallback/conv_bias/conv1x1/algos.h" | |||||
#include "src/fallback/conv_bias/conv1x1/conv1x1_dispatcher.h" | #include "src/fallback/conv_bias/conv1x1/conv1x1_dispatcher.h" | ||||
#include "src/fallback/conv_bias/conv1x1/conv1x1_strategy.h" | #include "src/fallback/conv_bias/conv1x1/conv1x1_strategy.h" | ||||
#include "src/fallback/conv_bias/opr_impl.h" | #include "src/fallback/conv_bias/opr_impl.h" | ||||
@@ -194,10 +194,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param, | |||||
PW = param.filter_meta.padding[1]; | PW = param.filter_meta.padding[1]; | ||||
size_t SH = param.filter_meta.stride[0], | size_t SH = param.filter_meta.stride[0], | ||||
SW = param.filter_meta.stride[1]; | SW = param.filter_meta.stride[1]; | ||||
if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) | if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) | ||||
return false; | return false; | ||||
if (param.src_type.enumv() != param.filter_type.enumv()) { | if (param.src_type.enumv() != param.filter_type.enumv()) { | ||||
return false; | return false; | ||||
} | } | ||||
@@ -216,6 +214,7 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param, | |||||
//! is identity otherwise return false mean that 8x8x32 and 8x8x16 | //! is identity otherwise return false mean that 8x8x32 and 8x8x16 | ||||
//! not support PostProcess | //! not support PostProcess | ||||
if (param.dst_type.enumv() == DTypeEnum::Int16 || | if (param.dst_type.enumv() == DTypeEnum::Int16 || | ||||
param.dst_type.enumv() == DTypeEnum::QuantizedS16 || | |||||
param.dst_type.enumv() == DTypeEnum::Int32 || | param.dst_type.enumv() == DTypeEnum::Int32 || | ||||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | ||||
if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include <unordered_map> | #include <unordered_map> | ||||
@@ -226,10 +227,10 @@ public: | |||||
PostprocessMode::FLOAT, | PostprocessMode::FLOAT, | ||||
"DefaultStrategyType::FLOAT"_hash); | "DefaultStrategyType::FLOAT"_hash); | ||||
} else if (format == param::ConvBias::Format::NCHW44) { | } else if (format == param::ConvBias::Format::NCHW44) { | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
auto matmul_block = matmul_algo->get_inner_block_size(); | auto matmul_block = matmul_algo->get_inner_block_size(); | ||||
//! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1 im2col+pack fuse | |||||
//! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1 | |||||
//! im2col+pack fuse | |||||
if ((matmul_block.m == 8 || matmul_block.m == 4) && | if ((matmul_block.m == 8 || matmul_block.m == 4) && | ||||
matmul_block.n == 12 && matmul_block.k == 1 && | matmul_block.n == 12 && matmul_block.k == 1 && | ||||
param.filter_meta.spatial[0] == 3 && | param.filter_meta.spatial[0] == 3 && | ||||
@@ -297,9 +298,21 @@ public: | |||||
break; | break; | ||||
case StrategyType::INT8x8x16: | case StrategyType::INT8x8x16: | ||||
cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||||
"DefaultStrategyType::INT8x8x16"_hash); | |||||
if (format == param::ConvBias::Format::NCHW) { | |||||
cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||||
"DefaultStrategyType::INT8x8x16"_hash); | |||||
} else if (format == param::ConvBias::Format::NCHW44) { | |||||
cb2(NCHW44, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||||
"DefaultStrategyType::INT8x8x16"_hash); | |||||
} else { | |||||
megdnn_throw( | |||||
ssprintf("Current only support layout " | |||||
"NCHW44/NCHW for im2col " | |||||
"algo, but got %d\n", | |||||
uint32_t(format))); | |||||
} | |||||
break; | break; | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | ||||
case StrategyType::QUINT8x8x32: | case StrategyType::QUINT8x8x32: | ||||
@@ -421,10 +434,11 @@ public: | |||||
dt_int32, dt_int8, PostprocessMode::QUANTIZED, | dt_int32, dt_int8, PostprocessMode::QUANTIZED, | ||||
"DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash); | "DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash); | ||||
} else { | } else { | ||||
megdnn_throw(ssprintf("Current only support layout " | |||||
"NCHW44/NCHW/NCHW_DOT for im2col " | |||||
"algo, but got %d\n", | |||||
uint32_t(format))); | |||||
megdnn_throw( | |||||
ssprintf("Current only support layout " | |||||
"NCHW44/NCHW/NCHW_DOT for im2col " | |||||
"algo, but got %d\n", | |||||
uint32_t(format))); | |||||
} | } | ||||
break; | break; | ||||
} | } | ||||
@@ -6,11 +6,12 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "src/naive/matrix_mul/opr_impl.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/naive/matrix_mul/opr_impl.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace fallback { | namespace fallback { | ||||
@@ -66,7 +67,8 @@ public: | |||||
}; | }; | ||||
typedef void (*kern_t)(const KernParam&); | typedef void (*kern_t)(const KernParam&); | ||||
typedef void (*kern_naked_t)(const KernParam& , const void* a_panel, const void *b_panel); | |||||
typedef void (*kern_naked_t)(const KernParam&, const void* a_panel, | |||||
const void* b_panel); | |||||
class AlgoBase : public Algorithm { | class AlgoBase : public Algorithm { | ||||
protected: | protected: | ||||
virtual ~AlgoBase() = default; | virtual ~AlgoBase() = default; | ||||
@@ -83,18 +85,19 @@ public: | |||||
bool can_be_treated_as_int8x8x16(const KernSizeParam& param) const { | bool can_be_treated_as_int8x8x16(const KernSizeParam& param) const { | ||||
return param.A_type.enumv() == param.B_type.enumv() && | return param.A_type.enumv() == param.B_type.enumv() && | ||||
param.A_type.enumv() == DTypeEnum::Int8 && | |||||
param.C_type.enumv() == DTypeEnum::Int16 && | |||||
param.format == param::MatrixMul::Format::DEFAULT && | |||||
param.compute_mode == Param::ComputeMode::DEFAULT; | |||||
(param.A_type.enumv() == DTypeEnum::Int8 || | |||||
param.A_type.enumv() == DTypeEnum::QuantizedS8) && | |||||
(param.C_type.enumv() == DTypeEnum::Int16 || | |||||
param.C_type.enumv() == DTypeEnum::QuantizedS16); | |||||
} | } | ||||
public: | public: | ||||
enum class AlgoSet:uint32_t { | |||||
enum class AlgoSet : uint32_t { | |||||
ALGO_TYPE_GEMM = 0, | ALGO_TYPE_GEMM = 0, | ||||
ALGO_TYPE_GEMV = 1, | ALGO_TYPE_GEMV = 1, | ||||
}; | }; | ||||
enum class PackMode:uint32_t { | |||||
enum class PackMode : uint32_t { | |||||
DEFAULT = 0, | DEFAULT = 0, | ||||
NO_PACK = 1, | NO_PACK = 1, | ||||
ONLY_PACKA = 2, | ONLY_PACKA = 2, | ||||
@@ -489,25 +489,26 @@ void benchmark_im2col_single_algo(const char* im2col_name, Handle* handle, | |||||
void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name, | void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name, | ||||
const char* im2col_name, Handle* handle, | const char* im2col_name, Handle* handle, | ||||
size_t kernel, size_t pack_size = 1) { | |||||
auto&& args = get_winograd_benchmark_args(kernel, pack_size); | |||||
size_t kernel, DType src_type, | |||||
DType dst_type) { | |||||
auto&& args = get_winograd_benchmark_args(kernel, 4); | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
constexpr size_t RUN = 10; | constexpr size_t RUN = 10; | ||||
Benchmarker<ConvBias> benchmark(handle); | Benchmarker<ConvBias> benchmark(handle); | ||||
benchmark.set_display(false); | benchmark.set_display(false); | ||||
benchmark.set_times(RUN); | benchmark.set_times(RUN); | ||||
benchmark.set_dtype(0, dtype::Int8()); | |||||
benchmark.set_dtype(1, dtype::Int8()); | |||||
benchmark.set_dtype(2, dtype::Int32()); | |||||
benchmark.set_dtype(4, dtype::Int32()); | |||||
benchmark.set_dtype(0, src_type); | |||||
benchmark.set_dtype(1, src_type); | |||||
benchmark.set_dtype(2, dst_type); | |||||
benchmark.set_dtype(4, dst_type); | |||||
Benchmarker<ConvBias> benchmark_im2col(handle); | Benchmarker<ConvBias> benchmark_im2col(handle); | ||||
benchmark_im2col.set_display(false); | benchmark_im2col.set_display(false); | ||||
benchmark_im2col.set_times(RUN); | benchmark_im2col.set_times(RUN); | ||||
benchmark_im2col.set_dtype(0, dtype::Int8()); | |||||
benchmark_im2col.set_dtype(1, dtype::Int8()); | |||||
benchmark_im2col.set_dtype(2, dtype::Int32()); | |||||
benchmark_im2col.set_dtype(4, dtype::Int32()); | |||||
benchmark_im2col.set_dtype(0, src_type); | |||||
benchmark_im2col.set_dtype(1, src_type); | |||||
benchmark_im2col.set_dtype(2, dst_type); | |||||
benchmark_im2col.set_dtype(4, dst_type); | |||||
for (auto&& arg : args) { | for (auto&& arg : args) { | ||||
TensorLayout dst_layout; | TensorLayout dst_layout; | ||||
@@ -556,6 +557,7 @@ void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name, | |||||
computations / used_im2col, used / used_im2col); | computations / used_im2col, used / used_im2col); | ||||
} | } | ||||
} | } | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) { | TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) { | ||||
printf("=========================compare " | printf("=========================compare " | ||||
@@ -563,7 +565,17 @@ TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) { | |||||
"IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16 \n"); | "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16 \n"); | ||||
BENCHMARK_IM2COL_NCHW44_VS_NCHW("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16", | BENCHMARK_IM2COL_NCHW44_VS_NCHW("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16", | ||||
"IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16", | "IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16", | ||||
handle(), 3, 4); | |||||
handle(), 3, dtype::Int8(), dtype::Int32()); | |||||
} | |||||
#endif | |||||
#if MEGDNN_ARMV7 | |||||
TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x16) { | |||||
const char* default_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"; | |||||
const char* mk4_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4"; | |||||
printf("compare %s vs %s \n", default_algo, mk4_algo); | |||||
BENCHMARK_IM2COL_NCHW44_VS_NCHW(default_algo, mk4_algo, handle(), 3, | |||||
dtype::Int8(), dtype::Int16()); | |||||
} | } | ||||
#endif | #endif | ||||
@@ -1860,15 +1872,16 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) { | |||||
param.format = param::ConvBias::Format::NCHW44_DOT; | param.format = param::ConvBias::Format::NCHW44_DOT; | ||||
//! channel bias | //! channel bias | ||||
args.emplace_back(param, TensorShape{1, ic/4, h, w, 4}, | |||||
TensorShape{oc/4, ic/4, kernel, kernel, 4, 4}, | |||||
TensorShape{1, oc/4, 1, 1, 4}); | |||||
args.emplace_back(param, TensorShape{1, ic / 4, h, w, 4}, | |||||
TensorShape{oc / 4, ic / 4, kernel, kernel, 4, 4}, | |||||
TensorShape{1, oc / 4, 1, 1, 4}); | |||||
}; | }; | ||||
for (size_t stride : {1, 2}) | for (size_t stride : {1, 2}) | ||||
for (size_t kernel : {2, 3, 5, 7}) | for (size_t kernel : {2, 3, 5, 7}) | ||||
for(size_t oc : {64}) | |||||
for (size_t oc : {64}) | |||||
for (NonlineMode nonline_mode : {NonlineMode::IDENTITY}) { | for (NonlineMode nonline_mode : {NonlineMode::IDENTITY}) { | ||||
run(oc, oc, 56, 56, kernel, kernel / 2, stride, nonline_mode); | |||||
run(oc, oc, 56, 56, kernel, kernel / 2, stride, | |||||
nonline_mode); | |||||
} | } | ||||
constexpr size_t RUN = 50; | constexpr size_t RUN = 50; | ||||
@@ -1880,7 +1893,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) { | |||||
benchmark0.set_display(false); | benchmark0.set_display(false); | ||||
benchmark0.set_times(RUN); | benchmark0.set_times(RUN); | ||||
benchmark0.set_before_exec_callback( | benchmark0.set_before_exec_callback( | ||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("ARMDOTS8DIRECT_NCHW44")); | |||||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||||
"ARMDOTS8DIRECT_NCHW44")); | |||||
Benchmarker<ConvBias> benchmark1(handle()); | Benchmarker<ConvBias> benchmark1(handle()); | ||||
benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f)) | benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f)) | ||||
@@ -2002,15 +2016,20 @@ std::vector<conv_bias::TestArg> get_conv_bias_1x1_benchmark_args( | |||||
void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle, | void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle, | ||||
DType stype, DType matmul_dtype, DType bias_type, | DType stype, DType matmul_dtype, DType bias_type, | ||||
DType conv_dtype) { | |||||
DType conv_dtype, bool is_mk4 = false) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
int pack_size = is_mk4 ? 4 : 1; | |||||
std::vector<TestArg> conv_bias_1x1_args = | std::vector<TestArg> conv_bias_1x1_args = | ||||
get_conv_bias_1x1_benchmark_args(); | |||||
get_conv_bias_1x1_benchmark_args(pack_size); | |||||
constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
param::MatrixMul param; | param::MatrixMul param; | ||||
param.transposeA = false; | param.transposeA = false; | ||||
param.transposeB = false; | param.transposeB = false; | ||||
if (is_mk4) { | |||||
param.format = MatrixMul::Param::Format::MK4; | |||||
} | |||||
Benchmarker<MatrixMul> benchmark_matmul(handle); | Benchmarker<MatrixMul> benchmark_matmul(handle); | ||||
benchmark_matmul.set_before_exec_callback( | benchmark_matmul.set_before_exec_callback( | ||||
AlgoChecker<MatrixMul>(matmul_algo_name)); | AlgoChecker<MatrixMul>(matmul_algo_name)); | ||||
@@ -2038,8 +2057,8 @@ void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle, | |||||
size_t OH = arg.src[2]; | size_t OH = arg.src[2]; | ||||
size_t OW = arg.src[3]; | size_t OW = arg.src[3]; | ||||
size_t OC = arg.filter[0]; | size_t OC = arg.filter[0]; | ||||
size_t M = OC; | |||||
size_t K = IC; | |||||
size_t M = OC * pack_size; | |||||
size_t K = IC * pack_size; | |||||
size_t N = OH * OW; | size_t N = OH * OW; | ||||
float computations = M * N * K * 2.f / (1024 * 1024 * 1024) * 1e3; | float computations = M * N * K * 2.f / (1024 * 1024 * 1024) * 1e3; | ||||
@@ -2047,6 +2066,10 @@ void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle, | |||||
TensorShape A, B; | TensorShape A, B; | ||||
A = TensorShape{M, K}; | A = TensorShape{M, K}; | ||||
B = TensorShape{K, N}; | B = TensorShape{K, N}; | ||||
if (is_mk4) { | |||||
A = TensorShape{M / 4, K / 4, 4, 4}; | |||||
B = TensorShape{K / 4, N, 4}; | |||||
} | |||||
auto conv1x1_used = benchmark_conv1x1.set_param(arg.param).exec( | auto conv1x1_used = benchmark_conv1x1.set_param(arg.param).exec( | ||||
{arg.src, arg.filter, arg.bias, {}, {}}) / | {arg.src, arg.filter, arg.bias, {}, {}}) / | ||||
@@ -2133,6 +2156,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_INT8x8x16) { | |||||
dtype::Int16{}, dtype::Int16{}, dtype::Int16{}); | dtype::Int16{}, dtype::Int16{}, dtype::Int16{}); | ||||
benchmark_conv1x1("ARMV7_INT8X8X16_K4X2X16", handle(), dtype::Int8{}, | benchmark_conv1x1("ARMV7_INT8X8X16_K4X2X16", handle(), dtype::Int8{}, | ||||
dtype::Int16{}, dtype::Int16{}, dtype::Int16{}); | dtype::Int16{}, dtype::Int16{}, dtype::Int16{}); | ||||
benchmark_conv1x1("ARMV7_INT8X8X16_MK4_K8X8X4", handle(), dtype::Int8{}, | |||||
dtype::Int16{}, dtype::Int16{}, dtype::Int16{}, true); | |||||
#endif | #endif | ||||
} | } | ||||
@@ -2145,13 +2170,13 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_GEMV_FP32) { | |||||
conv_param.pad_h = 0; | conv_param.pad_h = 0; | ||||
conv_param.pad_w = 0; | conv_param.pad_w = 0; | ||||
conv_param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; | conv_param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; | ||||
auto run = [&](size_t M, size_t K){ | |||||
auto run = [&](size_t M, size_t K) { | |||||
args.emplace_back(conv_param, TensorShape{1, K, 1, 1}, | args.emplace_back(conv_param, TensorShape{1, K, 1, 1}, | ||||
TensorShape{M, K, 1, 1}, TensorShape{}); | TensorShape{M, K, 1, 1}, TensorShape{}); | ||||
}; | }; | ||||
for (size_t M : {4, 64, 1024, 4096}) | for (size_t M : {4, 64, 1024, 4096}) | ||||
for (size_t K : {128, 256, 1024, 4096}) | for (size_t K : {128, 256, 1024, 4096}) | ||||
run(M, K); | |||||
run(M, K); | |||||
constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
param::MatrixMul param; | param::MatrixMul param; | ||||
@@ -850,7 +850,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) { | |||||
param::ConvBias::Format::NCHW44); | param::ConvBias::Format::NCHW44); | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44_WEIGHT_PREPROCESS) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_WINOGRAD_F63_4_NCHW44_WEIGHT_PREPROCESS) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<TestArg> args = get_nchw44_conv_bias_args({3}, 1); | std::vector<TestArg> args = get_nchw44_conv_bias_args({3}, 1); | ||||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | ||||
@@ -1131,7 +1132,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) { | |||||
1e-3f); | 1e-3f); | ||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2_WEIGHT_PREPROCESS) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||||
CONV_BIAS_WINOGRAD_MK_PACKED_F32_2_WEIGHT_PREPROCESS) { | |||||
using namespace conv_bias; | using namespace conv_bias; | ||||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | ||||
@@ -2089,6 +2091,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) { | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { | ||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
float epsilon = 0.001; | float epsilon = 0.001; | ||||
std::vector<conv_bias::TestArg> args_nchw44 = | |||||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, true, true, true, | |||||
false, false, false, false, true); | |||||
std::vector<conv_bias::TestArg> args_nchw44_1x1s2 = | |||||
get_nchw44_conv_bias_args({1}, 2, true, true, true, false, false, | |||||
false, false, true); | |||||
#define cb(name) \ | #define cb(name) \ | ||||
checker_conv_bias( \ | checker_conv_bias( \ | ||||
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ | get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ | ||||
@@ -2098,6 +2106,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { | |||||
&rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | ||||
dtype::Int16{}, dtype::Int16{}, name); | dtype::Int16{}, dtype::Int16{}, name); | ||||
#define cb_nchw44(name) \ | |||||
checker_conv_bias(args_nchw44, handle(), &rng, epsilon, dtype::Int8{}, \ | |||||
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); \ | |||||
checker_conv_bias(args_nchw44_1x1s2, handle(), &rng, epsilon, \ | |||||
dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, \ | |||||
dtype::Int16{}, name); | |||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8"); | cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8"); | ||||
cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16"); | cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16"); | ||||
@@ -2106,8 +2121,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { | |||||
cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16"); | cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16"); | ||||
cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"); | cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"); | ||||
cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16"); | cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16"); | ||||
cb_nchw44("IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4"); | |||||
#endif | #endif | ||||
#undef cb | #undef cb | ||||
#undef cb_nchw44 | |||||
} | } | ||||
#endif | #endif | ||||
@@ -2516,19 +2534,28 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { | |||||
UniformIntRNG rng{-50, 50}; | UniformIntRNG rng{-50, 50}; | ||||
float epsilon = 0.001; | float epsilon = 0.001; | ||||
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | ||||
std::vector<conv_bias::TestArg> args_nchw44 = get_nchw44_conv_bias_args( | |||||
{1}, 1, true, true, true, false, false, false, false, true); | |||||
#define cb(name) \ | #define cb(name) \ | ||||
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \ | checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \ | ||||
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); | dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); | ||||
#define cb_nchw44(name) \ | |||||
checker_conv_bias(args_nchw44, handle(), &rng, epsilon, dtype::Int8{}, \ | |||||
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); | |||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24"); | cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24"); | ||||
cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24"); | cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24"); | ||||
#elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24"); | cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24"); | ||||
cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48"); | cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48"); | ||||
cb_nchw44("CONV1x1:ARMV7_INT8X8X16_MK4_K8X8X4:48"); | |||||
#endif | #endif | ||||
cb("CONV1x1:ARM_COMMON_INT8X8X16:48"); | cb("CONV1x1:ARM_COMMON_INT8X8X16:48"); | ||||
#undef cb | #undef cb | ||||
#undef cb_nchw44 | |||||
std::vector<conv_bias::TestArg> gemv_args; | std::vector<conv_bias::TestArg> gemv_args; | ||||
for (auto&& arg : args) | for (auto&& arg : args) | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "test/armv7/fixture.h" | #include "test/armv7/fixture.h" | ||||
#include "test/common/benchmarker.h" | #include "test/common/benchmarker.h" | ||||
@@ -51,9 +52,15 @@ TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_K4x8x8) { | |||||
handle(), "ARMV7_INT8X8X16_K4X8X8"); | handle(), "ARMV7_INT8X8X16_K4X8X8"); | ||||
} | } | ||||
TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_MK4_K8x8x4) { | |||||
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | |||||
handle(), "ARMV7_INT8X8X16_MK4_K8X8X4", | |||||
param::MatrixMul::Format::MK4, 1); | |||||
} | |||||
TEST_F(ARMV7, MATRIX_MUL_INT16x16x32) { | TEST_F(ARMV7, MATRIX_MUL_INT16x16x32) { | ||||
matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, | |||||
handle(),"ARMV7_INT16X16X32_K12X4X1"); | |||||
matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, | |||||
handle(), "ARMV7_INT16X16X32_K12X4X1"); | |||||
} | } | ||||
TEST_F(ARMV7, MATRIX_MUL_INT16x16x32_MK8) { | TEST_F(ARMV7, MATRIX_MUL_INT16x16x32_MK8) { | ||||
@@ -83,7 +90,8 @@ TEST_F(ARMV7, MATRIX_MUL_SDOT) { | |||||
TEST_F(ARMV7, MATRIX_MUL_UDOT) { | TEST_F(ARMV7, MATRIX_MUL_UDOT) { | ||||
matrix_mul::check_matrix_mul( | matrix_mul::check_matrix_mul( | ||||
dtype::Quantized8Asymm(4.0f, static_cast<uint8_t>(10)), dtype::Quantized8Asymm(3.0f, static_cast<uint8_t>(54)), | |||||
dtype::Quantized8Asymm(4.0f, static_cast<uint8_t>(10)), | |||||
dtype::Quantized8Asymm(3.0f, static_cast<uint8_t>(54)), | |||||
dtype::QuantizedS32(12.0f), handle(), "AARCH32_QUINT8_K4X8X4"); | dtype::QuantizedS32(12.0f), handle(), "AARCH32_QUINT8_K4X8X4"); | ||||
} | } | ||||
@@ -103,7 +111,9 @@ TEST_F(ARMV7, MATRIX_MUL_MK4_DOT_INT8) { | |||||
#if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
namespace { | namespace { | ||||
void run_8x8x16_benchmark(const char* algo, Handle* handle) { | |||||
void run_8x8x16_benchmark( | |||||
const char* algo, Handle* handle, | |||||
MatrixMul::Param::Format format = MatrixMul::Param::Format::DEFAULT) { | |||||
constexpr size_t RUNS = 50; | constexpr size_t RUNS = 50; | ||||
param::MatrixMul param; | param::MatrixMul param; | ||||
Benchmarker<MatrixMul> benchmarker_int(handle); | Benchmarker<MatrixMul> benchmarker_int(handle); | ||||
@@ -116,21 +126,31 @@ void run_8x8x16_benchmark(const char* algo, Handle* handle) { | |||||
.set_dtype(2, dtype::Int16{}) | .set_dtype(2, dtype::Int16{}) | ||||
.set_param(param) | .set_param(param) | ||||
.set_display(false); | .set_display(false); | ||||
param::MatrixMul target_param; | |||||
target_param.format = format; | |||||
benchmarker_int_kern_4x2x16.set_before_exec_callback( | benchmarker_int_kern_4x2x16.set_before_exec_callback( | ||||
AlgoChecker<MatrixMul>(algo)); | AlgoChecker<MatrixMul>(algo)); | ||||
benchmarker_int_kern_4x2x16.set_times(RUNS) | benchmarker_int_kern_4x2x16.set_times(RUNS) | ||||
.set_dtype(0, dtype::Int8{}) | .set_dtype(0, dtype::Int8{}) | ||||
.set_dtype(1, dtype::Int8{}) | .set_dtype(1, dtype::Int8{}) | ||||
.set_dtype(2, dtype::Int16{}) | .set_dtype(2, dtype::Int16{}) | ||||
.set_param(param) | |||||
.set_param(target_param) | |||||
.set_display(false); | .set_display(false); | ||||
Benchmarker<MatrixMul> benchmarker_float(handle); | Benchmarker<MatrixMul> benchmarker_float(handle); | ||||
benchmarker_float.set_display(false).set_times(RUNS); | benchmarker_float.set_display(false).set_times(RUNS); | ||||
auto run = [&](size_t M, size_t N, size_t K) { | auto run = [&](size_t M, size_t N, size_t K) { | ||||
auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS; | auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS; | ||||
auto int_kern_used = | |||||
benchmarker_int_kern_4x2x16.exec({{M, K}, {K, N}, {}}) / RUNS; | |||||
auto int_kern_used = 1e10; | |||||
if (format == MatrixMul::Param::Format::MK4) { | |||||
int_kern_used = benchmarker_int_kern_4x2x16.exec( | |||||
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / | |||||
RUNS; | |||||
} else { | |||||
int_kern_used = | |||||
benchmarker_int_kern_4x2x16.exec({{M, K}, {K, N}, {}}) / | |||||
RUNS; | |||||
} | |||||
auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS; | auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS; | ||||
float computations = 2.f * M * K * N * 1e-6; | float computations = 2.f * M * K * N * 1e-6; | ||||
printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f " | printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f " | ||||
@@ -145,6 +165,7 @@ void run_8x8x16_benchmark(const char* algo, Handle* handle) { | |||||
}; | }; | ||||
run(256, 12 * 24, 256); | run(256, 12 * 24, 256); | ||||
run(256, 256, 256); | |||||
//////////////////////// gemv ////////////////////////// | //////////////////////// gemv ////////////////////////// | ||||
for (size_t M : {8, 64, 112, 256}) { | for (size_t M : {8, 64, 112, 256}) { | ||||
@@ -185,7 +206,8 @@ void run_16x16x32_benchmark(const char* algo, Handle* handle) { | |||||
"int: %f ms %f Gflops %s: \n" | "int: %f ms %f Gflops %s: \n" | ||||
"speedup(%s/arm_common, %s/float): %f\n", | "speedup(%s/arm_common, %s/float): %f\n", | ||||
M, K, N, float_used, computations / float_used, int_used, | M, K, N, float_used, computations / float_used, int_used, | ||||
computations / int_used,algo,algo,algo,float_used / int_used); | |||||
computations / int_used, algo, algo, algo, | |||||
float_used / int_used); | |||||
}; | }; | ||||
run(256, 12 * 24, 256); | run(256, 12 * 24, 256); | ||||
@@ -231,7 +253,8 @@ void run_8x8x32_benchmark(const char* algo, Handle* handle) { | |||||
"int: %f ms %f Gflops %s: \n" | "int: %f ms %f Gflops %s: \n" | ||||
"speedup(%s/arm_common, %s/float): %f\n", | "speedup(%s/arm_common, %s/float): %f\n", | ||||
M, K, N, float_used, computations / float_used, int_used, | M, K, N, float_used, computations / float_used, int_used, | ||||
computations / int_used,algo,algo,algo,float_used / int_used); | |||||
computations / int_used, algo, algo, algo, | |||||
float_used / int_used); | |||||
}; | }; | ||||
run(256, 12 * 24, 256); | run(256, 12 * 24, 256); | ||||
@@ -252,9 +275,11 @@ void run_8x8x32_quint_benchmark(Handle* handle) { | |||||
benchmarker_quint8_dot.set_before_exec_callback( | benchmarker_quint8_dot.set_before_exec_callback( | ||||
AlgoChecker<MatrixMul>("AARCH32_QUINT8_K4X8X4")); | AlgoChecker<MatrixMul>("AARCH32_QUINT8_K4X8X4")); | ||||
benchmarker_quint8_dot.set_times(RUNS) | benchmarker_quint8_dot.set_times(RUNS) | ||||
.set_dtype(0, dtype::Quantized8Asymm(2.3f, static_cast<uint8_t>(20))) | |||||
.set_dtype(1, dtype::Quantized8Asymm(3.1f, static_cast<uint8_t>(30))) | |||||
.set_dtype(2, dtype::QuantizedS32(2.3f*3.1f)) | |||||
.set_dtype(0, | |||||
dtype::Quantized8Asymm(2.3f, static_cast<uint8_t>(20))) | |||||
.set_dtype(1, | |||||
dtype::Quantized8Asymm(3.1f, static_cast<uint8_t>(30))) | |||||
.set_dtype(2, dtype::QuantizedS32(2.3f * 3.1f)) | |||||
.set_param(param) | .set_param(param) | ||||
.set_display(false); | .set_display(false); | ||||
@@ -262,14 +287,17 @@ void run_8x8x32_quint_benchmark(Handle* handle) { | |||||
benchmarker_quint8.set_before_exec_callback( | benchmarker_quint8.set_before_exec_callback( | ||||
AlgoChecker<MatrixMul>("ARMV7_QUINT8_K4X8X8")); | AlgoChecker<MatrixMul>("ARMV7_QUINT8_K4X8X8")); | ||||
benchmarker_quint8.set_times(RUNS) | benchmarker_quint8.set_times(RUNS) | ||||
.set_dtype(0, dtype::Quantized8Asymm(2.3f, static_cast<uint8_t>(20))) | |||||
.set_dtype(1, dtype::Quantized8Asymm(3.1f, static_cast<uint8_t>(30))) | |||||
.set_dtype(2, dtype::QuantizedS32(2.3f*3.1f)) | |||||
.set_dtype(0, | |||||
dtype::Quantized8Asymm(2.3f, static_cast<uint8_t>(20))) | |||||
.set_dtype(1, | |||||
dtype::Quantized8Asymm(3.1f, static_cast<uint8_t>(30))) | |||||
.set_dtype(2, dtype::QuantizedS32(2.3f * 3.1f)) | |||||
.set_param(param) | .set_param(param) | ||||
.set_display(false); | .set_display(false); | ||||
auto run = [&](size_t M, size_t N, size_t K) { | auto run = [&](size_t M, size_t N, size_t K) { | ||||
auto dot_used = benchmarker_quint8_dot.exec({{M, K}, {K, N}, {}}) / RUNS; | |||||
auto dot_used = | |||||
benchmarker_quint8_dot.exec({{M, K}, {K, N}, {}}) / RUNS; | |||||
auto normal_used = benchmarker_quint8.exec({{M, K}, {K, N}, {}}) / RUNS; | auto normal_used = benchmarker_quint8.exec({{M, K}, {K, N}, {}}) / RUNS; | ||||
float computations = 2.f * M * K * N * 1e-6; | float computations = 2.f * M * K * N * 1e-6; | ||||
printf("run: {%zu{M} %zu{K} %zu{N}} dot: %f ms %f Gflops \n" | printf("run: {%zu{M} %zu{K} %zu{N}} dot: %f ms %f Gflops \n" | ||||
@@ -351,11 +379,15 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x2x16) { | |||||
run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X2X16", handle()); | run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X2X16", handle()); | ||||
} | } | ||||
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8) { | TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8) { | ||||
run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X8X8", handle()); | run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X8X8", handle()); | ||||
} | } | ||||
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_MK4_K4x8x8) { | |||||
run_8x8x16_benchmark("ARMV7_INT8X8X16_MK4_K8X8X4", handle(), | |||||
MatrixMul::Param::Format::MK4); | |||||
} | |||||
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_K12x4x1) { | TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_K12x4x1) { | ||||
run_16x16x32_benchmark("ARMV7_INT16X16X32_K12X4X1", handle()); | run_16x16x32_benchmark("ARMV7_INT16X16X32_K12X4X1", handle()); | ||||
} | } | ||||
@@ -6,12 +6,13 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "test/common/matrix_mul.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "test/common/benchmarker.h" | #include "test/common/benchmarker.h" | ||||
#include "test/common/checker.h" | #include "test/common/checker.h" | ||||
#include "test/common/matrix_mul.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace test; | using namespace test; | ||||
@@ -39,9 +40,9 @@ std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args_no_mask() { | |||||
std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_mk_packed_args( | std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_mk_packed_args( | ||||
size_t nbase) { | size_t nbase) { | ||||
std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
for (size_t m : {1, 2, 3, 4, 5}) | |||||
for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 11}) | |||||
for (size_t n : {1, 2, 3, 4, 5, 8, 12, 16, 24}) | for (size_t n : {1, 2, 3, 4, 5, 8, 12, 16, 24}) | ||||
for (size_t k : {1, 2, 3, 4, 5, 9, 10}) | |||||
for (size_t k : {1, 2, 3, 4, 5, 9, 10, 11}) | |||||
args.emplace_back(m, n * nbase, k, 0); | args.emplace_back(m, n * nbase, k, 0); | ||||
return args; | return args; | ||||
} | } | ||||