Browse Source

feat(dnn/arm): add fp32 asm gemm for a53 a55 and i8i8i16 gemm for a72 a53

GitOrigin-RevId: a049c33f2b
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 4 years ago
parent
commit
6e70fa7a11
27 changed files with 9380 additions and 1964 deletions
  1. +157
    -0
      dnn/src/aarch64/matrix_mul/algos.cpp
  2. +33
    -5
      dnn/src/aarch64/matrix_mul/algos.h
  3. +144
    -15
      dnn/src/aarch64/matrix_mul/asm/common.h
  4. +1021
    -1010
      dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h
  5. +1331
    -0
      dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h
  6. +1170
    -0
      dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h
  7. +827
    -796
      dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h
  8. +1260
    -0
      dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h
  9. +1160
    -0
      dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h
  10. +112
    -62
      dnn/src/aarch64/matrix_mul/fp32/strategy.cpp
  11. +2
    -1
      dnn/src/aarch64/matrix_mul/fp32/strategy.h
  12. +1265
    -0
      dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h
  13. +387
    -0
      dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h
  14. +161
    -1
      dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp
  15. +7
    -1
      dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h
  16. +7
    -2
      dnn/src/aarch64/matrix_mul/opr_impl.cpp
  17. +18
    -15
      dnn/src/aarch64/matrix_mul/opr_impl.h
  18. +0
    -1
      dnn/src/arm_common/conv_bias/opr_impl.cpp
  19. +2
    -2
      dnn/src/common/cpuinfo_arch_vendor.cpp
  20. +2
    -2
      dnn/src/common/cpuinfo_arch_vendor.h
  21. +106
    -5
      dnn/test/aarch64/matrix_mul.cpp
  22. +8
    -2
      dnn/test/arm_common/conv_bias.cpp
  23. +132
    -35
      dnn/test/arm_common/conv_bias_multi_thread.cpp
  24. +2
    -5
      dnn/test/arm_common/cpuinfo.cpp
  25. +17
    -0
      dnn/test/arm_common/cpuinfo_help.cpp
  26. +47
    -0
      dnn/test/arm_common/cpuinfo_help.h
  27. +2
    -4
      dnn/test/x86/cpuinfo.cpp

+ 157
- 0
dnn/src/aarch64/matrix_mul/algos.cpp View File

@@ -23,6 +23,9 @@
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_impl.h"

#if MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif
#include "midout.h"

MIDOUT_DECL(megdnn_aarch64_matmul_kern)
@@ -80,6 +83,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(
}
MIDOUT_END();
};

return f32_kern_8x12;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern,
@@ -837,6 +841,159 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x4x16,
aarch64::matmul::gemm_s8x8x16_4x4, int8_t,
int16_t);

/* ===================== Int8x8x16 K16x12x4 algo ===================== */
namespace {
void int8x8x16_mk4_16x12x4_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("int8x8x16_mk4_16x12x4_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<dt_int8>(),
Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int16>();

aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy(M, N, K, A_type,
B_type, C_type);
megdnn::matmul::GemmInterleaved<
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53>(M, N, K, trA, trB,
strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace

bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::usable(
const KernSizeParam& kern_size_param) const {
return can_be_treated_as_int8x8x16(kern_size_param) &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
!kern_size_param.trA && !kern_size_param.trB &&
kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0;
}

bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::preferred(
const KernSizeParam&) const {
#if !MGB_ENABLE_CPUINFO
return false;
#else
auto arch = cpuinfo_get_current_core()->uarch;
bool little_core = arch == cpuinfo_uarch_cortex_a53 ||
arch == cpuinfo_uarch_cortex_a55;
return little_core;
#endif
}

size_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("AlgoInt8x8x16MK4_16x12x4::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy(M, N, K, A_type,
B_type, C_type);
return megdnn::matmul::GemmInterleaved<
matmul::gemm_s8x8x16_mk4_16x12_a53>(M, N, K, trA, trB,
strategy)
.get_workspace_size();
}
MIDOUT_END();
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern(
const KernSizeParam&) const {
return int8x8x16_mk4_16x12x4_kern;
}

MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
AlgoInt8x8x16MK4_16x12x4, megdnn_aarch64_matmul_kern,
"AlgoInt8x8x16MK4_16x12x4Impl"_hash,
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t);

/* ===================== Int8x8x16 MK4 4x4x8 algo ===================== */
namespace {
void int8x8x16_mk4_4x4x8_kern(const MatrixMulImpl::KernParam& kern_param) {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("int8x8x16_mk4_4x4x8_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto trA = kern_param.trA, trB = kern_param.trB;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
auto A_type = kern_param.A_type, B_type = kern_param.B_type,
C_type = kern_param.C_type;
const auto Aptr = kern_param.A<dt_int8>(),
Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int16>();

aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy(M, N, K, A_type,
B_type, C_type);
megdnn::matmul::GemmInterleaved<
aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72>(M, N, K, trA, trB,
strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr);
}
MIDOUT_END();
}
} // anonymous namespace

bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::usable(
const KernSizeParam& kern_size_param) const {
return can_be_treated_as_int8x8x16(kern_size_param) &&
kern_size_param.format == param::MatrixMul::Format::MK4 &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
!kern_size_param.trA && !kern_size_param.trB &&
kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0;
}

bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::preferred(
const KernSizeParam&) const {
#if !MGB_ENABLE_CPUINFO
return false;
#else
auto arch = cpuinfo_get_current_core()->uarch;
bool little_core = arch == cpuinfo_uarch_cortex_a53 ||
arch == cpuinfo_uarch_cortex_a55;
return !little_core;
#endif
}

size_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_workspace(
const KernSizeParam& kern_size_param) const {
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern,
midout_iv("AlgoInt8x8x16MK4_4x4x8::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
auto trA = kern_size_param.trA, trB = kern_size_param.trB;
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type;
aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy(M, N, K, A_type,
B_type, C_type);
return megdnn::matmul::GemmInterleaved<
matmul::gemm_s8x8x16_mk4_4x4_a72>(M, N, K, trA, trB,
strategy)
.get_workspace_size();
}
MIDOUT_END();
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_kern(
const KernSizeParam&) const {
return int8x8x16_mk4_4x4x8_kern;
}

MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_4x4x8,
megdnn_aarch64_matmul_kern,
"AlgoInt8x8x16MK4_4x4x8_Impl"_hash,
aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72,
int8_t, int16_t);

/* ===================== Int16x16x32 K12x8x1 algo ===================== */
namespace {
void int16x16x32_k12x8x1_kern(const MatrixMulImpl::KernParam& kern_param) {


+ 33
- 5
dnn/src/aarch64/matrix_mul/algos.h View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#pragma once
@@ -121,12 +122,9 @@ public:
#else

class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase {

public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_INT8X8X32_MK4_4X4X16";
}
const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
@@ -188,6 +186,36 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};

class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
return "AARCH64_INT8X8X16_MK4_16X12X4";
}
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::DEFAULT; }

MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};

class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; }
bool usable(const KernSizeParam&) const override;
bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::DEFAULT; }

MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
};

class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }


+ 144
- 15
dnn/src/aarch64/matrix_mul/asm/common.h View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <cmath>
@@ -993,8 +994,8 @@ static inline void interleave_4x1_4_s(const int32_t*& inptr0,

template <typename T>
static inline void interleave_4x8_1_s(const T*& inptr0, const T*& inptr1,
const T*& inptr2, const T*& inptr3,
T*& outptr) {
const T*& inptr2, const T*& inptr3,
T*& outptr) {
static_assert(sizeof(T) == 4, "only support size == 4");
asm volatile(
"ld1 {v0.4s, v1.4s}, [%[inptr0]], #32\n"
@@ -1140,8 +1141,8 @@ static inline void interleave_2x4_4_s(const T*& inptr0, const T*& inptr1,
"stp q2, q6, [%[outptr], #64]\n"
"stp q3, q7, [%[outptr], #96]\n"

: [ inptr0 ] "+r"(inptr0), [ inptr1 ] "+r"(inptr1),
[ outptr ] "+r"(outptr)
:
[inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory");
}
@@ -1153,7 +1154,7 @@ static inline void interleave_1x4_4_s(const T*& inptr0, T* outptr) {
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n"

: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr)
: [inptr0] "+r"(inptr0), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "memory");
}
@@ -1550,7 +1551,7 @@ static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) {
"stp q2, q6, [%[outptr], #96] \n"
"stp q10, q3, [%[outptr], #128] \n"
"stp q7, q11, [%[outptr], #160] \n"
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr)
: [inptr0] "+r"(inptr0), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "memory");
@@ -1564,7 +1565,7 @@ static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) {
asm volatile(
"ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]]\n"
: [ inptr0 ] "+r"(inptr0), [ outptr ] "+r"(outptr)
: [inptr0] "+r"(inptr0), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "memory");
}
@@ -1681,13 +1682,12 @@ static inline void transpose_12x4_1_s(const T*& inptr0, const T*& inptr1,
"st1 {v3.4s,v4.4s,v5.4s}, [%[outptr]], #48\n"
"st1 {v6.4s,v7.4s,v8.4s}, [%[outptr]], #48\n"
"st1 {v24.4s,v25.4s,v26.4s}, [%[outptr]], #48\n"
: [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1),
[inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3),
[inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5),
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7),
[inptr8] "+r"(inptr8), [inptr9] "+r"(inptr9),
[inptr10] "+r"(inptr10), [inptr11] "+r"(inptr11),
[outptr] "+r"(outptr)
:
[inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2),
[inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5),
[inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8),
[inptr9] "+r"(inptr9), [inptr10] "+r"(inptr10),
[inptr11] "+r"(inptr11), [outptr] "+r"(outptr)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
@@ -1972,6 +1972,135 @@ static inline void transpose_interleave_1x4_4_b(const T*& inptr0, T* outptr,
: "v0", "v1", "v2", "v3", "v4", "memory");
}

