GitOrigin-RevId: 5a3bfedd8a
tags/v0.5.0
@@ -201,24 +201,27 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||
if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) | |||
return false; | |||
if (param.src_type.enumv() != param.filter_type.enumv() && | |||
param.src_type.enumv() != DTypeEnum::Int8 && | |||
param.src_type.enumv() != DTypeEnum::QuantizedS8 && | |||
param.src_type.enumv() != DTypeEnum::Quantized8Asymm && | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
param.src_type.enumv() != DTypeEnum::Float16 && | |||
#endif | |||
param.src_type.enumv() != DTypeEnum::Float32) { | |||
return false; | |||
} | |||
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode | |||
//! is identity otherwise return false mean that 8x8x32 and 8x8x16 | |||
//! not support PostProcess | |||
if (param.src_type.enumv() == param.filter_type.enumv() && | |||
(param.src_type.enumv() == DTypeEnum::Int8 && | |||
(param.dst_type.enumv() == DTypeEnum::Int16 || | |||
param.dst_type.enumv() == DTypeEnum::Int32)) && | |||
param.bias_mode != megdnn::BiasMode::NO_BIAS && | |||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) | |||
return false; | |||
if (param.src_type.enumv() == param.filter_type.enumv() && | |||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||
param.src_type.enumv() == DTypeEnum::Quantized8Asymm) && | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) && | |||
param.bias_mode != megdnn::BiasMode::NO_BIAS && | |||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) | |||
return false; | |||
if (param.dst_type.enumv() == DTypeEnum::Int16 || | |||
param.dst_type.enumv() == DTypeEnum::Int32 || | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | |||
if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
return false; | |||
} | |||
} | |||
if (opr->param().format == param::ConvBias::Format::NCHW44 || | |||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | |||
@@ -647,19 +647,26 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||
return false; | |||
} | |||
if (param.src_type.enumv() != param.filter_type.enumv() && | |||
param.src_type.enumv() != DTypeEnum::Int8 && | |||
param.src_type.enumv() != DTypeEnum::QuantizedS8 && | |||
param.src_type.enumv() != DTypeEnum::Quantized8Asymm && | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
param.src_type.enumv() != DTypeEnum::Float16 && | |||
#endif | |||
param.src_type.enumv() != DTypeEnum::Float32) { | |||
return false; | |||
} | |||
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is | |||
//! identity otherwise return false mean that 8x8x32 and 8x8x16 not | |||
//! support PostProcess | |||
if (param.src_type.enumv() == param.filter_type.enumv() && | |||
((param.src_type.enumv() == DTypeEnum::Int8 && | |||
(param.dst_type.enumv() == DTypeEnum::Int16 || | |||
param.dst_type.enumv() == DTypeEnum::Int32)) || | |||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||
param.src_type.enumv() == DTypeEnum::Quantized8Asymm) && | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS32)) && | |||
param.bias_mode != megdnn::BiasMode::NO_BIAS && | |||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
return false; | |||
if (param.dst_type.enumv() == DTypeEnum::Int16 || | |||
param.dst_type.enumv() == DTypeEnum::Int32 || | |||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | |||
if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||
return false; | |||
} | |||
} | |||
if (opr->param().format == param::ConvBias::Format::NCHW44 || | |||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | |||
@@ -188,6 +188,24 @@ void checker_conv_bias(std::vector<conv_bias::TestArg> args, Handle* handle, | |||
} | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_IM2COL_8X8X16) { | |||
using namespace conv_bias; | |||
param::ConvBias cur_param; | |||
using NLMode = param::ConvBias::NonlineMode; | |||
std::vector<conv_bias::TestArg> args = get_conv_bias_args( | |||
{1, 3}, {0}, {NLMode::IDENTITY, NLMode::RELU}, {1}, false, true); | |||
NormalRNG default_rng; | |||
Checker<ConvBias> checker(handle()); | |||
checker.set_dtype(0, dtype::Int8{}); | |||
checker.set_dtype(1, dtype::Int8{}); | |||
checker.set_dtype(2, dtype::Int16{}); | |||
checker.set_dtype(4, dtype::Int16{}); | |||
for (auto&& arg : args) { | |||
checker.set_param(arg.param).execs( | |||
{arg.src, arg.filter, arg.bias, {}, {}}); | |||
} | |||
} | |||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD) { | |||
using namespace conv_bias; | |||
param::ConvBias cur_param; | |||
@@ -1671,7 +1671,9 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const { | |||
rewriter.get_var(typecvt->input(0))->owner_opr()); | |||
if (!conv_bias || m_deps.count(typecvt->input(0)) != 1 || | |||
typecvt->output(0)->dtype().enumv() != | |||
DTypeTrait<dtype::QuantizedS8>::enumv) | |||
DTypeTrait<dtype::QuantizedS8>::enumv || | |||
typecvt->input(0)->dtype().enumv() != | |||
DTypeTrait<dtype::QuantizedS32>::enumv) | |||
return nullptr; | |||
auto config = conv_bias->config(); | |||