@@ -707,7 +707,7 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output, | |||||
"cmp %w[n_remain], #3\n" \ | "cmp %w[n_remain], #3\n" \ | ||||
"blt 22f\n" \ | "blt 22f\n" \ | ||||
"st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ | "st1 {v8.4s, v9.4s, v10.4s}, [%[output0]]\n" \ | ||||
"b 23f\n" \ | |||||
"b 24f\n" \ | |||||
"22:\n" \ | "22:\n" \ | ||||
"cmp %w[n_remain], #2\n" \ | "cmp %w[n_remain], #2\n" \ | ||||
"blt 23f\n" \ | "blt 23f\n" \ | ||||
@@ -85,6 +85,74 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32, megdnn_armv7_matmul_kern, | |||||
"AlgoF32Impl"_hash, | "AlgoF32Impl"_hash, | ||||
armv7::matmul::sgemm_4x12, float, float); | armv7::matmul::sgemm_4x12, float, float); | ||||
/* ===================== F32 algo mk4 K4x12 ===================== */ | |||||
namespace { | |||||
void f32_mk4_pack_4x12_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
MIDOUT_BEGIN(megdnn_armv7_matmul_kern, | |||||
midout_iv("f32_mk4_pack_4x12_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto trA = kern_param.trA, trB = kern_param.trB; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
auto A_type = kern_param.A_type, B_type = kern_param.B_type, | |||||
C_type = kern_param.C_type; | |||||
const auto Aptr = kern_param.A<float>(), Bptr = kern_param.B<float>(); | |||||
auto Cptr = kern_param.C<float>(); | |||||
armv7::matmul::sgemm_mk4_pack_4x12 strategy(M, N, K, A_type, B_type, | |||||
C_type); | |||||
megdnn::matmul::GemmInterleaved<armv7::matmul::sgemm_mk4_pack_4x12>( | |||||
M, N, K, trA, trB, strategy) | |||||
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, | |||||
kern_param.workspace_ptr); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} // anonymous namespace | |||||
bool MatrixMulImpl::AlgoF32MK4Pack4x12::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||||
kern_size_param.B_type == kern_size_param.A_type && | |||||
kern_size_param.C_type == kern_size_param.A_type && | |||||
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | |||||
!kern_size_param.trB && kern_size_param.M % 4 == 0 && | |||||
kern_size_param.K % 4 == 0 && !kern_size_param.trA && | |||||
!kern_size_param.trB; | |||||
} | |||||
size_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_workspace( | |||||
const KernSizeParam& kern_size_param) const { | |||||
MIDOUT_BEGIN(megdnn_armv7_matmul_kern, | |||||
midout_iv("AlgoF32MK4Pack4x12::get_workspace"_hash)) { | |||||
auto M = kern_size_param.M, N = kern_size_param.N, | |||||
K = kern_size_param.K; | |||||
auto trA = kern_size_param.trA, trB = kern_size_param.trB; | |||||
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, | |||||
C_type = kern_size_param.C_type; | |||||
armv7::matmul::sgemm_mk4_pack_4x12 strategy(M, N, K, A_type, B_type, | |||||
C_type); | |||||
return megdnn::matmul::GemmInterleaved< | |||||
armv7::matmul::sgemm_mk4_pack_4x12>(M, N, K, trA, trB, | |||||
strategy) | |||||
.get_workspace_size(); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_kern( | |||||
const KernSizeParam&) const { | |||||
return f32_mk4_pack_4x12_kern; | |||||
} | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4Pack4x12, | |||||
megdnn_armv7_matmul_kern, | |||||
"AlgoF32MK4Pack4x12"_hash, | |||||
armv7::matmul::sgemm_mk4_pack_4x12, float, | |||||
float); | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
/* ===================== F16 K4x16x1 algo ===================== */ | /* ===================== F16 K4x16x1 algo ===================== */ | ||||
namespace { | namespace { | ||||
@@ -29,6 +29,17 @@ public: | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override; | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||||
}; | |||||
class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
@@ -1121,6 +1121,62 @@ static inline void transpose_4x4_1_b(const T*& inptr0, const T*& inptr1, | |||||
} | } | ||||
template <typename T> | template <typename T> | ||||
static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { | |||||
static_assert(sizeof(T) == 4, | |||||
"transpose_1x12_4_s only support sizeof(T) == 4"); | |||||
asm volatile( | |||||
"vld4.32 {d0-d3}, [%[inptr0]]!\n" | |||||
"vld4.32 {d4-d7}, [%[inptr0]]!\n" | |||||
"vld4.32 {d8-d11}, [%[inptr0]]!\n" | |||||
"vld4.32 {d12-d15}, [%[inptr0]]!\n" | |||||
"vld4.32 {d16-d19}, [%[inptr0]]!\n" | |||||
"vld4.32 {d20-d23}, [%[inptr0]]!\n" | |||||
"vswp d1, d4\n" | |||||
"vswp d3, d6\n" | |||||
"vswp d9, d12\n" | |||||
"vswp d11, d14\n" | |||||
"vswp d17, d20\n" | |||||
"vswp d19, d22\n" | |||||
"vst1.32 {d0-d1}, [%[outptr]]! \n" | |||||
"vst1.32 {d8-d9}, [%[outptr]]! \n" | |||||
"vst1.32 {d16-d17}, [%[outptr]]! \n" | |||||
"vst1.32 {d4-d5}, [%[outptr]]! \n" | |||||
"vst1.32 {d12-d13}, [%[outptr]]! \n" | |||||
"vst1.32 {d20-d21}, [%[outptr]]! \n" | |||||
"vst1.32 {d2-d3}, [%[outptr]]! \n" | |||||
"vst1.32 {d10-d11}, [%[outptr]]! \n" | |||||
"vst1.32 {d18-d19}, [%[outptr]]! \n" | |||||
"vst1.32 {d6-d7}, [%[outptr]]! \n" | |||||
"vst1.32 {d14-d15}, [%[outptr]]! \n" | |||||
"vst1.32 {d22-d23}, [%[outptr]]! \n" | |||||
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) | |||||
: | |||||
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", | |||||
"q11", "memory"); | |||||
} | |||||
template <typename T> | |||||
static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { | |||||
static_assert(sizeof(T) == 4, | |||||
"transpose_1x4_4_s only support sizeof(T) == 4"); | |||||
asm volatile( | |||||
"vld4.32 {d0-d3}, [%[inptr0]]!\n" | |||||
"vld4.32 {d4-d7}, [%[inptr0]]!\n" | |||||
"vswp d1, d4\n" | |||||
"vswp d3, d6\n" | |||||
"vst1.32 {d0-d1}, [%[outptr]]! \n" | |||||
"vst1.32 {d4-d5}, [%[outptr]]! \n" | |||||
"vst1.32 {d2-d3}, [%[outptr]]! \n" | |||||
"vst1.32 {d6-d7}, [%[outptr]]! \n" | |||||
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr) | |||||
: | |||||
: "q0", "q1", "q2", "q3", "memory"); | |||||
} | |||||
template <typename T> | |||||
static inline void transpose_4(const T*& inptr0, const T*& inptr1, | static inline void transpose_4(const T*& inptr0, const T*& inptr1, | ||||
const T*& inptr2, const T*& inptr3, T* outptr, | const T*& inptr2, const T*& inptr3, T* outptr, | ||||
int interleave, int size, T val = 0) { | int interleave, int size, T val = 0) { | ||||
@@ -18,6 +18,9 @@ namespace matmul { | |||||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, | MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, | ||||
sgemm_4x12); | sgemm_4x12); | ||||
MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, false, | |||||
sgemm_mk4_pack_4x12); | |||||
MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 8, 1, false, true, | MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 8, 1, false, true, | ||||
sgemm_nopack_4x8); | sgemm_nopack_4x8); | ||||
@@ -0,0 +1,451 @@ | |||||
/** | |||||
* \file dnn/src/armv7/matrix_mul/fp32/strategy_mk_4x12.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/armv7/matrix_mul/fp32/strategy.h" | |||||
#include "src/armv7/matrix_mul/asm/common.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
using namespace megdnn; | |||||
using namespace armv7; | |||||
using namespace armv7::matmul; | |||||
namespace { | |||||
// Overview of register layout: | |||||
// | |||||
// A 1x12 cell of Rhs is stored in 32bit in q1-q3 | |||||
// A 4x1 cell of Lhs is stored in 132bit in q0 | |||||
// A 4x12 block of accumulators is stored in 32bit in q4-q15. | |||||
// | |||||
// +--------+--------+--------+ | |||||
// | q1[0-3]| q2[0-3]| q3[0-3]| | |||||
// Rhs +--------+--------+--------+ | |||||
// | |||||
// | | | | | |||||
// | |||||
// Lhs | | | | | |||||
// | |||||
// +--+ - - - - +--------+--------+--------+ | |||||
// |q0| | q4[0-3]| q5[0-3]| q6[0-3]| | |||||
// |q0| | q7[0-3]| q8[0-3]| q9[0-3]| | |||||
// |q0| |q10[0-3]|q11[0-3]|q12[0-3]| | |||||
// |q0| |q13[0-3]|q14[0-3]|q15[0-3]| | |||||
// +--+ - - - - +--------+--------+--------+ | |||||
// | |||||
// Accumulator | |||||
void kern_4x12(const float* packA, const float* packB, int K, float* output, | |||||
int LDC, bool is_first_k) { | |||||
MEGDNN_MARK_USED_VAR(LDC); | |||||
const float* a_ptr = packA; | |||||
const float* b_ptr = packB; | |||||
float* output0 = output; | |||||
int oddk = (K & 1); | |||||
K = ((K + 1) / 2) - 1; | |||||
asm volatile( | |||||
"cmp %[is_first_k], #1\n" | |||||
"beq 1f\n" | |||||
"mov r1, %[output0]\n" | |||||
"vld1.32 {d8-d11}, [r1]!\n" | |||||
"vld1.32 {d12-d15}, [r1]!\n" | |||||
"vld1.32 {d16-d19}, [r1]!\n" | |||||
"vld1.32 {d20-d23}, [r1]!\n" | |||||
"vld1.32 {d24-d27}, [r1]!\n" | |||||
"vld1.32 {d28-d31}, [r1]!\n" | |||||
"vld1.32 {d0-d1}, [%[a_ptr]]!\n" | |||||
"vld1.32 {d4-d5}, [%[b_ptr]]!\n" | |||||
"b 2f\n" | |||||
"1:\n" | |||||
"veor.32 q4, q4, q4\n" | |||||
"pld [%[output0]]\n" | |||||
"veor.32 q5, q4, q4\n" | |||||
"veor.32 q6, q4, q4\n" | |||||
"veor.32 q7, q4, q4\n" | |||||
"vld1.32 {d0-d1}, [%[a_ptr]]!\n" | |||||
"veor.32 q8, q4, q4\n" | |||||
"veor.32 q9, q4, q4\n" | |||||
"veor.32 q10, q4, q4\n" | |||||
"veor.32 q11, q4, q4\n" | |||||
"vld1.32 {d4-d7}, [%[b_ptr]]!\n" | |||||
"veor.32 q12, q4, q4\n" | |||||
"veor.32 q13, q4, q4\n" | |||||
"veor.32 q14, q4, q4\n" | |||||
"veor.32 q15, q4, q4\n" | |||||
"2: \n" | |||||
"cmp %[K], #0\n" | |||||
"beq 4f\n" | |||||
"3:\n" | |||||
"vmla.f32 q4, q0, d4[0]\n" | |||||
"vmla.f32 q5, q0, d4[1]\n" | |||||
"vmla.f32 q6, q0, d5[0]\n" | |||||
"vmla.f32 q7, q0, d5[1]\n" | |||||
"vld1.32 {d4-d5}, [%[b_ptr]]!\n" | |||||
"vmla.f32 q8, q0, d6[0]\n" | |||||
"vmla.f32 q9, q0, d6[1]\n" | |||||
"vmla.f32 q10, q0, d7[0]\n" | |||||
"vld1.32 {d2-d3}, [%[a_ptr]]!\n" | |||||
"vmla.f32 q11, q0, d7[1]\n" | |||||
"vld1.32 {d6-d7}, [%[b_ptr]]!\n" | |||||
"vmla.f32 q12, q0, d4[0]\n" | |||||
"vmla.f32 q13, q0, d4[1]\n" | |||||
"vmla.f32 q14, q0, d5[0]\n" | |||||
"vmla.f32 q15, q0, d5[1]\n" | |||||
"vld1.32 {d4-d5}, [%[b_ptr]]!\n" | |||||
"vmla.f32 q4, q1, d6[0]\n" | |||||
"subs %[K], %[K], #1\n" | |||||
"vmla.f32 q5, q1, d6[1]\n" | |||||
"vmla.f32 q6, q1, d7[0]\n" | |||||
"vmla.f32 q7, q1, d7[1]\n" | |||||
"vld1.32 {d6-d7}, [%[b_ptr]]!\n" | |||||
"vmla.f32 q8, q1, d4[0]\n" | |||||
"vmla.f32 q9, q1, d4[1]\n" | |||||
"vld1.32 {d0-d1}, [%[a_ptr]]!\n" | |||||
"vmla.f32 q10, q1, d5[0]\n" | |||||
"vmla.f32 q11, q1, d5[1]\n" | |||||
"vld1.32 {d4-d5}, [%[b_ptr]]!\n" | |||||
"vmla.f32 q12, q1, d6[0]\n" | |||||
"vmla.f32 q13, q1, d6[1]\n" | |||||
"vmla.f32 q14, q1, d7[0]\n" | |||||
"vmla.f32 q15, q1, d7[1]\n" | |||||
"vld1.32 {d6-d7}, [%[b_ptr]]!\n" | |||||
"bne 3b\n" | |||||
"4:\n" | |||||
"cmp %[oddk], #1\n" | |||||
"beq 5f\n" | |||||
// Even tail | |||||
"vmla.f32 q4, q0, d4[0]\n" | |||||
"vmla.f32 q5, q0, d4[1]\n" | |||||
"vmla.f32 q6, q0, d5[0]\n" | |||||
"vmla.f32 q7, q0, d5[1]\n" | |||||
"vld1.32 {d4-d5}, [%[b_ptr]]!\n" | |||||
"vmla.f32 q8, q0, d6[0]\n" | |||||
"vmla.f32 q9, q0, d6[1]\n" | |||||
"vmla.f32 q10, q0, d7[0]\n" | |||||
"vld1.32 {d2-d3}, [%[a_ptr]]!\n" | |||||
"vmla.f32 q11, q0, d7[1]\n" | |||||
"vld1.32 {d6-d7}, [%[b_ptr]]!\n" | |||||
"vmla.f32 q12, q0, d4[0]\n" | |||||
"vmla.f32 q13, q0, d4[1]\n" | |||||
"vmla.f32 q14, q0, d5[0]\n" | |||||
"vmla.f32 q15, q0, d5[1]\n" | |||||
"vld1.32 {d4-d5}, [%[b_ptr]]!\n" | |||||
"vmla.f32 q4, q1, d6[0]\n" | |||||
"subs %[K], %[K], #1\n" | |||||
"vmla.f32 q5, q1, d6[1]\n" | |||||
"vmla.f32 q6, q1, d7[0]\n" | |||||
"vmla.f32 q7, q1, d7[1]\n" | |||||
"vld1.32 {d6-d7}, [%[b_ptr]]!\n" | |||||
"vmla.f32 q8, q1, d4[0]\n" | |||||
"vmla.f32 q9, q1, d4[1]\n" | |||||
"vst1.32 {d8-d11}, [%[output0]]!\n" | |||||
"vmla.f32 q10, q1, d5[0]\n" | |||||
"vmla.f32 q11, q1, d5[1]\n" | |||||
"vst1.32 {d12-d15}, [%[output0]]!\n" | |||||
"vmla.f32 q12, q1, d6[0]\n" | |||||
"vmla.f32 q13, q1, d6[1]\n" | |||||
"vst1.32 {d16-d19}, [%[output0]]!\n" | |||||
"vmla.f32 q14, q1, d7[0]\n" | |||||
"vmla.f32 q15, q1, d7[1]\n" | |||||
"vst1.32 {d20-d23}, [%[output0]]!\n" | |||||
"vst1.32 {d24-d27}, [%[output0]]!\n" | |||||
"vst1.32 {d28-d31}, [%[output0]]!\n" | |||||
"b 6f\n" | |||||
// odd tail | |||||
"5:\n" | |||||
"vmla.f32 q4, q0, d4[0]\n" | |||||
"vmla.f32 q5, q0, d4[1]\n" | |||||
"vmla.f32 q6, q0, d5[0]\n" | |||||
"vmla.f32 q7, q0, d5[1]\n" | |||||
"vld1.32 {d4-d5}, [%[b_ptr]]!\n" | |||||
"vmla.f32 q8, q0, d6[0]\n" | |||||
"vst1.32 {d8-d11}, [%[output0]]!\n" | |||||
"vmla.f32 q9, q0, d6[1]\n" | |||||
"vmla.f32 q10, q0, d7[0]\n" | |||||
"vst1.32 {d12-d15}, [%[output0]]!\n" | |||||
"vmla.f32 q11, q0, d7[1]\n" | |||||
"vmla.f32 q12, q0, d4[0]\n" | |||||
"vst1.32 {d16-d19}, [%[output0]]!\n" | |||||
"vmla.f32 q13, q0, d4[1]\n" | |||||
"vst1.32 {d20-d23}, [%[output0]]!\n" | |||||
"vmla.f32 q14, q0, d5[0]\n" | |||||
"vst1.32 {d24-d27}, [%[output0]]!\n" | |||||
"vmla.f32 q15, q0, d5[1]\n" | |||||
"vst1.32 {d28-d31}, [%[output0]]!\n" | |||||
"6:\n" | |||||
: [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), | |||||
[ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk), | |||||
[ output0 ] "+r"(output0) | |||||
: | |||||
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", | |||||
"q11", "q12", "q13", "q14", "q15", "r1", "cc", "memory"); | |||||
} | |||||
// Overview of register layout: | |||||
// | |||||
// A 2x4 cell of Rhs is stored in 32bit in v2 - v3 | |||||
// A 4x2 cell of Lhs is stored in 32bit in v0 - v1 | |||||
// A 4x4 block of accumulators is stored in 32bit in v4-v6 | |||||
// | |||||
// +--------+ | |||||
// | q2[0-3]| | |||||
// | q5[0-3]| | |||||
// Rhs +--------+ | |||||
// | |||||
// | | | |||||
// | |||||
// Lhs | | | |||||
// | |||||
// +--+ --- - +--------+ | |||||
// |q0| | q8[0-3]| | |||||
// |q0| |q11[0-3]| | |||||
// |q0| |q14[0-3]| | |||||
// |q0| |q17[0-3]| | |||||
// +--+ --- - +--------+ | |||||
// | |||||
// Accumulator | |||||
void kern_4x4(const float* packA, const float* packB, int K, float* output, | |||||
int LDC, bool is_first_k, int n_remain) { | |||||
MEGDNN_MARK_USED_VAR(LDC); | |||||
const float* a_ptr = packA; | |||||
const float* b_ptr = packB; | |||||
int oddk = (K & 1); | |||||
K = ((K + 1) / 2) - 1; | |||||
//clang-format off | |||||
#define LOAD_C \ | |||||
"cmp %[n_remain], #4\n" \ | |||||
"blt 11f\n" \ | |||||
"vld1.32 {d8-d11}, [r1]!\n" \ | |||||
"vld1.32 {d12-d15}, [r1]!\n" \ | |||||
"b 14f\n" \ | |||||
"11:\n" \ | |||||
"cmp %[n_remain], #3\n" \ | |||||
"blt 12f\n" \ | |||||
"vld1.32 {d8-d11}, [r1]!\n" \ | |||||
"vld1.32 {d12-d13}, [r1]!\n" \ | |||||
"b 14f\n" \ | |||||
"12:\n" \ | |||||
"cmp %[n_remain], #2\n" \ | |||||
"blt 13f\n" \ | |||||
"vld1.32 {d8-d11}, [r1]\n" \ | |||||
"b 14f\n" \ | |||||
"13:\n" \ | |||||
"vld1.32 {d8-d9}, [r1]\n" \ | |||||
"14:\n" | |||||
#define STORE_C \ | |||||
"cmp %[n_remain], #4\n" \ | |||||
"blt 21f\n" \ | |||||
"vst1.32 {d8-d11}, [%[output]]!\n" \ | |||||
"vst1.32 {d12-d15}, [%[output]]!\n" \ | |||||
"b 24f\n" \ | |||||
"21:\n" \ | |||||
"cmp %[n_remain], #3\n" \ | |||||
"blt 22f\n" \ | |||||
"vst1.32 {d8-d11}, [%[output]]!\n" \ | |||||
"vst1.32 {d12-d13}, [%[output]]!\n" \ | |||||
"b 24f\n" \ | |||||
"22:\n" \ | |||||
"cmp %[n_remain], #2\n" \ | |||||
"blt 23f\n" \ | |||||
"vst1.32 {d8-d11}, [%[output]]!\n" \ | |||||
"b 24f\n" \ | |||||
"23:\n" \ | |||||
"vst1.32 {d8-d9}, [%[output]]!\n" \ | |||||
"24:\n" | |||||
//clang-format on | |||||
asm volatile( | |||||
"cmp %[is_first_k], #1\n" | |||||
"beq 1f\n" | |||||
"mov r1, %[output]\n" LOAD_C | |||||
"vld1.32 {d0-d1}, [%[a_ptr]]!\n" | |||||
"vld1.32 {d4-d5}, [%[b_ptr]]!\n" | |||||
"b 2f\n" | |||||
"1:\n" | |||||
"veor.32 q4, q4, q4\n" | |||||
"pld [%[output]]\n" | |||||
"veor.32 q5, q4, q4\n" | |||||
"vld1.32 {d0-d1}, [%[a_ptr]]!\n" | |||||
"veor.32 q6, q4, q4\n" | |||||
"vld1.32 {d4-d5}, [%[b_ptr]]!\n" | |||||
"veor.32 q7, q4, q4\n" | |||||
"2: \n" | |||||
"cmp %[K], #0\n" | |||||
"beq 4f\n" | |||||
"3:\n" | |||||
"vmla.f32 q4, q0, d4[0]\n" | |||||
"vld1.32 {d2-d3}, [%[a_ptr]]!\n" | |||||
"vmla.f32 q5, q0, d4[1]\n" | |||||
"vld1.32 {d6-d7}, [%[b_ptr]]!\n" | |||||
"vmla.f32 q6, q0, d5[0]\n" | |||||
"vmla.f32 q7, q0, d5[1]\n" | |||||
"vld1.32 {d4-d5}, [%[b_ptr]]!\n" | |||||
"vmla.f32 q4, q1, d6[0]\n" | |||||
"subs %[K], %[K], #1\n" | |||||
"vmla.f32 q5, q1, d6[1]\n" | |||||
"vld1.32 {d0-d1}, [%[a_ptr]]!\n" | |||||
"vmla.f32 q6, q1, d7[0]\n" | |||||
"vmla.f32 q7, q1, d7[1]\n" | |||||
"bne 3b\n" | |||||
"4:\n" | |||||
"cmp %[oddk], #1\n" | |||||
"beq 5f\n" | |||||
// Even tail | |||||
"vmla.f32 q4, q0, d4[0]\n" | |||||
"vld1.32 {d2-d3}, [%[a_ptr]]!\n" | |||||
"vmla.f32 q5, q0, d4[1]\n" | |||||
"vld1.32 {d6-d7}, [%[b_ptr]]!\n" | |||||
"vmla.f32 q6, q0, d5[0]\n" | |||||
"vmla.f32 q7, q0, d5[1]\n" | |||||
"vmla.f32 q4, q1, d6[0]\n" | |||||
"vmla.f32 q5, q1, d6[1]\n" | |||||
"vmla.f32 q6, q1, d7[0]\n" | |||||
"vmla.f32 q7, q1, d7[1]\n" | |||||
"b 6f\n" | |||||
// odd tail | |||||
"5:\n" | |||||
"vmla.f32 q4, q0, d4[0]\n" | |||||
"vmla.f32 q5, q0, d4[1]\n" | |||||
"vmla.f32 q6, q0, d5[0]\n" | |||||
"vmla.f32 q7, q0, d5[1]\n" | |||||
"6:\n" STORE_C | |||||
: [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), | |||||
[ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk), | |||||
[ output ] "+r"(output), [ n_remain ] "+r"(n_remain) | |||||
: | |||||
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "r1", "cc", | |||||
"memory"); | |||||
#undef LOAD_C | |||||
#undef STORE_C | |||||
} | |||||
} // namespace | |||||
MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_pack_4x12); | |||||
//! Now no matmul mode of only packB support in conv1x1 and im2col, so just copy | |||||
//! the weight | |||||
void sgemm_mk4_pack_4x12::pack_A(float* out, const float* in, int ldin, int y0, | |||||
int ymax, int k0, int kmax, bool) const { | |||||
megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); | |||||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||||
constexpr int PACK_C_SIZE = 4; | |||||
size_t cp_length = (kmax - k0) * PACK_C_SIZE; | |||||
for (int m = y0; m < ymax; m += 4) { | |||||
const float* src = in + (m / PACK_C_SIZE) * ldin + k0 * PACK_C_SIZE; | |||||
memcpy(out, src, cp_length * sizeof(float)); | |||||
out += cp_length; | |||||
} | |||||
} | |||||
void sgemm_mk4_pack_4x12::pack_B(float* out, const float* in, int ldin, int x0, | |||||
int xmax, int k0, int kmax, | |||||
bool transpose_B) const { | |||||
megdnn_assert(!transpose_B); | |||||
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||||
float tmpbuff[16] = {0.0f}; | |||||
constexpr int PACK_C_SIZE = 4; | |||||
int ksize = kmax - k0; | |||||
int ksize12 = ksize * 12; | |||||
int ksize4 = (ksize << 2); | |||||
float* outptr_base = out; | |||||
float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12; | |||||
int k = k0; | |||||
for (; k + 3 < kmax; k += 4) { | |||||
const float* inptr = in + k / PACK_C_SIZE * ldin + x0 * PACK_C_SIZE; | |||||
prefetch_3x(inptr); | |||||
int x = x0; | |||||
auto outptr = outptr_base; | |||||
for (; x + 12 <= xmax; x += 12) { | |||||
auto outptr_interleave = outptr; | |||||
transpose_1x12_4_s(inptr, outptr_interleave); | |||||
outptr += ksize12; | |||||
} | |||||
outptr = outptr_base4; | |||||
for (; x + 4 <= xmax; x += 4) { | |||||
auto outptr_interleave = outptr; | |||||
transpose_1x4_4_s(inptr, outptr_interleave); | |||||
outptr += ksize4; | |||||
} | |||||
if (x < xmax) { | |||||
memcpy(tmpbuff, inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE); | |||||
auto outptr_interleave = outptr; | |||||
const float* tmp_ptr = &tmpbuff[0]; | |||||
transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave); | |||||
outptr += ksize4; | |||||
} | |||||
outptr_base += 12 * PACK_C_SIZE; | |||||
outptr_base4 += 4 * PACK_C_SIZE; | |||||
} | |||||
} | |||||
void sgemm_mk4_pack_4x12::kern(const float* packA, const float* packB, size_t M, | |||||
size_t N, size_t K, float* C, size_t LDC, | |||||
bool is_first_k, const float*, float*) const { | |||||
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
A_dtype.enumv() == C_dtype.enumv() && | |||||
A_dtype.enumv() == DTypeEnum::Float32); | |||||
constexpr int PACK_C_SIZE = 4; | |||||
constexpr size_t A_INTERLEAVE = 4; | |||||
constexpr size_t B_INTERLEAVE = 12; | |||||
const int K12 = K * 12; | |||||
const int K4 = K * 4; | |||||
size_t m = 0; | |||||
for (; m < M; m += A_INTERLEAVE) { | |||||
float* output = C + (m / 4 * LDC); | |||||
size_t n = 0; | |||||
const float* cur_packB = packB; | |||||
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||||
kern_4x12(packA, cur_packB, K, output, LDC, is_first_k); | |||||
output += PACK_C_SIZE * B_INTERLEAVE; | |||||
cur_packB += K12; | |||||
} | |||||
for (; n < N; n += 4) { | |||||
kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
std::min<size_t>(N - n, 4)); | |||||
output += PACK_C_SIZE * 4; | |||||
cur_packB += K4; | |||||
} | |||||
packA += K4; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -20,6 +20,7 @@ using namespace armv7; | |||||
class MatrixMulImpl::AlgoPack : NonCopyableObj { | class MatrixMulImpl::AlgoPack : NonCopyableObj { | ||||
AlgoF32 f32; | AlgoF32 f32; | ||||
AlgoF32MK4Pack4x12 f32_mk4_pack_4x12; | |||||
AlgoF32MK4_4x8 f32_mk4_4x8; | AlgoF32MK4_4x8 f32_mk4_4x8; | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
AlgoF16K4x16x1 f16_k4x16x1; | AlgoF16K4x16x1 f16_k4x16x1; | ||||
@@ -48,6 +49,7 @@ public: | |||||
AlgoPack() { | AlgoPack() { | ||||
all_algos.emplace_back(&f32_gemv); | all_algos.emplace_back(&f32_gemv); | ||||
all_algos.emplace_back(&f32); | all_algos.emplace_back(&f32); | ||||
all_algos.emplace_back(&f32_mk4_pack_4x12); | |||||
all_algos.emplace_back(&f32_mk4_4x8); | all_algos.emplace_back(&f32_mk4_4x8); | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
all_algos.emplace_back(&f16_k4x16x1); | all_algos.emplace_back(&f16_k4x16x1); | ||||
@@ -21,6 +21,7 @@ public: | |||||
SmallVector<AlgoBase*> algo_pack() override; | SmallVector<AlgoBase*> algo_pack() override; | ||||
private: | private: | ||||
class AlgoF32; // Armv7 F32 | class AlgoF32; // Armv7 F32 | ||||
class AlgoF32MK4Pack4x12; // Armv7 F32 Kernel 4x12 with pack | |||||
class AlgoF32MK4_4x8; // Armv7 F32 Kernel 4x8 nopack | class AlgoF32MK4_4x8; // Armv7 F32 Kernel 4x8 nopack | ||||
class AlgoF32Gemv; // Armv7 F32 Gemv | class AlgoF32Gemv; // Armv7 F32 Gemv | ||||
class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 | class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 | ||||
@@ -1287,23 +1287,27 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { | |||||
#undef cb | #undef cb | ||||
} | } | ||||
#if MEGDNN_AARCH64 | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<conv_bias::TestArg> args = | std::vector<conv_bias::TestArg> args = | ||||
get_nchw44_conv_bias_args({2, 4, 7}, 1); | get_nchw44_conv_bias_args({2, 4, 7}, 1); | ||||
#if MEGDNN_AARCH64 | |||||
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | ||||
} | |||||
#elif MEGDNN_ARMV7 | |||||
check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12"); | |||||
#endif | #endif | ||||
} | |||||
#if MEGDNN_AARCH64 | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<conv_bias::TestArg> args = | std::vector<conv_bias::TestArg> args = | ||||
get_nchw44_conv_bias_args({3, 5, 6}, 2); | get_nchw44_conv_bias_args({3, 5, 6}, 2); | ||||
#if MEGDNN_AARCH64 | |||||
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); | ||||
} | |||||
#elif MEGDNN_ARMV7 | |||||
check_conv_bias(args, handle(), "IM2COLMATMUL:ARMV7_F32_MK4_PACK_4X12"); | |||||
#endif | #endif | ||||
} | |||||
/***************************** Conv1x1 Algo Test ***********************/ | /***************************** Conv1x1 Algo Test ***********************/ | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { | ||||
@@ -1316,14 +1320,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { | |||||
#endif | #endif | ||||
} | } | ||||
#if MEGDNN_AARCH64 | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
std::vector<conv_bias::TestArg> args = | std::vector<conv_bias::TestArg> args = | ||||
get_nchw44_conv_bias_args({1}, 1, true, false, false); | get_nchw44_conv_bias_args({1}, 1, true, false, false); | ||||
#if MEGDNN_AARCH64 | |||||
check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24"); | check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24"); | ||||
} | |||||
#elif MEGDNN_ARMV7 | |||||
check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32_MK4_PACK_4X12:24"); | |||||
#endif | #endif | ||||
} | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
@@ -28,6 +28,12 @@ TEST_F(ARMV7, MATRIX_MUL_MK4) { | |||||
"ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 4); | "ARMV7_F32_MK4_4x8", param::MatrixMul::Format::MK4, 4); | ||||
} | } | ||||
TEST_F(ARMV7, MATRIX_MUL_PACK_MK4) { | |||||
matrix_mul::check_matrix_mul( | |||||
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), | |||||
"ARMV7_F32_MK4_PACK_4X12", param::MatrixMul::Format::MK4, 1); | |||||
} | |||||
TEST_F(ARMV7, MATRIX_MUL_MK4_INT8) { | TEST_F(ARMV7, MATRIX_MUL_MK4_INT8) { | ||||
std::vector<matrix_mul::TestArg> args; | std::vector<matrix_mul::TestArg> args; | ||||
for (size_t m : {1, 2, 3, 4, 5, 7, 10, 11}) | for (size_t m : {1, 2, 3, 4, 5, 7, 10, 11}) | ||||
@@ -349,6 +355,15 @@ TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_MK4) { | |||||
dtype::Float32{}); | dtype::Float32{}); | ||||
} | } | ||||
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_PACK_MK4) { | |||||
auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8); | |||||
matrix_mul::benchmark_with_contrast( | |||||
handle(), args, dtype::Float32{}, dtype::Float32{}, | |||||
dtype::Float32{}, "ARMV7_F32_MK4_PACK_4X12", | |||||
param::MatrixMul::Format::MK4, dtype::Float32{}, dtype::Float32{}, | |||||
dtype::Float32{}); | |||||
} | |||||
TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_MK8) { | TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT16x16x32_MK8) { | ||||
auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(4); | auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(4); | ||||
matrix_mul::benchmark_with_contrast( | matrix_mul::benchmark_with_contrast( | ||||