GitOrigin-RevId: 2b98867e45
tags/v0.5.0
@@ -14,7 +14,6 @@ | |||||
#include "src/aarch64/matrix_mul/fp32/strategy.h" | #include "src/aarch64/matrix_mul/fp32/strategy.h" | ||||
#include "src/aarch64/matrix_mul/int16/strategy.h" | #include "src/aarch64/matrix_mul/int16/strategy.h" | ||||
#include "src/aarch64/matrix_mul/int8/strategy.h" | #include "src/aarch64/matrix_mul/int8/strategy.h" | ||||
#include "src/aarch64/matrix_mul/int8_dot/gemv.h" | |||||
#include "src/aarch64/matrix_mul/int8_dot/strategy.h" | #include "src/aarch64/matrix_mul/int8_dot/strategy.h" | ||||
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | #include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | ||||
#include "src/aarch64/matrix_mul/quint8/strategy.h" | #include "src/aarch64/matrix_mul/quint8/strategy.h" | ||||
@@ -441,39 +440,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd, | |||||
"AlgoInt8x8x32K8x12x4DotProdImpl"_hash, | "AlgoInt8x8x32K8x12x4DotProdImpl"_hash, | ||||
aarch64::matmul::gemm_s8_8x12, int8_t, | aarch64::matmul::gemm_s8_8x12, int8_t, | ||||
int32_t); | int32_t); | ||||
/* ===================== Int8x8x32 Gemv DotProd algo ===================== */ | |||||
namespace { | |||||
void int8x8x32_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||||
auto Cptr = kern_param.C<dt_int32>(); | |||||
aarch64::matmul::gemv_like_int8(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
} | |||||
} // anonymous namespace | |||||
bool MatrixMulImpl::AlgoInt8x8x32GemvDotProd::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
return can_be_treated_as_int8x8x32(kern_size_param) && | |||||
!kern_size_param.trA && !kern_size_param.trB && | |||||
kern_size_param.N == 1 && kern_size_param.LDB == 1; | |||||
} | |||||
bool MatrixMulImpl::AlgoInt8x8x32GemvDotProd::preferred( | |||||
const KernSizeParam& kern_size_param) const { | |||||
auto N = kern_size_param.N, LDB = kern_size_param.LDB; | |||||
return (N == 1 && LDB == 1); | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvDotProd::get_kern( | |||||
const KernSizeParam&) const { | |||||
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||||
midout_iv("AlgoInt8x8x32GemvDotProd::get_kern"_hash)) { | |||||
return int8x8x32_gemv_dotprod_kern; | |||||
} | |||||
MIDOUT_END(); | |||||
return nullptr; | |||||
} | |||||
/* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */ | /* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */ | ||||
namespace { | namespace { | ||||
@@ -104,21 +104,6 @@ public: | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoInt8x8x32GemvDotProd final : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
return "AARCH64_INT8X8X32_GEMV_DOTPROD"; | |||||
} | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
}; | |||||
class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
@@ -174,10 +159,6 @@ public: | |||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoInt8x8x32Gemv final | |||||
: public arm_common::MatrixMulImpl::AlgoInt8x8x32Gemv {}; | |||||
#endif | #endif | ||||
class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { | ||||
@@ -1,116 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/aarch64/matrix_mul/int8_dot/gemv.h" | |||||
#include <cstddef> | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/common/unroll_macro.h" | |||||
#if __ARM_FEATURE_DOTPROD | |||||
namespace { | |||||
void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride) { | |||||
megdnn_assert(N == 1 && Bstride == 1); | |||||
size_t m = 0; | |||||
for (; m + 2 <= M; m += 2) { | |||||
int32_t acc[4]; | |||||
int32x4_t acc_neon = vdupq_n_s32(0); | |||||
size_t k = 0; | |||||
for (; k + 16 <= K; k += 16) { | |||||
int64x2_t a0 = vreinterpretq_s64_s8(vld1q_s8(A + m * Astride + k)); | |||||
int64x2_t a1 = | |||||
vreinterpretq_s64_s8(vld1q_s8(A + (m + 1) * Astride + k)); | |||||
//! the first 8 elements is m, the last 8 elements is m + 1 | |||||
int8x16_t a2 = vreinterpretq_s8_s64(vzip1q_s64(a0, a1)); | |||||
int8x16_t a3 = vreinterpretq_s8_s64(vzip2q_s64(a0, a1)); | |||||
int64x2_t b0 = vreinterpretq_s64_s8(vld1q_s8(B + k)); | |||||
int8x16_t b2 = vreinterpretq_s8_s64(vzip1q_s64(b0, b0)); | |||||
int8x16_t b3 = vreinterpretq_s8_s64(vzip2q_s64(b0, b0)); | |||||
acc_neon = vdotq_s32(acc_neon, a2, b2); | |||||
acc_neon = vdotq_s32(acc_neon, a3, b3); | |||||
} | |||||
vst1q_s32(acc, acc_neon); | |||||
for (; k + 8 <= K; k += 8) { | |||||
int8x8_t a0 = vld1_s8(A + m * Astride + k); | |||||
int8x8_t a1 = vld1_s8(A + (m + 1) * Astride + k); | |||||
int8x8_t b0 = vld1_s8(B + k); | |||||
uint32x2_t zero = vdup_n_s32(0); | |||||
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); | |||||
zero = vdup_n_s32(0); | |||||
acc[3] += vaddv_s32(vdot_s32(zero, a1, b0)); | |||||
} | |||||
for (; k < K; ++k) { | |||||
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||||
acc[3] += static_cast<int32_t>(A[(m + 1) * Astride + k]) * B[k]; | |||||
} | |||||
C[m * Cstride] = acc[0] + acc[1]; | |||||
C[(m + 1) * Cstride] = acc[2] + acc[3]; | |||||
} | |||||
for (; m < M; ++m) { | |||||
int32_t acc[4]; | |||||
int32x4_t acc_neon = vdupq_n_s32(0); | |||||
size_t k = 0; | |||||
for (; k + 16 <= K; k += 16) { | |||||
int8x16_t a0 = vld1q_s8(A + m * Astride + k); | |||||
int8x16_t b0 = vld1q_s8(B + k); | |||||
acc_neon = vdotq_s32(acc_neon, a0, b0); | |||||
} | |||||
vst1q_s32(acc, acc_neon); | |||||
for (; k + 8 <= K; k += 8) { | |||||
int8x8_t a0 = vld1_s8(A + m * Astride + k); | |||||
int8x8_t b0 = vld1_s8(B + k); | |||||
uint32x2_t zero = vdup_n_s32(0); | |||||
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); | |||||
} | |||||
for (; k < K; ++k) { | |||||
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||||
} | |||||
C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3]; | |||||
} | |||||
} | |||||
} // namespace | |||||
bool megdnn::aarch64::matmul::is_gemv_like_preferred_int8( | |||||
bool transposeA, bool transposeB, size_t M, size_t N, size_t K, | |||||
size_t /* LDA */, size_t LDB, size_t /* LDC */) { | |||||
if (transposeA) | |||||
return false; | |||||
if (transposeB) | |||||
return false; | |||||
MEGDNN_MARK_USED_VAR(K); | |||||
MEGDNN_MARK_USED_VAR(M); | |||||
return (N == 1 && LDB == 1); | |||||
} | |||||
void megdnn::aarch64::matmul::gemv_like_int8(const int8_t* __restrict A, | |||||
const int8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, | |||||
size_t N, size_t K, size_t Astride, | |||||
size_t Bstride, size_t Cstride) { | |||||
megdnn_assert(N == 1); | |||||
return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -1,34 +0,0 @@ | |||||
/** | |||||
* \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <cstddef> | |||||
#include <cstdint> | |||||
#if __ARM_FEATURE_DOTPROD | |||||
namespace megdnn { | |||||
namespace aarch64 { | |||||
namespace matmul { | |||||
bool is_gemv_like_preferred_int8(bool transposeA, bool transposeB, size_t M, | |||||
size_t N, size_t K, size_t LDA, size_t LDB, | |||||
size_t LDC); | |||||
void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride); | |||||
} // namespace matmul | |||||
} // namespace aarch64 | |||||
} // namespace megdnn | |||||
#endif | |||||
// vim: syntax=cpp.doxygen |
@@ -28,13 +28,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
#endif | #endif | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod; | AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod; | ||||
AlgoInt8x8x32GemvDotProd int8x8x32_gemv_dotprod; | |||||
AlgoInt8x8x32MK4_8x12x4DotProd int8x8x32_mk4_8x12x4_dotprod; | AlgoInt8x8x32MK4_8x12x4DotProd int8x8x32_mk4_8x12x4_dotprod; | ||||
#else | #else | ||||
AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16; | AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16; | ||||
AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16; | AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16; | ||||
AlgoInt8x8x32K8x8x8 int8x8x32_k8x8x8; | AlgoInt8x8x32K8x8x8 int8x8x32_k8x8x8; | ||||
AlgoInt8x8x32Gemv int8x8x32_gemv; | |||||
#endif | #endif | ||||
AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; | AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; | ||||
AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; | AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; | ||||
@@ -63,11 +61,9 @@ public: | |||||
all_algos.emplace_back(&f16_mk8_8x8); | all_algos.emplace_back(&f16_mk8_8x8); | ||||
#endif | #endif | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
all_algos.emplace_back(&int8x8x32_gemv_dotprod); | |||||
all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); | all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); | ||||
all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); | all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); | ||||
#else | #else | ||||
all_algos.emplace_back(&int8x8x32_gemv); | |||||
all_algos.emplace_back(&int8x8x32_k4x4x16); | all_algos.emplace_back(&int8x8x32_k4x4x16); | ||||
all_algos.emplace_back(&int8x8x32_k8x8x8); | all_algos.emplace_back(&int8x8x32_k8x8x8); | ||||
all_algos.emplace_back(&int8x8x32_mk4_4x4x16); | all_algos.emplace_back(&int8x8x32_mk4_4x4x16); | ||||
@@ -34,14 +34,12 @@ private: | |||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel | class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel | ||||
// 8x12x4 DotProduct | // 8x12x4 DotProduct | ||||
class AlgoInt8x8x32GemvDotProd; // Aarch64 Int8x8x32 Gemv DotProduct | |||||
class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel | class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel | ||||
// 8x12x4 DotProduct | // 8x12x4 DotProduct | ||||
#else | #else | ||||
class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 | class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 | ||||
class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 | class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 | ||||
class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 | class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 | ||||
class AlgoInt8x8x32Gemv; // Aarch64 Int8x8x32 Gemv | |||||
#endif | #endif | ||||
class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 | class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 | ||||
class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 | class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 | ||||
@@ -72,7 +72,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern( | |||||
return exec_int_8x8x16; | return exec_int_8x8x16; | ||||
} | } | ||||
#if !__ARM_FEATURE_DOTPROD | |||||
/* ===================== Int8x8x32 Gemv algo ===================== */ | /* ===================== Int8x8x32 Gemv algo ===================== */ | ||||
namespace { | namespace { | ||||
void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
@@ -102,7 +101,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern( | |||||
const KernSizeParam&) const { | const KernSizeParam&) const { | ||||
return int8x8x32_gemv_kern; | return int8x8x32_gemv_kern; | ||||
} | } | ||||
#endif | |||||
/* ===================== F32 Gemv algo ===================== */ | /* ===================== F32 Gemv algo ===================== */ | ||||
namespace { | namespace { | ||||
@@ -112,7 +110,6 @@ void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
const auto Aptr = kern_param.A<dt_float32>(), | const auto Aptr = kern_param.A<dt_float32>(), | ||||
Bptr = kern_param.B<dt_float32>(); | Bptr = kern_param.B<dt_float32>(); | ||||
auto Cptr = kern_param.C<dt_float32>(); | auto Cptr = kern_param.C<dt_float32>(); | ||||
arm_common::sgemm_sgemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | arm_common::sgemm_sgemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | ||||
} | } | ||||
} // anonymous namespace | } // anonymous namespace | ||||
@@ -27,11 +27,7 @@ public: | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
}; | }; | ||||
#if !__ARM_FEATURE_DOTPROD | |||||
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | ||||
protected: | |||||
~AlgoInt8x8x32Gemv() = default; | |||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } | const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } | ||||
@@ -43,7 +39,6 @@ public: | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
}; | }; | ||||
#endif | |||||
class MatrixMulImpl::AlgoF32Gemv : public AlgoBase { | class MatrixMulImpl::AlgoF32Gemv : public AlgoBase { | ||||
protected: | protected: | ||||
@@ -9,8 +9,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#if !__ARM_FEATURE_DOTPROD | |||||
#include <cstddef> | #include <cstddef> | ||||
#include "src/arm_common/matrix_mul/int8/gemv.h" | #include "src/arm_common/matrix_mul/int8/gemv.h" | ||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
@@ -23,6 +21,8 @@ MIDOUT_DECL(megdnn_arm_common_int8_gemv) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | using namespace arm_common; | ||||
#if !__ARM_FEATURE_DOTPROD | |||||
namespace { | namespace { | ||||
void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | ||||
@@ -95,8 +95,82 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
C[m * Cstride] = acc0; | C[m * Cstride] = acc0; | ||||
} | } | ||||
} | } | ||||
} // namespace | |||||
#endif | |||||
#if __ARM_FEATURE_DOTPROD | |||||
namespace { | |||||
void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride) { | |||||
megdnn_assert(N == 1 && Bstride == 1); | |||||
size_t m = 0; | |||||
for (; m + 2 <= M; m += 2) { | |||||
int32_t acc[4]; | |||||
int32x4_t acc_neon = vdupq_n_s32(0); | |||||
size_t k = 0; | |||||
for (; k + 16 <= K; k += 16) { | |||||
int64x2_t a0 = vreinterpretq_s64_s8(vld1q_s8(A + m * Astride + k)); | |||||
int64x2_t a1 = | |||||
vreinterpretq_s64_s8(vld1q_s8(A + (m + 1) * Astride + k)); | |||||
//! the first 8 elements is m, the last 8 elements is m + 1 | |||||
int8x16_t a2 = vreinterpretq_s8_s64(vzip1q_s64(a0, a1)); | |||||
int8x16_t a3 = vreinterpretq_s8_s64(vzip2q_s64(a0, a1)); | |||||
int64x2_t b0 = vreinterpretq_s64_s8(vld1q_s8(B + k)); | |||||
int8x16_t b2 = vreinterpretq_s8_s64(vzip1q_s64(b0, b0)); | |||||
int8x16_t b3 = vreinterpretq_s8_s64(vzip2q_s64(b0, b0)); | |||||
acc_neon = vdotq_s32(acc_neon, a2, b2); | |||||
acc_neon = vdotq_s32(acc_neon, a3, b3); | |||||
} | |||||
vst1q_s32(acc, acc_neon); | |||||
for (; k + 8 <= K; k += 8) { | |||||
int8x8_t a0 = vld1_s8(A + m * Astride + k); | |||||
int8x8_t a1 = vld1_s8(A + (m + 1) * Astride + k); | |||||
int8x8_t b0 = vld1_s8(B + k); | |||||
uint32x2_t zero = vdup_n_s32(0); | |||||
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); | |||||
zero = vdup_n_s32(0); | |||||
acc[3] += vaddv_s32(vdot_s32(zero, a1, b0)); | |||||
} | |||||
for (; k < K; ++k) { | |||||
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||||
acc[3] += static_cast<int32_t>(A[(m + 1) * Astride + k]) * B[k]; | |||||
} | |||||
C[m * Cstride] = acc[0] + acc[1]; | |||||
C[(m + 1) * Cstride] = acc[2] + acc[3]; | |||||
} | |||||
for (; m < M; ++m) { | |||||
int32_t acc[4]; | |||||
int32x4_t acc_neon = vdupq_n_s32(0); | |||||
size_t k = 0; | |||||
for (; k + 16 <= K; k += 16) { | |||||
int8x16_t a0 = vld1q_s8(A + m * Astride + k); | |||||
int8x16_t b0 = vld1q_s8(B + k); | |||||
acc_neon = vdotq_s32(acc_neon, a0, b0); | |||||
} | |||||
vst1q_s32(acc, acc_neon); | |||||
for (; k + 8 <= K; k += 8) { | |||||
int8x8_t a0 = vld1_s8(A + m * Astride + k); | |||||
int8x8_t b0 = vld1_s8(B + k); | |||||
uint32x2_t zero = vdup_n_s32(0); | |||||
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); | |||||
} | |||||
for (; k < K; ++k) { | |||||
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||||
} | |||||
C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3]; | |||||
} | |||||
} | |||||
} // namespace | } // namespace | ||||
#endif | |||||
bool matmul::is_gemv_like_preferred_int8(bool transposeA, bool transposeB, | bool matmul::is_gemv_like_preferred_int8(bool transposeA, bool transposeB, | ||||
size_t M, size_t N, size_t K, | size_t M, size_t N, size_t K, | ||||
@@ -124,6 +198,5 @@ void matmul::gemv_like_int8(const int8_t* __restrict A, | |||||
} MIDOUT_END(); | } MIDOUT_END(); | ||||
} | } | ||||
#endif | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -13,7 +13,6 @@ | |||||
#include <cstddef> | #include <cstddef> | ||||
#include <cstdint> | #include <cstdint> | ||||
#if !__ARM_FEATURE_DOTPROD | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | namespace arm_common { | ||||
namespace matmul { | namespace matmul { | ||||
@@ -28,6 +27,6 @@ void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
} // namespace matmul | } // namespace matmul | ||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
#endif | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -27,13 +27,14 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
AlgoF16Gemv f16gemv; | AlgoF16Gemv f16gemv; | ||||
#endif | #endif | ||||
AlgoInt8x8x32Gemv int8x8x32_gemv; | |||||
public: | public: | ||||
AlgoPack() { | AlgoPack() { | ||||
all_algos.emplace_back(&int8x8x16); | all_algos.emplace_back(&int8x8x16); | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
all_algos.emplace_back(&f16gemv); | all_algos.emplace_back(&f16gemv); | ||||
#endif | #endif | ||||
all_algos.emplace_back(&int8x8x32_gemv); | |||||
} | } | ||||
SmallVector<AlgoBase*> all_algos; | SmallVector<AlgoBase*> all_algos; | ||||
}; | }; | ||||
@@ -25,9 +25,7 @@ public: | |||||
protected: | protected: | ||||
static void* const sm_arm_common_algo_type; | static void* const sm_arm_common_algo_type; | ||||
#if !__ARM_FEATURE_DOTPROD | |||||
class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv | class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv | ||||
#endif | |||||
class AlgoF32Gemv; // Arm_common F32 Gemv | class AlgoF32Gemv; // Arm_common F32 Gemv | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
class AlgoF16Gemv; | class AlgoF16Gemv; | ||||
@@ -388,6 +388,19 @@ __ai int64x2_t vmovl_high_s32(int32x4_t __p0) { | |||||
__ai uint64x2_t vmovl_high_u32(uint32x4_t __p0) { | __ai uint64x2_t vmovl_high_u32(uint32x4_t __p0) { | ||||
return vmovl_u32(vget_high_u32(__p0)); | return vmovl_u32(vget_high_u32(__p0)); | ||||
} | } | ||||
__ai int64x2_t vzip1q_s64(int64x2_t& a, int64x2_t& b) { | |||||
return vcombine_s64(vget_low_s64(a), vget_low_s64(b)); | |||||
} | |||||
__ai int64x2_t vzip2q_s64(int64x2_t& a, int64x2_t& b) { | |||||
return vcombine_s64(vget_high_s64(a), vget_high_s64(b)); | |||||
} | |||||
__ai int32_t vaddv_s32(int32x2_t a) { | |||||
return vget_lane_s32(a, 0) + vget_lane_s32(a, 1); | |||||
} | |||||
#endif // MEGDNN_ARMV7 | #endif // MEGDNN_ARMV7 | ||||
//! pack vmovl_low_xx() on armv7 and armv8 | //! pack vmovl_low_xx() on armv7 and armv8 | ||||
@@ -134,11 +134,6 @@ public: | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
#if !__ARM_FEATURE_DOTPROD | |||||
class MatrixMulImpl::AlgoInt8x8x32Gemv final | |||||
: public arm_common::MatrixMulImpl::AlgoInt8x8x32Gemv {}; | |||||
#endif | |||||
class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
@@ -35,9 +35,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
AlgoInt8x8x32MK4_4x2x16 int8x8x32_mk4_4x2x16; | AlgoInt8x8x32MK4_4x2x16 int8x8x32_mk4_4x2x16; | ||||
AlgoInt8x8x32K4x2x16 int8x8x32_k4x2x16; | AlgoInt8x8x32K4x2x16 int8x8x32_k4x2x16; | ||||
AlgoInt8x8x32K4x8x8 int8x8x32_k4x8x8; | AlgoInt8x8x32K4x8x8 int8x8x32_k4x8x8; | ||||
#if !__ARM_FEATURE_DOTPROD | |||||
AlgoInt8x8x32Gemv int8x8x32_gemv; | |||||
#endif | |||||
AlgoQuint8K4x8x8 quint8_k4x8x8; | AlgoQuint8K4x8x8 quint8_k4x8x8; | ||||
AlgoInt8x8x16K4x2x16 int8x8x16_k4x2x16; | AlgoInt8x8x16K4x2x16 int8x8x16_k4x2x16; | ||||
AlgoInt8x8x16K4x8x8 int8x8x16_k4x8x8; | AlgoInt8x8x16K4x8x8 int8x8x16_k4x8x8; | ||||
@@ -61,9 +58,6 @@ public: | |||||
all_algos.emplace_back(&int8_k6x8x4); | all_algos.emplace_back(&int8_k6x8x4); | ||||
all_algos.emplace_back(&quint8_k4x8x4); | all_algos.emplace_back(&quint8_k4x8x4); | ||||
#endif | #endif | ||||
#if !__ARM_FEATURE_DOTPROD | |||||
all_algos.emplace_back(&int8x8x32_gemv); | |||||
#endif | |||||
all_algos.emplace_back(&int8x8x32_mk4_4x2x16); | all_algos.emplace_back(&int8x8x32_mk4_4x2x16); | ||||
all_algos.emplace_back(&int8x8x32_k4x2x16); | all_algos.emplace_back(&int8x8x32_k4x2x16); | ||||
all_algos.emplace_back(&int8x8x32_k4x8x8); | all_algos.emplace_back(&int8x8x32_k4x8x8); | ||||
@@ -27,9 +27,6 @@ private: | |||||
class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 | class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 | ||||
class AlgoInt8x8x32K4x2x16; // Armv7 Int8x8x32 Kernel 4x2x16 | class AlgoInt8x8x32K4x2x16; // Armv7 Int8x8x32 Kernel 4x2x16 | ||||
class AlgoInt8x8x32MK4_4x2x16; // Armv7 Int8x8x32 Kernel MK4 4x2x16 | class AlgoInt8x8x32MK4_4x2x16; // Armv7 Int8x8x32 Kernel MK4 4x2x16 | ||||
#if !__ARM_FEATURE_DOTPROD | |||||
class AlgoInt8x8x32Gemv; // Armv7 Int8x8x32 Gemv | |||||
#endif | |||||
class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8 | class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8 | ||||
class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16 | class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16 | ||||
class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8 | class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8 | ||||
@@ -133,6 +133,36 @@ TEST_F(ARM_COMMON, MATRIX_MUL_FP16_TEST) { | |||||
} | } | ||||
#endif | #endif | ||||
TEST_F(ARM_COMMON, QINT8x8x32_GEMV) { | |||||
Checker<MatrixMul> checker(handle()); | |||||
using Param = MatrixMul::Param; | |||||
checker.set_before_exec_callback( | |||||
AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEMV")); | |||||
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127); | |||||
checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | |||||
auto run = [&](size_t M, size_t K, size_t N) { | |||||
Param param; | |||||
param.transposeA = false; | |||||
param.transposeB = false; | |||||
TensorShape A, B; | |||||
A = TensorShape{M, K}; | |||||
B = TensorShape{K, N}; | |||||
checker.set_param(param) | |||||
.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||||
.execs({A, B, {}}); | |||||
}; | |||||
// N = 1 | |||||
for (size_t M : {1, 10, 16, 33, 64}) | |||||
for (size_t K : {7, 512, 1024}) | |||||
for (size_t N : {1}) | |||||
run(M, K, N); | |||||
} | |||||
#if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||