GitOrigin-RevId: d2f8290a8d
tags/v1.0.0-rc1
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/algos.h" | |||
@@ -733,7 +734,9 @@ void int8x8x16_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
return can_be_treated_as_int8x8x16(kern_size_param); | |||
return can_be_treated_as_int8x8x16(kern_size_param) && | |||
kern_size_param.format == param::MatrixMul::Format::DEFAULT && | |||
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT; | |||
} | |||
bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::preferred( | |||
@@ -796,7 +799,9 @@ void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
return can_be_treated_as_int8x8x16(kern_size_param); | |||
return can_be_treated_as_int8x8x16(kern_size_param) && | |||
kern_size_param.format == param::MatrixMul::Format::DEFAULT && | |||
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT; | |||
} | |||
bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::preferred( | |||
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/armv7/matrix_mul/algos.h" | |||
@@ -526,6 +527,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x8x8, | |||
"AlgoInt8x8x16K4x8x8"_hash, | |||
armv7::matmul::gemm_s8x8x16_4x8, int8_t, | |||
int16_t); | |||
/* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/ | |||
namespace { | |||
void kern_int8x8x16_mk4_k8x8x4(const MatrixMulImpl::KernParam& kern_param) { | |||
MIDOUT_BEGIN(megdnn_armv7_matmul_kern, | |||
midout_iv("kern_int8x8x16_mk4_k8x8x4"_hash)) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int16>(); | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
auto trA = kern_param.trA, trB = kern_param.trB; | |||
armv7::matmul::gemm_s8x8x16_mk4_8x8 strategy(M, N, K, kern_param.A_type, | |||
kern_param.B_type, | |||
kern_param.C_type); | |||
megdnn::matmul::GemmInterleaved<armv7::matmul::gemm_s8x8x16_mk4_8x8>( | |||
M, N, K, trA, trB, strategy) | |||
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | |||
kern_param.workspace_ptr); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
bool type_ok = can_be_treated_as_int8x8x16(kern_size_param); | |||
return type_ok && kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
!kern_size_param.trA && !kern_size_param.trB && | |||
kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; | |||
} | |||
size_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_workspace( | |||
const KernSizeParam& kern_size_param) const { | |||
MIDOUT_BEGIN(megdnn_armv7_matmul_kern, | |||
midout_iv("AlgoInt8x8x16K8x8x4::get_workspace"_hash)) { | |||
auto M = kern_size_param.M, N = kern_size_param.N, | |||
K = kern_size_param.K; | |||
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||
C_type = kern_size_param.C_type; | |||
auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||
matmul::gemm_s8x8x16_mk4_8x8 strategy(M, N, K, A_type, B_type, C_type); | |||
return megdnn::matmul::GemmInterleaved<matmul::gemm_s8x8x16_mk4_8x8>( | |||
M, N, K, trA, trB, strategy) | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_kern( | |||
const KernSizeParam&) const { | |||
return kern_int8x8x16_mk4_k8x8x4; | |||
} | |||
bool MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::preferred( | |||
const KernSizeParam& kern_size_param) const { | |||
return kern_size_param.K >= 4; | |||
} | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16MK4_8x8x4, | |||
megdnn_armv7_matmul_kern, | |||
"AlgoInt8x8x16MK4_8x8x4"_hash, | |||
armv7::matmul::gemm_s8x8x16_mk4_8x8, | |||
int8_t, int16_t, int16_t); | |||
/* ===================== Int16x16x32 Kernel 12x4x1 algo ===================== */ | |||
namespace { | |||
@@ -937,11 +1006,9 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_4x8::get_kern( | |||
Bptr = kern_param.B<dt_float16>(); | |||
auto Cptr = kern_param.C<dt_float16>(); | |||
armv7::matmul::gemm_nopack_f16_4x8 strategy(A_type, B_type, | |||
C_type); | |||
megdnn::matmul::GemmInterleaved< | |||
armv7::matmul::gemm_nopack_f16_4x8, false>(M, N, K, trA, | |||
trB, strategy) | |||
armv7::matmul::gemm_nopack_f16_4x8 strategy(A_type, B_type, C_type); | |||
megdnn::matmul::GemmInterleaved<armv7::matmul::gemm_nopack_f16_4x8, | |||
false>(M, N, K, trA, trB, strategy) | |||
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | |||
kern_param.workspace_ptr); | |||
} | |||
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
@@ -171,6 +172,18 @@ public: | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override; | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
@@ -6,13 +6,15 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include <arm_neon.h> | |||
#include <cmath> | |||
#include <cstdint> | |||
#include <type_traits> | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
@@ -172,7 +174,6 @@ static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1, | |||
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) | |||
: | |||
: "q0", "q1", "q2", "q3", "memory"); | |||
} | |||
template <typename T> | |||
@@ -183,12 +184,12 @@ static inline void interleave_4x4_4_b(const T*& inptr0, const T*& inptr1, | |||
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value, | |||
"interleave_4x4_4_b only support uint8_t and int8_t"); | |||
asm volatile( | |||
"vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3 | |||
"vld1.32 {d2, d3}, [%[inptr1]]!\n" // B0B1B2B3 | |||
"vld1.32 {d4, d5}, [%[inptr2]]!\n" // C0C1C2C3 | |||
"vld1.32 {d6, d7}, [%[inptr3]]!\n" // D0D1D2D3 | |||
"vtrn.32 q0, q1\n" // A0B0A2B2 A1B1A3B3 | |||
"vtrn.32 q2, q3\n" // C0D0C2D2 C1D1C3D3 | |||
"vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3 | |||
"vld1.32 {d2, d3}, [%[inptr1]]!\n" // B0B1B2B3 | |||
"vld1.32 {d4, d5}, [%[inptr2]]!\n" // C0C1C2C3 | |||
"vld1.32 {d6, d7}, [%[inptr3]]!\n" // D0D1D2D3 | |||
"vtrn.32 q0, q1\n" // A0B0A2B2 A1B1A3B3 | |||
"vtrn.32 q2, q3\n" // C0D0C2D2 C1D1C3D3 | |||
"vswp d1, d4 \n" // q0=A0,B0,C0,D0 q2=A2,B2,C2,D2 | |||
"vswp d3, d6 \n" // q1=A1,B1,C1,D1 q3=A3,B3,C3,D3 | |||
"vst1.32 {d0-d1},[%[outptr]]!\n" | |||
@@ -323,10 +324,10 @@ static inline void interleave_6x4_8_b(const T*& inptr0, const T*& inptr1, | |||
"vtrn.32 q1, q3 \n" // q1=r02,r12,r03,r13 q3=r06,r16,r07,r17 | |||
"vtrn.32 q5, q7 \n" // q5=r22,r32,r23,r33 q7=r26,r36,r27,r37 | |||
"vtrn.32 q9, q11 \n" // q9=r42,r52,r43,r53 q11=r46,r56,r47,r57 | |||
"vst1.32 {d0-d1}, [%[outptr]]! \n" | |||
"vst1.32 {d16}, [%[outptr]]! \n" | |||
"vst1.32 {d0-d1}, [%[outptr]]! \n" | |||
"vst1.32 {d16}, [%[outptr]]! \n" | |||
"vswp d3, d10 \n" // q1=r02,r12,r22,r32 q5=r03,r13,r23,r33 | |||
"vst1.32 {d8-d9}, [%[outptr]]! \n" | |||
"vst1.32 {d8-d9}, [%[outptr]]! \n" | |||
"vst1.32 {d17}, [%[outptr]]! \n" | |||
"vst1.32 {d2-d3}, [%[outptr]]!\n" | |||
"vst1.32 {d18}, [%[outptr]]!\n" | |||
@@ -810,15 +811,15 @@ static inline void transpose_12x4_1_h(const T*& inptr0, const T*& inptr1, | |||
"interleave_12x4_1_h only support uint16_t and int16_t"); | |||
auto ldin_asm = ldin << 1; | |||
asm volatile( | |||
"vld1.16 {d0}, [%[inptr0]]!\n" // A0A1A2A3 | |||
"vld1.16 {d1}, [%[inptr1]]!\n" // B0B1B2B3 | |||
"vld1.16 {d2}, [%[inptr2]]!\n" // C0C1C2C3 | |||
"vld1.16 {d3}, [%[inptr3]]!\n" // D0D1D2D3 | |||
"vld1.16 {d4}, [%[inptr4]]!\n" // E0E1E2E3 | |||
"vld1.16 {d5}, [%[inptr5]]!\n" // F0F1F2F3 | |||
"vld1.16 {d6}, [%[inptr6]]!\n" // G0G1G2G3 | |||
"vld1.16 {d7}, [%[inptr7]]!\n" // H0H1H2H3 | |||
"vld1.16 {d8}, [%[inptr8]]!\n" // I0I1I2I3 | |||
"vld1.16 {d0}, [%[inptr0]]!\n" // A0A1A2A3 | |||
"vld1.16 {d1}, [%[inptr1]]!\n" // B0B1B2B3 | |||
"vld1.16 {d2}, [%[inptr2]]!\n" // C0C1C2C3 | |||
"vld1.16 {d3}, [%[inptr3]]!\n" // D0D1D2D3 | |||
"vld1.16 {d4}, [%[inptr4]]!\n" // E0E1E2E3 | |||
"vld1.16 {d5}, [%[inptr5]]!\n" // F0F1F2F3 | |||
"vld1.16 {d6}, [%[inptr6]]!\n" // G0G1G2G3 | |||
"vld1.16 {d7}, [%[inptr7]]!\n" // H0H1H2H3 | |||
"vld1.16 {d8}, [%[inptr8]]!\n" // I0I1I2I3 | |||
"vld1.16 {d9}, [%[inptr9]]\n" // J0J1J2J3 | |||
"add %[inptr9], %[inptr9], %[ldin_asm]\n" | |||
"vld1.16 {d10}, [%[inptr9]]\n" // K0K1K2K3 | |||
@@ -854,17 +855,15 @@ static inline void transpose_12x4_1_h(const T*& inptr0, const T*& inptr1, | |||
[inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), | |||
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), | |||
[inptr9] "+r"(inptr9), [outptr] "+r"(outptr) | |||
:[ldin_asm] "r"(ldin_asm) | |||
: [ldin_asm] "r"(ldin_asm) | |||
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", | |||
"d11", "memory"); | |||
inptr9 -= ldin_asm; | |||
inptr9 += 4; | |||
inptr9 -= ldin_asm; | |||
inptr9 += 4; | |||
inptr10 += 4; | |||
inptr11 += 4; | |||
} | |||
template <typename T> | |||
static inline void transpose_2x16_1_b_helper(const T*& inptr0, const T*& inptr1, | |||
const T*& inptr2, const T*& inptr3, | |||
@@ -1038,7 +1037,7 @@ static inline void transpose_4x4_1_s(const T*& inptr0, const T*& inptr1, | |||
"vst1.32 {d7}, [%[outptr]], %[stride]\n" | |||
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), | |||
[inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), | |||
[outptr] "+r"(outptr), [stride] "+r" (stride) | |||
[outptr] "+r"(outptr), [stride] "+r"(stride) | |||
: | |||
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "memory"); | |||
} | |||
@@ -1069,7 +1068,6 @@ static inline void transpose_4x2_1_s(const T*& inptr0, const T*& inptr1, | |||
: "d0", "d1", "d2", "d3", "memory"); | |||
} | |||
template <typename T> | |||
static inline void transpose_6x4_1_b(const T*& inptr0, const T*& inptr1, | |||
const T*& inptr2, const T*& inptr3, | |||
@@ -1082,9 +1080,9 @@ static inline void transpose_6x4_1_b(const T*& inptr0, const T*& inptr1, | |||
"vld1.8 {d1}, [%[inptr1]]\n" // B0B1B2B3B4B5 B6B7 | |||
"vld1.8 {d2}, [%[inptr2]]\n" // C0C1C2C3C4C5 C6C7 | |||
"vld1.8 {d3}, [%[inptr3]]\n" // D0D1D2D3D4D5 D6D7 | |||
"vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7 | |||
"vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7 | |||
"vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7 | |||
"vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7 | |||
"add %[inptr0],%[inptr0],#6 \n" | |||
"add %[inptr1],%[inptr1],#6 \n" | |||
"add %[inptr2],%[inptr2],#6 \n" | |||
@@ -1121,9 +1119,9 @@ static inline void transpose_4x4_1_b(const T*& inptr0, const T*& inptr1, | |||
"vld1.8 {d1}, [%[inptr1]]\n" // B0B1B2B3B4B5 B6B7 | |||
"vld1.8 {d2}, [%[inptr2]]\n" // C0C1C2C3C4C5 C6C7 | |||
"vld1.8 {d3}, [%[inptr3]]\n" // D0D1D2D3D4D5 D6D7 | |||
"vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7 | |||
"vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7 | |||
"vtrn.8 d0, d1\n" // A0B0A2B2A4B4A6B6 A1B1A3B3A5B5A7B7 | |||
"vtrn.8 d2, d3\n" // C0D0C2D2C4D4C6D6 C1D1C3D3C5D5C7D7 | |||
"add %[inptr0],%[inptr0],#4 \n" | |||
"add %[inptr1],%[inptr1],#4 \n" | |||
"add %[inptr2],%[inptr2],#4 \n" | |||
@@ -1176,7 +1174,7 @@ static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { | |||
"vst1.32 {d6-d7}, [%[outptr]]! \n" | |||
"vst1.32 {d14-d15}, [%[outptr]]! \n" | |||
"vst1.32 {d22-d23}, [%[outptr]]! \n" | |||
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) | |||
: [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) | |||
: | |||
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", | |||
"q11", "memory"); | |||
@@ -1195,12 +1193,11 @@ static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { | |||
"vst1.32 {d4-d5}, [%[outptr]]! \n" | |||
"vst1.32 {d2-d3}, [%[outptr]]! \n" | |||
"vst1.32 {d6-d7}, [%[outptr]]! \n" | |||
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) | |||
: [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) | |||
: | |||
: "q0", "q1", "q2", "q3", "memory"); | |||
} | |||
template <typename T> | |||
static inline void transpose_4(const T*& inptr0, const T*& inptr1, | |||
const T*& inptr2, const T*& inptr3, T* outptr, | |||
@@ -1251,7 +1248,6 @@ static inline void transpose_8(const T*& inptr0, const T*& inptr1, | |||
} | |||
} | |||
template <typename T> | |||
static inline void transpose_4x1(const T*& inptr0, const T*& inptr1, | |||
const T*& inptr2, const T*& inptr3, | |||
@@ -1375,7 +1371,68 @@ static inline void transpose_interleave_1x4_4_b(const T*& inptr0, T* outptr, | |||
: "q0", "q1", "q2", "q3", "memory"); | |||
} | |||
} // armv7 | |||
static inline void interleave_4x4_8x4_s8_s16(const int8_t* inptr0, | |||
const int8_t* inptr1, | |||
int16_t* outptr) { | |||
int8x16_t row0 = vld1q_s8(inptr0); | |||
int16x8_t row0_01 = vmovl_low_s8(row0); | |||
int16x8_t row0_23 = vmovl_high_s8(row0); | |||
int16x4_t row0_0 = vget_low_s16(row0_01); | |||
int16x4_t row0_1 = vget_high_s16(row0_01); | |||
int16x4_t row0_2 = vget_low_s16(row0_23); | |||
int16x4_t row0_3 = vget_high_s16(row0_23); | |||
int8x16_t row1 = vld1q_s8(inptr1); | |||
int16x8_t row1_01 = vmovl_low_s8(row1); | |||
int16x8_t row1_23 = vmovl_high_s8(row1); | |||
int16x4_t row1_0 = vget_low_s16(row1_01); | |||
int16x4_t row1_1 = vget_high_s16(row1_01); | |||
int16x4_t row1_2 = vget_low_s16(row1_23); | |||
int16x4_t row1_3 = vget_high_s16(row1_23); | |||
vst1_s16(outptr, row0_0); | |||
vst1_s16(outptr + 1 * 4, row1_0); | |||
vst1_s16(outptr + 2 * 4, row0_1); | |||
vst1_s16(outptr + 3 * 4, row1_1); | |||
vst1_s16(outptr + 4 * 4, row0_2); | |||
vst1_s16(outptr + 5 * 4, row1_2); | |||
vst1_s16(outptr + 6 * 4, row0_3); | |||
vst1_s16(outptr + 7 * 4, row1_3); | |||
}; | |||
static inline void transpos_8x4_int8(const int8_t* inptr0, int8_t* outptr) { | |||
int8x8x4_t input = vld4_s8(inptr0); | |||
vst1_s8(outptr, input.val[0]); | |||
vst1_s8(outptr + 1 * 8, input.val[1]); | |||
vst1_s8(outptr + 2 * 8, input.val[2]); | |||
vst1_s8(outptr + 3 * 8, input.val[3]); | |||
} | |||
static inline void memcpy_s8_s16(const int8_t* inptr, int16_t* outptr, | |||
int count) { | |||
for (; count >= 32; count -= 32) { | |||
int8x8_t in0 = vld1_s8(inptr); | |||
int8x8_t in1 = vld1_s8(inptr + 1 * 8); | |||
int8x8_t in2 = vld1_s8(inptr + 2 * 8); | |||
int8x8_t in3 = vld1_s8(inptr + 3 * 8); | |||
vst1q_s16(outptr, vmovl_s8(in0)); | |||
vst1q_s16(outptr + 1 * 8, vmovl_s8(in1)); | |||
vst1q_s16(outptr + 2 * 8, vmovl_s8(in2)); | |||
vst1q_s16(outptr + 3 * 8, vmovl_s8(in3)); | |||
inptr += 32; | |||
outptr += 32; | |||
} | |||
for (; count >= 8; count -= 8) { | |||
int8x8_t in0 = vld1_s8(inptr); | |||
vst1q_s16(outptr, vmovl_s8(in0)); | |||
inptr += 8; | |||
outptr += 8; | |||
} | |||
for (; count > 0; --count) { | |||
*outptr++ = (int16_t)(*inptr++); | |||
} | |||
} | |||
} // namespace armv7 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -102,60 +102,60 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
"vld1.8 {d2}, [%[a_ptr]]!\n" | |||
"vld1.8 {d4}, [%[a_ptr]]!\n" | |||
"vld1.8 {d6}, [%[a_ptr]]!\n" | |||
"vld1.8 {d18}, [%[b_ptr]]!\n" | |||
"vmovl.s8 q8, d16\n" | |||
"vmovl.s8 q0, d0\n" | |||
"vmovl.s8 q1, d2\n" | |||
"vmovl.s8 q2, d4\n" | |||
"vmovl.s8 q3, d6\n" | |||
"vld1.8 {d18}, [%[b_ptr]]!\n" | |||
"vmovl.s8 q9, d18\n" | |||
"vld1.8 {d20}, [%[b_ptr]]!\n" | |||
"vmla.s16 q4, q8, d0[0]\n" | |||
"vmla.s16 q5, q8, d2[0]\n" | |||
"vmla.s16 q6, q8, d4[0]\n" | |||
"vmla.s16 q7, q8, d6[0]\n" | |||
"vmovl.s8 q9, d18\n" | |||
"vld1.8 {d20}, [%[b_ptr]]!\n" | |||
"vmovl.s8 q10, d20\n" | |||
"vld1.8 {d22}, [%[b_ptr]]!\n" | |||
"vmla.s16 q4, q9, d0[1]\n" | |||
"vmla.s16 q5, q9, d2[1]\n" | |||
"vmla.s16 q6, q9, d4[1]\n" | |||
"vmla.s16 q7, q9, d6[1]\n" | |||
"vmovl.s8 q10, d20\n" | |||
"vld1.8 {d22}, [%[b_ptr]]!\n" | |||
"vmovl.s8 q11, d22\n" | |||
"vld1.8 {d24}, [%[b_ptr]]!\n" | |||
"vmla.s16 q4, q10, d0[2]\n" | |||
"vmla.s16 q5, q10, d2[2]\n" | |||
"vmla.s16 q6, q10, d4[2]\n" | |||
"vmla.s16 q7, q10, d6[2]\n" | |||
"vmovl.s8 q11, d22\n" | |||
"vld1.8 {d24}, [%[b_ptr]]!\n" | |||
"vmovl.s8 q12, d24\n" | |||
"vld1.8 {d26}, [%[b_ptr]]!\n" | |||
"vmla.s16 q4, q11, d0[3]\n" | |||
"vmla.s16 q5, q11, d2[3]\n" | |||
"vmla.s16 q6, q11, d4[3]\n" | |||
"vmla.s16 q7, q11, d6[3]\n" | |||
"vmovl.s8 q12, d24\n" | |||
"vld1.8 {d26}, [%[b_ptr]]!\n" | |||
"vmovl.s8 q13, d26\n" | |||
"vld1.8 {d28}, [%[b_ptr]]!\n" | |||
"vmla.s16 q4, q12, d1[0]\n" | |||
"vmla.s16 q5, q12, d3[0]\n" | |||
"vmla.s16 q6, q12, d5[0]\n" | |||
"vmla.s16 q7, q12, d7[0]\n" | |||
"vmovl.s8 q13, d26\n" | |||
"vld1.8 {d28}, [%[b_ptr]]!\n" | |||
"vmovl.s8 q14, d28\n" | |||
"vld1.8 {d30}, [%[b_ptr]]!\n" | |||
"vmla.s16 q4, q13, d1[1]\n" | |||
"vmla.s16 q5, q13, d3[1]\n" | |||
"vmla.s16 q6, q13, d5[1]\n" | |||
"vmla.s16 q7, q13, d7[1]\n" | |||
"vmovl.s8 q14, d28\n" | |||
"vld1.8 {d30}, [%[b_ptr]]!\n" | |||
"vmovl.s8 q15, d30\n" | |||
"vmla.s16 q4, q14, d1[2]\n" | |||
"vmla.s16 q5, q14, d3[2]\n" | |||
"vmla.s16 q6, q14, d5[2]\n" | |||
"vmla.s16 q7, q14, d7[2]\n" | |||
"vmovl.s8 q15, d30\n" | |||
"vmla.s16 q4, q15, d1[3]\n" | |||
"vmla.s16 q5, q15, d3[3]\n" | |||
@@ -0,0 +1,406 @@ | |||
/** | |||
* \file dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/armv7/matrix_mul/asm/common.h" | |||
namespace megdnn { | |||
namespace armv7 { | |||
namespace matmul_mk4_8x8x4 { | |||
//! optimize for A7 | |||
/** | |||
* Overview of register layout: | |||
* | |||
* A 8x8x8 cell of Lhs is stored in 16bit in q0, q1 | |||
* A 8x8x8 cell of Rhs is stored in 8bit in q2, q3 | |||
* A 8x8 block of accumulators is stored in 16bit in q8-q15 | |||
* | |||
* +--------+ | |||
* | q4[0-8]| | |||
* Rhs +--------+ | |||
* Lhs | | | |||
* | |||
* +--------+ - - - - +--------- | |||
* |q0[0]| | q8 [0-8]| | |||
* |q0[1]| | q9 [0-8]| | |||
* |q0[2]| | q10[0-8]| | |||
* |q0[3]| | q11[0-8]| | |||
* |q0[4]| | q12[0-8]| | |||
* |q0[5]| | q13[0-8]| | |||
* |q0[6]| | q14[0-8]| | |||
* |q0[7]| | q15[0-8]| | |||
* +--------+ - - - - +--------- | |||
* | |||
* Accumulator | |||
*/ | |||
static void kern_8x8(const int16_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool is_first_k, int remain_n) { | |||
K /= 4; | |||
const int16_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
LDC = LDC * sizeof(int16_t); | |||
int x0 = 0; | |||
// clang-format off | |||
#define STORE_LINE(reg_index1, reg_index2) \ | |||
"cmp %[x0], #0 \n" \ | |||
"beq 101f\n" \ | |||
"vst1.16 {d" reg_index1 "}, [r0]!\n" \ | |||
"vst1.16 {d" reg_index2 "}, [r1]!\n" \ | |||
"subs %[x0], %[x0], #1\n" | |||
#define STORE_C \ | |||
"mov %[x0], %[remain_n]\n" \ | |||
STORE_LINE("16", "17") \ | |||
STORE_LINE("18", "19") \ | |||
STORE_LINE("20", "21") \ | |||
STORE_LINE("22", "23") \ | |||
STORE_LINE("24", "25") \ | |||
STORE_LINE("26", "27") \ | |||
STORE_LINE("28", "29") \ | |||
STORE_LINE("30", "31") \ | |||
"101:\n" | |||
// clang-format on | |||
register int16_t* outptr asm("r0") = output; | |||
asm volatile( | |||
// load accumulator C | |||
"add r1, r0, %[LDC]\n" | |||
"cmp %[is_first_k], #1\n" | |||
"beq 1f\n" | |||
"b 2f\n" | |||
"1:\n" | |||
"veor.s32 q8 , q8 , q8 \n" | |||
"veor.s32 q9 , q9 , q9 \n" | |||
"veor.s32 q10, q10, q10\n" | |||
"veor.s32 q11, q11, q11\n" | |||
"veor.s32 q12, q12, q12\n" | |||
"veor.s32 q13, q13, q13\n" | |||
"veor.s32 q14, q14, q14\n" | |||
"veor.s32 q15, q15, q15\n" | |||
"2: \n" | |||
"vld1.8 {d4}, [%[b_ptr]]!\n" | |||
"vld1.16 {d0, d1}, [%[a_ptr]]!\n" | |||
"vmovl.s8 q2, d4\n" | |||
"vld1.16 {d2, d3}, [%[a_ptr]]!\n" | |||
"vld1.8 {d6}, [%[b_ptr]]!\n" | |||
//! k0 | |||
"vmla.s16 q8, q0, d4[0]\n" | |||
"vmla.s16 q9, q0, d4[1]\n" | |||
"vmla.s16 q10, q0, d4[2]\n" | |||
"vmla.s16 q11, q0, d4[3]\n" | |||
"vmovl.s8 q3, d6\n" | |||
"vmla.s16 q12, q0, d5[0]\n" | |||
"vmla.s16 q13, q0, d5[1]\n" | |||
"vmla.s16 q14, q0, d5[2]\n" | |||
"vmla.s16 q15, q0, d5[3]\n" | |||
//! k1 | |||
"vld1.16 {d0, d1}, [%[a_ptr]]!\n" | |||
"vld1.8 {d4}, [%[b_ptr]]!\n" | |||
"vmla.s16 q8, q1, d6[0]\n" | |||
"vmla.s16 q9, q1, d6[1]\n" | |||
"vmla.s16 q10, q1, d6[2]\n" | |||
"vmla.s16 q11, q1, d6[3]\n" | |||
"vmovl.s8 q2, d4\n" | |||
"vmla.s16 q12, q1, d7[0]\n" | |||
"vmla.s16 q13, q1, d7[1]\n" | |||
"vmla.s16 q14, q1, d7[2]\n" | |||
"vmla.s16 q15, q1, d7[3]\n" | |||
//! k2 | |||
"vld1.16 {d2, d3}, [%[a_ptr]]!\n" | |||
"vld1.8 {d6}, [%[b_ptr]]!\n" | |||
"vmla.s16 q8, q0, d4[0]\n" | |||
"vmla.s16 q9, q0, d4[1]\n" | |||
"vmla.s16 q10, q0, d4[2]\n" | |||
"vmla.s16 q11, q0, d4[3]\n" | |||
"vmovl.s8 q3, d6\n" | |||
"vmla.s16 q12, q0, d5[0]\n" | |||
"vmla.s16 q13, q0, d5[1]\n" | |||
"vmla.s16 q14, q0, d5[2]\n" | |||
"vmla.s16 q15, q0, d5[3]\n" | |||
//! k3 | |||
"vmla.s16 q8, q1, d6[0]\n" | |||
"vmla.s16 q9, q1, d6[1]\n" | |||
"vmla.s16 q10, q1, d6[2]\n" | |||
"vmla.s16 q11, q1, d6[3]\n" | |||
"vmla.s16 q12, q1, d7[0]\n" | |||
"vmla.s16 q13, q1, d7[1]\n" | |||
"vmla.s16 q14, q1, d7[2]\n" | |||
"vmla.s16 q15, q1, d7[3]\n" | |||
"subs %[K], %[K], #1\n" | |||
"bne 2b\n" | |||
"3:\n" | |||
"cmp %[remain_n], #8\n" | |||
"bne 4f\n" | |||
"vstr d16, [r0]\n" | |||
"vstr d18, [r0, #8]\n" | |||
"vstr d20, [r0, #16]\n" | |||
"vstr d22, [r0, #24]\n" | |||
"vstr d24, [r0, #32]\n" | |||
"vstr d26, [r0, #40]\n" | |||
"vstr d28, [r0, #48]\n" | |||
"vstr d30, [r0, #56]\n" | |||
"vstr d17, [r1]\n" | |||
"vstr d19, [r1, #8]\n" | |||
"vstr d21, [r1, #16]\n" | |||
"vstr d23, [r1, #24]\n" | |||
"vstr d25, [r1, #32]\n" | |||
"vstr d27, [r1, #40]\n" | |||
"vstr d29, [r1, #48]\n" | |||
"vstr d31, [r1, #56]\n" | |||
"b 101f\n" | |||
"4:\n " STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[x0] "+r"(x0), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||
: | |||
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", | |||
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", | |||
"d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", | |||
"d29", "d30", "d31", "r1", "r2", "r3", "cc", "memory"); | |||
#undef STORE_C | |||
#undef STORE_LINE | |||
} | |||
/** | |||
* Overview of register layout: | |||
* | |||
* A 8x8x8 cell of Lhs is stored in 16bit in d0, d2 | |||
* A 8x8x8 cell of Rhs is stored in 8bit in q2, q3 | |||
* A 8x8 block of accumulators is stored in 16bit in q8-11 | |||
* | |||
* +--------+ | |||
* | q4[0-8]| | |||
* Rhs +--------+ | |||
* Lhs | | | |||
* | |||
* +--------+ - - - - +--------- | |||
* |d0[0]| | q8 [0-8]| | |||
* |d0[1]| | q9 [0-8]| | |||
* |d0[2]| | q10[0-8]| | |||
* |d0[3]| | q11[0-8]| | |||
* +--------+ - - - - +--------- | |||
* | |||
* Accumulator | |||
*/ | |||
static void kern_4x8(const int16_t* packA, const int8_t* packB, int K, | |||
int16_t* output, int LDC, bool is_first_k, int remain_n) { | |||
K /= 4; | |||
const int16_t* a_ptr = packA; | |||
const int8_t* b_ptr = packB; | |||
LDC = LDC * sizeof(int16_t); | |||
int x0 = 0; | |||
// clang-format off | |||
#define STORE_LINE(reg_index1) \ | |||
"cmp %[x0], #0 \n" \ | |||
"beq 101f\n" \ | |||
"vst1.16 {d" reg_index1 "}, [r0]!\n" \ | |||
"subs %[x0], %[x0], #1\n" | |||
#define STORE_C \ | |||
"mov %[x0], %[remain_n]\n" \ | |||
STORE_LINE("16") \ | |||
STORE_LINE("18") \ | |||
STORE_LINE("20") \ | |||
STORE_LINE("22") \ | |||
STORE_LINE("24") \ | |||
STORE_LINE("26") \ | |||
STORE_LINE("28") \ | |||
STORE_LINE("30") \ | |||
"101:\n" | |||
// clang-format on | |||
register int16_t* outptr asm("r0") = output; | |||
asm volatile( | |||
//! load accumulator C | |||
"add r1, r0, %[LDC]\n" | |||
"cmp %[is_first_k], #1\n" | |||
"beq 1f\n" | |||
"b 2f\n" | |||
"1:\n" | |||
"veor.s32 q8 , q8 , q8 \n" | |||
"veor.s32 q9 , q9 , q9 \n" | |||
"veor.s32 q10, q10, q10\n" | |||
"veor.s32 q11, q11, q11\n" | |||
"veor.s32 q12, q12, q12\n" | |||
"veor.s32 q13, q13, q13\n" | |||
"veor.s32 q14, q14, q14\n" | |||
"veor.s32 q15, q15, q15\n" | |||
"2: \n" | |||
"vld1.8 {d4}, [%[b_ptr]]!\n" | |||
"vld1.16 {d0}, [%[a_ptr]]!\n" | |||
"vmovl.s8 q2, d4\n" | |||
"vld1.16 {d2}, [%[a_ptr]]!\n" | |||
"vld1.8 {d6}, [%[b_ptr]]!\n" | |||
//! k0 | |||
"vmla.s16 d16, d0, d4[0]\n" | |||
"vmla.s16 d18, d0, d4[1]\n" | |||
"vmla.s16 d20, d0, d4[2]\n" | |||
"vmla.s16 d22, d0, d4[3]\n" | |||
"vmovl.s8 q3, d6\n" | |||
"vmla.s16 d24, d0, d5[0]\n" | |||
"vmla.s16 d26, d0, d5[1]\n" | |||
"vmla.s16 d28, d0, d5[2]\n" | |||
"vmla.s16 d30, d0, d5[3]\n" | |||
//! k1 | |||
"vld1.16 {d0}, [%[a_ptr]]!\n" | |||
"vld1.8 {d4}, [%[b_ptr]]!\n" | |||
"vmla.s16 d16, d2, d6[0]\n" | |||
"vmla.s16 d18, d2, d6[1]\n" | |||
"vmla.s16 d20, d2, d6[2]\n" | |||
"vmla.s16 d22, d2, d6[3]\n" | |||
"vmovl.s8 q2, d4\n" | |||
"vmla.s16 d24, d2, d7[0]\n" | |||
"vmla.s16 d26, d2, d7[1]\n" | |||
"vmla.s16 d28, d2, d7[2]\n" | |||
"vmla.s16 d30, d2, d7[3]\n" | |||
//! k2 | |||
"vld1.16 {d2}, [%[a_ptr]]!\n" | |||
"vld1.8 {d6}, [%[b_ptr]]!\n" | |||
"vmla.s16 d16, d0, d4[0]\n" | |||
"vmla.s16 d18, d0, d4[1]\n" | |||
"vmla.s16 d20, d0, d4[2]\n" | |||
"vmla.s16 d22, d0, d4[3]\n" | |||
"vmovl.s8 q3, d6\n" | |||
"vmla.s16 d24, d0, d5[0]\n" | |||
"vmla.s16 d26, d0, d5[1]\n" | |||
"vmla.s16 d28, d0, d5[2]\n" | |||
"vmla.s16 d30, d0, d5[3]\n" | |||
//! k3 | |||
"vmla.s16 d16, d2, d6[0]\n" | |||
"vmla.s16 d18, d2, d6[1]\n" | |||
"vmla.s16 d20, d2, d6[2]\n" | |||
"vmla.s16 d22, d2, d6[3]\n" | |||
"vmla.s16 d24, d2, d7[0]\n" | |||
"vmla.s16 d26, d2, d7[1]\n" | |||
"vmla.s16 d28, d2, d7[2]\n" | |||
"vmla.s16 d30, d2, d7[3]\n" | |||
"subs %[K], %[K], #1\n" | |||
"bne 2b\n" | |||
"3:\n" | |||
"cmp %[remain_n], #8\n" | |||
"bne 4f\n" | |||
"vstr d16, [r0]\n" | |||
"vstr d18, [r0, #8]\n" | |||
"vstr d20, [r0, #16]\n" | |||
"vstr d22, [r0, #24]\n" | |||
"vstr d24, [r0, #32]\n" | |||
"vstr d26, [r0, #40]\n" | |||
"vstr d28, [r0, #48]\n" | |||
"vstr d30, [r0, #56]\n" | |||
"b 101f\n" | |||
"4:\n " STORE_C | |||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
[x0] "+r"(x0), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
[outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||
: | |||
: "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", | |||
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", | |||
"d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", | |||
"d29", "d30", "d31", "r1", "r2", "r3", "cc", "memory"); | |||
#undef STORE_C | |||
#undef STORE_LINE | |||
} | |||
static void gemm_s8x8x16_mk4_8x8_pack_A_n(dt_int16* outptr, | |||
const dt_int8* inptr, int ldin, | |||
int m0, int mmax, int k0, int kmax) { | |||
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | |||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
constexpr int pack_m = 8; | |||
constexpr int pack_k = 4; | |||
constexpr int pack_size = 4; | |||
const int m_size = mmax - m0; | |||
const int m_end = m_size / pack_m * pack_m + m0; | |||
const int remain_m = mmax - m_end; | |||
for (int m_idx = m0; m_idx < m_end; m_idx += pack_m) { | |||
const int8_t* inptr0 = inptr + m_idx / pack_size * ldin + k0; | |||
const int8_t* inptr1 = inptr0 + ldin; | |||
prefetch_2x(inptr0); | |||
prefetch_2x(inptr1); | |||
for (int k_idx = k0; k_idx < kmax; k_idx += pack_size) { | |||
interleave_4x4_8x4_s8_s16(inptr0, inptr1, outptr); | |||
inptr0 += pack_size * pack_size; | |||
inptr1 += pack_size * pack_size; | |||
outptr += pack_m * pack_k; | |||
} | |||
} | |||
if (remain_m > 0) { | |||
const int8_t* inptr0 = inptr + m_end / pack_size * ldin + k0; | |||
const int k_size = kmax - k0; | |||
memcpy_s8_s16(inptr0, outptr, k_size * pack_size); | |||
} | |||
} | |||
static void gemm_s8x8x16_mk4_8x8_pack_B_n(dt_int8* out, const dt_int8* in, | |||
int ldin, int n0, int nmax, int k0, | |||
int kmax) { | |||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
int8_t tmpbuff[32] = {0}; | |||
constexpr int pack_n = 8; | |||
constexpr int pack_size = 4; | |||
const int ksize = kmax - k0; | |||
const int nsize = nmax - n0; | |||
const int n_end = nsize / pack_n * pack_n + n0; | |||
const int remain_n = nsize % pack_n; | |||
int output_stride = ksize * pack_n; | |||
int8_t* outptr_base = out; | |||
for (int k_idx = k0; k_idx < kmax; k_idx += pack_size) { | |||
const int8_t* inptr = in + k_idx / pack_size * ldin + n0 * pack_size; | |||
prefetch_3x(inptr); | |||
auto outptr = outptr_base; | |||
for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) { | |||
transpos_8x4_int8(inptr, outptr); | |||
inptr += pack_n * pack_size; | |||
outptr += output_stride; | |||
} | |||
if (remain_n > 0) { | |||
memcpy(tmpbuff, inptr, sizeof(int8_t) * remain_n * pack_size); | |||
transpos_8x4_int8(tmpbuff, outptr); | |||
outptr += output_stride; | |||
} | |||
outptr_base += pack_n * pack_size; | |||
} | |||
} | |||
} // namespace matmul_mk4_8x8x4 | |||
} // namespace armv7 | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -6,14 +6,16 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/armv7/matrix_mul/int8x8x16/strategy.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/armv7/matrix_mul/asm/common.h" | |||
#include "src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h" | |||
#include "src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h" | |||
#include "src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h" | |||
#include "src/armv7/matrix_mul/int8x8x16/strategy.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
@@ -108,7 +110,7 @@ void gemm_s8x8x16_4x2::kern(const dt_int8* packA, const dt_int8* packB, | |||
} | |||
} | |||
// ===========================gemm_s8x8x16_4x4================================== | |||
// ===========================gemm_s8x8x16_4x8================================== | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x8); | |||
void gemm_s8x8x16_4x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||
@@ -179,4 +181,79 @@ void gemm_s8x8x16_4x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
} | |||
} | |||
// ===========================gemm_s8x8x16_mk4_8x8================================== | |||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8); | |||
void gemm_s8x8x16_mk4_8x8::pack_A(dt_int16* out, const dt_int8* in, int ldin, | |||
int y0, int ymax, int k0, int kmax, | |||
bool) const { | |||
matmul_mk4_8x8x4::gemm_s8x8x16_mk4_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, | |||
kmax); | |||
} | |||
void gemm_s8x8x16_mk4_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
int x0, int xmax, int k0, int kmax, | |||
bool) const { | |||
matmul_mk4_8x8x4::gemm_s8x8x16_mk4_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, | |||
kmax); | |||
} | |||
void gemm_s8x8x16_mk4_8x8::kern(const dt_int16* packA, const dt_int8* packB, | |||
size_t M, size_t N, size_t K, dt_int16* C, | |||
size_t LDC, bool is_first_k, const dt_int16*, | |||
dt_int16*) const { | |||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
C_dtype.enumv() == DTypeEnum::Int16 && | |||
A_dtype.enumv() == DTypeEnum::Int8); | |||
megdnn_assert(is_first_k == true, "only impl is_first_k"); | |||
MEGDNN_MARK_USED_VAR(A_dtype); | |||
MEGDNN_MARK_USED_VAR(B_dtype); | |||
MEGDNN_MARK_USED_VAR(C_dtype); | |||
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | |||
constexpr size_t pack_size = 4; | |||
constexpr size_t pack_m = 8; | |||
constexpr size_t pack_n = 8; | |||
const size_t remain_n = N % pack_n; | |||
const size_t remain_m = M % pack_m; | |||
size_t m_idx = 0; | |||
for (; m_idx + pack_m <= M; m_idx += pack_m) { | |||
int16_t* output = C + (m_idx / pack_size * LDC); | |||
size_t n_idx = 0; | |||
const int8_t* cur_packB = packB; | |||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
matmul_mk4_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, | |||
is_first_k, pack_n); | |||
output += pack_n * pack_size; | |||
cur_packB += pack_n * K; | |||
} | |||
if (remain_n > 0) { | |||
matmul_mk4_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, | |||
is_first_k, remain_n); | |||
output += remain_n * pack_size; | |||
cur_packB += pack_n * K; | |||
} | |||
packA += pack_m * K; | |||
} | |||
if (remain_m > 0) { | |||
int16_t* output = C + (m_idx / pack_size * LDC); | |||
size_t n_idx = 0; | |||
const int8_t* cur_packB = packB; | |||
for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
matmul_mk4_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, | |||
is_first_k, pack_n); | |||
output += pack_n * pack_size; | |||
cur_packB += pack_n * K; | |||
} | |||
if (remain_n > 0) { | |||
matmul_mk4_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, | |||
is_first_k, remain_n); | |||
output += remain_n * pack_size; | |||
cur_packB += pack_n * K; | |||
} | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
@@ -21,6 +22,10 @@ MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 2, 16, false, true, | |||
MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 8, 8, false, true, | |||
gemm_s8x8x16_4x8); | |||
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(int8_t, int16_t, int16_t, int16_t, 8, | |||
8, 4, false, false, | |||
gemm_s8x8x16_mk4_8x8); | |||
} // namespace matmul | |||
} // namespace armv7 | |||
} // namespace megdnn | |||
@@ -6,10 +6,11 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/armv7/matrix_mul/opr_impl.h" | |||
#include "src/armv7/matrix_mul/algos.h" | |||
#include "src/armv7/matrix_mul/opr_impl.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/matrix_mul/gemm_impl.h" | |||
@@ -21,7 +22,7 @@ using namespace armv7; | |||
class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
AlgoF32 f32; | |||
AlgoF32MK4Pack4x12 f32_mk4_pack_4x12; | |||
AlgoF32MK4_4x8 f32_mk4_4x8; | |||
AlgoF32MK4_4x8 f32_mk4_4x8; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
AlgoF16K4x16x1 f16_k4x16x1; | |||
AlgoF16MK8_4x8 f16_mk8_4x8; | |||
@@ -38,6 +39,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
AlgoQuint8K4x8x8 quint8_k4x8x8; | |||
AlgoInt8x8x16K4x2x16 int8x8x16_k4x2x16; | |||
AlgoInt8x8x16K4x8x8 int8x8x16_k4x8x8; | |||
AlgoInt8x8x16MK4_8x8x4 int8x8x16_mk4_8x8x4; | |||
AlgoInt16x16x32K12x4x1 int16x16x32_k12x4x1; | |||
AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8; | |||
@@ -62,8 +64,10 @@ public: | |||
all_algos.emplace_back(&int8x8x32_k4x2x16); | |||
all_algos.emplace_back(&int8x8x32_k4x8x8); | |||
all_algos.emplace_back(&quint8_k4x8x8); | |||
all_algos.emplace_back(&int8x8x16_mk4_8x8x4); | |||
all_algos.emplace_back(&int8x8x16_k4x2x16); | |||
all_algos.emplace_back(&int8x8x16_k4x8x8); | |||
all_algos.emplace_back(&int16x16x32_k12x4x1); | |||
all_algos.emplace_back(&int16x16x32_mk8_4x8); | |||
} | |||
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "src/arm_common/matrix_mul/opr_impl.h" | |||
@@ -19,26 +20,28 @@ public: | |||
using arm_common::MatrixMulImpl::MatrixMulImpl; | |||
SmallVector<AlgoBase*> algo_pack() override; | |||
private: | |||
class AlgoF32; // Armv7 F32 | |||
class AlgoF32MK4Pack4x12; // Armv7 F32 Kernel 4x12 with pack | |||
class AlgoF32MK4_4x8; // Armv7 F32 Kernel 4x8 nopack | |||
class AlgoF32Gemv; // Armv7 F32 Gemv | |||
class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 | |||
class AlgoInt8x8x32K4x2x16; // Armv7 Int8x8x32 Kernel 4x2x16 | |||
class AlgoF32; // Armv7 F32 | |||
class AlgoF32MK4Pack4x12; // Armv7 F32 Kernel 4x12 with pack | |||
class AlgoF32MK4_4x8; // Armv7 F32 Kernel 4x8 nopack | |||
class AlgoF32Gemv; // Armv7 F32 Gemv | |||
class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 | |||
class AlgoInt8x8x32K4x2x16; // Armv7 Int8x8x32 Kernel 4x2x16 | |||
class AlgoInt8x8x32MK4_4x2x16; // Armv7 Int8x8x32 Kernel MK4 4x2x16 | |||
class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8 | |||
class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16 | |||
class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8 | |||
class AlgoInt16x16x32K12x4x1; // Armv7 Int16x16x32 Kernel 12x4x1 | |||
class AlgoInt16x16x32MK8_4x8; // Armv7 Int16x16x32 MK8 Format block 4x8 | |||
class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8 | |||
class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16 | |||
class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8 | |||
class AlgoInt8x8x16MK4_8x8x4; // Armv7 Int8x8x16 Kernel 8x8x8 | |||
class AlgoInt16x16x32K12x4x1; // Armv7 Int16x16x32 Kernel 12x4x1 | |||
class AlgoInt16x16x32MK8_4x8; // Armv7 Int16x16x32 MK8 Format block 4x8 | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
class AlgoF16K4x16x1; // Armv7 F16 Kernel 4x16x1 | |||
class AlgoF16MK8_4x8; // Armv7 F16 MK8 Format block 4x8 | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4 | |||
class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4 | |||
class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4 | |||
class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4 | |||
class AlgoInt8x8x32MK4_8x4x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x4x4 | |||
// DotProduct | |||
#endif | |||
@@ -10,9 +10,9 @@ | |||
* implied. | |||
*/ | |||
#include "src/fallback/conv_bias/conv1x1/algos.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/conv1x1/algos.h" | |||
#include "src/fallback/conv_bias/conv1x1/conv1x1_dispatcher.h" | |||
#include "src/fallback/conv_bias/conv1x1/conv1x1_strategy.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
@@ -194,10 +194,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param, | |||
PW = param.filter_meta.padding[1]; | |||
size_t SH = param.filter_meta.stride[0], | |||
SW = param.filter_meta.stride[1]; | |||
if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) | |||
return false; | |||
if (param.src_type.enumv() != param.filter_type.enumv()) { | |||
return false; | |||
} | |||
@@ -216,6 +214,7 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param, | |||
//! is identity otherwise return false mean that 8x8x32 and 8x8x16 | |||
//! not support PostProcess | |||
if (param.dst_type.enumv() == DTypeEnum::Int16 || | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS16 || | |||
param.dst_type.enumv() == DTypeEnum::Int32 || | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | |||
if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include <unordered_map> | |||
@@ -226,10 +227,10 @@ public: | |||
PostprocessMode::FLOAT, | |||
"DefaultStrategyType::FLOAT"_hash); | |||
} else if (format == param::ConvBias::Format::NCHW44) { | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
auto matmul_block = matmul_algo->get_inner_block_size(); | |||
//! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1 im2col+pack fuse | |||
//! Optimize NCHW44 3x3s2 aarch64 8X12X1 and armv7 4x12x1 | |||
//! im2col+pack fuse | |||
if ((matmul_block.m == 8 || matmul_block.m == 4) && | |||
matmul_block.n == 12 && matmul_block.k == 1 && | |||
param.filter_meta.spatial[0] == 3 && | |||
@@ -297,9 +298,21 @@ public: | |||
break; | |||
case StrategyType::INT8x8x16: | |||
cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
"DefaultStrategyType::INT8x8x16"_hash); | |||
if (format == param::ConvBias::Format::NCHW) { | |||
cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
"DefaultStrategyType::INT8x8x16"_hash); | |||
} else if (format == param::ConvBias::Format::NCHW44) { | |||
cb2(NCHW44, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8, | |||
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
"DefaultStrategyType::INT8x8x16"_hash); | |||
} else { | |||
megdnn_throw( | |||
ssprintf("Current only support layout " | |||
"NCHW44/NCHW for im2col " | |||
"algo, but got %d\n", | |||
uint32_t(format))); | |||
} | |||
break; | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
case StrategyType::QUINT8x8x32: | |||
@@ -421,10 +434,11 @@ public: | |||
dt_int32, dt_int8, PostprocessMode::QUANTIZED, | |||
"DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash); | |||
} else { | |||
megdnn_throw(ssprintf("Current only support layout " | |||
"NCHW44/NCHW/NCHW_DOT for im2col " | |||
"algo, but got %d\n", | |||
uint32_t(format))); | |||
megdnn_throw( | |||
ssprintf("Current only support layout " | |||
"NCHW44/NCHW/NCHW_DOT for im2col " | |||
"algo, but got %d\n", | |||
uint32_t(format))); | |||
} | |||
break; | |||
} | |||
@@ -6,11 +6,12 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "src/naive/matrix_mul/opr_impl.h" | |||
#include "src/common/utils.h" | |||
#include "src/naive/matrix_mul/opr_impl.h" | |||
namespace megdnn { | |||
namespace fallback { | |||
@@ -66,7 +67,8 @@ public: | |||
}; | |||
typedef void (*kern_t)(const KernParam&); | |||
typedef void (*kern_naked_t)(const KernParam& , const void* a_panel, const void *b_panel); | |||
typedef void (*kern_naked_t)(const KernParam&, const void* a_panel, | |||
const void* b_panel); | |||
class AlgoBase : public Algorithm { | |||
protected: | |||
virtual ~AlgoBase() = default; | |||
@@ -83,18 +85,19 @@ public: | |||
bool can_be_treated_as_int8x8x16(const KernSizeParam& param) const { | |||
return param.A_type.enumv() == param.B_type.enumv() && | |||
param.A_type.enumv() == DTypeEnum::Int8 && | |||
param.C_type.enumv() == DTypeEnum::Int16 && | |||
param.format == param::MatrixMul::Format::DEFAULT && | |||
param.compute_mode == Param::ComputeMode::DEFAULT; | |||
(param.A_type.enumv() == DTypeEnum::Int8 || | |||
param.A_type.enumv() == DTypeEnum::QuantizedS8) && | |||
(param.C_type.enumv() == DTypeEnum::Int16 || | |||
param.C_type.enumv() == DTypeEnum::QuantizedS16); | |||
} | |||
public: | |||
enum class AlgoSet:uint32_t { | |||
enum class AlgoSet : uint32_t { | |||
ALGO_TYPE_GEMM = 0, | |||
ALGO_TYPE_GEMV = 1, | |||
}; | |||
enum class PackMode:uint32_t { | |||
enum class PackMode : uint32_t { | |||
DEFAULT = 0, | |||
NO_PACK = 1, | |||
ONLY_PACKA = 2, | |||
@@ -489,25 +489,26 @@ void benchmark_im2col_single_algo(const char* im2col_name, Handle* handle, | |||
void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name, | |||
const char* im2col_name, Handle* handle, | |||
size_t kernel, size_t pack_size = 1) { | |||
auto&& args = get_winograd_benchmark_args(kernel, pack_size); | |||
size_t kernel, DType src_type, | |||
DType dst_type) { | |||
auto&& args = get_winograd_benchmark_args(kernel, 4); | |||
using namespace conv_bias; | |||
constexpr size_t RUN = 10; | |||
Benchmarker<ConvBias> benchmark(handle); | |||
benchmark.set_display(false); | |||
benchmark.set_times(RUN); | |||
benchmark.set_dtype(0, dtype::Int8()); | |||
benchmark.set_dtype(1, dtype::Int8()); | |||
benchmark.set_dtype(2, dtype::Int32()); | |||
benchmark.set_dtype(4, dtype::Int32()); | |||
benchmark.set_dtype(0, src_type); | |||
benchmark.set_dtype(1, src_type); | |||
benchmark.set_dtype(2, dst_type); | |||
benchmark.set_dtype(4, dst_type); | |||
Benchmarker<ConvBias> benchmark_im2col(handle); | |||
benchmark_im2col.set_display(false); | |||
benchmark_im2col.set_times(RUN); | |||
benchmark_im2col.set_dtype(0, dtype::Int8()); | |||
benchmark_im2col.set_dtype(1, dtype::Int8()); | |||
benchmark_im2col.set_dtype(2, dtype::Int32()); | |||
benchmark_im2col.set_dtype(4, dtype::Int32()); | |||
benchmark_im2col.set_dtype(0, src_type); | |||
benchmark_im2col.set_dtype(1, src_type); | |||
benchmark_im2col.set_dtype(2, dst_type); | |||
benchmark_im2col.set_dtype(4, dst_type); | |||
for (auto&& arg : args) { | |||
TensorLayout dst_layout; | |||
@@ -556,6 +557,7 @@ void BENCHMARK_IM2COL_NCHW44_VS_NCHW(const char* algo_name, | |||
computations / used_im2col, used / used_im2col); | |||
} | |||
} | |||
#if MEGDNN_AARCH64 | |||
TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) { | |||
printf("=========================compare " | |||
@@ -563,7 +565,17 @@ TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) { | |||
"IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16 \n"); | |||
BENCHMARK_IM2COL_NCHW44_VS_NCHW("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16", | |||
"IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16", | |||
handle(), 3, 4); | |||
handle(), 3, dtype::Int8(), dtype::Int32()); | |||
} | |||
#endif | |||
#if MEGDNN_ARMV7 | |||
TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x16) { | |||
const char* default_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"; | |||
const char* mk4_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4"; | |||
printf("compare %s vs %s \n", default_algo, mk4_algo); | |||
BENCHMARK_IM2COL_NCHW44_VS_NCHW(default_algo, mk4_algo, handle(), 3, | |||
dtype::Int8(), dtype::Int16()); | |||
} | |||
#endif | |||
@@ -1860,15 +1872,16 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) { | |||
param.format = param::ConvBias::Format::NCHW44_DOT; | |||
//! channel bias | |||
args.emplace_back(param, TensorShape{1, ic/4, h, w, 4}, | |||
TensorShape{oc/4, ic/4, kernel, kernel, 4, 4}, | |||
TensorShape{1, oc/4, 1, 1, 4}); | |||
args.emplace_back(param, TensorShape{1, ic / 4, h, w, 4}, | |||
TensorShape{oc / 4, ic / 4, kernel, kernel, 4, 4}, | |||
TensorShape{1, oc / 4, 1, 1, 4}); | |||
}; | |||
for (size_t stride : {1, 2}) | |||
for (size_t kernel : {2, 3, 5, 7}) | |||
for(size_t oc : {64}) | |||
for (size_t oc : {64}) | |||
for (NonlineMode nonline_mode : {NonlineMode::IDENTITY}) { | |||
run(oc, oc, 56, 56, kernel, kernel / 2, stride, nonline_mode); | |||
run(oc, oc, 56, 56, kernel, kernel / 2, stride, | |||
nonline_mode); | |||
} | |||
constexpr size_t RUN = 50; | |||
@@ -1880,7 +1893,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_NCHW44_DOT) { | |||
benchmark0.set_display(false); | |||
benchmark0.set_times(RUN); | |||
benchmark0.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("ARMDOTS8DIRECT_NCHW44")); | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
"ARMDOTS8DIRECT_NCHW44")); | |||
Benchmarker<ConvBias> benchmark1(handle()); | |||
benchmark1.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||
@@ -2002,15 +2016,20 @@ std::vector<conv_bias::TestArg> get_conv_bias_1x1_benchmark_args( | |||
void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle, | |||
DType stype, DType matmul_dtype, DType bias_type, | |||
DType conv_dtype) { | |||
DType conv_dtype, bool is_mk4 = false) { | |||
using namespace conv_bias; | |||
int pack_size = is_mk4 ? 4 : 1; | |||
std::vector<TestArg> conv_bias_1x1_args = | |||
get_conv_bias_1x1_benchmark_args(); | |||
get_conv_bias_1x1_benchmark_args(pack_size); | |||
constexpr size_t RUNS = 50; | |||
param::MatrixMul param; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
if (is_mk4) { | |||
param.format = MatrixMul::Param::Format::MK4; | |||
} | |||
Benchmarker<MatrixMul> benchmark_matmul(handle); | |||
benchmark_matmul.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>(matmul_algo_name)); | |||
@@ -2038,8 +2057,8 @@ void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle, | |||
size_t OH = arg.src[2]; | |||
size_t OW = arg.src[3]; | |||
size_t OC = arg.filter[0]; | |||
size_t M = OC; | |||
size_t K = IC; | |||
size_t M = OC * pack_size; | |||
size_t K = IC * pack_size; | |||
size_t N = OH * OW; | |||
float computations = M * N * K * 2.f / (1024 * 1024 * 1024) * 1e3; | |||
@@ -2047,6 +2066,10 @@ void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle, | |||
TensorShape A, B; | |||
A = TensorShape{M, K}; | |||
B = TensorShape{K, N}; | |||
if (is_mk4) { | |||
A = TensorShape{M / 4, K / 4, 4, 4}; | |||
B = TensorShape{K / 4, N, 4}; | |||
} | |||
auto conv1x1_used = benchmark_conv1x1.set_param(arg.param).exec( | |||
{arg.src, arg.filter, arg.bias, {}, {}}) / | |||
@@ -2133,6 +2156,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_INT8x8x16) { | |||
dtype::Int16{}, dtype::Int16{}, dtype::Int16{}); | |||
benchmark_conv1x1("ARMV7_INT8X8X16_K4X2X16", handle(), dtype::Int8{}, | |||
dtype::Int16{}, dtype::Int16{}, dtype::Int16{}); | |||
benchmark_conv1x1("ARMV7_INT8X8X16_MK4_K8X8X4", handle(), dtype::Int8{}, | |||
dtype::Int16{}, dtype::Int16{}, dtype::Int16{}, true); | |||
#endif | |||
} | |||
@@ -2145,13 +2170,13 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_GEMV_FP32) { | |||
conv_param.pad_h = 0; | |||
conv_param.pad_w = 0; | |||
conv_param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; | |||
auto run = [&](size_t M, size_t K){ | |||
auto run = [&](size_t M, size_t K) { | |||
args.emplace_back(conv_param, TensorShape{1, K, 1, 1}, | |||
TensorShape{M, K, 1, 1}, TensorShape{}); | |||
}; | |||
for (size_t M : {4, 64, 1024, 4096}) | |||
for (size_t K : {128, 256, 1024, 4096}) | |||
run(M, K); | |||
run(M, K); | |||
constexpr size_t RUNS = 50; | |||
param::MatrixMul param; | |||
@@ -850,7 +850,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) { | |||
param::ConvBias::Format::NCHW44); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44_WEIGHT_PREPROCESS) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||
CONV_BIAS_WINOGRAD_F63_4_NCHW44_WEIGHT_PREPROCESS) { | |||
using namespace conv_bias; | |||
std::vector<TestArg> args = get_nchw44_conv_bias_args({3}, 1); | |||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||
@@ -1131,7 +1132,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) { | |||
1e-3f); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2_WEIGHT_PREPROCESS) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, | |||
CONV_BIAS_WINOGRAD_MK_PACKED_F32_2_WEIGHT_PREPROCESS) { | |||
using namespace conv_bias; | |||
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker( | |||
@@ -2089,6 +2091,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) { | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { | |||
UniformIntRNG rng{-50, 50}; | |||
float epsilon = 0.001; | |||
std::vector<conv_bias::TestArg> args_nchw44 = | |||
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, true, true, true, | |||
false, false, false, false, true); | |||
std::vector<conv_bias::TestArg> args_nchw44_1x1s2 = | |||
get_nchw44_conv_bias_args({1}, 2, true, true, true, false, false, | |||
false, false, true); | |||
#define cb(name) \ | |||
checker_conv_bias( \ | |||
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \ | |||
@@ -2098,6 +2106,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { | |||
&rng, epsilon, dtype::Int8{}, dtype::Int8{}, \ | |||
dtype::Int16{}, dtype::Int16{}, name); | |||
#define cb_nchw44(name) \ | |||
checker_conv_bias(args_nchw44, handle(), &rng, epsilon, dtype::Int8{}, \ | |||
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); \ | |||
checker_conv_bias(args_nchw44_1x1s2, handle(), &rng, epsilon, \ | |||
dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, \ | |||
dtype::Int16{}, name); | |||
#if MEGDNN_AARCH64 | |||
cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8"); | |||
cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16"); | |||
@@ -2106,8 +2121,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) { | |||
cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16"); | |||
cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8"); | |||
cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16"); | |||
cb_nchw44("IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4"); | |||
#endif | |||
#undef cb | |||
#undef cb_nchw44 | |||
} | |||
#endif | |||
@@ -2516,19 +2534,28 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { | |||
UniformIntRNG rng{-50, 50}; | |||
float epsilon = 0.001; | |||
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||
std::vector<conv_bias::TestArg> args_nchw44 = get_nchw44_conv_bias_args( | |||
{1}, 1, true, true, true, false, false, false, false, true); | |||
#define cb(name) \ | |||
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \ | |||
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); | |||
#define cb_nchw44(name) \ | |||
checker_conv_bias(args_nchw44, handle(), &rng, epsilon, dtype::Int8{}, \ | |||
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); | |||
#if MEGDNN_AARCH64 | |||
cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24"); | |||
cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24"); | |||
#elif MEGDNN_ARMV7 | |||
cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24"); | |||
cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48"); | |||
cb_nchw44("CONV1x1:ARMV7_INT8X8X16_MK4_K8X8X4:48"); | |||
#endif | |||
cb("CONV1x1:ARM_COMMON_INT8X8X16:48"); | |||
#undef cb | |||
#undef cb_nchw44 | |||
std::vector<conv_bias::TestArg> gemv_args; | |||
for (auto&& arg : args) | |||
@@ -6,7 +6,8 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "test/armv7/fixture.h" | |||
#include "test/common/benchmarker.h" | |||
@@ -51,9 +52,15 @@ TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_K4x8x8) { | |||
handle(), "ARMV7_INT8X8X16_K4X8X8"); | |||
} | |||
TEST_F(ARMV7, MATRIX_MUL_INT8x8x16_MK4_K8x8x4) { | |||
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, | |||
handle(), "ARMV7_INT8X8X16_MK4_K8X8X4", | |||
param::MatrixMul::Format::MK4, 1); | |||
} | |||
TEST_F(ARMV7, MATRIX_MUL_INT16x16x32) { | |||
matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, | |||
handle(),"ARMV7_INT16X16X32_K12X4X1"); | |||
matrix_mul::check_matrix_mul(dtype::Int16{}, dtype::Int16{}, dtype::Int32{}, | |||
handle(), "ARMV7_INT16X16X32_K12X4X1"); | |||
} | |||
TEST_F(ARMV7, MATRIX_MUL_INT16x16x32_MK8) { | |||
@@ -83,7 +90,8 @@ TEST_F(ARMV7, MATRIX_MUL_SDOT) { | |||
TEST_F(ARMV7, MATRIX_MUL_UDOT) { | |||
matrix_mul::check_matrix_mul( | |||
dtype::Quantized8Asymm(4.0f, static_cast<uint8_t>(10)), dtype::Quantized8Asymm(3.0f, static_cast<uint8_t>(54)), | |||
dtype::Quantized8Asymm(4.0f, static_cast<uint8_t>(10)), | |||
dtype::Quantized8Asymm(3.0f, static_cast<uint8_t>(54)), | |||
dtype::QuantizedS32(12.0f), handle(), "AARCH32_QUINT8_K4X8X4"); | |||
} | |||
@@ -103,7 +111,9 @@ TEST_F(ARMV7, MATRIX_MUL_MK4_DOT_INT8) { | |||
#if MEGDNN_WITH_BENCHMARK | |||
namespace { | |||
void run_8x8x16_benchmark(const char* algo, Handle* handle) { | |||
void run_8x8x16_benchmark( | |||
const char* algo, Handle* handle, | |||
MatrixMul::Param::Format format = MatrixMul::Param::Format::DEFAULT) { | |||
constexpr size_t RUNS = 50; | |||
param::MatrixMul param; | |||
Benchmarker<MatrixMul> benchmarker_int(handle); | |||
@@ -116,21 +126,31 @@ void run_8x8x16_benchmark(const char* algo, Handle* handle) { | |||
.set_dtype(2, dtype::Int16{}) | |||
.set_param(param) | |||
.set_display(false); | |||
param::MatrixMul target_param; | |||
target_param.format = format; | |||
benchmarker_int_kern_4x2x16.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>(algo)); | |||
benchmarker_int_kern_4x2x16.set_times(RUNS) | |||
.set_dtype(0, dtype::Int8{}) | |||
.set_dtype(1, dtype::Int8{}) | |||
.set_dtype(2, dtype::Int16{}) | |||
.set_param(param) | |||
.set_param(target_param) | |||
.set_display(false); | |||
Benchmarker<MatrixMul> benchmarker_float(handle); | |||
benchmarker_float.set_display(false).set_times(RUNS); | |||
auto run = [&](size_t M, size_t N, size_t K) { | |||
auto int_used = benchmarker_int.exec({{M, K}, {K, N}, {}}) / RUNS; | |||
auto int_kern_used = | |||
benchmarker_int_kern_4x2x16.exec({{M, K}, {K, N}, {}}) / RUNS; | |||
auto int_kern_used = 1e10; | |||
if (format == MatrixMul::Param::Format::MK4) { | |||
int_kern_used = benchmarker_int_kern_4x2x16.exec( | |||
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) / | |||
RUNS; | |||
} else { | |||
int_kern_used = | |||
benchmarker_int_kern_4x2x16.exec({{M, K}, {K, N}, {}}) / | |||
RUNS; | |||
} | |||
auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS; | |||
float computations = 2.f * M * K * N * 1e-6; | |||
printf("run: {%zu{M} %zu{K} %zu{N}} float: %f ms %f Gflops int: %f " | |||
@@ -145,6 +165,7 @@ void run_8x8x16_benchmark(const char* algo, Handle* handle) { | |||
}; | |||
run(256, 12 * 24, 256); | |||
run(256, 256, 256); | |||
//////////////////////// gemv ////////////////////////// | |||
for (size_t M : {8, 64, 112, 256}) { | |||
@@ -185,7 +206,8 @@ void run_16x16x32_benchmark(const char* algo, Handle* handle) { | |||
"int: %f ms %f Gflops %s: \n" | |||
"speedup(%s/arm_common, %s/float): %f\n", | |||
M, K, N, float_used, computations / float_used, int_used, | |||
computations / int_used,algo,algo,algo,float_used / int_used); | |||
computations / int_used, algo, algo, algo, | |||
float_used / int_used); | |||
}; | |||
run(256, 12 * 24, 256); | |||
@@ -231,7 +253,8 @@ void run_8x8x32_benchmark(const char* algo, Handle* handle) { | |||
"int: %f ms %f Gflops %s: \n" | |||
"speedup(%s/arm_common, %s/float): %f\n", | |||
M, K, N, float_used, computations / float_used, int_used, | |||
computations / int_used,algo,algo,algo,float_used / int_used); | |||
computations / int_used, algo, algo, algo, | |||
float_used / int_used); | |||
}; | |||
run(256, 12 * 24, 256); | |||
@@ -252,9 +275,11 @@ void run_8x8x32_quint_benchmark(Handle* handle) { | |||
benchmarker_quint8_dot.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>("AARCH32_QUINT8_K4X8X4")); | |||
benchmarker_quint8_dot.set_times(RUNS) | |||
.set_dtype(0, dtype::Quantized8Asymm(2.3f, static_cast<uint8_t>(20))) | |||
.set_dtype(1, dtype::Quantized8Asymm(3.1f, static_cast<uint8_t>(30))) | |||
.set_dtype(2, dtype::QuantizedS32(2.3f*3.1f)) | |||
.set_dtype(0, | |||
dtype::Quantized8Asymm(2.3f, static_cast<uint8_t>(20))) | |||
.set_dtype(1, | |||
dtype::Quantized8Asymm(3.1f, static_cast<uint8_t>(30))) | |||
.set_dtype(2, dtype::QuantizedS32(2.3f * 3.1f)) | |||
.set_param(param) | |||
.set_display(false); | |||
@@ -262,14 +287,17 @@ void run_8x8x32_quint_benchmark(Handle* handle) { | |||
benchmarker_quint8.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>("ARMV7_QUINT8_K4X8X8")); | |||
benchmarker_quint8.set_times(RUNS) | |||
.set_dtype(0, dtype::Quantized8Asymm(2.3f, static_cast<uint8_t>(20))) | |||
.set_dtype(1, dtype::Quantized8Asymm(3.1f, static_cast<uint8_t>(30))) | |||
.set_dtype(2, dtype::QuantizedS32(2.3f*3.1f)) | |||
.set_dtype(0, | |||
dtype::Quantized8Asymm(2.3f, static_cast<uint8_t>(20))) | |||
.set_dtype(1, | |||
dtype::Quantized8Asymm(3.1f, static_cast<uint8_t>(30))) | |||
.set_dtype(2, dtype::QuantizedS32(2.3f * 3.1f)) | |||
.set_param(param) | |||
.set_display(false); | |||
auto run = [&](size_t M, size_t N, size_t K) { | |||
auto dot_used = benchmarker_quint8_dot.exec({{M, K}, {K, N}, {}}) / RUNS; | |||
auto dot_used = | |||
benchmarker_quint8_dot.exec({{M, K}, {K, N}, {}}) / RUNS; | |||
auto normal_used = benchmarker_quint8.exec({{M, K}, {K, N}, {}}) / RUNS; | |||
float computations = 2.f * M * K * N * 1e-6; | |||
printf("run: {%zu{M} %zu{K} %zu{N}} dot: %f ms %f Gflops \n" | |||
@@ -351,11 +379,15 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x2x16) { | |||
run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X2X16", handle()); | |||
} | |||
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_K4x8x8) { | |||
run_8x8x16_benchmark("ARMV7_INT8X8X16_K4X8X8", handle()); | |||
} | |||
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x16_MK4_K4x8x8) { | |||
run_8x8x16_benchmark("ARMV7_INT8X8X16_MK4_K8X8X4", handle(), | |||
MatrixMul::Param::Format::MK4); | |||
} | |||
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_K12x4x1) { | |||
run_16x16x32_benchmark("ARMV7_INT16X16X32_K12X4X1", handle()); | |||
} | |||
@@ -6,12 +6,13 @@ | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "test/common/matrix_mul.h" | |||
#include "src/common/utils.h" | |||
#include "test/common/benchmarker.h" | |||
#include "test/common/checker.h" | |||
#include "test/common/matrix_mul.h" | |||
using namespace megdnn; | |||
using namespace test; | |||
@@ -39,9 +40,9 @@ std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_args_no_mask() { | |||
std::vector<matrix_mul::TestArg> matrix_mul::get_matmul_mk_packed_args( | |||
size_t nbase) { | |||
std::vector<TestArg> args; | |||
for (size_t m : {1, 2, 3, 4, 5}) | |||
for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 11}) | |||
for (size_t n : {1, 2, 3, 4, 5, 8, 12, 16, 24}) | |||
for (size_t k : {1, 2, 3, 4, 5, 9, 10}) | |||
for (size_t k : {1, 2, 3, 4, 5, 9, 10, 11}) | |||
args.emplace_back(m, n * nbase, k, 0); | |||
return args; | |||
} | |||