GitOrigin-RevId: 4c746ef228
tags/v0.5.0
@@ -707,11 +707,11 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8DotK4x8x4, | |||||
armv7::matmul::gemm_dot_quint8_4x8, | armv7::matmul::gemm_dot_quint8_4x8, | ||||
uint8_t, int32_t); | uint8_t, int32_t); | ||||
/* ======================== Int8 MK4 8x6x4 dot algo ======================== */ | |||||
/* ======================== Int8 MK4 8x4x4 dot algo ======================== */ | |||||
namespace { | namespace { | ||||
void int8_mk4_8x6x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
void int8_mk4_8x4x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
MIDOUT_BEGIN(megdnn_armv7_matmul_kern, | MIDOUT_BEGIN(megdnn_armv7_matmul_kern, | ||||
midout_iv("int8_mk4_8x6x4_dotprod_kern"_hash)) { | |||||
midout_iv("int8_mk4_8x4x4_dotprod_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | ||||
auto trA = kern_param.trA, trB = kern_param.trB; | auto trA = kern_param.trA, trB = kern_param.trB; | ||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | ||||
@@ -720,9 +720,9 @@ void int8_mk4_8x6x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
const auto Aptr = kern_param.A<dt_int8>(), | const auto Aptr = kern_param.A<dt_int8>(), | ||||
Bptr = kern_param.B<dt_int8>(); | Bptr = kern_param.B<dt_int8>(); | ||||
auto Cptr = kern_param.C<dt_int32>(); | auto Cptr = kern_param.C<dt_int32>(); | ||||
armv7::matmul::gemm_mk4_dots8_8x6 strategy(M, N, K, A_type, B_type, | |||||
armv7::matmul::gemm_mk4_dots8_8x4 strategy(M, N, K, A_type, B_type, | |||||
C_type); | C_type); | ||||
megdnn::matmul::GemmInterleaved<armv7::matmul::gemm_mk4_dots8_8x6>( | |||||
megdnn::matmul::GemmInterleaved<armv7::matmul::gemm_mk4_dots8_8x4>( | |||||
M, N, K, trA, trB, strategy) | M, N, K, trA, trB, strategy) | ||||
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | ||||
kern_param.workspace_ptr); | kern_param.workspace_ptr); | ||||
@@ -731,7 +731,7 @@ void int8_mk4_8x6x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
} | } | ||||
} // namespace | } // namespace | ||||
bool MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::usable( | |||||
bool MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::usable( | |||||
const KernSizeParam& kern_size_param) const { | const KernSizeParam& kern_size_param) const { | ||||
return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && | return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && | ||||
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | ||||
@@ -743,35 +743,35 @@ bool MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::usable( | |||||
!kern_size_param.trA && !kern_size_param.trB; | !kern_size_param.trA && !kern_size_param.trB; | ||||
} | } | ||||
size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::get_workspace( | |||||
size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_workspace( | |||||
const KernSizeParam& kern_size_param) const { | const KernSizeParam& kern_size_param) const { | ||||
MIDOUT_BEGIN( | MIDOUT_BEGIN( | ||||
megdnn_armv7_matmul_kern, | megdnn_armv7_matmul_kern, | ||||
midout_iv("AlgoInt8x8x32MK4_8x6x4DotProd::get_workspace"_hash)) { | |||||
midout_iv("AlgoInt8x8x32MK4_8x4x4DotProd::get_workspace"_hash)) { | |||||
auto M = kern_size_param.M, N = kern_size_param.N, | auto M = kern_size_param.M, N = kern_size_param.N, | ||||
K = kern_size_param.K; | K = kern_size_param.K; | ||||
auto trA = kern_size_param.trA, trB = kern_size_param.trB; | auto trA = kern_size_param.trA, trB = kern_size_param.trB; | ||||
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | ||||
C_type = kern_size_param.C_type; | C_type = kern_size_param.C_type; | ||||
armv7::matmul::gemm_mk4_dots8_8x6 strategy(M, N, K, A_type, B_type, | |||||
armv7::matmul::gemm_mk4_dots8_8x4 strategy(M, N, K, A_type, B_type, | |||||
C_type); | C_type); | ||||
return megdnn::matmul::GemmInterleaved< | return megdnn::matmul::GemmInterleaved< | ||||
armv7::matmul::gemm_mk4_dots8_8x6>(M, N, K, trA, trB, | |||||
armv7::matmul::gemm_mk4_dots8_8x4>(M, N, K, trA, trB, | |||||
strategy) | strategy) | ||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd::get_kern( | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_kern( | |||||
const KernSizeParam&) const { | const KernSizeParam&) const { | ||||
return int8_mk4_8x6x4_dotprod_kern; | |||||
return int8_mk4_8x4x4_dotprod_kern; | |||||
} | } | ||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x6x4DotProd, | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x4x4DotProd, | |||||
megdnn_armv7_matmul_kern, | megdnn_armv7_matmul_kern, | ||||
"AlgoInt8x8x32MK4_8x6x4DotProd"_hash, | |||||
armv7::matmul::gemm_mk4_dots8_8x6, int8_t, | |||||
"AlgoInt8x8x32MK4_8x4x4DotProd"_hash, | |||||
armv7::matmul::gemm_mk4_dots8_8x4, int8_t, | |||||
int32_t); | int32_t); | ||||
#endif | #endif | ||||
@@ -94,11 +94,11 @@ public: | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoInt8x8x32MK4_8x6x4DotProd final : public AlgoBase { | |||||
class MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd final : public AlgoBase { | |||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { | const char* name() const override { | ||||
return "AARCH32_INT8_MK4_8X6X4_DOTPROD"; | |||||
return "AARCH32_INT8_MK4_8X4X4_DOTPROD"; | |||||
} | } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* \file dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h | |||||
* \file dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
@@ -17,205 +17,7 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace armv7 { | namespace armv7 { | ||||
namespace matmul_mk4_dot_8x6x4 { | |||||
// Overview of register layout: | |||||
// | |||||
// A 1x6x4 cell of Rhs is stored in 8bit in q0, q1. | |||||
// A 2x1x4x4 cell of Lhs is stored in 8bit in q2, q3 | |||||
// A 2x6x4 block of accumulators is stored in 8bit in q4-q15 | |||||
// | |||||
// +--------+ | |||||
// Rhs |q0[0-16]| | |||||
// |q1[0-16]| | |||||
// +--------+ | |||||
// Lhs | | | |||||
// +-------+-------+ - - - - +--------+ | |||||
// | q2[0-16]| | q4[0-4]| | |||||
// | q3[0-16]| | q5[0-4]| | |||||
// +---------+ | q6[0-4]| | |||||
// | q7[0-4]| | |||||
// | q8[0-4]| | |||||
// | q9[0-4]| | |||||
// |q10[0-4]| | |||||
// |q11[0-4]| | |||||
// |q12[0-4]| | |||||
// |q13[0-4]| | |||||
// |q14[0-4]| | |||||
// |q15[0-4]| | |||||
// +--------+ | |||||
// Accumulator | |||||
static void kern_8x6(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k) { | |||||
K /= 4; | |||||
const int8_t* a_ptr = packA; | |||||
const int8_t* b_ptr = packB; | |||||
// Fix up for odd lengths - set a flag if K is odd, but make | |||||
// sure we round up the iteration count. | |||||
int oddk = (K & 1); | |||||
int k = (K + 1) / 2 - 1; | |||||
LDC = LDC * sizeof(int32_t); | |||||
int32_t* outptr0 = output; | |||||
int32_t* outptr1; | |||||
asm volatile( | |||||
// load accumulator C | |||||
"add %[outptr1], %[outptr0], %[LDC]\n" | |||||
"cmp %[is_first_k], #1\n" | |||||
"beq 1f\n" | |||||
"vld1.32 {d8, d9}, [%[outptr0]]!\n" | |||||
"vld1.32 {d10, d11}, [%[outptr0]]!\n" | |||||
"vld1.32 {d12, d13}, [%[outptr0]]!\n" | |||||
"vld1.32 {d14, d15}, [%[outptr0]]!\n" | |||||
"vld1.32 {d16, d17}, [%[outptr0]]!\n" | |||||
"vld1.32 {d18, d19}, [%[outptr0]]!\n" | |||||
"vld1.32 {d20, d21}, [%[outptr1]]!\n" | |||||
"vld1.32 {d22, d23}, [%[outptr1]]!\n" | |||||
"vld1.32 {d24, d25}, [%[outptr1]]!\n" | |||||
"vld1.32 {d26, d27}, [%[outptr1]]!\n" | |||||
"vld1.32 {d28, d29}, [%[outptr1]]!\n" | |||||
"vld1.32 {d30, d31}, [%[outptr1]]!\n" | |||||
"b 2f\n" | |||||
"1:\n" | |||||
"veor.s32 q4, q4, q4\n" | |||||
"veor.s32 q5, q5, q5\n" | |||||
"veor.s32 q6, q6, q6\n" | |||||
"veor.s32 q7, q7, q7\n" | |||||
"veor.s32 q8, q8, q8\n" | |||||
"veor.s32 q9, q9, q9\n" | |||||
"veor.s32 q10, q10, q10\n" | |||||
"veor.s32 q11, q11, q11\n" | |||||
"veor.s32 q12, q12, q12\n" | |||||
"veor.s32 q13, q13, q13\n" | |||||
"veor.s32 q14, q14, q14\n" | |||||
"veor.s32 q15, q15, q15\n" | |||||
"2: \n" | |||||
"vld1.s8 {q0}, [%[b_ptr]]!\n" | |||||
"vld1.s8 {d2}, [%[b_ptr]]!\n" | |||||
"vld1.s8 {q2}, [%[a_ptr]]!\n" | |||||
"vld1.s8 {q3}, [%[a_ptr]]!\n" | |||||
"cmp %[k], #0 \n" | |||||
"beq 4f \n" | |||||
"3:\n" | |||||
"vsdot.s8 q4 , q2, d0[0]\n" | |||||
"vsdot.s8 q5 , q2, d0[1]\n" | |||||
"vsdot.s8 q6 , q2, d1[0]\n" | |||||
"vsdot.s8 q7 , q2, d1[1]\n" | |||||
"vsdot.s8 q8 , q2, d2[0]\n" | |||||
"vsdot.s8 q9 , q2, d2[1]\n" | |||||
"vsdot.s8 q10 , q3, d0[0]\n" | |||||
"vsdot.s8 q11 , q3, d0[1]\n" | |||||
"vsdot.s8 q12 , q3, d1[0]\n" | |||||
"vsdot.s8 q13 , q3, d1[1]\n" | |||||
"vsdot.s8 q14 , q3, d2[0]\n" | |||||
"vsdot.s8 q15 , q3, d2[1]\n" | |||||
"vld1.s8 {q0}, [%[b_ptr]]!\n" | |||||
"vld1.s8 {d2}, [%[b_ptr]]!\n" | |||||
"vld1.s8 {q2}, [%[a_ptr]]!\n" | |||||
"vld1.s8 {q3}, [%[a_ptr]]!\n" | |||||
"vsdot.s8 q4 , q2, d0[0]\n" | |||||
"vsdot.s8 q5 , q2, d0[1]\n" | |||||
"vsdot.s8 q6 , q2, d1[0]\n" | |||||
"vsdot.s8 q7 , q2, d1[1]\n" | |||||
"vsdot.s8 q8 , q2, d2[0]\n" | |||||
"vsdot.s8 q9 , q2, d2[1]\n" | |||||
"vsdot.s8 q10 , q3, d0[0]\n" | |||||
"vsdot.s8 q11 , q3, d0[1]\n" | |||||
"vsdot.s8 q12 , q3, d1[0]\n" | |||||
"vsdot.s8 q13 , q3, d1[1]\n" | |||||
"vsdot.s8 q14 , q3, d2[0]\n" | |||||
"vsdot.s8 q15 , q3, d2[1]\n" | |||||
"vld1.s8 {q0}, [%[b_ptr]]!\n" | |||||
"vld1.s8 {d2}, [%[b_ptr]]!\n" | |||||
"vld1.s8 {q2}, [%[a_ptr]]!\n" | |||||
"vld1.s8 {q3}, [%[a_ptr]]!\n" | |||||
"subs %[k], %[k], #1\n" | |||||
"bne 3b\n" | |||||
// Target to use when K is 1 or 2 (i.e. zero iterations of main | |||||
// loop) | |||||
"4:\n" | |||||
"cmp %[oddk], #0 \n" | |||||
"bne 5f \n" | |||||
"vsdot.s8 q4 , q2, d0[0]\n" | |||||
"vsdot.s8 q5 , q2, d0[1]\n" | |||||
"vsdot.s8 q6 , q2, d1[0]\n" | |||||
"vsdot.s8 q7 , q2, d1[1]\n" | |||||
"vsdot.s8 q8 , q2, d2[0]\n" | |||||
"vsdot.s8 q9 , q2, d2[1]\n" | |||||
"vsdot.s8 q10 , q3, d0[0]\n" | |||||
"vsdot.s8 q11 , q3, d0[1]\n" | |||||
"vsdot.s8 q12 , q3, d1[0]\n" | |||||
"vsdot.s8 q13 , q3, d1[1]\n" | |||||
"vsdot.s8 q14 , q3, d2[0]\n" | |||||
"vsdot.s8 q15 , q3, d2[1]\n" | |||||
"vld1.s8 {q0}, [%[b_ptr]]!\n" | |||||
"vld1.s8 {d2}, [%[b_ptr]]!\n" | |||||
"vld1.s8 {q2}, [%[a_ptr]]!\n" | |||||
"vld1.s8 {q3}, [%[a_ptr]]!\n" | |||||
"vsdot.s8 q4 , q2, d0[0]\n" | |||||
"vsdot.s8 q5 , q2, d0[1]\n" | |||||
"vsdot.s8 q6 , q2, d1[0]\n" | |||||
"vst1.32 {d8, d9}, [%[outptr0]]!\n" | |||||
"vsdot.s8 q7 , q2, d1[1]\n" | |||||
"vsdot.s8 q8 , q2, d2[0]\n" | |||||
"vsdot.s8 q9 , q2, d2[1]\n" | |||||
"vst1.32 {d10, d11}, [%[outptr0]]!\n" | |||||
"vsdot.s8 q10 , q3, d0[0]\n" | |||||
"vsdot.s8 q11 , q3, d0[1]\n" | |||||
"vsdot.s8 q12 , q3, d1[0]\n" | |||||
"vst1.32 {d12, d13}, [%[outptr0]]!\n" | |||||
"vsdot.s8 q13 , q3, d1[1]\n" | |||||
"vsdot.s8 q14 , q3, d2[0]\n" | |||||
"vsdot.s8 q15 , q3, d2[1]\n" | |||||
"b 6f\n" | |||||
"5: \n" | |||||
"vsdot.s8 q4 , q2, d0[0]\n" | |||||
"vsdot.s8 q5 , q2, d0[1]\n" | |||||
"vsdot.s8 q6 , q2, d1[0]\n" | |||||
"vst1.32 {d8, d9}, [%[outptr0]]!\n" | |||||
"vsdot.s8 q7 , q2, d1[1]\n" | |||||
"vsdot.s8 q8 , q2, d2[0]\n" | |||||
"vsdot.s8 q9 , q2, d2[1]\n" | |||||
"vst1.32 {d10, d11}, [%[outptr0]]!\n" | |||||
"vsdot.s8 q10 , q3, d0[0]\n" | |||||
"vsdot.s8 q11 , q3, d0[1]\n" | |||||
"vsdot.s8 q12 , q3, d1[0]\n" | |||||
"vst1.32 {d12, d13}, [%[outptr0]]!\n" | |||||
"vsdot.s8 q13 , q3, d1[1]\n" | |||||
"vsdot.s8 q14 , q3, d2[0]\n" | |||||
"vsdot.s8 q15 , q3, d2[1]\n" | |||||
"6: \n" | |||||
"vst1.32 {d14, d15}, [%[outptr0]]!\n" | |||||
"vst1.32 {d16, d17}, [%[outptr0]]!\n" | |||||
"vst1.32 {d18, d19}, [%[outptr0]]!\n" | |||||
"vst1.32 {d20, d21}, [%[outptr1]]!\n" | |||||
"vst1.32 {d22, d23}, [%[outptr1]]!\n" | |||||
"vst1.32 {d24, d25}, [%[outptr1]]!\n" | |||||
"vst1.32 {d26, d27}, [%[outptr1]]!\n" | |||||
"vst1.32 {d28, d29}, [%[outptr1]]!\n" | |||||
"vst1.32 {d30, d31}, [%[outptr1]]!\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC), | |||||
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [k] "+r"(k), | |||||
[outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) | |||||
: | |||||
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", | |||||
"q11", "q12", "q14", "q15", "cc", "memory"); | |||||
} | |||||
namespace matmul_mk4_dot_8x4x4 { | |||||
// Overview of register layout: | // Overview of register layout: | ||||
// | // | ||||
@@ -392,144 +194,6 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
// Overview of register layout: | // Overview of register layout: | ||||
// | // | ||||
// A 1x6x4 pingpong cell of Rhs is stored in 8bit in q0-q3. | |||||
// A 1x1x4x4 pingpong cell of Lhs is stored in 8bit in q4-q5 | |||||
// A 2x6x4 block of accumulators is stored in 8bit in q10-q15 | |||||
// | |||||
// +--------+ | |||||
// Rhs |q0[0-16]| | |||||
// |q1[0-16]| | |||||
// +--------+ | |||||
// Lhs | | | |||||
// +-------+-------+ - - - - +--------+ | |||||
// | q4[0-16]| |q10[0-4]| | |||||
// | q5[0-16]| |q11[0-4]| | |||||
// +---------+ |q12[0-4]| | |||||
// |q13[0-4]| | |||||
// |q14[0-4]| | |||||
// |q15[0-4]| | |||||
// +--------+ | |||||
// Accumulator | |||||
static void kern_4x6(const int8_t* packA, const int8_t* packB, int K, | |||||
int32_t* output, int LDC, bool is_first_k) { | |||||
K /= 4; | |||||
const int8_t* a_ptr = packA; | |||||
const int8_t* b_ptr = packB; | |||||
// Fix up for odd lengths - set a flag if K is odd, but make | |||||
// sure we round up the iteration count. | |||||
int oddk = (K & 1); | |||||
int k = (K + 1) / 2 - 1; | |||||
LDC = LDC * sizeof(int32_t); | |||||
int32_t* outptr0 = output; | |||||
asm volatile( | |||||
// load accumulator C | |||||
"cmp %[is_first_k], #1\n" | |||||
"beq 1f\n" | |||||
"vld1.32 {d20, d21}, [%[outptr0]]!\n" | |||||
"vld1.32 {d22, d23}, [%[outptr0]]!\n" | |||||
"vld1.32 {d24, d25}, [%[outptr0]]!\n" | |||||
"vld1.32 {d26, d27}, [%[outptr0]]!\n" | |||||
"vld1.32 {d28, d29}, [%[outptr0]]!\n" | |||||
"vld1.32 {d30, d31}, [%[outptr0]]!\n" | |||||
"b 2f\n" | |||||
"1:\n" | |||||
"veor.s32 q10, q10, q10\n" | |||||
"veor.s32 q11, q11, q11\n" | |||||
"veor.s32 q12, q12, q12\n" | |||||
"veor.s32 q13, q13, q13\n" | |||||
"veor.s32 q14, q14, q14\n" | |||||
"veor.s32 q15, q15, q15\n" | |||||
"2: \n" | |||||
"vld1.s8 {q0}, [%[b_ptr]]!\n" | |||||
"vld1.s8 {d2}, [%[b_ptr]]!\n" | |||||
"vld1.s8 {q4}, [%[a_ptr]]!\n" | |||||
"cmp %[k], #0 \n" | |||||
"beq 4f \n" | |||||
"3:\n" | |||||
"vsdot.s8 q10 , q4, d0[0]\n" | |||||
"vsdot.s8 q11 , q4, d0[1]\n" | |||||
"vsdot.s8 q12 , q4, d1[0]\n" | |||||
"vld1.s8 {q2}, [%[b_ptr]]!\n" | |||||
"vld1.s8 {d6}, [%[b_ptr]]!\n" | |||||
"vld1.s8 {q5}, [%[a_ptr]]!\n" | |||||
"vsdot.s8 q13 , q4, d1[1]\n" | |||||
"vsdot.s8 q14 , q4, d2[0]\n" | |||||
"vsdot.s8 q15 , q4, d2[1]\n" | |||||
"vld1.s8 {q0}, [%[b_ptr]]!\n" | |||||
"vsdot.s8 q10 , q5, d4[0]\n" | |||||
"vsdot.s8 q11 , q5, d4[1]\n" | |||||
"vsdot.s8 q12 , q5, d5[0]\n" | |||||
"vld1.s8 {d2}, [%[b_ptr]]!\n" | |||||
"vsdot.s8 q13 , q5, d5[1]\n" | |||||
"vsdot.s8 q14 , q5, d6[0]\n" | |||||
"vsdot.s8 q15 , q5, d6[1]\n" | |||||
"vld1.s8 {q4}, [%[a_ptr]]!\n" | |||||
"subs %[k], %[k], #1\n" | |||||
"bne 3b\n" | |||||
// Target to use when K is 1 or 2 (i.e. zero iterations of main | |||||
// loop) | |||||
"4:\n" | |||||
"cmp %[oddk], #0 \n" | |||||
"bne 5f \n" | |||||
"vsdot.s8 q10 , q4, d0[0]\n" | |||||
"vsdot.s8 q11 , q4, d0[1]\n" | |||||
"vsdot.s8 q12 , q4, d1[0]\n" | |||||
"vld1.s8 {q2}, [%[b_ptr]]!\n" | |||||
"vld1.s8 {d6}, [%[b_ptr]]!\n" | |||||
"vld1.s8 {q5}, [%[a_ptr]]!\n" | |||||
"vsdot.s8 q13 , q4, d1[1]\n" | |||||
"vsdot.s8 q14 , q4, d2[0]\n" | |||||
"vsdot.s8 q15 , q4, d2[1]\n" | |||||
"vsdot.s8 q10 , q5, d4[0]\n" | |||||
"vsdot.s8 q11 , q5, d4[1]\n" | |||||
"vsdot.s8 q12 , q5, d5[0]\n" | |||||
"vst1.32 {d20, d21}, [%[outptr0]]!\n" | |||||
"vsdot.s8 q13 , q5, d5[1]\n" | |||||
"vsdot.s8 q14 , q5, d6[0]\n" | |||||
"vsdot.s8 q15 , q5, d6[1]\n" | |||||
"vst1.32 {d22, d23}, [%[outptr0]]!\n" | |||||
"b 6f\n" | |||||
"5: \n" | |||||
"vsdot.s8 q10 , q4, d0[0]\n" | |||||
"vsdot.s8 q11 , q4, d0[1]\n" | |||||
"vsdot.s8 q12 , q4, d1[0]\n" | |||||
"vst1.32 {d20, d21}, [%[outptr0]]!\n" | |||||
"vsdot.s8 q13 , q4, d1[1]\n" | |||||
"vsdot.s8 q14 , q4, d2[0]\n" | |||||
"vsdot.s8 q15 , q4, d2[1]\n" | |||||
"vst1.32 {d22, d23}, [%[outptr0]]!\n" | |||||
"6: \n" | |||||
"vst1.32 {d24, d25}, [%[outptr0]]!\n" | |||||
"vst1.32 {d26, d27}, [%[outptr0]]!\n" | |||||
"vst1.32 {d28, d29}, [%[outptr0]]!\n" | |||||
"vst1.32 {d30, d31}, [%[outptr0]]!\n" | |||||
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC), | |||||
[oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [k] "+r"(k), | |||||
[outptr0] "+r"(outptr0) | |||||
: | |||||
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", | |||||
"q11", "q12", "q14", "q15", "cc", "memory"); | |||||
} | |||||
// Overview of register layout: | |||||
// | |||||
// A 2x4x4 cell of Rhs is stored in 8bit in q1, q3. | // A 2x4x4 cell of Rhs is stored in 8bit in q1, q3. | ||||
// A 1x2x4x4 ping-pong cell of Lhs is stored in 8bit in q5, q7 | // A 1x2x4x4 ping-pong cell of Lhs is stored in 8bit in q5, q7 | ||||
// A 1x4x4 block of accumulators is stored in 8bit in q0, q2, q4, q6 | // A 1x4x4 block of accumulators is stored in 8bit in q0, q2, q4, q6 | ||||
@@ -671,7 +335,7 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
#undef STORE_C | #undef STORE_C | ||||
} | } | ||||
static void gemm_dots8_8x6_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
static void gemm_dots8_8x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
int ldin, int y0, int ymax, int k0, | int ldin, int y0, int ymax, int k0, | ||||
int kmax) { | int kmax) { | ||||
int y = y0, y_start = y0 / 4; | int y = y0, y_start = y0 / 4; | ||||
@@ -692,14 +356,12 @@ static void gemm_dots8_8x6_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
} | } | ||||
} | } | ||||
static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
static void gemm_dots8_8x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax) { | int x0, int xmax, int k0, int kmax) { | ||||
const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
const int ksize4 = ksize * 4; | const int ksize4 = ksize * 4; | ||||
const int ksize6 = ksize * 6; | |||||
int8_t* outptr = out; | int8_t* outptr = out; | ||||
int8_t* outptr_base = out; | int8_t* outptr_base = out; | ||||
int8_t* outptr_base4 = out + ((xmax - x0) / 6) * ksize6; | |||||
int k = k0; | int k = k0; | ||||
for (; k + 3 < kmax; k += 4) { | for (; k + 3 < kmax; k += 4) { | ||||
@@ -708,13 +370,6 @@ static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
outptr = outptr_base; | outptr = outptr_base; | ||||
int x = x0; | int x = x0; | ||||
for (; x + 5 < xmax; x += 6) { | |||||
memcpy(outptr, inptr, sizeof(int8_t) * 24); | |||||
outptr += ksize6; | |||||
inptr += 24; | |||||
} | |||||
outptr = outptr_base4; | |||||
for (; x + 3 < xmax; x += 4) { | for (; x + 3 < xmax; x += 4) { | ||||
memcpy(outptr, inptr, sizeof(int8_t) * 16); | memcpy(outptr, inptr, sizeof(int8_t) * 16); | ||||
outptr += ksize4; | outptr += ksize4; | ||||
@@ -735,12 +390,11 @@ static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
*outptr++ = 0; | *outptr++ = 0; | ||||
} | } | ||||
} | } | ||||
outptr_base += 24; | |||||
outptr_base4 += 16; | |||||
outptr_base += 16; | |||||
} | } | ||||
} | } | ||||
} // namespace matmul_mk4_dot_8x6x4 | |||||
} // namespace matmul_mk4_dot_8x4x4 | |||||
} // namespace armv7 | } // namespace armv7 | ||||
} // namespace megdnn | } // namespace megdnn | ||||
#endif | #endif |
@@ -16,7 +16,7 @@ | |||||
#include "src/armv7/matrix_mul/int8/kernel_4x8x8.h" | #include "src/armv7/matrix_mul/int8/kernel_4x8x8.h" | ||||
#include "src/armv7/matrix_mul/int8/kernel_6x8x4.h" | #include "src/armv7/matrix_mul/int8/kernel_6x8x4.h" | ||||
#include "src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h" | #include "src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h" | ||||
#include "src/armv7/matrix_mul/int8/kernel_mk4_dot_8x6x4.h" | |||||
#include "src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/matrix_mul/gemm_common.h" | #include "src/fallback/matrix_mul/gemm_common.h" | ||||
@@ -254,10 +254,10 @@ void gemm_dots8_6x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
} | } | ||||
} | } | ||||
// ===========================gemm_mk4_dots8_8x6====================================== | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_dots8_8x6); | |||||
// ===========================gemm_mk4_dots8_8x4====================================== | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_dots8_8x4); | |||||
void gemm_mk4_dots8_8x6::pack_A(dt_int8* out, const dt_int8* in, int ldin, | |||||
void gemm_mk4_dots8_8x4::pack_A(dt_int8* out, const dt_int8* in, int ldin, | |||||
int y0, int ymax, int k0, int kmax, | int y0, int ymax, int k0, int kmax, | ||||
bool transpose) const { | bool transpose) const { | ||||
megdnn_assert(!transpose, | megdnn_assert(!transpose, | ||||
@@ -266,49 +266,39 @@ void gemm_mk4_dots8_8x6::pack_A(dt_int8* out, const dt_int8* in, int ldin, | |||||
"mk4 format matmul with m is not times of 4."); | "mk4 format matmul with m is not times of 4."); | ||||
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, | megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, | ||||
"mk4 format matmul with k is not times of 4."); | "mk4 format matmul with k is not times of 4."); | ||||
matmul_mk4_dot_8x6x4::gemm_dots8_8x6_pack_A(out, in, ldin, y0, ymax, k0, | |||||
matmul_mk4_dot_8x4x4::gemm_dots8_8x4_pack_A(out, in, ldin, y0, ymax, k0, | |||||
kmax); | kmax); | ||||
} | } | ||||
void gemm_mk4_dots8_8x6::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
void gemm_mk4_dots8_8x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
int x0, int xmax, int k0, int kmax, | int x0, int xmax, int k0, int kmax, | ||||
bool transpose) const { | bool transpose) const { | ||||
megdnn_assert(!transpose, | megdnn_assert(!transpose, | ||||
"matrix mul mk4 with transposed matrix B is not supported"); | "matrix mul mk4 with transposed matrix B is not supported"); | ||||
megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, | megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, | ||||
"mk4 format matmul with k is not times of 4."); | "mk4 format matmul with k is not times of 4."); | ||||
matmul_mk4_dot_8x6x4::gemm_dots8_8x6_pack_B(out, in, ldin, x0, xmax, k0, | |||||
matmul_mk4_dot_8x4x4::gemm_dots8_8x4_pack_B(out, in, ldin, x0, xmax, k0, | |||||
kmax); | kmax); | ||||
} | } | ||||
void gemm_mk4_dots8_8x6::kern(const dt_int8* packA, const dt_int8* packB, | |||||
void gemm_mk4_dots8_8x4::kern(const dt_int8* packA, const dt_int8* packB, | |||||
size_t M, size_t N, size_t K, dt_int32* C, | size_t M, size_t N, size_t K, dt_int32* C, | ||||
size_t LDC, bool is_first_k, const dt_int32* bias, | size_t LDC, bool is_first_k, const dt_int32* bias, | ||||
dt_int32* workspace) const { | dt_int32* workspace) const { | ||||
MEGDNN_MARK_USED_VAR(bias); | MEGDNN_MARK_USED_VAR(bias); | ||||
constexpr size_t A_INTERLEAVE = 8; | constexpr size_t A_INTERLEAVE = 8; | ||||
constexpr size_t B_INTERLEAVE = 6; | |||||
//! K is packed to times of 4 | //! K is packed to times of 4 | ||||
K = round_up<size_t>(K, 4); | K = round_up<size_t>(K, 4); | ||||
const int K4 = K * 4; | const int K4 = K * 4; | ||||
const int K6 = K * 6; | |||||
const int K8 = K * 8; | const int K8 = K * 8; | ||||
size_t m = 0; | size_t m = 0; | ||||
for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | for (; m + A_INTERLEAVE - 1 < M; m += A_INTERLEAVE) { | ||||
int32_t* output = C + ((m >> 2) * LDC); | int32_t* output = C + ((m >> 2) * LDC); | ||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
size_t n = 0; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_mk4_dot_8x6x4::kern_8x6(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
output += 24; | |||||
cur_packB += K6; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
for (size_t n = 0; n < N; n += 4) { | |||||
size_t n_remain = std::min<size_t>(N - n, 4); | size_t n_remain = std::min<size_t>(N - n, 4); | ||||
matmul_mk4_dot_8x6x4::kern_8x4(packA, cur_packB, K, output, LDC, | |||||
matmul_mk4_dot_8x4x4::kern_8x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, n_remain); | is_first_k, n_remain); | ||||
output += 16; | output += 16; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
@@ -318,16 +308,9 @@ void gemm_mk4_dots8_8x6::kern(const dt_int8* packA, const dt_int8* packB, | |||||
for (; m < M; m += 4) { | for (; m < M; m += 4) { | ||||
int32_t* output = C + ((m >> 2) * LDC); | int32_t* output = C + ((m >> 2) * LDC); | ||||
const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
size_t n = 0; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
matmul_mk4_dot_8x6x4::kern_4x6(packA, cur_packB, K, output, LDC, | |||||
is_first_k); | |||||
output += 24; | |||||
cur_packB += K6; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
for (size_t n = 0; n < N; n += 4) { | |||||
size_t n_remain = std::min<size_t>(N - n, 4); | size_t n_remain = std::min<size_t>(N - n, 4); | ||||
matmul_mk4_dot_8x6x4::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
matmul_mk4_dot_8x4x4::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
is_first_k, n_remain); | is_first_k, n_remain); | ||||
output += 16; | output += 16; | ||||
cur_packB += K4; | cur_packB += K4; | ||||
@@ -27,8 +27,8 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 2, 16, false, false, | |||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 6, 8, 4, false, false, | MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 6, 8, 4, false, false, | ||||
gemm_dots8_6x8); | gemm_dots8_6x8); | ||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 6, 4, false, false, | |||||
gemm_mk4_dots8_8x6); | |||||
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 4, 4, false, false, | |||||
gemm_mk4_dots8_8x4); | |||||
#endif | #endif | ||||
} // namespace matmul | } // namespace matmul | ||||
} // namespace armv7 | } // namespace armv7 | ||||
@@ -29,7 +29,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
AlgoInt8x8x32K6x8x4 int8_k6x8x4; | AlgoInt8x8x32K6x8x4 int8_k6x8x4; | ||||
AlgoQuint8DotK4x8x4 quint8_k4x8x4; | AlgoQuint8DotK4x8x4 quint8_k4x8x4; | ||||
AlgoInt8x8x32MK4_8x6x4DotProd int8x8x32_mk4_8x6x4_dotprod; | |||||
AlgoInt8x8x32MK4_8x4x4DotProd int8x8x32_mk4_8x4x4_dotprod; | |||||
#endif | #endif | ||||
AlgoF32Gemv f32_gemv; | AlgoF32Gemv f32_gemv; | ||||
AlgoInt8x8x32MK4_4x2x16 int8x8x32_mk4_4x2x16; | AlgoInt8x8x32MK4_4x2x16 int8x8x32_mk4_4x2x16; | ||||
@@ -57,7 +57,7 @@ public: | |||||
all_algos.emplace_back(&f16_mk8_4x8); | all_algos.emplace_back(&f16_mk8_4x8); | ||||
#endif | #endif | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
all_algos.emplace_back(&int8x8x32_mk4_8x6x4_dotprod); | |||||
all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); | |||||
all_algos.emplace_back(&int8_k6x8x4); | all_algos.emplace_back(&int8_k6x8x4); | ||||
all_algos.emplace_back(&quint8_k4x8x4); | all_algos.emplace_back(&quint8_k4x8x4); | ||||
#endif | #endif | ||||
@@ -42,7 +42,7 @@ private: | |||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4 | class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4 | ||||
class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4 | class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4 | ||||
class AlgoInt8x8x32MK4_8x6x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x6x4 | |||||
class AlgoInt8x8x32MK4_8x4x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x4x4 | |||||
// DotProduct | // DotProduct | ||||
#endif | #endif | ||||
class AlgoPack; | class AlgoPack; | ||||
@@ -94,7 +94,7 @@ TEST_F(ARMV7, MATRIX_MUL_MK4_DOT_INT8) { | |||||
for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34}) | for (size_t k : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32, 33, 34}) | ||||
args.emplace_back(m, n, k, 0); | args.emplace_back(m, n, k, 0); | ||||
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, | ||||
handle(), "AARCH32_INT8_MK4_8X6X4_DOTPROD", | |||||
handle(), "AARCH32_INT8_MK4_8X4X4_DOTPROD", | |||||
param::MatrixMul::Format::MK4_DOT, 1, 1e-3, | param::MatrixMul::Format::MK4_DOT, 1, 1e-3, | ||||
std::move(args)); | std::move(args)); | ||||
} | } | ||||
@@ -315,7 +315,7 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x32_MK4_DOT) { | |||||
param.format = MatrixMul::Param::Format::MK4_DOT; | param.format = MatrixMul::Param::Format::MK4_DOT; | ||||
Benchmarker<MatrixMul> benchmarker_mk4_dot(handle()); | Benchmarker<MatrixMul> benchmarker_mk4_dot(handle()); | ||||
benchmarker_mk4_dot.set_before_exec_callback( | benchmarker_mk4_dot.set_before_exec_callback( | ||||
AlgoChecker<MatrixMul>("AARCH32_INT8_MK4_8X6X4_DOTPROD")); | |||||
AlgoChecker<MatrixMul>("AARCH32_INT8_MK4_8X4X4_DOTPROD")); | |||||
benchmarker_mk4_dot.set_param(param) | benchmarker_mk4_dot.set_param(param) | ||||
.set_dtype(0, dtype::Int8()) | .set_dtype(0, dtype::Int8()) | ||||
.set_dtype(1, dtype::Int8()) | .set_dtype(1, dtype::Int8()) | ||||