GitOrigin-RevId: f8b6d7a1b7
release-0.6
@@ -210,27 +210,33 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||
DEFAULT \ | |||
} | |||
#define FOR_BIAS(_bias_mode, OH, OW) \ | |||
switch (_bias_mode) { \ | |||
case megdnn::BiasMode::NO_BIAS: \ | |||
FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY); \ | |||
break; \ | |||
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
if (pack_oc_size == 1) { \ | |||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | |||
} else { \ | |||
megdnn_assert(pack_oc_size == 4, \ | |||
"Only support nchw44 in ARM"); \ | |||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ | |||
} \ | |||
break; \ | |||
default: \ | |||
if (OH * OW == 1) { \ | |||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | |||
break; \ | |||
} \ | |||
megdnn_throw("quantized unsupported biasmode"); \ | |||
break; \ | |||
#define FOR_BIAS(_bias_mode, OH, OW) \ | |||
switch (_bias_mode) { \ | |||
case megdnn::BiasMode::NO_BIAS: \ | |||
FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY); \ | |||
break; \ | |||
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ | |||
if (pack_oc_size == 1) { \ | |||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | |||
} else { \ | |||
megdnn_assert(pack_oc_size == 4, \ | |||
"Only support nchw44 in ARM"); \ | |||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ | |||
} \ | |||
break; \ | |||
default: \ | |||
if (OH * OW == 1) { \ | |||
if (pack_oc_size == 1) { \ | |||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | |||
} else { \ | |||
megdnn_assert(pack_oc_size == 4, \ | |||
"Only support nchw44 in ARM"); \ | |||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ | |||
} \ | |||
break; \ | |||
} \ | |||
megdnn_throw("quantized unsupported biasmode"); \ | |||
break; \ | |||
} | |||
template <typename opctype, typename opdtype> | |||
@@ -101,6 +101,91 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern( | |||
return int8x8x32_gemv_kern; | |||
} | |||
/* ===================== Int8x8x32 Gemv MK4 algo ===================== */ | |||
namespace { | |||
void int8x8x32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int32>(); | |||
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
} | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoInt8x8x32GemvMK4::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
auto M = kern_size_param.M; | |||
auto N = kern_size_param.N; | |||
auto K = kern_size_param.K; | |||
auto LDB = kern_size_param.LDB; | |||
bool is_dtype_ok = | |||
kern_size_param.A_type == kern_size_param.B_type && | |||
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | |||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && | |||
(kern_size_param.C_type.enumv() == DTypeEnum::Int32 || | |||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); | |||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB && | |||
M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; | |||
} | |||
bool MatrixMulImpl::AlgoInt8x8x32GemvMK4::preferred( | |||
const KernSizeParam& kern_size_param) const { | |||
MEGDNN_MARK_USED_VAR(kern_size_param); | |||
return true; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern( | |||
const KernSizeParam&) const { | |||
return int8x8x32_gemv_mk4_kern; | |||
} | |||
#if __ARM_FEATURE_DOTPROD | |||
/* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ | |||
namespace { | |||
void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int32>(); | |||
gemv_like_mk4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
} | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
auto M = kern_size_param.M; | |||
auto N = kern_size_param.N; | |||
auto K = kern_size_param.K; | |||
auto LDB = kern_size_param.LDB; | |||
bool is_dtype_ok = | |||
kern_size_param.A_type == kern_size_param.B_type && | |||
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | |||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && | |||
(kern_size_param.C_type.enumv() == DTypeEnum::Int32 || | |||
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); | |||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
kern_size_param.format == param::MatrixMul::Format::MK4_DOT && | |||
is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB && | |||
M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; | |||
} | |||
bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::preferred( | |||
const KernSizeParam& kern_size_param) const { | |||
return true; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::get_kern( | |||
const KernSizeParam&) const { | |||
return int8x8x32_gemv_mk4_dot_kern; | |||
} | |||
#endif | |||
/* ===================== F32 Gemv algo ===================== */ | |||
namespace { | |||
void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
@@ -137,6 +222,46 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern( | |||
return f32_gemv_kern; | |||
} | |||
/* ================== F32 Gemv MK4 algo ================== */ | |||
namespace { | |||
void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_float32>(), | |||
Bptr = kern_param.B<dt_float32>(); | |||
auto Cptr = kern_param.C<dt_float32>(); | |||
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
} | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoF32GemvMK4::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
// enumerate the M, N, K, only usable when preferred | |||
auto M = kern_size_param.M; | |||
auto N = kern_size_param.N; | |||
auto K = kern_size_param.K; | |||
auto LDB = kern_size_param.LDB; | |||
return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
kern_size_param.format == param::MatrixMul::Format::MK4 && | |||
kern_size_param.B_type == kern_size_param.A_type && | |||
kern_size_param.C_type == kern_size_param.A_type && | |||
kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && | |||
!kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && | |||
LDB == 4; | |||
} | |||
bool MatrixMulImpl::AlgoF32GemvMK4::preferred( | |||
const KernSizeParam& kern_size_param) const { | |||
MEGDNN_MARK_USED_VAR(kern_size_param); | |||
return true; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GemvMK4::get_kern( | |||
const KernSizeParam&) const { | |||
return f32_gemv_mk4_kern; | |||
} | |||
/* ===================== F32 Gevm algo ===================== */ | |||
namespace { | |||
template <typename stype, typename dtype> | |||
@@ -43,6 +43,36 @@ public: | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||
}; | |||
#if __ARM_FEATURE_DOTPROD | |||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||
}; | |||
#endif | |||
class MatrixMulImpl::AlgoF32Gemv : public AlgoBase { | |||
protected: | |||
~AlgoF32Gemv() = default; | |||
@@ -60,6 +90,20 @@ public: | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||
}; | |||
class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | |||
public: | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } | |||
bool usable(const KernSizeParam&) const override; | |||
bool preferred(const KernSizeParam&) const override; | |||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4) | |||
}; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
class MatrixMulImpl::AlgoF16Gemv : public AlgoBase { | |||
public: | |||
@@ -87,10 +131,9 @@ public: | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(1, 1, 1, 4) | |||
}; | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
@@ -13,11 +13,11 @@ | |||
#include "src/arm_common/matrix_mul/fp32/exec_sgemv.h" | |||
#include <cstddef> | |||
#include "include/megdnn/oprs.h" | |||
#include "midout.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/common/unroll_macro.h" | |||
#include "src/common/utils.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_fp32_sgemv) | |||
using namespace megdnn; | |||
@@ -68,18 +68,10 @@ void sgemv_naive_n(const float* __restrict A, const float* __restrict B, | |||
#if !defined(__aarch64__) | |||
#undef vaddvq_f32 | |||
#endif | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
void gemv_like(const float* __restrict A, const float* __restrict B, | |||
float* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride) { | |||
megdnn_assert(M < 8 || (M == 8 && K <= 2) || (N == 1 && Bstride == 1)); | |||
if (N == 1) { | |||
return sgemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||
} | |||
void sgemv_naive_m(const float* __restrict A, const float* __restrict B, | |||
float* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride) { | |||
size_t m = 0; | |||
for (; m + 4 <= M; m += 4) { | |||
size_t k = 0; | |||
@@ -762,6 +754,85 @@ void gemv_like(const float* __restrict A, const float* __restrict B, | |||
} | |||
} | |||
} | |||
void sgemv_naive_n_mk4(const float* __restrict A, const float* __restrict B, | |||
float* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride) { | |||
constexpr size_t PACK_SIZE = 4; | |||
megdnn_assert(N == 1 && Bstride == PACK_SIZE && M % PACK_SIZE == 0 && | |||
K % PACK_SIZE == 0); | |||
auto Aptr = A; | |||
auto Cptr = C; | |||
size_t m = 0; | |||
while (m < M) { | |||
auto Aptr0 = Aptr; | |||
auto Cptr0 = Cptr; | |||
float32x4_t c[4]; | |||
#define INIT(step) c[step] = vdupq_n_f32(0.0f); | |||
UNROLL_CALL_RAW(4, INIT) | |||
#undef INIT | |||
auto Bptr = B; | |||
size_t k = 0; | |||
while (k < K) { | |||
float32x4_t b = vld1q_f32(Bptr); | |||
float32x4x2_t a[2]; | |||
#define LOAD_A(step) a[step] = vld1q_f32_x2(Aptr0 + step * 8); | |||
UNROLL_CALL_RAW(2, LOAD_A) | |||
#undef LOAD_A | |||
#define COMPT(step) \ | |||
c[step] = vfmaq_laneq_f32(c[step], a[step / 2].val[step % 2], b, step % 4); | |||
UNROLL_CALL_RAW(4, COMPT) | |||
#undef COMPT | |||
Bptr += Bstride; | |||
Aptr0 += PACK_SIZE * PACK_SIZE; | |||
k += PACK_SIZE; | |||
} | |||
#define ADD_C(step, stride) c[step] = vaddq_f32(c[step], c[step + stride]); | |||
UNROLL_CALL_RAW(2, ADD_C, 2) | |||
UNROLL_CALL_RAW(1, ADD_C, 1) | |||
#undef ADD_C | |||
vst1q_f32(Cptr0, c[0]); | |||
Aptr += Astride; | |||
Cptr += Cstride; | |||
m += PACK_SIZE; | |||
} | |||
} | |||
} // namespace | |||
namespace megdnn { | |||
namespace arm_common { | |||
void gemv_like(const float* __restrict A, const float* __restrict B, | |||
float* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride) { | |||
megdnn_assert(M < 8 || (M == 8 && K <= 2) || (N == 1 && Bstride == 1)); | |||
if (N == 1) { | |||
MIDOUT_BEGIN(megdnn_fp32_sgemv, midout_iv("F32_GEMV_NCHW_N"_hash)) { | |||
return sgemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||
} | |||
MIDOUT_END(); | |||
} else { | |||
MIDOUT_BEGIN(megdnn_fp32_sgemv, midout_iv("F32_GEMV_NCHW_M"_hash)) { | |||
return sgemv_naive_m(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} | |||
void gemv_like_mk4(const float* __restrict A, const float* __restrict B, | |||
float* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride) { | |||
megdnn_assert(N == 1 && Bstride == 4); | |||
MIDOUT_BEGIN(megdnn_fp32_sgemv, midout_iv("F32_GEMV_NCHW44_N"_hash)) { | |||
return sgemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
@@ -24,6 +24,9 @@ void gemv_like(const float* __restrict A, const float* __restrict B, | |||
float* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride); | |||
void gemv_like_mk4(const float* __restrict A, const float* __restrict B, | |||
float* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride); | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
@@ -10,8 +10,8 @@ | |||
*/ | |||
#include <cstddef> | |||
#include "src/arm_common/matrix_mul/int8/gemv.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/matrix_mul/int8/gemv.h" | |||
#include "src/common/utils.h" | |||
#include "megdnn/oprs.h" | |||
@@ -95,6 +95,80 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||
C[m * Cstride] = acc0; | |||
} | |||
} | |||
void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride) { | |||
constexpr size_t PACK_SIZE = 4; | |||
megdnn_assert(N == 1 && Bstride == 4); | |||
auto Aptr = A; | |||
size_t m = 0; | |||
for (; m < M; m += PACK_SIZE) { | |||
auto Bptr = B; | |||
auto Aptr0 = Aptr; | |||
int32_t acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0; | |||
size_t k = 0; | |||
for (; k + 16 <= K; k += 16) { | |||
int8x16x4_t a = vld4q_s8(Aptr0); | |||
int8x16_t b = vld1q_s8(Bptr); | |||
int16x8_t c[4]; | |||
c[0] = vmull_s8(vget_low_s8(a.val[0]), vget_low_s8(b)); | |||
c[1] = vmull_s8(vget_low_s8(a.val[1]), vget_low_s8(b)); | |||
c[2] = vmull_s8(vget_low_s8(a.val[2]), vget_low_s8(b)); | |||
c[3] = vmull_s8(vget_low_s8(a.val[3]), vget_low_s8(b)); | |||
c[0] = vmlal_high_s8(c[0], a.val[0], b); | |||
c[1] = vmlal_high_s8(c[1], a.val[1], b); | |||
c[2] = vmlal_high_s8(c[2], a.val[2], b); | |||
c[3] = vmlal_high_s8(c[3], a.val[3], b); | |||
acc0 += vaddlvq_s16(c[0]); | |||
acc1 += vaddlvq_s16(c[1]); | |||
acc2 += vaddlvq_s16(c[2]); | |||
acc3 += vaddlvq_s16(c[3]); | |||
Bptr += 16; | |||
Aptr0 += PACK_SIZE * 16; | |||
} | |||
for (; k + 8 <= K; k += 8) { | |||
int8x8x4_t a = vld4_s8(Aptr0); | |||
int8x8_t b = vld1_s8(Bptr); | |||
int16x8_t c[4]; | |||
c[0] = vmull_s8(a.val[0], b); | |||
c[1] = vmull_s8(a.val[1], b); | |||
c[2] = vmull_s8(a.val[2], b); | |||
c[3] = vmull_s8(a.val[3], b); | |||
acc0 += vaddlvq_s16(c[0]); | |||
acc1 += vaddlvq_s16(c[1]); | |||
acc2 += vaddlvq_s16(c[2]); | |||
acc3 += vaddlvq_s16(c[3]); | |||
Bptr += 8; | |||
Aptr0 += PACK_SIZE * 8; | |||
} | |||
for (; k < K; ++k) { | |||
acc0 += static_cast<int32_t>(*(Aptr0 + 0)) * B[k]; | |||
acc1 += static_cast<int32_t>(*(Aptr0 + 1)) * B[k]; | |||
acc2 += static_cast<int32_t>(*(Aptr0 + 2)) * B[k]; | |||
acc3 += static_cast<int32_t>(*(Aptr0 + 3)) * B[k]; | |||
Aptr0 += 4; | |||
} | |||
C[0] = acc0; | |||
C[1] = acc1; | |||
C[2] = acc2; | |||
C[3] = acc3; | |||
Aptr += Astride; | |||
C += Cstride; | |||
} | |||
} | |||
} // namespace | |||
#endif | |||
@@ -169,6 +243,139 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||
C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3]; | |||
} | |||
} | |||
void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride) { | |||
constexpr size_t PACK_SIZE = 4; | |||
megdnn_assert(N == 1 && Bstride == 4); | |||
auto Aptr = A; | |||
size_t m = 0; | |||
for (; m < M; m += PACK_SIZE) { | |||
auto Bptr = B; | |||
auto Aptr0 = Aptr; | |||
int32_t acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0; | |||
size_t k = 0; | |||
if (k + 16 <= K) { | |||
int32x4_t acc_neon[4]; | |||
acc_neon[0] = vdupq_n_s32(0); | |||
acc_neon[1] = vdupq_n_s32(0); | |||
acc_neon[2] = vdupq_n_s32(0); | |||
acc_neon[3] = vdupq_n_s32(0); | |||
for (; k + 16 <= K; k += 16) { | |||
int8x16x4_t a = vld4q_s8(Aptr0); | |||
int8x16_t b = vld1q_s8(Bptr); | |||
acc_neon[0] = vdotq_s32(acc_neon[0], a.val[0], b); | |||
acc_neon[1] = vdotq_s32(acc_neon[1], a.val[1], b); | |||
acc_neon[2] = vdotq_s32(acc_neon[2], a.val[2], b); | |||
acc_neon[3] = vdotq_s32(acc_neon[3], a.val[3], b); | |||
Bptr += 16; | |||
Aptr0 += PACK_SIZE * 16; | |||
} | |||
acc0 = vaddvq_s32(acc_neon[0]); | |||
acc1 = vaddvq_s32(acc_neon[1]); | |||
acc2 = vaddvq_s32(acc_neon[2]); | |||
acc3 = vaddvq_s32(acc_neon[3]); | |||
} | |||
if (k + 8 <= K) { | |||
int32x2_t acc_neon[4]; | |||
acc_neon[0] = vdup_n_s32(0); | |||
acc_neon[1] = vdup_n_s32(0); | |||
acc_neon[2] = vdup_n_s32(0); | |||
acc_neon[3] = vdup_n_s32(0); | |||
int8x8x4_t a = vld4_s8(Aptr0); | |||
int8x8_t b = vld1_s8(Bptr); | |||
acc_neon[0] = vdot_s32(acc_neon[0], a.val[0], b); | |||
acc_neon[1] = vdot_s32(acc_neon[1], a.val[1], b); | |||
acc_neon[2] = vdot_s32(acc_neon[2], a.val[2], b); | |||
acc_neon[3] = vdot_s32(acc_neon[3], a.val[3], b); | |||
Bptr += 8; | |||
Aptr0 += PACK_SIZE * 8; | |||
k += 8; | |||
acc0 += vaddv_s32(acc_neon[0]); | |||
acc1 += vaddv_s32(acc_neon[1]); | |||
acc2 += vaddv_s32(acc_neon[2]); | |||
acc3 += vaddv_s32(acc_neon[3]); | |||
} | |||
for (; k < K; ++k) { | |||
acc0 += static_cast<int32_t>(*(Aptr0 + 0)) * B[k]; | |||
acc1 += static_cast<int32_t>(*(Aptr0 + 1)) * B[k]; | |||
acc2 += static_cast<int32_t>(*(Aptr0 + 2)) * B[k]; | |||
acc3 += static_cast<int32_t>(*(Aptr0 + 3)) * B[k]; | |||
Aptr0 += 4; | |||
} | |||
C[0] = acc0; | |||
C[1] = acc1; | |||
C[2] = acc2; | |||
C[3] = acc3; | |||
Aptr += Astride; | |||
C += Cstride; | |||
} | |||
} | |||
void gemv_naive_n_mk4_dot(const int8_t* __restrict A, | |||
const int8_t* __restrict B, int32_t* __restrict C, | |||
size_t M, size_t N, size_t K, size_t Astride, | |||
size_t Bstride, size_t Cstride) { | |||
constexpr size_t PACK_SIZE = 4; | |||
megdnn_assert(N == 1 && Bstride == 4); | |||
auto Aptr = A; | |||
size_t m = 0; | |||
for (; m < M; m += PACK_SIZE) { | |||
auto Bptr = B; | |||
auto Aptr0 = Aptr; | |||
size_t k = 0; | |||
int32x4_t acc_neon; | |||
acc_neon = vdupq_n_s32(0); | |||
for (; k + 16 <= K; k += 16) { | |||
int8x16_t a0 = vld1q_s8(Aptr0); | |||
int8x16_t a1 = vld1q_s8(Aptr0 + 16); | |||
int8x16_t a2 = vld1q_s8(Aptr0 + 32); | |||
int8x16_t a3 = vld1q_s8(Aptr0 + 48); | |||
int8x16_t b = vld1q_s8(Bptr); | |||
acc_neon = vdotq_laneq_s32(acc_neon, a0, b, 0); | |||
acc_neon = vdotq_laneq_s32(acc_neon, a1, b, 1); | |||
acc_neon = vdotq_laneq_s32(acc_neon, a2, b, 2); | |||
acc_neon = vdotq_laneq_s32(acc_neon, a3, b, 3); | |||
Bptr += 16; | |||
Aptr0 += PACK_SIZE * 16; | |||
} | |||
if (k + 8 <= K) { | |||
int8x16_t a0 = vld1q_s8(Aptr0); | |||
int8x16_t a1 = vld1q_s8(Aptr0 + 16); | |||
int8x8_t b = vld1_s8(Bptr); | |||
acc_neon = vdotq_lane_s32(acc_neon, a0, b, 0); | |||
acc_neon = vdotq_lane_s32(acc_neon, a1, b, 1); | |||
Bptr += 8; | |||
Aptr0 += PACK_SIZE * 8; | |||
k += 8; | |||
} | |||
if (k + 4 <= K) { | |||
int8x16_t a = vld1q_s8(Aptr0); | |||
int32_t tmp = *(reinterpret_cast<const int32_t*>(Bptr)); | |||
int8x8_t b = vdup_n_s32(tmp); | |||
acc_neon = vdotq_lane_s32(acc_neon, a, b, 0); | |||
} | |||
vst1q_s32(C, acc_neon); | |||
Aptr += Astride; | |||
C += Cstride; | |||
} | |||
} | |||
} // namespace | |||
#endif | |||
@@ -201,4 +408,33 @@ void arm_common::gemv_like(const int8_t* __restrict A, | |||
MIDOUT_END(); | |||
} | |||
void arm_common::gemv_like_mk4(const int8_t* __restrict A, | |||
const int8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, | |||
size_t K, size_t Astride, size_t Bstride, | |||
size_t Cstride) { | |||
megdnn_assert(N == 1); | |||
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, | |||
midout_iv("INT8_gemv_like_mk4"_hash)) { | |||
return gemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||
} | |||
MIDOUT_END(); | |||
} | |||
#if __ARM_FEATURE_DOTPROD | |||
void arm_common::gemv_like_mk4_dot(const int8_t* __restrict A, | |||
const int8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, | |||
size_t K, size_t Astride, size_t Bstride, | |||
size_t Cstride) { | |||
megdnn_assert(N == 1); | |||
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, | |||
midout_iv("INT8_gemv_like_mk4_dot"_hash)) { | |||
return gemv_naive_n_mk4_dot(A, B, C, M, N, K, Astride, Bstride, | |||
Cstride); | |||
} | |||
MIDOUT_END(); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -24,6 +24,16 @@ void gemv_like(const int8_t* __restrict A, const int8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride); | |||
void gemv_like_mk4(const int8_t* __restrict A, const int8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride); | |||
#if __ARM_FEATURE_DOTPROD | |||
void gemv_like_mk4_dot(const int8_t* __restrict A, const int8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride); | |||
#endif | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
@@ -28,14 +28,24 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
AlgoF16Gemv f16gemv; | |||
#endif | |||
AlgoInt8x8x32Gemv int8x8x32_gemv; | |||
AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; | |||
#if __ARM_FEATURE_DOTPROD | |||
AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; | |||
#endif | |||
AlgoGevm gevm; | |||
AlgoF32GemvMK4 f32_gemv_mk4; | |||
public: | |||
AlgoPack() { | |||
all_algos.emplace_back(&int8x8x16); | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
all_algos.emplace_back(&f16gemv); | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); | |||
#endif | |||
all_algos.emplace_back(&int8x8x32_gemv); | |||
all_algos.emplace_back(&int8x8x32_gemv_mk4); | |||
all_algos.emplace_back(&f32_gemv_mk4); | |||
all_algos.emplace_back(&gevm); | |||
} | |||
SmallVector<AlgoBase*> all_algos; | |||
@@ -25,12 +25,17 @@ public: | |||
protected: | |||
static void* const sm_arm_common_algo_type; | |||
class AlgoInt8x8x32Gemv; // Arm_common Int 8x8x32 Gemv | |||
class AlgoF32Gemv; // Arm_common F32 Gemv | |||
class AlgoGevm; // Arm_common Gemv(support int8 and fp32) | |||
class AlgoF32Gemv; // Arm_common F32 Gemv | |||
class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44 | |||
class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv | |||
class AlgoInt8x8x32GemvMK4; // Arm_common Int8x8x32 Gemv NCHW44 | |||
class AlgoGevm; // Arm_common Gevm(support int8 and fp32) | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
class AlgoF16Gemv; | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
class AlgoInt8x8x32GemvMK4Dot;// Arm_common Int8x8x32 Gemv NCHW44_DOT | |||
#endif | |||
class AlgoInt8x8x16; // Arm_common Int 8x8x16 | |||
class AlgoPack; | |||
}; | |||
@@ -407,6 +407,16 @@ __ai int32_t vaddv_s32(int32x2_t a) { | |||
return vget_lane_s32(a, 0) + vget_lane_s32(a, 1); | |||
} | |||
__ai int32_t vaddvq_s32(int32x4_t a) { | |||
return vgetq_lane_s32(a, 0) + vgetq_lane_s32(a, 1) + | |||
vgetq_lane_s32(a, 2) + vgetq_lane_s32(a, 3); | |||
} | |||
__ai float32_t vaddvq_f32(float32x4_t a) { | |||
return vgetq_lane_f32(a, 0) + vgetq_lane_f32(a, 1) + | |||
vgetq_lane_f32(a, 2) + vgetq_lane_f32(a, 3); | |||
} | |||
#endif // MEGDNN_ARMV7 | |||
//! pack vmovl_low_xx() on armv7 and armv8 | |||
@@ -42,14 +42,27 @@ using namespace conv1x1; | |||
namespace { | |||
#if MEGDNN_X86 | |||
template <typename stype, typename btype, param::ConvBias::Format F> | |||
struct GemvLike { | |||
inline static void do_gemv(const stype* A, const stype* B, btype* C, | |||
size_t M, size_t N, size_t K, size_t LDA, | |||
size_t LDB, size_t LDC, DType src, | |||
DType filter) { | |||
megdnn_throw("x86 conv1x1 gemv only supports format : NCHW"); | |||
MEGDNN_MARK_USED_VAR(A); | |||
MEGDNN_MARK_USED_VAR(B); | |||
MEGDNN_MARK_USED_VAR(C); | |||
MEGDNN_MARK_USED_VAR(M); | |||
MEGDNN_MARK_USED_VAR(N); | |||
MEGDNN_MARK_USED_VAR(K); | |||
MEGDNN_MARK_USED_VAR(LDA); | |||
MEGDNN_MARK_USED_VAR(LDB); | |||
MEGDNN_MARK_USED_VAR(LDC); | |||
MEGDNN_MARK_USED_VAR(src); | |||
MEGDNN_MARK_USED_VAR(filter); | |||
megdnn_assert(false, | |||
"unspported conv1x1 gemv : \nsrc_type : " | |||
"%s\nfilter_type : %s\n", | |||
src.name(), filter.name()); | |||
} | |||
}; | |||
@@ -66,39 +79,29 @@ struct GemvLike<stype, btype, param::ConvBias::Format::NCHW> { | |||
} | |||
}; | |||
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
template <typename stype, typename btype, param::ConvBias::Format F> | |||
struct GemvLike { | |||
inline static void do_gemv(const stype* A, const stype* B, btype* C, | |||
size_t M, size_t N, size_t K, size_t LDA, | |||
size_t LDB, size_t LDC, DType src, | |||
DType filter) { | |||
megdnn_throw("arm conv1x1 gemv only supports format : NCHW"); | |||
} | |||
}; | |||
template <typename stype, typename btype> | |||
struct GemvLike<stype, btype, param::ConvBias::Format::NCHW> { | |||
inline static void do_gemv(const stype* A, const stype* B, btype* C, | |||
size_t M, size_t N, size_t K, size_t LDA, | |||
size_t LDB, size_t LDC, DType src, | |||
template <> | |||
struct GemvLike<dt_uint8, dt_int32, param::ConvBias::Format::NCHW> { | |||
inline static void do_gemv(const dt_uint8* A, const dt_uint8* B, | |||
dt_int32* C, size_t M, size_t N, size_t K, | |||
size_t LDA, size_t LDB, size_t LDC, DType src, | |||
DType filter) { | |||
MEGDNN_MARK_USED_VAR(src); | |||
MEGDNN_MARK_USED_VAR(filter); | |||
megdnn::arm_common::gemv_like(A, B, C, M, N, K, LDA, LDB, LDC); | |||
uint8_t zp0 = src.param<dtype::Quantized8Asymm>().zero_point; | |||
uint8_t zp1 = filter.param<dtype::Quantized8Asymm>().zero_point; | |||
megdnn::fallback::gemv_like<dt_uint8, dt_int32>(A, B, C, M, N, K, LDA, | |||
LDB, LDC, zp0, zp1); | |||
} | |||
}; | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
template <> | |||
struct GemvLike<dt_int8, dt_int16, param::ConvBias::Format::NCHW> { | |||
inline static void do_gemv(const dt_int8* A, const dt_int8* B, dt_int16* C, | |||
size_t M, size_t N, size_t K, size_t LDA, | |||
size_t LDB, size_t LDC, DType src, | |||
struct GemvLike<dt_float32, dt_float32, param::ConvBias::Format::NCHW> { | |||
inline static void do_gemv(const dt_float32* A, const dt_float32* B, | |||
dt_float32* C, size_t M, size_t N, size_t K, | |||
size_t LDA, size_t LDB, size_t LDC, DType src, | |||
DType filter) { | |||
MEGDNN_MARK_USED_VAR(src); | |||
MEGDNN_MARK_USED_VAR(filter); | |||
megdnn::fallback::gemv_like<dt_int8, dt_int16>(A, B, C, M, N, K, LDA, | |||
LDB, LDC); | |||
megdnn::arm_common::gemv_like(A, B, C, M, N, K, LDA, LDB, LDC); | |||
} | |||
}; | |||
@@ -118,21 +121,47 @@ struct GemvLike<dt_float16, dt_float16, param::ConvBias::Format::NCHW> { | |||
} | |||
}; | |||
#endif | |||
#endif | |||
template <> | |||
struct GemvLike<dt_uint8, dt_int32, param::ConvBias::Format::NCHW> { | |||
inline static void do_gemv(const dt_uint8* A, const dt_uint8* B, | |||
dt_int32* C, size_t M, size_t N, size_t K, | |||
size_t LDA, size_t LDB, size_t LDC, DType src, | |||
struct GemvLike<dt_int8, dt_int32, param::ConvBias::Format::NCHW> { | |||
inline static void do_gemv(const dt_int8* A, const dt_int8* B, dt_int32* C, | |||
size_t M, size_t N, size_t K, size_t LDA, | |||
size_t LDB, size_t LDC, DType src, | |||
DType filter) { | |||
uint8_t zp0 = src.param<dtype::Quantized8Asymm>().zero_point; | |||
uint8_t zp1 = filter.param<dtype::Quantized8Asymm>().zero_point; | |||
megdnn::fallback::gemv_like<dt_uint8, dt_int32>(A, B, C, M, N, K, LDA, | |||
LDB, LDC, zp0, zp1); | |||
MEGDNN_MARK_USED_VAR(src); | |||
MEGDNN_MARK_USED_VAR(filter); | |||
megdnn::arm_common::gemv_like(A, B, C, M, N, K, LDA, LDB, LDC); | |||
} | |||
}; | |||
template <typename stype, typename btype> | |||
struct GemvLike<stype, btype, param::ConvBias::Format::NCHW44> { | |||
inline static void do_gemv(const stype* A, const stype* B, btype* C, | |||
size_t M, size_t N, size_t K, size_t LDA, | |||
size_t LDB, size_t LDC, DType src, | |||
DType filter) { | |||
MEGDNN_MARK_USED_VAR(src); | |||
MEGDNN_MARK_USED_VAR(filter); | |||
megdnn::arm_common::gemv_like_mk4(A, B, C, M, N, K, LDA, LDB, LDC); | |||
} | |||
}; | |||
#if __ARM_FEATURE_DOTPROD | |||
template <typename stype, typename btype> | |||
struct GemvLike<stype, btype, param::ConvBias::Format::NCHW44_DOT> { | |||
inline static void do_gemv(const stype* A, const stype* B, btype* C, | |||
size_t M, size_t N, size_t K, size_t LDA, | |||
size_t LDB, size_t LDC, DType src, | |||
DType filter) { | |||
MEGDNN_MARK_USED_VAR(src); | |||
MEGDNN_MARK_USED_VAR(filter); | |||
megdnn::arm_common::gemv_like_mk4_dot(A, B, C, M, N, K, LDA, LDB, LDC); | |||
} | |||
}; | |||
#endif | |||
#endif | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode, | |||
@@ -185,19 +214,18 @@ struct Conv1x1GemvWorker { | |||
is_dst_8bit ? matmul_temp_dst | |||
: reinterpret_cast<bias_ctype*>(conv_bias_dst); | |||
size_t pack_size = megdnn::fallback::pack_size(format); | |||
GemvLike<src_ctype, bias_ctype, format>::do_gemv( | |||
Aptr, Bptr, gemv_dst, oc_end - oc_start, 1, IC, IC, 1, 1, | |||
ncb_param.filter_type, ncb_param.src_type); | |||
Aptr, Bptr, gemv_dst, oc_end - oc_start, 1, IC, IC * pack_size, | |||
pack_size, pack_size, ncb_param.filter_type, | |||
ncb_param.src_type); | |||
//! do postprocess | |||
void* bias_ptr = nullptr; | |||
if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||
if (param.bias_mode != megdnn::BiasMode::NO_BIAS) { | |||
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | |||
ncb_param.bias<bias_ctype>(batch_id, group_id) + | |||
numbers_of_ncb_dst_offset)); | |||
} else { | |||
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | |||
ncb_param.bias<bias_ctype>(batch_id, group_id) + oc_start)); | |||
} | |||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||
@@ -211,9 +239,13 @@ struct Conv1x1GemvWorker { | |||
size_t ConvBiasImpl::AlgoConv1x1Gemv::get_oc_tile_size_heuristic( | |||
const NCBKernSizeParam& param) const { | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t oc_block_size_one_thread = div_ceil(OC, param.nr_threads); | |||
return round_up<size_t>(oc_block_size_one_thread, 16); | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, | |||
midout_iv("AlgoConv1x1Gemv::get_oc_tile"_hash)) { | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t oc_block_size_one_thread = div_ceil(OC, param.nr_threads); | |||
return round_up<size_t>(oc_block_size_one_thread, 16); | |||
} | |||
MIDOUT_END(); | |||
} | |||
size_t ConvBiasImpl::AlgoConv1x1Gemv::get_workspace( | |||
@@ -286,6 +318,11 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
cb1(param::ConvBias::Format::NCHW, dt_float16, __fp16, | |||
PostprocessMode::FLOAT, "NCHW::GEMV::FLOAT16_FP16"_hash); | |||
#else | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
cb1(param::ConvBias::Format::NCHW, dt_float16, dt_float16, | |||
PostprocessMode::NO_PROCESS, "NCHW::GEMV::FLOAT16_FLOAT16"_hash); | |||
#endif | |||
#endif | |||
cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, | |||
dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
@@ -311,6 +348,37 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||
"NCHW::GEMV::QUINT8x8x32_QUINT8"_hash); | |||
break; | |||
case param::ConvBias::Format::NCHW44: | |||
cb1(param::ConvBias::Format::NCHW44, dt_float32, dt_float32, | |||
PostprocessMode::FLOAT, "NCHW44::GEMV::FLOAT"_hash); | |||
cb2(param::ConvBias::Format::NCHW44, dt_int8, dt_int32, dt_int32, | |||
dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
"NCHW44::GEMV::INT8x8x32_INT32"_hash); | |||
cb2(param::ConvBias::Format::NCHW44, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | |||
dt_int32, PostprocessMode::NO_PROCESS, | |||
"NCHW44::GEMV::QINT8x8x32_QINT32"_hash); | |||
cb2(param::ConvBias::Format::NCHW44, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | |||
dt_int8, PostprocessMode::QUANTIZED, | |||
"NCHW44::GEMV::QINT8x8x32_QINT8"_hash); | |||
break; | |||
case param::ConvBias::Format::NCHW44_DOT: | |||
cb2(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32, | |||
dt_int32, dt_int8, dt_int32, dt_int32, | |||
PostprocessMode::NO_PROCESS, | |||
"NCHW44_DOT::GEMV::INT8x8x32_INT32"_hash); | |||
cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | |||
dt_int32, PostprocessMode::NO_PROCESS, | |||
"NCHW44_DOT::GEMV::QINT8x8x32_QINT32"_hash); | |||
cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | |||
dt_int8, PostprocessMode::QUANTIZED, | |||
"NCHW44_DOT::GEMV::QINT8x8x32_QINT8"_hash); | |||
break; | |||
default: | |||
megdnn_throw("Invalid Format"); | |||
break; | |||
@@ -338,6 +406,16 @@ bool ConvBiasImpl::AlgoConv1x1Gemv::usable(ConvBiasImpl* opr, | |||
AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, | |||
midout_iv("AlgoConv1x1Gemv::usable"_hash)) { | |||
#if MEGDNN_X86 | |||
if (opr->param().format != param::ConvBias::Format::NCHW) | |||
return false; | |||
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
if (opr->param().format != param::ConvBias::Format::NCHW && | |||
opr->param().format != param::ConvBias::Format::NCHW44 && | |||
opr->param().format != param::ConvBias::Format::NCHW44_DOT) | |||
return false; | |||
#endif | |||
//! whether 1x1 | |||
size_t FH = param.filter_meta.spatial[0], | |||
FW = param.filter_meta.spatial[1]; | |||
@@ -390,59 +468,43 @@ bool ConvBiasImpl::AlgoConv1x1Gemv::usable(ConvBiasImpl* opr, | |||
param.src_type.enumv() != DTypeEnum::Float32) { | |||
return false; | |||
} | |||
bool is_param_ok = | |||
(param.filter_meta.dilation[0] == | |||
param.filter_meta.dilation[1] && | |||
param.filter_meta.dilation[0] == 1) && | |||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT; | |||
bool is_format_and_dtype_ok = false; | |||
#if MEGDNN_X86 | |||
if (opr->param().format == param::ConvBias::Format::NCHW) { | |||
//! x86 supports all dtypes in NCHW | |||
is_format_and_dtype_ok = true; | |||
} | |||
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
//! add NCHW44 and NCHW44_DOT support in the future | |||
if (opr->param().format == param::ConvBias::Format::NCHW) { | |||
//! NCHW format supports all dtype | |||
is_format_and_dtype_ok = true; | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
if (opr->param().format == param::ConvBias::Format::NCHW44) { | |||
if (param.src_type.enumv() != DTypeEnum::Float32 && | |||
param.src_type.enumv() != DTypeEnum::Int8 && | |||
param.src_type.enumv() != DTypeEnum::QuantizedS8) { | |||
return false; | |||
} | |||
} else if (opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | |||
if (param.src_type.enumv() != DTypeEnum::Int8 && | |||
param.src_type.enumv() != DTypeEnum::QuantizedS8) { | |||
return false; | |||
} | |||
} | |||
#endif | |||
return is_param_ok && is_format_and_dtype_ok; | |||
return (param.filter_meta.dilation[0] == | |||
param.filter_meta.dilation[1] && | |||
param.filter_meta.dilation[0] == 1) && | |||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
bool ConvBiasImpl::AlgoConv1x1Gemv::is_preferred( | |||
ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
size_t OC = param.filter_meta.ocpg; | |||
if (OC <= 2 && param.src_type.enumv() != DTypeEnum::Float32) | |||
return true; | |||
ConvBiasImpl* opr, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, | |||
midout_iv("AlgoConv1x1Gemv::is_preferred"_hash)) { | |||
#if (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
//! maybe add support for QuantizedAsym in the future | |||
return (param.src_type.enumv() == DTypeEnum::Int8 && | |||
param.filter_type.enumv() == DTypeEnum::Int8 && | |||
param.dst_type.enumv() == DTypeEnum::Int32) || | |||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | |||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) || | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
(param.src_type.enumv() == DTypeEnum::Float16 && | |||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||
param.dst_type.enumv() == DTypeEnum::Float16) || | |||
if (opr->param().format == param::ConvBias::Format::NCHW && | |||
param.src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||
return false; | |||
} | |||
#endif | |||
(param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32); | |||
#else | |||
return true; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
#endif | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -2036,7 +2036,6 @@ void benchmark_conv1x1(const char* matmul_algo_name, Handle* handle, | |||
RUNS; | |||
auto matmul_used = benchmark_matmul.exec({A, B, {}}) / RUNS; | |||
printf("\n%s: ", matmul_algo_name); | |||
printf("%s %s:\n matmul: %f ms %f Gflops\nconv1x1: %f ms %f GFlops " | |||
"speedup: " | |||
"%f\n", | |||
@@ -2120,6 +2119,82 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_INT8x8x16) { | |||
#endif | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_GEMV_FP32) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args; | |||
param::ConvBias conv_param; | |||
conv_param.stride_h = 1; | |||
conv_param.stride_w = 1; | |||
conv_param.pad_h = 0; | |||
conv_param.pad_w = 0; | |||
conv_param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; | |||
auto run = [&](size_t M, size_t K){ | |||
args.emplace_back(conv_param, TensorShape{1, K, 1, 1}, | |||
TensorShape{M, K, 1, 1}, TensorShape{}); | |||
}; | |||
for (size_t M : {4, 64, 1024, 4096}) | |||
for (size_t K : {128, 256, 1024, 4096}) | |||
run(M, K); | |||
constexpr size_t RUNS = 50; | |||
param::MatrixMul param; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
Benchmarker<MatrixMul> benchmark_matmul(handle()); | |||
benchmark_matmul.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>("ARM_COMMON_F32_GEMV")); | |||
benchmark_matmul.set_times(RUNS) | |||
.set_dtype(0, dtype::Float32{}) | |||
.set_dtype(1, dtype::Float32{}) | |||
.set_dtype(2, dtype::Float32{}) | |||
.set_param(param) | |||
.set_display(false); | |||
Benchmarker<ConvBias> benchmark_conv1x1(handle()); | |||
benchmark_conv1x1.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBias>("CONV1x1_GEMV")); | |||
benchmark_conv1x1.set_times(RUNS) | |||
.set_dtype(0, dtype::Float32{}) | |||
.set_dtype(1, dtype::Float32{}) | |||
.set_dtype(2, dtype::Float32{}) | |||
.set_dtype(4, dtype::Float32{}) | |||
.set_display(false); | |||
std::cout << "warm up:\n"; | |||
for (int i = 0; i < 50; i++) { | |||
benchmark_matmul.exec({{1, 1024}, {1024, 512}, {}}); | |||
benchmark_matmul.set_display(true); | |||
} | |||
for (auto&& arg : args) { | |||
size_t IC = arg.src[1]; | |||
size_t OH = arg.src[2]; | |||
size_t OW = arg.src[3]; | |||
size_t OC = arg.filter[0]; | |||
size_t M = OC; | |||
size_t K = IC; | |||
size_t N = OH * OW; | |||
float computations = M * N * K * 2.f / (1024 * 1024 * 1024) * 1e3; | |||
TensorShape A, B; | |||
A = TensorShape{M, K}; | |||
B = TensorShape{K, N}; | |||
auto conv1x1_used = benchmark_conv1x1.set_param(arg.param).exec( | |||
{arg.src, arg.filter, arg.bias, {}, {}}) / | |||
RUNS; | |||
auto matmul_used = benchmark_matmul.exec({A, B, {}}) / RUNS; | |||
printf("%s %s:\n gemv: %f ms %f Gflops\nconv1x1: %f ms %f GFlops " | |||
"speedup: " | |||
"%f\n", | |||
arg.src.to_string().c_str(), arg.filter.to_string().c_str(), | |||
matmul_used, computations / matmul_used, conv1x1_used, | |||
computations / conv1x1_used, matmul_used / conv1x1_used); | |||
} | |||
} | |||
#ifndef __ARM_FEATURE_DOTPROD | |||
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) { | |||
std::vector<TestArg> conv_bias_1x1_args_nchw44 = | |||
@@ -180,12 +180,15 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( | |||
for (size_t kernel : kernel_vec) | |||
for (size_t oc : {4, 12}) | |||
for (size_t ic : {1, 3, 4, 12}) | |||
for (size_t h : {3, 5, 12}) | |||
for (size_t w : {7, 16, 23}) { | |||
for (size_t h : {1, 3, 12}) | |||
for (size_t w : {1, 16, 23}) { | |||
for (size_t group = 1; | |||
group <= | |||
std::min(std::min(oc, ic), 4_z); | |||
++group) { | |||
if (kernel != 1 && (h == 1 || w == 1)) { | |||
continue; | |||
} | |||
pack(n, oc, ic, h, w, kernel, stride, | |||
group, nlmode, bias); | |||
} | |||
@@ -1897,6 +1900,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) { | |||
#elif MEGDNN_ARMV7 | |||
check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32_MK4_PACK_4X12:24"); | |||
#endif | |||
std::vector<conv_bias::TestArg> gemv_args; | |||
for (auto&& arg : args) | |||
if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||
gemv_args.emplace_back(arg); | |||
} | |||
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) { | |||
@@ -1932,7 +1941,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) { | |||
#endif | |||
std::vector<conv_bias::TestArg> gemv_args; | |||
for (auto&& arg : args) | |||
if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||
gemv_args.emplace_back(arg); | |||
} | |||
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); | |||
@@ -2138,4 +2147,40 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) { | |||
} | |||
#endif | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = | |||
get_nchw44_conv_bias_args({1}, 1, true, false, false); | |||
UniformIntRNG rng{-50, 50}; | |||
float epsilon = 0.001; | |||
std::vector<conv_bias::TestArg> gemv_args; | |||
for (auto&& arg : args) | |||
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||
gemv_args.emplace_back(arg); | |||
} | |||
checker_conv_bias(gemv_args, handle(), &rng, epsilon, | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | |||
dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), | |||
"CONV1x1_GEMV"); | |||
} | |||
#ifdef __ARM_FEATURE_DOTPROD | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44_DOT) { | |||
using namespace conv_bias; | |||
std::vector<conv_bias::TestArg> args = | |||
get_nchw44_conv_bias_args({1}, 1, true, false, false, false, true); | |||
UniformIntRNG rng{-50, 50}; | |||
float epsilon = 0.001; | |||
std::vector<conv_bias::TestArg> gemv_args; | |||
for (auto&& arg : args) | |||
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||
gemv_args.emplace_back(arg); | |||
} | |||
checker_conv_bias(gemv_args, handle(), &rng, epsilon, | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | |||
dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), | |||
"CONV1x1_GEMV"); | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -156,7 +156,7 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV) { | |||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||
.execs({A, B, {}}); | |||
}; | |||
// N = 1 | |||
for (size_t M : {1, 10, 16, 33, 64}) | |||
for (size_t K : {7, 512, 1024}) | |||
@@ -164,6 +164,70 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV) { | |||
run(M, K, N); | |||
} | |||
TEST_F(ARM_COMMON, QINT8x8x32_GEMV_MK4) { | |||
Checker<MatrixMul> checker(handle()); | |||
using Param = MatrixMul::Param; | |||
checker.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEMV_MK4")); | |||
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127); | |||
checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | |||
auto run = [&](size_t M, size_t K, size_t N) { | |||
Param param; | |||
param.format = param::MatrixMul::Format::MK4; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
TensorShape A, B; | |||
A = TensorShape{M / 4, K / 4, 4, 4}; | |||
B = TensorShape{K / 4, 1, 4}; | |||
checker.set_param(param) | |||
.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||
.execs({A, B, {}}); | |||
}; | |||
// N = 1 | |||
for (size_t M : {4, 16, 128, 1024}) | |||
for (size_t K : {4, 8, 12, 16, 20, 24, 256, 1024}) | |||
run(M, K, 1); | |||
} | |||
#if __ARM_FEATURE_DOTPROD | |||
TEST_F(ARM_COMMON, QINT8x8x32_GEMV_MK4_DOT) { | |||
Checker<MatrixMul> checker(handle()); | |||
using Param = MatrixMul::Param; | |||
checker.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>("ARM_COMMON_INT8X8X32_GEMV_MK4_DOT")); | |||
std::unique_ptr<RNG> rng = std::make_unique<UniformIntRNG>(-127, 127); | |||
checker.set_rng(0, rng.get()).set_rng(1, rng.get()); | |||
auto run = [&](size_t M, size_t K, size_t N) { | |||
Param param; | |||
param.format = param::MatrixMul::Format::MK4_DOT; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
TensorShape A, B; | |||
A = TensorShape{M / 4, K / 4, 4, 4}; | |||
B = TensorShape{K / 4, 1, 4}; | |||
checker.set_param(param) | |||
.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||
.execs({A, B, {}}); | |||
}; | |||
// N = 1 | |||
for (size_t M : {4, 16, 128, 1024}) | |||
for (size_t K : {4, 8, 12, 16, 20, 24, 256, 1024}) | |||
run(M, K, 1); | |||
} | |||
#endif | |||
TEST_F(ARM_COMMON, QINT8x8x32_GEVM) { | |||
Checker<MatrixMul> checker(handle()); | |||
using Param = MatrixMul::Param; | |||
@@ -220,6 +284,31 @@ TEST_F(ARM_COMMON, FP32_GEVM) { | |||
run(M, K, N); | |||
} | |||
TEST_F(ARM_COMMON, FP32_GEMV_MK4) { | |||
Checker<MatrixMul> checker(handle()); | |||
using Param = MatrixMul::Param; | |||
checker.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>("ARM_COMMON_F32_GEMV_MK4")); | |||
checker.set_epsilon(1e-2); | |||
auto run = [&](size_t M, size_t K) { | |||
Param param; | |||
param.format = param::MatrixMul::Format::MK4; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
TensorShape A, B; | |||
A = TensorShape{M/4, K/4, 4, 4}; | |||
B = TensorShape{K/4, 1, 4}; | |||
checker.set_param(param).execs({A, B, {}}); | |||
}; | |||
// N = 1 | |||
for (size_t M : {4, 16, 128, 1024}) | |||
for (size_t K : {4, 8, 12, 128, 256, 4096}) | |||
run(M, K); | |||
} | |||
#if MEGDNN_WITH_BENCHMARK | |||
TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { | |||
@@ -228,18 +317,16 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { | |||
benchmarker.set_times(exec_times); | |||
auto run = [&](size_t M, size_t K, size_t N) { | |||
std::cout << "SGEMV: (" << M << ", " << K << ", " << N << ")" | |||
<< std::endl; | |||
printf("SGEMV: (%zu, %zu, %zu)\n", M, K, N); | |||
benchmarker.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()); | |||
auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times; | |||
auto computations = 2.f * M * K * N * 1e-6; | |||
auto perf = computations / time; | |||
std::cout << "gemv fp32, Performance is " << perf << " Gflops" | |||
<< std::endl; | |||
printf("gemv fp32, Performance is %f Gflops\n", perf); | |||
}; | |||
std::cout << "warm up:\n"; | |||
printf("warm up:\n"); | |||
for (int i = 0; i < 50; i++) { | |||
benchmarker.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()) | |||
@@ -253,6 +340,10 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { | |||
for (size_t K : {1024, 1536, 2048}) | |||
for (size_t N : {512, 1024}) | |||
run(M, K, N); | |||
for (size_t M : {4, 64, 1024, 4096}) | |||
for (size_t K : {128, 256, 1024, 4096}) | |||
run(M, K, 1); | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { | |||
@@ -263,28 +354,25 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { | |||
AlgoChecker<MatrixMul>("ARM_COMMON_F32_GEMV")); | |||
auto run = [&](size_t M, size_t K, size_t N) { | |||
std::cout << "SGEMV: (" << M << ", " << K << ", " << N << ")" | |||
<< std::endl; | |||
printf("SGEMV: (%zu, %zu, %zu)\n", M, K, N); | |||
benchmarker.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()) | |||
.set_dtype(2, dtype::Float32()); | |||
auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times; | |||
auto computations = 2 * M * K * N * 1e-6; | |||
auto perf = computations / time; | |||
std::cout << "gemv fp32, Performance is " << perf << " Gflops" | |||
<< std::endl; | |||
printf("gemv fp32, Performance is %f Gflops\n", perf); | |||
}; | |||
std::cout << "warm up:\n"; | |||
printf("warm up:\n"); | |||
for (int i = 0; i < 50; i++) { | |||
benchmarker.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()) | |||
.set_dtype(2, dtype::Float32()) | |||
.set_display(false) | |||
.exec({{2, 1024}, {1024, 512}, {}}); | |||
benchmarker.set_display(true); | |||
} | |||
// run gemv | |||
run(12, 48, 1); | |||
run(48, 12, 1); | |||
@@ -298,6 +386,45 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { | |||
run(1024, 256, 1); | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_SGEMV_MK4) { | |||
int exec_times = 10; | |||
using Param = MatrixMul::Param; | |||
Param param; | |||
param.format = param::MatrixMul::Format::MK4; | |||
param.transposeA = false; | |||
param.transposeB = false; | |||
Benchmarker<MatrixMul> benchmarker(handle()); | |||
benchmarker.set_times(exec_times); | |||
benchmarker.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()) | |||
.set_param(param); | |||
auto run = [&](size_t M, size_t K) { | |||
printf("SGEMV_MK4: (%zu, %zu, %zu)\n", M, K, N); | |||
TensorShape A, B; | |||
A = TensorShape{M/4, K/4, 4, 4}; | |||
B = TensorShape{K/4, 1, 4}; | |||
auto time = benchmarker.exec({A, B, {}}) / exec_times; | |||
auto computations = 2.f * M * K * 1e-6; | |||
auto perf = computations / time; | |||
printf("gemv mk4 fp32, Performance is %f Gflops\n", perf); | |||
}; | |||
printf("warm up:\n"); | |||
for (int i = 0; i < 50; i++) { | |||
benchmarker.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()) | |||
.set_dtype(2, dtype::Float32()) | |||
.set_display(false) | |||
.exec({{4, 256, 4, 4}, {256, 1, 4}, {}}); | |||
} | |||
// run gemv mk4 | |||
for (size_t M : {4, 64, 1024, 4096}) | |||
for (size_t K : {128, 1024, 4096}) | |||
run(M, K); | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) { | |||
int exec_times = 50; | |||
Benchmarker<MatrixMul> benchmarker(handle()); | |||
@@ -306,19 +433,17 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) { | |||
AlgoChecker<MatrixMul>("ARM_COMMON_F16_GEMV")); | |||
auto run = [&](size_t M, size_t K, size_t N) { | |||
std::cout << "SGEMV: (" << M << ", " << K << ", " << N << ")" | |||
<< std::endl; | |||
printf("SGEMV_FP16: (%zu, %zu, %zu)\n", M, K, N); | |||
benchmarker.set_dtype(0, dtype::Float16()) | |||
.set_dtype(1, dtype::Float16()) | |||
.set_dtype(2, dtype::Float16()); | |||
auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times; | |||
auto computations = 2 * M * K * N * 1e-6; | |||
auto perf = computations / time; | |||
std::cout << "gemv fp16, Performance is " << perf << " Gflops" | |||
<< std::endl; | |||
printf("gemv fp16, Performance is %f Gflops\n", perf); | |||
}; | |||
std::cout << "warm up:\n"; | |||
printf("warm up:\n"); | |||
for (int i = 0; i < 50; i++) { | |||
benchmarker.set_dtype(0, dtype::Float16()) | |||
.set_dtype(1, dtype::Float16()) | |||
@@ -343,17 +468,15 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMM) { | |||
float mod = 1000 * exec_times / 1e9; | |||
auto run = [&](size_t M, size_t K, size_t N) { | |||
float time = 1.f, perf = 1.f; | |||
std::cout << "SGEMM: (" << M << ", " << K << ", " << N << ")" | |||
<< std::endl; | |||
printf("SGEMM: (%zu, %zu, %zu)\n", M, K, N); | |||
benchmarker.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()); | |||
time = benchmarker.exec({{M, K}, {K, N}, {}}); | |||
perf = 2.f * M * K * N / time * mod; | |||
std::cout << "gemm fp32, Performance is " << perf << " Gflops" | |||
<< std::endl; | |||
printf("gemm, Performance is %f Gflops\n", perf); | |||
}; | |||
std::cout << "warm up:\n"; | |||
printf("warm up:\n"); | |||
for (int i = 0; i < 50; i++) { | |||
benchmarker.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()) | |||