|
|
@@ -431,7 +431,9 @@ ConvolutionImpl::AlgoDataType ConvolutionImpl::NCBKernSizeParam::deduce_algo_dat |
|
|
|
} |
|
|
|
} else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) { |
|
|
|
return ConvolutionImpl::AlgoDataType::QUINT8X8X32; |
|
|
|
} else if (src_type.enumv() == DTypeEnum::QuantizedS4) { |
|
|
|
} else if ( |
|
|
|
src_type.enumv() == DTypeEnum::QuantizedS4 || |
|
|
|
src_type.enumv() == DTypeEnum::Quantized4Asymm) { |
|
|
|
return ConvolutionImpl::AlgoDataType::QINT4x4x32; |
|
|
|
} else { |
|
|
|
megdnn_throw(ssprintf( |
|
|
@@ -477,7 +479,8 @@ void ConvolutionBackwardDataImpl::exec( |
|
|
|
_megdnn_workspace workspace) { |
|
|
|
if (param().format == param::Convolution::Format::NHWCD4 || |
|
|
|
param().format == param::Convolution::Format::NCHW4 || |
|
|
|
(param().format == param::Convolution::Format::NCHW && |
|
|
|
((param().format == param::Convolution::Format::NCHW || |
|
|
|
param().format == param::Convolution::Format::NHWC) && |
|
|
|
grad.layout.dtype.enumv() == DTypeEnum::QuantizedS8)) { |
|
|
|
return naive::ConvolutionBackwardDataImpl::exec(filter, diff, grad, workspace); |
|
|
|
} |
|
|
@@ -499,7 +502,8 @@ size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( |
|
|
|
|
|
|
|
if (param().format == param::Convolution::Format::NHWCD4 || |
|
|
|
param().format == param::Convolution::Format::NCHW4 || |
|
|
|
(param().format == param::Convolution::Format::NCHW && |
|
|
|
((param().format == param::Convolution::Format::NCHW || |
|
|
|
param().format == param::Convolution::Format::NHWC) && |
|
|
|
grad.dtype.enumv() == DTypeEnum::QuantizedS8)) { |
|
|
|
return naive::ConvolutionBackwardDataImpl::get_workspace_in_bytes( |
|
|
|
filter, diff, grad); |
|
|
@@ -514,7 +518,8 @@ std::vector<ConvolutionBackwardDataImpl::Algorithm*> ConvolutionBackwardDataImpl |
|
|
|
const TensorLayout& grad) { |
|
|
|
if (param().format == param::Convolution::Format::NHWCD4 || |
|
|
|
param().format == param::Convolution::Format::NCHW4 || |
|
|
|
(param().format == param::Convolution::Format::NCHW && |
|
|
|
((param().format == param::Convolution::Format::NCHW || |
|
|
|
param().format == param::Convolution::Format::NHWC) && |
|
|
|
grad.dtype.enumv() == DTypeEnum::QuantizedS8)) { |
|
|
|
return naive::ConvolutionBackwardDataImpl::get_all_algorithms( |
|
|
|
filter, diff, grad); |
|
|
@@ -541,7 +546,8 @@ ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl:: |
|
|
|
const AlgoAttribute& negative_attr) { |
|
|
|
if (param().format == param::Convolution::Format::NHWCD4 || |
|
|
|
param().format == param::Convolution::Format::NCHW4 || |
|
|
|
(param().format == param::Convolution::Format::NCHW && |
|
|
|
((param().format == param::Convolution::Format::NCHW || |
|
|
|
param().format == param::Convolution::Format::NHWC) && |
|
|
|
grad.dtype.enumv() == DTypeEnum::QuantizedS8)) { |
|
|
|
return naive::ConvolutionBackwardDataImpl::get_algorithm_heuristic( |
|
|
|
filter, diff, grad, workspace_limit_in_bytes, positive_attr, |
|
|
|