static inline void interleave_4x4_16x4_s8_s16(const int8_t* inptr0,
const int8_t* inptr1,
const int8_t* inptr2,
const int8_t* inptr3,
int16_t* outptr) {
int8x16_t row0 = vld1q_s8(inptr0);
int16x8_t row0_01 = vmovl_low_s8(row0);
int16x8_t row0_23 = vmovl_high_s8(row0);
int16x4_t row0_0 = vget_low_s16(row0_01);
int16x4_t row0_1 = vget_high_s16(row0_01);
int16x4_t row0_2 = vget_low_s16(row0_23);
int16x4_t row0_3 = vget_high_s16(row0_23);

int8x16_t row1 = vld1q_s8(inptr1);
int16x8_t row1_01 = vmovl_low_s8(row1);
int16x8_t row1_23 = vmovl_high_s8(row1);
int16x4_t row1_0 = vget_low_s16(row1_01);
int16x4_t row1_1 = vget_high_s16(row1_01);
int16x4_t row1_2 = vget_low_s16(row1_23);
int16x4_t row1_3 = vget_high_s16(row1_23);

int8x16_t row2 = vld1q_s8(inptr2);
int16x8_t row2_01 = vmovl_low_s8(row2);
int16x8_t row2_23 = vmovl_high_s8(row2);
int16x4_t row2_0 = vget_low_s16(row2_01);
int16x4_t row2_1 = vget_high_s16(row2_01);
int16x4_t row2_2 = vget_low_s16(row2_23);
int16x4_t row2_3 = vget_high_s16(row2_23);

int8x16_t row3 = vld1q_s8(inptr3);
int16x8_t row3_01 = vmovl_low_s8(row3);
int16x8_t row3_23 = vmovl_high_s8(row3);
int16x4_t row3_0 = vget_low_s16(row3_01);
int16x4_t row3_1 = vget_high_s16(row3_01);
int16x4_t row3_2 = vget_low_s16(row3_23);
int16x4_t row3_3 = vget_high_s16(row3_23);

vst1_s16(outptr, row0_0);
vst1_s16(outptr + 1 * 4, row1_0);
vst1_s16(outptr + 2 * 4, row2_0);
vst1_s16(outptr + 3 * 4, row3_0);
vst1_s16(outptr + 4 * 4, row0_1);
vst1_s16(outptr + 5 * 4, row1_1);
vst1_s16(outptr + 6 * 4, row2_1);
vst1_s16(outptr + 7 * 4, row3_1);
vst1_s16(outptr + 8 * 4, row0_2);
vst1_s16(outptr + 9 * 4, row1_2);
vst1_s16(outptr + 10 * 4, row2_2);
vst1_s16(outptr + 11 * 4, row3_2);
vst1_s16(outptr + 12 * 4, row0_3);
vst1_s16(outptr + 13 * 4, row1_3);
vst1_s16(outptr + 14 * 4, row2_3);
vst1_s16(outptr + 15 * 4, row3_3);
};
static inline void interleave_4x4_8x4_s8_s16(const int8_t* inptr0,
const int8_t* inptr1,
int16_t* outptr) {
int8x16_t row0 = vld1q_s8(inptr0);
int16x8_t row0_01 = vmovl_low_s8(row0);
int16x8_t row0_23 = vmovl_high_s8(row0);
int16x4_t row0_0 = vget_low_s16(row0_01);
int16x4_t row0_1 = vget_high_s16(row0_01);
int16x4_t row0_2 = vget_low_s16(row0_23);
int16x4_t row0_3 = vget_high_s16(row0_23);

int8x16_t row1 = vld1q_s8(inptr1);
int16x8_t row1_01 = vmovl_low_s8(row1);
int16x8_t row1_23 = vmovl_high_s8(row1);
int16x4_t row1_0 = vget_low_s16(row1_01);
int16x4_t row1_1 = vget_high_s16(row1_01);
int16x4_t row1_2 = vget_low_s16(row1_23);
int16x4_t row1_3 = vget_high_s16(row1_23);

vst1_s16(outptr, row0_0);
vst1_s16(outptr + 1 * 4, row1_0);
vst1_s16(outptr + 2 * 4, row0_1);
vst1_s16(outptr + 3 * 4, row1_1);
vst1_s16(outptr + 4 * 4, row0_2);
vst1_s16(outptr + 5 * 4, row1_2);
vst1_s16(outptr + 6 * 4, row0_3);
vst1_s16(outptr + 7 * 4, row1_3);
};

static inline void memcpy_s8_s16(const int8_t* inptr, int16_t* outptr,
int count) {
for (; count >= 32; count -= 32) {
int8x8_t in0 = vld1_s8(inptr);
int8x8_t in1 = vld1_s8(inptr + 1 * 8);
int8x8_t in2 = vld1_s8(inptr + 2 * 8);
int8x8_t in3 = vld1_s8(inptr + 3 * 8);
vst1q_s16(outptr, vmovl_s8(in0));
vst1q_s16(outptr + 1 * 8, vmovl_s8(in1));
vst1q_s16(outptr + 2 * 8, vmovl_s8(in2));
vst1q_s16(outptr + 3 * 8, vmovl_s8(in3));
inptr += 32;
outptr += 32;
}
for (; count >= 8; count -= 8) {
int8x8_t in0 = vld1_s8(inptr);
vst1q_s16(outptr, vmovl_s8(in0));
inptr += 8;
outptr += 8;
}
for (; count > 0; --count) {
*outptr++ = (int16_t)(*inptr++);
}
}

static inline void transpos_12x4_s8(const int8_t* inptr0, int8_t* outptr) {
static const uint8_t src_idx_buffer[16] = {0, 4, 8, 12, 1, 5, 9, 13,
2, 6, 10, 14, 3, 7, 11, 15};
static const uint8x16_t vtbl = vld1q_u8(&src_idx_buffer[0]);
int8x8x4_t input = vld4_s8(inptr0);
int8x16_t input2 = vqtbl1q_s8(vld1q_s8(inptr0 + 4 * 8), vtbl);

vst1_s8(outptr, input.val[0]);
vst1q_lane_s32(reinterpret_cast<int32_t*>(outptr + 8),
vreinterpretq_s32_s8(input2), 0);
vst1_s8(outptr + 1 * 12, input.val[1]);
vst1q_lane_s32(reinterpret_cast<int32_t*>(outptr + 1 * 12 + 8),
vreinterpretq_s32_s8(input2), 1);
vst1_s8(outptr + 2 * 12, input.val[2]);
vst1q_lane_s32(reinterpret_cast<int32_t*>(outptr + 2 * 12 + 8),
vreinterpretq_s32_s8(input2), 2);
vst1_s8(outptr + 3 * 12, input.val[3]);
vst1q_lane_s32(reinterpret_cast<int32_t*>(outptr + 3 * 12 + 8),
vreinterpretq_s32_s8(input2), 3);
}

} // namespace aarch64
} // namespace megdnn



+ 1021
- 1010
dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h
File diff suppressed because it is too large
View File


+ 1331
- 0
dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h
File diff suppressed because it is too large
View File


+ 1170
- 0
dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h
File diff suppressed because it is too large
View File


+ 827
- 796
dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h
File diff suppressed because it is too large
View File


+ 1260
- 0
dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h
File diff suppressed because it is too large
View File


+ 1160
- 0
dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h
File diff suppressed because it is too large
View File


+ 112
- 62
dnn/src/aarch64/matrix_mul/fp32/strategy.cpp View File

@@ -6,42 +6,55 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h"
#include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h"
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h"
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h"
#include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h"
#include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/common/utils.h"

