GitOrigin-RevId: 29cd73f87b
release-1.2
@@ -39,7 +39,10 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||
'NCHW44','NCHW44_DOT', | |||
Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'), | |||
Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'), | |||
Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'), | |||
Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'), | |||
Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||
Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||
Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | |||
Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' | |||
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.')) | |||
) | |||
@@ -48,38 +48,52 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||
megdnn_assert(src.dtype.enumv() == filter.dtype.enumv()); | |||
} | |||
if (src.dtype.enumv() == DTypeEnum::QuantizedS8) { | |||
float scale_src = src.dtype.param<dtype::QuantizedS8>().scale; | |||
float scale_filter = 0.f; | |||
if (param().format == param::ConvBias::Format::NCHW_WINOGRAD || | |||
param().format == param::ConvBias::Format::NCHW88_WINOGRAD || | |||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) { | |||
if (filter.dtype.enumv() == DTypeEnum::QuantizedS32) { | |||
//!int8 winogradf23_44 using float,QuantizedS32 take the scale | |||
scale_filter = filter.dtype.param<dtype::QuantizedS32>().scale; | |||
if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) { | |||
float scale_src = src.dtype.param<dtype::QuantizedS8>().scale; | |||
float scale_filter = 0.f; | |||
if (param().format == param::ConvBias::Format::NCHW_WINOGRAD || | |||
param().format == param::ConvBias::Format::NCHW88_WINOGRAD || | |||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) { | |||
if (filter.dtype.enumv() == DTypeEnum::QuantizedS32) { | |||
//! int8 winogradf23_44 using float,QuantizedS32 take the | |||
//! scale | |||
scale_filter = | |||
filter.dtype.param<dtype::QuantizedS32>().scale; | |||
} else { | |||
scale_filter = | |||
filter.dtype.param<dtype::QuantizedS16>().scale; | |||
} | |||
} else { | |||
scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale; | |||
scale_filter = filter.dtype.param<dtype::QuantizedS8>().scale; | |||
} | |||
float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale; | |||
megdnn_assert( | |||
std::abs(scale_src * scale_filter - scale_bias) < 1e-6, | |||
"scale_src: %f scale_filter: %f scale_bias: %f", scale_src, | |||
scale_filter, scale_bias); | |||
} else { | |||
scale_filter = filter.dtype.param<dtype::QuantizedS8>().scale; | |||
megdnn_assert(bias.dtype.enumv() == DTypeEnum::Float32); | |||
} | |||
float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale; | |||
megdnn_assert(std::abs(scale_src * scale_filter - scale_bias) < 1e-6, | |||
"scale_src: %f scale_filter: %f scale_bias: %f", | |||
scale_src, scale_filter, scale_bias); | |||
} else if (src.dtype.enumv() == DTypeEnum::Quantized8Asymm) { | |||
float scale_src = src.dtype.param<dtype::Quantized8Asymm>().scale; | |||
float scale_filter = 0.f; | |||
if (param().format == param::ConvBias::Format::NCHW_WINOGRAD || | |||
param().format == param::ConvBias::Format::NCHW88_WINOGRAD || | |||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) { | |||
scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale; | |||
if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) { | |||
float scale_src = src.dtype.param<dtype::Quantized8Asymm>().scale; | |||
float scale_filter = 0.f; | |||
if (param().format == param::ConvBias::Format::NCHW_WINOGRAD || | |||
param().format == param::ConvBias::Format::NCHW88_WINOGRAD || | |||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) { | |||
scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale; | |||
} else { | |||
scale_filter = | |||
filter.dtype.param<dtype::Quantized8Asymm>().scale; | |||
} | |||
float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale; | |||
megdnn_assert( | |||
std::abs(scale_src * scale_filter - scale_bias) < 1e-6, | |||
"scale_src: %f scale_filter: %f scale_bias: %f", scale_src, | |||
scale_filter, scale_bias); | |||
} else { | |||
scale_filter = filter.dtype.param<dtype::Quantized8Asymm>().scale; | |||
megdnn_assert(bias.dtype.enumv() == DTypeEnum::Float32); | |||
} | |||
float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale; | |||
megdnn_assert(std::abs(scale_src * scale_filter - scale_bias) < 1e-6, | |||
"scale_src: %f scale_filter: %f scale_bias: %f", | |||
scale_src, scale_filter, scale_bias); | |||
} | |||
auto ret = check_layout_fwd(src, filter, dst); | |||
@@ -101,7 +115,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||
if (check_eq(bias, dst)) | |||
return ret; | |||
if (param().format == param::ConvBias::Format::NCHW || | |||
param().format == param::ConvBias::Format::NCHW_WINOGRAD) { | |||
param().format == param::ConvBias::Format::NCHW_WINOGRAD || | |||
param().format == param::ConvBias::Format::NCHW4_NCHW) { | |||
megdnn_assert(bias.shape[0] == 1); | |||
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", | |||
bias.to_string().c_str(), dst.to_string().c_str()); | |||
@@ -116,7 +131,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||
} else if (param().format == param::ConvBias::Format::NCHW4 || | |||
param().format == param::ConvBias::Format::NCHW44 || | |||
param().format == param::ConvBias::Format::NCHW44_DOT || | |||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) { | |||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD || | |||
param().format == param::ConvBias::Format::NCHW32_NCHW4) { | |||
megdnn_assert(bias.shape[0] == 1); | |||
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", | |||
bias.to_string().c_str(), dst.to_string().c_str()); | |||
@@ -132,7 +148,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||
megdnn_assert(bias.shape[2] == 1); | |||
megdnn_assert(bias.shape[3] == 1); | |||
megdnn_assert(bias.shape[4] == 8); | |||
} else if (param().format == param::ConvBias::Format::NCHW32) { | |||
} else if (param().format == param::ConvBias::Format::NCHW32 || | |||
param().format == param::ConvBias::Format::NCHW4_NCHW32) { | |||
megdnn_assert(bias.shape[0] == 1); | |||
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", | |||
bias.to_string().c_str(), dst.to_string().c_str()); | |||
@@ -163,6 +180,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||
param::ConvBias::Format::NCHW88_WINOGRAD); | |||
megdnn_assert(param().format != | |||
param::ConvBias::Format::NCHW44_WINOGRAD); | |||
megdnn_assert(param().format != param::ConvBias::Format::NCHW4_NCHW32); | |||
megdnn_assert(param().format != param::ConvBias::Format::NCHW32_NCHW4); | |||
megdnn_assert(z.dtype.enumv() == dst.dtype.enumv()); | |||
megdnn_assert(z.eq_shape(dst)); | |||
} | |||
@@ -443,7 +443,10 @@ void make_canonized_filter_meta_nchwx( | |||
*/ | |||
megdnn_assert(param.format == Param::Format::NCHW4 || | |||
param.format == Param::Format::NCHW8 || | |||
param.format == Param::Format::NCHW32); | |||
param.format == Param::Format::NCHW32 || | |||
param.format == Param::Format::NCHW4_NCHW || | |||
param.format == Param::Format::NCHW4_NCHW32 || | |||
param.format == Param::Format::NCHW32_NCHW4); | |||
auto img_ndim = src_ndim - 3; | |||
size_t flt_start = 0, flt_spatial_start = 2; | |||
if (param.sparse == Param::Sparse::DENSE) { | |||
@@ -568,7 +571,9 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta( | |||
make_canonized_filter_meta_nhwcd4<Parameter>(src_ndim, filter, | |||
param(), ret); | |||
} | |||
} else if (param().format == Param::Format::NCHW4) { | |||
} else if (param().format == Param::Format::NCHW4 || | |||
param().format == Param::Format::NCHW4_NCHW || | |||
param().format == Param::Format::NCHW4_NCHW32) { | |||
make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter, | |||
param(), ret); | |||
} else if (param().format == Param::Format::NCHW8) { | |||
@@ -583,7 +588,8 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta( | |||
param().format == Param::Format::NCHW44_WINOGRAD) { | |||
make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter, | |||
param(), ret); | |||
} else if (param().format == Param::Format::NCHW32) { | |||
} else if (param().format == Param::Format::NCHW32 || | |||
param().format == Param::Format::NCHW32_NCHW4) { | |||
make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter, | |||
param(), ret); | |||
} else if (param().format == Param::Format::CHWN4) { | |||
@@ -627,6 +633,9 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src, | |||
if (dst.valid() && dst.enumv() == src.enumv()) { | |||
supported_dst_dtype.push_back(dst); | |||
} | |||
if (src.enumv() == DTypeEnum::QuantizedS8) { | |||
supported_dst_dtype.push_back(dtype::Float32()); | |||
} | |||
} else if (src.enumv() == DTypeEnum::QuantizedS32) { | |||
//! ConvolutionBackwardData: s8(filter) + s8(dst) -> s32(src) | |||
megdnn_assert(filter.enumv() == DTypeEnum::QuantizedS8); | |||
@@ -697,10 +706,13 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||
} else { | |||
megdnn_assert(param().format == Param::Format::NHWCD4 || | |||
param().format == Param::Format::NCHW4 || | |||
param().format == Param::Format::NCHW4_NCHW || | |||
param().format == Param::Format::NCHW4_NCHW32 || | |||
param().format == Param::Format::NCHW44 || | |||
param().format == Param::Format::NCHW44_DOT || | |||
param().format == Param::Format::NCHW8 || | |||
param().format == Param::Format::NCHW32 || | |||
param().format == Param::Format::NCHW32_NCHW4 || | |||
param().format == Param::Format::NCHW88 || | |||
param().format == Param::Format::NCHW88_WINOGRAD || | |||
param().format == Param::Format::NCHW44_WINOGRAD || | |||
@@ -720,13 +732,17 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||
filter.ndim == img_dim + 4 || | |||
filter.ndim == img_dim + 5, | |||
"%s", errmsg().c_str()); | |||
if (param().format == Param::Format::NCHW4) { | |||
if (param().format == Param::Format::NCHW4 || | |||
param().format == Param::Format::NCHW4_NCHW || | |||
param().format == Param::Format::NCHW4_NCHW32) { | |||
megdnn_assert(src.ndim == 5 && | |||
(filter.ndim == 5 || filter.ndim == 6 || | |||
filter.ndim == 7) && | |||
src[src.ndim - 1] == 4 && | |||
filter[filter.ndim - 1] == 4, | |||
"NCHW4 require src and filter's ndim is 5 or 6, and " | |||
"NCHW4/NCHW4_NCHW/NCHW4_NCHW32 require src and " | |||
"filter's ndim is " | |||
"5 or 6, and " | |||
"last shape " | |||
"is 4 " | |||
"but got src %s, filter %s", | |||
@@ -742,15 +758,17 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||
"but got src %s, filter %s", | |||
src.to_string().c_str(), filter.to_string().c_str()); | |||
} | |||
if (param().format == Param::Format::NCHW32) { | |||
megdnn_assert( | |||
src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) && | |||
src[src.ndim - 1] == 32 && | |||
filter[filter.ndim - 1] == 32, | |||
"NCHW32 require src and filter's ndim is 5 or 6, and last " | |||
"shape is 32 " | |||
"but got src %s, filter %s", | |||
src.to_string().c_str(), filter.to_string().c_str()); | |||
if (param().format == Param::Format::NCHW32 || | |||
param().format == Param::Format::NCHW32_NCHW4) { | |||
megdnn_assert(src.ndim == 5 && | |||
(filter.ndim == 5 || filter.ndim == 6) && | |||
src[src.ndim - 1] == 32 && | |||
filter[filter.ndim - 1] == 32, | |||
"NCHW32/NCHW32_NCHW4 require src and filter's ndim " | |||
"is 5 or 6, and last " | |||
"shape is 32 " | |||
"but got src %s, filter %s", | |||
src.to_string().c_str(), filter.to_string().c_str()); | |||
} | |||
if (param().format == Param::Format::NCHW88 || | |||
param().format == Param::Format::NCHW88_WINOGRAD) { | |||
@@ -943,6 +961,55 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||
dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[1], | |||
cflt.stride[1], cflt.padding[1]); | |||
dst[4] = 4; | |||
} else if (param().format == Param::Format::NCHW4_NCHW) { | |||
megdnn_assert(src.ndim == 5, | |||
"invalid src ndim for NCHW4_NCHW, expected=5, got=%zu", | |||
src.ndim); | |||
megdnn_assert(cflt.icpg * cflt.group == src[1] * 4, | |||
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, | |||
cflt.group); | |||
dst.ndim = 4; | |||
dst[0] = src[0]; | |||
auto oc = cflt.ocpg * cflt.group; | |||
dst[1] = oc; | |||
dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], | |||
cflt.stride[0], cflt.padding[0]); | |||
dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], | |||
cflt.stride[1], cflt.padding[1]); | |||
} else if (param().format == Param::Format::NCHW4_NCHW32) { | |||
megdnn_assert(src.ndim == 5, | |||
"invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu", | |||
src.ndim); | |||
megdnn_assert(cflt.icpg * cflt.group == src[1] * 4, | |||
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, | |||
cflt.group); | |||
dst.ndim = src.ndim; | |||
dst[0] = src[0]; | |||
auto oc = cflt.ocpg * cflt.group; | |||
megdnn_assert(oc % 32 == 0); | |||
dst[1] = oc / 32; | |||
dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], | |||
cflt.stride[0], cflt.padding[0]); | |||
dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], | |||
cflt.stride[1], cflt.padding[1]); | |||
dst[4] = 32; | |||
} else if (param().format == Param::Format::NCHW32_NCHW4) { | |||
megdnn_assert(src.ndim == 5, | |||
"invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu", | |||
src.ndim); | |||
megdnn_assert(cflt.icpg * cflt.group == src[1] * 32, | |||
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, | |||
cflt.group); | |||
dst.ndim = src.ndim; | |||
dst[0] = src[0]; | |||
auto oc = cflt.ocpg * cflt.group; | |||
megdnn_assert(oc % 4 == 0); | |||
dst[1] = oc / 4; | |||
dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], | |||
cflt.stride[0], cflt.padding[0]); | |||
dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], | |||
cflt.stride[1], cflt.padding[1]); | |||
dst[4] = 4; | |||
} else { | |||
megdnn_assert(param().format == Param::Format::NHWCD4); | |||
megdnn_assert(src.ndim == 5, | |||
@@ -31,6 +31,9 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||
args.bias_layout->eq_shape(*args.dst_layout)) | |||
return false; | |||
auto&& param = args.opr->param(); | |||
if (param.format == param::ConvBias::Format::NCHW4_NCHW32 || | |||
param.format == param::ConvBias::Format::NCHW32_NCHW4) | |||
return false; | |||
if (param.format == param::ConvBias::Format::NCHW && | |||
(param.dilate_h != 1 || param.dilate_w != 1) && | |||
m_cudnn_enum == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) { | |||
@@ -152,16 +155,24 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( | |||
} | |||
}; | |||
megdnn_assert(args.src_layout->dtype.category() == | |||
args.dst_layout->dtype.category() && | |||
args.src_tensor->layout.dtype.category() == | |||
args.filter_layout->dtype.category()); | |||
auto src_dtype = args.src_layout->dtype, | |||
filter_dtype = args.filter_layout->dtype, | |||
dst_dtype = args.dst_layout->dtype; | |||
megdnn_assert( | |||
(src_dtype.category() == dst_dtype.category()) || | |||
(args.opr->param().format == param::ConvBias::Format::NCHW4_NCHW && | |||
src_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
dst_dtype.enumv() == DTypeEnum::Float32)); | |||
megdnn_assert(src_dtype.category() == filter_dtype.category()); | |||
if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED) { | |||
auto expected_bias_scale = get_scale(args.src_layout->dtype) * | |||
get_scale(args.filter_layout->dtype); | |||
alpha = expected_bias_scale / get_scale(args.dst_layout->dtype); | |||
if (args.z_layout->ndim > 0) { | |||
alpha = expected_bias_scale; | |||
if (args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED) | |||
alpha /= get_scale(args.dst_layout->dtype); | |||
if (args.z_layout->ndim > 0 && | |||
args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { | |||
beta = get_scale(args.z_layout->dtype) / | |||
get_scale(args.dst_layout->dtype); | |||
} | |||
@@ -232,10 +243,23 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( | |||
break; | |||
case param::ConvBias::NonlineMode::H_SWISH: { | |||
megdnn_assert(args.dst_layout->dtype.category() == | |||
DTypeCategory::QUANTIZED); | |||
auto&& elem_opr = args.handle->create_operator<ElemwiseMultiType>(); | |||
elem_opr->param().mode = ElemwiseMultiType::Param::Mode::QH_SWISH; | |||
elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor)); | |||
DTypeCategory::QUANTIZED || | |||
(args.dst_layout->dtype.category() == | |||
DTypeCategory::FLOAT && | |||
args.opr->param().format == | |||
param::ConvBias::Format::NCHW4_NCHW)); | |||
if (args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED) { | |||
auto&& elem_opr = | |||
args.handle->create_operator<ElemwiseMultiType>(); | |||
elem_opr->param().mode = | |||
ElemwiseMultiType::Param::Mode::QH_SWISH; | |||
elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor)); | |||
} else { | |||
auto&& elem_opr = | |||
args.handle->create_operator<ElemwiseForward>(); | |||
elem_opr->param().mode = ElemwiseForward::Param::Mode::H_SWISH; | |||
elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor)); | |||
} | |||
break; | |||
} | |||
default: | |||
@@ -171,7 +171,8 @@ bool is_cudnn_supported(const BiasForwardSizeArgs& args) { | |||
bool check_bias_share_in_channel(const TensorLayout& bias, | |||
const param::ConvBias::Format format) { | |||
bool share_in_channel = false; | |||
if (format == param::ConvBias::Format::NCHW) { | |||
if (format == param::ConvBias::Format::NCHW || | |||
format == param::ConvBias::Format::NCHW4_NCHW) { | |||
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 && | |||
bias[3] == 1); | |||
} else if (format == param::ConvBias::Format::NHWC) { | |||
@@ -179,7 +180,9 @@ bool check_bias_share_in_channel(const TensorLayout& bias, | |||
bias[2] == 1); | |||
} else if (format == param::ConvBias::Format::NCHW4 || | |||
format == param::ConvBias::Format::NCHW8 || | |||
format == param::ConvBias::Format::NCHW32) { | |||
format == param::ConvBias::Format::NCHW32 || | |||
format == param::ConvBias::Format::NCHW4_NCHW32 || | |||
format == param::ConvBias::Format::NCHW32_NCHW4) { | |||
share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[2] == 1 && | |||
bias[3] == 1); | |||
} else if (format == param::ConvBias::Format::NHWCD4) { | |||
@@ -72,12 +72,19 @@ namespace conv_bias { | |||
const TensorLayout& dst, const TensorLayout& bias, | |||
const TensorLayout& z, | |||
const param::ConvBias& param) { | |||
src_desc.set(src, param.format); | |||
using Format = param::ConvBias::Format; | |||
Format src_format, dst_format; | |||
src_format = dst_format = param.format; | |||
if (param.format == Format::NCHW4_NCHW) { | |||
src_format = Format::NCHW4; | |||
dst_format = Format::NCHW; | |||
} | |||
src_desc.set(src, src_format); | |||
filter_desc.set(filter); | |||
if (z.ndim > 0) { | |||
z_desc.set(z, param.format); | |||
z_desc.set(z, dst_format); | |||
} | |||
dst_desc.set(dst, param.format); | |||
dst_desc.set(dst, dst_format); | |||
conv_desc.set_conv_bias(src.dtype, param, filter.group); | |||
// cudnn requires the bias to be float tensor. | |||
@@ -91,6 +98,12 @@ namespace conv_bias { | |||
float_bias_layout[1] * float_bias_layout[4], | |||
float_bias_layout[2], float_bias_layout[3]}); | |||
bias_desc.set(float_bias_layout); | |||
} else if (param.format == param::ConvBias::Format::NCHW4_NCHW) { | |||
megdnn_assert(float_bias_layout.ndim == 4, | |||
"NCHW4_NCHW format assumes bias tensor is stored " | |||
"in NCHW layout, ndim(expected:4,got:%zu)", | |||
float_bias_layout.ndim); | |||
bias_desc.set(float_bias_layout); | |||
} else { | |||
bias_desc.set(float_bias_layout, param.format); | |||
} | |||
@@ -99,9 +112,16 @@ namespace conv_bias { | |||
void set_conv(const TensorLayout& src, | |||
const CanonizedFilterMeta& filter, | |||
const TensorLayout& dst, const param::ConvBias& param) { | |||
src_desc.set(src, param.format); | |||
using Format = param::ConvBias::Format; | |||
Format src_format, dst_format; | |||
src_format = dst_format = param.format; | |||
if (param.format == Format::NCHW4_NCHW) { | |||
src_format = Format::NCHW4; | |||
dst_format = Format::NCHW; | |||
} | |||
src_desc.set(src, src_format); | |||
filter_desc.set(filter); | |||
dst_desc.set(dst, param.format); | |||
dst_desc.set(dst, dst_format); | |||
conv_desc.set_conv(src.dtype, param, filter.group); | |||
} | |||
}; | |||
@@ -187,11 +187,15 @@ void FilterDesc<Param>::set( | |||
megdnn_assert(filter_meta.group == 1); | |||
#endif | |||
auto filter_format = filter_meta.format; | |||
if (filter_format == param::ConvBias::Format::NCHW4_NCHW) { | |||
filter_format = param::ConvBias::Format::NCHW4; | |||
} | |||
// cuDNN version 6 or below filter_meta.group always is 1. | |||
// So it is compatible for all cuDNN versions. | |||
cudnn_check(cudnnSetFilter4dDescriptor( | |||
desc, to_cudnn_dtype(filter_meta.dtype, filter_meta.format), | |||
to_cudnn_format(filter_meta.format), | |||
desc, to_cudnn_dtype(filter_meta.dtype, filter_format), | |||
to_cudnn_format(filter_format), | |||
filter_meta.ocpg * filter_meta.group, // cudnn 6 group always be 1 | |||
filter_meta.icpg, filter_meta.spatial[0], filter_meta.spatial[1])); | |||
} | |||
@@ -203,6 +203,7 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
DISPATCH(Int8, Int16) | |||
DISPATCH(Int8, Int32) | |||
DISPATCH(QuantizedS8, QuantizedS32) | |||
DISPATCH(QuantizedS8, Float32) | |||
DISPATCH(Quantized8Asymm, QuantizedS32) | |||
DISPATCH(Quantized4Asymm, QuantizedS32) | |||
DISPATCH_RAW(QuantizedS8, QuantizedS32, QuantizedS32, FLOAT32, | |||
@@ -66,6 +66,15 @@ inline void StrategyFwd::on(dt_quint8& s, dt_quint8& f, dt_qint32& d, | |||
} | |||
template <> | |||
inline void StrategyFwd::on(dt_qint8& s, dt_qint8& f, dt_float32& d, | |||
DType src_dt, DType filt_dt, DType) { | |||
auto cast = [](const dt_qint8& val, DType dt) { | |||
return dt.param<dtype::QuantizedS8>().dequantize(val); | |||
}; | |||
d += cast(s, src_dt) * cast(f, filt_dt); | |||
} | |||
template <> | |||
inline void StrategyFwd::on(dt_qint8& s, dt_qint8& f, dt_qint32& d, DType, | |||
DType, DType) { | |||
auto cast = [](const dt_qint8& val) { | |||
@@ -149,8 +158,11 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||
filter_meta.format == Format::NCHW44 || | |||
filter_meta.format == Format::NCHW44_DOT || | |||
filter_meta.format == Format::NCHW4 || | |||
filter_meta.format == Format::NCHW4_NCHW || | |||
filter_meta.format == Format::NCHW4_NCHW32 || | |||
filter_meta.format == Format::NCHW8 || | |||
filter_meta.format == Format::NCHW32) { | |||
filter_meta.format == Format::NCHW32 || | |||
filter_meta.format == Format::NCHW32_NCHW4) { | |||
spatial_start = 2; | |||
channel_pos = 1; | |||
batch_pos = 0; | |||
@@ -176,20 +188,25 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||
if (filter_meta.format == Format::NCHW4 || | |||
filter_meta.format == Format::CHWN4 || | |||
filter_meta.format == Format::NCHW44_DOT || | |||
filter_meta.format == Format::NCHW44) { | |||
filter_meta.format == Format::NCHW44 || | |||
filter_meta.format == Format::NCHW32_NCHW4) { | |||
OC *= 4; | |||
} else if (filter_meta.format == Format::NCHW8 || | |||
filter_meta.format == Format::NCHW88) { | |||
OC *= 8; | |||
} else if (filter_meta.format == Format::NCHW32) { | |||
} else if (filter_meta.format == Format::NCHW32 || | |||
filter_meta.format == Format::NCHW4_NCHW32) { | |||
OC *= 32; | |||
} | |||
size_t FS_G, FS_OC, FS_IC, FS_SPATIAL; | |||
if (filter_meta.format == Format::NCHW || | |||
filter_meta.format == Format::NCHW4 || | |||
filter_meta.format == Format::NCHW4_NCHW || | |||
filter_meta.format == Format::NCHW4_NCHW32 || | |||
filter_meta.format == Format::NCHW8 || | |||
filter_meta.format == Format::NCHW32) { | |||
filter_meta.format == Format::NCHW32 || | |||
filter_meta.format == Format::NCHW32_NCHW4) { | |||
// g, oc, ic, fh, fw | |||
FS_SPATIAL = 1; | |||
FS_IC = FH * FW; | |||
@@ -299,10 +316,39 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||
return n * layout.stride[0] + (c >> 5) * layout.stride[1] + | |||
h * layout.stride[2] + w * layout.stride[3] + | |||
(c & 0x1F) * layout.stride[4]; | |||
} else if (filter_meta.format == Format::NCHW32_NCHW4) { | |||
if (is_output) { | |||
return n * layout.stride[0] + (c / 4) * layout.stride[1] + | |||
h * layout.stride[2] + w * layout.stride[3] + | |||
(c & 0b11) * layout.stride[4]; | |||
} else { | |||
return n * layout.stride[0] + (c >> 5) * layout.stride[1] + | |||
h * layout.stride[2] + w * layout.stride[3] + | |||
(c & 0x1F) * layout.stride[4]; | |||
} | |||
} else if (filter_meta.format == Format::CHWN4) { | |||
return (c / 4) * layout.stride[0] + h * layout.stride[1] + | |||
w * layout.stride[2] + n * layout.stride[3] + | |||
(c % 4) * layout.stride[4]; | |||
} else if (filter_meta.format == Format::NCHW4_NCHW) { | |||
if (is_output) { | |||
return n * layout.stride[0] + c * layout.stride[1] + | |||
h * layout.stride[2] + w * layout.stride[3]; | |||
} else { | |||
return n * layout.stride[0] + (c / 4) * layout.stride[1] + | |||
h * layout.stride[2] + w * layout.stride[3] + | |||
(c & 0b11) * layout.stride[4]; | |||
} | |||
} else if (filter_meta.format == Format::NCHW4_NCHW32) { | |||
if (is_output) { | |||
return n * layout.stride[0] + (c >> 5) * layout.stride[1] + | |||
h * layout.stride[2] + w * layout.stride[3] + | |||
(c & 0x1F) * layout.stride[4]; | |||
} else { | |||
return n * layout.stride[0] + (c / 4) * layout.stride[1] + | |||
h * layout.stride[2] + w * layout.stride[3] + | |||
(c & 0b11) * layout.stride[4]; | |||
} | |||
} else { | |||
megdnn_assert(filter_meta.format == Format::NCHW4, | |||
"invalid conv format"); | |||
@@ -314,7 +360,9 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||
auto get_filter_addr = [&](GroupCounter& gc_out, size_t ic, size_t ic0, | |||
size_t fh, size_t fw) { | |||
if (filter_meta.format == Format::NCHW4) { | |||
if (filter_meta.format == Format::NCHW4 || | |||
filter_meta.format == Format::NCHW4_NCHW || | |||
filter_meta.format == Format::NCHW4_NCHW32) { | |||
return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | |||
(ic - ic0) / 4 * FS_IC * 4 + | |||
(fh * FW + fw) * FS_SPATIAL * 4 + ((ic - ic0) & 0b11); | |||
@@ -322,7 +370,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||
return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | |||
(ic - ic0) / 8 * FS_IC * 8 + | |||
(fh * FW + fw) * FS_SPATIAL * 8 + ((ic - ic0) & 0b111); | |||
} else if (filter_meta.format == Format::NCHW32) { | |||
} else if (filter_meta.format == Format::NCHW32 || | |||
filter_meta.format == Format::NCHW32_NCHW4) { | |||
return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | |||
(ic - ic0) / 32 * FS_IC * 32 + | |||
(fh * FW + fw) * FS_SPATIAL * 32 + ((ic - ic0) & 0x1F); | |||
@@ -569,12 +618,16 @@ template <typename stype, typename ftype, typename dtype, typename comp_type> | |||
void forward(_megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_out dst, | |||
const Convolution::CanonizedFilterMeta& filter_meta) { | |||
megdnn_assert(filter_meta.spatial_ndim == 2); | |||
megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW || | |||
filter_meta.format == param::Convolution::Format::NHWC || | |||
filter_meta.format == param::Convolution::Format::NCHW88 || | |||
filter_meta.format == param::Convolution::Format::NCHW44 || | |||
filter_meta.format == param::Convolution::Format::NCHW44_DOT || | |||
filter_meta.format == param::Convolution::Format::NCHW4); | |||
megdnn_assert( | |||
filter_meta.format == param::Convolution::Format::NCHW || | |||
filter_meta.format == param::Convolution::Format::NHWC || | |||
filter_meta.format == param::Convolution::Format::NCHW88 || | |||
filter_meta.format == param::Convolution::Format::NCHW44 || | |||
filter_meta.format == param::Convolution::Format::NCHW44_DOT || | |||
filter_meta.format == param::Convolution::Format::NCHW4 || | |||
filter_meta.format == param::Convolution::Format::NCHW4_NCHW || | |||
filter_meta.format == param::Convolution::Format::NCHW4_NCHW32 || | |||
filter_meta.format == param::Convolution::Format::NCHW32_NCHW4); | |||
compute2d<stype, ftype, dtype, comp_type, StrategyFwd>( | |||
src, const_cast<ftype*>(fptr), dst, filter_meta); | |||
} | |||
@@ -631,8 +684,11 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
case param::Convolution::Format::NCHW44_DOT: | |||
case param::Convolution::Format::NHWC: | |||
case param::Convolution::Format::NCHW4: | |||
case param::Convolution::Format::NCHW4_NCHW: | |||
case param::Convolution::Format::NCHW4_NCHW32: | |||
case param::Convolution::Format::NCHW8: | |||
case param::Convolution::Format::NCHW32: | |||
case param::Convolution::Format::NCHW32_NCHW4: | |||
case param::Convolution::Format::CHWN4: | |||
compute2d<stype, ftype, dtype, comp_type, StrategyFwd, FilterMeta, | |||
FilterVisitor>(src, filter.compatible_ptr<ftype>(), dst, | |||
@@ -666,7 +722,8 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
using Format = param::ConvBias::Format; | |||
switch (filter_meta.format) { | |||
case Format::NCHW: { | |||
case Format::NCHW: | |||
case Format::NCHW4_NCHW: { | |||
int dst_batch = dst.layout.shape[0]; | |||
int dst_channel = dst.layout.shape[1]; | |||
int chann_stride = dst.layout.shape[2] * dst.layout.shape[3]; | |||
@@ -707,6 +764,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
} while (0) | |||
case Format::NCHW44: | |||
case Format::NCHW44_DOT: | |||
case Format::NCHW32_NCHW4: | |||
case Format::NCHW4: { | |||
BIAS_ADD_NCHWx(4); | |||
break; | |||
@@ -715,6 +773,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
BIAS_ADD_NCHWx(8); | |||
break; | |||
}; | |||
case Format::NCHW4_NCHW32: | |||
case Format::NCHW32: { | |||
BIAS_ADD_NCHWx(32); | |||
break; | |||
@@ -429,6 +429,62 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_NCHW4) { | |||
checker.exec({{1, 4, 2, 2, 4}, {16, 4, 3, 3, 4}, {1, 4, 1, 1, 4}, {}, {}}); | |||
} | |||
TEST_F(CUDA, CONV_BIAS_FORWARD_NCHW4_NCHW) { | |||
require_compute_capability(6, 1); | |||
using namespace conv_bias; | |||
Checker<ConvBiasForward> checker(handle_cuda()); | |||
UniformIntRNG int_rng{-3, 3}; | |||
UniformFloatRNG float_rng{-50, 50}; | |||
ConvBias::Param param; | |||
param.format = ConvBias::Param::Format::NCHW4_NCHW; | |||
param.nonlineMode = ConvBias::Param::NonlineMode::IDENTITY; | |||
checker.set_dtype(0, dtype::QuantizedS8(1.9980618f)) | |||
.set_dtype(1, dtype::QuantizedS8(1.9980927f)) | |||
.set_dtype(2, dtype::Float32()) | |||
.set_dtype(3, dtype::Float32()) | |||
.set_dtype(4, dtype::Float32()) | |||
.set_rng(0, &int_rng) | |||
.set_rng(1, &int_rng) | |||
.set_rng(2, &float_rng) | |||
.set_rng(3, &float_rng) | |||
.set_param(param); | |||
auto opr = handle_cuda()->create_operator<ConvBias>(); | |||
auto run = [&](const TensorShapeArray& shapes) { | |||
opr->param() = param; | |||
TensorLayout dst_layout; | |||
opr->deduce_layout({shapes[0], dtype::Float32()}, | |||
{shapes[1], dtype::Float32()}, {}, {}, dst_layout); | |||
checker.execs({shapes[0], shapes[1], shapes[2], dst_layout, {}}); | |||
}; | |||
run({{1, 4, 4, 4, 4}, {4, 4, 3, 3, 4}, {1, 4, 1, 1}}); | |||
run({{20, 1, 24, 24, 4}, {24, 1, 2, 2, 4}, {1, 24, 1, 1}}); | |||
run({{20, 2, 24, 24, 4}, {24, 2, 3, 3, 4}, {1, 24, 1, 1}}); | |||
param.sparse = ConvBias::Param::Sparse::GROUP; | |||
param.nonlineMode = ConvBias::Param::NonlineMode::RELU; | |||
checker.set_param(param); | |||
run({{1, 4, 24, 24, 4}, {4, 4, 1, 1, 1, 4}, {1, 16, 1, 1}}); | |||
run({{20, 8, 24, 24, 4}, {4, 24, 2, 2, 2, 4}, {1, 96, 1, 1}}); | |||
run({{1, 3, 24, 24, 4}, {3, 8, 1, 3, 3, 4}, {1, 24, 1, 1}}); | |||
param.pad_h = param.pad_w = 1; | |||
param.stride_h = param.stride_w = 2; | |||
checker.set_param(param); | |||
run({{10, 16, 28, 28, 4}, {8, 8, 2, 3, 3, 4}, {1, 64, 1, 1}}); | |||
// case which cudnn not supported | |||
param.sparse = ConvBias::Param::Sparse::DENSE; | |||
param.pad_h = param.pad_w = 1; | |||
param.stride_h = param.stride_w = 1; | |||
param.nonlineMode = ConvBias::Param::NonlineMode::H_SWISH; | |||
checker.set_param(param); | |||
checker.exec({{1, 4, 2, 2, 4}, {16, 4, 3, 3, 4}, {1, 16, 1, 1}, {}, {}}); | |||
} | |||
#endif | |||
TEST_F(CUDA, CONV_BIAS_FORWARD_CHANWISE) { | |||