GitOrigin-RevId: 97679e8526
release-0.6
@@ -193,7 +193,7 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||
DEFAULT \ | |||
} | |||
#define FOR_BIAS(_bias_mode) \ | |||
#define FOR_BIAS(_bias_mode, OH, OW) \ | |||
switch (_bias_mode) { \ | |||
case megdnn::BiasMode::NO_BIAS: \ | |||
FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY); \ | |||
@@ -208,6 +208,10 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||
} \ | |||
break; \ | |||
default: \ | |||
if (OH * OW == 1) { \ | |||
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ | |||
break; \ | |||
} \ | |||
megdnn_throw("quantized unsupported biasmode"); \ | |||
break; \ | |||
} | |||
@@ -218,7 +222,9 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { | |||
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | |||
megdnn::DType bias_type, megdnn::DType dst_type, size_t N, | |||
size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
FOR_BIAS(bias_mode); | |||
//! when OH * OW = 1, the bias_mode will be BiasMode::BIAS. It is wrong, | |||
//! we deal this case at default branch. | |||
FOR_BIAS(bias_mode, OH, OW); | |||
} | |||
}; | |||
@@ -43,9 +43,9 @@ void exec_int_8x8x16(const MatrixMulImpl::KernParam& kern_param) { | |||
size_t N = kern_param.N; | |||
size_t K = kern_param.K; | |||
size_t LDB = kern_param.LDB; | |||
exec_gemm_int8_int8_int16(kern_param.A<dt_int8>(), | |||
kern_param.B<dt_int8>(), | |||
kern_param.C<dt_int16>(), M, K, N, LDB, w0, w1); | |||
exec_gemm_int8_int8_int16( | |||
kern_param.A<dt_int8>(), kern_param.B<dt_int8>(), | |||
kern_param.C<dt_int16>(), M, K, N, LDB, w0, w1); | |||
} | |||
MIDOUT_END(); | |||
} | |||
@@ -79,8 +79,7 @@ void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int32>(); | |||
arm_common::matmul::gemv_like_int8(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, | |||
LDC); | |||
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
} | |||
} // anonymous namespace | |||
@@ -110,7 +109,7 @@ void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
const auto Aptr = kern_param.A<dt_float32>(), | |||
Bptr = kern_param.B<dt_float32>(); | |||
auto Cptr = kern_param.C<dt_float32>(); | |||
arm_common::sgemm_sgemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
} | |||
} // anonymous namespace | |||
@@ -140,25 +139,14 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern( | |||
/* ===================== F32 Gevm algo ===================== */ | |||
namespace { | |||
void gevm_fp32_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDB = kern_param.LDB; | |||
const auto Aptr = kern_param.A<dt_float32>(), | |||
Bptr = kern_param.B<dt_float32>(); | |||
auto Cptr = kern_param.C<dt_float32>(); | |||
arm_common::sgemm_sgemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1); | |||
} | |||
void gevm_int8_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
template <typename stype, typename dtype> | |||
void gevm_like_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDB = kern_param.LDB; | |||
const auto Aptr = kern_param.A<dt_int8>(), | |||
Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int32>(); | |||
arm_common::matmul::gemv_like_int8(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1); | |||
const auto Aptr = kern_param.A<stype>(), Bptr = kern_param.B<stype>(); | |||
auto Cptr = kern_param.C<dtype>(); | |||
megdnn::arm_common::gemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1); | |||
} | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoGevm::usable( | |||
@@ -170,8 +158,16 @@ bool MatrixMulImpl::AlgoGevm::usable( | |||
kern_size_param.B_type == kern_size_param.A_type && | |||
kern_size_param.C_type == kern_size_param.A_type && | |||
kern_size_param.A_type == dtype::Float32(); | |||
return (fp32_ok || can_be_treated_as_int8x8x32(kern_size_param)) && | |||
preferred(kern_size_param); | |||
bool fp16_ok = false; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
fp16_ok = kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && | |||
kern_size_param.format == param::MatrixMul::Format::DEFAULT && | |||
kern_size_param.B_type == kern_size_param.A_type && | |||
kern_size_param.C_type == kern_size_param.A_type && | |||
kern_size_param.A_type == dtype::Float16(); | |||
#endif | |||
bool int8_ok = can_be_treated_as_int8x8x32(kern_size_param); | |||
return (fp32_ok || fp16_ok || int8_ok) && preferred(kern_size_param); | |||
} | |||
bool MatrixMulImpl::AlgoGevm::preferred( | |||
@@ -183,11 +179,17 @@ bool MatrixMulImpl::AlgoGevm::preferred( | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoGevm::get_kern( | |||
const KernSizeParam& kern_size_param) const { | |||
if (kern_size_param.A_type == dtype::Float32()) { | |||
return gevm_fp32_kern; | |||
return gevm_like_kern<dt_float32, dt_float32>; | |||
} else if (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || | |||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) { | |||
return gevm_int8_kern; | |||
} else { | |||
return gevm_like_kern<dt_int8, dt_int32>; | |||
} | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
else if (kern_size_param.A_type == dtype::Float16()) { | |||
return gevm_like_kern<__fp16, __fp16>; | |||
} | |||
#endif | |||
else { | |||
megdnn_assert( | |||
false, "no avaliable kern got A_type: %s B_type: %s C_type: %s", | |||
kern_size_param.A_type.name(), kern_size_param.B_type.name(), | |||
@@ -205,10 +207,10 @@ void f16_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
Bptr = kern_param.B<dt_float16>(); | |||
auto Cptr = kern_param.C<dt_float16>(); | |||
MIDOUT_BEGIN(megdnn_arm_hgemv, void) { | |||
arm_common::hgemv_exec(reinterpret_cast<const __fp16*>(Aptr), | |||
reinterpret_cast<const __fp16*>(Bptr), | |||
reinterpret_cast<__fp16*>(Cptr), M, N, K, LDA, | |||
LDB, LDC); | |||
arm_common::gemv_like(reinterpret_cast<const __fp16*>(Aptr), | |||
reinterpret_cast<const __fp16*>(Bptr), | |||
reinterpret_cast<__fp16*>(Cptr), M, N, K, LDA, | |||
LDB, LDC); | |||
} | |||
MIDOUT_END(); | |||
} | |||
@@ -96,11 +96,11 @@ void hgemv_naive_n(const __fp16* __restrict A, const __fp16* __restrict B, | |||
} | |||
} // namespace | |||
void megdnn::arm_common::hgemv_exec(const __fp16* __restrict A, | |||
const __fp16* __restrict B, | |||
__fp16* __restrict C, size_t M, size_t N, | |||
size_t K, size_t Astride, size_t Bstride, | |||
size_t Cstride) { | |||
void megdnn::arm_common::gemv_like(const __fp16* __restrict A, | |||
const __fp16* __restrict B, | |||
__fp16* __restrict C, size_t M, size_t N, | |||
size_t K, size_t Astride, size_t Bstride, | |||
size_t Cstride) { | |||
megdnn_assert((M <= 4) || (M == 8 && K <= 2) || (N == 1 && Bstride == 1)); | |||
if (N == 1) { | |||
return hgemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||
@@ -16,13 +16,14 @@ | |||
namespace megdnn { | |||
namespace arm_common { | |||
void hgemv_exec(const __fp16* __restrict A, const __fp16* __restrict B, | |||
__fp16* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride); | |||
bool is_hgemv_preferred(bool transposeA, bool transposeB, size_t M, size_t N, | |||
size_t K, size_t /*LDA*/, size_t LDB, size_t /*LDC*/); | |||
void gemv_like(const __fp16* __restrict A, const __fp16* __restrict B, | |||
__fp16* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride); | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
@@ -42,40 +42,6 @@ void sgemv_naive_n(const float* __restrict A, const float* __restrict B, | |||
#define calculate(i) sum##i = vmlaq_f32(sum##i, a##i, b0); | |||
#define vstore(i) C[(m + i) * Cstride] = vaddvq_f32(sum##i) + acc##i; | |||
size_t m = 0; | |||
for (; m + 4 <= M; m += 4) { | |||
float acc0, acc1, acc2, acc3; | |||
float32x4_t a0, a1, a2, a3, b0; | |||
float32x4_t sum0, sum1, sum2, sum3; | |||
UNROLL_OUT(vdupq_sum, 4) | |||
size_t k = 0; | |||
for (; k + 4 <= K; k += 4) { | |||
UNROLL_OUT(loadA, 4) | |||
UNROLL_OUT(loadB, 1) | |||
UNROLL_OUT(calculate, 4) | |||
} | |||
UNROLL_OUT(reset_acc, 4) | |||
for (; k < K; ++k) { | |||
UNROLL_OUT(acc_calu, 4) | |||
} | |||
UNROLL_OUT(vstore, 4) | |||
} | |||
for (; m + 2 <= M; m += 2) { | |||
float acc0, acc1; | |||
float32x4_t a0, a1, b0; | |||
float32x4_t sum0, sum1; | |||
UNROLL_OUT(vdupq_sum, 2) | |||
size_t k = 0; | |||
for (; k + 4 <= K; k += 4) { | |||
UNROLL_OUT(loadA, 2) | |||
UNROLL_OUT(loadB, 1) | |||
UNROLL_OUT(calculate, 2) | |||
} | |||
UNROLL_OUT(reset_acc, 2) | |||
for (; k < K; ++k) { | |||
UNROLL_OUT(acc_calu, 2) | |||
} | |||
UNROLL_OUT(vstore, 2) | |||
} | |||
for (; m < M; m += 1) { | |||
float acc0; | |||
float32x4_t a0, b0; | |||
@@ -107,9 +73,9 @@ void sgemv_naive_n(const float* __restrict A, const float* __restrict B, | |||
namespace megdnn { | |||
namespace arm_common { | |||
void sgemm_sgemv_like(const float* __restrict A, const float* __restrict B, | |||
float* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride) { | |||
void gemv_like(const float* __restrict A, const float* __restrict B, | |||
float* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride) { | |||
megdnn_assert(M < 8 || (M == 8 && K <= 2) || (N == 1 && Bstride == 1)); | |||
if (N == 1) { | |||
return sgemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||
@@ -20,9 +20,10 @@ bool is_sgemv_like_preferred(bool row_major, bool transposeA, bool transposeB, | |||
size_t /* LDA */, size_t LDB, float beta, | |||
size_t /* LDC */); | |||
void sgemm_sgemv_like(const float* __restrict A, const float* __restrict B, | |||
float* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride); | |||
void gemv_like(const float* __restrict A, const float* __restrict B, | |||
float* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride); | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
@@ -172,9 +172,10 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, | |||
} // namespace | |||
#endif | |||
bool matmul::is_gemv_like_preferred_int8(bool transposeA, bool transposeB, | |||
size_t M, size_t N, size_t K, | |||
size_t LDA, size_t LDB, size_t LDC) { | |||
bool arm_common::is_gemv_like_preferred_int8(bool transposeA, bool transposeB, | |||
size_t M, size_t N, size_t K, | |||
size_t LDA, size_t LDB, | |||
size_t LDC) { | |||
MEGDNN_MARK_USED_VAR(LDA); | |||
MEGDNN_MARK_USED_VAR(LDB); | |||
MEGDNN_MARK_USED_VAR(LDC); | |||
@@ -188,15 +189,16 @@ bool matmul::is_gemv_like_preferred_int8(bool transposeA, bool transposeB, | |||
return N == 1 && LDB == 1; | |||
} | |||
void matmul::gemv_like_int8(const int8_t* __restrict A, | |||
const int8_t* __restrict B, int32_t* __restrict C, | |||
size_t M, size_t N, size_t K, size_t Astride, | |||
size_t Bstride, size_t Cstride) { | |||
void arm_common::gemv_like(const int8_t* __restrict A, | |||
const int8_t* __restrict B, int32_t* __restrict C, | |||
size_t M, size_t N, size_t K, size_t Astride, | |||
size_t Bstride, size_t Cstride) { | |||
megdnn_assert(N == 1); | |||
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv) { | |||
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, | |||
midout_iv("INT8_gemv_like"_hash)) { | |||
return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); | |||
} MIDOUT_END(); | |||
} | |||
MIDOUT_END(); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -15,16 +15,15 @@ | |||
namespace megdnn { | |||
namespace arm_common { | |||
namespace matmul { | |||
bool is_gemv_like_preferred_int8(bool transposeA, bool transposeB, size_t M, | |||
size_t N, size_t K, size_t LDA, size_t LDB, | |||
size_t LDC); | |||
void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride); | |||
} // namespace matmul | |||
void gemv_like(const int8_t* __restrict A, const int8_t* __restrict B, | |||
int32_t* __restrict C, size_t M, size_t N, size_t K, | |||
size_t Astride, size_t Bstride, size_t Cstride); | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
@@ -54,7 +54,7 @@ size_t ConvBiasImpl::AlgoConv1x1::get_workspace( | |||
size_t compt_oc_block_size = get_oc_tile_size_heuristic(param); | |||
auto matmul_param = | |||
get_matmul_kern_param(param, OH * OW, compt_oc_block_size); | |||
utils::get_matmul_kern_param(param, OH * OW, compt_oc_block_size); | |||
auto pack_mode = m_matmul_algo->packmode(); | |||
if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::DEFAULT) { | |||
@@ -92,7 +92,6 @@ size_t ConvBiasImpl::AlgoConv1x1::get_workspace( | |||
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoConv1x1::dispatch_kerns( | |||
ConvBiasImpl* opr, const NCBKernSizeParam& param) const { | |||
SmallVector<ConvBiasImpl::NCBKern> ret_kern; | |||
size_t OH = param.osz[0]; | |||
size_t OW = param.osz[1]; | |||
size_t OC = param.filter_meta.ocpg; | |||
@@ -102,7 +101,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoConv1x1::dispatch_kerns( | |||
size_t oc_blocks_per_group = div_ceil(OC, compt_oc_block_size); | |||
auto matmul_param = | |||
get_matmul_kern_param(param, OH * OW, compt_oc_block_size); | |||
utils::get_matmul_kern_param(param, OH * OW, compt_oc_block_size); | |||
WorkspaceBundle whole_bundle = {nullptr, {}}; | |||
WorkspaceBundle thread_bundle = {nullptr, {}}; | |||
WorkspaceBundle matmul_bundle = {nullptr, {}}; | |||
@@ -138,7 +137,7 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoConv1x1::dispatch_kerns( | |||
} | |||
//! get thread bundle | |||
thread_bundle = get_thread_bundle(param, matmul_bundle.get_size(2), | |||
thread_bundle = utils::get_thread_bundle(param, matmul_bundle.get_size(2), | |||
compt_oc_block_size); | |||
Conv1x1StrategyBase* conv1x1_strategy = | |||
@@ -178,7 +177,6 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoConv1x1::dispatch_kerns( | |||
} | |||
} | |||
ret_kern.push_back({kern_compt, {BATCH, GROUP, oc_blocks_per_group}}); | |||
return ret_kern; | |||
} | |||
@@ -201,8 +199,11 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||
if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) | |||
return false; | |||
if (param.src_type.enumv() != param.filter_type.enumv() && | |||
param.src_type.enumv() != DTypeEnum::Int8 && | |||
if(param.src_type.enumv() != param.filter_type.enumv()) { | |||
return false; | |||
} | |||
if (param.src_type.enumv() != DTypeEnum::Int8 && | |||
param.src_type.enumv() != DTypeEnum::QuantizedS8 && | |||
param.src_type.enumv() != DTypeEnum::Quantized8Asymm && | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
@@ -211,6 +212,7 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||
param.src_type.enumv() != DTypeEnum::Float32) { | |||
return false; | |||
} | |||
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode | |||
//! is identity otherwise return false mean that 8x8x32 and 8x8x16 | |||
//! not support PostProcess | |||
@@ -233,7 +235,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||
size_t OH = param.osz[0]; | |||
size_t OW = param.osz[1]; | |||
MatrixMulImpl::KernSizeParam matmul_param = get_matmul_kern_param( | |||
MatrixMulImpl::KernSizeParam matmul_param = utils::get_matmul_kern_param( | |||
param, OH * OW, get_oc_tile_size_heuristic(param)); | |||
bool matmul_usable = m_matmul_algo->usable(matmul_param); | |||
@@ -250,3 +253,27 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
bool ConvBiasImpl::AlgoConv1x1::is_preferred( | |||
ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
size_t OH = param.osz[0]; | |||
size_t OW = param.osz[1]; | |||
if (OH * OW != 1) { | |||
return true; | |||
} else { | |||
#if (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
if (param.src_type.enumv() == DTypeEnum::Int8 && | |||
param.filter_type.enumv() == DTypeEnum::Int8 && | |||
param.dst_type.enumv() == DTypeEnum::Int16) { | |||
return true; | |||
} | |||
#elif MEGDNN_X86 | |||
size_t OC = param.filter_meta.ocpg; | |||
if (OC > 2 || param.src_type.enumv() == DTypeEnum::Float32) | |||
return true; | |||
#endif | |||
return false; | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -41,9 +41,7 @@ public: | |||
SmallVector<NCBKern> dispatch_kerns( | |||
ConvBiasImpl* opr, const NCBKernSizeParam& param) const override; | |||
bool is_preferred(ConvBiasImpl*, const NCBKernSizeParam&) const override{ | |||
return true; | |||
} | |||
bool is_preferred(ConvBiasImpl*, const NCBKernSizeParam&) const override; | |||
protected: | |||
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; | |||
@@ -0,0 +1,448 @@ | |||
/** | |||
* \file dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h" | |||
#include "src/fallback/conv_bias/conv1x1/conv1x1_utils.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "src/fallback/conv_bias/common.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "src/naive/convolution/helper.h" | |||
#include "src/fallback/matrix_mul/gemv.h" | |||
#if MEGDNN_X86 | |||
#include "src/x86/conv_bias/postprocess_helper.h" | |||
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
#include "src/arm_common/conv_bias/postprocess_helper.h" | |||
#include "src/arm_common/matrix_mul/fp32/exec_sgemv.h" | |||
#include "src/arm_common/matrix_mul/fp16/hgemv.h" | |||
#include "src/arm_common/matrix_mul/int8/gemv.h" | |||
#endif | |||
#include "midout.h" | |||
MIDOUT_DECL(megdnn_fallback_conv1x1_gemv) | |||
using namespace megdnn; | |||
using namespace fallback; | |||
#if MEGDNN_X86 | |||
using namespace x86; | |||
#endif | |||
using namespace conv1x1; | |||
namespace { | |||
#if MEGDNN_X86 | |||
template <typename stype, typename btype, param::ConvBias::Format F> | |||
struct GemvLike { | |||
inline static void do_gemv(const stype* A, const stype* B, btype* C, | |||
size_t M, size_t N, size_t K, size_t LDA, | |||
size_t LDB, size_t LDC, DType src, | |||
DType filter) { | |||
megdnn_throw("x86 conv1x1 gemv only supports format : NCHW"); | |||
} | |||
}; | |||
template <typename stype, typename btype> | |||
struct GemvLike<stype, btype, param::ConvBias::Format::NCHW> { | |||
inline static void do_gemv(const stype* A, const stype* B, btype* C, | |||
size_t M, size_t N, size_t K, size_t LDA, | |||
size_t LDB, size_t LDC, DType src, | |||
DType filter) { | |||
MEGDNN_MARK_USED_VAR(src); | |||
MEGDNN_MARK_USED_VAR(filter); | |||
megdnn::fallback::gemv_like<stype, btype>(A, B, C, M, N, K, LDA, LDB, | |||
LDC); | |||
} | |||
}; | |||
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
template <typename stype, typename btype, param::ConvBias::Format F> | |||
struct GemvLike { | |||
inline static void do_gemv(const stype* A, const stype* B, btype* C, | |||
size_t M, size_t N, size_t K, size_t LDA, | |||
size_t LDB, size_t LDC, DType src, | |||
DType filter) { | |||
megdnn_throw("arm conv1x1 gemv only supports format : NCHW"); | |||
} | |||
}; | |||
template <typename stype, typename btype> | |||
struct GemvLike<stype, btype, param::ConvBias::Format::NCHW> { | |||
inline static void do_gemv(const stype* A, const stype* B, btype* C, | |||
size_t M, size_t N, size_t K, size_t LDA, | |||
size_t LDB, size_t LDC, DType src, | |||
DType filter) { | |||
MEGDNN_MARK_USED_VAR(src); | |||
MEGDNN_MARK_USED_VAR(filter); | |||
megdnn::arm_common::gemv_like(A, B, C, M, N, K, LDA, LDB, LDC); | |||
} | |||
}; | |||
template <> | |||
struct GemvLike<dt_int8, dt_int16, param::ConvBias::Format::NCHW> { | |||
inline static void do_gemv(const dt_int8* A, const dt_int8* B, dt_int16* C, | |||
size_t M, size_t N, size_t K, size_t LDA, | |||
size_t LDB, size_t LDC, DType src, | |||
DType filter) { | |||
MEGDNN_MARK_USED_VAR(src); | |||
MEGDNN_MARK_USED_VAR(filter); | |||
megdnn::fallback::gemv_like<dt_int8, dt_int16>(A, B, C, M, N, K, LDA, | |||
LDB, LDC); | |||
} | |||
}; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
template <> | |||
struct GemvLike<dt_float16, dt_float16, param::ConvBias::Format::NCHW> { | |||
inline static void do_gemv(const dt_float16* A, const dt_float16* B, | |||
dt_float16* C, size_t M, size_t N, size_t K, | |||
size_t LDA, size_t LDB, size_t LDC, DType src, | |||
DType filter) { | |||
MEGDNN_MARK_USED_VAR(src); | |||
MEGDNN_MARK_USED_VAR(filter); | |||
megdnn::arm_common::gemv_like(reinterpret_cast<const __fp16*>(A), | |||
reinterpret_cast<const __fp16*>(B), | |||
reinterpret_cast<__fp16*>(C), M, N, K, | |||
LDA, LDB, LDC); | |||
} | |||
}; | |||
#endif | |||
#endif | |||
template <> | |||
struct GemvLike<dt_uint8, dt_int32, param::ConvBias::Format::NCHW> { | |||
inline static void do_gemv(const dt_uint8* A, const dt_uint8* B, | |||
dt_int32* C, size_t M, size_t N, size_t K, | |||
size_t LDA, size_t LDB, size_t LDC, DType src, | |||
DType filter) { | |||
uint8_t zp0 = src.param<dtype::Quantized8Asymm>().zero_point; | |||
uint8_t zp1 = filter.param<dtype::Quantized8Asymm>().zero_point; | |||
megdnn::fallback::gemv_like<dt_uint8, dt_int32>(A, B, C, M, N, K, LDA, | |||
LDB, LDC, zp0, zp1); | |||
} | |||
}; | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
typename op_ctype, typename op_dtype, | |||
megdnn::PostprocessMode postprocess_mode, | |||
param::ConvBias::Format format> | |||
struct Conv1x1GemvWorker { | |||
static void exec(WorkspaceBundle& whole_bundle, | |||
WorkspaceBundle& thread_bundle, size_t oc_tile_size, | |||
const ConvBiasImpl::NCBKernSizeParam& param, | |||
const ConvBiasImpl::NCBKernParam& ncb_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
whole_bundle.set(ncb_param.workspace_ptr); | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t IC = param.filter_meta.icpg; | |||
size_t batch_id = ncb_index.ndrange_id[0]; | |||
size_t group_id = ncb_index.ndrange_id[1]; | |||
size_t oc_tile_id_in_group = ncb_index.ndrange_id[2]; | |||
size_t thread_id = ncb_index.thread_id; | |||
size_t oc_start = oc_tile_size * oc_tile_id_in_group; | |||
size_t oc_end = oc_start + oc_tile_size; | |||
oc_end = (oc_end <= OC ? oc_end : OC); | |||
size_t numbers_of_ncb_filter_offset = | |||
oc_tile_size * IC * oc_tile_id_in_group; | |||
const src_ctype* Aptr = ncb_param.filter<src_ctype>(group_id) + | |||
numbers_of_ncb_filter_offset; | |||
const src_ctype* Bptr = ncb_param.src<src_ctype>(batch_id, group_id); | |||
size_t thread_offset = thread_bundle.total_size_in_bytes() * thread_id; | |||
size_t bytes_offset_of_matmul_dst_this_thread = | |||
thread_offset + thread_bundle.get_size(0); | |||
bias_ctype* matmul_temp_dst = reinterpret_cast<bias_ctype*>( | |||
reinterpret_cast<int8_t*>(whole_bundle.get(0)) + | |||
bytes_offset_of_matmul_dst_this_thread); | |||
size_t numbers_of_ncb_dst_offset = oc_tile_size * oc_tile_id_in_group; | |||
dst_ctype* conv_bias_dst = | |||
ncb_param.dst<dst_ctype>(batch_id, group_id) + | |||
numbers_of_ncb_dst_offset; | |||
bool is_dst_8bit = | |||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | |||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | |||
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); | |||
bias_ctype* gemv_dst = | |||
is_dst_8bit ? matmul_temp_dst | |||
: reinterpret_cast<bias_ctype*>(conv_bias_dst); | |||
GemvLike<src_ctype, bias_ctype, format>::do_gemv( | |||
Aptr, Bptr, gemv_dst, oc_end - oc_start, 1, IC, IC, 1, 1, | |||
ncb_param.filter_type, ncb_param.src_type); | |||
//! do postprocess | |||
void* bias_ptr = nullptr; | |||
if (param.bias_mode == megdnn::BiasMode::BIAS) { | |||
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | |||
ncb_param.bias<bias_ctype>(batch_id, group_id) + | |||
numbers_of_ncb_dst_offset)); | |||
} else { | |||
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>( | |||
ncb_param.bias<bias_ctype>(batch_id, group_id) + oc_start)); | |||
} | |||
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( | |||
gemv_dst, bias_ptr, conv_bias_dst, param.bias_mode, | |||
param.nonlineMode, param.bias_type, param.dst_type, 1_z, | |||
oc_end - oc_start, 1, 1, 1); | |||
} | |||
}; | |||
} // namespace | |||
size_t ConvBiasImpl::AlgoConv1x1Gemv::get_oc_tile_size_heuristic( | |||
const NCBKernSizeParam& param) const { | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t oc_block_size_one_thread = div_ceil(OC, param.nr_threads); | |||
return round_up<size_t>(oc_block_size_one_thread, 16); | |||
} | |||
size_t ConvBiasImpl::AlgoConv1x1Gemv::get_workspace( | |||
ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, | |||
midout_iv("AlgoConv1x1Gemv::get_workspace"_hash)) { | |||
size_t compt_oc_block_size = get_oc_tile_size_heuristic(param); | |||
auto thread_bundle = | |||
utils::get_thread_bundle(param, 0, compt_oc_block_size); | |||
return WorkspaceBundle{ | |||
nullptr, | |||
{thread_bundle.total_size_in_bytes() * param.nr_threads}} | |||
.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||
ConvBiasImpl* opr, const NCBKernSizeParam& param) const { | |||
SmallVector<ConvBiasImpl::NCBKern> ret_kern; | |||
size_t OC = param.filter_meta.ocpg; | |||
size_t compt_oc_block_size = get_oc_tile_size_heuristic(param); | |||
size_t GROUP = param.filter_meta.group; | |||
size_t BATCH = param.n; | |||
size_t oc_blocks_per_group = div_ceil(OC, compt_oc_block_size); | |||
//! get thread bundle | |||
auto thread_bundle = | |||
utils::get_thread_bundle(param, 0, compt_oc_block_size); | |||
auto whole_bundle = WorkspaceBundle{ | |||
nullptr, {thread_bundle.total_size_in_bytes() * param.nr_threads}}; | |||
using conv1x1_gemv_kern = | |||
std::function<void(WorkspaceBundle&, WorkspaceBundle&, size_t, | |||
const ConvBiasImpl::NCBKernSizeParam&, | |||
const ConvBiasImpl::NCBKernParam&, | |||
const ConvBiasImpl::NCBKernIndex&)>; | |||
conv1x1_gemv_kern conv1x1_gemv_worker = nullptr; | |||
#define cb1(_format, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||
conv1x1_gemv_worker = \ | |||
Conv1x1GemvWorker<_dt, _dt, _dt, _post_ctype, _post_ctype, \ | |||
_postprocess_mode, _format>::exec; \ | |||
} \ | |||
} \ | |||
MIDOUT_END() | |||
#define cb2(_format, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ | |||
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, midout_iv(_midout_tag)) { \ | |||
if (param.filter_type.enumv() == param.src_type.enumv() && \ | |||
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ | |||
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ | |||
conv1x1_gemv_worker = \ | |||
Conv1x1GemvWorker<_src_ctype, _bias_ctype, _dst_ctype, \ | |||
DTypeTrait<_i_bias_type>::ctype, \ | |||
DTypeTrait<_i_dst_type>::ctype, \ | |||
_postprocess_mode, _format>::exec; \ | |||
} \ | |||
} \ | |||
MIDOUT_END() | |||
switch (opr->param().format) { | |||
case param::ConvBias::Format::NCHW: | |||
cb1(param::ConvBias::Format::NCHW, dt_float32, dt_float32, | |||
PostprocessMode::FLOAT, "NCHW::GEMV::FLOAT"_hash); | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
cb1(param::ConvBias::Format::NCHW, dt_float16, __fp16, | |||
PostprocessMode::FLOAT, "NCHW::GEMV::FLOAT16_FP16"_hash); | |||
#endif | |||
cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, | |||
dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS, | |||
"NCHW::GEMV::INT8x8x32_INT32"_hash); | |||
cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int16, dt_int16, | |||
dt_int8, dt_int16, dt_int16, PostprocessMode::NO_PROCESS, | |||
"NCHW::GEMV::INT8x8x16_INT16"_hash); | |||
cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32, | |||
dt_int32, PostprocessMode::NO_PROCESS, | |||
"NCHW::GEMV::QINT8x8x32_QINT32"_hash); | |||
cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8, | |||
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32, | |||
dt_int8, PostprocessMode::QUANTIZED, | |||
"NCHW::GEMV::QINT8x8x32_QINT8"_hash); | |||
cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, | |||
dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, dt_int32, | |||
dt_int32, PostprocessMode::NO_PROCESS, | |||
"NCHW::GEMV::QUINT8x8x32_QINT32"_hash); | |||
cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm, | |||
dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8, dt_int32, | |||
dt_uint8, PostprocessMode::QUANTIZED, | |||
"NCHW::GEMV::QUINT8x8x32_QUINT8"_hash); | |||
break; | |||
default: | |||
megdnn_throw("Invalid Format"); | |||
break; | |||
} | |||
#undef cb1 | |||
#undef cb2 | |||
megdnn_assert(conv1x1_gemv_worker, "No suitable gemv worker"); | |||
auto kern_compt = | |||
[compt_oc_block_size, param, conv1x1_gemv_worker, whole_bundle, | |||
thread_bundle]( | |||
const ConvBiasImpl::NCBKernParam& ncb_param, | |||
const ConvBiasImpl::NCBKernIndex& ncb_index) mutable { | |||
conv1x1_gemv_worker(whole_bundle, thread_bundle, | |||
compt_oc_block_size, param, ncb_param, | |||
std::move(ncb_index)); | |||
}; | |||
ret_kern.push_back({kern_compt, {BATCH, GROUP, oc_blocks_per_group}}); | |||
return ret_kern; | |||
} | |||
bool ConvBiasImpl::AlgoConv1x1Gemv::usable(ConvBiasImpl* opr, | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, | |||
midout_iv("AlgoConv1x1Gemv::usable"_hash)) { | |||
//! whether 1x1 | |||
size_t FH = param.filter_meta.spatial[0], | |||
FW = param.filter_meta.spatial[1]; | |||
size_t PH = param.filter_meta.padding[0], | |||
PW = param.filter_meta.padding[1]; | |||
size_t SH = param.filter_meta.stride[0], | |||
SW = param.filter_meta.stride[1]; | |||
if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) { | |||
return false; | |||
} | |||
//! whether gemv | |||
size_t OH = param.osz[0]; | |||
size_t OW = param.osz[1]; | |||
if (OH * OW != 1) { | |||
return false; | |||
} | |||
//! even no naive support in gemv | |||
if ((param.src_type.enumv() == param.filter_type.enumv() && | |||
param.src_type.enumv() == DTypeEnum::Int16) && | |||
param.dst_type.enumv() == DTypeEnum::Int32) { | |||
return false; | |||
} | |||
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode | |||
//! is identity otherwise return false mean that 8x8x32 and 8x8x16 | |||
//! not support PostProcess | |||
if (param.dst_type.enumv() == DTypeEnum::Int16 || | |||
param.dst_type.enumv() == DTypeEnum::Int32 || | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | |||
if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
return false; | |||
} | |||
} | |||
//! supports a few dtypes | |||
if (param.src_type.enumv() != param.filter_type.enumv()) { | |||
return false; | |||
} | |||
if (param.src_type.enumv() != DTypeEnum::Int8 && | |||
param.src_type.enumv() != DTypeEnum::QuantizedS8 && | |||
param.src_type.enumv() != DTypeEnum::Quantized8Asymm && | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
param.src_type.enumv() != DTypeEnum::Float16 && | |||
#endif | |||
param.src_type.enumv() != DTypeEnum::Float32) { | |||
return false; | |||
} | |||
bool is_param_ok = | |||
(param.filter_meta.dilation[0] == | |||
param.filter_meta.dilation[1] && | |||
param.filter_meta.dilation[0] == 1) && | |||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT; | |||
bool is_format_and_dtype_ok = false; | |||
#if MEGDNN_X86 | |||
if (opr->param().format == param::ConvBias::Format::NCHW) { | |||
//! x86 supports all dtypes in NCHW | |||
is_format_and_dtype_ok = true; | |||
} | |||
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
//! add NCHW44 and NCHW44_DOT support in the future | |||
if (opr->param().format == param::ConvBias::Format::NCHW) { | |||
//! NCHW format supports all dtype | |||
is_format_and_dtype_ok = true; | |||
} | |||
#endif | |||
return is_param_ok && is_format_and_dtype_ok; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
bool ConvBiasImpl::AlgoConv1x1Gemv::is_preferred( | |||
ConvBiasImpl*, const NCBKernSizeParam& param) const { | |||
size_t OC = param.filter_meta.ocpg; | |||
if (OC <= 2 && param.src_type.enumv() != DTypeEnum::Float32) | |||
return true; | |||
#if (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
//! maybe add support for QuantizedAsym in the future | |||
return (param.src_type.enumv() == DTypeEnum::Int8 && | |||
param.filter_type.enumv() == DTypeEnum::Int8 && | |||
param.dst_type.enumv() == DTypeEnum::Int32) || | |||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | |||
(param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.filter_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) || | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
(param.src_type.enumv() == DTypeEnum::Float16 && | |||
param.filter_type.enumv() == DTypeEnum::Float16 && | |||
param.dst_type.enumv() == DTypeEnum::Float16) || | |||
#endif | |||
(param.src_type.enumv() == DTypeEnum::Float32 && | |||
param.filter_type.enumv() == DTypeEnum::Float32 && | |||
param.dst_type.enumv() == DTypeEnum::Float32); | |||
#else | |||
return false; | |||
#endif | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,47 @@ | |||
/** | |||
* \file dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/thin/small_vector.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
namespace megdnn { | |||
namespace fallback { | |||
class ConvBiasImpl::AlgoConv1x1Gemv final : public AlgoBase { | |||
public: | |||
AlgoConv1x1Gemv() = default; | |||
bool is_reproducible() const override { return true; } | |||
const char* name() const override { | |||
return "CONV1x1_GEMV"; | |||
} | |||
bool usable(ConvBiasImpl* opr, const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(ConvBiasImpl*, | |||
const NCBKernSizeParam& param) const override; | |||
SmallVector<NCBKern> dispatch_kerns( | |||
ConvBiasImpl* opr, const NCBKernSizeParam& param) const override; | |||
bool is_preferred(ConvBiasImpl*, const NCBKernSizeParam&) const override; | |||
protected: | |||
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; | |||
}; | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -11,30 +11,12 @@ | |||
#pragma once | |||
#include "src/fallback/conv_bias/conv1x1/conv1x1_strategy.h" | |||
#include "src/fallback/conv_bias/conv1x1/conv1x1_utils.h" | |||
namespace megdnn { | |||
namespace fallback { | |||
namespace conv1x1 { | |||
namespace { | |||
//! get_thread_bundle | |||
WorkspaceBundle get_thread_bundle(const ConvBiasImpl::NCBKernSizeParam& param, | |||
size_t matmul_c_size, size_t oc_tile_size) { | |||
//! for some cases, matmul result need temp space to store | |||
size_t OH = param.osz[0]; | |||
size_t OW = param.osz[1]; | |||
bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | |||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | |||
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); | |||
size_t matmul_dst_bytes_per_thread = | |||
is_dst_8bit ? oc_tile_size * OH * OW * sizeof(param.bias_type) : 0; | |||
return WorkspaceBundle{nullptr, | |||
{matmul_c_size, matmul_dst_bytes_per_thread}}; | |||
} | |||
} // anonymous namespace | |||
template <MatrixMulImpl::AlgoBase::PackMode pack_mode> | |||
class Conv1x1Kerns { | |||
public: | |||
@@ -51,7 +33,7 @@ public: | |||
//! matmul_param records a matmul with M = oc_tile_size, K = IC, N = OH | |||
//! * OW this does not bother packb bytes | |||
auto matmul_bundle = matmul_algo->get_bundle(matmul_param); | |||
auto thread_bundle = get_thread_bundle(param, matmul_bundle.get_size(2), | |||
auto thread_bundle = utils::get_thread_bundle(param, matmul_bundle.get_size(2), | |||
oc_tile_size); | |||
//! size per thread | |||
@@ -86,7 +68,7 @@ public: | |||
const MatrixMulImpl::AlgoBase* matmul_algo, | |||
size_t oc_tile_size) { | |||
size_t matmul_size = matmul_algo->get_workspace(matmul_param); | |||
auto thread_bundle = get_thread_bundle(param, matmul_size, oc_tile_size); | |||
auto thread_bundle = utils::get_thread_bundle(param, matmul_size, oc_tile_size); | |||
//! size per thread | |||
size_t all_threads_bytes = | |||
thread_bundle.total_size_in_bytes() * param.nr_threads; | |||
@@ -10,8 +10,8 @@ | |||
* implied. | |||
*/ | |||
#include "src/fallback/conv_bias/conv1x1/conv1x1_utils.h" | |||
#include "src/fallback/conv_bias/conv1x1/conv1x1_strategy.h" | |||
#include <unordered_map> | |||
#include "midout.h" | |||
@@ -20,53 +20,7 @@ MIDOUT_DECL(megdnn_fallback_conv1x1_factory_strategy) | |||
namespace megdnn { | |||
namespace fallback { | |||
namespace conv1x1 { | |||
namespace { | |||
struct StrategyHashParam { | |||
ConvBiasImpl::NCBKernSizeParam param; | |||
param::ConvBias::Format format; | |||
MatrixMulImpl::AlgoBase::PackMode packmode; | |||
}; | |||
struct StrategyHashParamHash { | |||
std::size_t operator()(const StrategyHashParam& sparam) const { | |||
constexpr size_t base = 1; //! avoid hashkey is zero | |||
std::size_t result = | |||
static_cast<std::size_t>(sparam.param.src_type.enumv()) + base; | |||
result = result ^ | |||
((static_cast<std::size_t>(sparam.param.dst_type.enumv()) + | |||
base) | |||
<< 3); | |||
result = result ^ | |||
((static_cast<std::size_t>(sparam.param.filter_type.enumv()) + | |||
base) | |||
<< 6); | |||
result = result ^ | |||
((static_cast<std::size_t>(sparam.param.bias_type.enumv()) + | |||
base) | |||
<< 9); | |||
result = result ^ | |||
((static_cast<std::size_t>(sparam.format) + base) << 12); | |||
result = result ^ | |||
((static_cast<std::size_t>(sparam.packmode) + base) << 15); | |||
return result; | |||
}; | |||
}; | |||
struct StrategyHashParamEqual { | |||
bool operator()(const StrategyHashParam& param1, | |||
const StrategyHashParam& param2) const { | |||
bool flags = true; | |||
flags = param1.param.src_type == param2.param.src_type && flags; | |||
flags = param1.param.filter_type == param2.param.filter_type && flags; | |||
flags = param1.param.bias_type == param2.param.bias_type && flags; | |||
flags = param1.param.dst_type == param2.param.dst_type && flags; | |||
flags = param1.format == param2.format && flags; | |||
flags = param1.packmode == param2.packmode && flags; | |||
return flags; | |||
}; | |||
}; | |||
//! NOTE: must keep consistence with can_make_conv1x1_strategy when you modify | |||
//! this function | |||
std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||
@@ -176,39 +130,14 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( | |||
megdnn_throw("Invalid Data Type"); | |||
return nullptr; | |||
} | |||
class StrategyDelegationStorage { | |||
public: | |||
Conv1x1StrategyBase* get(const ConvBiasImpl::NCBKernSizeParam& param, | |||
MatrixMulImpl::AlgoBase::PackMode pack_mode, | |||
param::ConvBias::Format format) { | |||
MEGDNN_LOCK_GUARD(m_mtx); | |||
StrategyHashParam sparam; | |||
sparam.param = param; | |||
sparam.format = format; | |||
sparam.packmode = pack_mode; | |||
if (m_map_strategies.find(sparam) == m_map_strategies.end()) { | |||
auto strategy = create_conv1x1_strategy(param, pack_mode, format); | |||
m_map_strategies[sparam] = std::move(strategy); | |||
} | |||
return m_map_strategies[sparam].get(); | |||
} | |||
private: | |||
std::mutex m_mtx; | |||
std::unordered_map<StrategyHashParam, std::unique_ptr<Conv1x1StrategyBase>, | |||
StrategyHashParamHash, StrategyHashParamEqual> | |||
m_map_strategies; | |||
}; | |||
} // anonymous namespace | |||
Conv1x1StrategyBase* Conv1x1Factory::make_conv1x1_strategy( | |||
const ConvBiasImpl::NCBKernSizeParam& param, | |||
MatrixMulImpl::AlgoBase::PackMode pack_mode, | |||
param::ConvBias::Format format) { | |||
static StrategyDelegationStorage storage; | |||
return storage.get(param, pack_mode, format); | |||
static utils::StrategyDelegationStorage<Conv1x1StrategyBase> storage; | |||
return storage.get(param, pack_mode, format, create_conv1x1_strategy); | |||
} | |||
bool Conv1x1Factory::can_make_conv1x1_strategy( | |||
@@ -277,3 +206,5 @@ bool Conv1x1Factory::can_make_conv1x1_strategy( | |||
} // namespace conv1x1 | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -13,6 +13,8 @@ | |||
#pragma once | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#include "src/fallback/conv_bias/conv1x1/conv1x1_utils.h" | |||
#if MEGDNN_X86 | |||
#include "src/x86/conv_bias/postprocess_helper.h" | |||
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
@@ -27,44 +29,6 @@ namespace conv1x1 { | |||
using namespace x86; | |||
#endif | |||
namespace { | |||
//! get_matmul_kern_param | |||
MatrixMulImpl::KernSizeParam get_matmul_kern_param( | |||
const ConvBiasImpl::NCBKernSizeParam& param, size_t n, size_t m) { | |||
size_t M = m; | |||
size_t N = n; | |||
size_t K = param.filter_meta.icpg; //! K = IC | |||
size_t LDA = K, LDB = N, LDC = N; | |||
bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | |||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | |||
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); | |||
size_t pack_c_size = pack_size(param.filter_meta.format); | |||
auto format = param::MatrixMul::Format::DEFAULT; | |||
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | |||
format = param::MatrixMul::Format::MK4; | |||
} else if (param.filter_meta.format == | |||
param::ConvBias::Format::NCHW44_DOT) { | |||
format = param::MatrixMul::Format::MK4_DOT; | |||
} | |||
return {param.filter_type, | |||
param.src_type, | |||
is_dst_8bit ? param.bias_type : param.dst_type, | |||
M, | |||
N, | |||
K, | |||
LDA * pack_c_size, | |||
LDB * pack_c_size, | |||
LDC * pack_c_size, | |||
false, | |||
false, | |||
param::MatrixMul::ComputeMode::DEFAULT, | |||
format}; | |||
} | |||
} // namespace | |||
class Conv1x1StrategyBase { | |||
public: | |||
virtual void packA(WorkspaceBundle& whole_bundle, | |||
@@ -134,7 +98,7 @@ public: | |||
size_t IC = param.filter_meta.icpg; | |||
MatrixMulImpl::KernParam matmul_kern_param; | |||
static_cast<MatrixMulImpl::KernSizeParam&>(matmul_kern_param) = | |||
get_matmul_kern_param(param, OH * OW, oc_end - oc_start); | |||
utils::get_matmul_kern_param(param, OH * OW, oc_end - oc_start); | |||
size_t bytes_offset_of_a_panel = | |||
group_id * packa_bytes_per_group + | |||
@@ -176,8 +140,7 @@ public: | |||
MatrixMulImpl::KernParam matmul_kern_param; | |||
static_cast<MatrixMulImpl::KernSizeParam&>(matmul_kern_param) = | |||
get_matmul_kern_param(param, OH * OW, OC); | |||
utils::get_matmul_kern_param(param, OH * OW, OC); | |||
rep(batch, BATCH) { | |||
rep(g, GROUP) { | |||
@@ -238,7 +201,7 @@ public: | |||
MatrixMulImpl::KernParam matmul_kern_param; | |||
static_cast<MatrixMulImpl::KernSizeParam&>(matmul_kern_param) = | |||
get_matmul_kern_param(param, OH * OW, oc_end - oc_start); | |||
utils::get_matmul_kern_param(param, OH * OW, oc_end - oc_start); | |||
size_t bytes_offset_of_a_panel = | |||
group_id * packa_bytes_per_group + | |||
@@ -328,7 +291,6 @@ public: | |||
MatrixMulImpl::AlgoBase::PackMode pack_mode, | |||
param::ConvBias::Format format); | |||
}; | |||
} // namespace conv1x1 | |||
} // namespace fallback | |||
} // namespace megdnn | |||
@@ -0,0 +1,75 @@ | |||
/** | |||
* \file dnn/src/fallback/conv_bias/conv1x1/conv1x1_utils.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "src/fallback/conv_bias/conv1x1/conv1x1_utils.h" | |||
namespace megdnn { | |||
namespace fallback { | |||
namespace conv1x1 { | |||
namespace utils{ | |||
//! get_thread_bundle | |||
WorkspaceBundle get_thread_bundle(const ConvBiasImpl::NCBKernSizeParam& param, | |||
size_t matmul_c_size, size_t oc_tile_size) { | |||
//! for some cases, matmul result need temp space to store | |||
size_t OH = param.osz[0]; | |||
size_t OW = param.osz[1]; | |||
bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | |||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | |||
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); | |||
size_t matmul_dst_bytes_per_thread = | |||
is_dst_8bit ? oc_tile_size * OH * OW * sizeof(param.bias_type) : 0; | |||
return WorkspaceBundle{nullptr, | |||
{matmul_c_size, matmul_dst_bytes_per_thread}}; | |||
} | |||
//! get_matmul_kern_param | |||
MatrixMulImpl::KernSizeParam get_matmul_kern_param( | |||
const ConvBiasImpl::NCBKernSizeParam& param, size_t n, size_t m) { | |||
size_t M = m; | |||
size_t N = n; | |||
size_t K = param.filter_meta.icpg; //! K = IC | |||
size_t LDA = K, LDB = N, LDC = N; | |||
bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || | |||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm && | |||
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); | |||
size_t pack_c_size = pack_size(param.filter_meta.format); | |||
auto format = param::MatrixMul::Format::DEFAULT; | |||
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { | |||
format = param::MatrixMul::Format::MK4; | |||
} else if (param.filter_meta.format == | |||
param::ConvBias::Format::NCHW44_DOT) { | |||
format = param::MatrixMul::Format::MK4_DOT; | |||
} | |||
return {param.filter_type, | |||
param.src_type, | |||
is_dst_8bit ? param.bias_type : param.dst_type, | |||
M, | |||
N, | |||
K, | |||
LDA * pack_c_size, | |||
LDB * pack_c_size, | |||
LDC * pack_c_size, | |||
false, | |||
false, | |||
param::MatrixMul::ComputeMode::DEFAULT, | |||
format}; | |||
} | |||
} // namespace utils | |||
} // namespace conv1x1 | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,102 @@ | |||
/** | |||
* \file dnn/src/fallback/conv_bias/conv1x1/conv1x1_utils.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <unordered_map> | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
namespace megdnn { | |||
namespace fallback { | |||
namespace conv1x1 { | |||
namespace utils { | |||
struct StrategyHashKey { | |||
ConvBiasImpl::NCBKernSizeParam param; | |||
param::ConvBias::Format format; | |||
MatrixMulImpl::AlgoBase::PackMode packmode; | |||
}; | |||
struct StrategyHasher { | |||
std::size_t operator()(const StrategyHashKey& key) const { | |||
constexpr size_t base = 1; //! avoid hashkey is zero | |||
std::size_t result = | |||
static_cast<std::size_t>(key.param.src_type.enumv()) + base; | |||
result = result ^ | |||
((static_cast<std::size_t>(key.param.dst_type.enumv()) + base) | |||
<< 3); | |||
result = result ^ | |||
((static_cast<std::size_t>(key.param.filter_type.enumv()) + | |||
base) | |||
<< 6); | |||
result = result ^ | |||
((static_cast<std::size_t>(key.param.bias_type.enumv()) + base) | |||
<< 9); | |||
result = result ^ ((static_cast<std::size_t>(key.format) + base) << 12); | |||
result = result ^ | |||
((static_cast<std::size_t>(key.packmode) + base) << 15); | |||
return result; | |||
} | |||
}; | |||
struct StrategyHashKeyEqual { | |||
bool operator()(const StrategyHashKey& key1, | |||
const StrategyHashKey& key2) const { | |||
return key1.param.src_type == key2.param.src_type && | |||
key1.param.filter_type == key2.param.filter_type && | |||
key1.param.bias_type == key2.param.bias_type && | |||
key1.param.dst_type == key2.param.dst_type && | |||
key1.format == key2.format && key1.packmode == key2.packmode; | |||
} | |||
}; | |||
template <typename T> | |||
class StrategyDelegationStorage { | |||
using creator = std::function<std::unique_ptr<T>( | |||
const ConvBiasImpl::NCBKernSizeParam&, | |||
MatrixMulImpl::AlgoBase::PackMode, param::ConvBias::Format)>; | |||
public: | |||
T* get(const ConvBiasImpl::NCBKernSizeParam& param, | |||
MatrixMulImpl::AlgoBase::PackMode pack_mode, | |||
param::ConvBias::Format format, creator Fun) { | |||
MEGDNN_LOCK_GUARD(m_mtx); | |||
StrategyHashKey key; | |||
key.param = param; | |||
key.format = format; | |||
key.packmode = pack_mode; | |||
if (m_map_strategies.find(key) == m_map_strategies.end()) { | |||
auto strategy = Fun(param, pack_mode, format); | |||
m_map_strategies[key] = std::move(strategy); | |||
} | |||
return m_map_strategies[key].get(); | |||
} | |||
private: | |||
std::mutex m_mtx; | |||
std::unordered_map<StrategyHashKey, std::unique_ptr<T>, StrategyHasher, | |||
StrategyHashKeyEqual> | |||
m_map_strategies; | |||
}; | |||
//! get_thread_bundle | |||
WorkspaceBundle get_thread_bundle(const ConvBiasImpl::NCBKernSizeParam& param, | |||
size_t matmul_c_size, size_t oc_tile_size); | |||
//! get_matmul_kern_param | |||
MatrixMulImpl::KernSizeParam get_matmul_kern_param( | |||
const ConvBiasImpl::NCBKernSizeParam& param, size_t n, size_t m); | |||
} // namespace utils | |||
} // namespace conv1x1 | |||
} // namespace fallback | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -647,8 +647,11 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||
return false; | |||
} | |||
if (param.src_type.enumv() != param.filter_type.enumv() && | |||
param.src_type.enumv() != DTypeEnum::Int8 && | |||
if(param.src_type.enumv() != param.filter_type.enumv()) { | |||
return false; | |||
} | |||
if (param.src_type.enumv() != DTypeEnum::Int8 && | |||
param.src_type.enumv() != DTypeEnum::QuantizedS8 && | |||
param.src_type.enumv() != DTypeEnum::Quantized8Asymm && | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
@@ -16,6 +16,7 @@ | |||
#include "src/common/utils.h" | |||
#include "src/fallback/conv_bias/algos.h" | |||
#include "src/fallback/conv_bias/conv1x1/algos.h" | |||
#include "src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h" | |||
#include "src/fallback/conv_bias/im2col/algos.h" | |||
#include "src/fallback/conv_bias/opr_impl.h" | |||
#include "src/naive/convolution/algorithms.h" | |||
@@ -53,6 +54,10 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
public: | |||
AlgoPack() { | |||
refhold.emplace_back(new AlgoConv1x1Gemv()); | |||
all_algos.emplace_back(refhold.back().get()); | |||
static CpuOprDelegationStorage<> storage; | |||
auto matmul_opr = storage.get<MatrixMul>(); | |||
auto&& matmul_algos = | |||
@@ -259,6 +259,7 @@ private: | |||
class AlgoNaive; | |||
class AlgoIm2col; | |||
class AlgoConv1x1; | |||
class AlgoConv1x1Gemv; | |||
class AlgoWinogradF32; | |||
class AlgoWinogradF32_4x4; | |||
class AlgoWinogradQS8; | |||
@@ -11,6 +11,7 @@ | |||
#include "src/fallback/matrix_mul/algos.h" | |||
#include "src/fallback/matrix_mul/gemm_impl.h" | |||
#include "src/fallback/matrix_mul/gemv.h" | |||
#include "src/fallback/matrix_mul/generic_strategy.h" | |||
#include "midout.h" | |||
@@ -71,39 +72,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern, | |||
float); | |||
/* ===================== gemv algo ===================== */ | |||
namespace { | |||
template <typename itype, typename otype, bool have_zp = false> | |||
void gemm_gemv_like(const MatrixMulImpl::KernParam& kern_param) { | |||
const itype* A = kern_param.A<itype>(); | |||
const itype* B = kern_param.B<itype>(); | |||
uint8_t zp0, zp1; | |||
if (have_zp) { | |||
zp0 = kern_param.A_type.param<dtype::Quantized8Asymm>().zero_point; | |||
zp1 = kern_param.B_type.param<dtype::Quantized8Asymm>().zero_point; | |||
} | |||
otype* C = kern_param.C<otype>(); | |||
for (size_t m = 0; m < kern_param.M; ++m) { | |||
memset(C + m * kern_param.LDC, 0, sizeof(otype) * kern_param.N); | |||
for (size_t k = 0; k < kern_param.K; ++k) | |||
for (size_t n = 0; n < kern_param.N; ++n) { | |||
if (!have_zp) | |||
C[m * kern_param.LDC + n] += | |||
static_cast<otype>(A[m * kern_param.LDA + k]) * | |||
static_cast<otype>(B[k * kern_param.LDB + n]); | |||
else { | |||
C[m * kern_param.LDC + n] += | |||
(static_cast<otype>(A[m * kern_param.LDA + k]) - | |||
static_cast<otype>(zp0)) * | |||
(static_cast<otype>(B[k * kern_param.LDB + n]) - | |||
static_cast<otype>(zp1)); | |||
} | |||
} | |||
} | |||
} | |||
} // anonymous namespace | |||
bool MatrixMulImpl::AlgoGemv::usable( | |||
const KernSizeParam& kern_size_param) const { | |||
return !kern_size_param.trA && !kern_size_param.trB && | |||
@@ -0,0 +1,71 @@ | |||
/** | |||
* \file dnn/src/fallback/matrix_mul/gemv.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "src/fallback/matrix_mul/opr_impl.h" | |||
namespace megdnn { | |||
namespace fallback{ | |||
template <typename itype, typename otype> | |||
void gemv_like(const itype* A, const itype* B, otype* C, size_t M, size_t N, | |||
size_t K, size_t LDA, size_t LDB, size_t LDC) { | |||
for (size_t m = 0; m < M; ++m) { | |||
memset(C + m * LDC, 0, sizeof(otype) * N); | |||
for (size_t k = 0; k < K; ++k) | |||
for (size_t n = 0; n < N; ++n) { | |||
C[m * LDC + n] += static_cast<otype>(A[m * LDA + k]) * | |||
static_cast<otype>(B[k * LDB + n]); | |||
} | |||
} | |||
} | |||
template <typename itype, typename otype> | |||
void gemv_like(const itype* A, const itype* B, otype* C, size_t M, size_t N, | |||
size_t K, size_t LDA, size_t LDB, size_t LDC, uint8_t zp0, | |||
uint8_t zp1) { | |||
for (size_t m = 0; m < M; ++m) { | |||
memset(C + m * LDC, 0, sizeof(otype) * N); | |||
for (size_t k = 0; k < K; ++k) | |||
for (size_t n = 0; n < N; ++n) { | |||
C[m * LDC + n] += (static_cast<otype>(A[m * LDA + k]) - | |||
static_cast<otype>(zp0)) * | |||
(static_cast<otype>(B[k * LDB + n]) - | |||
static_cast<otype>(zp1)); | |||
} | |||
} | |||
} | |||
template <typename itype, typename otype, bool have_zp = false> | |||
void gemm_gemv_like(const MatrixMulImpl::KernParam& kern_param) { | |||
const itype* A = kern_param.A<itype>(); | |||
const itype* B = kern_param.B<itype>(); | |||
otype* C = kern_param.C<otype>(); | |||
size_t M = kern_param.M; | |||
size_t N = kern_param.N; | |||
size_t K = kern_param.K; | |||
size_t LDA = kern_param.LDA; | |||
size_t LDB = kern_param.LDB; | |||
size_t LDC = kern_param.LDC; | |||
if (have_zp) { | |||
uint8_t zp0 = kern_param.A_type.param<dtype::Quantized8Asymm>().zero_point; | |||
uint8_t zp1 = kern_param.B_type.param<dtype::Quantized8Asymm>().zero_point; | |||
gemv_like<itype, otype>(A, B, C, M, N, K, LDA, LDB, LDC, zp0, zp1); | |||
} | |||
else { | |||
gemv_like<itype, otype>(A, B, C, M, N, K, LDA, LDB, LDC); | |||
} | |||
} | |||
} // namespace fallback | |||
} // namespace megdnn |
@@ -1861,6 +1861,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { | |||
#elif MEGDNN_ARMV7 | |||
check_conv_bias(args, handle(), "CONV1x1:ARMV7_F32:48"); | |||
#endif | |||
std::vector<conv_bias::TestArg> gemv_args; | |||
for (auto&& arg : args) | |||
if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||
gemv_args.emplace_back(arg); | |||
} | |||
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) { | |||
@@ -1905,16 +1911,23 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) { | |||
dtype::Float16{}, dtype::Float16{}, dtype::Float16{}, | |||
"CONV1x1:AARCH32_F16_K4X16X1:24"); | |||
#endif | |||
std::vector<conv_bias::TestArg> gemv_args; | |||
for (auto&& arg : args) | |||
if(arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||
gemv_args.emplace_back(arg); | |||
} | |||
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV"); | |||
} | |||
#endif | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) { | |||
UniformIntRNG rng{-50, 50}; | |||
float epsilon = 0.001; | |||
#define cb(name) \ | |||
checker_conv_bias(get_conv_bias_1x1_args(false, false, true, true), \ | |||
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||
std::vector<conv_bias::TestArg> args = | |||
get_conv_bias_1x1_args(false, false, true, true); | |||
#define cb(name) \ | |||
checker_conv_bias(args, handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ | |||
dtype::QuantizedS8(60.25f), name); | |||
#if MEGDNN_AARCH64 | |||
#if __ARM_FEATURE_DOTPROD | |||
@@ -1928,17 +1941,27 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) { | |||
cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:48"); | |||
#endif | |||
#undef cb | |||
std::vector<conv_bias::TestArg> gemv_args; | |||
for (auto&& arg : args) | |||
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||
gemv_args.emplace_back(arg); | |||
} | |||
checker_conv_bias(gemv_args, handle(), &rng, epsilon, | |||
dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f), | |||
dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), | |||
"CONV1x1_GEMV"); | |||
} | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) { | |||
NormalRNG rng(128.f); | |||
#define cb(name) \ | |||
checker_conv_bias(get_conv_bias_1x1_args(false, false, true, true), \ | |||
handle(), &rng, epsilon, \ | |||
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||
dtype::QuantizedS32(1.2 * 1.3), \ | |||
UniformIntRNG rng{-50, 50}; | |||
std::vector<conv_bias::TestArg> args = | |||
get_conv_bias_1x1_args(false, false, true, true); | |||
#define cb(name) \ | |||
checker_conv_bias(args, handle(), &rng, epsilon, \ | |||
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||
dtype::QuantizedS32(1.2 * 1.3), \ | |||
dtype::Quantized8Asymm(50.3f, (uint8_t)120), name); | |||
float epsilon = 0.001; | |||
#if MEGDNN_AARCH64 | |||
@@ -1952,17 +1975,29 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) { | |||
cb("CONV1x1:ARMV7_QUINT8_K4X8X8:48"); | |||
#endif | |||
#undef cb | |||
std::vector<conv_bias::TestArg> gemv_args; | |||
for (auto&& arg : args) | |||
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||
gemv_args.emplace_back(arg); | |||
} | |||
checker_conv_bias(gemv_args, handle(), &rng, epsilon, | |||
dtype::Quantized8Asymm(1.2f, (uint8_t)125), | |||
dtype::Quantized8Asymm(1.3f, (uint8_t)129), | |||
dtype::QuantizedS32(1.2 * 1.3), | |||
dtype::Quantized8Asymm(50.3f, (uint8_t)120), | |||
"CONV1x1_GEMV"); | |||
} | |||
#endif | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) { | |||
UniformIntRNG rng{-50, 50}; | |||
NormalRNG rng(128.f); | |||
float epsilon = 0.001; | |||
#define cb(name) \ | |||
checker_conv_bias(get_conv_bias_1x1_args(true, true), handle(), &rng, \ | |||
epsilon, dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||
#define cb(name) \ | |||
checker_conv_bias(args, handle(), &rng, epsilon, \ | |||
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \ | |||
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \ | |||
dtype::QuantizedS32(1.2 * 1.3), {}, name); | |||
#if MEGDNN_AARCH64 | |||
@@ -1978,15 +2013,25 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) { | |||
cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24"); | |||
#endif | |||
#undef cb | |||
std::vector<conv_bias::TestArg> gemv_args; | |||
for (auto&& arg : args) | |||
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||
gemv_args.emplace_back(arg); | |||
} | |||
checker_conv_bias(gemv_args, handle(), &rng, epsilon, | |||
dtype::Quantized8Asymm(1.2f, (uint8_t)125), | |||
dtype::Quantized8Asymm(1.3f, (uint8_t)129), | |||
dtype::QuantizedS32(1.2 * 1.3), {}, "CONV1x1_GEMV"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { | |||
UniformIntRNG rng{-50, 50}; | |||
float epsilon = 0.001; | |||
#define cb(name) \ | |||
checker_conv_bias(get_conv_bias_1x1_args(true, true), handle(), &rng, \ | |||
epsilon, dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, \ | |||
dtype::Int16{}, name); | |||
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true); | |||
#define cb(name) \ | |||
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \ | |||
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name); | |||
#if MEGDNN_AARCH64 | |||
cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24"); | |||
@@ -1997,6 +2042,16 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) { | |||
#endif | |||
cb("CONV1x1:ARM_COMMON_INT8X8X16:48"); | |||
#undef cb | |||
std::vector<conv_bias::TestArg> gemv_args; | |||
for (auto&& arg : args) | |||
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||
gemv_args.emplace_back(arg); | |||
} | |||
checker_conv_bias(gemv_args, handle(), &rng, epsilon, dtype::Int8{}, | |||
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, | |||
"CONV1x1_GEMV"); | |||
} | |||
#endif | |||
@@ -2024,6 +2079,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) { | |||
cb("CONV1x1:ARMV7_INT8X8X32_K4X2X16:48"); | |||
#endif | |||
#undef cb | |||
std::vector<conv_bias::TestArg> gemv_args; | |||
for (auto&& arg : args) | |||
if (arg.src.shape[2] == 1 && arg.src.shape[3] == 1) { | |||
gemv_args.emplace_back(arg); | |||
} | |||
checker_conv_bias_mul_int8x8x32(gemv_args, handle(), "CONV1x1_GEMV"); | |||
} | |||
#ifndef __ARM_FEATURE_DOTPROD | |||
@@ -254,6 +254,50 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV) { | |||
for (size_t N : {512, 1024}) | |||
run(M, K, N); | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP32) { | |||
int exec_times = 50; | |||
Benchmarker<MatrixMul> benchmarker(handle()); | |||
benchmarker.set_times(exec_times); | |||
benchmarker.set_before_exec_callback( | |||
AlgoChecker<MatrixMul>("ARM_COMMON_F32_GEMV")); | |||
auto run = [&](size_t M, size_t K, size_t N) { | |||
std::cout << "SGEMV: (" << M << ", " << K << ", " << N << ")" | |||
<< std::endl; | |||
benchmarker.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()) | |||
.set_dtype(2, dtype::Float32()); | |||
auto time = benchmarker.exec({{M, K}, {K, N}, {}}) / exec_times; | |||
auto computations = 2 * M * K * N * 1e-6; | |||
auto perf = computations / time; | |||
std::cout << "gemv fp32, Performance is " << perf << " Gflops" | |||
<< std::endl; | |||
}; | |||
std::cout << "warm up:\n"; | |||
for (int i = 0; i < 50; i++) { | |||
benchmarker.set_dtype(0, dtype::Float32()) | |||
.set_dtype(1, dtype::Float32()) | |||
.set_dtype(2, dtype::Float32()) | |||
.set_display(false) | |||
.exec({{2, 1024}, {1024, 512}, {}}); | |||
benchmarker.set_display(true); | |||
} | |||
// run gemv | |||
run(12, 48, 1); | |||
run(48, 12, 1); | |||
run(32, 128, 1); | |||
run(128, 32, 1); | |||
run(64, 256, 1); | |||
run(256, 64, 1); | |||
run(128, 512, 1); | |||
run(512, 128, 1); | |||
run(256, 1024, 1); | |||
run(1024, 256, 1); | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) { | |||
int exec_times = 50; | |||
Benchmarker<MatrixMul> benchmarker(handle()); | |||
@@ -290,6 +334,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_FP16) { | |||
for (size_t N : {512, 1024}) | |||
run(M, K, N); | |||
} | |||
TEST_F(ARM_COMMON, BENCHMARK_SGEMM) { | |||
int exec_times = 10; | |||
Benchmarker<MatrixMul> benchmarker(handle()); | |||
@@ -1081,7 +1081,7 @@ std::vector<megdnn::test::conv_bias::TestArg> get_conv_bias_1x1_args( | |||
for (size_t n : {1, 2}) | |||
for (size_t oc : {1, 9, 33}) | |||
for (size_t ic : {1, 16, 64}) | |||
for (size_t size : {7, 14, 28}) | |||
for (size_t size : {1, 7, 14, 28}) | |||
for (auto nlmode : nonlinemode) | |||
for (auto convmode : convmodes) { | |||
pack(n, oc, ic, size, size, 1, nlmode, convmode); | |||