#if MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif

using namespace megdnn;
using namespace aarch64;
using namespace aarch64::matmul;

MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16);

void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0,
int ymax, int k0, int kmax, bool transpose_A) const {
void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, int ymax,
int k0, int kmax, bool transpose_A) const {
if (transpose_A) {
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax);
matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0,
kmax);
} else {
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax);
matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0,
kmax);
}
}

void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax,
int k0, int kmax, bool transpose_B) const {
if (transpose_B) {
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax);
matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0,
kmax);
} else {
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax);
matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0,
kmax);
}
}

void sgemm_4x16::kern(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
void sgemm_4x16::kern(const float* packA, const float* packB, size_t M,
size_t N, size_t K, float* C, size_t LDC, bool is_first_k,
const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
@@ -61,15 +74,17 @@ void sgemm_4x16::kern(const float* packA, const float* packB,
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC,
is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K16;
}

for (; n < N; n += 4) {
matmul_general_4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
matmul_general_4x16::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
@@ -80,8 +95,8 @@ void sgemm_4x16::kern(const float* packA, const float* packB,

MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12);

void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0,
int ymax, int k0, int kmax, bool transpose_A) const {
void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, int ymax,
int k0, int kmax, bool transpose_A) const {
if (transpose_A) {
matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0,
kmax);
@@ -102,16 +117,10 @@ void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax,
}
}

void sgemm_8x12::kern(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);

template <typename gemm_class>
static inline void sgemm_8x12_helper(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k) {
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t A_INTERLEAVE4 = 4;
constexpr size_t B_INTERLEAVE = 12;
@@ -126,16 +135,14 @@ void sgemm_8x12::kern(const float* packA, const float* packB,
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) {
matmul_general_8x12::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k);
gemm_class::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE;
cur_packB += K12;
}

for (; n < N; n += 4) {
matmul_general_8x12::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k,
std::min<size_t>(N - n, 4));
gemm_class::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
@@ -146,17 +153,16 @@ void sgemm_8x12::kern(const float* packA, const float* packB,
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_general_8x12::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k,
std::min<size_t>(M - m, 4));
gemm_class::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4));
output += B_INTERLEAVE;
cur_packB += K12;
}

for (; n < N; n += 4) {
matmul_general_8x12::kern_4x4(
packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4));
gemm_class::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(M - m, 4),
std::min<size_t>(N - n, 4));
output += 4;
cur_packB += K4;
}
@@ -164,6 +170,33 @@ void sgemm_8x12::kern(const float* packA, const float* packB,
}
}

void sgemm_8x12::kern(const float* packA, const float* packB, size_t M,
size_t N, size_t K, float* C, size_t LDC, bool is_first_k,
const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
#if !MGB_ENABLE_CPUINFO
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC,
is_first_k);
#else
auto arch = cpuinfo_get_current_core()->uarch;
if (arch == cpuinfo_uarch_cortex_a53) {
sgemm_8x12_helper<matmul_general_8x12_a53>(packA, packB, M, N, K, C,
LDC, is_first_k);
} else if (arch == cpuinfo_uarch_cortex_a55) {
sgemm_8x12_helper<matmul_general_8x12_a55>(packA, packB, M, N, K, C,
LDC, is_first_k);
} else {
sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC,
is_first_k);
}
#endif
}

MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12);

void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0,
@@ -180,25 +213,17 @@ void sgemm_mk4_8x12::pack_B(float* out, const float* in, int ldin, int x0,
matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax);
}

void sgemm_mk4_8x12::kern(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4");

template <typename gemm_name>
static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB,
size_t M, size_t N, size_t K, float* C,
size_t LDC, bool is_first_k) {
const int K12 = K * 12;
const int K8 = K * 8;
const int K4 = K * 4;
constexpr size_t PACK_C_SIZE = 4;
constexpr size_t A_INTERLEAVE = 8;
constexpr size_t A_INTERLEAVE4 = 4;
constexpr size_t B_INTERLEAVE = 12;
const int K12 = K * 12;
const int K8 = K * 8;
const int K4 = K * 4;

size_t m = 0;
for (; m + A_INTERLEAVE <= M; m += A_INTERLEAVE) {
float* output = C + (m / PACK_C_SIZE * LDC);
@@ -206,15 +231,14 @@ void sgemm_mk4_8x12::kern(const float* packA, const float* packB,
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE <= N; n += B_INTERLEAVE) {
matmul_mk4_8x12::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k);
gemm_name::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE * PACK_C_SIZE;
cur_packB += K12;
}

for (; n < N; n += 4) {
matmul_mk4_8x12::kern_8x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
for (; n < N; n += 4) {
gemm_name::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4 * PACK_C_SIZE;
cur_packB += K4;
}
@@ -225,19 +249,45 @@ void sgemm_mk4_8x12::kern(const float* packA, const float* packB,
size_t n = 0;
const float* cur_packB = packB;
for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) {
matmul_mk4_8x12::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k);
gemm_name::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k);
output += B_INTERLEAVE * PACK_C_SIZE;
cur_packB += K12;
}
for (; n < N; n += 4) {
matmul_mk4_8x12::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, std::min<size_t>(N - n, 4));
gemm_name::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k,
std::min<size_t>(N - n, 4));
output += 4 * PACK_C_SIZE;
cur_packB += K4;
}
packA += K4;
}
}
void sgemm_mk4_8x12::kern(const float* packA, const float* packB, size_t M,
size_t N, size_t K, float* C, size_t LDC,
bool is_first_k, const float*, float*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
A_dtype.enumv() == C_dtype.enumv() &&
A_dtype.enumv() == DTypeEnum::Float32);
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4");
#if !MGB_ENABLE_CPUINFO
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC,
is_first_k);
#else
auto arch = cpuinfo_get_current_core()->uarch;
if (arch == cpuinfo_uarch_cortex_a53) {
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>(packA, packB, M, N, K, C,
LDC, is_first_k);
} else if (arch == cpuinfo_uarch_cortex_a55) {
sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>(packA, packB, M, N, K, C,
LDC, is_first_k);
} else {
sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC,
is_first_k);
}
#endif
}

// vim: syntax=cpp.doxygen

+ 2
- 1
dnn/src/aarch64/matrix_mul/fp32/strategy.h View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/fallback/matrix_mul/gemm_common.h"


+ 1265
- 0
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h
File diff suppressed because it is too large
View File


+ 387
- 0
dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h View File

@@ -0,0 +1,387 @@
/**
* \file dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include <inttypes.h>
#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h"

namespace megdnn {
namespace aarch64 {
namespace matmul_mk4_4x4x8_a72 {

//! optimize for A72

// clang-format off
/**
* Overview of register layout:
*
* A 4x4x8 cell of Lhs is stored in 8bit in q0-q3, q4-q7
* A 4x4x8 cell of Rhs is stored in 8bit in q8-q11, q12-q15
* A 4x4 block of accumulators is stored in 16bit in q16-q31
*
* +------------------------+
* | q8 | q9 | q10 | q11 |
* Rhs +------------------------+
* Lhs | | | | |
* +--------+ - - - - +------------------------+
* | q0 | | q16 | q20 | q24 | q28 |
* | q1 | | q17 | q21 | q25 | q29 |
* | q2 | | q18 | q22 | q26 | q30 |
* | q3 | | q19 | q23 | q27 | q31 |
* +--------+ - - - - +------------------------+
*
* Accumulator
*/

