Browse Source

fix(dnn/falback): let cpu be able to execute int4 model

GitOrigin-RevId: 1a6b78f3b6
tags/v1.7.0.m1
Megvii Engine Team 3 years ago
parent
commit
5e07e1e0f9
1 changed files with 11 additions and 5 deletions
  1. +11
    -5
      dnn/src/fallback/convolution/opr_impl.cpp

+ 11
- 5
dnn/src/fallback/convolution/opr_impl.cpp View File

@@ -431,7 +431,9 @@ ConvolutionImpl::AlgoDataType ConvolutionImpl::NCBKernSizeParam::deduce_algo_dat
}
} else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) {
return ConvolutionImpl::AlgoDataType::QUINT8X8X32;
} else if (src_type.enumv() == DTypeEnum::QuantizedS4) {
} else if (
src_type.enumv() == DTypeEnum::QuantizedS4 ||
src_type.enumv() == DTypeEnum::Quantized4Asymm) {
return ConvolutionImpl::AlgoDataType::QINT4x4x32;
} else {
megdnn_throw(ssprintf(
@@ -477,7 +479,8 @@ void ConvolutionBackwardDataImpl::exec(
_megdnn_workspace workspace) {
if (param().format == param::Convolution::Format::NHWCD4 ||
param().format == param::Convolution::Format::NCHW4 ||
(param().format == param::Convolution::Format::NCHW &&
((param().format == param::Convolution::Format::NCHW ||
param().format == param::Convolution::Format::NHWC) &&
grad.layout.dtype.enumv() == DTypeEnum::QuantizedS8)) {
return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad, workspace);
}
@@ -499,7 +502,8 @@ size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(

if (param().format == param::Convolution::Format::NHWCD4 ||
param().format == param::Convolution::Format::NCHW4 ||
(param().format == param::Convolution::Format::NCHW &&
((param().format == param::Convolution::Format::NCHW ||
param().format == param::Convolution::Format::NHWC) &&
grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
return naive::ConvolutionBackwardDataImpl::get_workspace_in_bytes(
filter, diff, grad);
@@ -514,7 +518,8 @@ std::vector<ConvolutionBackwardDataImpl::Algorithm*> ConvolutionBackwardDataImpl
const TensorLayout& grad) {
if (param().format == param::Convolution::Format::NHWCD4 ||
param().format == param::Convolution::Format::NCHW4 ||
(param().format == param::Convolution::Format::NCHW &&
((param().format == param::Convolution::Format::NCHW ||
param().format == param::Convolution::Format::NHWC) &&
grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
return naive::ConvolutionBackwardDataImpl::get_all_algorithms(
filter, diff, grad);
@@ -541,7 +546,8 @@ ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
const AlgoAttribute& negative_attr) {
if (param().format == param::Convolution::Format::NHWCD4 ||
param().format == param::Convolution::Format::NCHW4 ||
(param().format == param::Convolution::Format::NCHW &&
((param().format == param::Convolution::Format::NCHW ||
param().format == param::Convolution::Format::NHWC) &&
grad.dtype.enumv() == DTypeEnum::QuantizedS8)) {
return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic(
filter, diff, grad, workspace_limit_in_bytes, positive_attr,


Loading…
Cancel
Save