Browse Source

fix(dnn): midout at where neccessary in megdnn

GitOrigin-RevId: 191334bd96
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 4 years ago
parent
commit
bca00f2e22
26 changed files with 533 additions and 238 deletions
  1. +3
    -3
      CMakeLists.txt
  2. +1
    -1
      dnn/src/CMakeLists.txt
  3. +2
    -2
      dnn/src/aarch64/conv_bias/fp16/algos.cpp
  4. +19
    -0
      dnn/src/aarch64/matrix_mul/algos.cpp
  5. +6
    -1
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp
  6. +6
    -1
      dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp
  7. +108
    -62
      dnn/src/arm_common/conv_bias/int8/algos.cpp
  8. +7
    -2
      dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp
  9. +6
    -1
      dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp
  10. +6
    -1
      dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp
  11. +6
    -1
      dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
  12. +36
    -20
      dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp
  13. +0
    -3
      dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.cpp
  14. +0
    -2
      dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.cpp
  15. +6
    -1
      dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp
  16. +40
    -16
      dnn/src/arm_common/conv_bias/quint8/algos.cpp
  17. +24
    -4
      dnn/src/arm_common/convolution/int8x8x32/algos.cpp
  18. +24
    -4
      dnn/src/arm_common/convolution/quint8/algos.cpp
  19. +65
    -34
      dnn/src/arm_common/matrix_mul/algos.cpp
  20. +17
    -0
      dnn/src/armv7/matrix_mul/algos.cpp
  21. +1
    -0
      dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp
  22. +114
    -67
      dnn/src/fallback/convolution/algos.cpp
  23. +0
    -1
      dnn/src/fallback/convolution/opr_impl.cpp
  24. +14
    -7
      dnn/src/fallback/matrix_mul/algos.cpp
  25. +11
    -4
      dnn/src/naive/matrix_mul/opr_impl.cpp
  26. +11
    -0
      src/megbrain_build_config.h.in

+ 3
- 3
CMakeLists.txt View File

@@ -116,9 +116,9 @@ if(${MGE_ARCH} STREQUAL "AUTO")
endif()

if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386" OR ${MGE_ARCH} STREQUAL "armv7" OR ${MGE_ARCH} STREQUAL "aarch64")
option(MGB_ENABLE_CPUINFO "Build cpuinfo library for check runtime." ON)
if(MGB_ENABLE_CPUINFO)
message("-- Enable cpuinfo runtime check.")
option(MGE_ENABLE_CPUINFO "Build cpuinfo library for check runtime." ON)
if(MGE_ENABLE_CPUINFO)
message("-- Enable cpuinfo runtime check and little kernel optimize.")
add_definitions(-DMGB_ENABLE_CPUINFO_CHECK)
include(cmake/cpuinfo.cmake)
endif()


+ 1
- 1
dnn/src/CMakeLists.txt View File

@@ -53,7 +53,7 @@ add_library(megdnn EXCLUDE_FROM_ALL OBJECT ${SOURCES})
target_link_libraries(megdnn PUBLIC opr_param_defs)

if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386" OR ${MGE_ARCH} STREQUAL "armv7" OR ${MGE_ARCH} STREQUAL "aarch64")
if(MGB_ENABLE_CPUINFO)
if(MGE_ENABLE_CPUINFO)
target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:cpuinfo>)
endif()
endif()


+ 2
- 2
dnn/src/aarch64/conv_bias/fp16/algos.cpp View File

@@ -49,13 +49,13 @@ size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace(
return wbundle.total_size_in_bytes();
}
MIDOUT_END();
return false;
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) {
return get_kimpls(param);
}
MIDOUT_END();


+ 19
- 0
dnn/src/aarch64/matrix_mul/algos.cpp View File