// clang-format on
static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int16_t* output, int LDC, bool, int remain_n) {
K = div_ceil(K, 8);
int oddk = (K & 1);
K = ((K + 1) / 2) - 1;

const int8_t* a_ptr = packA;
const int8_t* b_ptr = packB;

LDC = LDC * sizeof(int8_t);
// clang-format off
#define STORE_LINE(reg0) \
"cmp w10, #0 \n" \
"beq 101f\n" \
"st1 {v" reg0 ".4h}, [x0], #8\n" \
"subs w10, w10, #1\n"

#define STORE_C \
"mov w10, %w[remain_n]\n" \
STORE_LINE("16") \
STORE_LINE("20") \
STORE_LINE("24") \
STORE_LINE("28")

// clang-format on

register int16_t* outptr asm("x0") = output;
asm volatile(
// load accumulator C

"1:\n"

"eor v16.16b, v16.16b, v16.16b\n"
"eor v17.16b, v17.16b, v17.16b\n"
"eor v18.16b, v18.16b, v18.16b\n"
"eor v19.16b, v19.16b, v19.16b\n"
"eor v20.16b, v20.16b, v20.16b\n"
"eor v21.16b, v21.16b, v21.16b\n"
"eor v22.16b, v22.16b, v22.16b\n"
"eor v23.16b, v23.16b, v23.16b\n"
"eor v24.16b, v24.16b, v24.16b\n"
"eor v25.16b, v25.16b, v25.16b\n"
"eor v26.16b, v26.16b, v26.16b\n"
"eor v27.16b, v27.16b, v27.16b\n"
"eor v28.16b, v28.16b, v28.16b\n"
"eor v29.16b, v29.16b, v29.16b\n"
"eor v30.16b, v30.16b, v30.16b\n"
"eor v31.16b, v31.16b, v31.16b\n"

"2: \n"

"ld1 {v0.8b, v1.8b}, [%[a_ptr]], #16\n"
"ld1 {v2.8b, v3.8b}, [%[a_ptr]], #16\n"
"ld1 {v8.8b, v9.8b}, [%[b_ptr]], #16\n"
"ld1 {v10.8b, v11.8b}, [%[b_ptr]], #16\n"

"cmp %w[K], #0\n"
"beq 4f\n"

"3: \n"
//! k = 0
"smlal v16.8h, v0.8b, v8.8b\n"
"ld1 {v4.8b}, [%[a_ptr]], #8\n"
"smlal v17.8h, v1.8b, v8.8b\n"
"smlal v18.8h, v2.8b, v8.8b\n"
"ld1 {v5.8b}, [%[a_ptr]], #8\n"
"smlal v19.8h, v3.8b, v8.8b\n"
"smlal v20.8h, v0.8b, v9.8b\n"
"ld1 {v6.8b}, [%[a_ptr]], #8\n"
"smlal v21.8h, v1.8b, v9.8b\n"
"smlal v22.8h, v2.8b, v9.8b\n"
"ld1 {v7.8b}, [%[a_ptr]], #8\n"
"smlal v23.8h, v3.8b, v9.8b\n"
"smlal v24.8h, v0.8b, v10.8b\n"
"ld1 {v12.8b}, [%[b_ptr]], #8\n"
"smlal v25.8h, v1.8b, v10.8b\n"
"smlal v26.8h, v2.8b, v10.8b\n"
"ld1 {v13.8b}, [%[b_ptr]], #8\n"
"smlal v27.8h, v3.8b, v10.8b\n"
"smlal v28.8h, v0.8b, v11.8b\n"
"ld1 {v14.8b}, [%[b_ptr]], #8\n"
"smlal v29.8h, v1.8b, v11.8b\n"
"smlal v30.8h, v2.8b, v11.8b\n"
"ld1 {v15.8b}, [%[b_ptr]], #8\n"
"smlal v31.8h, v3.8b, v11.8b\n"
//! k = 8
"smlal v16.8h, v4.8b, v12.8b\n"
"ld1 {v0.8b}, [%[a_ptr]], #8\n"
"smlal v17.8h, v5.8b, v12.8b\n"
"smlal v18.8h, v6.8b, v12.8b\n"
"ld1 {v1.8b}, [%[a_ptr]], #8\n"
"smlal v19.8h, v7.8b, v12.8b\n"
"smlal v20.8h, v4.8b, v13.8b\n"
"ld1 {v2.8b}, [%[a_ptr]], #8\n"
"smlal v21.8h, v5.8b, v13.8b\n"
"smlal v22.8h, v6.8b, v13.8b\n"
"ld1 {v3.8b}, [%[a_ptr]], #8\n"
"smlal v23.8h, v7.8b, v13.8b\n"
"smlal v24.8h, v4.8b, v14.8b\n"
"ld1 {v8.8b}, [%[b_ptr]], #8\n"
"smlal v25.8h, v5.8b, v14.8b\n"
"smlal v26.8h, v6.8b, v14.8b\n"
"ld1 {v9.8b}, [%[b_ptr]], #8\n"
"smlal v27.8h, v7.8b, v14.8b\n"
"smlal v28.8h, v4.8b, v15.8b\n"
"ld1 {v10.8b}, [%[b_ptr]], #8\n"
"smlal v29.8h, v5.8b, v15.8b\n"
"smlal v30.8h, v6.8b, v15.8b\n"
"ld1 {v11.8b}, [%[b_ptr]], #8\n"
"smlal v31.8h, v7.8b, v15.8b\n"

"subs %w[K], %w[K], #1\n"
"bne 3b\n"

"4:\n"
"cmp %w[oddk], #1\n"
"beq 5f\n"
//! even tail
//! k = 0
"smlal v16.8h, v0.8b, v8.8b\n"
"ld1 {v4.8b}, [%[a_ptr]], #8\n"
"smlal v17.8h, v1.8b, v8.8b\n"
"smlal v18.8h, v2.8b, v8.8b\n"
"ld1 {v5.8b}, [%[a_ptr]], #8\n"
"smlal v19.8h, v3.8b, v8.8b\n"
"smlal v20.8h, v0.8b, v9.8b\n"
"ld1 {v6.8b}, [%[a_ptr]], #8\n"
"smlal v21.8h, v1.8b, v9.8b\n"
"smlal v22.8h, v2.8b, v9.8b\n"
"ld1 {v7.8b}, [%[a_ptr]], #8\n"
"smlal v23.8h, v3.8b, v9.8b\n"
"smlal v24.8h, v0.8b, v10.8b\n"
"ld1 {v12.8b}, [%[b_ptr]], #8\n"
"smlal v25.8h, v1.8b, v10.8b\n"
"smlal v26.8h, v2.8b, v10.8b\n"
"ld1 {v13.8b}, [%[b_ptr]], #8\n"
"smlal v27.8h, v3.8b, v10.8b\n"
"smlal v28.8h, v0.8b, v11.8b\n"
"ld1 {v14.8b}, [%[b_ptr]], #8\n"
"smlal v29.8h, v1.8b, v11.8b\n"
"smlal v30.8h, v2.8b, v11.8b\n"
"ld1 {v15.8b}, [%[b_ptr]], #8\n"
"smlal v31.8h, v3.8b, v11.8b\n"
//! k = 8
"smlal v16.8h, v4.8b, v12.8b\n"
"smlal v17.8h, v5.8b, v12.8b\n"
"smlal v18.8h, v6.8b, v12.8b\n"
"smlal v19.8h, v7.8b, v12.8b\n"
"smlal v20.8h, v4.8b, v13.8b\n"
"smlal v21.8h, v5.8b, v13.8b\n"
"smlal v22.8h, v6.8b, v13.8b\n"
"smlal v23.8h, v7.8b, v13.8b\n"
"smlal v24.8h, v4.8b, v14.8b\n"
"smlal v25.8h, v5.8b, v14.8b\n"
"smlal v26.8h, v6.8b, v14.8b\n"
"smlal v27.8h, v7.8b, v14.8b\n"
"smlal v28.8h, v4.8b, v15.8b\n"
"smlal v29.8h, v5.8b, v15.8b\n"
"smlal v30.8h, v6.8b, v15.8b\n"
"smlal v31.8h, v7.8b, v15.8b\n"
"b 6f\n"

"5:\n"
//! odd tail
"smlal v16.8h, v0.8b, v8.8b\n"
"smlal v17.8h, v1.8b, v8.8b\n"
"smlal v18.8h, v2.8b, v8.8b\n"
"smlal v19.8h, v3.8b, v8.8b\n"
"smlal v20.8h, v0.8b, v9.8b\n"
"smlal v21.8h, v1.8b, v9.8b\n"
"smlal v22.8h, v2.8b, v9.8b\n"
"smlal v23.8h, v3.8b, v9.8b\n"
"smlal v24.8h, v0.8b, v10.8b\n"
"smlal v25.8h, v1.8b, v10.8b\n"
"smlal v26.8h, v2.8b, v10.8b\n"
"smlal v27.8h, v3.8b, v10.8b\n"
"smlal v28.8h, v0.8b, v11.8b\n"
"smlal v29.8h, v1.8b, v11.8b\n"
"smlal v30.8h, v2.8b, v11.8b\n"
"smlal v31.8h, v3.8b, v11.8b\n"

"6:\n"
//! reduece
"addp v16.8h, v16.8h, v17.8h\n"
"addp v18.8h, v18.8h, v19.8h\n"
"addp v20.8h, v20.8h, v21.8h\n"
"addp v22.8h, v22.8h, v23.8h\n"
"addp v24.8h, v24.8h, v25.8h\n"
"addp v26.8h, v26.8h, v27.8h\n"

"addp v16.8h, v16.8h, v18.8h\n"
"addp v28.8h, v28.8h, v29.8h\n"
"addp v30.8h, v30.8h, v31.8h\n"
"addp v20.8h, v20.8h, v22.8h\n"

"addp v16.8h, v16.8h, v16.8h\n"
"addp v20.8h, v20.8h, v20.8h\n"

"addp v24.8h, v24.8h, v26.8h\n"
"addp v24.8h, v24.8h, v24.8h\n"

"addp v28.8h, v28.8h, v30.8h\n"
"addp v28.8h, v28.8h, v28.8h\n"

"cmp %w[remain_n], #4\n"
"bne 7f\n"

"st1 {v16.4h}, [x0], #8\n"
"st1 {v20.4h}, [x0], #8\n"
"st1 {v24.4h}, [x0], #8\n"
"st1 {v28.4h}, [x0], #8\n"
"b 101f\n"

"7:\n" STORE_C

"101:\n"
: [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K),
[oddk] "+r"(oddk), [LDC] "+r"(LDC), [outptr] "+r"(outptr),
[remain_n] "+r"(remain_n)
:
: "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10",
"v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19",
"v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28",
"v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7",
"x8", "x9", "x10", "cc", "memory");

#undef STORE_C
#undef STORE_LINE
}
static inline void transpose_8x4_b(const dt_int8* inptr, dt_int8* outptr) {
int8x8x4_t in0 = vld4_s8(inptr);
vst1_s8(outptr + 0 * 8, in0.val[0]);
vst1_s8(outptr + 1 * 8, in0.val[1]);
vst1_s8(outptr + 2 * 8, in0.val[2]);
vst1_s8(outptr + 3 * 8, in0.val[3]);
}

