GitOrigin-RevId: 191334bd96
tags/v1.0.0-rc1
@@ -116,9 +116,9 @@ if(${MGE_ARCH} STREQUAL "AUTO") | |||
endif() | |||
if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386" OR ${MGE_ARCH} STREQUAL "armv7" OR ${MGE_ARCH} STREQUAL "aarch64") | |||
option(MGB_ENABLE_CPUINFO "Build cpuinfo library for check runtime." ON) | |||
if(MGB_ENABLE_CPUINFO) | |||
message("-- Enable cpuinfo runtime check.") | |||
option(MGE_ENABLE_CPUINFO "Build cpuinfo library for check runtime." ON) | |||
if(MGE_ENABLE_CPUINFO) | |||
message("-- Enable cpuinfo runtime check and little kernel optimize.") | |||
add_definitions(-DMGB_ENABLE_CPUINFO_CHECK) | |||
include(cmake/cpuinfo.cmake) | |||
endif() | |||
@@ -53,7 +53,7 @@ add_library(megdnn EXCLUDE_FROM_ALL OBJECT ${SOURCES}) | |||
target_link_libraries(megdnn PUBLIC opr_param_defs) | |||
if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386" OR ${MGE_ARCH} STREQUAL "armv7" OR ${MGE_ARCH} STREQUAL "aarch64") | |||
if(MGB_ENABLE_CPUINFO) | |||
if(MGE_ENABLE_CPUINFO) | |||
target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:cpuinfo>) | |||
endif() | |||
endif() | |||
@@ -49,13 +49,13 @@ size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace( | |||
return wbundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { | |||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) { | |||
return get_kimpls(param); | |||
} | |||
MIDOUT_END(); | |||
@@ -58,6 +58,7 @@ size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( | |||
@@ -118,6 +119,7 @@ size_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_kern( | |||
@@ -177,6 +179,7 @@ size_t MatrixMulImpl::AlgoF32K4x16x1::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K4x16x1::get_kern( | |||
@@ -237,6 +240,7 @@ size_t MatrixMulImpl::AlgoF32MK4_4x16::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_4x16::get_kern( | |||
@@ -313,6 +317,7 @@ size_t MatrixMulImpl::AlgoF16K8x24x1::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K8x24x1::get_kern( | |||
@@ -352,6 +357,7 @@ size_t MatrixMulImpl::AlgoF16MK8_8x8::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern( | |||
@@ -431,6 +437,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_kern( | |||
@@ -501,6 +508,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_kern( | |||
@@ -573,6 +581,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_kern( | |||
@@ -635,6 +644,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_kern( | |||
@@ -696,6 +706,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_kern( | |||
@@ -762,6 +773,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_kern( | |||
@@ -828,6 +840,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_kern( | |||
@@ -905,6 +918,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern( | |||
@@ -981,6 +995,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_kern( | |||
@@ -1051,6 +1066,7 @@ size_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_kern( | |||
@@ -1092,6 +1108,7 @@ size_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_kern( | |||
@@ -1172,6 +1189,7 @@ size_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_kern( | |||
@@ -1277,6 +1295,7 @@ size_t MatrixMulImpl::AlgoQuint8K8x8x8::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x8::get_kern( | |||
@@ -160,7 +160,12 @@ bool ConvBiasImpl::AlgoF32DirectNCHW44::usable(const NCBKernSizeParam& param, | |||
size_t ConvBiasImpl::AlgoF32DirectNCHW44::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
return get_bundle(param).total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw44_stride1, | |||
midout_iv("AlgoF32DirectNCHW44::get_workspace"_hash)) { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
@@ -199,7 +199,12 @@ bool ConvBiasImpl::AlgoF32DirectNCHWNCHW44::usable( | |||
size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
return get_bundle(param).total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw_nchw44, | |||
midout_iv("AlgoF32DirectNCHWNCHW44::get_workspace"_hash)) { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
@@ -33,29 +33,39 @@ bool ConvBiasImpl::AlgoS8DirectStride1::usable(const NCBKernSizeParam& param, | |||
return direct_int8_stride1::can_conv_direct_stride1_int8(param); | |||
} | |||
bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred( | |||
const NCBKernSizeParam& param) const { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
auto OC = fm.ocpg; | |||
auto IC = fm.icpg; | |||
bool preferred = ((FH == 2 && (OC <= 10 || IC <= 8)) || | |||
((FH == 3 || FH == 5 || FH == 7) && | |||
(OC <= 16 || (IC <= 4 && OC <= 32)))) && | |||
param.bias_mode != BiasMode::BIAS; | |||
return preferred; | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
midout_iv("AlgoS8DirectStride1::is_preferred"_hash)) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
auto OC = fm.ocpg; | |||
auto IC = fm.icpg; | |||
bool preferred = ((FH == 2 && (OC <= 10 || IC <= 8)) || | |||
((FH == 3 || FH == 5 || FH == 7) && | |||
(OC <= 16 || (IC <= 4 && OC <= 32)))) && | |||
param.bias_mode != BiasMode::BIAS; | |||
return preferred; | |||
} | |||
MIDOUT_END(); | |||
} | |||
size_t ConvBiasImpl::AlgoS8DirectStride1::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_int8_stride1::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
midout_iv("AlgoS8DirectStride1::get_workspace"_hash)) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_int8_stride1::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoS8DirectStride1::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 0) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
midout_iv("AlgoS8DirectStride1::dispatch_kerns"_hash)) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
return direct_int8_stride1::get_kimpls(param, large_group); | |||
} | |||
@@ -72,15 +82,20 @@ bool ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::usable( | |||
size_t ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
auto bundle = channel_wise_nchw44::stride1::get_bundle(param); | |||
return bundle.total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
midout_iv("AlgoS8ChanWiseStride1NCHW44::get_workspace"_hash)) { | |||
auto bundle = channel_wise_nchw44::stride1::get_bundle(param); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
midout_iv("AlgoS8ChanWiseStride1NCHW44"_hash)) { | |||
midout_iv("AlgoS8ChanWiseStride1NCHW44::dispatch_kerns"_hash)) { | |||
return channel_wise_nchw44::stride1::get_kimpls(param); | |||
} | |||
MIDOUT_END(); | |||
@@ -96,15 +111,20 @@ bool ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::usable( | |||
size_t ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
auto bundle = channel_wise_nchw44::stride2::get_bundle(param); | |||
return bundle.total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
midout_iv("AlgoS8ChanWiseStride2NCHW44::get_workspace"_hash)) { | |||
auto bundle = channel_wise_nchw44::stride2::get_bundle(param); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
midout_iv("AlgoS8ChanWiseStride2NCHW44"_hash)) { | |||
midout_iv("AlgoS8ChanWiseStride2NCHW44::dispatch_kerns"_hash)) { | |||
return channel_wise_nchw44::stride2::get_kimpls(param); | |||
} | |||
MIDOUT_END(); | |||
@@ -119,15 +139,21 @@ bool ConvBiasImpl::AlgoS8DirectStride2::usable(const NCBKernSizeParam& param, | |||
size_t ConvBiasImpl::AlgoS8DirectStride2::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_int8_stride2::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
midout_iv("AlgoS8DirectStride2::get_workspace"_hash)) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_int8_stride2::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 1) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
midout_iv("AlgoS8DirectStride2::dispatch_kerns"_hash)) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
return direct_int8_stride2::get_kimpls(param, large_group); | |||
} | |||
@@ -144,15 +170,21 @@ bool ConvBiasImpl::AlgoDotS8DirectStride1::usable(const NCBKernSizeParam& param, | |||
size_t ConvBiasImpl::AlgoDotS8DirectStride1::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_dotprod_int8_stride1::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
midout_iv("AlgoDotS8DirectStride1::get_workspace"_hash)) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_dotprod_int8_stride1::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 1) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
midout_iv("AlgoDotS8DirectStride1::dispatch_kerns"_hash)) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
return direct_dotprod_int8_stride1::get_kimpls(param, large_group); | |||
} | |||
@@ -168,15 +200,21 @@ bool ConvBiasImpl::AlgoDotS8DirectStride2::usable(const NCBKernSizeParam& param, | |||
size_t ConvBiasImpl::AlgoDotS8DirectStride2::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_dotprod_int8_stride2::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
midout_iv("AlgoDotS8DirectStride2::get_workspace"_hash)) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_dotprod_int8_stride2::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 2) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
midout_iv("AlgoDotS8DirectStride2::dispatch_kerns"_hash)) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
return direct_dotprod_int8_stride2::get_kimpls(param, large_group); | |||
} | |||
@@ -188,37 +226,45 @@ ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns( | |||
/* ======================= AlgoS8WinogradF23_8x8 ======================== */ | |||
bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable( | |||
const NCBKernSizeParam& param, | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0) | |||
return false; | |||
using Strategy = winograd::winograd_2x3_8x8_s8; | |||
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; | |||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||
auto&& matmul_param = | |||
megdnn::winograd::ConvBias<Strategy, param::MatrixMul::Format::MK8>( | |||
strategy, m_tile_size, param) | |||
.get_matmul_kern_param(param); | |||
return m_matmul_algo->usable(matmul_param) && | |||
m_matmul_algo->packmode() == PackMode::NO_PACK && | |||
((param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.filter_type.enumv() == DTypeEnum::QuantizedS8) || | |||
(param.filter_meta.format == | |||
param::ConvBias::Format::NCHW_WINOGRAD && | |||
param.output_block_size == 2 && | |||
param.winograd_matmul_format == param::MatrixMul::Format::MK8 && | |||
param.filter_type.enumv() == DTypeEnum::QuantizedS16)) && | |||
!param.filter_meta.should_flip && | |||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||
param.filter_meta.spatial[0] == 3) && | |||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||
param.filter_meta.stride[0] == 1) && | |||
(param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && | |||
param.filter_meta.dilation[0] == 1) && | |||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||
param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.bias_type.enumv() == DTypeEnum::QuantizedS32 && | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
midout_iv("AlgoS8WinogradF23_8x8::usable"_hash)) { | |||
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0) | |||
return false; | |||
using Strategy = winograd::winograd_2x3_8x8_s8; | |||
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; | |||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||
auto&& matmul_param = | |||
megdnn::winograd::ConvBias<Strategy, | |||
param::MatrixMul::Format::MK8>( | |||
strategy, m_tile_size, param) | |||
.get_matmul_kern_param(param); | |||
return m_matmul_algo->usable(matmul_param) && | |||
m_matmul_algo->packmode() == PackMode::NO_PACK && | |||
((param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
param.filter_type.enumv() == DTypeEnum::QuantizedS8) || | |||
(param.filter_meta.format == | |||
param::ConvBias::Format::NCHW_WINOGRAD && | |||
param.output_block_size == 2 && | |||
param.winograd_matmul_format == | |||
param::MatrixMul::Format::MK8 && | |||
param.filter_type.enumv() == DTypeEnum::QuantizedS16)) && | |||
!param.filter_meta.should_flip && | |||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||
param.filter_meta.spatial[0] == 3) && | |||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||
param.filter_meta.stride[0] == 1) && | |||
(param.filter_meta.dilation[0] == | |||
param.filter_meta.dilation[1] && | |||
param.filter_meta.dilation[0] == 1) && | |||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||
param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||
param.bias_type.enumv() == DTypeEnum::QuantizedS32 && | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||
} | |||
MIDOUT_END(); | |||
return false; | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoS8WinogradF23_8x8, | |||
@@ -202,14 +202,19 @@ bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::is_preferred( | |||
size_t ConvBiasImpl::AlgoDotS8Direct_NCHW44::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
return get_bundle(param).total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
midout_iv("ALGODOTS8DIRECT_NCHW44::get_workspace"_hash)) { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||
midout_iv("ALGODOTS8DIRECT_NCHW44"_hash)) { | |||
midout_iv("ALGODOTS8DIRECT_NCHW44::dispatch_kerns"_hash)) { | |||
auto fm = param.filter_meta; | |||
size_t BATCH = param.n; | |||
size_t GROUP = fm.group; | |||
@@ -223,7 +223,12 @@ bool ConvBiasImpl::AlgoS8DirectNCHW44::is_preferred( | |||
size_t ConvBiasImpl::AlgoS8DirectNCHW44::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
return get_bundle(param).total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, | |||
midout_iv("AlgoS8DirectNCHW44::get_workspace"_hash)) { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
@@ -232,7 +232,12 @@ bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred( | |||
size_t ConvBiasImpl::AlgoS8DirectNCHWNCHW44::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
return get_bundle(param).total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44, | |||
midout_iv("AlgoS8DirectNCHWNCHW44::get_workspace"_hash)) { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
@@ -183,7 +183,12 @@ bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable( | |||
size_t ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
return get_bundle(param).total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_dot, | |||
midout_iv("AlgoDotS8DirectNCHWNCHW44::get_workspace"_hash)) { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
@@ -83,7 +83,8 @@ void get_rectified_size_str2(size_t IH, size_t IW, size_t OH, size_t OW, | |||
/* ===================== direct algo ===================== */ | |||
bool ConvBiasImpl::AlgoI8x8x16Direct::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 0) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||
midout_iv("AlgoI8x8x16Direct::usable"_hash)) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
return param.bias_mode == BiasMode::NO_BIAS && | |||
@@ -122,7 +123,8 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Direct::get_bundle( | |||
} | |||
size_t ConvBiasImpl::AlgoI8x8x16Direct::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 1) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||
midout_iv("AlgoI8x8x16Direct::get_workspace"_hash)) { | |||
auto bundle = get_bundle(param); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
@@ -287,7 +289,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Direct::get_kimpls( | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 2) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||
midout_iv("AlgoI8x8x16Direct::dispatch_kerns"_hash)) { | |||
return get_kimpls(param); | |||
} | |||
MIDOUT_END(); | |||
@@ -297,7 +300,8 @@ ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns( | |||
/* ===================== stride-2 algo ===================== */ | |||
bool ConvBiasImpl::AlgoI8x8x16Stride2::usable(const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 0) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||
midout_iv("AlgoI8x8x16Stride2::usable"_hash)) { | |||
auto&& fm = param.filter_meta; | |||
auto FH = fm.spatial[0]; | |||
return param.bias_mode == BiasMode::NO_BIAS && | |||
@@ -337,7 +341,8 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Stride2::get_bundle( | |||
} | |||
size_t ConvBiasImpl::AlgoI8x8x16Stride2::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 1) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||
midout_iv("AlgoI8x8x16Stride2::get_workspace"_hash)) { | |||
auto bundle = get_bundle(param); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
@@ -501,7 +506,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Stride2::get_kimpls( | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoI8x8x16Stride2::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 2) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||
midout_iv("AlgoI8x8x16Stride2::dispatch_kerns"_hash)) { | |||
return get_kimpls(param); | |||
} | |||
MIDOUT_END(); | |||
@@ -510,7 +516,8 @@ ConvBiasImpl::AlgoI8x8x16Stride2::dispatch_kerns( | |||
bool ConvBiasImpl::AlgoI8x8x16Stride2Filter2::usable( | |||
const NCBKernSizeParam& param, | |||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 3, 0) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||
midout_iv("AlgoI8x8x16Stride2Filter2::usable"_hash)) { | |||
return param.bias_mode == BiasMode::NO_BIAS && | |||
param.nonlineMode == NonlineMode::IDENTITY && | |||
param.nr_threads == 1_z && | |||
@@ -522,7 +529,8 @@ bool ConvBiasImpl::AlgoI8x8x16Stride2Filter2::usable( | |||
size_t ConvBiasImpl::AlgoI8x8x16Stride2Filter2::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 3, 1) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||
midout_iv("AlgoI8x8x16Stride2Filter2::get_workspace"_hash)) { | |||
return conv_bias::get_workspace_in_bytes_conv_int8x8x16_stride2_flt2( | |||
param); | |||
} | |||
@@ -535,7 +543,8 @@ ConvBiasImpl::AlgoI8x8x16Stride2Filter2::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
// return {conv_bias::conv_int8x8x16_stride2_flt2,true}; | |||
auto kern = [](const NCBKernParam& param, const NCBKernIndex& ncb_index) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 3, 2) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||
midout_iv("AlgoI8x8x16Stride2Filter2::dispatch_kerns"_hash)) { | |||
auto ncb_param = param; | |||
ncb_param.src_ptr = param.src<void>(0, ncb_index.ndrange_id[0]); | |||
ncb_param.dst_ptr = param.dst<void>(0, ncb_index.ndrange_id[0]); | |||
@@ -573,18 +582,25 @@ bool ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::usable( | |||
size_t ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
size_t stride_h = param.filter_meta.stride[0]; | |||
size_t stride_w = param.filter_meta.stride[1]; | |||
megdnn_assert(stride_h == stride_w); | |||
if (stride_h == 1) { | |||
return channel_wise_nchw44_8x8x16::stride1::get_bundle(param) | |||
.total_size_in_bytes(); | |||
} else if (stride_h == 2) { | |||
return channel_wise_nchw44_8x8x16::stride2::get_bundle(param) | |||
.total_size_in_bytes(); | |||
} else { | |||
return 0; | |||
MIDOUT_BEGIN( | |||
megdnn_arm_common_conv_bias_int8816_kimpl, | |||
midout_iv( | |||
"AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace"_hash)) { | |||
size_t stride_h = param.filter_meta.stride[0]; | |||
size_t stride_w = param.filter_meta.stride[1]; | |||
megdnn_assert(stride_h == stride_w); | |||
if (stride_h == 1) { | |||
return channel_wise_nchw44_8x8x16::stride1::get_bundle(param) | |||
.total_size_in_bytes(); | |||
} else if (stride_h == 2) { | |||
return channel_wise_nchw44_8x8x16::stride2::get_bundle(param) | |||
.total_size_in_bytes(); | |||
} else { | |||
return 0; | |||
} | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
@@ -13,11 +13,8 @@ | |||
#include "src/common/utils.h" | |||
#include <cstring> | |||
#include "midout.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8816_filter) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
using namespace conv_bias; | |||
@@ -12,9 +12,7 @@ | |||
#include "src/common/utils.h" | |||
#include <cstring> | |||
#include "midout.h" | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
MIDOUT_DECL(megdnn_arm_common_conv_bias_s2_filter) | |||
#pragma GCC diagnostic ignored "-Wunused-parameter" | |||
@@ -229,7 +229,12 @@ bool ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::usable( | |||
size_t ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
return get_bundle(param).total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_i8i8i16_nchw_nchw44, | |||
midout_iv("AlgoI8x8x16DirectNCHWNCHW44::get_workspace"_hash)) { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
@@ -32,15 +32,21 @@ bool ConvBiasImpl::AlgoQU8DirectStride1::usable(const NCBKernSizeParam& param, | |||
size_t ConvBiasImpl::AlgoQU8DirectStride1::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_quint8_stride1::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, | |||
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_quint8_stride1::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 0) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, | |||
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
return direct_quint8_stride1::get_kimpls(param, large_group); | |||
} | |||
@@ -57,15 +63,21 @@ bool ConvBiasImpl::AlgoQU8DirectStride2::usable( | |||
size_t ConvBiasImpl::AlgoQU8DirectStride2::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_quint8_stride2::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, | |||
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_quint8_stride2::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 1) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, | |||
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
return direct_quint8_stride2::get_kimpls(param, large_group); | |||
} | |||
@@ -81,15 +93,21 @@ bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(const NCBKernSizeParam& param, | |||
size_t ConvBiasImpl::AlgoDotU8DirectStride1::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_dotprod_quint8_stride1::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, | |||
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_dotprod_quint8_stride1::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 0) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, | |||
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
return direct_dotprod_quint8_stride1::get_kimpls(param, large_group); | |||
} | |||
@@ -105,15 +123,21 @@ bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(const NCBKernSizeParam& param, | |||
size_t ConvBiasImpl::AlgoDotU8DirectStride2::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_dotprod_quint8_stride2::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, | |||
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
auto bundle = direct_dotprod_quint8_stride2::get_bundle(param, large_group); | |||
return bundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
ConvBiasImpl::AlgoDotU8DirectStride2::dispatch_kerns( | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 1) { | |||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, | |||
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) { | |||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||
return direct_dotprod_quint8_stride2::get_kimpls(param, large_group); | |||
} | |||
@@ -32,13 +32,23 @@ bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::usable( | |||
size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::get_workspace( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||
return deconv::get_workspace_in_bytes_stride1_int8x8x32_dot(param); | |||
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, | |||
midout_iv("AlgoSdot8DirectStride1::get_workspace"_hash)) { | |||
return deconv::get_workspace_in_bytes_stride1_int8x8x32_dot(param); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
ConvolutionBackwardDataImpl::ncb_kern_t | |||
ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | |||
return deconv::stride1_int8x8x32_dot; | |||
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, | |||
midout_iv("AlgoSdot8DirectStride1::dispatch_kern"_hash)) { | |||
return deconv::stride1_int8x8x32_dot; | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
/* ===================== direct stride 2 algo ===================== */ | |||
@@ -49,13 +59,23 @@ bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::usable( | |||
size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::get_workspace( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||
return deconv::get_workspace_in_bytes_stride2_int8x8x32_dot(param); | |||
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, | |||
midout_iv("AlgoSdot8DirectStride2::get_workspace"_hash)) { | |||
return deconv::get_workspace_in_bytes_stride2_int8x8x32_dot(param); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
ConvolutionBackwardDataImpl::ncb_kern_t | |||
ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::dispatch_kern( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | |||
return deconv::stride2_int8x8x32_dot; | |||
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, | |||
midout_iv("AlgoSdot8DirectStride2::dispatch_kern"_hash)) { | |||
return deconv::stride2_int8x8x32_dot; | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
#endif | |||
@@ -33,13 +33,23 @@ bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::usable( | |||
size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::get_workspace( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||
return deconv::get_workspace_in_bytes_stride1_quint8_dot(param); | |||
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, | |||
midout_iv("AlgoUdot8DirectStride1::get_workspace"_hash)) { | |||
return deconv::get_workspace_in_bytes_stride1_quint8_dot(param); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
ConvolutionBackwardDataImpl::ncb_kern_t | |||
ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | |||
return deconv::stride1_quint8_dot; | |||
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, | |||
midout_iv("AlgoUdot8DirectStride1::dispatch_kern"_hash)) { | |||
return deconv::stride1_quint8_dot; | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
/* ===================== direct stride 2 algo ===================== */ | |||
@@ -50,13 +60,23 @@ bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::usable( | |||
size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::get_workspace( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||
return deconv::get_workspace_in_bytes_stride2_quint8_dot(param); | |||
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, | |||
midout_iv("AlgoUdot8DirectStride2::get_workspace"_hash)) { | |||
return deconv::get_workspace_in_bytes_stride2_quint8_dot(param); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
ConvolutionBackwardDataImpl::ncb_kern_t | |||
ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::dispatch_kern( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | |||
return deconv::stride2_quint8_dot; | |||
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, | |||
midout_iv("AlgoUdot8DirectStride2::dispatch_kern"_hash)) { | |||
return deconv::stride2_quint8_dot; | |||
} | |||
MIDOUT_END(); | |||
return {}; | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -18,6 +18,8 @@ | |||
MIDOUT_DECL(megdnn_arm_hgemv) | |||
MIDOUT_DECL(megdnn_arm_exec_int8816) | |||
MIDOUT_DECL(megdnn_arm_exec_int8832) | |||
MIDOUT_DECL(megdnn_arm_exec_fp32) | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
@@ -63,8 +65,13 @@ bool MatrixMulImpl::AlgoInt8x8x16::usable( | |||
size_t MatrixMulImpl::AlgoInt8x8x16::get_workspace( | |||
const KernSizeParam& kern_size_param) const { | |||
auto wbundle = get_workspace_bundle_int_8x8x16(kern_size_param); | |||
return wbundle.total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_arm_exec_int8816, | |||
midout_iv("AlgoInt8x8x16::get_workspace"_hash)) { | |||
auto wbundle = get_workspace_bundle_int_8x8x16(kern_size_param); | |||
return wbundle.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern( | |||
@@ -75,11 +82,15 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern( | |||
/* ===================== Int8x8x32 Gemv algo ===================== */ | |||
namespace { | |||
void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int32>(); | |||
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
MIDOUT_BEGIN(megdnn_arm_exec_int8832, | |||
midout_iv("int8x8x32_gemv_kern"_hash)) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int32>(); | |||
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // anonymous namespace | |||
@@ -104,11 +115,15 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern( | |||
/* ===================== Int8x8x32 Gemv MK4 algo ===================== */ | |||
namespace { | |||
void int8x8x32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int32>(); | |||
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
MIDOUT_BEGIN(megdnn_arm_exec_int8832, | |||
midout_iv("int8x8x32_gemv_mk4_kern"_hash)) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int32>(); | |||
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // anonymous namespace | |||
@@ -147,11 +162,15 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern( | |||
/* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ | |||
namespace { | |||
void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int32>(); | |||
gemv_like_mk4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
MIDOUT_BEGIN(megdnn_arm_exec_int8832, | |||
midout_iv("int8x8x32_gemv_mk4_dot_kern"_hash)) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||
auto Cptr = kern_param.C<dt_int32>(); | |||
gemv_like_mk4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // anonymous namespace | |||
@@ -189,12 +208,16 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::get_kern( | |||
/* ===================== F32 Gemv algo ===================== */ | |||
namespace { | |||
void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_float32>(), | |||
Bptr = kern_param.B<dt_float32>(); | |||
auto Cptr = kern_param.C<dt_float32>(); | |||
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
MIDOUT_BEGIN(megdnn_arm_exec_fp32, | |||
midout_iv("f32_gemv_kern"_hash)) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_float32>(), | |||
Bptr = kern_param.B<dt_float32>(); | |||
auto Cptr = kern_param.C<dt_float32>(); | |||
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // anonymous namespace | |||
@@ -225,12 +248,16 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern( | |||
/* ================== F32 Gemv MK4 algo ================== */ | |||
namespace { | |||
void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_float32>(), | |||
Bptr = kern_param.B<dt_float32>(); | |||
auto Cptr = kern_param.C<dt_float32>(); | |||
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
MIDOUT_BEGIN(megdnn_arm_exec_fp32, | |||
midout_iv("f32_gemv_mk4_kern"_hash)) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||
const auto Aptr = kern_param.A<dt_float32>(), | |||
Bptr = kern_param.B<dt_float32>(); | |||
auto Cptr = kern_param.C<dt_float32>(); | |||
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // anonymous namespace | |||
@@ -266,11 +293,15 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GemvMK4::get_kern( | |||
namespace { | |||
template <typename stype, typename dtype> | |||
void gevm_like_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDB = kern_param.LDB; | |||
const auto Aptr = kern_param.A<stype>(), Bptr = kern_param.B<stype>(); | |||
auto Cptr = kern_param.C<dtype>(); | |||
megdnn::arm_common::gemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1); | |||
MIDOUT_BEGIN(megdnn_arm_exec_fp32, | |||
midout_iv("gevm_like_kern"_hash)) { | |||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||
auto LDB = kern_param.LDB; | |||
const auto Aptr = kern_param.A<stype>(), Bptr = kern_param.B<stype>(); | |||
auto Cptr = kern_param.C<dtype>(); | |||
megdnn::arm_common::gemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1); | |||
} | |||
MIDOUT_END(); | |||
} | |||
} // anonymous namespace | |||
@@ -75,6 +75,7 @@ size_t MatrixMulImpl::AlgoF32::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32::get_kern( | |||
@@ -141,6 +142,7 @@ size_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_kern( | |||
@@ -202,6 +204,7 @@ size_t MatrixMulImpl::AlgoF16K4x16x1::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K4x16x1::get_kern( | |||
@@ -265,6 +268,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K4x2x16::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x2x16::get_kern( | |||
@@ -326,6 +330,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K4x8x8::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x8x8::get_kern( | |||
@@ -386,6 +391,7 @@ size_t MatrixMulImpl::AlgoQuint8K4x8x8::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K4x8x8::get_kern( | |||
@@ -445,6 +451,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16K4x2x16::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x2x16::get_kern( | |||
@@ -510,6 +517,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16K4x8x8::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x8x8::get_kern( | |||
@@ -577,6 +585,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_kern( | |||
@@ -642,6 +651,7 @@ size_t MatrixMulImpl::AlgoInt16x16x32K12x4x1::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32K12x4x1::get_kern( | |||
@@ -702,6 +712,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K6x8x4::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K6x8x4::get_kern( | |||
@@ -764,6 +775,7 @@ size_t MatrixMulImpl::AlgoQuint8DotK4x8x4::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8DotK4x8x4::get_kern( | |||
@@ -830,6 +842,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_kern( | |||
@@ -894,6 +907,7 @@ size_t MatrixMulImpl::AlgoF32MK4_4x8::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_4x8::get_kern( | |||
@@ -929,6 +943,7 @@ size_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_kern( | |||
@@ -986,6 +1001,7 @@ size_t MatrixMulImpl::AlgoF16MK8_4x8::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_4x8::get_kern( | |||
@@ -1066,6 +1082,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::get_workspace( | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::get_kern( | |||
@@ -261,6 +261,7 @@ size_t ConvBiasImpl::AlgoConv1x1Gemv::get_workspace( | |||
.total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvBiasImpl::NCBKern> | |||
@@ -175,44 +175,54 @@ bool ConvolutionImpl::AlgoFallback::usable( | |||
} | |||
size_t ConvolutionImpl::AlgoFallback::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
auto FH = param.filter_meta.spatial[0], FW = param.filter_meta.spatial[1]; | |||
size_t nr_threads = param.nr_threads; | |||
if (param.filter_meta.should_flip) { | |||
// need transpose filter | |||
return WorkspaceBundle{nullptr, {FH * FW * sizeof(float)}} | |||
.total_size_in_bytes() * | |||
nr_threads; | |||
} else { | |||
return 0; | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN(megdnn_fallback_conv, | |||
midout_iv("AlgoFallback::get_workspace"_hash)) { | |||
auto FH = param.filter_meta.spatial[0], | |||
FW = param.filter_meta.spatial[1]; | |||
size_t nr_threads = param.nr_threads; | |||
if (param.filter_meta.should_flip) { | |||
// need transpose filter | |||
return WorkspaceBundle{nullptr, {FH * FW * sizeof(float)}} | |||
.total_size_in_bytes() * | |||
nr_threads; | |||
} else { | |||
return 0; | |||
} | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
SmallVector<ConvolutionImpl::NCBKern> | |||
ConvolutionImpl::AlgoFallback::dispatch_kern( | |||
const NCBKernSizeParam& param) const { | |||
size_t group = param.filter_meta.group; | |||
size_t N = param.n; | |||
size_t nr_threads = param.nr_threads; | |||
size_t workspace_per_thread = get_workspace( param) / nr_threads; | |||
auto kern_fallback = [workspace_per_thread](const NCBKernParam& p, | |||
const NCBKernIndex& ncb_index) { | |||
UNPACK_CONV_F32_NCB_KERN_SIZES(p); | |||
size_t batch_id = ncb_index.ndrange_id[1]; | |||
size_t group_id = ncb_index.ndrange_id[0]; | |||
MEGDNN_MARK_USED_VAR(N); | |||
auto src = p.src<float>(batch_id, group_id), | |||
filter = p.filter<float>(group_id); | |||
auto dst = p.dst<float>(batch_id, group_id); | |||
size_t thread_id = ncb_index.thread_id; | |||
void* workspace_ptr = reinterpret_cast<void*>( | |||
reinterpret_cast<ptrdiff_t>(p.workspace_ptr) + | |||
workspace_per_thread * thread_id); | |||
convolution::run_conv(src, filter, dst, workspace_ptr, IH, IW, IC, FH, | |||
FW, OH, OW, OC, PH, PW, SH, SW, | |||
!p.filter_meta.should_flip); | |||
}; | |||
return {{kern_fallback, {group, N, 1_z}}}; | |||
MIDOUT_BEGIN(megdnn_fallback_conv, | |||
midout_iv("AlgoFallback::dispatch_kern"_hash)) { | |||
size_t group = param.filter_meta.group; | |||
size_t N = param.n; | |||
size_t nr_threads = param.nr_threads; | |||
size_t workspace_per_thread = get_workspace( param) / nr_threads; | |||
auto kern_fallback = [workspace_per_thread](const NCBKernParam& p, | |||
const NCBKernIndex& ncb_index) { | |||
UNPACK_CONV_F32_NCB_KERN_SIZES(p); | |||
size_t batch_id = ncb_index.ndrange_id[1]; | |||
size_t group_id = ncb_index.ndrange_id[0]; | |||
MEGDNN_MARK_USED_VAR(N); | |||
auto src = p.src<float>(batch_id, group_id), | |||
filter = p.filter<float>(group_id); | |||
auto dst = p.dst<float>(batch_id, group_id); | |||
size_t thread_id = ncb_index.thread_id; | |||
void* workspace_ptr = reinterpret_cast<void*>( | |||
reinterpret_cast<ptrdiff_t>(p.workspace_ptr) + | |||
workspace_per_thread * thread_id); | |||
convolution::run_conv(src, filter, dst, workspace_ptr, IH, IW, IC, FH, | |||
FW, OH, OW, OC, PH, PW, SH, SW, | |||
!p.filter_meta.should_flip); | |||
}; | |||
return {{kern_fallback, {group, N, 1_z}}}; | |||
} | |||
MIDOUT_END(); | |||
} | |||
/* ===================== naive algo ===================== */ | |||
@@ -339,22 +349,36 @@ WorkspaceBundle ConvolutionImpl::AlgoDefault::get_bundle( | |||
size_t ConvolutionImpl::AlgoDefault::get_workspace( | |||
const NCBKernSizeParam& param) const { | |||
return get_bundle(param).total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_fallback_conv, | |||
midout_iv("AlgoDefault::get_workspace"_hash)) { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
size_t ConvolutionImpl::AlgoDefault::get_preprocess_workspace( | |||
const NCBKernSizeParam& param) const { | |||
::ConvBiasImpl::NCBKernSizeParam conv_bias_param = | |||
init_conv_bias_param(param); | |||
return m_algorithm->get_preprocess_workspace(conv_bias_param); | |||
MIDOUT_BEGIN(megdnn_fallback_conv, | |||
midout_iv("AlgoDefault::get_preprocess_workspace"_hash)) { | |||
::ConvBiasImpl::NCBKernSizeParam conv_bias_param = | |||
init_conv_bias_param(param); | |||
return m_algorithm->get_preprocess_workspace(conv_bias_param); | |||
} | |||
MIDOUT_END(); | |||
} | |||
SmallVector<TensorLayout> | |||
ConvolutionImpl::AlgoDefault::deduce_preprocessed_filter_layout( | |||
const NCBKernSizeParam& param) const { | |||
::ConvBiasImpl::NCBKernSizeParam conv_bias_param = | |||
init_conv_bias_param( param); | |||
return m_algorithm->deduce_preprocessed_filter_layout(conv_bias_param); | |||
const NCBKernSizeParam& param) const { | |||
MIDOUT_BEGIN( | |||
megdnn_fallback_conv, | |||
midout_iv("AlgoDefault::deduce_preprocessed_filter_layout"_hash)) { | |||
::ConvBiasImpl::NCBKernSizeParam conv_bias_param = | |||
init_conv_bias_param(param); | |||
return m_algorithm->deduce_preprocessed_filter_layout(conv_bias_param); | |||
} | |||
MIDOUT_END(); | |||
} | |||
//! Return the implement preprocess kernel | |||
@@ -450,19 +474,29 @@ bool ConvolutionBackwardDataImpl::AlgoDirect::usable( | |||
size_t ConvolutionBackwardDataImpl::AlgoDirect::get_workspace( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||
auto FH = param.filter_meta.spatial[0], FW = param.filter_meta.spatial[1]; | |||
if (param.filter_meta.should_flip) { | |||
// need transpose filter | |||
return FH * FW * sizeof(float); | |||
} else { | |||
return 0; | |||
MIDOUT_BEGIN(megdnn_fallback_conv, | |||
midout_iv("AlgoDirect::get_workspace"_hash)) { | |||
auto FH = param.filter_meta.spatial[0], | |||
FW = param.filter_meta.spatial[1]; | |||
if (param.filter_meta.should_flip) { | |||
// need transpose filter | |||
return FH * FW * sizeof(float); | |||
} else { | |||
return 0; | |||
} | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
ConvolutionBackwardDataImpl::ncb_kern_t | |||
ConvolutionBackwardDataImpl::AlgoDirect::dispatch_kern( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | |||
return kern_direct; | |||
MIDOUT_BEGIN(megdnn_fallback_conv, | |||
midout_iv("AlgoDirect::dispatch_kern"_hash)) { | |||
return kern_direct; | |||
} | |||
MIDOUT_END(); | |||
} | |||
/* ===================== Matrix mul algo ===================== */ | |||
@@ -477,35 +511,48 @@ bool ConvolutionBackwardDataImpl::AlgoMatrixMul::usable( | |||
size_t ConvolutionBackwardDataImpl::AlgoMatrixMul::get_workspace( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||
return get_bundle(param).total_size_in_bytes(); | |||
MIDOUT_BEGIN(megdnn_fallback_conv, | |||
midout_iv("AlgoMatrixMul::get_workspace"_hash)) { | |||
return get_bundle(param).total_size_in_bytes(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
ConvolutionBackwardDataImpl::ncb_kern_t | |||
ConvolutionBackwardDataImpl::AlgoMatrixMul::dispatch_kern( | |||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||
#define cb(dt) \ | |||
do { \ | |||
if (param.filter_type.enumv() == DTypeTrait<dt>::enumv) { \ | |||
using ctype = DTypeTrait<dt>::ctype; \ | |||
return kern_matmul<ctype, ctype, ctype>; \ | |||
} \ | |||
#define cb(dt, midout_tag) \ | |||
do { \ | |||
if (param.filter_type.enumv() == DTypeTrait<dt>::enumv) { \ | |||
MIDOUT_BEGIN(megdnn_fallback_conv, midout_iv(midout_tag)) { \ | |||
using ctype = DTypeTrait<dt>::ctype; \ | |||
return kern_matmul<ctype, ctype, ctype>; \ | |||
} \ | |||
MIDOUT_END(); \ | |||
} \ | |||
} while (0); | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||
cb(dtype::Float32, "FLOAT"_hash); | |||
MEGDNN_INC_FLOAT16(cb(dtype::Float16, "FLOAT16"_hash)); | |||
MEGDNN_INC_FLOAT16(cb(dtype::BFloat16, "BFLOAT16"_hash)); | |||
#undef cb | |||
#define cb(dt_src, dt_dst) \ | |||
do { \ | |||
if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||
param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||
param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \ | |||
return kern_matmul<DTypeTrait<dt_src>::ctype, \ | |||
DTypeTrait<dt_src>::ctype, \ | |||
DTypeTrait<dt_dst>::ctype>; \ | |||
} \ | |||
#define cb(dt_src, dt_dst, midout_tag) \ | |||
do { \ | |||
if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||
param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||
param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \ | |||
MIDOUT_BEGIN(megdnn_fallback_conv, midout_iv(midout_tag)) { \ | |||
return kern_matmul<DTypeTrait<dt_src>::ctype, \ | |||
DTypeTrait<dt_src>::ctype, \ | |||
DTypeTrait<dt_dst>::ctype>; \ | |||
} \ | |||
MIDOUT_END(); \ | |||
} \ | |||
} while (0) | |||
cb(dtype::Int8, dtype::Int32); | |||
cb(dtype::QuantizedS8, dtype::QuantizedS32); | |||
cb(dtype::Quantized8Asymm, dtype::QuantizedS32); | |||
cb(dtype::Int8, dtype::Int32, "INT8x8x32"_hash); | |||
cb(dtype::QuantizedS8, dtype::QuantizedS32, "QINT8x8x32"_hash); | |||
cb(dtype::Quantized8Asymm, dtype::QuantizedS32, "QUINT8x8x32"_hash); | |||
megdnn_throw("unsupported data type on matrix mul"); | |||
#undef cb | |||
} | |||
@@ -24,7 +24,6 @@ | |||
#include <cstring> | |||
MIDOUT_DECL(megdnn_fb_conv_float) | |||
MIDOUT_DECL(megdnn_fb_convbwd_float) | |||
using namespace megdnn; | |||
@@ -53,13 +53,20 @@ bool MatrixMulImpl::AlgoF32K8x12x1::usable( | |||
size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace( | |||
const KernSizeParam& kern_size_param) const { | |||
auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; | |||
matmul::fallback::sgemm_8x12 strategy(M, N, K, kern_size_param.A_type, | |||
kern_size_param.B_type, | |||
kern_size_param.C_type); | |||
return matmul::GemmInterleaved<matmul::fallback::sgemm_8x12>( | |||
M, N, K, kern_size_param.trA, kern_size_param.trB, strategy) | |||
.get_workspace_size(); | |||
MIDOUT_BEGIN(megdnn_fb_matmul_f32_kern, | |||
midout_iv("AlgoF32K8x12x1::get_workspace"_hash)) { | |||
auto M = kern_size_param.M, N = kern_size_param.N, | |||
K = kern_size_param.K; | |||
matmul::fallback::sgemm_8x12 strategy(M, N, K, kern_size_param.A_type, | |||
kern_size_param.B_type, | |||
kern_size_param.C_type); | |||
return matmul::GemmInterleaved<matmul::fallback::sgemm_8x12>( | |||
M, N, K, kern_size_param.trA, kern_size_param.trB, | |||
strategy) | |||
.get_workspace_size(); | |||
} | |||
MIDOUT_END(); | |||
return 0; | |||
} | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( | |||
@@ -23,10 +23,16 @@ namespace naive { | |||
size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout&) { | |||
if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||
return (A.span().dist_elem() + B.span().dist_elem()) * sizeof(uint8_t); | |||
MIDOUT_BEGIN( | |||
megdnn_naive_matmul, | |||
midout_iv("MatrixMulForwardImpl::get_workspace_in_bytes"_hash)) { | |||
if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||
return (A.span().dist_elem() + B.span().dist_elem()) * | |||
sizeof(uint8_t); | |||
} | |||
return 0; | |||
} | |||
return 0; | |||
MIDOUT_END(); | |||
} | |||
template <bool TA, bool TB> | |||
@@ -127,7 +133,8 @@ void MatrixMulForwardImpl::exec_internal(_megdnn_tensor_in A, | |||
void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
_megdnn_tensor_out C, | |||
_megdnn_workspace workspace) { | |||
MIDOUT_BEGIN(megdnn_naive_matmul) { | |||
MIDOUT_BEGIN(megdnn_naive_matmul, | |||
midout_iv("MatrixMulForwardImpl::exec"_hash)) { | |||
check_exec(A.layout, B.layout, C.layout, workspace.size); | |||
auto p = param(); | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal(A, B, C, workspace, p)); | |||
@@ -17,6 +17,7 @@ | |||
#cmakedefine01 MGB_ENABLE_DEBUG_UTIL | |||
#cmakedefine01 MGB_ENABLE_LOGGING | |||
#cmakedefine01 MGB_ENABLE_GRAD | |||
#cmakedefine01 MGB_ENABLE_CPUINFO | |||
#cmakedefine01 MGB_VERBOSE_TYPEINFO_NAME | |||
#cmakedefine01 MGB_BUILD_SLIM_SERVING | |||
#cmakedefine01 MGB_ENABLE_EXCEPTION | |||
@@ -80,6 +81,16 @@ | |||
#define MGB_ENABLE_GRAD 1 | |||
#endif | |||
// whether to enable cpuinfo | |||
#ifndef MGB_ENABLE_CPUINFO | |||
#define MGB_ENABLE_CPUINFO 1 | |||
#endif | |||
#ifdef IOS | |||
#undef MGB_ENABLE_CPUINFO | |||
#define MGB_ENABLE_CPUINFO 0 | |||
#endif | |||
// whether to include actual class name in mgb::Typeinfo object; if this is | |||
// disabled, mgb::serialization::OprRegistry::find_opr_by_name would not work. | |||
#ifndef MGB_VERBOSE_TYPEINFO_NAME | |||