GitOrigin-RevId: f8b6d7a1b7
release-0.6
@@ -210,27 +210,33 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||||
DEFAULT \ | DEFAULT \ | ||||
} | } | ||||
#define FOR_BIAS(_bias_mode, OH, OW) \ | |||||
switch (_bias_mode) { \ | |||||
case megdnn::BiasMode::NO_BIAS: \ | |||||
FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY); \ | |||||
break; \ | |||||
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
if (pack_oc_size == 1) { \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | |||||
} else { \ | |||||
megdnn_assert(pack_oc_size == 4, \ | |||||
"Only support nchw44 in ARM"); \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ | |||||
} \ | |||||
break; \ | |||||
default: \ | |||||
if (OH * OW == 1) { \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | |||||
break; \ | |||||
} \ | |||||
megdnn_throw("quantized unsupported biasmode"); \ | |||||
break; \ | |||||
#define FOR_BIAS(_bias_mode, OH, OW) \ | |||||
switch (_bias_mode) { \ | |||||
case megdnn::BiasMode::NO_BIAS: \ | |||||
FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY); \ | |||||
break; \ | |||||
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||||
if (pack_oc_size == 1) { \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | |||||
} else { \ | |||||
megdnn_assert(pack_oc_size == 4, \ | |||||
"Only support nchw44 in ARM"); \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ | |||||
} \ | |||||
break; \ | |||||
default: \ | |||||
if (OH * OW == 1) { \ | |||||
if (pack_oc_size == 1) { \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | |||||
} else { \ | |||||
megdnn_assert(pack_oc_size == 4, \ | |||||
"Only support nchw44 in ARM"); \ | |||||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ | |||||
} \ | |||||
break; \ | |||||
} \ | |||||
megdnn_throw("quantized unsupported biasmode"); \ | |||||
break; \ | |||||
} | } | ||||
template <typename opctype, typename opdtype> | template <typename opctype, typename opdtype> | ||||
@@ -101,6 +101,91 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern( | |||||
return int8x8x32_gemv_kern; | return int8x8x32_gemv_kern; | ||||
} | } | ||||
/* ===================== Int8x8x32 Gemv MK4 algo ===================== */ | |||||
namespace { | |||||
void int8x8x32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||||
auto Cptr = kern_param.C<dt_int32>(); | |||||
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
} | |||||
} // anonymous namespace | |||||
bool MatrixMulImpl::AlgoInt8x8x32GemvMK4::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
auto M = kern_size_param.M; | |||||
auto N = kern_size_param.N; | |||||
auto K = kern_size_param.K; | |||||
auto LDB = kern_size_param.LDB; | |||||
bool is_dtype_ok = | |||||
kern_size_param.A_type == kern_size_param.B_type && | |||||
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | |||||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && | |||||
(kern_size_param.C_type.enumv() == DTypeEnum::Int32 || | |||||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); | |||||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||||
is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB && | |||||
M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; | |||||
} | |||||
bool MatrixMulImpl::AlgoInt8x8x32GemvMK4::preferred( | |||||
const KernSizeParam& kern_size_param) const { | |||||
MEGDNN_MARK_USED_VAR(kern_size_param); | |||||
return true; | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern( | |||||
const KernSizeParam&) const { | |||||
return int8x8x32_gemv_mk4_kern; | |||||
} | |||||
#if __ARM_FEATURE_DOTPROD | |||||
/* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ | |||||
namespace { | |||||
void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||||
auto Cptr = kern_param.C<dt_int32>(); | |||||
gemv_like_mk4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
} | |||||
} // anonymous namespace | |||||
bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
auto M = kern_size_param.M; | |||||
auto N = kern_size_param.N; | |||||
auto K = kern_size_param.K; | |||||
auto LDB = kern_size_param.LDB; | |||||
bool is_dtype_ok = | |||||
kern_size_param.A_type == kern_size_param.B_type && | |||||
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | |||||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && | |||||
(kern_size_param.C_type.enumv() == DTypeEnum::Int32 || | |||||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); | |||||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
kern_size_param.format == param::MatrixMul::Format::MK4_DOT && | |||||
is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB && | |||||
M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; | |||||
} | |||||
bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::preferred( | |||||
const KernSizeParam& kern_size_param) const { | |||||
return true; | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::get_kern( | |||||
const KernSizeParam&) const { | |||||
return int8x8x32_gemv_mk4_dot_kern; | |||||
} | |||||
#endif | |||||
/* ===================== F32 Gemv algo ===================== */ | /* ===================== F32 Gemv algo ===================== */ | ||||
namespace { | namespace { | ||||
void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
@@ -137,6 +222,46 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern( | |||||
return f32_gemv_kern; | return f32_gemv_kern; | ||||
} | } | ||||
/* ================== F32 Gemv MK4 algo ================== */ | |||||
namespace { | |||||
void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_float32>(), | |||||
Bptr = kern_param.B<dt_float32>(); | |||||
auto Cptr = kern_param.C<dt_float32>(); | |||||
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
} | |||||
} // anonymous namespace | |||||
bool MatrixMulImpl::AlgoF32GemvMK4::usable( | |||||
const KernSizeParam& kern_size_param) const { | |||||
// enumerate the M, N, K, only usable when preferred | |||||
auto M = kern_size_param.M; | |||||
auto N = kern_size_param.N; | |||||
auto K = kern_size_param.K; | |||||
auto LDB = kern_size_param.LDB; | |||||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||||
kern_size_param.B_type == kern_size_param.A_type && | |||||
kern_size_param.C_type == kern_size_param.A_type && | |||||
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | |||||
!kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && | |||||
LDB == 4; | |||||
} | |||||
bool MatrixMulImpl::AlgoF32GemvMK4::preferred( | |||||
const KernSizeParam& kern_size_param) const { | |||||
MEGDNN_MARK_USED_VAR(kern_size_param); | |||||
return true; | |||||
} | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GemvMK4::get_kern( | |||||
const KernSizeParam&) const { | |||||
return f32_gemv_mk4_kern; | |||||
} | |||||
/* ===================== F32 Gevm algo ===================== */ | /* ===================== F32 Gevm algo ===================== */ | ||||
namespace { | namespace { | ||||
template <typename stype, typename dtype> | template <typename stype, typename dtype> | ||||
@@ -43,6 +43,36 @@ public: | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||||
}; | |||||
#if __ARM_FEATURE_DOTPROD | |||||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||||
}; | |||||
#endif | |||||
class MatrixMulImpl::AlgoF32Gemv : public AlgoBase { | class MatrixMulImpl::AlgoF32Gemv : public AlgoBase { | ||||
protected: | protected: | ||||
~AlgoF32Gemv() = default; | ~AlgoF32Gemv() = default; | ||||
@@ -60,6 +90,20 @@ public: | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | |||||
public: | |||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } | |||||
bool usable(const KernSizeParam&) const override; | |||||
bool preferred(const KernSizeParam&) const override; | |||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||||
kern_t get_kern(const KernSizeParam&) const override; | |||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4) | |||||
}; | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
class MatrixMulImpl::AlgoF16Gemv : public AlgoBase { | class MatrixMulImpl::AlgoF16Gemv : public AlgoBase { | ||||
public: | public: | ||||
@@ -87,10 +131,9 @@ public: | |||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(1, 1, 1, 4) | |||||
}; | }; | ||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -13,11 +13,11 @@ | |||||
#include "src/arm_common/matrix_mul/fp32/exec_sgemv.h" | #include "src/arm_common/matrix_mul/fp32/exec_sgemv.h" | ||||
#include <cstddef> | #include <cstddef> | ||||
#include "include/megdnn/oprs.h" | #include "include/megdnn/oprs.h" | ||||
#include "midout.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/common/unroll_macro.h" | #include "src/common/unroll_macro.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "midout.h" | |||||
MIDOUT_DECL(megdnn_fp32_sgemv) | MIDOUT_DECL(megdnn_fp32_sgemv) | ||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -68,18 +68,10 @@ void sgemv_naive_n(const float* __restrict A, const float* __restrict B, | |||||
#if !defined(__aarch64__) | #if !defined(__aarch64__) | ||||
#undef vaddvq_f32 | #undef vaddvq_f32 | ||||
#endif | #endif | ||||
} // namespace | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
void gemv_like(const float* __restrict A, const float* __restrict B, | |||||
float* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride) { | |||||
megdnn_assert(M < 8 || (M == 8 && K <= 2) || (N == 1 && Bstride == 1)); | |||||
if (N == 1) { | |||||
return sgemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||||
} | |||||
void sgemv_naive_m(const float* __restrict A, const float* __restrict B, | |||||
float* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride) { | |||||
size_t m = 0; | size_t m = 0; | ||||
for (; m + 4 <= M; m += 4) { | for (; m + 4 <= M; m += 4) { | ||||
size_t k = 0; | size_t k = 0; | ||||
@@ -762,6 +754,85 @@ void gemv_like(const float* __restrict A, const float* __restrict B, | |||||
} | } | ||||
} | } | ||||
} | } | ||||
void sgemv_naive_n_mk4(const float* __restrict A, const float* __restrict B, | |||||
float* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride) { | |||||
constexpr size_t PACK_SIZE = 4; | |||||
megdnn_assert(N == 1 && Bstride == PACK_SIZE && M % PACK_SIZE == 0 && | |||||
K % PACK_SIZE == 0); | |||||
auto Aptr = A; | |||||
auto Cptr = C; | |||||
size_t m = 0; | |||||
while (m < M) { | |||||
auto Aptr0 = Aptr; | |||||
auto Cptr0 = Cptr; | |||||
float32x4_t c[4]; | |||||
#define INIT(step) c[step] = vdupq_n_f32(0.0f); | |||||
UNROLL_CALL_RAW(4, INIT) | |||||
#undef INIT | |||||
auto Bptr = B; | |||||
size_t k = 0; | |||||
while (k < K) { | |||||
float32x4_t b = vld1q_f32(Bptr); | |||||
float32x4x2_t a[2]; | |||||
#define LOAD_A(step) a[step] = vld1q_f32_x2(Aptr0 + step * 8); | |||||
UNROLL_CALL_RAW(2, LOAD_A) | |||||
#undef LOAD_A | |||||
#define COMPT(step) \ | |||||
c[step] = vfmaq_laneq_f32(c[step], a[step / 2].val[step % 2], b, step % 4); | |||||
UNROLL_CALL_RAW(4, COMPT) | |||||
#undef COMPT | |||||
Bptr += Bstride; | |||||
Aptr0 += PACK_SIZE * PACK_SIZE; | |||||
k += PACK_SIZE; | |||||
} | |||||
#define ADD_C(step, stride) c[step] = vaddq_f32(c[step], c[step + stride]); | |||||
UNROLL_CALL_RAW(2, ADD_C, 2) | |||||
UNROLL_CALL_RAW(1, ADD_C, 1) | |||||
#undef ADD_C | |||||
vst1q_f32(Cptr0, c[0]); | |||||
Aptr += Astride; | |||||
Cptr += Cstride; | |||||
m += PACK_SIZE; | |||||
} | |||||
} | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace arm_common { | |||||
void gemv_like(const float* __restrict A, const float* __restrict B, | |||||
float* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride) { | |||||
megdnn_assert(M < 8 || (M == 8 && K <= 2) || (N == 1 && Bstride == 1)); | |||||
if (N == 1) { | |||||
MIDOUT_BEGIN(megdnn_fp32_sgemv, midout_iv("F32_GEMV_NCHW_N"_hash)) { | |||||
return sgemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||||
} | |||||
MIDOUT_END(); | |||||
} else { | |||||
MIDOUT_BEGIN(megdnn_fp32_sgemv, midout_iv("F32_GEMV_NCHW_M"_hash)) { | |||||
return sgemv_naive_m(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} | |||||
void gemv_like_mk4(const float* __restrict A, const float* __restrict B, | |||||
float* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride) { | |||||
megdnn_assert(N == 1 && Bstride == 4); | |||||
MIDOUT_BEGIN(megdnn_fp32_sgemv, midout_iv("F32_GEMV_NCHW44_N"_hash)) { | |||||
return sgemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -24,6 +24,9 @@ void gemv_like(const float* __restrict A, const float* __restrict B, | |||||
float* __restrict C, size_t M, size_t N, size_t K, | float* __restrict C, size_t M, size_t N, size_t K, | ||||
size_t Astride, size_t Bstride, size_t Cstride); | size_t Astride, size_t Bstride, size_t Cstride); | ||||
void gemv_like_mk4(const float* __restrict A, const float* __restrict B, | |||||
float* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride); | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -10,8 +10,8 @@ | |||||
*/ | */ | ||||
#include <cstddef> | #include <cstddef> | ||||
#include "src/arm_common/matrix_mul/int8/gemv.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
#include "src/arm_common/matrix_mul/int8/gemv.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
@@ -95,6 +95,80 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
C[m * Cstride] = acc0; | C[m * Cstride] = acc0; | ||||
} | } | ||||
} | } | ||||
void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride) { | |||||
constexpr size_t PACK_SIZE = 4; | |||||
megdnn_assert(N == 1 && Bstride == 4); | |||||
auto Aptr = A; | |||||
size_t m = 0; | |||||
for (; m < M; m += PACK_SIZE) { | |||||
auto Bptr = B; | |||||
auto Aptr0 = Aptr; | |||||
int32_t acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0; | |||||
size_t k = 0; | |||||
for (; k + 16 <= K; k += 16) { | |||||
int8x16x4_t a = vld4q_s8(Aptr0); | |||||
int8x16_t b = vld1q_s8(Bptr); | |||||
int16x8_t c[4]; | |||||
c[0] = vmull_s8(vget_low_s8(a.val[0]), vget_low_s8(b)); | |||||
c[1] = vmull_s8(vget_low_s8(a.val[1]), vget_low_s8(b)); | |||||
c[2] = vmull_s8(vget_low_s8(a.val[2]), vget_low_s8(b)); | |||||
c[3] = vmull_s8(vget_low_s8(a.val[3]), vget_low_s8(b)); | |||||
c[0] = vmlal_high_s8(c[0], a.val[0], b); | |||||
c[1] = vmlal_high_s8(c[1], a.val[1], b); | |||||
c[2] = vmlal_high_s8(c[2], a.val[2], b); | |||||
c[3] = vmlal_high_s8(c[3], a.val[3], b); | |||||
acc0 += vaddlvq_s16(c[0]); | |||||
acc1 += vaddlvq_s16(c[1]); | |||||
acc2 += vaddlvq_s16(c[2]); | |||||
acc3 += vaddlvq_s16(c[3]); | |||||
Bptr += 16; | |||||
Aptr0 += PACK_SIZE * 16; | |||||
} | |||||
for (; k + 8 <= K; k += 8) { | |||||
int8x8x4_t a = vld4_s8(Aptr0); | |||||
int8x8_t b = vld1_s8(Bptr); | |||||
int16x8_t c[4]; | |||||
c[0] = vmull_s8(a.val[0], b); | |||||
c[1] = vmull_s8(a.val[1], b); | |||||
c[2] = vmull_s8(a.val[2], b); | |||||
c[3] = vmull_s8(a.val[3], b); | |||||
acc0 += vaddlvq_s16(c[0]); | |||||
acc1 += vaddlvq_s16(c[1]); | |||||
acc2 += vaddlvq_s16(c[2]); | |||||
acc3 += vaddlvq_s16(c[3]); | |||||
Bptr += 8; | |||||
Aptr0 += PACK_SIZE * 8; | |||||
} | |||||
for (; k < K; ++k) { | |||||
acc0 += static_cast<int32_t>(*(Aptr0 + 0)) * B[k]; | |||||
acc1 += static_cast<int32_t>(*(Aptr0 + 1)) * B[k]; | |||||
acc2 += static_cast<int32_t>(*(Aptr0 + 2)) * B[k]; | |||||
acc3 += static_cast<int32_t>(*(Aptr0 + 3)) * B[k]; | |||||
Aptr0 += 4; | |||||
} | |||||
C[0] = acc0; | |||||
C[1] = acc1; | |||||
C[2] = acc2; | |||||
C[3] = acc3; | |||||
Aptr += Astride; | |||||
C += Cstride; | |||||
} | |||||
} | |||||
} // namespace | } // namespace | ||||
#endif | #endif | ||||
@@ -169,6 +243,139 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3]; | C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3]; | ||||
} | } | ||||
} | } | ||||
void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride) { | |||||
constexpr size_t PACK_SIZE = 4; | |||||
megdnn_assert(N == 1 && Bstride == 4); | |||||
auto Aptr = A; | |||||
size_t m = 0; | |||||
for (; m < M; m += PACK_SIZE) { | |||||
auto Bptr = B; | |||||
auto Aptr0 = Aptr; | |||||
int32_t acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0; | |||||
size_t k = 0; | |||||
if (k + 16 <= K) { | |||||
int32x4_t acc_neon[4]; | |||||
acc_neon[0] = vdupq_n_s32(0); | |||||
acc_neon[1] = vdupq_n_s32(0); | |||||
acc_neon[2] = vdupq_n_s32(0); | |||||
acc_neon[3] = vdupq_n_s32(0); | |||||
for (; k + 16 <= K; k += 16) { | |||||
int8x16x4_t a = vld4q_s8(Aptr0); | |||||
int8x16_t b = vld1q_s8(Bptr); | |||||
acc_neon[0] = vdotq_s32(acc_neon[0], a.val[0], b); | |||||
acc_neon[1] = vdotq_s32(acc_neon[1], a.val[1], b); | |||||
acc_neon[2] = vdotq_s32(acc_neon[2], a.val[2], b); | |||||
acc_neon[3] = vdotq_s32(acc_neon[3], a.val[3], b); | |||||
Bptr += 16; | |||||
Aptr0 += PACK_SIZE * 16; | |||||
} | |||||
acc0 = vaddvq_s32(acc_neon[0]); | |||||
acc1 = vaddvq_s32(acc_neon[1]); | |||||
acc2 = vaddvq_s32(acc_neon[2]); | |||||
acc3 = vaddvq_s32(acc_neon[3]); | |||||
} | |||||
if (k + 8 <= K) { | |||||
int32x2_t acc_neon[4]; | |||||
acc_neon[0] = vdup_n_s32(0); | |||||
acc_neon[1] = vdup_n_s32(0); | |||||
acc_neon[2] = vdup_n_s32(0); | |||||
acc_neon[3] = vdup_n_s32(0); | |||||
int8x8x4_t a = vld4_s8(Aptr0); | |||||
int8x8_t b = vld1_s8(Bptr); | |||||
acc_neon[0] = vdot_s32(acc_neon[0], a.val[0], b); | |||||
acc_neon[1] = vdot_s32(acc_neon[1], a.val[1], b); | |||||
acc_neon[2] = vdot_s32(acc_neon[2], a.val[2], b); | |||||
acc_neon[3] = vdot_s32(acc_neon[3], a.val[3], b); | |||||
Bptr += 8; | |||||
Aptr0 += PACK_SIZE * 8; | |||||
k += 8; | |||||
acc0 += vaddv_s32(acc_neon[0]); | |||||
acc1 += vaddv_s32(acc_neon[1]); | |||||
acc2 += vaddv_s32(acc_neon[2]); | |||||
acc3 += vaddv_s32(acc_neon[3]); | |||||
} | |||||
for (; k < K; ++k) { | |||||
acc0 += static_cast<int32_t>(*(Aptr0 + 0)) * B[k]; | |||||
acc1 += static_cast<int32_t>(*(Aptr0 + 1)) * B[k]; | |||||
acc2 += static_cast<int32_t>(*(Aptr0 + 2)) * B[k]; | |||||
acc3 += static_cast<int32_t>(*(Aptr0 + 3)) * B[k]; | |||||
Aptr0 += 4; | |||||
} | |||||
C[0] = acc0; | |||||
C[1] = acc1; | |||||
C[2] = acc2; | |||||
C[3] = acc3; | |||||
Aptr += Astride; | |||||
C += Cstride; | |||||
} | |||||
} | |||||
void gemv_naive_n_mk4_dot(const int8_t* __restrict A, | |||||
const int8_t* __restrict B, int32_t* __restrict C, | |||||
size_t M, size_t N, size_t K, size_t Astride, | |||||
size_t Bstride, size_t Cstride) { | |||||
constexpr size_t PACK_SIZE = 4; | |||||
megdnn_assert(N == 1 && Bstride == 4); | |||||
auto Aptr = A; | |||||
size_t m = 0; | |||||
for (; m < M; m += PACK_SIZE) { | |||||
auto Bptr = B; | |||||
auto Aptr0 = Aptr; | |||||
size_t k = 0; | |||||
int32x4_t acc_neon; | |||||
acc_neon = vdupq_n_s32(0); | |||||
for (; k + 16 <= K; k += 16) { | |||||
int8x16_t a0 = vld1q_s8(Aptr0); | |||||
int8x16_t a1 = vld1q_s8(Aptr0 + 16); | |||||
int8x16_t a2 = vld1q_s8(Aptr0 + 32); | |||||
int8x16_t a3 = vld1q_s8(Aptr0 + 48); | |||||
int8x16_t b = vld1q_s8(Bptr); | |||||
acc_neon = vdotq_laneq_s32(acc_neon, a0, b, 0); | |||||
acc_neon = vdotq_laneq_s32(acc_neon, a1, b, 1); | |||||
acc_neon = vdotq_laneq_s32(acc_neon, a2, b, 2); | |||||
acc_neon = vdotq_laneq_s32(acc_neon, a3, b, 3); | |||||
Bptr += 16; | |||||
Aptr0 += PACK_SIZE * 16; | |||||
} | |||||
if (k + 8 <= K) { | |||||
int8x16_t a0 = vld1q_s8(Aptr0); | |||||
int8x16_t a1 = vld1q_s8(Aptr0 + 16); | |||||
int8x8_t b = vld1_s8(Bptr); | |||||
acc_neon = vdotq_lane_s32(acc_neon, a0, b, 0); | |||||
acc_neon = vdotq_lane_s32(acc_neon, a1, b, 1); | |||||
Bptr += 8; | |||||
Aptr0 += PACK_SIZE * 8; | |||||
k += 8; | |||||
} | |||||
if (k + 4 <= K) { | |||||
int8x16_t a = vld1q_s8(Aptr0); | |||||
int32_t tmp = *(reinterpret_cast<const int32_t*>(Bptr)); | |||||
int8x8_t b = vdup_n_s32(tmp); | |||||
acc_neon = vdotq_lane_s32(acc_neon, a, b, 0); | |||||
} | |||||
vst1q_s32(C, acc_neon); | |||||
Aptr += Astride; | |||||
C += Cstride; | |||||
} | |||||
} | |||||
} // namespace | } // namespace | ||||
#endif | #endif | ||||
@@ -201,4 +408,33 @@ void arm_common::gemv_like(const int8_t* __restrict A, | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
} | } | ||||
void arm_common::gemv_like_mk4(const int8_t* __restrict A, | |||||
const int8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, size_t N, | |||||
size_t K, size_t Astride, size_t Bstride, | |||||
size_t Cstride) { | |||||
megdnn_assert(N == 1); | |||||
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, | |||||
midout_iv("INT8_gemv_like_mk4"_hash)) { | |||||
return gemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
#if __ARM_FEATURE_DOTPROD | |||||
void arm_common::gemv_like_mk4_dot(const int8_t* __restrict A, | |||||
const int8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, size_t N, | |||||
size_t K, size_t Astride, size_t Bstride, | |||||
size_t Cstride) { | |||||
megdnn_assert(N == 1); | |||||
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, | |||||
midout_iv("INT8_gemv_like_mk4_dot"_hash)) { | |||||
return gemv_naive_n_mk4_dot(A, B, C, M, N, K, Astride, Bstride, | |||||
Cstride); | |||||
} | |||||
MIDOUT_END(); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -24,6 +24,16 @@ void gemv_like(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, size_t N, size_t K, | int32_t* __restrict C, size_t M, size_t N, size_t K, | ||||
size_t Astride, size_t Bstride, size_t Cstride); | size_t Astride, size_t Bstride, size_t Cstride); | ||||
void gemv_like_mk4(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride); | |||||
#if __ARM_FEATURE_DOTPROD | |||||
void gemv_like_mk4_dot(const int8_t* __restrict A, const int8_t* __restrict B, | |||||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||||
size_t Astride, size_t Bstride, size_t Cstride); | |||||
#endif | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -28,14 +28,24 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
AlgoF16Gemv f16gemv; | AlgoF16Gemv f16gemv; | ||||
#endif | #endif | ||||
AlgoInt8x8x32Gemv int8x8x32_gemv; | AlgoInt8x8x32Gemv int8x8x32_gemv; | ||||
AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; | |||||
#if __ARM_FEATURE_DOTPROD | |||||
AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; | |||||
#endif | |||||
AlgoGevm gevm; | AlgoGevm gevm; | ||||
AlgoF32GemvMK4 f32_gemv_mk4; | |||||
public: | public: | ||||
AlgoPack() { | AlgoPack() { | ||||
all_algos.emplace_back(&int8x8x16); | all_algos.emplace_back(&int8x8x16); | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
all_algos.emplace_back(&f16gemv); | all_algos.emplace_back(&f16gemv); | ||||
#endif | #endif | ||||
#if __ARM_FEATURE_DOTPROD | |||||
all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); | |||||
#endif | |||||
all_algos.emplace_back(&int8x8x32_gemv); | all_algos.emplace_back(&int8x8x32_gemv); | ||||
all_algos.emplace_back(&int8x8x32_gemv_mk4); | |||||
all_algos.emplace_back(&f32_gemv_mk4); | |||||
all_algos.emplace_back(&gevm); | all_algos.emplace_back(&gevm); | ||||
} | } | ||||
SmallVector<AlgoBase*> all_algos; | SmallVector<AlgoBase*> all_algos; | ||||
@@ -25,12 +25,17 @@ public: | |||||
protected: | protected: | ||||
static void* const sm_arm_common_algo_type; | static void* const sm_arm_common_algo_type; | ||||
class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv | |||||
class AlgoF32Gemv; // Arm_common F32 Gemv | |||||
class AlgoGevm; // Arm_common Gemv(support int8 and fp32) | |||||
class AlgoF32Gemv; // Arm_common F32 Gemv | |||||
class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44 | |||||
class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv | |||||
class AlgoInt8x8x32GemvMK4; // Arm_common Int8x8x32 Gemv NCHW44 | |||||
class AlgoGevm; // Arm_common Gevm(support int8 and fp32) | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
class AlgoF16Gemv; | class AlgoF16Gemv; | ||||
#endif | #endif | ||||
#if __ARM_FEATURE_DOTPROD | |||||
class AlgoInt8x8x32GemvMK4Dot;// Arm_common Int8x8x32 Gemv NCHW44_DOT | |||||
#endif | |||||
class AlgoInt8x8x16; // Arm_common Int 8x8x16 | class AlgoInt8x8x16; // Arm_common Int 8x8x16 | ||||
class AlgoPack; | class AlgoPack; | ||||
}; | }; | ||||
@@ -407,6 +407,16 @@ __ai int32_t vaddv_s32(int32x2_t a) { | |||||
return vget_lane_s32(a, 0) + vget_lane_s32(a, 1); | return vget_lane_s32(a, 0) + vget_lane_s32(a, 1); | ||||
} | } | ||||
__ai int32_t vaddvq_s32(int32x4_t a) { | |||||
return vgetq_lane_s32(a, 0) + vgetq_lane_s32(a, 1) + | |||||
vgetq_lane_s32(a, 2) + vgetq_lane_s32(a, 3); | |||||
} | |||||
__ai float32_t vaddvq_f32(float32x4_t a) { | |||||
return vgetq_lane_f32(a, 0) + vgetq_lane_f32(a, 1) + | |||||
vgetq_lane_f32(a, 2) + vgetq_lane_f32(a, 3); | |||||
} | |||||
#endif // MEGDNN_ARMV7 | #endif // MEGDNN_ARMV7 | ||||
//! pack vmovl_low_xx() on armv7 and armv8 | //! pack vmovl_low_xx() on armv7 and armv8 | ||||
@@ -42,14 +42,27 @@ using namespace conv1x1; | |||||
namespace { | namespace { | ||||
#if MEGDNN_X86 | |||||
template <typename stype, typename btype, param::ConvBias::Format F> | template <typename stype, typename btype, param::ConvBias::Format F> | ||||
struct GemvLike { | struct GemvLike { | ||||
inline static void do_gemv(const stype* A, const stype* B, btype* C, | inline static void do_gemv(const stype* A, const stype* B, btype* C, | ||||
size_t M, size_t N, size_t K, size_t LDA, | size_t M, size_t N, size_t K, size_t LDA, | ||||
size_t LDB, size_t LDC, DType src, | size_t LDB, size_t LDC, DType src, | ||||
DType filter) { | DType filter) { | ||||
megdnn_throw("x86 conv1x1 gemv only supports format : NCHW"); | |||||
MEGDNN_MARK_USED_VAR(A); | |||||
MEGDNN_MARK_USED_VAR(B); | |||||
MEGDNN_MARK_USED_VAR(C); | |||||
MEGDNN_MARK_USED_VAR(M); | |||||
MEGDNN_MARK_USED_VAR(N); | |||||
MEGDNN_MARK_USED_VAR(K); | |||||
MEGDNN_MARK_USED_VAR(LDA); | |||||
MEGDNN_MARK_USED_VAR(LDB); | |||||
MEGDNN_MARK_USED_VAR(LDC); | |||||
MEGDNN_MARK_USED_VAR(src); | |||||
MEGDNN_MARK_USED_VAR(filter); | |||||
megdnn_assert(false, | |||||
"unspported conv1x1 gemv : \nsrc_type : " | |||||
"%s\nfilter_type : %s\n", | |||||
src.name(), filter.name()); | |||||
} | } | ||||
}; | }; | ||||
@@ -66,39 +79,29 @@ struct GemvLike<stype, btype, param::ConvBias::Format::NCHW> { | |||||
} | } | ||||
}; | }; | ||||
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
template <typename stype, typename btype, param::ConvBias::Format F> | |||||
struct GemvLike { | |||||
inline static void do_gemv(const stype* A, const stype* B, btype* C, | |||||
size_t M, size_t N, size_t K, size_t LDA, | |||||
size_t LDB, size_t LDC, DType src, | |||||
DType filter) { | |||||
megdnn_throw("arm conv1x1 gemv only supports format : NCHW"); | |||||
} | |||||
}; | |||||
template <typename stype, typename btype> | |||||
struct GemvLike<stype, btype, param::ConvBias::Format::NCHW> { | |||||
inline static void do_gemv(const stype* A, const stype* B, btype* C, | |||||
size_t M, size_t N, size_t K, size_t LDA, | |||||
size_t LDB, size_t LDC, DType src, | |||||
template <> | |||||
struct GemvLike<dt_uint8, dt_int32, param::ConvBias::Format::NCHW> { | |||||
inline static void do_gemv(const dt_uint8* A, const dt_uint8* B, | |||||
dt_int32* C, size_t M, size_t N, size_t K, | |||||
size_t LDA, size_t LDB, size_t LDC, DType src, | |||||
DType filter) { | DType filter) { | ||||
MEGDNN_MARK_USED_VAR(src); | |||||
MEGDNN_MARK_USED_VAR(filter); | |||||
megdnn::arm_common::gemv_like(A, B, C, M, N, K, LDA, LDB, LDC); | |||||
uint8_t zp0 = src.param<dtype::Quantized8Asymm>().zero_point; | |||||
uint8_t zp1 = filter.param<dtype::Quantized8Asymm>().zero_point; | |||||
megdnn::fallback::gemv_like<dt_uint8, dt_int32>(A, B, C, M, N, K, LDA, | |||||
LDB, LDC, zp0, zp1); | |||||
} | } | ||||
}; | }; | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
template <> | template <> | ||||
struct GemvLike<dt_int8, dt_int16, param::ConvBias::Format::NCHW> { | |||||
inline static void do_gemv(const dt_int8* A, const dt_int8* B, dt_int16* C, | |||||
size_t M, size_t N, size_t K, size_t LDA, | |||||
size_t LDB, size_t LDC, DType src, | |||||
struct GemvLike<dt_float32, dt_float32, param::ConvBias::Format::NCHW> { | |||||
inline static void do_gemv(const dt_float32* A, const dt_float32* B, | |||||
dt_float32* C, size_t M, size_t N, size_t K, | |||||
size_t LDA, size_t LDB, size_t LDC, DType src, | |||||
DType filter) { | DType filter) { | ||||
MEGDNN_MARK_USED_VAR(src); | MEGDNN_MARK_USED_VAR(src); | ||||
MEGDNN_MARK_USED_VAR(filter); | MEGDNN_MARK_USED_VAR(filter); | ||||
megdnn::fallback::gemv_like<dt_int8, dt_int16>(A, B, C, M, N, K, LDA, | |||||
LDB, LDC); | |||||
megdnn::arm_common::gemv_like(A, B, C, M, N, K, LDA, LDB, LDC); | |||||
} | } | ||||
}; | }; | ||||
@@ -118,21 +121,47 @@ struct GemvLike<dt_float16, dt_float16, param::ConvBias::Format::NCHW> { | |||||
} | } | ||||
}; | }; | ||||
#endif | #endif | ||||
#endif | |||||
template <> | template <> | ||||
struct GemvLike<dt_uint8, dt_int32, param::ConvBias::Format::NCHW> { | |||||
inline static void do_gemv(const dt_uint8* A, const dt_uint8* B, | |||||
dt_int32* C, size_t M, size_t N, size_t K, | |||||
size_t LDA, size_t LDB, size_t LDC, DType src, | |||||
struct GemvLike<dt_int8, dt_int32, param::ConvBias::Format::NCHW> { | |||||
inline static void do_gemv(const dt_int8* A, const dt_int8* B, dt_int32* C, | |||||
size_t M, size_t N, size_t K, size_t LDA, | |||||
size_t LDB, size_t LDC, DType src, | |||||
DType filter) { | DType filter) { | ||||
uint8_t zp0 = src.param<dtype::Quantized8Asymm>().zero_point; | |||||
uint8_t zp1 = filter.param<dtype::Quantized8Asymm>().zero_point; | |||||
megdnn::fallback::gemv_like<dt_uint8, dt_int32>(A, B, C, M, N, K, LDA, | |||||
LDB, LDC, zp0, zp1); | |||||
MEGDNN_MARK_USED_VAR(src); | |||||
MEGDNN_MARK_USED_VAR(filter); | |||||
megdnn::arm_common::gemv_like(A, B, C, M, N, K, LDA, LDB, LDC); | |||||
} | } | ||||
}; | }; | ||||
template <typename stype, typename btype> | |||||
struct GemvLike<stype, btype, param::ConvBias::Format::NCHW44> { | |||||
inline static void do_gemv(const stype* A, const stype* B, btype* C, | |||||
size_t M, size_t N, size_t K, size_t LDA, | |||||
size_t LDB, size_t LDC, DType src, | |||||
DType filter) { | |||||
MEGDNN_MARK_USED_VAR(src); | |||||
MEGDNN_MARK_USED_VAR(filter); | |||||
megdnn::arm_common::gemv_like_mk4(A, B, C, M, N, K, LDA, LDB, LDC); | |||||
} | |||||
}; | |||||
#if __ARM_FEATURE_DOTPROD | |||||
template <typename stype, typename btype> | |||||
struct GemvLike<stype, btype, param::ConvBias::Format::NCHW44_DOT> { | |||||
inline static void do_gemv(const stype* A, const stype* B, btype* C, | |||||
size_t M, size_t N, size_t K, size_t LDA, | |||||
size_t LDB, size_t LDC, DType src, | |||||
DType filter) { | |||||
MEGDNN_MARK_USED_VAR(src); | |||||
MEGDNN_MARK_USED_VAR(filter); | |||||
megdnn::arm_common::gemv_like_mk4_dot(A, B, C, M, N, K, LDA, LDB, LDC); | |||||
} | |||||
}; | |||||
#endif | |||||
#endif | |||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
typename op_ctype, typename op_dtype, | typename op_ctype, typename op_dtype, | ||||
megdnn::PostprocessMode postprocess_mode, | megdnn::PostprocessMode postprocess_mode, | ||||
@@ -185,19 +214,18 @@ struct Conv1x1GemvWorker { | |||||
is_dst_8bit ? matmul_temp_dst | is_dst_8bit ? matmul_temp_dst | ||||
: reinterpret_cast<bias_ctype*>(conv_bias_dst); | : reinterpret_cast<bias_ctype*>(conv_bias_dst); | ||||
size_t pack_size = megdnn::fallback::pack_size(format); | |||||
GemvLike<src_ctype, bias_ctype, format>::do_gemv( | GemvLike<src_ctype, bias_ctype, format>::do_gemv( | ||||
Aptr, Bptr, gemv_dst, oc_end - oc_start, 1, IC, IC, 1, 1, | |||||
ncb_param.filter_type, ncb_param.src_type); | |||||
Aptr, Bptr, gemv_dst, oc_end - oc_start, 1, IC, IC * pack_size, | |||||
pack_size, pack_size, ncb_param.filter_type, | |||||
ncb_param.src_type); | |||||
//! do postprocess | //! do postprocess | ||||
void* bias_ptr = nullptr; | void* bias_ptr = nullptr; | ||||
if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||||
if (param.bias_mode != megdnn::BiasMode::NO_BIAS) { | |||||
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | ||||
ncb_param.bias<bias_ctype>(batch_id, group_id) + | ncb_param.bias<bias_ctype>(batch_id, group_id) + | ||||
numbers_of_ncb_dst_offset)); | numbers_of_ncb_dst_offset)); | ||||
} else { | |||||
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | |||||
ncb_param.bias<bias_ctype>(batch_id, group_id) + oc_start)); | |||||
} | } | ||||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | ||||
@@ -211,9 +239,13 @@ struct Conv1x1GemvWorker { | |||||
size_t ConvBiasImpl::AlgoConv1x1Gemv::get_oc_tile_size_heuristic( | size_t ConvBiasImpl::AlgoConv1x1Gemv::get_oc_tile_size_heuristic( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
size_t OC = param.filter_meta.ocpg; | |||||
size_t oc_block_size_one_thread = div_ceil(OC, param.nr_threads); | |||||
return round_up<size_t>(oc_block_size_one_thread, 16); | |||||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, | |||||
midout_iv("AlgoConv1x1Gemv::get_oc_tile"_hash)) { | |||||
size_t OC = param.filter_meta.ocpg; | |||||
size_t oc_block_size_one_thread = div_ceil(OC, param.nr_threads); | |||||
return round_up<size_t>(oc_block_size_one_thread, 16); | |||||
} | |||||
MIDOUT_END(); | |||||
} | } | ||||
size_t ConvBiasImpl::AlgoConv1x1Gemv::get_workspace( | size_t ConvBiasImpl::AlgoConv1x1Gemv::get_workspace( | ||||
@@ -286,6 +318,11 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
cb1(param::ConvBias::Format::NCHW, dt_float16, __fp16, | cb1(param::ConvBias::Format::NCHW, dt_float16, __fp16, | ||||
PostprocessMode::FLOAT, "NCHW::GEMV::FLOAT16_FP16"_hash); | PostprocessMode::FLOAT, "NCHW::GEMV::FLOAT16_FP16"_hash); | ||||
#else | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
cb1(param::ConvBias::Format::NCHW, dt_float16, dt_float16, | |||||
PostprocessMode::NO_PROCESS, "NCHW::GEMV::FLOAT16_FLOAT16"_hash); | |||||
#endif | |||||
#endif | #endif | ||||
cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, | cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, | ||||
dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | ||||
@@ -311,6 +348,37 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||||
"NCHW::GEMV::QUINT8x8x32_QUINT8"_hash); | "NCHW::GEMV::QUINT8x8x32_QUINT8"_hash); | ||||
break; | break; | ||||
case param::ConvBias::Format::NCHW44: | |||||
cb1(param::ConvBias::Format::NCHW44, dt_float32, dt_float32, | |||||
PostprocessMode::FLOAT, "NCHW44::GEMV::FLOAT"_hash); | |||||
cb2(param::ConvBias::Format::NCHW44, dt_int8, dt_int32, dt_int32, | |||||
dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||||
"NCHW44::GEMV::INT8x8x32_INT32"_hash); | |||||
cb2(param::ConvBias::Format::NCHW44, dtype::QuantizedS8, | |||||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | |||||
dt_int32, PostprocessMode::NO_PROCESS, | |||||
"NCHW44::GEMV::QINT8x8x32_QINT32"_hash); | |||||
cb2(param::ConvBias::Format::NCHW44, dtype::QuantizedS8, | |||||
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | |||||
dt_int8, PostprocessMode::QUANTIZED, | |||||
"NCHW44::GEMV::QINT8x8x32_QINT8"_hash); | |||||
break; | |||||
case param::ConvBias::Format::NCHW44_DOT: | |||||
cb2(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32, | |||||
dt_int32, dt_int8, dt_int32, dt_int32, | |||||
PostprocessMode::NO_PROCESS, | |||||
"NCHW44_DOT::GEMV::INT8x8x32_INT32"_hash); | |||||
cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | |||||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | |||||
dt_int32, PostprocessMode::NO_PROCESS, | |||||
"NCHW44_DOT::GEMV::QINT8x8x32_QINT32"_hash); | |||||
cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | |||||
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | |||||
dt_int8, PostprocessMode::QUANTIZED, | |||||
"NCHW44_DOT::GEMV::QINT8x8x32_QINT8"_hash); | |||||
break; | |||||
default: | default: | ||||
megdnn_throw("Invalid Format"); | megdnn_throw("Invalid Format"); | ||||
break; | break; | ||||
@@ -338,6 +406,16 @@ bool ConvBiasImpl::AlgoConv1x1Gemv::usable(ConvBiasImpl* opr, | |||||
AlgoSelectionStrategy) const { | AlgoSelectionStrategy) const { | ||||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, | MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, | ||||
midout_iv("AlgoConv1x1Gemv::usable"_hash)) { | midout_iv("AlgoConv1x1Gemv::usable"_hash)) { | ||||
#if MEGDNN_X86 | |||||
if (opr->param().format != param::ConvBias::Format::NCHW) | |||||
return false; | |||||
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
if (opr->param().format != param::ConvBias::Format::NCHW && | |||||
opr->param().format != param::ConvBias::Format::NCHW44 && | |||||
opr->param().format != param::ConvBias::Format::NCHW44_DOT) | |||||
return false; | |||||
#endif | |||||
//! whether 1x1 | //! whether 1x1 | ||||
size_t FH = param.filter_meta.spatial[0], | size_t FH = param.filter_meta.spatial[0], | ||||
FW = param.filter_meta.spatial[1]; | FW = param.filter_meta.spatial[1]; | ||||
@@ -390,59 +468,43 @@ bool ConvBiasImpl::AlgoConv1x1Gemv::usable(ConvBiasImpl* opr, | |||||
param.src_type.enumv() != DTypeEnum::Float32) { | param.src_type.enumv() != DTypeEnum::Float32) { | ||||
return false; | return false; | ||||
} | } | ||||
bool is_param_ok = | |||||
(param.filter_meta.dilation[0] == | |||||
param.filter_meta.dilation[1] && | |||||
param.filter_meta.dilation[0] == 1) && | |||||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT; | |||||
bool is_format_and_dtype_ok = false; | |||||
#if MEGDNN_X86 | |||||
if (opr->param().format == param::ConvBias::Format::NCHW) { | |||||
//! x86 supports all dtypes in NCHW | |||||
is_format_and_dtype_ok = true; | |||||
} | |||||
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
//! add NCHW44 and NCHW44_DOT support in the future | |||||
if (opr->param().format == param::ConvBias::Format::NCHW) { | |||||
//! NCHW format supports all dtype | |||||
is_format_and_dtype_ok = true; | |||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
if (opr->param().format == param::ConvBias::Format::NCHW44) { | |||||
if (param.src_type.enumv() != DTypeEnum::Float32 && | |||||
param.src_type.enumv() != DTypeEnum::Int8 && | |||||
param.src_type.enumv() != DTypeEnum::QuantizedS8) { | |||||
return false; | |||||
} | |||||
} else if (opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | |||||
if (param.src_type.enumv() != DTypeEnum::Int8 && | |||||
param.src_type.enumv() != DTypeEnum::QuantizedS8) { | |||||
return false; | |||||
} | |||||
} | } | ||||
#endif | #endif | ||||
return is_param_ok && is_format_and_dtype_ok; | |||||
return (param.filter_meta.dilation[0] == | |||||
param.filter_meta.dilation[1] && | |||||
param.filter_meta.dilation[0] == 1) && | |||||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT; | |||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return false; | return false; | ||||
} | } | ||||
bool ConvBiasImpl::AlgoConv1x1Gemv::is_preferred( | bool ConvBiasImpl::AlgoConv1x1Gemv::is_preferred( | ||||
ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||||
size_t OC = param.filter_meta.ocpg; | |||||
if (OC <= 2 && param.src_type.enumv() != DTypeEnum::Float32) | |||||
return true; | |||||
ConvBiasImpl* opr, const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, | |||||
midout_iv("AlgoConv1x1Gemv::is_preferred"_hash)) { | |||||
#if (MEGDNN_ARMV7 || MEGDNN_AARCH64) | #if (MEGDNN_ARMV7 || MEGDNN_AARCH64) | ||||
//! maybe add support for QuantizedAsym in the future | |||||
return (param.src_type.enumv() == DTypeEnum::Int8 && | |||||
param.filter_type.enumv() == DTypeEnum::Int8 && | |||||
param.dst_type.enumv() == DTypeEnum::Int32) || | |||||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | |||||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) || | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
(param.src_type.enumv() == DTypeEnum::Float16 && | |||||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||||
param.dst_type.enumv() == DTypeEnum::Float16) || | |||||
if (opr->param().format == param::ConvBias::Format::NCHW && | |||||
param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||||
return false; | |||||
} | |||||
#endif | #endif | ||||
(param.src_type.enumv() == DTypeEnum::Float32 && | |||||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||||
param.dst_type.enumv() == DTypeEnum::Float32); | |||||
#else | |||||
return true; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | return false; | ||||
#endif | |||||
} | } | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -2036,7 +2036,6 @@ void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle, | |||||
RUNS; | RUNS; | ||||
auto matmul_used = benchmark_matmul.exec({A, B, {}}) / RUNS; | auto matmul_used = benchmark_matmul.exec({A, B, {}}) / RUNS; | ||||
printf("\n%s: ", matmul_algo_name); | |||||
printf("%s %s:\n matmul: %f ms %f Gflops\nconv1x1: %f ms %f GFlops " | printf("%s %s:\n matmul: %f ms %f Gflops\nconv1x1: %f ms %f GFlops " | ||||
"speedup: " | "speedup: " | ||||
"%f\n", | "%f\n", | ||||
@@ -2120,6 +2119,82 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_INT8x8x16) { | |||||
#endif | #endif | ||||
} | } | ||||
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_GEMV_FP32) { | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args; | |||||
param::ConvBias conv_param; | |||||
conv_param.stride_h = 1; | |||||
conv_param.stride_w = 1; | |||||
conv_param.pad_h = 0; | |||||
conv_param.pad_w = 0; | |||||
conv_param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; | |||||
auto run = [&](size_t M, size_t K){ | |||||
args.emplace_back(conv_param, TensorShape{1, K, 1, 1}, | |||||
TensorShape{M, K, 1, 1}, TensorShape{}); | |||||
}; | |||||
for (size_t M : {4, 64, 1024, 4096}) | |||||
for (size_t K : {128, 256, 1024, 4096}) | |||||
run(M, K); | |||||
constexpr size_t RUNS = 50; | |||||
param::MatrixMul param; | |||||
param.transposeA = false; | |||||
param.transposeB = false; | |||||
Benchmarker<MatrixMul> benchmark_matmul(handle()); | |||||
benchmark_matmul.set_before_exec_callback( | |||||
AlgoChecker<MatrixMul>("ARM_COMMON_F32_GEMV")); | |||||
benchmark_matmul.set_times(RUNS) | |||||
.set_dtype(0, dtype::Float32{}) | |||||
.set_dtype(1, dtype::Float32{}) | |||||
.set_dtype(2, dtype::Float32{}) | |||||
.set_param(param) | |||||
.set_display(false); | |||||
Benchmarker<ConvBias> benchmark_conv1x1(handle()); | |||||
benchmark_conv1x1.set_before_exec_callback( | |||||
conv_bias::ConvBiasAlgoChecker<ConvBias>("CONV1x1_GEMV")); | |||||
benchmark_conv1x1.set_times(RUNS) | |||||
.set_dtype(0, dtype::Float32{}) | |||||
.set_dtype(1, dtype::Float32{}) | |||||
.set_dtype(2, dtype::Float32{}) | |||||
.set_dtype(4, dtype::Float32{}) | |||||
.set_display(false); | |||||
std::cout << "warm up:\n"; | |||||
for (int i = 0; i < 50; i++) { | |||||
benchmark_matmul.exec({{1, 1024}, {1024, 512}, {}}); | |||||
benchmark_matmul.set_display(true); | |||||
} | |||||
for (auto&& arg : args) { | |||||
size_t IC = arg.src[1]; | |||||
size_t OH = arg.src[2]; | |||||
size_t OW = arg.src[3]; | |||||
size_t OC = arg.filter[0]; | |||||
size_t M = OC; | |||||
size_t K = IC; | |||||
size_t N = OH * OW; | |||||
float computations = M * N * K * 2.f / (1024 * 1024 * 1024) * 1e3; | |||||
TensorShape A, B; | |||||
A = TensorShape{M, K}; | |||||
B = TensorShape{K, N}; | |||||
auto conv1x1_used = benchmark_conv1x1.set_param(arg.param).exec( | |||||
{arg.src, arg.filter, arg.bias, {}, {}}) / | |||||
RUNS; | |||||
auto matmul_used = benchmark_matmul.exec({A, B, {}}) / RUNS; | |||||
printf("%s %s:\n gemv: %f ms %f Gflops\nconv1x1: %f ms %f GFlops " | |||||
"speedup: " | |||||
"%f\n", | |||||
arg.src.to_string().c_str(), arg.filter.to_string().c_str(), | |||||
matmul_used, computations / matmul_used, conv1x1_used, | |||||
computations / conv1x1_used, matmul_used / conv1x1_used); | |||||
} | |||||
} | |||||
#ifndef __ARM_FEATURE_DOTPROD | #ifndef __ARM_FEATURE_DOTPROD | ||||
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) { | TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) { | ||||
std::vector<TestArg> conv_bias_1x1_args_nchw44 = | std::vector<TestArg> conv_bias_1x1_args_nchw44 = | ||||
@@ -180,12 +180,15 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||||
for (size_t kernel : kernel_vec) | for (size_t kernel : kernel_vec) | ||||
for (size_t oc : {4, 12}) | for (size_t oc : {4, 12}) | ||||
for (size_t ic : {1, 3, 4, 12}) | for (size_t ic : {1, 3, 4, 12}) | ||||
for (size_t h : {3, 5, 12}) | |||||
for (size_t w : {7, 16, 23}) { | |||||
for (size_t h : {1, 3, 12}) | |||||
for (size_t w : {1, 16, 23}) { | |||||
for (size_t group = 1; | for (size_t group = 1; | ||||
group <= | group <= | ||||
std::min(std::min(oc, ic), 4_z); | std::min(std::min(oc, ic), 4_z); | ||||
++group) { | ++group) { | ||||
if (kernel != 1 && (h == 1 || w == 1)) { | |||||
continue; | |||||
} | |||||
pack(n, oc, ic, h, w, kernel, stride, | pack(n, oc, ic, h, w, kernel, stride, | ||||
group, nlmode, bias); | group, nlmode, bias); | ||||
} | } | ||||
@@ -1897,6 +1900,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) { | |||||
#elif MEGDNN_ARMV7 | #elif MEGDNN_ARMV7 | ||||
check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32_MK4_PACK_4X12:24"); | check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32_MK4_PACK_4X12:24"); | ||||
#endif | #endif | ||||
std::vector<conv_bias::TestArg> gemv_args; | |||||
for (auto&& arg : args) | |||||
if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||||
gemv_args.emplace_back(arg); | |||||
} | |||||
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); | |||||
} | } | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) { | TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) { | ||||
@@ -1932,7 +1941,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) { | |||||
#endif | #endif | ||||
std::vector<conv_bias::TestArg> gemv_args; | std::vector<conv_bias::TestArg> gemv_args; | ||||
for (auto&& arg : args) | for (auto&& arg : args) | ||||
if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||||
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||||
gemv_args.emplace_back(arg); | gemv_args.emplace_back(arg); | ||||
} | } | ||||
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); | check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); | ||||
@@ -2138,4 +2147,40 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) { | |||||
} | } | ||||
#endif | #endif | ||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44) { | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({1}, 1, true, false, false); | |||||
UniformIntRNG rng{-50, 50}; | |||||
float epsilon = 0.001; | |||||
std::vector<conv_bias::TestArg> gemv_args; | |||||
for (auto&& arg : args) | |||||
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||||
gemv_args.emplace_back(arg); | |||||
} | |||||
checker_conv_bias(gemv_args, handle(), &rng, epsilon, | |||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | |||||
dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), | |||||
"CONV1x1_GEMV"); | |||||
} | |||||
#ifdef __ARM_FEATURE_DOTPROD | |||||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44_DOT) { | |||||
using namespace conv_bias; | |||||
std::vector<conv_bias::TestArg> args = | |||||
get_nchw44_conv_bias_args({1}, 1, true, false, false, false, true); | |||||
UniformIntRNG rng{-50, 50}; | |||||
float epsilon = 0.001; | |||||
std::vector<conv_bias::TestArg> gemv_args; | |||||
for (auto&& arg : args) | |||||
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||||
gemv_args.emplace_back(arg); | |||||
} | |||||
checker_conv_bias(gemv_args, handle(), &rng, epsilon, | |||||
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | |||||
dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), | |||||
"CONV1x1_GEMV"); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -156,7 +156,7 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV) { | |||||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | .set_dtype(2, dtype::QuantizedS32(6.25f)) | ||||
.execs({A, B, {}}); | .execs({A, B, {}}); | ||||
}; | }; | ||||
// N = 1 | // N = 1 | ||||
for (size_t M : {1, 10, 16, 33, 64}) | for (size_t M : {1, 10, 16, 33, 64}) | ||||
for (size_t K : {7, 512, 1024}) | for (size_t K : {7, 512, 1024}) | ||||
@@ -164,6 +164,70 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV) { | |||||
run(M, K, N); | run(M, K, N); | ||||
} | } | ||||
TEST_F(ARM_COMMON, QINT8x8x32_GEMV_MK4) { | |||||
Checker<MatrixMul> checker(handle()); | |||||
using Param = MatrixMul::Param; | |||||
checker.set_before_exec_callback( | |||||
AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEMV_MK4")); | |||||
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127); | |||||
checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | |||||
auto run = [&](size_t M, size_t K, size_t N) { | |||||
Param param; | |||||
param.format = param::MatrixMul::Format::MK4; | |||||
param.transposeA = false; | |||||
param.transposeB = false; | |||||
TensorShape A, B; | |||||
A = TensorShape{M / 4, K / 4, 4, 4}; | |||||
B = TensorShape{K / 4, 1, 4}; | |||||
checker.set_param(param) | |||||
.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||||
.execs({A, B, {}}); | |||||
}; | |||||
// N = 1 | |||||
for (size_t M : {4, 16, 128, 1024}) | |||||
for (size_t K : {4, 8, 12, 16, 20, 24, 256, 1024}) | |||||
run(M, K, 1); | |||||
} | |||||
#if __ARM_FEATURE_DOTPROD | |||||
TEST_F(ARM_COMMON, QINT8x8x32_GEMV_MK4_DOT) { | |||||
Checker<MatrixMul> checker(handle()); | |||||
using Param = MatrixMul::Param; | |||||
checker.set_before_exec_callback( | |||||
AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEMV_MK4_DOT")); | |||||
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127); | |||||
checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | |||||
auto run = [&](size_t M, size_t K, size_t N) { | |||||
Param param; | |||||
param.format = param::MatrixMul::Format::MK4_DOT; | |||||
param.transposeA = false; | |||||
param.transposeB = false; | |||||
TensorShape A, B; | |||||
A = TensorShape{M / 4, K / 4, 4, 4}; | |||||
B = TensorShape{K / 4, 1, 4}; | |||||
checker.set_param(param) | |||||
.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||||
.execs({A, B, {}}); | |||||
}; | |||||
// N = 1 | |||||
for (size_t M : {4, 16, 128, 1024}) | |||||
for (size_t K : {4, 8, 12, 16, 20, 24, 256, 1024}) | |||||
run(M, K, 1); | |||||
} | |||||
#endif | |||||
TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { | TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { | ||||
Checker<MatrixMul> checker(handle()); | Checker<MatrixMul> checker(handle()); | ||||
using Param = MatrixMul::Param; | using Param = MatrixMul::Param; | ||||
@@ -220,6 +284,31 @@ TEST_F(ARM_COMMON, FP32_GEVM) { | |||||
run(M, K, N); | run(M, K, N); | ||||
} | } | ||||
TEST_F(ARM_COMMON, FP32_GEMV_MK4) { | |||||
Checker<MatrixMul> checker(handle()); | |||||
using Param = MatrixMul::Param; | |||||
checker.set_before_exec_callback( | |||||
AlgoChecker<MatrixMul>("ARM_COMMON_F32_GEMV_MK4")); | |||||
checker.set_epsilon(1e-2); | |||||
auto run = [&](size_t M, size_t K) { | |||||
Param param; | |||||
param.format = param::MatrixMul::Format::MK4; | |||||
param.transposeA = false; | |||||
param.transposeB = false; | |||||
TensorShape A, B; | |||||
A = TensorShape{M/4, K/4, 4, 4}; | |||||
B = TensorShape{K/4, 1, 4}; | |||||
checker.set_param(param).execs({A, B, {}}); | |||||
}; | |||||
// N = 1 | |||||
for (size_t M : {4, 16, 128, 1024}) | |||||
for (size_t K : {4, 8, 12, 128, 256, 4096}) | |||||
run(M, K); | |||||
} | |||||
#if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { | TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { | ||||
@@ -228,18 +317,16 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { | |||||
benchmarker.set_times(exec_times); | benchmarker.set_times(exec_times); | ||||
auto run = [&](size_t M, size_t K, size_t N) { | auto run = [&](size_t M, size_t K, size_t N) { | ||||
std::cout << "SGEMV: (" << M << ", " << K << ", " << N << ")" | |||||
<< std::endl; | |||||
printf("SGEMV: (%zu, %zu, %zu)\n", M, K, N); | |||||
benchmarker.set_dtype(0, dtype::Float32()) | benchmarker.set_dtype(0, dtype::Float32()) | ||||
.set_dtype(1, dtype::Float32()); | .set_dtype(1, dtype::Float32()); | ||||
auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times; | auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times; | ||||
auto computations = 2.f * M * K * N * 1e-6; | auto computations = 2.f * M * K * N * 1e-6; | ||||
auto perf = computations / time; | auto perf = computations / time; | ||||
std::cout << "gemv fp32, Performance is " << perf << " Gflops" | |||||
<< std::endl; | |||||
printf("gemv fp32, Performance is %f Gflops\n", perf); | |||||
}; | }; | ||||
std::cout << "warm up:\n"; | |||||
printf("warm up:\n"); | |||||
for (int i = 0; i < 50; i++) { | for (int i = 0; i < 50; i++) { | ||||
benchmarker.set_dtype(0, dtype::Float32()) | benchmarker.set_dtype(0, dtype::Float32()) | ||||
.set_dtype(1, dtype::Float32()) | .set_dtype(1, dtype::Float32()) | ||||
@@ -253,6 +340,10 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { | |||||
for (size_t K : {1024, 1536, 2048}) | for (size_t K : {1024, 1536, 2048}) | ||||
for (size_t N : {512, 1024}) | for (size_t N : {512, 1024}) | ||||
run(M, K, N); | run(M, K, N); | ||||
for (size_t M : {4, 64, 1024, 4096}) | |||||
for (size_t K : {128, 256, 1024, 4096}) | |||||
run(M, K, 1); | |||||
} | } | ||||
TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { | TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { | ||||
@@ -263,28 +354,25 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { | |||||
AlgoChecker<MatrixMul>("ARM_COMMON_F32_GEMV")); | AlgoChecker<MatrixMul>("ARM_COMMON_F32_GEMV")); | ||||
auto run = [&](size_t M, size_t K, size_t N) { | auto run = [&](size_t M, size_t K, size_t N) { | ||||
std::cout << "SGEMV: (" << M << ", " << K << ", " << N << ")" | |||||
<< std::endl; | |||||
printf("SGEMV: (%zu, %zu, %zu)\n", M, K, N); | |||||
benchmarker.set_dtype(0, dtype::Float32()) | benchmarker.set_dtype(0, dtype::Float32()) | ||||
.set_dtype(1, dtype::Float32()) | .set_dtype(1, dtype::Float32()) | ||||
.set_dtype(2, dtype::Float32()); | .set_dtype(2, dtype::Float32()); | ||||
auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times; | auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times; | ||||
auto computations = 2 * M * K * N * 1e-6; | auto computations = 2 * M * K * N * 1e-6; | ||||
auto perf = computations / time; | auto perf = computations / time; | ||||
std::cout << "gemv fp32, Performance is " << perf << " Gflops" | |||||
<< std::endl; | |||||
printf("gemv fp32, Performance is %f Gflops\n", perf); | |||||
}; | }; | ||||
std::cout << "warm up:\n"; | |||||
printf("warm up:\n"); | |||||
for (int i = 0; i < 50; i++) { | for (int i = 0; i < 50; i++) { | ||||
benchmarker.set_dtype(0, dtype::Float32()) | benchmarker.set_dtype(0, dtype::Float32()) | ||||
.set_dtype(1, dtype::Float32()) | .set_dtype(1, dtype::Float32()) | ||||
.set_dtype(2, dtype::Float32()) | |||||
.set_display(false) | .set_display(false) | ||||
.exec({{2, 1024}, {1024, 512}, {}}); | .exec({{2, 1024}, {1024, 512}, {}}); | ||||
benchmarker.set_display(true); | benchmarker.set_display(true); | ||||
} | } | ||||
// run gemv | // run gemv | ||||
run(12, 48, 1); | run(12, 48, 1); | ||||
run(48, 12, 1); | run(48, 12, 1); | ||||
@@ -298,6 +386,45 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { | |||||
run(1024, 256, 1); | run(1024, 256, 1); | ||||
} | } | ||||
TEST_F(ARM_COMMON, BENCHMARK_SGEMV_MK4) { | |||||
int exec_times = 10; | |||||
using Param = MatrixMul::Param; | |||||
Param param; | |||||
param.format = param::MatrixMul::Format::MK4; | |||||
param.transposeA = false; | |||||
param.transposeB = false; | |||||
Benchmarker<MatrixMul> benchmarker(handle()); | |||||
benchmarker.set_times(exec_times); | |||||
benchmarker.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_param(param); | |||||
auto run = [&](size_t M, size_t K) { | |||||
printf("SGEMV_MK4: (%zu, %zu, %zu)\n", M, K, N); | |||||
TensorShape A, B; | |||||
A = TensorShape{M/4, K/4, 4, 4}; | |||||
B = TensorShape{K/4, 1, 4}; | |||||
auto time = benchmarker.exec({A, B, {}}) / exec_times; | |||||
auto computations = 2.f * M * K * 1e-6; | |||||
auto perf = computations / time; | |||||
printf("gemv mk4 fp32, Performance is %f Gflops\n", perf); | |||||
}; | |||||
printf("warm up:\n"); | |||||
for (int i = 0; i < 50; i++) { | |||||
benchmarker.set_dtype(0, dtype::Float32()) | |||||
.set_dtype(1, dtype::Float32()) | |||||
.set_dtype(2, dtype::Float32()) | |||||
.set_display(false) | |||||
.exec({{4, 256, 4, 4}, {256, 1, 4}, {}}); | |||||
} | |||||
// run gemv mk4 | |||||
for (size_t M : {4, 64, 1024, 4096}) | |||||
for (size_t K : {128, 1024, 4096}) | |||||
run(M, K); | |||||
} | |||||
TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) { | TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) { | ||||
int exec_times = 50; | int exec_times = 50; | ||||
Benchmarker<MatrixMul> benchmarker(handle()); | Benchmarker<MatrixMul> benchmarker(handle()); | ||||
@@ -306,19 +433,17 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) { | |||||
AlgoChecker<MatrixMul>("ARM_COMMON_F16_GEMV")); | AlgoChecker<MatrixMul>("ARM_COMMON_F16_GEMV")); | ||||
auto run = [&](size_t M, size_t K, size_t N) { | auto run = [&](size_t M, size_t K, size_t N) { | ||||
std::cout << "SGEMV: (" << M << ", " << K << ", " << N << ")" | |||||
<< std::endl; | |||||
printf("SGEMV_FP16: (%zu, %zu, %zu)\n", M, K, N); | |||||
benchmarker.set_dtype(0, dtype::Float16()) | benchmarker.set_dtype(0, dtype::Float16()) | ||||
.set_dtype(1, dtype::Float16()) | .set_dtype(1, dtype::Float16()) | ||||
.set_dtype(2, dtype::Float16()); | .set_dtype(2, dtype::Float16()); | ||||
auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times; | auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times; | ||||
auto computations = 2 * M * K * N * 1e-6; | auto computations = 2 * M * K * N * 1e-6; | ||||
auto perf = computations / time; | auto perf = computations / time; | ||||
std::cout << "gemv fp16, Performance is " << perf << " Gflops" | |||||
<< std::endl; | |||||
printf("gemv fp16, Performance is %f Gflops\n", perf); | |||||
}; | }; | ||||
std::cout << "warm up:\n"; | |||||
printf("warm up:\n"); | |||||
for (int i = 0; i < 50; i++) { | for (int i = 0; i < 50; i++) { | ||||
benchmarker.set_dtype(0, dtype::Float16()) | benchmarker.set_dtype(0, dtype::Float16()) | ||||
.set_dtype(1, dtype::Float16()) | .set_dtype(1, dtype::Float16()) | ||||
@@ -343,17 +468,15 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMM) { | |||||
float mod = 1000 * exec_times / 1e9; | float mod = 1000 * exec_times / 1e9; | ||||
auto run = [&](size_t M, size_t K, size_t N) { | auto run = [&](size_t M, size_t K, size_t N) { | ||||
float time = 1.f, perf = 1.f; | float time = 1.f, perf = 1.f; | ||||
std::cout << "SGEMM: (" << M << ", " << K << ", " << N << ")" | |||||
<< std::endl; | |||||
printf("SGEMM: (%zu, %zu, %zu)\n", M, K, N); | |||||
benchmarker.set_dtype(0, dtype::Float32()) | benchmarker.set_dtype(0, dtype::Float32()) | ||||
.set_dtype(1, dtype::Float32()); | .set_dtype(1, dtype::Float32()); | ||||
time = benchmarker.exec({{M, K}, {K, N}, {}}); | time = benchmarker.exec({{M, K}, {K, N}, {}}); | ||||
perf = 2.f * M * K * N / time * mod; | perf = 2.f * M * K * N / time * mod; | ||||
std::cout << "gemm fp32, Performance is " << perf << " Gflops" | |||||
<< std::endl; | |||||
printf("gemm, Performance is %f Gflops\n", perf); | |||||
}; | }; | ||||
std::cout << "warm up:\n"; | |||||
printf("warm up:\n"); | |||||
for (int i = 0; i < 50; i++) { | for (int i = 0; i < 50; i++) { | ||||
benchmarker.set_dtype(0, dtype::Float32()) | benchmarker.set_dtype(0, dtype::Float32()) | ||||
.set_dtype(1, dtype::Float32()) | .set_dtype(1, dtype::Float32()) | ||||