GitOrigin-RevId: 191334bd96
tags/v1.0.0-rc1
@@ -116,9 +116,9 @@ if(${MGE_ARCH} STREQUAL "AUTO") | |||||
endif() | endif() | ||||
if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386" OR ${MGE_ARCH} STREQUAL "armv7" OR ${MGE_ARCH} STREQUAL "aarch64") | if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386" OR ${MGE_ARCH} STREQUAL "armv7" OR ${MGE_ARCH} STREQUAL "aarch64") | ||||
option(MGB_ENABLE_CPUINFO "Build cpuinfo library for check runtime." ON) | |||||
if(MGB_ENABLE_CPUINFO) | |||||
message("-- Enable cpuinfo runtime check.") | |||||
option(MGE_ENABLE_CPUINFO "Build cpuinfo library for check runtime." ON) | |||||
if(MGE_ENABLE_CPUINFO) | |||||
message("-- Enable cpuinfo runtime check and little kernel optimize.") | |||||
add_definitions(-DMGB_ENABLE_CPUINFO_CHECK) | add_definitions(-DMGB_ENABLE_CPUINFO_CHECK) | ||||
include(cmake/cpuinfo.cmake) | include(cmake/cpuinfo.cmake) | ||||
endif() | endif() | ||||
@@ -53,7 +53,7 @@ add_library(megdnn EXCLUDE_FROM_ALL OBJECT ${SOURCES}) | |||||
target_link_libraries(megdnn PUBLIC opr_param_defs) | target_link_libraries(megdnn PUBLIC opr_param_defs) | ||||
if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386" OR ${MGE_ARCH} STREQUAL "armv7" OR ${MGE_ARCH} STREQUAL "aarch64") | if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386" OR ${MGE_ARCH} STREQUAL "armv7" OR ${MGE_ARCH} STREQUAL "aarch64") | ||||
if(MGB_ENABLE_CPUINFO) | |||||
if(MGE_ENABLE_CPUINFO) | |||||
target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:cpuinfo>) | target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:cpuinfo>) | ||||
endif() | endif() | ||||
endif() | endif() | ||||
@@ -49,13 +49,13 @@ size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace( | |||||
return wbundle.total_size_in_bytes(); | return wbundle.total_size_in_bytes(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return false; | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { | |||||
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) { | |||||
return get_kimpls(param); | return get_kimpls(param); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -58,6 +58,7 @@ size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( | ||||
@@ -118,6 +119,7 @@ size_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_kern( | ||||
@@ -177,6 +179,7 @@ size_t MatrixMulImpl::AlgoF32K4x16x1::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K4x16x1::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K4x16x1::get_kern( | ||||
@@ -237,6 +240,7 @@ size_t MatrixMulImpl::AlgoF32MK4_4x16::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_4x16::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_4x16::get_kern( | ||||
@@ -313,6 +317,7 @@ size_t MatrixMulImpl::AlgoF16K8x24x1::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K8x24x1::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K8x24x1::get_kern( | ||||
@@ -352,6 +357,7 @@ size_t MatrixMulImpl::AlgoF16MK8_8x8::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern( | ||||
@@ -431,6 +437,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_kern( | ||||
@@ -501,6 +508,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_kern( | ||||
@@ -573,6 +581,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_kern( | ||||
@@ -635,6 +644,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_kern( | ||||
@@ -696,6 +706,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_kern( | ||||
@@ -762,6 +773,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_kern( | ||||
@@ -828,6 +840,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_kern( | ||||
@@ -905,6 +918,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern( | ||||
@@ -981,6 +995,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_kern( | ||||
@@ -1051,6 +1066,7 @@ size_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_kern( | ||||
@@ -1092,6 +1108,7 @@ size_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_kern( | ||||
@@ -1172,6 +1189,7 @@ size_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_kern( | ||||
@@ -1277,6 +1295,7 @@ size_t MatrixMulImpl::AlgoQuint8K8x8x8::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x8::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x8::get_kern( | ||||
@@ -160,7 +160,12 @@ bool ConvBiasImpl::AlgoF32DirectNCHW44::usable(const NCBKernSizeParam& param, | |||||
size_t ConvBiasImpl::AlgoF32DirectNCHW44::get_workspace( | size_t ConvBiasImpl::AlgoF32DirectNCHW44::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
return get_bundle(param).total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw44_stride1, | |||||
midout_iv("AlgoF32DirectNCHW44::get_workspace"_hash)) { | |||||
return get_bundle(param).total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
@@ -199,7 +199,12 @@ bool ConvBiasImpl::AlgoF32DirectNCHWNCHW44::usable( | |||||
size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace( | size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
return get_bundle(param).total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw_nchw44, | |||||
midout_iv("AlgoF32DirectNCHWNCHW44::get_workspace"_hash)) { | |||||
return get_bundle(param).total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
@@ -33,29 +33,39 @@ bool ConvBiasImpl::AlgoS8DirectStride1::usable(const NCBKernSizeParam& param, | |||||
return direct_int8_stride1::can_conv_direct_stride1_int8(param); | return direct_int8_stride1::can_conv_direct_stride1_int8(param); | ||||
} | } | ||||
bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred( | bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred( | ||||
const NCBKernSizeParam& param) const { | |||||
auto&& fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
auto OC = fm.ocpg; | |||||
auto IC = fm.icpg; | |||||
bool preferred = ((FH == 2 && (OC <= 10 || IC <= 8)) || | |||||
((FH == 3 || FH == 5 || FH == 7) && | |||||
(OC <= 16 || (IC <= 4 && OC <= 32)))) && | |||||
param.bias_mode != BiasMode::BIAS; | |||||
return preferred; | |||||
const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||||
midout_iv("AlgoS8DirectStride1::is_preferred"_hash)) { | |||||
auto&& fm = param.filter_meta; | |||||
auto FH = fm.spatial[0]; | |||||
auto OC = fm.ocpg; | |||||
auto IC = fm.icpg; | |||||
bool preferred = ((FH == 2 && (OC <= 10 || IC <= 8)) || | |||||
((FH == 3 || FH == 5 || FH == 7) && | |||||
(OC <= 16 || (IC <= 4 && OC <= 32)))) && | |||||
param.bias_mode != BiasMode::BIAS; | |||||
return preferred; | |||||
} | |||||
MIDOUT_END(); | |||||
} | } | ||||
size_t ConvBiasImpl::AlgoS8DirectStride1::get_workspace( | size_t ConvBiasImpl::AlgoS8DirectStride1::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_int8_stride1::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||||
midout_iv("AlgoS8DirectStride1::get_workspace"_hash)) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_int8_stride1::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
ConvBiasImpl::AlgoS8DirectStride1::dispatch_kerns( | ConvBiasImpl::AlgoS8DirectStride1::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 0) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||||
midout_iv("AlgoS8DirectStride1::dispatch_kerns"_hash)) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | bool large_group = param.filter_meta.group >= param.nr_threads; | ||||
return direct_int8_stride1::get_kimpls(param, large_group); | return direct_int8_stride1::get_kimpls(param, large_group); | ||||
} | } | ||||
@@ -72,15 +82,20 @@ bool ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::usable( | |||||
size_t ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::get_workspace( | size_t ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
auto bundle = channel_wise_nchw44::stride1::get_bundle(param); | |||||
return bundle.total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||||
midout_iv("AlgoS8ChanWiseStride1NCHW44::get_workspace"_hash)) { | |||||
auto bundle = channel_wise_nchw44::stride1::get_bundle(param); | |||||
return bundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::dispatch_kerns( | ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | ||||
midout_iv("AlgoS8ChanWiseStride1NCHW44"_hash)) { | |||||
midout_iv("AlgoS8ChanWiseStride1NCHW44::dispatch_kerns"_hash)) { | |||||
return channel_wise_nchw44::stride1::get_kimpls(param); | return channel_wise_nchw44::stride1::get_kimpls(param); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -96,15 +111,20 @@ bool ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::usable( | |||||
size_t ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::get_workspace( | size_t ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
auto bundle = channel_wise_nchw44::stride2::get_bundle(param); | |||||
return bundle.total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||||
midout_iv("AlgoS8ChanWiseStride2NCHW44::get_workspace"_hash)) { | |||||
auto bundle = channel_wise_nchw44::stride2::get_bundle(param); | |||||
return bundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::dispatch_kerns( | ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | ||||
midout_iv("AlgoS8ChanWiseStride2NCHW44"_hash)) { | |||||
midout_iv("AlgoS8ChanWiseStride2NCHW44::dispatch_kerns"_hash)) { | |||||
return channel_wise_nchw44::stride2::get_kimpls(param); | return channel_wise_nchw44::stride2::get_kimpls(param); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -119,15 +139,21 @@ bool ConvBiasImpl::AlgoS8DirectStride2::usable(const NCBKernSizeParam& param, | |||||
size_t ConvBiasImpl::AlgoS8DirectStride2::get_workspace( | size_t ConvBiasImpl::AlgoS8DirectStride2::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_int8_stride2::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||||
midout_iv("AlgoS8DirectStride2::get_workspace"_hash)) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_int8_stride2::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns( | ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 1) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||||
midout_iv("AlgoS8DirectStride2::dispatch_kerns"_hash)) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | bool large_group = param.filter_meta.group >= param.nr_threads; | ||||
return direct_int8_stride2::get_kimpls(param, large_group); | return direct_int8_stride2::get_kimpls(param, large_group); | ||||
} | } | ||||
@@ -144,15 +170,21 @@ bool ConvBiasImpl::AlgoDotS8DirectStride1::usable(const NCBKernSizeParam& param, | |||||
size_t ConvBiasImpl::AlgoDotS8DirectStride1::get_workspace( | size_t ConvBiasImpl::AlgoDotS8DirectStride1::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_dotprod_int8_stride1::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||||
midout_iv("AlgoDotS8DirectStride1::get_workspace"_hash)) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_dotprod_int8_stride1::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns( | ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 1) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||||
midout_iv("AlgoDotS8DirectStride1::dispatch_kerns"_hash)) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | bool large_group = param.filter_meta.group >= param.nr_threads; | ||||
return direct_dotprod_int8_stride1::get_kimpls(param, large_group); | return direct_dotprod_int8_stride1::get_kimpls(param, large_group); | ||||
} | } | ||||
@@ -168,15 +200,21 @@ bool ConvBiasImpl::AlgoDotS8DirectStride2::usable(const NCBKernSizeParam& param, | |||||
size_t ConvBiasImpl::AlgoDotS8DirectStride2::get_workspace( | size_t ConvBiasImpl::AlgoDotS8DirectStride2::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_dotprod_int8_stride2::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||||
midout_iv("AlgoDotS8DirectStride2::get_workspace"_hash)) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_dotprod_int8_stride2::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns( | ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 2) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||||
midout_iv("AlgoDotS8DirectStride2::dispatch_kerns"_hash)) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | bool large_group = param.filter_meta.group >= param.nr_threads; | ||||
return direct_dotprod_int8_stride2::get_kimpls(param, large_group); | return direct_dotprod_int8_stride2::get_kimpls(param, large_group); | ||||
} | } | ||||
@@ -188,37 +226,45 @@ ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns( | |||||
/* ======================= AlgoS8WinogradF23_8x8 ======================== */ | /* ======================= AlgoS8WinogradF23_8x8 ======================== */ | ||||
bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable( | bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable( | ||||
const NCBKernSizeParam& param, | |||||
const NCBKernSizeParam& param, | |||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | AlgoSelectionStrategy /*algo_selection_strategy*/) const { | ||||
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0) | |||||
return false; | |||||
using Strategy = winograd::winograd_2x3_8x8_s8; | |||||
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; | |||||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||||
auto&& matmul_param = | |||||
megdnn::winograd::ConvBias<Strategy, param::MatrixMul::Format::MK8>( | |||||
strategy, m_tile_size, param) | |||||
.get_matmul_kern_param(param); | |||||
return m_matmul_algo->usable(matmul_param) && | |||||
m_matmul_algo->packmode() == PackMode::NO_PACK && | |||||
((param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
param.filter_type.enumv() == DTypeEnum::QuantizedS8) || | |||||
(param.filter_meta.format == | |||||
param::ConvBias::Format::NCHW_WINOGRAD && | |||||
param.output_block_size == 2 && | |||||
param.winograd_matmul_format == param::MatrixMul::Format::MK8 && | |||||
param.filter_type.enumv() == DTypeEnum::QuantizedS16)) && | |||||
!param.filter_meta.should_flip && | |||||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||||
param.filter_meta.spatial[0] == 3) && | |||||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||||
param.filter_meta.stride[0] == 1) && | |||||
(param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && | |||||
param.filter_meta.dilation[0] == 1) && | |||||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||||
param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.bias_type.enumv() == DTypeEnum::QuantizedS32 && | |||||
param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||||
midout_iv("AlgoS8WinogradF23_8x8::usable"_hash)) { | |||||
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0) | |||||
return false; | |||||
using Strategy = winograd::winograd_2x3_8x8_s8; | |||||
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; | |||||
Strategy strategy(param.src_type, param.filter_type, param.dst_type); | |||||
auto&& matmul_param = | |||||
megdnn::winograd::ConvBias<Strategy, | |||||
param::MatrixMul::Format::MK8>( | |||||
strategy, m_tile_size, param) | |||||
.get_matmul_kern_param(param); | |||||
return m_matmul_algo->usable(matmul_param) && | |||||
m_matmul_algo->packmode() == PackMode::NO_PACK && | |||||
((param.filter_meta.format == param::ConvBias::Format::NCHW && | |||||
param.filter_type.enumv() == DTypeEnum::QuantizedS8) || | |||||
(param.filter_meta.format == | |||||
param::ConvBias::Format::NCHW_WINOGRAD && | |||||
param.output_block_size == 2 && | |||||
param.winograd_matmul_format == | |||||
param::MatrixMul::Format::MK8 && | |||||
param.filter_type.enumv() == DTypeEnum::QuantizedS16)) && | |||||
!param.filter_meta.should_flip && | |||||
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && | |||||
param.filter_meta.spatial[0] == 3) && | |||||
(param.filter_meta.stride[0] == param.filter_meta.stride[1] && | |||||
param.filter_meta.stride[0] == 1) && | |||||
(param.filter_meta.dilation[0] == | |||||
param.filter_meta.dilation[1] && | |||||
param.filter_meta.dilation[0] == 1) && | |||||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && | |||||
param.src_type.enumv() == DTypeEnum::QuantizedS8 && | |||||
param.bias_type.enumv() == DTypeEnum::QuantizedS32 && | |||||
param.dst_type.enumv() == DTypeEnum::QuantizedS8; | |||||
} | |||||
MIDOUT_END(); | |||||
return false; | |||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoS8WinogradF23_8x8, | MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoS8WinogradF23_8x8, | ||||
@@ -202,14 +202,19 @@ bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::is_preferred( | |||||
size_t ConvBiasImpl::AlgoDotS8Direct_NCHW44::get_workspace( | size_t ConvBiasImpl::AlgoDotS8Direct_NCHW44::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
return get_bundle(param).total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | |||||
midout_iv("ALGODOTS8DIRECT_NCHW44::get_workspace"_hash)) { | |||||
return get_bundle(param).total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns( | ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, | ||||
midout_iv("ALGODOTS8DIRECT_NCHW44"_hash)) { | |||||
midout_iv("ALGODOTS8DIRECT_NCHW44::dispatch_kerns"_hash)) { | |||||
auto fm = param.filter_meta; | auto fm = param.filter_meta; | ||||
size_t BATCH = param.n; | size_t BATCH = param.n; | ||||
size_t GROUP = fm.group; | size_t GROUP = fm.group; | ||||
@@ -223,7 +223,12 @@ bool ConvBiasImpl::AlgoS8DirectNCHW44::is_preferred( | |||||
size_t ConvBiasImpl::AlgoS8DirectNCHW44::get_workspace( | size_t ConvBiasImpl::AlgoS8DirectNCHW44::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
return get_bundle(param).total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, | |||||
midout_iv("AlgoS8DirectNCHW44::get_workspace"_hash)) { | |||||
return get_bundle(param).total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
@@ -232,7 +232,12 @@ bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred( | |||||
size_t ConvBiasImpl::AlgoS8DirectNCHWNCHW44::get_workspace( | size_t ConvBiasImpl::AlgoS8DirectNCHWNCHW44::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
return get_bundle(param).total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44, | |||||
midout_iv("AlgoS8DirectNCHWNCHW44::get_workspace"_hash)) { | |||||
return get_bundle(param).total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
@@ -183,7 +183,12 @@ bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable( | |||||
size_t ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::get_workspace( | size_t ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
return get_bundle(param).total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_dot, | |||||
midout_iv("AlgoDotS8DirectNCHWNCHW44::get_workspace"_hash)) { | |||||
return get_bundle(param).total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
@@ -83,7 +83,8 @@ void get_rectified_size_str2(size_t IH, size_t IW, size_t OH, size_t OW, | |||||
/* ===================== direct algo ===================== */ | /* ===================== direct algo ===================== */ | ||||
bool ConvBiasImpl::AlgoI8x8x16Direct::usable(const NCBKernSizeParam& param, | bool ConvBiasImpl::AlgoI8x8x16Direct::usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy) const { | AlgoSelectionStrategy) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 0) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||||
midout_iv("AlgoI8x8x16Direct::usable"_hash)) { | |||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
return param.bias_mode == BiasMode::NO_BIAS && | return param.bias_mode == BiasMode::NO_BIAS && | ||||
@@ -122,7 +123,8 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Direct::get_bundle( | |||||
} | } | ||||
size_t ConvBiasImpl::AlgoI8x8x16Direct::get_workspace( | size_t ConvBiasImpl::AlgoI8x8x16Direct::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 1) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||||
midout_iv("AlgoI8x8x16Direct::get_workspace"_hash)) { | |||||
auto bundle = get_bundle(param); | auto bundle = get_bundle(param); | ||||
return bundle.total_size_in_bytes(); | return bundle.total_size_in_bytes(); | ||||
} | } | ||||
@@ -287,7 +289,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Direct::get_kimpls( | |||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns( | ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 2) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||||
midout_iv("AlgoI8x8x16Direct::dispatch_kerns"_hash)) { | |||||
return get_kimpls(param); | return get_kimpls(param); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -297,7 +300,8 @@ ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns( | |||||
/* ===================== stride-2 algo ===================== */ | /* ===================== stride-2 algo ===================== */ | ||||
bool ConvBiasImpl::AlgoI8x8x16Stride2::usable(const NCBKernSizeParam& param, | bool ConvBiasImpl::AlgoI8x8x16Stride2::usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy) const { | AlgoSelectionStrategy) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 0) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||||
midout_iv("AlgoI8x8x16Stride2::usable"_hash)) { | |||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
return param.bias_mode == BiasMode::NO_BIAS && | return param.bias_mode == BiasMode::NO_BIAS && | ||||
@@ -337,7 +341,8 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Stride2::get_bundle( | |||||
} | } | ||||
size_t ConvBiasImpl::AlgoI8x8x16Stride2::get_workspace( | size_t ConvBiasImpl::AlgoI8x8x16Stride2::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 1) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||||
midout_iv("AlgoI8x8x16Stride2::get_workspace"_hash)) { | |||||
auto bundle = get_bundle(param); | auto bundle = get_bundle(param); | ||||
return bundle.total_size_in_bytes(); | return bundle.total_size_in_bytes(); | ||||
} | } | ||||
@@ -501,7 +506,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Stride2::get_kimpls( | |||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
ConvBiasImpl::AlgoI8x8x16Stride2::dispatch_kerns( | ConvBiasImpl::AlgoI8x8x16Stride2::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 2) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||||
midout_iv("AlgoI8x8x16Stride2::dispatch_kerns"_hash)) { | |||||
return get_kimpls(param); | return get_kimpls(param); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
@@ -510,7 +516,8 @@ ConvBiasImpl::AlgoI8x8x16Stride2::dispatch_kerns( | |||||
bool ConvBiasImpl::AlgoI8x8x16Stride2Filter2::usable( | bool ConvBiasImpl::AlgoI8x8x16Stride2Filter2::usable( | ||||
const NCBKernSizeParam& param, | const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy /*algo_selection_strategy*/) const { | AlgoSelectionStrategy /*algo_selection_strategy*/) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 3, 0) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||||
midout_iv("AlgoI8x8x16Stride2Filter2::usable"_hash)) { | |||||
return param.bias_mode == BiasMode::NO_BIAS && | return param.bias_mode == BiasMode::NO_BIAS && | ||||
param.nonlineMode == NonlineMode::IDENTITY && | param.nonlineMode == NonlineMode::IDENTITY && | ||||
param.nr_threads == 1_z && | param.nr_threads == 1_z && | ||||
@@ -522,7 +529,8 @@ bool ConvBiasImpl::AlgoI8x8x16Stride2Filter2::usable( | |||||
size_t ConvBiasImpl::AlgoI8x8x16Stride2Filter2::get_workspace( | size_t ConvBiasImpl::AlgoI8x8x16Stride2Filter2::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 3, 1) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||||
midout_iv("AlgoI8x8x16Stride2Filter2::get_workspace"_hash)) { | |||||
return conv_bias::get_workspace_in_bytes_conv_int8x8x16_stride2_flt2( | return conv_bias::get_workspace_in_bytes_conv_int8x8x16_stride2_flt2( | ||||
param); | param); | ||||
} | } | ||||
@@ -535,7 +543,8 @@ ConvBiasImpl::AlgoI8x8x16Stride2Filter2::dispatch_kerns( | |||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
// return {conv_bias::conv_int8x8x16_stride2_flt2,true}; | // return {conv_bias::conv_int8x8x16_stride2_flt2,true}; | ||||
auto kern = [](const NCBKernParam& param, const NCBKernIndex& ncb_index) { | auto kern = [](const NCBKernParam& param, const NCBKernIndex& ncb_index) { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 3, 2) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, | |||||
midout_iv("AlgoI8x8x16Stride2Filter2::dispatch_kerns"_hash)) { | |||||
auto ncb_param = param; | auto ncb_param = param; | ||||
ncb_param.src_ptr = param.src<void>(0, ncb_index.ndrange_id[0]); | ncb_param.src_ptr = param.src<void>(0, ncb_index.ndrange_id[0]); | ||||
ncb_param.dst_ptr = param.dst<void>(0, ncb_index.ndrange_id[0]); | ncb_param.dst_ptr = param.dst<void>(0, ncb_index.ndrange_id[0]); | ||||
@@ -573,18 +582,25 @@ bool ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::usable( | |||||
size_t ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace( | size_t ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
size_t stride_h = param.filter_meta.stride[0]; | |||||
size_t stride_w = param.filter_meta.stride[1]; | |||||
megdnn_assert(stride_h == stride_w); | |||||
if (stride_h == 1) { | |||||
return channel_wise_nchw44_8x8x16::stride1::get_bundle(param) | |||||
.total_size_in_bytes(); | |||||
} else if (stride_h == 2) { | |||||
return channel_wise_nchw44_8x8x16::stride2::get_bundle(param) | |||||
.total_size_in_bytes(); | |||||
} else { | |||||
return 0; | |||||
MIDOUT_BEGIN( | |||||
megdnn_arm_common_conv_bias_int8816_kimpl, | |||||
midout_iv( | |||||
"AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace"_hash)) { | |||||
size_t stride_h = param.filter_meta.stride[0]; | |||||
size_t stride_w = param.filter_meta.stride[1]; | |||||
megdnn_assert(stride_h == stride_w); | |||||
if (stride_h == 1) { | |||||
return channel_wise_nchw44_8x8x16::stride1::get_bundle(param) | |||||
.total_size_in_bytes(); | |||||
} else if (stride_h == 2) { | |||||
return channel_wise_nchw44_8x8x16::stride2::get_bundle(param) | |||||
.total_size_in_bytes(); | |||||
} else { | |||||
return 0; | |||||
} | |||||
} | } | ||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
@@ -13,11 +13,8 @@ | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include <cstring> | #include <cstring> | ||||
#include "midout.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8816_filter) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | using namespace arm_common; | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
@@ -12,9 +12,7 @@ | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include <cstring> | #include <cstring> | ||||
#include "midout.h" | |||||
#include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
MIDOUT_DECL(megdnn_arm_common_conv_bias_s2_filter) | |||||
#pragma GCC diagnostic ignored "-Wunused-parameter" | #pragma GCC diagnostic ignored "-Wunused-parameter" | ||||
@@ -229,7 +229,12 @@ bool ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::usable( | |||||
size_t ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::get_workspace( | size_t ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
return get_bundle(param).total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_i8i8i16_nchw_nchw44, | |||||
midout_iv("AlgoI8x8x16DirectNCHWNCHW44::get_workspace"_hash)) { | |||||
return get_bundle(param).total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
@@ -32,15 +32,21 @@ bool ConvBiasImpl::AlgoQU8DirectStride1::usable(const NCBKernSizeParam& param, | |||||
size_t ConvBiasImpl::AlgoQU8DirectStride1::get_workspace( | size_t ConvBiasImpl::AlgoQU8DirectStride1::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_quint8_stride1::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, | |||||
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_quint8_stride1::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns( | ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 0) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, | |||||
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | bool large_group = param.filter_meta.group >= param.nr_threads; | ||||
return direct_quint8_stride1::get_kimpls(param, large_group); | return direct_quint8_stride1::get_kimpls(param, large_group); | ||||
} | } | ||||
@@ -57,15 +63,21 @@ bool ConvBiasImpl::AlgoQU8DirectStride2::usable( | |||||
size_t ConvBiasImpl::AlgoQU8DirectStride2::get_workspace( | size_t ConvBiasImpl::AlgoQU8DirectStride2::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_quint8_stride2::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, | |||||
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_quint8_stride2::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns( | ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 1) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, | |||||
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | bool large_group = param.filter_meta.group >= param.nr_threads; | ||||
return direct_quint8_stride2::get_kimpls(param, large_group); | return direct_quint8_stride2::get_kimpls(param, large_group); | ||||
} | } | ||||
@@ -81,15 +93,21 @@ bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(const NCBKernSizeParam& param, | |||||
size_t ConvBiasImpl::AlgoDotU8DirectStride1::get_workspace( | size_t ConvBiasImpl::AlgoDotU8DirectStride1::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_dotprod_quint8_stride1::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, | |||||
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_dotprod_quint8_stride1::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns( | ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 0) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, | |||||
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | bool large_group = param.filter_meta.group >= param.nr_threads; | ||||
return direct_dotprod_quint8_stride1::get_kimpls(param, large_group); | return direct_dotprod_quint8_stride1::get_kimpls(param, large_group); | ||||
} | } | ||||
@@ -105,15 +123,21 @@ bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(const NCBKernSizeParam& param, | |||||
size_t ConvBiasImpl::AlgoDotU8DirectStride2::get_workspace( | size_t ConvBiasImpl::AlgoDotU8DirectStride2::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_dotprod_quint8_stride2::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, | |||||
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | |||||
auto bundle = direct_dotprod_quint8_stride2::get_bundle(param, large_group); | |||||
return bundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
ConvBiasImpl::AlgoDotU8DirectStride2::dispatch_kerns( | ConvBiasImpl::AlgoDotU8DirectStride2::dispatch_kerns( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 1) { | |||||
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, | |||||
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) { | |||||
bool large_group = param.filter_meta.group >= param.nr_threads; | bool large_group = param.filter_meta.group >= param.nr_threads; | ||||
return direct_dotprod_quint8_stride2::get_kimpls(param, large_group); | return direct_dotprod_quint8_stride2::get_kimpls(param, large_group); | ||||
} | } | ||||
@@ -32,13 +32,23 @@ bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::usable( | |||||
size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::get_workspace( | size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::get_workspace( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ||||
return deconv::get_workspace_in_bytes_stride1_int8x8x32_dot(param); | |||||
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, | |||||
midout_iv("AlgoSdot8DirectStride1::get_workspace"_hash)) { | |||||
return deconv::get_workspace_in_bytes_stride1_int8x8x32_dot(param); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
ConvolutionBackwardDataImpl::ncb_kern_t | ConvolutionBackwardDataImpl::ncb_kern_t | ||||
ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern( | ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | ||||
return deconv::stride1_int8x8x32_dot; | |||||
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, | |||||
midout_iv("AlgoSdot8DirectStride1::dispatch_kern"_hash)) { | |||||
return deconv::stride1_int8x8x32_dot; | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | } | ||||
/* ===================== direct stride 2 algo ===================== */ | /* ===================== direct stride 2 algo ===================== */ | ||||
@@ -49,13 +59,23 @@ bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::usable( | |||||
size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::get_workspace( | size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::get_workspace( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ||||
return deconv::get_workspace_in_bytes_stride2_int8x8x32_dot(param); | |||||
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, | |||||
midout_iv("AlgoSdot8DirectStride2::get_workspace"_hash)) { | |||||
return deconv::get_workspace_in_bytes_stride2_int8x8x32_dot(param); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
ConvolutionBackwardDataImpl::ncb_kern_t | ConvolutionBackwardDataImpl::ncb_kern_t | ||||
ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::dispatch_kern( | ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::dispatch_kern( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | ||||
return deconv::stride2_int8x8x32_dot; | |||||
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, | |||||
midout_iv("AlgoSdot8DirectStride2::dispatch_kern"_hash)) { | |||||
return deconv::stride2_int8x8x32_dot; | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | } | ||||
#endif | #endif | ||||
@@ -33,13 +33,23 @@ bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::usable( | |||||
size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::get_workspace( | size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::get_workspace( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ||||
return deconv::get_workspace_in_bytes_stride1_quint8_dot(param); | |||||
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, | |||||
midout_iv("AlgoUdot8DirectStride1::get_workspace"_hash)) { | |||||
return deconv::get_workspace_in_bytes_stride1_quint8_dot(param); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
ConvolutionBackwardDataImpl::ncb_kern_t | ConvolutionBackwardDataImpl::ncb_kern_t | ||||
ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern( | ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | ||||
return deconv::stride1_quint8_dot; | |||||
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, | |||||
midout_iv("AlgoUdot8DirectStride1::dispatch_kern"_hash)) { | |||||
return deconv::stride1_quint8_dot; | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | } | ||||
/* ===================== direct stride 2 algo ===================== */ | /* ===================== direct stride 2 algo ===================== */ | ||||
@@ -50,13 +60,23 @@ bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::usable( | |||||
size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::get_workspace( | size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::get_workspace( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ||||
return deconv::get_workspace_in_bytes_stride2_quint8_dot(param); | |||||
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, | |||||
midout_iv("AlgoUdot8DirectStride2::get_workspace"_hash)) { | |||||
return deconv::get_workspace_in_bytes_stride2_quint8_dot(param); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
ConvolutionBackwardDataImpl::ncb_kern_t | ConvolutionBackwardDataImpl::ncb_kern_t | ||||
ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::dispatch_kern( | ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::dispatch_kern( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | ||||
return deconv::stride2_quint8_dot; | |||||
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, | |||||
midout_iv("AlgoUdot8DirectStride2::dispatch_kern"_hash)) { | |||||
return deconv::stride2_quint8_dot; | |||||
} | |||||
MIDOUT_END(); | |||||
return {}; | |||||
} | } | ||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -18,6 +18,8 @@ | |||||
MIDOUT_DECL(megdnn_arm_hgemv) | MIDOUT_DECL(megdnn_arm_hgemv) | ||||
MIDOUT_DECL(megdnn_arm_exec_int8816) | MIDOUT_DECL(megdnn_arm_exec_int8816) | ||||
MIDOUT_DECL(megdnn_arm_exec_int8832) | |||||
MIDOUT_DECL(megdnn_arm_exec_fp32) | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | using namespace arm_common; | ||||
@@ -63,8 +65,13 @@ bool MatrixMulImpl::AlgoInt8x8x16::usable( | |||||
size_t MatrixMulImpl::AlgoInt8x8x16::get_workspace( | size_t MatrixMulImpl::AlgoInt8x8x16::get_workspace( | ||||
const KernSizeParam& kern_size_param) const { | const KernSizeParam& kern_size_param) const { | ||||
auto wbundle = get_workspace_bundle_int_8x8x16(kern_size_param); | |||||
return wbundle.total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_arm_exec_int8816, | |||||
midout_iv("AlgoInt8x8x16::get_workspace"_hash)) { | |||||
auto wbundle = get_workspace_bundle_int_8x8x16(kern_size_param); | |||||
return wbundle.total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern( | ||||
@@ -75,11 +82,15 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern( | |||||
/* ===================== Int8x8x32 Gemv algo ===================== */ | /* ===================== Int8x8x32 Gemv algo ===================== */ | ||||
namespace { | namespace { | ||||
void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||||
auto Cptr = kern_param.C<dt_int32>(); | |||||
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
MIDOUT_BEGIN(megdnn_arm_exec_int8832, | |||||
midout_iv("int8x8x32_gemv_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||||
auto Cptr = kern_param.C<dt_int32>(); | |||||
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
} | |||||
MIDOUT_END(); | |||||
} | } | ||||
} // anonymous namespace | } // anonymous namespace | ||||
@@ -104,11 +115,15 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern( | |||||
/* ===================== Int8x8x32 Gemv MK4 algo ===================== */ | /* ===================== Int8x8x32 Gemv MK4 algo ===================== */ | ||||
namespace { | namespace { | ||||
void int8x8x32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { | void int8x8x32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||||
auto Cptr = kern_param.C<dt_int32>(); | |||||
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
MIDOUT_BEGIN(megdnn_arm_exec_int8832, | |||||
midout_iv("int8x8x32_gemv_mk4_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||||
auto Cptr = kern_param.C<dt_int32>(); | |||||
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
} | |||||
MIDOUT_END(); | |||||
} | } | ||||
} // anonymous namespace | } // anonymous namespace | ||||
@@ -147,11 +162,15 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern( | |||||
/* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ | /* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ | ||||
namespace { | namespace { | ||||
void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||||
auto Cptr = kern_param.C<dt_int32>(); | |||||
gemv_like_mk4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
MIDOUT_BEGIN(megdnn_arm_exec_int8832, | |||||
midout_iv("int8x8x32_gemv_mk4_dot_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>(); | |||||
auto Cptr = kern_param.C<dt_int32>(); | |||||
gemv_like_mk4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
} | |||||
MIDOUT_END(); | |||||
} | } | ||||
} // anonymous namespace | } // anonymous namespace | ||||
@@ -189,12 +208,16 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::get_kern( | |||||
/* ===================== F32 Gemv algo ===================== */ | /* ===================== F32 Gemv algo ===================== */ | ||||
namespace { | namespace { | ||||
void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_float32>(), | |||||
Bptr = kern_param.B<dt_float32>(); | |||||
auto Cptr = kern_param.C<dt_float32>(); | |||||
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
MIDOUT_BEGIN(megdnn_arm_exec_fp32, | |||||
midout_iv("f32_gemv_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_float32>(), | |||||
Bptr = kern_param.B<dt_float32>(); | |||||
auto Cptr = kern_param.C<dt_float32>(); | |||||
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
} | |||||
MIDOUT_END(); | |||||
} | } | ||||
} // anonymous namespace | } // anonymous namespace | ||||
@@ -225,12 +248,16 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern( | |||||
/* ================== F32 Gemv MK4 algo ================== */ | /* ================== F32 Gemv MK4 algo ================== */ | ||||
namespace { | namespace { | ||||
void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { | void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_float32>(), | |||||
Bptr = kern_param.B<dt_float32>(); | |||||
auto Cptr = kern_param.C<dt_float32>(); | |||||
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
MIDOUT_BEGIN(megdnn_arm_exec_fp32, | |||||
midout_iv("f32_gemv_mk4_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; | |||||
const auto Aptr = kern_param.A<dt_float32>(), | |||||
Bptr = kern_param.B<dt_float32>(); | |||||
auto Cptr = kern_param.C<dt_float32>(); | |||||
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); | |||||
} | |||||
MIDOUT_END(); | |||||
} | } | ||||
} // anonymous namespace | } // anonymous namespace | ||||
@@ -266,11 +293,15 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GemvMK4::get_kern( | |||||
namespace { | namespace { | ||||
template <typename stype, typename dtype> | template <typename stype, typename dtype> | ||||
void gevm_like_kern(const MatrixMulImpl::KernParam& kern_param) { | void gevm_like_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDB = kern_param.LDB; | |||||
const auto Aptr = kern_param.A<stype>(), Bptr = kern_param.B<stype>(); | |||||
auto Cptr = kern_param.C<dtype>(); | |||||
megdnn::arm_common::gemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1); | |||||
MIDOUT_BEGIN(megdnn_arm_exec_fp32, | |||||
midout_iv("gevm_like_kern"_hash)) { | |||||
auto M = kern_param.M, N = kern_param.N, K = kern_param.K; | |||||
auto LDB = kern_param.LDB; | |||||
const auto Aptr = kern_param.A<stype>(), Bptr = kern_param.B<stype>(); | |||||
auto Cptr = kern_param.C<dtype>(); | |||||
megdnn::arm_common::gemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1); | |||||
} | |||||
MIDOUT_END(); | |||||
} | } | ||||
} // anonymous namespace | } // anonymous namespace | ||||
@@ -75,6 +75,7 @@ size_t MatrixMulImpl::AlgoF32::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32::get_kern( | ||||
@@ -141,6 +142,7 @@ size_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_kern( | ||||
@@ -202,6 +204,7 @@ size_t MatrixMulImpl::AlgoF16K4x16x1::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K4x16x1::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K4x16x1::get_kern( | ||||
@@ -265,6 +268,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K4x2x16::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x2x16::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x2x16::get_kern( | ||||
@@ -326,6 +330,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K4x8x8::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x8x8::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x8x8::get_kern( | ||||
@@ -386,6 +391,7 @@ size_t MatrixMulImpl::AlgoQuint8K4x8x8::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K4x8x8::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K4x8x8::get_kern( | ||||
@@ -445,6 +451,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16K4x2x16::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x2x16::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x2x16::get_kern( | ||||
@@ -510,6 +517,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16K4x8x8::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x8x8::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x8x8::get_kern( | ||||
@@ -577,6 +585,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_kern( | ||||
@@ -642,6 +651,7 @@ size_t MatrixMulImpl::AlgoInt16x16x32K12x4x1::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32K12x4x1::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32K12x4x1::get_kern( | ||||
@@ -702,6 +712,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K6x8x4::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K6x8x4::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K6x8x4::get_kern( | ||||
@@ -764,6 +775,7 @@ size_t MatrixMulImpl::AlgoQuint8DotK4x8x4::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8DotK4x8x4::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8DotK4x8x4::get_kern( | ||||
@@ -830,6 +842,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_kern( | ||||
@@ -894,6 +907,7 @@ size_t MatrixMulImpl::AlgoF32MK4_4x8::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_4x8::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_4x8::get_kern( | ||||
@@ -929,6 +943,7 @@ size_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_kern( | ||||
@@ -986,6 +1001,7 @@ size_t MatrixMulImpl::AlgoF16MK8_4x8::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_4x8::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_4x8::get_kern( | ||||
@@ -1066,6 +1082,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::get_workspace( | |||||
.get_workspace_size(); | .get_workspace_size(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::get_kern( | ||||
@@ -261,6 +261,7 @@ size_t ConvBiasImpl::AlgoConv1x1Gemv::get_workspace( | |||||
.total_size_in_bytes(); | .total_size_in_bytes(); | ||||
} | } | ||||
MIDOUT_END(); | MIDOUT_END(); | ||||
return 0; | |||||
} | } | ||||
SmallVector<ConvBiasImpl::NCBKern> | SmallVector<ConvBiasImpl::NCBKern> | ||||
@@ -175,44 +175,54 @@ bool ConvolutionImpl::AlgoFallback::usable( | |||||
} | } | ||||
size_t ConvolutionImpl::AlgoFallback::get_workspace( | size_t ConvolutionImpl::AlgoFallback::get_workspace( | ||||
const NCBKernSizeParam& param) const { | |||||
auto FH = param.filter_meta.spatial[0], FW = param.filter_meta.spatial[1]; | |||||
size_t nr_threads = param.nr_threads; | |||||
if (param.filter_meta.should_flip) { | |||||
// need transpose filter | |||||
return WorkspaceBundle{nullptr, {FH * FW * sizeof(float)}} | |||||
.total_size_in_bytes() * | |||||
nr_threads; | |||||
} else { | |||||
return 0; | |||||
const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_fallback_conv, | |||||
midout_iv("AlgoFallback::get_workspace"_hash)) { | |||||
auto FH = param.filter_meta.spatial[0], | |||||
FW = param.filter_meta.spatial[1]; | |||||
size_t nr_threads = param.nr_threads; | |||||
if (param.filter_meta.should_flip) { | |||||
// need transpose filter | |||||
return WorkspaceBundle{nullptr, {FH * FW * sizeof(float)}} | |||||
.total_size_in_bytes() * | |||||
nr_threads; | |||||
} else { | |||||
return 0; | |||||
} | |||||
} | } | ||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
SmallVector<ConvolutionImpl::NCBKern> | SmallVector<ConvolutionImpl::NCBKern> | ||||
ConvolutionImpl::AlgoFallback::dispatch_kern( | ConvolutionImpl::AlgoFallback::dispatch_kern( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
size_t group = param.filter_meta.group; | |||||
size_t N = param.n; | |||||
size_t nr_threads = param.nr_threads; | |||||
size_t workspace_per_thread = get_workspace( param) / nr_threads; | |||||
auto kern_fallback = [workspace_per_thread](const NCBKernParam& p, | |||||
const NCBKernIndex& ncb_index) { | |||||
UNPACK_CONV_F32_NCB_KERN_SIZES(p); | |||||
size_t batch_id = ncb_index.ndrange_id[1]; | |||||
size_t group_id = ncb_index.ndrange_id[0]; | |||||
MEGDNN_MARK_USED_VAR(N); | |||||
auto src = p.src<float>(batch_id, group_id), | |||||
filter = p.filter<float>(group_id); | |||||
auto dst = p.dst<float>(batch_id, group_id); | |||||
size_t thread_id = ncb_index.thread_id; | |||||
void* workspace_ptr = reinterpret_cast<void*>( | |||||
reinterpret_cast<ptrdiff_t>(p.workspace_ptr) + | |||||
workspace_per_thread * thread_id); | |||||
convolution::run_conv(src, filter, dst, workspace_ptr, IH, IW, IC, FH, | |||||
FW, OH, OW, OC, PH, PW, SH, SW, | |||||
!p.filter_meta.should_flip); | |||||
}; | |||||
return {{kern_fallback, {group, N, 1_z}}}; | |||||
MIDOUT_BEGIN(megdnn_fallback_conv, | |||||
midout_iv("AlgoFallback::dispatch_kern"_hash)) { | |||||
size_t group = param.filter_meta.group; | |||||
size_t N = param.n; | |||||
size_t nr_threads = param.nr_threads; | |||||
size_t workspace_per_thread = get_workspace( param) / nr_threads; | |||||
auto kern_fallback = [workspace_per_thread](const NCBKernParam& p, | |||||
const NCBKernIndex& ncb_index) { | |||||
UNPACK_CONV_F32_NCB_KERN_SIZES(p); | |||||
size_t batch_id = ncb_index.ndrange_id[1]; | |||||
size_t group_id = ncb_index.ndrange_id[0]; | |||||
MEGDNN_MARK_USED_VAR(N); | |||||
auto src = p.src<float>(batch_id, group_id), | |||||
filter = p.filter<float>(group_id); | |||||
auto dst = p.dst<float>(batch_id, group_id); | |||||
size_t thread_id = ncb_index.thread_id; | |||||
void* workspace_ptr = reinterpret_cast<void*>( | |||||
reinterpret_cast<ptrdiff_t>(p.workspace_ptr) + | |||||
workspace_per_thread * thread_id); | |||||
convolution::run_conv(src, filter, dst, workspace_ptr, IH, IW, IC, FH, | |||||
FW, OH, OW, OC, PH, PW, SH, SW, | |||||
!p.filter_meta.should_flip); | |||||
}; | |||||
return {{kern_fallback, {group, N, 1_z}}}; | |||||
} | |||||
MIDOUT_END(); | |||||
} | } | ||||
/* ===================== naive algo ===================== */ | /* ===================== naive algo ===================== */ | ||||
@@ -339,22 +349,36 @@ WorkspaceBundle ConvolutionImpl::AlgoDefault::get_bundle( | |||||
size_t ConvolutionImpl::AlgoDefault::get_workspace( | size_t ConvolutionImpl::AlgoDefault::get_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
return get_bundle(param).total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_fallback_conv, | |||||
midout_iv("AlgoDefault::get_workspace"_hash)) { | |||||
return get_bundle(param).total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
size_t ConvolutionImpl::AlgoDefault::get_preprocess_workspace( | size_t ConvolutionImpl::AlgoDefault::get_preprocess_workspace( | ||||
const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
::ConvBiasImpl::NCBKernSizeParam conv_bias_param = | |||||
init_conv_bias_param(param); | |||||
return m_algorithm->get_preprocess_workspace(conv_bias_param); | |||||
MIDOUT_BEGIN(megdnn_fallback_conv, | |||||
midout_iv("AlgoDefault::get_preprocess_workspace"_hash)) { | |||||
::ConvBiasImpl::NCBKernSizeParam conv_bias_param = | |||||
init_conv_bias_param(param); | |||||
return m_algorithm->get_preprocess_workspace(conv_bias_param); | |||||
} | |||||
MIDOUT_END(); | |||||
} | } | ||||
SmallVector<TensorLayout> | SmallVector<TensorLayout> | ||||
ConvolutionImpl::AlgoDefault::deduce_preprocessed_filter_layout( | ConvolutionImpl::AlgoDefault::deduce_preprocessed_filter_layout( | ||||
const NCBKernSizeParam& param) const { | |||||
::ConvBiasImpl::NCBKernSizeParam conv_bias_param = | |||||
init_conv_bias_param( param); | |||||
return m_algorithm->deduce_preprocessed_filter_layout(conv_bias_param); | |||||
const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN( | |||||
megdnn_fallback_conv, | |||||
midout_iv("AlgoDefault::deduce_preprocessed_filter_layout"_hash)) { | |||||
::ConvBiasImpl::NCBKernSizeParam conv_bias_param = | |||||
init_conv_bias_param(param); | |||||
return m_algorithm->deduce_preprocessed_filter_layout(conv_bias_param); | |||||
} | |||||
MIDOUT_END(); | |||||
} | } | ||||
//! Return the implement preprocess kernel | //! Return the implement preprocess kernel | ||||
@@ -450,19 +474,29 @@ bool ConvolutionBackwardDataImpl::AlgoDirect::usable( | |||||
size_t ConvolutionBackwardDataImpl::AlgoDirect::get_workspace( | size_t ConvolutionBackwardDataImpl::AlgoDirect::get_workspace( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ||||
auto FH = param.filter_meta.spatial[0], FW = param.filter_meta.spatial[1]; | |||||
if (param.filter_meta.should_flip) { | |||||
// need transpose filter | |||||
return FH * FW * sizeof(float); | |||||
} else { | |||||
return 0; | |||||
MIDOUT_BEGIN(megdnn_fallback_conv, | |||||
midout_iv("AlgoDirect::get_workspace"_hash)) { | |||||
auto FH = param.filter_meta.spatial[0], | |||||
FW = param.filter_meta.spatial[1]; | |||||
if (param.filter_meta.should_flip) { | |||||
// need transpose filter | |||||
return FH * FW * sizeof(float); | |||||
} else { | |||||
return 0; | |||||
} | |||||
} | } | ||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
ConvolutionBackwardDataImpl::ncb_kern_t | ConvolutionBackwardDataImpl::ncb_kern_t | ||||
ConvolutionBackwardDataImpl::AlgoDirect::dispatch_kern( | ConvolutionBackwardDataImpl::AlgoDirect::dispatch_kern( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | ||||
return kern_direct; | |||||
MIDOUT_BEGIN(megdnn_fallback_conv, | |||||
midout_iv("AlgoDirect::dispatch_kern"_hash)) { | |||||
return kern_direct; | |||||
} | |||||
MIDOUT_END(); | |||||
} | } | ||||
/* ===================== Matrix mul algo ===================== */ | /* ===================== Matrix mul algo ===================== */ | ||||
@@ -477,35 +511,48 @@ bool ConvolutionBackwardDataImpl::AlgoMatrixMul::usable( | |||||
size_t ConvolutionBackwardDataImpl::AlgoMatrixMul::get_workspace( | size_t ConvolutionBackwardDataImpl::AlgoMatrixMul::get_workspace( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ||||
return get_bundle(param).total_size_in_bytes(); | |||||
MIDOUT_BEGIN(megdnn_fallback_conv, | |||||
midout_iv("AlgoMatrixMul::get_workspace"_hash)) { | |||||
return get_bundle(param).total_size_in_bytes(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
ConvolutionBackwardDataImpl::ncb_kern_t | ConvolutionBackwardDataImpl::ncb_kern_t | ||||
ConvolutionBackwardDataImpl::AlgoMatrixMul::dispatch_kern( | ConvolutionBackwardDataImpl::AlgoMatrixMul::dispatch_kern( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ||||
#define cb(dt) \ | |||||
do { \ | |||||
if (param.filter_type.enumv() == DTypeTrait<dt>::enumv) { \ | |||||
using ctype = DTypeTrait<dt>::ctype; \ | |||||
return kern_matmul<ctype, ctype, ctype>; \ | |||||
} \ | |||||
#define cb(dt, midout_tag) \ | |||||
do { \ | |||||
if (param.filter_type.enumv() == DTypeTrait<dt>::enumv) { \ | |||||
MIDOUT_BEGIN(megdnn_fallback_conv, midout_iv(midout_tag)) { \ | |||||
using ctype = DTypeTrait<dt>::ctype; \ | |||||
return kern_matmul<ctype, ctype, ctype>; \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
} \ | |||||
} while (0); | } while (0); | ||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||||
cb(dtype::Float32, "FLOAT"_hash); | |||||
MEGDNN_INC_FLOAT16(cb(dtype::Float16, "FLOAT16"_hash)); | |||||
MEGDNN_INC_FLOAT16(cb(dtype::BFloat16, "BFLOAT16"_hash)); | |||||
#undef cb | #undef cb | ||||
#define cb(dt_src, dt_dst) \ | |||||
do { \ | |||||
if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||||
param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||||
param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \ | |||||
return kern_matmul<DTypeTrait<dt_src>::ctype, \ | |||||
DTypeTrait<dt_src>::ctype, \ | |||||
DTypeTrait<dt_dst>::ctype>; \ | |||||
} \ | |||||
#define cb(dt_src, dt_dst, midout_tag) \ | |||||
do { \ | |||||
if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||||
param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||||
param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \ | |||||
MIDOUT_BEGIN(megdnn_fallback_conv, midout_iv(midout_tag)) { \ | |||||
return kern_matmul<DTypeTrait<dt_src>::ctype, \ | |||||
DTypeTrait<dt_src>::ctype, \ | |||||
DTypeTrait<dt_dst>::ctype>; \ | |||||
} \ | |||||
MIDOUT_END(); \ | |||||
} \ | |||||
} while (0) | } while (0) | ||||
cb(dtype::Int8, dtype::Int32); | |||||
cb(dtype::QuantizedS8, dtype::QuantizedS32); | |||||
cb(dtype::Quantized8Asymm, dtype::QuantizedS32); | |||||
cb(dtype::Int8, dtype::Int32, "INT8x8x32"_hash); | |||||
cb(dtype::QuantizedS8, dtype::QuantizedS32, "QINT8x8x32"_hash); | |||||
cb(dtype::Quantized8Asymm, dtype::QuantizedS32, "QUINT8x8x32"_hash); | |||||
megdnn_throw("unsupported data type on matrix mul"); | megdnn_throw("unsupported data type on matrix mul"); | ||||
#undef cb | #undef cb | ||||
} | } | ||||
@@ -24,7 +24,6 @@ | |||||
#include <cstring> | #include <cstring> | ||||
MIDOUT_DECL(megdnn_fb_conv_float) | |||||
MIDOUT_DECL(megdnn_fb_convbwd_float) | MIDOUT_DECL(megdnn_fb_convbwd_float) | ||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -53,13 +53,20 @@ bool MatrixMulImpl::AlgoF32K8x12x1::usable( | |||||
size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace( | size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace( | ||||
const KernSizeParam& kern_size_param) const { | const KernSizeParam& kern_size_param) const { | ||||
auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; | |||||
matmul::fallback::sgemm_8x12 strategy(M, N, K, kern_size_param.A_type, | |||||
kern_size_param.B_type, | |||||
kern_size_param.C_type); | |||||
return matmul::GemmInterleaved<matmul::fallback::sgemm_8x12>( | |||||
M, N, K, kern_size_param.trA, kern_size_param.trB, strategy) | |||||
.get_workspace_size(); | |||||
MIDOUT_BEGIN(megdnn_fb_matmul_f32_kern, | |||||
midout_iv("AlgoF32K8x12x1::get_workspace"_hash)) { | |||||
auto M = kern_size_param.M, N = kern_size_param.N, | |||||
K = kern_size_param.K; | |||||
matmul::fallback::sgemm_8x12 strategy(M, N, K, kern_size_param.A_type, | |||||
kern_size_param.B_type, | |||||
kern_size_param.C_type); | |||||
return matmul::GemmInterleaved<matmul::fallback::sgemm_8x12>( | |||||
M, N, K, kern_size_param.trA, kern_size_param.trB, | |||||
strategy) | |||||
.get_workspace_size(); | |||||
} | |||||
MIDOUT_END(); | |||||
return 0; | |||||
} | } | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( | ||||
@@ -23,10 +23,16 @@ namespace naive { | |||||
size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A, | size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A, | ||||
const TensorLayout& B, | const TensorLayout& B, | ||||
const TensorLayout&) { | const TensorLayout&) { | ||||
if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||||
return (A.span().dist_elem() + B.span().dist_elem()) * sizeof(uint8_t); | |||||
MIDOUT_BEGIN( | |||||
megdnn_naive_matmul, | |||||
midout_iv("MatrixMulForwardImpl::get_workspace_in_bytes"_hash)) { | |||||
if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||||
return (A.span().dist_elem() + B.span().dist_elem()) * | |||||
sizeof(uint8_t); | |||||
} | |||||
return 0; | |||||
} | } | ||||
return 0; | |||||
MIDOUT_END(); | |||||
} | } | ||||
template <bool TA, bool TB> | template <bool TA, bool TB> | ||||
@@ -127,7 +133,8 @@ void MatrixMulForwardImpl::exec_internal(_megdnn_tensor_in A, | |||||
void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | ||||
_megdnn_tensor_out C, | _megdnn_tensor_out C, | ||||
_megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
MIDOUT_BEGIN(megdnn_naive_matmul) { | |||||
MIDOUT_BEGIN(megdnn_naive_matmul, | |||||
midout_iv("MatrixMulForwardImpl::exec"_hash)) { | |||||
check_exec(A.layout, B.layout, C.layout, workspace.size); | check_exec(A.layout, B.layout, C.layout, workspace.size); | ||||
auto p = param(); | auto p = param(); | ||||
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal(A, B, C, workspace, p)); | MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal(A, B, C, workspace, p)); | ||||
@@ -17,6 +17,7 @@ | |||||
#cmakedefine01 MGB_ENABLE_DEBUG_UTIL | #cmakedefine01 MGB_ENABLE_DEBUG_UTIL | ||||
#cmakedefine01 MGB_ENABLE_LOGGING | #cmakedefine01 MGB_ENABLE_LOGGING | ||||
#cmakedefine01 MGB_ENABLE_GRAD | #cmakedefine01 MGB_ENABLE_GRAD | ||||
#cmakedefine01 MGB_ENABLE_CPUINFO | |||||
#cmakedefine01 MGB_VERBOSE_TYPEINFO_NAME | #cmakedefine01 MGB_VERBOSE_TYPEINFO_NAME | ||||
#cmakedefine01 MGB_BUILD_SLIM_SERVING | #cmakedefine01 MGB_BUILD_SLIM_SERVING | ||||
#cmakedefine01 MGB_ENABLE_EXCEPTION | #cmakedefine01 MGB_ENABLE_EXCEPTION | ||||
@@ -80,6 +81,16 @@ | |||||
#define MGB_ENABLE_GRAD 1 | #define MGB_ENABLE_GRAD 1 | ||||
#endif | #endif | ||||
// whether to enable cpuinfo | |||||
#ifndef MGB_ENABLE_CPUINFO | |||||
#define MGB_ENABLE_CPUINFO 1 | |||||
#endif | |||||
#ifdef IOS | |||||
#undef MGB_ENABLE_CPUINFO | |||||
#define MGB_ENABLE_CPUINFO 0 | |||||
#endif | |||||
// whether to include actual class name in mgb::Typeinfo object; if this is | // whether to include actual class name in mgb::Typeinfo object; if this is | ||||
// disabled, mgb::serialization::OprRegistry::find_opr_by_name would not work. | // disabled, mgb::serialization::OprRegistry::find_opr_by_name would not work. | ||||
#ifndef MGB_VERBOSE_TYPEINFO_NAME | #ifndef MGB_VERBOSE_TYPEINFO_NAME | ||||