@@ -58,6 +58,7 @@ size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(
@@ -118,6 +119,7 @@ size_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_kern(
@@ -177,6 +179,7 @@ size_t MatrixMulImpl::AlgoF32K4x16x1::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K4x16x1::get_kern(
@@ -237,6 +240,7 @@ size_t MatrixMulImpl::AlgoF32MK4_4x16::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_4x16::get_kern(
@@ -313,6 +317,7 @@ size_t MatrixMulImpl::AlgoF16K8x24x1::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K8x24x1::get_kern(
@@ -352,6 +357,7 @@ size_t MatrixMulImpl::AlgoF16MK8_8x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern(
@@ -431,6 +437,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_kern(
@@ -501,6 +508,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_kern(
@@ -573,6 +581,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_kern(
@@ -635,6 +644,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_kern(
@@ -696,6 +706,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_kern(
@@ -762,6 +773,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_kern(
@@ -828,6 +840,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_kern(
@@ -905,6 +918,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern(
@@ -981,6 +995,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_kern(
@@ -1051,6 +1066,7 @@ size_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_kern(
@@ -1092,6 +1108,7 @@ size_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_kern(
@@ -1172,6 +1189,7 @@ size_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_kern(
@@ -1277,6 +1295,7 @@ size_t MatrixMulImpl::AlgoQuint8K8x8x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x8::get_kern(


+ 6
- 1
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp View File

@@ -160,7 +160,12 @@ bool ConvBiasImpl::AlgoF32DirectNCHW44::usable(const NCBKernSizeParam& param,

size_t ConvBiasImpl::AlgoF32DirectNCHW44::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw44_stride1,
midout_iv("AlgoF32DirectNCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>


+ 6
- 1
dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp View File

@@ -199,7 +199,12 @@ bool ConvBiasImpl::AlgoF32DirectNCHWNCHW44::usable(

size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw_nchw44,
midout_iv("AlgoF32DirectNCHWNCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>


+ 108
- 62
dnn/src/arm_common/conv_bias/int8/algos.cpp View File

@@ -33,29 +33,39 @@ bool ConvBiasImpl::AlgoS8DirectStride1::usable(const NCBKernSizeParam& param,
return direct_int8_stride1::can_conv_direct_stride1_int8(param);
}
bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred(
const NCBKernSizeParam& param) const {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
auto OC = fm.ocpg;
auto IC = fm.icpg;
bool preferred = ((FH == 2 && (OC <= 10 || IC <= 8)) ||
((FH == 3 || FH == 5 || FH == 7) &&
(OC <= 16 || (IC <= 4 && OC <= 32)))) &&
param.bias_mode != BiasMode::BIAS;
return preferred;
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8DirectStride1::is_preferred"_hash)) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
auto OC = fm.ocpg;
auto IC = fm.icpg;
bool preferred = ((FH == 2 && (OC <= 10 || IC <= 8)) ||
((FH == 3 || FH == 5 || FH == 7) &&
(OC <= 16 || (IC <= 4 && OC <= 32)))) &&
param.bias_mode != BiasMode::BIAS;
return preferred;
}
MIDOUT_END();
}

size_t ConvBiasImpl::AlgoS8DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_int8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8DirectStride1::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_int8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8DirectStride1::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 0) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8DirectStride1::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_int8_stride1::get_kimpls(param, large_group);
}
@@ -72,15 +82,20 @@ bool ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::usable(

size_t ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::get_workspace(
const NCBKernSizeParam& param) const {
auto bundle = channel_wise_nchw44::stride1::get_bundle(param);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8ChanWiseStride1NCHW44::get_workspace"_hash)) {
auto bundle = channel_wise_nchw44::stride1::get_bundle(param);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8ChanWiseStride1NCHW44"_hash)) {
midout_iv("AlgoS8ChanWiseStride1NCHW44::dispatch_kerns"_hash)) {
return channel_wise_nchw44::stride1::get_kimpls(param);
}
MIDOUT_END();
@@ -96,15 +111,20 @@ bool ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::usable(

size_t ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::get_workspace(
const NCBKernSizeParam& param) const {
auto bundle = channel_wise_nchw44::stride2::get_bundle(param);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8ChanWiseStride2NCHW44::get_workspace"_hash)) {
auto bundle = channel_wise_nchw44::stride2::get_bundle(param);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8ChanWiseStride2NCHW44"_hash)) {
midout_iv("AlgoS8ChanWiseStride2NCHW44::dispatch_kerns"_hash)) {
return channel_wise_nchw44::stride2::get_kimpls(param);
}
MIDOUT_END();
@@ -119,15 +139,21 @@ bool ConvBiasImpl::AlgoS8DirectStride2::usable(const NCBKernSizeParam& param,

size_t ConvBiasImpl::AlgoS8DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_int8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8DirectStride2::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_int8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 1) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8DirectStride2::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_int8_stride2::get_kimpls(param, large_group);
}
@@ -144,15 +170,21 @@ bool ConvBiasImpl::AlgoDotS8DirectStride1::usable(const NCBKernSizeParam& param,

size_t ConvBiasImpl::AlgoDotS8DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_int8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoDotS8DirectStride1::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_int8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 1) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoDotS8DirectStride1::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_dotprod_int8_stride1::get_kimpls(param, large_group);
}
@@ -168,15 +200,21 @@ bool ConvBiasImpl::AlgoDotS8DirectStride2::usable(const NCBKernSizeParam& param,

size_t ConvBiasImpl::AlgoDotS8DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_int8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoDotS8DirectStride2::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_int8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 2) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoDotS8DirectStride2::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_dotprod_int8_stride2::get_kimpls(param, large_group);
}
@@ -188,37 +226,45 @@ ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns(
/* ======================= AlgoS8WinogradF23_8x8 ======================== */

bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable(
const NCBKernSizeParam& param,
const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0)
return false;
using Strategy = winograd::winograd_2x3_8x8_s8;
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy, param::MatrixMul::Format::MK8>(
strategy, m_tile_size, param)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
m_matmul_algo->packmode() == PackMode::NO_PACK &&
((param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.filter_meta.format ==
param::ConvBias::Format::NCHW_WINOGRAD &&
param.output_block_size == 2 &&
param.winograd_matmul_format == param::MatrixMul::Format::MK8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS16)) &&
!param.filter_meta.should_flip &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 3) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] == param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.bias_type.enumv() == DTypeEnum::QuantizedS32 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8;
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8WinogradF23_8x8::usable"_hash)) {
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0)
return false;
using Strategy = winograd::winograd_2x3_8x8_s8;
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy,
param::MatrixMul::Format::MK8>(
strategy, m_tile_size, param)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
m_matmul_algo->packmode() == PackMode::NO_PACK &&
((param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.filter_meta.format ==
param::ConvBias::Format::NCHW_WINOGRAD &&
param.output_block_size == 2 &&
param.winograd_matmul_format ==
param::MatrixMul::Format::MK8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS16)) &&
!param.filter_meta.should_flip &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 3) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.bias_type.enumv() == DTypeEnum::QuantizedS32 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8;
}
MIDOUT_END();
return false;
}

MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoS8WinogradF23_8x8,


+ 7
- 2
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp View File

@@ -202,14 +202,19 @@ bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::is_preferred(

size_t ConvBiasImpl::AlgoDotS8Direct_NCHW44::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("ALGODOTS8DIRECT_NCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("ALGODOTS8DIRECT_NCHW44"_hash)) {
midout_iv("ALGODOTS8DIRECT_NCHW44::dispatch_kerns"_hash)) {
auto fm = param.filter_meta;
size_t BATCH = param.n;
size_t GROUP = fm.group;


+ 6
- 1
dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp View File

@@ -223,7 +223,12 @@ bool ConvBiasImpl::AlgoS8DirectNCHW44::is_preferred(

size_t ConvBiasImpl::AlgoS8DirectNCHW44::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44,
midout_iv("AlgoS8DirectNCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>


+ 6
- 1
dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp View File

@@ -232,7 +232,12 @@ bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred(

size_t ConvBiasImpl::AlgoS8DirectNCHWNCHW44::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44,
midout_iv("AlgoS8DirectNCHWNCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>


+ 6
- 1
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp View File

@@ -183,7 +183,12 @@ bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable(

size_t ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_dot,
midout_iv("AlgoDotS8DirectNCHWNCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>


+ 36
- 20
dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp View File

@@ -83,7 +83,8 @@ void get_rectified_size_str2(size_t IH, size_t IW, size_t OH, size_t OW,
/* ===================== direct algo ===================== */
bool ConvBiasImpl::AlgoI8x8x16Direct::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 0) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Direct::usable"_hash)) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
return param.bias_mode == BiasMode::NO_BIAS &&
@@ -122,7 +123,8 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Direct::get_bundle(
}
size_t ConvBiasImpl::AlgoI8x8x16Direct::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 1) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Direct::get_workspace"_hash)) {
auto bundle = get_bundle(param);
return bundle.total_size_in_bytes();
}
@@ -287,7 +289,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Direct::get_kimpls(
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 2) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Direct::dispatch_kerns"_hash)) {
return get_kimpls(param);
}
MIDOUT_END();
@@ -297,7 +300,8 @@ ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns(
/* ===================== stride-2 algo ===================== */
bool ConvBiasImpl::AlgoI8x8x16Stride2::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 0) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2::usable"_hash)) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
return param.bias_mode == BiasMode::NO_BIAS &&
@@ -337,7 +341,8 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Stride2::get_bundle(
}
size_t ConvBiasImpl::AlgoI8x8x16Stride2::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 1) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2::get_workspace"_hash)) {
auto bundle = get_bundle(param);
return bundle.total_size_in_bytes();
}
@@ -501,7 +506,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Stride2::get_kimpls(
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoI8x8x16Stride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 2) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2::dispatch_kerns"_hash)) {
return get_kimpls(param);
}
MIDOUT_END();
@@ -510,7 +516,8 @@ ConvBiasImpl::AlgoI8x8x16Stride2::dispatch_kerns(
bool ConvBiasImpl::AlgoI8x8x16Stride2Filter2::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 3, 0) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2Filter2::usable"_hash)) {
return param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY &&
param.nr_threads == 1_z &&
@@ -522,7 +529,8 @@ bool ConvBiasImpl::AlgoI8x8x16Stride2Filter2::usable(

size_t ConvBiasImpl::AlgoI8x8x16Stride2Filter2::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 3, 1) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2Filter2::get_workspace"_hash)) {
return conv_bias::get_workspace_in_bytes_conv_int8x8x16_stride2_flt2(
param);
}
@@ -535,7 +543,8 @@ ConvBiasImpl::AlgoI8x8x16Stride2Filter2::dispatch_kerns(
const NCBKernSizeParam& param) const {
// return {conv_bias::conv_int8x8x16_stride2_flt2,true};
auto kern = [](const NCBKernParam& param, const NCBKernIndex& ncb_index) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 3, 2) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2Filter2::dispatch_kerns"_hash)) {
auto ncb_param = param;
ncb_param.src_ptr = param.src<void>(0, ncb_index.ndrange_id[0]);
ncb_param.dst_ptr = param.dst<void>(0, ncb_index.ndrange_id[0]);
@@ -573,18 +582,25 @@ bool ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::usable(

size_t ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace(
const NCBKernSizeParam& param) const {
size_t stride_h = param.filter_meta.stride[0];
size_t stride_w = param.filter_meta.stride[1];
megdnn_assert(stride_h == stride_w);
if (stride_h == 1) {
return channel_wise_nchw44_8x8x16::stride1::get_bundle(param)
.total_size_in_bytes();
} else if (stride_h == 2) {
return channel_wise_nchw44_8x8x16::stride2::get_bundle(param)
.total_size_in_bytes();
} else {
return 0;
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv(
"AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace"_hash)) {
size_t stride_h = param.filter_meta.stride[0];
size_t stride_w = param.filter_meta.stride[1];
megdnn_assert(stride_h == stride_w);
if (stride_h == 1) {
return channel_wise_nchw44_8x8x16::stride1::get_bundle(param)
.total_size_in_bytes();
} else if (stride_h == 2) {
return channel_wise_nchw44_8x8x16::stride2::get_bundle(param)
.total_size_in_bytes();
} else {
return 0;
}
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>


+ 0
- 3
dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.cpp View File

@@ -13,11 +13,8 @@
#include "src/common/utils.h"

#include <cstring>
#include "midout.h"
#include "src/arm_common/simd_macro/marm_neon.h"

MIDOUT_DECL(megdnn_arm_common_conv_bias_int8816_filter)

using namespace megdnn;
using namespace arm_common;
using namespace conv_bias;


+ 0
- 2
dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.cpp View File

@@ -12,9 +12,7 @@
#include "src/common/utils.h"

#include <cstring>
#include "midout.h"
#include "src/arm_common/simd_macro/marm_neon.h"
MIDOUT_DECL(megdnn_arm_common_conv_bias_s2_filter)

#pragma GCC diagnostic ignored "-Wunused-parameter"



+ 6
- 1
dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp View File

@@ -229,7 +229,12 @@ bool ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::usable(

size_t ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_i8i8i16_nchw_nchw44,
midout_iv("AlgoI8x8x16DirectNCHWNCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>


+ 40
- 16
dnn/src/arm_common/conv_bias/quint8/algos.cpp View File

@@ -32,15 +32,21 @@ bool ConvBiasImpl::AlgoQU8DirectStride1::usable(const NCBKernSizeParam& param,

size_t ConvBiasImpl::AlgoQU8DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_quint8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_quint8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 0) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_quint8_stride1::get_kimpls(param, large_group);
}
@@ -57,15 +63,21 @@ bool ConvBiasImpl::AlgoQU8DirectStride2::usable(

size_t ConvBiasImpl::AlgoQU8DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_quint8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_quint8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 1) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_quint8_stride2::get_kimpls(param, large_group);
}
@@ -81,15 +93,21 @@ bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(const NCBKernSizeParam& param,

size_t ConvBiasImpl::AlgoDotU8DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_quint8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_quint8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 0) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_dotprod_quint8_stride1::get_kimpls(param, large_group);
}
@@ -105,15 +123,21 @@ bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(const NCBKernSizeParam& param,

size_t ConvBiasImpl::AlgoDotU8DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_quint8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_quint8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotU8DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 1) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_dotprod_quint8_stride2::get_kimpls(param, large_group);
}


+ 24
- 4
dnn/src/arm_common/convolution/int8x8x32/algos.cpp View File

@@ -32,13 +32,23 @@ bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::usable(

size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
return deconv::get_workspace_in_bytes_stride1_int8x8x32_dot(param);
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl,
midout_iv("AlgoSdot8DirectStride1::get_workspace"_hash)) {
return deconv::get_workspace_in_bytes_stride1_int8x8x32_dot(param);
}
MIDOUT_END();
return 0;
}

ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
return deconv::stride1_int8x8x32_dot;
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl,
midout_iv("AlgoSdot8DirectStride1::dispatch_kern"_hash)) {
return deconv::stride1_int8x8x32_dot;
}
MIDOUT_END();
return {};
}

/* ===================== direct stride 2 algo ===================== */
@@ -49,13 +59,23 @@ bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::usable(

size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
return deconv::get_workspace_in_bytes_stride2_int8x8x32_dot(param);
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl,
midout_iv("AlgoSdot8DirectStride2::get_workspace"_hash)) {
return deconv::get_workspace_in_bytes_stride2_int8x8x32_dot(param);
}
MIDOUT_END();
return 0;
}

ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
return deconv::stride2_int8x8x32_dot;
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl,
midout_iv("AlgoSdot8DirectStride2::dispatch_kern"_hash)) {
return deconv::stride2_int8x8x32_dot;
}
MIDOUT_END();
return {};
}

#endif


+ 24
- 4
dnn/src/arm_common/convolution/quint8/algos.cpp View File

@@ -33,13 +33,23 @@ bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::usable(

size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
return deconv::get_workspace_in_bytes_stride1_quint8_dot(param);
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl,
midout_iv("AlgoUdot8DirectStride1::get_workspace"_hash)) {
return deconv::get_workspace_in_bytes_stride1_quint8_dot(param);
}
MIDOUT_END();
return 0;
}

ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
return deconv::stride1_quint8_dot;
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl,
midout_iv("AlgoUdot8DirectStride1::dispatch_kern"_hash)) {
return deconv::stride1_quint8_dot;
}
MIDOUT_END();
return {};
}

/* ===================== direct stride 2 algo ===================== */
@@ -50,13 +60,23 @@ bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::usable(

size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
return deconv::get_workspace_in_bytes_stride2_quint8_dot(param);
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl,
midout_iv("AlgoUdot8DirectStride2::get_workspace"_hash)) {
return deconv::get_workspace_in_bytes_stride2_quint8_dot(param);
}
MIDOUT_END();
return 0;
}

ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
return deconv::stride2_quint8_dot;
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl,
midout_iv("AlgoUdot8DirectStride2::dispatch_kern"_hash)) {
return deconv::stride2_quint8_dot;
}
MIDOUT_END();
return {};
}
#endif
// vim: syntax=cpp.doxygen

+ 65
- 34
dnn/src/arm_common/matrix_mul/algos.cpp View File

@@ -18,6 +18,8 @@

MIDOUT_DECL(megdnn_arm_hgemv)
MIDOUT_DECL(megdnn_arm_exec_int8816)
MIDOUT_DECL(megdnn_arm_exec_int8832)
MIDOUT_DECL(megdnn_arm_exec_fp32)

using namespace megdnn;
using namespace arm_common;
@@ -63,8 +65,13 @@ bool MatrixMulImpl::AlgoInt8x8x16::usable(

size_t MatrixMulImpl::AlgoInt8x8x16::get_workspace(
const KernSizeParam& kern_size_param) const {
auto wbundle = get_workspace_bundle_int_8x8x16(kern_size_param);
return wbundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_exec_int8816,
midout_iv("AlgoInt8x8x16::get_workspace"_hash)) {
auto wbundle = get_workspace_bundle_int_8x8x16(kern_size_param);
return wbundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern(
@@ -75,11 +82,15 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern(
/* ===================== Int8x8x32 Gemv algo ===================== */
namespace {
void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
MIDOUT_BEGIN(megdnn_arm_exec_int8832,
midout_iv("int8x8x32_gemv_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
MIDOUT_END();
}
} // anonymous namespace

@@ -104,11 +115,15 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern(
/* ===================== Int8x8x32 Gemv MK4 algo ===================== */
namespace {
void int8x8x32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
MIDOUT_BEGIN(megdnn_arm_exec_int8832,
midout_iv("int8x8x32_gemv_mk4_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
MIDOUT_END();
}
} // anonymous namespace

@@ -147,11 +162,15 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern(
/* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */
namespace {
void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
gemv_like_mk4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
MIDOUT_BEGIN(megdnn_arm_exec_int8832,
midout_iv("int8x8x32_gemv_mk4_dot_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
gemv_like_mk4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
MIDOUT_END();
}
} // anonymous namespace

@@ -189,12 +208,16 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::get_kern(
/* ===================== F32 Gemv algo ===================== */
namespace {
void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_float32>(),
Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
MIDOUT_BEGIN(megdnn_arm_exec_fp32,
midout_iv("f32_gemv_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_float32>(),
Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
MIDOUT_END();
}
} // anonymous namespace

@@ -225,12 +248,16 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern(
/* ================== F32 Gemv MK4 algo ================== */
namespace {
void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_float32>(),
Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
MIDOUT_BEGIN(megdnn_arm_exec_fp32,
midout_iv("f32_gemv_mk4_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_float32>(),
Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
MIDOUT_END();
}
} // anonymous namespace

@@ -266,11 +293,15 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GemvMK4::get_kern(
namespace {
template <typename stype, typename dtype>
void gevm_like_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDB = kern_param.LDB;
const auto Aptr = kern_param.A<stype>(), Bptr = kern_param.B<stype>();
auto Cptr = kern_param.C<dtype>();
megdnn::arm_common::gemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1);
MIDOUT_BEGIN(megdnn_arm_exec_fp32,
midout_iv("gevm_like_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDB = kern_param.LDB;
const auto Aptr = kern_param.A<stype>(), Bptr = kern_param.B<stype>();
auto Cptr = kern_param.C<dtype>();
megdnn::arm_common::gemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1);
}
MIDOUT_END();
}
} // anonymous namespace



+ 17
- 0
dnn/src/armv7/matrix_mul/algos.cpp View File

@@ -75,6 +75,7 @@ size_t MatrixMulImpl::AlgoF32::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32::get_kern(
@@ -141,6 +142,7 @@ size_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_kern(
@@ -202,6 +204,7 @@ size_t MatrixMulImpl::AlgoF16K4x16x1::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K4x16x1::get_kern(
@@ -265,6 +268,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K4x2x16::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x2x16::get_kern(
@@ -326,6 +330,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K4x8x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x8x8::get_kern(
@@ -386,6 +391,7 @@ size_t MatrixMulImpl::AlgoQuint8K4x8x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K4x8x8::get_kern(
@@ -445,6 +451,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16K4x2x16::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x2x16::get_kern(
@@ -510,6 +517,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16K4x8x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x8x8::get_kern(
@@ -577,6 +585,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_kern(
@@ -642,6 +651,7 @@ size_t MatrixMulImpl::AlgoInt16x16x32K12x4x1::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32K12x4x1::get_kern(
@@ -702,6 +712,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K6x8x4::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K6x8x4::get_kern(
@@ -764,6 +775,7 @@ size_t MatrixMulImpl::AlgoQuint8DotK4x8x4::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8DotK4x8x4::get_kern(
@@ -830,6 +842,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_kern(
@@ -894,6 +907,7 @@ size_t MatrixMulImpl::AlgoF32MK4_4x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_4x8::get_kern(
@@ -929,6 +943,7 @@ size_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_kern(
@@ -986,6 +1001,7 @@ size_t MatrixMulImpl::AlgoF16MK8_4x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_4x8::get_kern(
@@ -1066,6 +1082,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::get_kern(


+ 1
- 0
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp View File

@@ -261,6 +261,7 @@ size_t ConvBiasImpl::AlgoConv1x1Gemv::get_workspace(
.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

SmallVector<ConvBiasImpl::NCBKern>


+ 114
- 67
dnn/src/fallback/convolution/algos.cpp View File

@@ -175,44 +175,54 @@ bool ConvolutionImpl::AlgoFallback::usable(
}

size_t ConvolutionImpl::AlgoFallback::get_workspace(
const NCBKernSizeParam& param) const {
auto FH = param.filter_meta.spatial[0], FW = param.filter_meta.spatial[1];
size_t nr_threads = param.nr_threads;
if (param.filter_meta.should_flip) {
// need transpose filter
return WorkspaceBundle{nullptr, {FH * FW * sizeof(float)}}
.total_size_in_bytes() *
nr_threads;
} else {
return 0;
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_fallback_conv,
midout_iv("AlgoFallback::get_workspace"_hash)) {
auto FH = param.filter_meta.spatial[0],
FW = param.filter_meta.spatial[1];
size_t nr_threads = param.nr_threads;
if (param.filter_meta.should_flip) {
// need transpose filter
return WorkspaceBundle{nullptr, {FH * FW * sizeof(float)}}
.total_size_in_bytes() *
nr_threads;
} else {
return 0;
}
}
MIDOUT_END();
return 0;
}

SmallVector<ConvolutionImpl::NCBKern>
ConvolutionImpl::AlgoFallback::dispatch_kern(
const NCBKernSizeParam& param) const {
size_t group = param.filter_meta.group;
size_t N = param.n;
size_t nr_threads = param.nr_threads;
size_t workspace_per_thread = get_workspace( param) / nr_threads;
auto kern_fallback = [workspace_per_thread](const NCBKernParam& p,
const NCBKernIndex& ncb_index) {
UNPACK_CONV_F32_NCB_KERN_SIZES(p);
size_t batch_id = ncb_index.ndrange_id[1];
size_t group_id = ncb_index.ndrange_id[0];
MEGDNN_MARK_USED_VAR(N);
auto src = p.src<float>(batch_id, group_id),
filter = p.filter<float>(group_id);
auto dst = p.dst<float>(batch_id, group_id);
size_t thread_id = ncb_index.thread_id;
void* workspace_ptr = reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(p.workspace_ptr) +
workspace_per_thread * thread_id);
convolution::run_conv(src, filter, dst, workspace_ptr, IH, IW, IC, FH,
FW, OH, OW, OC, PH, PW, SH, SW,
!p.filter_meta.should_flip);
};
return {{kern_fallback, {group, N, 1_z}}};
MIDOUT_BEGIN(megdnn_fallback_conv,
midout_iv("AlgoFallback::dispatch_kern"_hash)) {
size_t group = param.filter_meta.group;
size_t N = param.n;
size_t nr_threads = param.nr_threads;
size_t workspace_per_thread = get_workspace( param) / nr_threads;
auto kern_fallback = [workspace_per_thread](const NCBKernParam& p,
const NCBKernIndex& ncb_index) {
UNPACK_CONV_F32_NCB_KERN_SIZES(p);
size_t batch_id = ncb_index.ndrange_id[1];
size_t group_id = ncb_index.ndrange_id[0];
MEGDNN_MARK_USED_VAR(N);
auto src = p.src<float>(batch_id, group_id),
filter = p.filter<float>(group_id);
auto dst = p.dst<float>(batch_id, group_id);
size_t thread_id = ncb_index.thread_id;
void* workspace_ptr = reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(p.workspace_ptr) +
workspace_per_thread * thread_id);
convolution::run_conv(src, filter, dst, workspace_ptr, IH, IW, IC, FH,
FW, OH, OW, OC, PH, PW, SH, SW,
!p.filter_meta.should_flip);
};
return {{kern_fallback, {group, N, 1_z}}};
}
MIDOUT_END();
}

/* ===================== naive algo ===================== */
@@ -339,22 +349,36 @@ WorkspaceBundle ConvolutionImpl::AlgoDefault::get_bundle(

size_t ConvolutionImpl::AlgoDefault::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_fallback_conv,
midout_iv("AlgoDefault::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

size_t ConvolutionImpl::AlgoDefault::get_preprocess_workspace(
const NCBKernSizeParam& param) const {
::ConvBiasImpl::NCBKernSizeParam conv_bias_param =
init_conv_bias_param(param);
return m_algorithm->get_preprocess_workspace(conv_bias_param);
MIDOUT_BEGIN(megdnn_fallback_conv,
midout_iv("AlgoDefault::get_preprocess_workspace"_hash)) {
::ConvBiasImpl::NCBKernSizeParam conv_bias_param =
init_conv_bias_param(param);
return m_algorithm->get_preprocess_workspace(conv_bias_param);
}
MIDOUT_END();
}

SmallVector<TensorLayout>
ConvolutionImpl::AlgoDefault::deduce_preprocessed_filter_layout(
const NCBKernSizeParam& param) const {
::ConvBiasImpl::NCBKernSizeParam conv_bias_param =
init_conv_bias_param( param);
return m_algorithm->deduce_preprocessed_filter_layout(conv_bias_param);
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_fallback_conv,
midout_iv("AlgoDefault::deduce_preprocessed_filter_layout"_hash)) {
::ConvBiasImpl::NCBKernSizeParam conv_bias_param =
init_conv_bias_param(param);
return m_algorithm->deduce_preprocessed_filter_layout(conv_bias_param);
}
MIDOUT_END();
}

//! Return the implement preprocess kernel
@@ -450,19 +474,29 @@ bool ConvolutionBackwardDataImpl::AlgoDirect::usable(

size_t ConvolutionBackwardDataImpl::AlgoDirect::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
auto FH = param.filter_meta.spatial[0], FW = param.filter_meta.spatial[1];
if (param.filter_meta.should_flip) {
// need transpose filter
return FH * FW * sizeof(float);
} else {
return 0;
MIDOUT_BEGIN(megdnn_fallback_conv,
midout_iv("AlgoDirect::get_workspace"_hash)) {
auto FH = param.filter_meta.spatial[0],
FW = param.filter_meta.spatial[1];
if (param.filter_meta.should_flip) {
// need transpose filter
return FH * FW * sizeof(float);
} else {
return 0;
}
}
MIDOUT_END();
return 0;
}

ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::AlgoDirect::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
return kern_direct;
MIDOUT_BEGIN(megdnn_fallback_conv,
midout_iv("AlgoDirect::dispatch_kern"_hash)) {
return kern_direct;
}
MIDOUT_END();
}

/* ===================== Matrix mul algo ===================== */
@@ -477,35 +511,48 @@ bool ConvolutionBackwardDataImpl::AlgoMatrixMul::usable(

size_t ConvolutionBackwardDataImpl::AlgoMatrixMul::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_fallback_conv,
midout_iv("AlgoMatrixMul::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}

ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::AlgoMatrixMul::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
#define cb(dt) \
do { \
if (param.filter_type.enumv() == DTypeTrait<dt>::enumv) { \
using ctype = DTypeTrait<dt>::ctype; \
return kern_matmul<ctype, ctype, ctype>; \
} \
#define cb(dt, midout_tag) \
do { \
if (param.filter_type.enumv() == DTypeTrait<dt>::enumv) { \
MIDOUT_BEGIN(megdnn_fallback_conv, midout_iv(midout_tag)) { \
using ctype = DTypeTrait<dt>::ctype; \
return kern_matmul<ctype, ctype, ctype>; \
} \
MIDOUT_END(); \
} \
} while (0);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
cb(dtype::Float32, "FLOAT"_hash);
MEGDNN_INC_FLOAT16(cb(dtype::Float16, "FLOAT16"_hash));
MEGDNN_INC_FLOAT16(cb(dtype::BFloat16, "BFLOAT16"_hash));
#undef cb

#define cb(dt_src, dt_dst) \
do { \
if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \
param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \
param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \
return kern_matmul<DTypeTrait<dt_src>::ctype, \
DTypeTrait<dt_src>::ctype, \
DTypeTrait<dt_dst>::ctype>; \
} \
#define cb(dt_src, dt_dst, midout_tag) \
do { \
if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \
param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \
param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \
MIDOUT_BEGIN(megdnn_fallback_conv, midout_iv(midout_tag)) { \
return kern_matmul<DTypeTrait<dt_src>::ctype, \
DTypeTrait<dt_src>::ctype, \
DTypeTrait<dt_dst>::ctype>; \
} \
MIDOUT_END(); \
} \
} while (0)
cb(dtype::Int8, dtype::Int32);
cb(dtype::QuantizedS8, dtype::QuantizedS32);
cb(dtype::Quantized8Asymm, dtype::QuantizedS32);
cb(dtype::Int8, dtype::Int32, "INT8x8x32"_hash);
cb(dtype::QuantizedS8, dtype::QuantizedS32, "QINT8x8x32"_hash);
cb(dtype::Quantized8Asymm, dtype::QuantizedS32, "QUINT8x8x32"_hash);
megdnn_throw("unsupported data type on matrix mul");
#undef cb
}


+ 0
- 1
dnn/src/fallback/convolution/opr_impl.cpp View File

@@ -24,7 +24,6 @@

#include <cstring>

MIDOUT_DECL(megdnn_fb_conv_float)
MIDOUT_DECL(megdnn_fb_convbwd_float)

using namespace megdnn;


+ 14
- 7
dnn/src/fallback/matrix_mul/algos.cpp View File

@@ -53,13 +53,20 @@ bool MatrixMulImpl::AlgoF32K8x12x1::usable(

size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace(
const KernSizeParam& kern_size_param) const {
auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
matmul::fallback::sgemm_8x12 strategy(M, N, K, kern_size_param.A_type,
kern_size_param.B_type,
kern_size_param.C_type);
return matmul::GemmInterleaved<matmul::fallback::sgemm_8x12>(
M, N, K, kern_size_param.trA, kern_size_param.trB, strategy)
.get_workspace_size();
MIDOUT_BEGIN(megdnn_fb_matmul_f32_kern,
midout_iv("AlgoF32K8x12x1::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
matmul::fallback::sgemm_8x12 strategy(M, N, K, kern_size_param.A_type,
kern_size_param.B_type,
kern_size_param.C_type);
return matmul::GemmInterleaved<matmul::fallback::sgemm_8x12>(
M, N, K, kern_size_param.trA, kern_size_param.trB,
strategy)
.get_workspace_size();
}
MIDOUT_END();
return 0;
}

MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(


+ 11
- 4
dnn/src/naive/matrix_mul/opr_impl.cpp View File

@@ -23,10 +23,16 @@ namespace naive {
size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout&) {
if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
return (A.span().dist_elem() + B.span().dist_elem()) * sizeof(uint8_t);
MIDOUT_BEGIN(
megdnn_naive_matmul,
midout_iv("MatrixMulForwardImpl::get_workspace_in_bytes"_hash)) {
if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
return (A.span().dist_elem() + B.span().dist_elem()) *
sizeof(uint8_t);
}
return 0;
}
return 0;
MIDOUT_END();
}

template <bool TA, bool TB>
@@ -127,7 +133,8 @@ void MatrixMulForwardImpl::exec_internal(_megdnn_tensor_in A,
void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
_megdnn_tensor_out C,
_megdnn_workspace workspace) {
MIDOUT_BEGIN(megdnn_naive_matmul) {
MIDOUT_BEGIN(megdnn_naive_matmul,
midout_iv("MatrixMulForwardImpl::exec"_hash)) {
check_exec(A.layout, B.layout, C.layout, workspace.size);
auto p = param();
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal(A, B, C, workspace, p));


+ 11
- 0
src/megbrain_build_config.h.in View File

@@ -17,6 +17,7 @@
#cmakedefine01 MGB_ENABLE_DEBUG_UTIL
#cmakedefine01 MGB_ENABLE_LOGGING
#cmakedefine01 MGB_ENABLE_GRAD
#cmakedefine01 MGB_ENABLE_CPUINFO
#cmakedefine01 MGB_VERBOSE_TYPEINFO_NAME
#cmakedefine01 MGB_BUILD_SLIM_SERVING
#cmakedefine01 MGB_ENABLE_EXCEPTION
@@ -80,6 +81,16 @@
#define MGB_ENABLE_GRAD 1
#endif

// whether to enable cpuinfo
#ifndef MGB_ENABLE_CPUINFO
#define MGB_ENABLE_CPUINFO 1
#endif

#ifdef IOS
#undef MGB_ENABLE_CPUINFO
#define MGB_ENABLE_CPUINFO 0
#endif

// whether to include actual class name in mgb::Typeinfo object; if this is
// disabled, mgb::serialization::OprRegistry::find_opr_by_name would not work.
#ifndef MGB_VERBOSE_TYPEINFO_NAME


Loading…
Cancel
Save