Browse Source

fix(gopt): fix global layout transform fold conv typecvt

GitOrigin-RevId: 66a23a927e
release-1.11.1
Megvii Engine Team Wanwan1996 2 years ago
parent
commit
6f9f25a882
3 changed files with 31 additions and 12 deletions
  1. +1
    -1
      dnn/src/cuda/conv_bias/conv_nchwqs8.cpp
  2. +27
    -11
      src/gopt/impl/folding_conv_typecvt.cpp
  3. +3
    -0
      src/gopt/impl/framework.cpp

+ 1
- 1
dnn/src/cuda/conv_bias/conv_nchwqs8.cpp View File

@@ -110,7 +110,7 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available(
bool is_version_ok = CUDNN_VERSION >= 7500;
bool is_dtype_ok =
(args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 &&
(args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 ||
(args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 &&
args.dst_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm));
bool is_bias_ok =
args.bias_layout->ndim == 0 ||


+ 27
- 11
src/gopt/impl/folding_conv_typecvt.cpp View File

@@ -76,7 +76,7 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const {
if (conv_bias == nullptr)
return false;
auto inp_dtype_conv = conv_bias->input(0)->dtype(),
out_dtype_conv = conv_bias->input(0)->dtype();
out_dtype_conv = conv_bias->output(0)->dtype();
bool is_s8nhwc =
inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 &&
out_dtype_conv.enumv() == inp_dtype_conv.enumv() &&
@@ -86,7 +86,11 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const {
inp_dtype_conv.enumv() == DTypeEnum::Quantized4Asymm) &&
out_dtype_conv.enumv() == inp_dtype_conv.enumv() &&
conv_bias->param().format == megdnn::param::ConvBias::Format::NHWC;
if (!(is_s8nhwc || is_s4nhwc))
bool is_s8nchw =
inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 &&
out_dtype_conv.enumv() == inp_dtype_conv.enumv() &&
conv_bias->param().format == megdnn::param::ConvBias::Format::NCHW;
if (!(is_s8nhwc || is_s4nhwc || is_s8nchw))
return false;
if (conv_bias->input().size() != 3)
return false;
@@ -107,15 +111,27 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const {
auto new_bias = (out_dtype_typecvt.enumv() == DTypeEnum::Float32)
? opr::TypeCvt::make(bias, dtype::Float32()).node()
: bias;
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NHWC;
auto conv_bias_typecvt = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{out_dtype_typecvt});
rewriter.replace_var(
opr->output(0), conv_bias_typecvt.node(),
mgb_cstr_log("replace conv_bias(NHWC) + typecvt "
"to conv_bias(NHWC)"));
if (is_s8nchw && is_s82s4) {
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NCHW;
auto conv_bias_typecvt = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{out_dtype_typecvt});
rewriter.replace_var(
opr->output(0), conv_bias_typecvt.node(),
mgb_cstr_log("replace conv_bias(NCHW) + typecvt "
"to conv_bias(NCHW)"));
} else {
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NHWC;
auto conv_bias_typecvt = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{out_dtype_typecvt});
rewriter.replace_var(
opr->output(0), conv_bias_typecvt.node(),
mgb_cstr_log("replace conv_bias(NHWC) + typecvt "
"to conv_bias(NHWC)"));
}
return true;
};



+ 3
- 0
src/gopt/impl/framework.cpp View File

@@ -823,6 +823,9 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options(
add_pass<FuseConvBiasNonlinPass>();
if (options.target == Target::CUDA)
add_pass<FuseConvBiasZPass>();
#if CUDA_VERSION >= 10020
add_pass<FoldingConvBiasTypecvtPass>();
#endif
add_pass(LayoutTransformPass::make(options.target));
add_pass<ShuffleShuffleRemovePass>();
if (options.target == Target::CUDA) {


Loading…
Cancel
Save