static inline void interleve_8x4_b(const dt_int8* inptr, const dt_int8* inptr2,
dt_int8* outptr) {
int8x16_t in0 = vld1q_s8(inptr);
int8x16_t in1 = vld1q_s8(inptr2);
int32x4x2_t in_x2 = {
{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}};
vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2);
}

static inline void interleve_8x4_b_pad(const dt_int8* inptr, dt_int8* outptr) {
int8x16_t in0 = vld1q_s8(inptr);
int8x16_t in1 = vdupq_n_s8(0);
int32x4x2_t in_x2 = {
{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}};
vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2);
}

static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in,
int ldin, int m0, int mmax, int k0,
int kmax) {
megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4");
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");
constexpr int pack_m = 4;
constexpr int pack_k = 8;
constexpr int pack_size = 4;
const int ksize = kmax - k0;
const int remain_k = ksize % pack_k;
const int kend = kmax - remain_k;
int8_t tmpbuff[pack_m * pack_k]{0};

for (int m_idx = m0; m_idx < mmax; m_idx += pack_m) {
const int8_t* inptr0 = in + m_idx / pack_size * ldin + k0;

for (int k_idx = k0; k_idx < kend; k_idx += pack_k) {
transpose_8x4_b(inptr0, out);
inptr0 += pack_m * pack_k;
out += pack_m * pack_k;
}
if (remain_k > 0) {
int8x16_t tmp = vld1q_s8(inptr0);
vst1q_s8(&tmpbuff[0], tmp);
transpose_8x4_b(&tmpbuff[0], out);
inptr0 += pack_m * pack_size;
out += pack_m * pack_k;
}
}
}

static void gemm_s8x8x16_mk4_4x4x8_pack_B(dt_int8* out, const dt_int8* in,
int ldin, int n0, int nmax, int k0,
int kmax) {
megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4");

constexpr int pack_n = 4;
constexpr int pack_k = 8;
constexpr int pack_size = 4;
const int ksize = kmax - k0;
const int packed_ksize = round_up(ksize, pack_k);
const int remain_k = ksize % pack_k;
const int kend = kmax - remain_k;
const int nsize = nmax - n0;
const int remain_n = nsize % pack_n;
const int nend = nmax - remain_n;
const int stride_input = pack_size * nsize;
int8_t tmpbuff[pack_n * pack_k]{0};
int8_t tmpbuff2[pack_n * pack_k]{0};

for (int k_idx = k0; k_idx < kend; k_idx += pack_k) {
const int8_t* inptr = in + k_idx / pack_size * ldin + n0 * pack_size;
const int8_t* inptr2 = inptr + stride_input;
int8_t* outptr = out + k_idx * pack_n;
for (int n_idx = n0; n_idx < nend; n_idx += pack_n) {
interleve_8x4_b(inptr, inptr2, outptr);
inptr += pack_n * pack_size;
inptr2 += pack_n * pack_size;
outptr += pack_n * packed_ksize;
}
if (remain_n > 0) {
memcpy(&tmpbuff[0], inptr, remain_n * pack_size * sizeof(int8_t));
memcpy(&tmpbuff2[0], inptr2, remain_n * pack_size * sizeof(int8_t));
interleve_8x4_b(&tmpbuff[0], &tmpbuff2[0], outptr);
outptr += pack_n * packed_ksize;
}
}
if (remain_k > 0) {
const int8_t* inptr = in + kend / pack_size * ldin + n0 * pack_size;
int8_t* outptr = out + kend * pack_n;
for (int n_idx = n0; n_idx < nend; n_idx += pack_n) {
interleve_8x4_b_pad(inptr, outptr);
inptr += pack_n * pack_size;
outptr += pack_n * packed_ksize;
}
if (remain_n > 0) {
memcpy(&tmpbuff[0], inptr, remain_n * pack_size * sizeof(int8_t));
interleve_8x4_b_pad(&tmpbuff[0], outptr);
outptr += pack_n * packed_ksize;
}
}
}

} // namespace matmul_mk4_4x4x8_a72
} // namespace aarch64
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 161
- 1
dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp View File

@@ -6,12 +6,15 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h"
#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h"
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
@@ -197,4 +200,161 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB,
packA += K4;
}
}

// ===========================gemm_s8x8x16_mk4_16x12==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_16x12_a53);

void gemm_s8x8x16_mk4_16x12_a53::pack_A(dt_int16* out, const dt_int8* in,
int ldin, int y0, int ymax, int k0,
int kmax, bool) const {
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A(out, in, ldin, y0,
ymax, k0, kmax);
}

void gemm_s8x8x16_mk4_16x12_a53::pack_B(dt_int8* out, const dt_int8* in,
int ldin, int x0, int xmax, int k0,
int kmax, bool) const {
matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B(out, in, ldin, x0,
xmax, k0, kmax);
}

void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA,
const dt_int8* packB, size_t M, size_t N,
size_t K, dt_int16* C, size_t LDC,
bool is_first_k, const dt_int16*,
dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
megdnn_assert(is_first_k == true, "only impl is_first_k");
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4");

constexpr size_t pack_size = 4;
constexpr size_t pack_m = 16;
constexpr size_t pack_n = 12;
const size_t remain_n = N % pack_n;
size_t remain_m = M % pack_m;

size_t m_idx = 0;
for (; m_idx + pack_m <= M; m_idx += pack_m) {
int16_t* output = C + (m_idx / pack_size * LDC);

size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * K;
}
if (remain_n > 0) {
matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * K;
}
packA += pack_m * K;
}

if (remain_m >= 8) {
int16_t* output = C + (m_idx / pack_size * LDC);
size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * K;
}
if (remain_n > 0) {
matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * K;
}
packA += 8 * K;
m_idx += 8;
remain_m -= 8;
}

if (remain_m == 4) {
int16_t* output = C + (m_idx / pack_size * LDC);
size_t n_idx = 0;
const int8_t* cur_packB = packB;
for (; n_idx + pack_n <= N; n_idx += pack_n) {
matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * K;
}
if (remain_n > 0) {
matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * K;
}
}
}

// ===========================gemm_s8x8x16_mk4_4x4_a72==================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_4x4_a72);

void gemm_s8x8x16_mk4_4x4_a72::pack_A(dt_int8* out, const dt_int8* in, int ldin,
int y0, int ymax, int k0, int kmax,
bool) const {
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A(out, in, ldin, y0, ymax,
k0, kmax);
}

void gemm_s8x8x16_mk4_4x4_a72::pack_B(dt_int8* out, const dt_int8* in, int ldin,
int x0, int xmax, int k0, int kmax,
bool) const {
matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B(out, in, ldin, x0, xmax,
k0, kmax);
}

void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB,
size_t M, size_t N, size_t K, dt_int16* C,
size_t LDC, bool is_first_k,
const dt_int16*, dt_int16*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&
C_dtype.enumv() == DTypeEnum::Int16 &&
A_dtype.enumv() == DTypeEnum::Int8);
megdnn_assert(is_first_k == true, "only impl is_first_k");
MEGDNN_MARK_USED_VAR(A_dtype);
MEGDNN_MARK_USED_VAR(B_dtype);
MEGDNN_MARK_USED_VAR(C_dtype);
megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4");

constexpr size_t pack_size = 4;
constexpr size_t pack_m = 4;
constexpr size_t pack_n = 4;
constexpr size_t pack_k = 8;
const size_t remain_n = N % pack_n;
const size_t nend = N - remain_n;
const size_t packed_k = round_up(K, pack_k);

