Browse Source

fix(dnn/fallback): fix conv1x1/im2col usable and fuse-conv-bias get fp32xfp32-->qint8 error

GitOrigin-RevId: 5a3bfedd8a
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
4d35397bdf
4 changed files with 56 additions and 26 deletions
  1. +18
    -15
      dnn/src/fallback/conv_bias/conv1x1/algos.cpp
  2. +17
    -10
      dnn/src/fallback/conv_bias/im2col/algos.cpp
  3. +18
    -0
      dnn/test/fallback/conv_bias.cpp
  4. +3
    -1
      src/gopt/impl/inference.cpp

+ 18
- 15
dnn/src/fallback/conv_bias/conv1x1/algos.cpp View File

@@ -201,24 +201,27 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr,
if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1)
return false;

if (param.src_type.enumv() != param.filter_type.enumv() &&
param.src_type.enumv() != DTypeEnum::Int8 &&
param.src_type.enumv() != DTypeEnum::QuantizedS8 &&
param.src_type.enumv() != DTypeEnum::Quantized8Asymm &&
#if !MEGDNN_DISABLE_FLOAT16
param.src_type.enumv() != DTypeEnum::Float16 &&
#endif
param.src_type.enumv() != DTypeEnum::Float32) {
return false;
}
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode
//! is identity otherwise return false mean that 8x8x32 and 8x8x16
//! not support PostProcess
if (param.src_type.enumv() == param.filter_type.enumv() &&
(param.src_type.enumv() == DTypeEnum::Int8 &&
(param.dst_type.enumv() == DTypeEnum::Int16 ||
param.dst_type.enumv() == DTypeEnum::Int32)) &&
param.bias_mode != megdnn::BiasMode::NO_BIAS &&
param.nonlineMode != megdnn::NonlineMode::IDENTITY)
return false;

if (param.src_type.enumv() == param.filter_type.enumv() &&
((param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Quantized8Asymm) &&
param.dst_type.enumv() == DTypeEnum::QuantizedS32) &&
param.bias_mode != megdnn::BiasMode::NO_BIAS &&
param.nonlineMode != megdnn::NonlineMode::IDENTITY)
return false;
if (param.dst_type.enumv() == DTypeEnum::Int16 ||
param.dst_type.enumv() == DTypeEnum::Int32 ||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
if (param.bias_mode != megdnn::BiasMode::NO_BIAS ||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
return false;
}
}

if (opr->param().format == param::ConvBias::Format::NCHW44 ||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) {


+ 17
- 10
dnn/src/fallback/conv_bias/im2col/algos.cpp View File

@@ -647,19 +647,26 @@ bool ConvBiasImpl::AlgoIm2col::usable(
return false;
}

if (param.src_type.enumv() != param.filter_type.enumv() &&
param.src_type.enumv() != DTypeEnum::Int8 &&
param.src_type.enumv() != DTypeEnum::QuantizedS8 &&
param.src_type.enumv() != DTypeEnum::Quantized8Asymm &&
#if !MEGDNN_DISABLE_FLOAT16
param.src_type.enumv() != DTypeEnum::Float16 &&
#endif
param.src_type.enumv() != DTypeEnum::Float32) {
return false;
}
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is
//! identity otherwise return false mean that 8x8x32 and 8x8x16 not
//! support PostProcess
if (param.src_type.enumv() == param.filter_type.enumv() &&
((param.src_type.enumv() == DTypeEnum::Int8 &&
(param.dst_type.enumv() == DTypeEnum::Int16 ||
param.dst_type.enumv() == DTypeEnum::Int32)) ||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Quantized8Asymm) &&
param.dst_type.enumv() == DTypeEnum::QuantizedS32)) &&
param.bias_mode != megdnn::BiasMode::NO_BIAS &&
param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
return false;
if (param.dst_type.enumv() == DTypeEnum::Int16 ||
param.dst_type.enumv() == DTypeEnum::Int32 ||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
if (param.bias_mode != megdnn::BiasMode::NO_BIAS ||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
return false;
}
}
if (opr->param().format == param::ConvBias::Format::NCHW44 ||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) {


+ 18
- 0
dnn/test/fallback/conv_bias.cpp View File

@@ -188,6 +188,24 @@ void checker_conv_bias(std::vector<conv_bias::TestArg> args, Handle* handle,
}
}

TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_IM2COL_8X8X16) {
using namespace conv_bias;
param::ConvBias cur_param;
using NLMode = param::ConvBias::NonlineMode;
std::vector<conv_bias::TestArg> args = get_conv_bias_args(
{1, 3}, {0}, {NLMode::IDENTITY, NLMode::RELU}, {1}, false, true);
NormalRNG default_rng;
Checker<ConvBias> checker(handle());
checker.set_dtype(0, dtype::Int8{});
checker.set_dtype(1, dtype::Int8{});
checker.set_dtype(2, dtype::Int16{});
checker.set_dtype(4, dtype::Int16{});
for (auto&& arg : args) {
checker.set_param(arg.param).execs(
{arg.src, arg.filter, arg.bias, {}, {}});
}
}

TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD) {
using namespace conv_bias;
param::ConvBias cur_param;


+ 3
- 1
src/gopt/impl/inference.cpp View File

@@ -1671,7 +1671,9 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const {
rewriter.get_var(typecvt->input(0))->owner_opr());
if (!conv_bias || m_deps.count(typecvt->input(0)) != 1 ||
typecvt->output(0)->dtype().enumv() !=
DTypeTrait<dtype::QuantizedS8>::enumv)
DTypeTrait<dtype::QuantizedS8>::enumv ||
typecvt->input(0)->dtype().enumv() !=
DTypeTrait<dtype::QuantizedS32>::enumv)
return nullptr;

auto config = conv_bias->config();


Loading…
Cancel
Save