From 5e07e1e0f91db733e643ad3d90e8aa8df9ac2669 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 22 Nov 2021 17:50:56 +0800 Subject: [PATCH] fix(dnn/falback): let cpu be able to execute int4 model GitOrigin-RevId: 1a6b78f3b695aac0b9de346163e4b0f3cd1dc3fb --- dnn/src/fallback/convolution/opr_impl.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/dnn/src/fallback/convolution/opr_impl.cpp b/dnn/src/fallback/convolution/opr_impl.cpp index 1e1622db..0624f02e 100644 --- a/dnn/src/fallback/convolution/opr_impl.cpp +++ b/dnn/src/fallback/convolution/opr_impl.cpp @@ -431,7 +431,9 @@ ConvolutionImpl::AlgoDataType ConvolutionImpl::NCBKernSizeParam::deduce_algo_dat } } else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) { return ConvolutionImpl::AlgoDataType::QUINT8X8X32; - } else if (src_type.enumv() == DTypeEnum::QuantizedS4) { + } else if ( + src_type.enumv() == DTypeEnum::QuantizedS4 || + src_type.enumv() == DTypeEnum::Quantized4Asymm) { return ConvolutionImpl::AlgoDataType::QINT4x4x32; } else { megdnn_throw(ssprintf( @@ -477,7 +479,8 @@ void ConvolutionBackwardDataImpl::exec( _megdnn_workspace workspace) { if (param().format == param::Convolution::Format::NHWCD4 || param().format == param::Convolution::Format::NCHW4 || - (param().format == param::Convolution::Format::NCHW && + ((param().format == param::Convolution::Format::NCHW || + param().format == param::Convolution::Format::NHWC) && grad.layout.dtype.enumv() == DTypeEnum::QuantizedS8)) { return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad, workspace); } @@ -499,7 +502,8 @@ size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( if (param().format == param::Convolution::Format::NHWCD4 || param().format == param::Convolution::Format::NCHW4 || - (param().format == param::Convolution::Format::NCHW && + ((param().format == param::Convolution::Format::NCHW || + param().format == param::Convolution::Format::NHWC) && grad.dtype.enumv() == DTypeEnum::QuantizedS8)) { return naive::ConvolutionBackwardDataImpl::get_workspace_in_bytes( filter, diff, grad); @@ -514,7 +518,8 @@ std::vector ConvolutionBackwardDataImpl const TensorLayout& grad) { if (param().format == param::Convolution::Format::NHWCD4 || param().format == param::Convolution::Format::NCHW4 || - (param().format == param::Convolution::Format::NCHW && + ((param().format == param::Convolution::Format::NCHW || + param().format == param::Convolution::Format::NHWC) && grad.dtype.enumv() == DTypeEnum::QuantizedS8)) { return naive::ConvolutionBackwardDataImpl::get_all_algorithms( filter, diff, grad); @@ -541,7 +546,8 @@ ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl:: const AlgoAttribute& negative_attr) { if (param().format == param::Convolution::Format::NHWCD4 || param().format == param::Convolution::Format::NCHW4 || - (param().format == param::Convolution::Format::NCHW && + ((param().format == param::Convolution::Format::NCHW || + param().format == param::Convolution::Format::NHWC) && grad.dtype.enumv() == DTypeEnum::QuantizedS8)) { return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic( filter, diff, grad, workspace_limit_in_bytes, positive_attr,