|
|
@@ -76,7 +76,7 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { |
|
|
|
if (conv_bias == nullptr) |
|
|
|
return false; |
|
|
|
auto inp_dtype_conv = conv_bias->input(0)->dtype(), |
|
|
|
out_dtype_conv = conv_bias->input(0)->dtype(); |
|
|
|
out_dtype_conv = conv_bias->output(0)->dtype(); |
|
|
|
bool is_s8nhwc = |
|
|
|
inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
out_dtype_conv.enumv() == inp_dtype_conv.enumv() && |
|
|
@@ -86,7 +86,11 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { |
|
|
|
inp_dtype_conv.enumv() == DTypeEnum::Quantized4Asymm) && |
|
|
|
out_dtype_conv.enumv() == inp_dtype_conv.enumv() && |
|
|
|
conv_bias->param().format == megdnn::param::ConvBias::Format::NHWC; |
|
|
|
if (!(is_s8nhwc || is_s4nhwc)) |
|
|
|
bool is_s8nchw = |
|
|
|
inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
out_dtype_conv.enumv() == inp_dtype_conv.enumv() && |
|
|
|
conv_bias->param().format == megdnn::param::ConvBias::Format::NCHW; |
|
|
|
if (!(is_s8nhwc || is_s4nhwc || is_s8nchw)) |
|
|
|
return false; |
|
|
|
if (conv_bias->input().size() != 3) |
|
|
|
return false; |
|
|
@@ -107,15 +111,27 @@ void FoldingConvBiasTypecvtPass::apply(OptState& opt) const { |
|
|
|
auto new_bias = (out_dtype_typecvt.enumv() == DTypeEnum::Float32) |
|
|
|
? opr::TypeCvt::make(bias, dtype::Float32()).node() |
|
|
|
: bias; |
|
|
|
auto new_param = conv_bias->param(); |
|
|
|
new_param.format = megdnn::param::ConvBias::Format::NHWC; |
|
|
|
auto conv_bias_typecvt = opr::ConvBias::make( |
|
|
|
src, filter, new_bias, new_param, conv_bias->execution_policy(), |
|
|
|
OperatorNodeConfig{out_dtype_typecvt}); |
|
|
|
rewriter.replace_var( |
|
|
|
opr->output(0), conv_bias_typecvt.node(), |
|
|
|
mgb_cstr_log("replace conv_bias(NHWC) + typecvt " |
|
|
|
"to conv_bias(NHWC)")); |
|
|
|
if (is_s8nchw && is_s82s4) { |
|
|
|
auto new_param = conv_bias->param(); |
|
|
|
new_param.format = megdnn::param::ConvBias::Format::NCHW; |
|
|
|
auto conv_bias_typecvt = opr::ConvBias::make( |
|
|
|
src, filter, new_bias, new_param, conv_bias->execution_policy(), |
|
|
|
OperatorNodeConfig{out_dtype_typecvt}); |
|
|
|
rewriter.replace_var( |
|
|
|
opr->output(0), conv_bias_typecvt.node(), |
|
|
|
mgb_cstr_log("replace conv_bias(NCHW) + typecvt " |
|
|
|
"to conv_bias(NCHW)")); |
|
|
|
} else { |
|
|
|
auto new_param = conv_bias->param(); |
|
|
|
new_param.format = megdnn::param::ConvBias::Format::NHWC; |
|
|
|
auto conv_bias_typecvt = opr::ConvBias::make( |
|
|
|
src, filter, new_bias, new_param, conv_bias->execution_policy(), |
|
|
|
OperatorNodeConfig{out_dtype_typecvt}); |
|
|
|
rewriter.replace_var( |
|
|
|
opr->output(0), conv_bias_typecvt.node(), |
|
|
|
mgb_cstr_log("replace conv_bias(NHWC) + typecvt " |
|
|
|
"to conv_bias(NHWC)")); |
|
|
|
} |
|
|
|
return true; |
|
|
|
}; |
|
|
|
|
|
|
|