GitOrigin-RevId: 2b98867e45
tags/v0.5.0
@@ -14,7 +14,6 @@ | |||
#include "src/aarch64/matrix_mul/fp32/strategy.h" | |||
#include "src/aarch64/matrix_mul/int16/strategy.h" | |||
#include "src/aarch64/matrix_mul/int8/strategy.h" | |||
#include "src/aarch64/matrix_mul/int8_dot/gemv.h" | |||
#include "src/aarch64/matrix_mul/int8_dot/strategy.h" | |||
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | |||
#include "src/aarch64/matrix_mul/quint8/strategy.h" | |||
@@ -441,39 +440,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd, | |||
"AlgoInt8x8x32K8x12x4DotProdImpl"_hash, | |||
aarch64::matmul::gemm_s8_8x12, int8_t, | |||
int32_t); | |||
/* ===================== Int8x8x32 Gemv DotProd algo ===================== */ | |||
namespace { | |||
void int8x8x32_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int32>(); | |||
aarch64::matmul::gemv_like_int8(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
} | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoInt8x8x32GemvDotProd::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
return can_be_treated_as_int8x8x32(kern_size_param) && | |||
!kern_size_param.trA && !kern_size_param.trB && | |||
kern_size_param.N == 1 && kern_size_param.LDB == 1; | |||
} | |||
bool MatrixMulImpl::AlgoInt8x8x32GemvDotProd::preferred( | |||
const KernSizeParam& kern_size_param) const { | |||
auto N = kern_size_param.N, LDB = kern_size_param.LDB; | |||
return (N == 1 && LDB == 1); | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvDotProd::get_kern( | |||
const KernSizeParam&) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, | |||
midout_iv("AlgoInt8x8x32GemvDotProd::get_kern"_hash)) { | |||
return int8x8x32_gemv_dotprod_kern; | |||
} | |||
MIDOUT_END(); | |||
return nullptr; | |||
} | |||
/* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */ | |||
namespace { | |||
@@ -104,21 +104,6 @@ public: | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32GemvDotProd final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return "AARCH64_INT8X8X32_GEMV_DOTPROD"; | |||
} | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
@@ -174,10 +159,6 @@ public: | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32Gemv final | |||
: public arm_common::MatrixMulImpl::AlgoInt8x8x32Gemv {}; | |||
#endif | |||
class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { | |||
@@ -1,116 +0,0 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/aarch64/matrix_mul/int8_dot/gemv.h" | |||
#include <cstddef> | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/utils.h" | |||
#include "src/common/unroll_macro.h" | |||
#if __ARM_FEATURE_DOTPROD | |||
namespace { | |||
void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride) { | |||
megdnn_assert(N == 1 && Bstride == 1); | |||
size_t m = 0; | |||
for (; m + 2 <= M; m += 2) { | |||
int32_t acc[4]; | |||
int32x4_t acc_neon = vdupq_n_s32(0); | |||
size_t k = 0; | |||
for (; k + 16 <= K; k += 16) { | |||
int64x2_t a0 = vreinterpretq_s64_s8(vld1q_s8(A + m * Astride + k)); | |||
int64x2_t a1 = | |||
vreinterpretq_s64_s8(vld1q_s8(A + (m + 1) * Astride + k)); | |||
//! the first 8 elements is m, the last 8 elements is m + 1 | |||
int8x16_t a2 = vreinterpretq_s8_s64(vzip1q_s64(a0, a1)); | |||
int8x16_t a3 = vreinterpretq_s8_s64(vzip2q_s64(a0, a1)); | |||
int64x2_t b0 = vreinterpretq_s64_s8(vld1q_s8(B + k)); | |||
int8x16_t b2 = vreinterpretq_s8_s64(vzip1q_s64(b0, b0)); | |||
int8x16_t b3 = vreinterpretq_s8_s64(vzip2q_s64(b0, b0)); | |||
acc_neon = vdotq_s32(acc_neon, a2, b2); | |||
acc_neon = vdotq_s32(acc_neon, a3, b3); | |||
} | |||
vst1q_s32(acc, acc_neon); | |||
for (; k + 8 <= K; k += 8) { | |||
int8x8_t a0 = vld1_s8(A + m * Astride + k); | |||
int8x8_t a1 = vld1_s8(A + (m + 1) * Astride + k); | |||
int8x8_t b0 = vld1_s8(B + k); | |||
uint32x2_t zero = vdup_n_s32(0); | |||
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); | |||
zero = vdup_n_s32(0); | |||
acc[3] += vaddv_s32(vdot_s32(zero, a1, b0)); | |||
} | |||
for (; k < K; ++k) { | |||
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||
acc[3] += static_cast<int32_t>(A[(m + 1) * Astride + k]) * B[k]; | |||
} | |||
C[m * Cstride] = acc[0] + acc[1]; | |||
C[(m + 1) * Cstride] = acc[2] + acc[3]; | |||
} | |||
for (; m < M; ++m) { | |||
int32_t acc[4]; | |||
int32x4_t acc_neon = vdupq_n_s32(0); | |||
size_t k = 0; | |||
for (; k + 16 <= K; k += 16) { | |||
int8x16_t a0 = vld1q_s8(A + m * Astride + k); | |||
int8x16_t b0 = vld1q_s8(B + k); | |||
acc_neon = vdotq_s32(acc_neon, a0, b0); | |||
} | |||
vst1q_s32(acc, acc_neon); | |||
for (; k + 8 <= K; k += 8) { | |||
int8x8_t a0 = vld1_s8(A + m * Astride + k); | |||
int8x8_t b0 = vld1_s8(B + k); | |||
uint32x2_t zero = vdup_n_s32(0); | |||
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); | |||
} | |||
for (; k < K; ++k) { | |||
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||
} | |||
C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3]; | |||
} | |||
} | |||
} // namespace | |||
bool megdnn::aarch64::matmul::is_gemv_like_preferred_int8( | |||
bool transposeA, bool transposeB, size_t M, size_t N, size_t K, | |||
size_t /* LDA */, size_t LDB, size_t /* LDC */) { | |||
if (transposeA) | |||
return false; | |||
if (transposeB) | |||
return false; | |||
MEGDNN_MARK_USED_VAR(K); | |||
MEGDNN_MARK_USED_VAR(M); | |||
return (N == 1 && LDB == 1); | |||
} | |||
void megdnn::aarch64::matmul::gemv_like_int8(const int8_t* __restrict A, | |||
const int8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, | |||
size_t N, size_t K, size_t Astride, | |||
size_t Bstride, size_t Cstride) { | |||
megdnn_assert(N == 1); | |||
return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -1,34 +0,0 @@ | |||
/** | |||
* \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <cstddef> | |||
#include <cstdint> | |||
#if __ARM_FEATURE_DOTPROD | |||
namespace megdnn { | |||
namespace aarch64 { | |||
namespace matmul { | |||
bool is_gemv_like_preferred_int8(bool transposeA, bool transposeB, size_t M, | |||
size_t N, size_t K, size_t LDA, size_t LDB, | |||
size_t LDC); | |||
void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride); | |||
} // namespace matmul | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -28,13 +28,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod; | |||
AlgoInt8x8x32GemvDotProd int8x8x32_gemv_dotprod; | |||
AlgoInt8x8x32MK4_8x12x4DotProd int8x8x32_mk4_8x12x4_dotprod; | |||
#else | |||
AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16; | |||
AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16; | |||
AlgoInt8x8x32K8x8x8 int8x8x32_k8x8x8; | |||
AlgoInt8x8x32Gemv int8x8x32_gemv; | |||
#endif | |||
AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; | |||
AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; | |||
@@ -63,11 +61,9 @@ public: | |||
all_algos.emplace_back(&f16_mk8_8x8); | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
all_algos.emplace_back(&int8x8x32_gemv_dotprod); | |||
all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); | |||
all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); | |||
#else | |||
all_algos.emplace_back(&int8x8x32_gemv); | |||
all_algos.emplace_back(&int8x8x32_k4x4x16); | |||
all_algos.emplace_back(&int8x8x32_k8x8x8); | |||
all_algos.emplace_back(&int8x8x32_mk4_4x4x16); | |||
@@ -34,14 +34,12 @@ private: | |||
#if __ARM_FEATURE_DOTPROD | |||
class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel | |||
// 8x12x4 DotProduct | |||
class AlgoInt8x8x32GemvDotProd; // Aarch64 Int8x8x32 Gemv DotProduct | |||
class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel | |||
// 8x12x4 DotProduct | |||
#else | |||
class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 | |||
class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 | |||
class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 | |||
class AlgoInt8x8x32Gemv; // Aarch64 Int8x8x32 Gemv | |||
#endif | |||
class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 | |||
class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 | |||
@@ -72,7 +72,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern( | |||
return exec_int_8x8x16; | |||
} | |||
#if !__ARM_FEATURE_DOTPROD | |||
/* ===================== Int8x8x32 Gemv algo ===================== */ | |||
namespace { | |||
void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
@@ -102,7 +101,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern( | |||
const KernSizeParam&) const { | |||
return int8x8x32_gemv_kern; | |||
} | |||
#endif | |||
/* ===================== F32 Gemv algo ===================== */ | |||
namespace { | |||
@@ -112,7 +110,6 @@ void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
const auto Aptr = kern_param.A<dt_float32>(), | |||
Bptr = kern_param.B<dt_float32>(); | |||
auto Cptr = kern_param.C<dt_float32>(); | |||
arm_common::sgemm_sgemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
} | |||
} // anonymous namespace | |||
@@ -27,11 +27,7 @@ public: | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
}; | |||
#if !__ARM_FEATURE_DOTPROD | |||
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | |||
protected: | |||
~AlgoInt8x8x32Gemv() = default; | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } | |||
@@ -43,7 +39,6 @@ public: | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
}; | |||
#endif | |||
class MatrixMulImpl::AlgoF32Gemv : public AlgoBase { | |||
protected: | |||
@@ -9,8 +9,6 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#if !__ARM_FEATURE_DOTPROD | |||
#include <cstddef> | |||
#include "src/arm_common/matrix_mul/int8/gemv.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
@@ -23,6 +21,8 @@ MIDOUT_DECL(megdnn_arm_common_int8_gemv) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
#if !__ARM_FEATURE_DOTPROD | |||
namespace { | |||
void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||
@@ -95,8 +95,82 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||
C[m * Cstride] = acc0; | |||
} | |||
} | |||
} // namespace | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
namespace { | |||
void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride) { | |||
megdnn_assert(N == 1 && Bstride == 1); | |||
size_t m = 0; | |||
for (; m + 2 <= M; m += 2) { | |||
int32_t acc[4]; | |||
int32x4_t acc_neon = vdupq_n_s32(0); | |||
size_t k = 0; | |||
for (; k + 16 <= K; k += 16) { | |||
int64x2_t a0 = vreinterpretq_s64_s8(vld1q_s8(A + m * Astride + k)); | |||
int64x2_t a1 = | |||
vreinterpretq_s64_s8(vld1q_s8(A + (m + 1) * Astride + k)); | |||
//! the first 8 elements is m, the last 8 elements is m + 1 | |||
int8x16_t a2 = vreinterpretq_s8_s64(vzip1q_s64(a0, a1)); | |||
int8x16_t a3 = vreinterpretq_s8_s64(vzip2q_s64(a0, a1)); | |||
int64x2_t b0 = vreinterpretq_s64_s8(vld1q_s8(B + k)); | |||
int8x16_t b2 = vreinterpretq_s8_s64(vzip1q_s64(b0, b0)); | |||
int8x16_t b3 = vreinterpretq_s8_s64(vzip2q_s64(b0, b0)); | |||
acc_neon = vdotq_s32(acc_neon, a2, b2); | |||
acc_neon = vdotq_s32(acc_neon, a3, b3); | |||
} | |||
vst1q_s32(acc, acc_neon); | |||
for (; k + 8 <= K; k += 8) { | |||
int8x8_t a0 = vld1_s8(A + m * Astride + k); | |||
int8x8_t a1 = vld1_s8(A + (m + 1) * Astride + k); | |||
int8x8_t b0 = vld1_s8(B + k); | |||
uint32x2_t zero = vdup_n_s32(0); | |||
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); | |||
zero = vdup_n_s32(0); | |||
acc[3] += vaddv_s32(vdot_s32(zero, a1, b0)); | |||
} | |||
for (; k < K; ++k) { | |||
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||
acc[3] += static_cast<int32_t>(A[(m + 1) * Astride + k]) * B[k]; | |||
} | |||
C[m * Cstride] = acc[0] + acc[1]; | |||
C[(m + 1) * Cstride] = acc[2] + acc[3]; | |||
} | |||
for (; m < M; ++m) { | |||
int32_t acc[4]; | |||
int32x4_t acc_neon = vdupq_n_s32(0); | |||
size_t k = 0; | |||
for (; k + 16 <= K; k += 16) { | |||
int8x16_t a0 = vld1q_s8(A + m * Astride + k); | |||
int8x16_t b0 = vld1q_s8(B + k); | |||
acc_neon = vdotq_s32(acc_neon, a0, b0); | |||
} | |||
vst1q_s32(acc, acc_neon); | |||
for (; k + 8 <= K; k += 8) { | |||
int8x8_t a0 = vld1_s8(A + m * Astride + k); | |||
int8x8_t b0 = vld1_s8(B + k); | |||
uint32x2_t zero = vdup_n_s32(0); | |||
acc[0] += vaddv_s32(vdot_s32(zero, a0, b0)); | |||
} | |||
for (; k < K; ++k) { | |||
acc[0] += static_cast<int32_t>(A[m * Astride + k]) * B[k]; | |||
} | |||
C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3]; | |||
} | |||
} | |||
} // namespace | |||
#endif | |||
bool matmul::is_gemv_like_preferred_int8(bool transposeA, bool transposeB, | |||
size_t M, size_t N, size_t K, | |||
@@ -124,6 +198,5 @@ void matmul::gemv_like_int8(const int8_t* __restrict A, | |||
} MIDOUT_END(); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -13,7 +13,6 @@ | |||
#include <cstddef> | |||
#include <cstdint> | |||
#if !__ARM_FEATURE_DOTPROD | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace matmul { | |||
@@ -28,6 +27,6 @@ void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B, | |||
} // namespace matmul | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -27,13 +27,14 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
AlgoF16Gemv f16gemv; | |||
#endif | |||
AlgoInt8x8x32Gemv int8x8x32_gemv; | |||
public: | |||
AlgoPack() { | |||
all_algos.emplace_back(&int8x8x16); | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
all_algos.emplace_back(&f16gemv); | |||
#endif | |||
all_algos.emplace_back(&int8x8x32_gemv); | |||
} | |||
SmallVector<AlgoBase*> all_algos; | |||
}; | |||
@@ -25,9 +25,7 @@ public: | |||
protected: | |||
static void* const sm_arm_common_algo_type; | |||
#if !__ARM_FEATURE_DOTPROD | |||
class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv | |||
#endif | |||
class AlgoF32Gemv; // Arm_common F32 Gemv | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
class AlgoF16Gemv; | |||
@@ -388,6 +388,19 @@ __ai int64x2_t vmovl_high_s32(int32x4_t __p0) { | |||
__ai uint64x2_t vmovl_high_u32(uint32x4_t __p0) { | |||
return vmovl_u32(vget_high_u32(__p0)); | |||
} | |||
__ai int64x2_t vzip1q_s64(int64x2_t& a, int64x2_t& b) { | |||
return vcombine_s64(vget_low_s64(a), vget_low_s64(b)); | |||
} | |||
__ai int64x2_t vzip2q_s64(int64x2_t& a, int64x2_t& b) { | |||
return vcombine_s64(vget_high_s64(a), vget_high_s64(b)); | |||
} | |||
__ai int32_t vaddv_s32(int32x2_t a) { | |||
return vget_lane_s32(a, 0) + vget_lane_s32(a, 1); | |||
} | |||
#endif // MEGDNN_ARMV7 | |||
//! pack vmovl_low_xx() on armv7 and armv8 | |||
@@ -134,11 +134,6 @@ public: | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | |||
}; | |||
#if !__ARM_FEATURE_DOTPROD | |||
class MatrixMulImpl::AlgoInt8x8x32Gemv final | |||
: public arm_common::MatrixMulImpl::AlgoInt8x8x32Gemv {}; | |||
#endif | |||
class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
@@ -35,9 +35,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
AlgoInt8x8x32MK4_4x2x16 int8x8x32_mk4_4x2x16; | |||
AlgoInt8x8x32K4x2x16 int8x8x32_k4x2x16; | |||
AlgoInt8x8x32K4x8x8 int8x8x32_k4x8x8; | |||
#if !__ARM_FEATURE_DOTPROD | |||
AlgoInt8x8x32Gemv int8x8x32_gemv; | |||
#endif | |||
AlgoQuint8K4x8x8 quint8_k4x8x8; | |||
AlgoInt8x8x16K4x2x16 int8x8x16_k4x2x16; | |||
AlgoInt8x8x16K4x8x8 int8x8x16_k4x8x8; | |||
@@ -61,9 +58,6 @@ public: | |||
all_algos.emplace_back(&int8_k6x8x4); | |||
all_algos.emplace_back(&quint8_k4x8x4); | |||
#endif | |||
#if !__ARM_FEATURE_DOTPROD | |||
all_algos.emplace_back(&int8x8x32_gemv); | |||
#endif | |||
all_algos.emplace_back(&int8x8x32_mk4_4x2x16); | |||
all_algos.emplace_back(&int8x8x32_k4x2x16); | |||
all_algos.emplace_back(&int8x8x32_k4x8x8); | |||
@@ -27,9 +27,6 @@ private: | |||
class AlgoInt8x8x32K4x8x8; // Armv7 Int8x8x32 Kernel 4x8x8 | |||
class AlgoInt8x8x32K4x2x16; // Armv7 Int8x8x32 Kernel 4x2x16 | |||
class AlgoInt8x8x32MK4_4x2x16; // Armv7 Int8x8x32 Kernel MK4 4x2x16 | |||
#if !__ARM_FEATURE_DOTPROD | |||
class AlgoInt8x8x32Gemv; // Armv7 Int8x8x32 Gemv | |||
#endif | |||
class AlgoQuint8K4x8x8; // Armv7 Quint8 Kernel 4x8x8 | |||
class AlgoInt8x8x16K4x2x16; // Armv7 Int8x8x16 Kernel 4x2x16 | |||
class AlgoInt8x8x16K4x8x8; // Armv7 Int8x8x16 Kernel 4x8x8 | |||
@@ -133,6 +133,36 @@ TEST_F(ARM_COMMON, MATRIX_MUL_FP16_TEST) { | |||
} | |||
#endif | |||
TEST_F(ARM_COMMON, QINT8x8x32_GEMV) { | |||
Checker<MatrixMul> checker(handle()); | |||
using Param = MatrixMul::Param; | |||
checker.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEMV")); | |||
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127); | |||
checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | |||
auto run = [&](size_t M, size_t K, size_t N) { | |||
Param param; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
TensorShape A, B; | |||
A = TensorShape{M, K}; | |||
B = TensorShape{K, N}; | |||
checker.set_param(param) | |||
.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||
.execs({A, B, {}}); | |||
}; | |||
// N = 1 | |||
for (size_t M : {1, 10, 16, 33, 64}) | |||
for (size_t K : {7, 512, 1024}) | |||
for (size_t N : {1}) | |||
run(M, K, N); | |||
} | |||
#if MEGDNN_WITH_BENCHMARK | |||