for (size_t m_idx = 0; m_idx < M; m_idx += pack_m) {
int16_t* output = C + (m_idx / pack_size * LDC);

const int8_t* cur_packB = packB;
for (size_t n_idx = 0; n_idx < nend; n_idx += pack_n) {
matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, pack_n);
output += pack_n * pack_size;
cur_packB += pack_n * packed_k;
}
if (remain_n > 0) {
matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC,
is_first_k, remain_n);
output += remain_n * pack_size;
cur_packB += pack_n * packed_k;
}
packA += pack_m * packed_k;
}
}

// vim: syntax=cpp.doxygen

+ 7
- 1
dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once

@@ -20,6 +21,11 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true,
gemm_s8x8x16_8x8);
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 16, false, true,
gemm_s8x8x16_4x4);
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 8, false, false,
gemm_s8x8x16_mk4_4x4_a72);
MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int16,
16, 12, 4, false, false,
gemm_s8x8x16_mk4_16x12_a53);

} // namespace matmul
} // namespace aarch64


+ 7
- 2
dnn/src/aarch64/matrix_mul/opr_impl.cpp View File

@@ -6,10 +6,11 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/aarch64/matrix_mul/opr_impl.h"
#include "src/aarch64/matrix_mul/algos.h"
#include "src/aarch64/matrix_mul/opr_impl.h"
#include "src/common/metahelper.h"
#include "src/common/utils.h"

@@ -36,6 +37,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#endif
AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8;
AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16;
AlgoInt8x8x16MK4_16x12x4 int8x8x16_mk4_16x12x4;
AlgoInt8x8x16MK4_4x4x8 int8x8x16_mk4_4x4x8;

AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1;
AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8;
@@ -70,6 +73,8 @@ public:
#endif
all_algos.emplace_back(&int8x8x16_k4x4x16);
all_algos.emplace_back(&int8x8x16_k8x8x8);
all_algos.emplace_back(&int8x8x16_mk4_4x4x8);
all_algos.emplace_back(&int8x8x16_mk4_16x12x4);

all_algos.emplace_back(&int16x16x32_k12x8x1);
all_algos.emplace_back(&int16x16x32_mk8_8x8);


+ 18
- 15
dnn/src/aarch64/matrix_mul/opr_impl.h View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/arm_common/matrix_mul/opr_impl.h"
@@ -21,28 +22,30 @@ public:
SmallVector<AlgoBase*> algo_pack() override;

private:
class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1
class AlgoF32MK4_8x12x1; // Aarch64 F32 Kernel MK4 8x12x1
class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1
class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4
class AlgoF32Gemv; // Aarch64 F32 Gemv
class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1
class AlgoF32MK4_8x12x1; // Aarch64 F32 Kernel MK4 8x12x1
class AlgoF32K4x16x1; // Aarch64 F32 Kernel 4x16x1
class AlgoF32MK4_4x16; // Aarch64 F32 Format MK4 block 16x4
class AlgoF32Gemv; // Aarch64 F32 Gemv
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class AlgoF16K8x24x1; // Aarch64 F16 Kernel 8x24x1
class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8
#endif

#if __ARM_FEATURE_DOTPROD
class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel
// 8x12x4 DotProduct
class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel
// 8x12x4 DotProduct
class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel
// 8x12x4 DotProduct
class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel
// 8x12x4 DotProduct
#else
class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16
class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16
class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8
class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16
class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8
#endif
class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8
class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16
class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8
class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16
class AlgoInt8x8x16MK4_16x12x4; // Aarch64 Int8x8x16 Kernel 16x12x16
class AlgoInt8x8x16MK4_4x4x8; // Aarch64 Int8x8x16 Kernel 4x4x8

class AlgoInt16x16x32K12x8x1; // Aarch64 Int16x16x32 Kernel 12x8x1
class AlgoInt16x16x32MK8_8x8; // Aarch64 Int16x16x32 Format MK8 block 8x8
@@ -52,7 +55,7 @@ private:
// 8x8x4 DotProduct
class AlgoQuint8GemvDotProd; // Aarch64 Quint8 Gemv DotProduct
#else
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8
#endif

class AlgoPack;


+ 0
- 1
dnn/src/arm_common/conv_bias/opr_impl.cpp View File

@@ -214,7 +214,6 @@ void* const ConvBiasImpl::sm_arm_common_algo_type =

bool ConvBiasImpl::is_matmul_quantized_prefer(
const ConvBiasImpl::NCBKernSizeParam& param) const {
// fallback::ConvBiasImpl::NCBKernParam conv_ncb_param;
fallback::ConvBiasImpl::NCBKernSizeParam conv_ncb_param(
param, 0, param::MatrixMul::Format::DEFAULT, {}, 0,
BiasMode::NO_BIAS, param::ConvBias::NonlineMode::IDENTITY);


+ 2
- 2
dnn/src/common/cpuinfo_arch_vendor.cpp View File

@@ -9,8 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#ifdef MGB_ENABLE_CPUINFO_CHECK
#include "src/common/utils.h"
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO

#include "cpuinfo_arch_vendor.h"



+ 2
- 2
dnn/src/common/cpuinfo_arch_vendor.h View File

@@ -11,8 +11,8 @@
*/

#pragma once
#ifdef MGB_ENABLE_CPUINFO_CHECK
#include "src/common/utils.h"
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO

#include <cpuinfo.h>



+ 106
- 5
dnn/test/aarch64/matrix_mul.cpp View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/aarch64/fixture.h"

@@ -16,6 +17,7 @@
#include "test/common/matrix_mul.h"
#include "test/common/rng.h"

#include "test/arm_common/cpuinfo_help.h"
using namespace megdnn;
using namespace test;

@@ -24,6 +26,20 @@ TEST_F(AARCH64, MATRIX_MUL_FP32K8X12) {
dtype::Float32{}, handle(),
"AARCH64_F32K8X12X1");
}
#if MGB_ENABLE_CPUINFO
TEST_F(AARCH64, MATRIX_MUL_FP32K8X12_A53) {
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53);
matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},
dtype::Float32{}, handle(),
"AARCH64_F32K8X12X1");
}
TEST_F(AARCH64, MATRIX_MUL_FP32K8X12_A55) {
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55);
matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},
dtype::Float32{}, handle(),
"AARCH64_F32K8X12X1");
}
#endif

TEST_F(AARCH64, MATRIX_MUL_FP32K4X16) {
matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},
@@ -36,6 +52,20 @@ TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4) {
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
"AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1);
}
#if MGB_ENABLE_CPUINFO
TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4_A53) {
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53);
matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
"AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1);
}
TEST_F(AARCH64, MATRIX_MUL_FP32_PACK_MK4_A55) {
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55);
matrix_mul::check_matrix_mul(
dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(),
"AARCH64_F32_MK4_K8X12X1", param::MatrixMul::Format::MK4, 1);
}
#endif

TEST_F(AARCH64, MATRIX_MUL_FP32_MK4) {
matrix_mul::check_matrix_mul(
@@ -92,6 +122,18 @@ TEST_F(AARCH64, MATRIX_MUL_INT8_MK4) {
std::move(args));
}

TEST_F(AARCH64, MATRIX_MUL_MK4_8x8x16_4x4) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
handle(), "AARCH64_INT8X8X16_MK4_4X4X8",
param::MatrixMul::Format::MK4, 1);
}

TEST_F(AARCH64, MATRIX_MUL_MK4_8x8x16) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int16{},
handle(), "AARCH64_INT8X8X16_MK4_16X12X4",
param::MatrixMul::Format::MK4, 1);
}

TEST_F(AARCH64, MATRIX_MUL_INT8x8x32_K8x8x8) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
handle(), "AARCH64_INT8X8X32_K8X8X8");
@@ -172,6 +214,7 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_FP32_K4X16) {
};

run(256, 256, 128);
run(384, 384, 384);

for (size_t k = 4; k <= 256; k *= 8) {
for (size_t m = 4; m <= 256; m *= 4) {
@@ -235,7 +278,7 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_8X8X8) {
int32_used / int_used);
};

run(256, 256, 128);
run(256, 256, 256);

for (size_t k = 4; k <= 256; k *= 8) {
for (size_t m = 4; m <= 256; m *= 4) {
@@ -297,6 +340,62 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT32_MK_4X4X16) {
}
}

TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_MK4_8x8x16) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
param.transposeA = false;
param.transposeB = false;
Benchmarker<MatrixMul> benchmarker(handle());
Benchmarker<MatrixMul> benchmarker_mk4(handle());
Benchmarker<MatrixMul> benchmarker_mk4_16x12(handle());
benchmarker.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_param(param)
.set_display(false);
benchmarker.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_K4X4X16"));

