GitOrigin-RevId: 5a3bfedd8a
tags/v0.5.0
@@ -201,24 +201,27 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, | |||||
if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) | if (FH != 1 || FW != 1 || PH || PW || SH != 1 || SW != 1) | ||||
return false; | return false; | ||||
if (param.src_type.enumv() != param.filter_type.enumv() && | |||||
param.src_type.enumv() != DTypeEnum::Int8 && | |||||
param.src_type.enumv() != DTypeEnum::QuantizedS8 && | |||||
param.src_type.enumv() != DTypeEnum::Quantized8Asymm && | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
param.src_type.enumv() != DTypeEnum::Float16 && | |||||
#endif | |||||
param.src_type.enumv() != DTypeEnum::Float32) { | |||||
return false; | |||||
} | |||||
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode | //! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode | ||||
//! is identity otherwise return false mean that 8x8x32 and 8x8x16 | //! is identity otherwise return false mean that 8x8x32 and 8x8x16 | ||||
//! not support PostProcess | //! not support PostProcess | ||||
if (param.src_type.enumv() == param.filter_type.enumv() && | |||||
(param.src_type.enumv() == DTypeEnum::Int8 && | |||||
(param.dst_type.enumv() == DTypeEnum::Int16 || | |||||
param.dst_type.enumv() == DTypeEnum::Int32)) && | |||||
param.bias_mode != megdnn::BiasMode::NO_BIAS && | |||||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) | |||||
return false; | |||||
if (param.src_type.enumv() == param.filter_type.enumv() && | |||||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
param.src_type.enumv() == DTypeEnum::Quantized8Asymm) && | |||||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) && | |||||
param.bias_mode != megdnn::BiasMode::NO_BIAS && | |||||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) | |||||
return false; | |||||
if (param.dst_type.enumv() == DTypeEnum::Int16 || | |||||
param.dst_type.enumv() == DTypeEnum::Int32 || | |||||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | |||||
if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||||
return false; | |||||
} | |||||
} | |||||
if (opr->param().format == param::ConvBias::Format::NCHW44 || | if (opr->param().format == param::ConvBias::Format::NCHW44 || | ||||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | ||||
@@ -647,19 +647,26 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||||
return false; | return false; | ||||
} | } | ||||
if (param.src_type.enumv() != param.filter_type.enumv() && | |||||
param.src_type.enumv() != DTypeEnum::Int8 && | |||||
param.src_type.enumv() != DTypeEnum::QuantizedS8 && | |||||
param.src_type.enumv() != DTypeEnum::Quantized8Asymm && | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
param.src_type.enumv() != DTypeEnum::Float16 && | |||||
#endif | |||||
param.src_type.enumv() != DTypeEnum::Float32) { | |||||
return false; | |||||
} | |||||
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is | //! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is | ||||
//! identity otherwise return false mean that 8x8x32 and 8x8x16 not | //! identity otherwise return false mean that 8x8x32 and 8x8x16 not | ||||
//! support PostProcess | //! support PostProcess | ||||
if (param.src_type.enumv() == param.filter_type.enumv() && | |||||
((param.src_type.enumv() == DTypeEnum::Int8 && | |||||
(param.dst_type.enumv() == DTypeEnum::Int16 || | |||||
param.dst_type.enumv() == DTypeEnum::Int32)) || | |||||
((param.src_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
param.src_type.enumv() == DTypeEnum::Quantized8Asymm) && | |||||
param.dst_type.enumv() == DTypeEnum::QuantizedS32)) && | |||||
param.bias_mode != megdnn::BiasMode::NO_BIAS && | |||||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||||
return false; | |||||
if (param.dst_type.enumv() == DTypeEnum::Int16 || | |||||
param.dst_type.enumv() == DTypeEnum::Int32 || | |||||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) { | |||||
if (param.bias_mode != megdnn::BiasMode::NO_BIAS || | |||||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { | |||||
return false; | |||||
} | |||||
} | } | ||||
if (opr->param().format == param::ConvBias::Format::NCHW44 || | if (opr->param().format == param::ConvBias::Format::NCHW44 || | ||||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | ||||
@@ -188,6 +188,24 @@ void checker_conv_bias(std::vector<conv_bias::TestArg> args, Handle* handle, | |||||
} | } | ||||
} | } | ||||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD_IM2COL_8X8X16) { | |||||
using namespace conv_bias; | |||||
param::ConvBias cur_param; | |||||
using NLMode = param::ConvBias::NonlineMode; | |||||
std::vector<conv_bias::TestArg> args = get_conv_bias_args( | |||||
{1, 3}, {0}, {NLMode::IDENTITY, NLMode::RELU}, {1}, false, true); | |||||
NormalRNG default_rng; | |||||
Checker<ConvBias> checker(handle()); | |||||
checker.set_dtype(0, dtype::Int8{}); | |||||
checker.set_dtype(1, dtype::Int8{}); | |||||
checker.set_dtype(2, dtype::Int16{}); | |||||
checker.set_dtype(4, dtype::Int16{}); | |||||
for (auto&& arg : args) { | |||||
checker.set_param(arg.param).execs( | |||||
{arg.src, arg.filter, arg.bias, {}, {}}); | |||||
} | |||||
} | |||||
TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD) { | TEST_F(FALLBACK_MULTI_THREADS, CONV_BIAS_FORWARD) { | ||||
using namespace conv_bias; | using namespace conv_bias; | ||||
param::ConvBias cur_param; | param::ConvBias cur_param; | ||||
@@ -1671,7 +1671,9 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const { | |||||
rewriter.get_var(typecvt->input(0))->owner_opr()); | rewriter.get_var(typecvt->input(0))->owner_opr()); | ||||
if (!conv_bias || m_deps.count(typecvt->input(0)) != 1 || | if (!conv_bias || m_deps.count(typecvt->input(0)) != 1 || | ||||
typecvt->output(0)->dtype().enumv() != | typecvt->output(0)->dtype().enumv() != | ||||
DTypeTrait<dtype::QuantizedS8>::enumv) | |||||
DTypeTrait<dtype::QuantizedS8>::enumv || | |||||
typecvt->input(0)->dtype().enumv() != | |||||
DTypeTrait<dtype::QuantizedS32>::enumv) | |||||
return nullptr; | return nullptr; | ||||
auto config = conv_bias->config(); | auto config = conv_bias->config(); | ||||