param.format = MatrixMul::Param::Format::MK4;
benchmarker_mk4.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_MK4_4X4X8"));
benchmarker_mk4.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_param(param)
.set_display(false);

benchmarker_mk4_16x12.set_before_exec_callback(
AlgoChecker<MatrixMul>("AARCH64_INT8X8X16_MK4_16X12X4"));
benchmarker_mk4_16x12.set_times(RUNS)
.set_dtype(0, dtype::Int8{})
.set_dtype(1, dtype::Int8{})
.set_dtype(2, dtype::Int16{})
.set_param(param)
.set_display(false);

auto run = [&](size_t M, size_t N, size_t K) {
auto default_used = benchmarker.exec({{M, K}, {K, N}, {}}) / RUNS;
auto mk_used = benchmarker_mk4.exec(
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
RUNS;
auto mk4_16x12_used =
benchmarker_mk4_16x12.exec(
{{M / 4, K / 4, 4, 4}, {K / 4, N, 4}, {}}) /
RUNS;
float computations = 2.f * M * K * N * 1e-6;
printf("run: {%zu{M} %zu{K} %zu{N}} normal: %f ms %f Gflops mk4: %f ms "
"%f Gflops speedup: %f, mk4_16x12 %f Gflops speedup: %f\n",
M, K, N, default_used, computations / default_used, mk_used,
computations / mk_used, default_used / mk_used,
computations / mk4_16x12_used, default_used / mk4_16x12_used);
};

run(384, 384, 384);
}

TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_4X4X16) {
constexpr size_t RUNS = 50;
param::MatrixMul param;
@@ -350,9 +449,11 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT16_4X4X16) {

run(256, 256, 128);

for (size_t k = 4; k <= 16; k *= 2) {
for (size_t m = 4; m <= 64; m *= 2) {
for (size_t n = 4; n <= 64; n *= 2) {
run(256, 256, 256);

for (size_t k = 4; k <= 256; k *= 4) {
for (size_t m = 4; m <= 256; m *= 4) {
for (size_t n = 4; n <= 256; n *= 4) {
run(m, n, k);
}
}


+ 8
- 2
dnn/test/arm_common/conv_bias.cpp View File

@@ -736,15 +736,21 @@ TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x32) {
}
#endif

#if MEGDNN_ARMV7
TEST_F(ARM_COMMON, BENCHMARK_NCHW_VS_NCHW44_INT8x8x16) {
#if MEGDNN_ARMV7
const char* default_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8";
const char* mk4_algo = "IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4";
printf("compare %s vs %s \n", default_algo, mk4_algo);
BENCHMARK_IM2COL_NCHW44_VS_NCHW(default_algo, mk4_algo, handle(), 3,
dtype::Int8(), dtype::Int16());
}
#else
const char* default_algo = "IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16";
const char* mk4_algo = "IM2COLMATMUL:AARCH64_INT8X8X16_MK4_4X4X8";
printf("compare %s vs %s \n", default_algo, mk4_algo);
BENCHMARK_IM2COL_NCHW44_VS_NCHW(default_algo, mk4_algo, handle(), 3,
dtype::Int8(), dtype::Int16());
#endif
}

TEST_F(ARM_COMMON, BENCHMARK_GROUP_CONV_NCHW44_INT8x8x32_VS_INT8x8x16_STRIDE1) {
BENCHMARK_GROUPCONV_NCHW44_int8x8x16VS_int8x8x32("S8_CHAN_WISE_STRD1_NCHW44",


+ 132
- 35
dnn/test/arm_common/conv_bias_multi_thread.cpp View File

@@ -14,6 +14,8 @@
#include "test/common/benchmarker.h"
#include "test/common/conv_bias.h"

#include "test/arm_common/cpuinfo_help.h"

using namespace megdnn;
using namespace test;
using namespace conv_bias;
@@ -487,11 +489,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS,
handle(), "S8_CHAN_WISE_STRD2_NCHW44");
}

TEST_F(ARM_COMMON,
CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) {
TEST_F(ARM_COMMON, CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) {
Checker<ConvBias> checker(handle());
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>("S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
"S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
checker.set_dtype(0, dtype::Int8());
checker.set_dtype(1, dtype::Int8());
checker.set_dtype(2, dtype::Int16());
@@ -505,8 +506,8 @@ TEST_F(ARM_COMMON,
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT2_NCHW44) {
Checker<ConvBias> checker(handle());
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>("S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
"S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
checker.set_dtype(0, dtype::Int8());
checker.set_dtype(1, dtype::Int8());
checker.set_dtype(2, dtype::Int16());
@@ -1803,8 +1804,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE2_PREPROCESS) {
handle(), nullptr, 0.001, dtype::Float32(), dtype::Float32(), \
dtype::Float32(), dtype::Float32(), name);
#if MEGDNN_AARCH64
cb("IM2COLMATMUL:AARCH64_F32K8X12X1")
cb("IM2COLMATMUL:AARCH64_F32K4X16X1")
cb("IM2COLMATMUL:AARCH64_F32K8X12X1") cb("IM2COLMATMUL:AARCH64_F32K4X16X1")
#elif MEGDNN_ARMV7
cb("IM2COLMATMUL:ARMV7_F32")
#endif
@@ -1858,6 +1858,94 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_STRIDE1) {
#undef cb
}

//! CPUINFO ralated test
#if MEGDNN_AARCH64
#if MGB_ENABLE_CPUINFO
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_A55) {
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55);
#define cb(name,stride) \
check_conv_bias( \
get_conv_bias_args({2, 3, 4, 5, 6, 7}, stride, false, false, false), \
handle(), name);

cb("IM2COLMATMUL:AARCH64_F32K8X12X1", 1)
cb("IM2COLMATMUL:AARCH64_F32K8X12X1", 2)
#undef cb
}
#endif
#endif

#if MEGDNN_AARCH64
#if MGB_ENABLE_CPUINFO
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COL_FP32_A53) {
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53);
#define cb(name,stride) \
check_conv_bias( \
get_conv_bias_args({2, 3, 4, 5, 6, 7}, stride, false, false, false), \
handle(), name);

cb("IM2COLMATMUL:AARCH64_F32K8X12X1", 1)
cb("IM2COLMATMUL:AARCH64_F32K8X12X1", 2)
#undef cb
}
#endif
#endif

#if MEGDNN_AARCH64
#if MGB_ENABLE_CPUINFO
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_MK4_PACK_F32_A55) {
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55);
using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
{2, 3, 7}, 1, false, false, false, false, false, true, true);
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
args = get_nchw44_conv_bias_args(
{2, 3, 7}, 2, false, false, false, false, false, true, true);
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
}
#endif
#endif

#if MEGDNN_AARCH64
#if MGB_ENABLE_CPUINFO
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_MK4_PACK_F32_A53) {
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53);
using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
{2, 3, 7}, 1, false, false, false, false, false, true, true);
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
args = get_nchw44_conv_bias_args(
{2, 3, 7}, 2, false, false, false, false, false, true, true);
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
}
#endif
#endif


#if MEGDNN_AARCH64
#if MGB_ENABLE_CPUINFO
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_MK4_PACK_F32_A55) {
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a55);
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({1}, 1, true, false, false);
check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24");
}
#endif
#endif

#if MEGDNN_AARCH64
#if MGB_ENABLE_CPUINFO
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_MK4_PACK_F32_A53) {
CpuInfoTmpReplace cpu_replace_guard(cpuinfo_uarch_cortex_a53);
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({1}, 1, true, false, false);
check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24");
}
#endif
#endif

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) {
UniformIntRNG rng{-50, 50};

@@ -2216,7 +2304,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) {
#undef cb
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32_FILTERPREPROCESS) {
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_IM2COLMATMUL_QUINT8x8x32_FILTERPREPROCESS) {
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
#define cb(name) \
@@ -2247,7 +2336,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32_FILTERPREPRO
#undef cb
}


TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) {
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
@@ -2276,19 +2364,21 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) {
#if MEGDNN_AARCH64
cb("IM2COLMATMUL:AARCH64_INT8X8X16_K8X8X8");
cb("IM2COLMATMUL:AARCH64_INT8X8X16_K4X4X16");
cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16");
cb_nchw44("IM2COLMATMUL:AARCH64_INT8X8X16_MK4_4X4X8");
cb_nchw44("IM2COLMATMUL:AARCH64_INT8X8X16_MK4_16X12X4");
#elif MEGDNN_ARMV7
cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16");
cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X8X8");
cb("IM2COLMATMUL:ARMV7_INT8X8X16_K4X2X16");
cb_nchw44("IM2COLMATMUL:ARMV7_INT8X8X16_MK4_K8X8X4");
#endif
cb("IM2COLMATMUL:ARM_COMMON_INT8X8X16");

#undef cb
#undef cb_nchw44
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCESS) {
TEST_F(ARM_COMMON_MULTI_THREADS,
CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCESS) {
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
#define cb(name) \
@@ -2311,7 +2401,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCES
#undef cb
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPREPROCESS) {
TEST_F(ARM_COMMON_MULTI_THREADS,
CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPREPROCESS) {
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
#define cb(name) \
@@ -2415,8 +2506,9 @@ void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args,
}
}

void checker_conv_bias_int8x8x32_preprocess(std::vector<conv_bias::TestArg> args,
Handle* handle, const char* algo_name) {
void checker_conv_bias_int8x8x32_preprocess(
std::vector<conv_bias::TestArg> args, Handle* handle,
const char* algo_name) {
using namespace conv_bias;

Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
@@ -2461,7 +2553,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) {
#undef cb
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) {
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true);
@@ -2490,7 +2583,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) {
#undef cb
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1_PREPROCESS) {
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1_PREPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({3, 4, 6}, 1, false, true, true);
@@ -2541,7 +2635,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS,
#undef cb
}


TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_S1) {
UniformIntRNG rng{-50, 50};
@@ -2678,7 +2771,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
#undef cb
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X32_FILTER_PREPROCESS) {
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_IM2COLMATMUL_INT8X8X32_FILTER_PREPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true);
@@ -2722,7 +2816,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32_PREPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
{2, 4, 7}, 1, false, false, false, false, false, true,true);
{2, 4, 7}, 1, false, false, false, false, false, true, true);
#define cb(name) \
check_conv_bias_preprocess(args, handle(), nullptr, 0.001, \
dtype::Float32(), dtype::Float32(), \
@@ -2748,7 +2842,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) {
#undef cb
}

TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32_FUSE_PREPROCESS) {
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_IM2COL_S2_MK4_PACK_F32_FUSE_PREPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
{3}, 2, false, false, false, false, false, true, true, false);
@@ -2884,12 +2979,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16_PREPROCESS) {
NormalRNG rng(1);
#if MEGDNN_AARCH64
check_conv_bias_preprocess(args, handle(), &rng, 0.03, dtype::Float16{},
dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
"CONV1x1:AARCH64_F16_K8X24X1:48");
dtype::Float16{}, dtype::Float16{},
dtype::Float16{},
"CONV1x1:AARCH64_F16_K8X24X1:48");
#elif MEGDNN_ARMV7
check_conv_bias_preprocess(args, handle(), &rng, 0.03, dtype::Float16{},
dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
"CONV1x1:AARCH32_F16_K4X16X1:24");
dtype::Float16{}, dtype::Float16{},
dtype::Float16{},
"CONV1x1:AARCH32_F16_K4X16X1:24");
#endif
}

@@ -2951,7 +3048,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM_PREPROCESS) {
#undef cb
}


#if MEGDNN_AARCH64 || MEGDNN_ARMV7
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) {
UniformIntRNG rng{-50, 50};
@@ -3074,7 +3170,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32_PREPROCESS) {
cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24");
#endif
#undef cb

}

TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) {
@@ -3095,6 +3190,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) {
#if MEGDNN_AARCH64
cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24");
cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24");
cb_nchw44("CONV1x1:AARCH64_INT8X8X16_MK4_4X4X8:48");
cb_nchw44("CONV1x1:AARCH64_INT8X8X16_MK4_16X12X4:48");
#elif MEGDNN_ARMV7
cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24");
cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48");
@@ -3128,11 +3225,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16_PREPROCESS) {
#if MEGDNN_AARCH64
cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24");
cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24");
cb("CONV1x1:ARM_COMMON_INT8X8X16:24");//!add nopack test
cb("CONV1x1:ARM_COMMON_INT8X8X16:24"); //! add nopack test
#elif MEGDNN_ARMV7
cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24");
cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48");
cb("CONV1x1:ARM_COMMON_INT8X8X16:24");//!add nopack test
cb("CONV1x1:ARM_COMMON_INT8X8X16:24"); //! add nopack test
#endif
#undef cb
}
@@ -3245,11 +3342,11 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4_PREPROCESS) {

UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
#define cb(name) \
check_conv_bias_preprocess(get_nchw44_conv_bias_args({1}, 1, true, false, false), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name);
#define cb(name) \
check_conv_bias_preprocess( \
get_nchw44_conv_bias_args({1}, 1, true, false, false), handle(), \
&rng, epsilon, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), \
dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), name);
#if MEGDNN_AARCH64
cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24");
#elif MEGDNN_ARMV7


+ 2
- 5
dnn/test/arm_common/cpuinfo.cpp View File

@@ -9,7 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#ifdef MGB_ENABLE_CPUINFO_CHECK
#include "src/common/utils.h"
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO
#include <cpuinfo.h>
#include <inttypes.h>
#include "gtest/gtest.h"
@@ -18,7 +19,6 @@ namespace megdnn {
namespace test {

TEST(ARM_RUNTIME, CPUINFO_KIRIN980) {

ASSERT_TRUE(cpuinfo_initialize());

int right_soc = strcmp(cpuinfo_get_package(0)->name, "HiSilicon Kirin 980");
@@ -68,7 +68,6 @@ TEST(ARM_RUNTIME, CPUINFO_KIRIN980) {
}

TEST(ARM_RUNTIME, CPUINFO_SDM8150) {

ASSERT_TRUE(cpuinfo_initialize());

int right_soc =
@@ -119,7 +118,6 @@ TEST(ARM_RUNTIME, CPUINFO_SDM8150) {
}

TEST(ARM_RUNTIME, CPUINFO_SDM660) {

ASSERT_TRUE(cpuinfo_initialize());

int right_soc =
@@ -173,4 +171,3 @@ TEST(ARM_RUNTIME, CPUINFO_SDM660) {
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen


+ 17
- 0
dnn/test/arm_common/cpuinfo_help.cpp View File

@@ -0,0 +1,17 @@
/**
* \file dnn/test/arm_common/cpuinfo_help.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/common/utils.h"
#include "test/arm_common/cpuinfo_help.h"
#if MGB_ENABLE_CPUINFO
std::mutex CpuInfoTmpReplace::m_cpuinfo_lock;
#endif
// vim: syntax=cpp.doxygen

+ 47
- 0
dnn/test/arm_common/cpuinfo_help.h View File

@@ -0,0 +1,47 @@
/**
* \file dnn/test/arm_common/cpuinfo_help.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include <mutex>
#include <vector>
#include "src/common/utils.h"

#if MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
extern const struct cpuinfo_core** cpuinfo_linux_cpu_to_core_map;
class CpuInfoTmpReplace {
public:
CpuInfoTmpReplace(enum cpuinfo_uarch arch) {
m_cpuinfo_lock.lock();
for (uint32_t i = 0; i < cpuinfo_get_cores_count(); ++i) {
m_arch_bak_vec.push_back(cpuinfo_linux_cpu_to_core_map[i]->uarch);
((struct cpuinfo_core**)cpuinfo_linux_cpu_to_core_map)[i]->uarch =
arch;
}
}
~CpuInfoTmpReplace() {
if (m_arch_bak_vec.size() > 0) {
for (uint32_t i = 0; i < cpuinfo_get_cores_count(); ++i) {
((struct cpuinfo_core**)cpuinfo_linux_cpu_to_core_map)[i]
->uarch = m_arch_bak_vec[i];
}
}

m_cpuinfo_lock.unlock();
}

private:
static std::mutex m_cpuinfo_lock;
std::vector<cpuinfo_uarch> m_arch_bak_vec;
};
#endif

// vim: syntax=cpp.doxygen

+ 2
- 4
dnn/test/x86/cpuinfo.cpp View File

@@ -9,7 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#ifdef MGB_ENABLE_CPUINFO_CHECK
#include "src/common/utils.h"
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO
#include <cpuinfo.h>
#include <inttypes.h>
#include "gtest/gtest.h"
@@ -18,14 +19,12 @@ namespace megdnn {
namespace test {

TEST(X86_RUNTIME, CPUINFO_XEON6130) {

ASSERT_TRUE(cpuinfo_initialize());

int right_cpu =
strcmp(cpuinfo_get_package(0)->name, "Intel Xeon Gold 6130");

if (!right_cpu) {

ASSERT_TRUE(cpuinfo_get_processors());

ASSERT_TRUE(cpuinfo_has_x86_avx2());
@@ -44,4 +43,3 @@ TEST(X86_RUNTIME, CPUINFO_XEON6130) {
} // namespace megdnn
#endif
// vim: syntax=cpp.doxygen


Loading…
Cancel
Save