GitOrigin-RevId: 29cd73f87b
release-1.2
@@ -39,7 +39,10 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
'NCHW44','NCHW44_DOT', | 'NCHW44','NCHW44_DOT', | ||||
Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'), | Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'), | ||||
Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'), | Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'), | ||||
Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'), | |||||
Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'), | |||||
Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), | |||||
Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), | |||||
Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), | |||||
Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' | Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' | ||||
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.')) | 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.')) | ||||
) | ) | ||||
@@ -48,38 +48,52 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||||
megdnn_assert(src.dtype.enumv() == filter.dtype.enumv()); | megdnn_assert(src.dtype.enumv() == filter.dtype.enumv()); | ||||
} | } | ||||
if (src.dtype.enumv() == DTypeEnum::QuantizedS8) { | if (src.dtype.enumv() == DTypeEnum::QuantizedS8) { | ||||
float scale_src = src.dtype.param<dtype::QuantizedS8>().scale; | |||||
float scale_filter = 0.f; | |||||
if (param().format == param::ConvBias::Format::NCHW_WINOGRAD || | |||||
param().format == param::ConvBias::Format::NCHW88_WINOGRAD || | |||||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) { | |||||
if (filter.dtype.enumv() == DTypeEnum::QuantizedS32) { | |||||
//!int8 winogradf23_44 using float,QuantizedS32 take the scale | |||||
scale_filter = filter.dtype.param<dtype::QuantizedS32>().scale; | |||||
if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) { | |||||
float scale_src = src.dtype.param<dtype::QuantizedS8>().scale; | |||||
float scale_filter = 0.f; | |||||
if (param().format == param::ConvBias::Format::NCHW_WINOGRAD || | |||||
param().format == param::ConvBias::Format::NCHW88_WINOGRAD || | |||||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) { | |||||
if (filter.dtype.enumv() == DTypeEnum::QuantizedS32) { | |||||
//! int8 winogradf23_44 using float,QuantizedS32 take the | |||||
//! scale | |||||
scale_filter = | |||||
filter.dtype.param<dtype::QuantizedS32>().scale; | |||||
} else { | |||||
scale_filter = | |||||
filter.dtype.param<dtype::QuantizedS16>().scale; | |||||
} | |||||
} else { | } else { | ||||
scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale; | |||||
scale_filter = filter.dtype.param<dtype::QuantizedS8>().scale; | |||||
} | } | ||||
float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale; | |||||
megdnn_assert( | |||||
std::abs(scale_src * scale_filter - scale_bias) < 1e-6, | |||||
"scale_src: %f scale_filter: %f scale_bias: %f", scale_src, | |||||
scale_filter, scale_bias); | |||||
} else { | } else { | ||||
scale_filter = filter.dtype.param<dtype::QuantizedS8>().scale; | |||||
megdnn_assert(bias.dtype.enumv() == DTypeEnum::Float32); | |||||
} | } | ||||
float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale; | |||||
megdnn_assert(std::abs(scale_src * scale_filter - scale_bias) < 1e-6, | |||||
"scale_src: %f scale_filter: %f scale_bias: %f", | |||||
scale_src, scale_filter, scale_bias); | |||||
} else if (src.dtype.enumv() == DTypeEnum::Quantized8Asymm) { | } else if (src.dtype.enumv() == DTypeEnum::Quantized8Asymm) { | ||||
float scale_src = src.dtype.param<dtype::Quantized8Asymm>().scale; | |||||
float scale_filter = 0.f; | |||||
if (param().format == param::ConvBias::Format::NCHW_WINOGRAD || | |||||
param().format == param::ConvBias::Format::NCHW88_WINOGRAD || | |||||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) { | |||||
scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale; | |||||
if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) { | |||||
float scale_src = src.dtype.param<dtype::Quantized8Asymm>().scale; | |||||
float scale_filter = 0.f; | |||||
if (param().format == param::ConvBias::Format::NCHW_WINOGRAD || | |||||
param().format == param::ConvBias::Format::NCHW88_WINOGRAD || | |||||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) { | |||||
scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale; | |||||
} else { | |||||
scale_filter = | |||||
filter.dtype.param<dtype::Quantized8Asymm>().scale; | |||||
} | |||||
float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale; | |||||
megdnn_assert( | |||||
std::abs(scale_src * scale_filter - scale_bias) < 1e-6, | |||||
"scale_src: %f scale_filter: %f scale_bias: %f", scale_src, | |||||
scale_filter, scale_bias); | |||||
} else { | } else { | ||||
scale_filter = filter.dtype.param<dtype::Quantized8Asymm>().scale; | |||||
megdnn_assert(bias.dtype.enumv() == DTypeEnum::Float32); | |||||
} | } | ||||
float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale; | |||||
megdnn_assert(std::abs(scale_src * scale_filter - scale_bias) < 1e-6, | |||||
"scale_src: %f scale_filter: %f scale_bias: %f", | |||||
scale_src, scale_filter, scale_bias); | |||||
} | } | ||||
auto ret = check_layout_fwd(src, filter, dst); | auto ret = check_layout_fwd(src, filter, dst); | ||||
@@ -101,7 +115,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||||
if (check_eq(bias, dst)) | if (check_eq(bias, dst)) | ||||
return ret; | return ret; | ||||
if (param().format == param::ConvBias::Format::NCHW || | if (param().format == param::ConvBias::Format::NCHW || | ||||
param().format == param::ConvBias::Format::NCHW_WINOGRAD) { | |||||
param().format == param::ConvBias::Format::NCHW_WINOGRAD || | |||||
param().format == param::ConvBias::Format::NCHW4_NCHW) { | |||||
megdnn_assert(bias.shape[0] == 1); | megdnn_assert(bias.shape[0] == 1); | ||||
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", | megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", | ||||
bias.to_string().c_str(), dst.to_string().c_str()); | bias.to_string().c_str(), dst.to_string().c_str()); | ||||
@@ -116,7 +131,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||||
} else if (param().format == param::ConvBias::Format::NCHW4 || | } else if (param().format == param::ConvBias::Format::NCHW4 || | ||||
param().format == param::ConvBias::Format::NCHW44 || | param().format == param::ConvBias::Format::NCHW44 || | ||||
param().format == param::ConvBias::Format::NCHW44_DOT || | param().format == param::ConvBias::Format::NCHW44_DOT || | ||||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) { | |||||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD || | |||||
param().format == param::ConvBias::Format::NCHW32_NCHW4) { | |||||
megdnn_assert(bias.shape[0] == 1); | megdnn_assert(bias.shape[0] == 1); | ||||
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", | megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", | ||||
bias.to_string().c_str(), dst.to_string().c_str()); | bias.to_string().c_str(), dst.to_string().c_str()); | ||||
@@ -132,7 +148,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||||
megdnn_assert(bias.shape[2] == 1); | megdnn_assert(bias.shape[2] == 1); | ||||
megdnn_assert(bias.shape[3] == 1); | megdnn_assert(bias.shape[3] == 1); | ||||
megdnn_assert(bias.shape[4] == 8); | megdnn_assert(bias.shape[4] == 8); | ||||
} else if (param().format == param::ConvBias::Format::NCHW32) { | |||||
} else if (param().format == param::ConvBias::Format::NCHW32 || | |||||
param().format == param::ConvBias::Format::NCHW4_NCHW32) { | |||||
megdnn_assert(bias.shape[0] == 1); | megdnn_assert(bias.shape[0] == 1); | ||||
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", | megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", | ||||
bias.to_string().c_str(), dst.to_string().c_str()); | bias.to_string().c_str(), dst.to_string().c_str()); | ||||
@@ -163,6 +180,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||||
param::ConvBias::Format::NCHW88_WINOGRAD); | param::ConvBias::Format::NCHW88_WINOGRAD); | ||||
megdnn_assert(param().format != | megdnn_assert(param().format != | ||||
param::ConvBias::Format::NCHW44_WINOGRAD); | param::ConvBias::Format::NCHW44_WINOGRAD); | ||||
megdnn_assert(param().format != param::ConvBias::Format::NCHW4_NCHW32); | |||||
megdnn_assert(param().format != param::ConvBias::Format::NCHW32_NCHW4); | |||||
megdnn_assert(z.dtype.enumv() == dst.dtype.enumv()); | megdnn_assert(z.dtype.enumv() == dst.dtype.enumv()); | ||||
megdnn_assert(z.eq_shape(dst)); | megdnn_assert(z.eq_shape(dst)); | ||||
} | } | ||||
@@ -443,7 +443,10 @@ void make_canonized_filter_meta_nchwx( | |||||
*/ | */ | ||||
megdnn_assert(param.format == Param::Format::NCHW4 || | megdnn_assert(param.format == Param::Format::NCHW4 || | ||||
param.format == Param::Format::NCHW8 || | param.format == Param::Format::NCHW8 || | ||||
param.format == Param::Format::NCHW32); | |||||
param.format == Param::Format::NCHW32 || | |||||
param.format == Param::Format::NCHW4_NCHW || | |||||
param.format == Param::Format::NCHW4_NCHW32 || | |||||
param.format == Param::Format::NCHW32_NCHW4); | |||||
auto img_ndim = src_ndim - 3; | auto img_ndim = src_ndim - 3; | ||||
size_t flt_start = 0, flt_spatial_start = 2; | size_t flt_start = 0, flt_spatial_start = 2; | ||||
if (param.sparse == Param::Sparse::DENSE) { | if (param.sparse == Param::Sparse::DENSE) { | ||||
@@ -568,7 +571,9 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta( | |||||
make_canonized_filter_meta_nhwcd4<Parameter>(src_ndim, filter, | make_canonized_filter_meta_nhwcd4<Parameter>(src_ndim, filter, | ||||
param(), ret); | param(), ret); | ||||
} | } | ||||
} else if (param().format == Param::Format::NCHW4) { | |||||
} else if (param().format == Param::Format::NCHW4 || | |||||
param().format == Param::Format::NCHW4_NCHW || | |||||
param().format == Param::Format::NCHW4_NCHW32) { | |||||
make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter, | make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter, | ||||
param(), ret); | param(), ret); | ||||
} else if (param().format == Param::Format::NCHW8) { | } else if (param().format == Param::Format::NCHW8) { | ||||
@@ -583,7 +588,8 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta( | |||||
param().format == Param::Format::NCHW44_WINOGRAD) { | param().format == Param::Format::NCHW44_WINOGRAD) { | ||||
make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter, | make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter, | ||||
param(), ret); | param(), ret); | ||||
} else if (param().format == Param::Format::NCHW32) { | |||||
} else if (param().format == Param::Format::NCHW32 || | |||||
param().format == Param::Format::NCHW32_NCHW4) { | |||||
make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter, | make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter, | ||||
param(), ret); | param(), ret); | ||||
} else if (param().format == Param::Format::CHWN4) { | } else if (param().format == Param::Format::CHWN4) { | ||||
@@ -627,6 +633,9 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src, | |||||
if (dst.valid() && dst.enumv() == src.enumv()) { | if (dst.valid() && dst.enumv() == src.enumv()) { | ||||
supported_dst_dtype.push_back(dst); | supported_dst_dtype.push_back(dst); | ||||
} | } | ||||
if (src.enumv() == DTypeEnum::QuantizedS8) { | |||||
supported_dst_dtype.push_back(dtype::Float32()); | |||||
} | |||||
} else if (src.enumv() == DTypeEnum::QuantizedS32) { | } else if (src.enumv() == DTypeEnum::QuantizedS32) { | ||||
//! ConvolutionBackwardData: s8(filter) + s8(dst) -> s32(src) | //! ConvolutionBackwardData: s8(filter) + s8(dst) -> s32(src) | ||||
megdnn_assert(filter.enumv() == DTypeEnum::QuantizedS8); | megdnn_assert(filter.enumv() == DTypeEnum::QuantizedS8); | ||||
@@ -697,10 +706,13 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
} else { | } else { | ||||
megdnn_assert(param().format == Param::Format::NHWCD4 || | megdnn_assert(param().format == Param::Format::NHWCD4 || | ||||
param().format == Param::Format::NCHW4 || | param().format == Param::Format::NCHW4 || | ||||
param().format == Param::Format::NCHW4_NCHW || | |||||
param().format == Param::Format::NCHW4_NCHW32 || | |||||
param().format == Param::Format::NCHW44 || | param().format == Param::Format::NCHW44 || | ||||
param().format == Param::Format::NCHW44_DOT || | param().format == Param::Format::NCHW44_DOT || | ||||
param().format == Param::Format::NCHW8 || | param().format == Param::Format::NCHW8 || | ||||
param().format == Param::Format::NCHW32 || | param().format == Param::Format::NCHW32 || | ||||
param().format == Param::Format::NCHW32_NCHW4 || | |||||
param().format == Param::Format::NCHW88 || | param().format == Param::Format::NCHW88 || | ||||
param().format == Param::Format::NCHW88_WINOGRAD || | param().format == Param::Format::NCHW88_WINOGRAD || | ||||
param().format == Param::Format::NCHW44_WINOGRAD || | param().format == Param::Format::NCHW44_WINOGRAD || | ||||
@@ -720,13 +732,17 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
filter.ndim == img_dim + 4 || | filter.ndim == img_dim + 4 || | ||||
filter.ndim == img_dim + 5, | filter.ndim == img_dim + 5, | ||||
"%s", errmsg().c_str()); | "%s", errmsg().c_str()); | ||||
if (param().format == Param::Format::NCHW4) { | |||||
if (param().format == Param::Format::NCHW4 || | |||||
param().format == Param::Format::NCHW4_NCHW || | |||||
param().format == Param::Format::NCHW4_NCHW32) { | |||||
megdnn_assert(src.ndim == 5 && | megdnn_assert(src.ndim == 5 && | ||||
(filter.ndim == 5 || filter.ndim == 6 || | (filter.ndim == 5 || filter.ndim == 6 || | ||||
filter.ndim == 7) && | filter.ndim == 7) && | ||||
src[src.ndim - 1] == 4 && | src[src.ndim - 1] == 4 && | ||||
filter[filter.ndim - 1] == 4, | filter[filter.ndim - 1] == 4, | ||||
"NCHW4 require src and filter's ndim is 5 or 6, and " | |||||
"NCHW4/NCHW4_NCHW/NCHW4_NCHW32 require src and " | |||||
"filter's ndim is " | |||||
"5 or 6, and " | |||||
"last shape " | "last shape " | ||||
"is 4 " | "is 4 " | ||||
"but got src %s, filter %s", | "but got src %s, filter %s", | ||||
@@ -742,15 +758,17 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
"but got src %s, filter %s", | "but got src %s, filter %s", | ||||
src.to_string().c_str(), filter.to_string().c_str()); | src.to_string().c_str(), filter.to_string().c_str()); | ||||
} | } | ||||
if (param().format == Param::Format::NCHW32) { | |||||
megdnn_assert( | |||||
src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) && | |||||
src[src.ndim - 1] == 32 && | |||||
filter[filter.ndim - 1] == 32, | |||||
"NCHW32 require src and filter's ndim is 5 or 6, and last " | |||||
"shape is 32 " | |||||
"but got src %s, filter %s", | |||||
src.to_string().c_str(), filter.to_string().c_str()); | |||||
if (param().format == Param::Format::NCHW32 || | |||||
param().format == Param::Format::NCHW32_NCHW4) { | |||||
megdnn_assert(src.ndim == 5 && | |||||
(filter.ndim == 5 || filter.ndim == 6) && | |||||
src[src.ndim - 1] == 32 && | |||||
filter[filter.ndim - 1] == 32, | |||||
"NCHW32/NCHW32_NCHW4 require src and filter's ndim " | |||||
"is 5 or 6, and last " | |||||
"shape is 32 " | |||||
"but got src %s, filter %s", | |||||
src.to_string().c_str(), filter.to_string().c_str()); | |||||
} | } | ||||
if (param().format == Param::Format::NCHW88 || | if (param().format == Param::Format::NCHW88 || | ||||
param().format == Param::Format::NCHW88_WINOGRAD) { | param().format == Param::Format::NCHW88_WINOGRAD) { | ||||
@@ -943,6 +961,55 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[1], | dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[1], | ||||
cflt.stride[1], cflt.padding[1]); | cflt.stride[1], cflt.padding[1]); | ||||
dst[4] = 4; | dst[4] = 4; | ||||
} else if (param().format == Param::Format::NCHW4_NCHW) { | |||||
megdnn_assert(src.ndim == 5, | |||||
"invalid src ndim for NCHW4_NCHW, expected=5, got=%zu", | |||||
src.ndim); | |||||
megdnn_assert(cflt.icpg * cflt.group == src[1] * 4, | |||||
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, | |||||
cflt.group); | |||||
dst.ndim = 4; | |||||
dst[0] = src[0]; | |||||
auto oc = cflt.ocpg * cflt.group; | |||||
dst[1] = oc; | |||||
dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], | |||||
cflt.stride[0], cflt.padding[0]); | |||||
dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], | |||||
cflt.stride[1], cflt.padding[1]); | |||||
} else if (param().format == Param::Format::NCHW4_NCHW32) { | |||||
megdnn_assert(src.ndim == 5, | |||||
"invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu", | |||||
src.ndim); | |||||
megdnn_assert(cflt.icpg * cflt.group == src[1] * 4, | |||||
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, | |||||
cflt.group); | |||||
dst.ndim = src.ndim; | |||||
dst[0] = src[0]; | |||||
auto oc = cflt.ocpg * cflt.group; | |||||
megdnn_assert(oc % 32 == 0); | |||||
dst[1] = oc / 32; | |||||
dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], | |||||
cflt.stride[0], cflt.padding[0]); | |||||
dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], | |||||
cflt.stride[1], cflt.padding[1]); | |||||
dst[4] = 32; | |||||
} else if (param().format == Param::Format::NCHW32_NCHW4) { | |||||
megdnn_assert(src.ndim == 5, | |||||
"invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu", | |||||
src.ndim); | |||||
megdnn_assert(cflt.icpg * cflt.group == src[1] * 32, | |||||
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, | |||||
cflt.group); | |||||
dst.ndim = src.ndim; | |||||
dst[0] = src[0]; | |||||
auto oc = cflt.ocpg * cflt.group; | |||||
megdnn_assert(oc % 4 == 0); | |||||
dst[1] = oc / 4; | |||||
dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], | |||||
cflt.stride[0], cflt.padding[0]); | |||||
dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], | |||||
cflt.stride[1], cflt.padding[1]); | |||||
dst[4] = 4; | |||||
} else { | } else { | ||||
megdnn_assert(param().format == Param::Format::NHWCD4); | megdnn_assert(param().format == Param::Format::NHWCD4); | ||||
megdnn_assert(src.ndim == 5, | megdnn_assert(src.ndim == 5, | ||||
@@ -31,6 +31,9 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||||
args.bias_layout->eq_shape(*args.dst_layout)) | args.bias_layout->eq_shape(*args.dst_layout)) | ||||
return false; | return false; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
if (param.format == param::ConvBias::Format::NCHW4_NCHW32 || | |||||
param.format == param::ConvBias::Format::NCHW32_NCHW4) | |||||
return false; | |||||
if (param.format == param::ConvBias::Format::NCHW && | if (param.format == param::ConvBias::Format::NCHW && | ||||
(param.dilate_h != 1 || param.dilate_w != 1) && | (param.dilate_h != 1 || param.dilate_w != 1) && | ||||
m_cudnn_enum == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) { | m_cudnn_enum == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) { | ||||
@@ -152,16 +155,24 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( | |||||
} | } | ||||
}; | }; | ||||
megdnn_assert(args.src_layout->dtype.category() == | |||||
args.dst_layout->dtype.category() && | |||||
args.src_tensor->layout.dtype.category() == | |||||
args.filter_layout->dtype.category()); | |||||
auto src_dtype = args.src_layout->dtype, | |||||
filter_dtype = args.filter_layout->dtype, | |||||
dst_dtype = args.dst_layout->dtype; | |||||
megdnn_assert( | |||||
(src_dtype.category() == dst_dtype.category()) || | |||||
(args.opr->param().format == param::ConvBias::Format::NCHW4_NCHW && | |||||
src_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
dst_dtype.enumv() == DTypeEnum::Float32)); | |||||
megdnn_assert(src_dtype.category() == filter_dtype.category()); | |||||
if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED) { | if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED) { | ||||
auto expected_bias_scale = get_scale(args.src_layout->dtype) * | auto expected_bias_scale = get_scale(args.src_layout->dtype) * | ||||
get_scale(args.filter_layout->dtype); | get_scale(args.filter_layout->dtype); | ||||
alpha = expected_bias_scale / get_scale(args.dst_layout->dtype); | |||||
if (args.z_layout->ndim > 0) { | |||||
alpha = expected_bias_scale; | |||||
if (args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED) | |||||
alpha /= get_scale(args.dst_layout->dtype); | |||||
if (args.z_layout->ndim > 0 && | |||||
args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { | |||||
beta = get_scale(args.z_layout->dtype) / | beta = get_scale(args.z_layout->dtype) / | ||||
get_scale(args.dst_layout->dtype); | get_scale(args.dst_layout->dtype); | ||||
} | } | ||||
@@ -232,10 +243,23 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( | |||||
break; | break; | ||||
case param::ConvBias::NonlineMode::H_SWISH: { | case param::ConvBias::NonlineMode::H_SWISH: { | ||||
megdnn_assert(args.dst_layout->dtype.category() == | megdnn_assert(args.dst_layout->dtype.category() == | ||||
DTypeCategory::QUANTIZED); | |||||
auto&& elem_opr = args.handle->create_operator<ElemwiseMultiType>(); | |||||
elem_opr->param().mode = ElemwiseMultiType::Param::Mode::QH_SWISH; | |||||
elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor)); | |||||
DTypeCategory::QUANTIZED || | |||||
(args.dst_layout->dtype.category() == | |||||
DTypeCategory::FLOAT && | |||||
args.opr->param().format == | |||||
param::ConvBias::Format::NCHW4_NCHW)); | |||||
if (args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED) { | |||||
auto&& elem_opr = | |||||
args.handle->create_operator<ElemwiseMultiType>(); | |||||
elem_opr->param().mode = | |||||
ElemwiseMultiType::Param::Mode::QH_SWISH; | |||||
elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor)); | |||||
} else { | |||||
auto&& elem_opr = | |||||
args.handle->create_operator<ElemwiseForward>(); | |||||
elem_opr->param().mode = ElemwiseForward::Param::Mode::H_SWISH; | |||||
elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor)); | |||||
} | |||||
break; | break; | ||||
} | } | ||||
default: | default: | ||||
@@ -171,7 +171,8 @@ bool is_cudnn_supported(const BiasForwardSizeArgs& args) { | |||||
bool check_bias_share_in_channel(const TensorLayout& bias, | bool check_bias_share_in_channel(const TensorLayout& bias, | ||||
const param::ConvBias::Format format) { | const param::ConvBias::Format format) { | ||||
bool share_in_channel = false; | bool share_in_channel = false; | ||||
if (format == param::ConvBias::Format::NCHW) { | |||||
if (format == param::ConvBias::Format::NCHW || | |||||
format == param::ConvBias::Format::NCHW4_NCHW) { | |||||
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 && | share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 && | ||||
bias[3] == 1); | bias[3] == 1); | ||||
} else if (format == param::ConvBias::Format::NHWC) { | } else if (format == param::ConvBias::Format::NHWC) { | ||||
@@ -179,7 +180,9 @@ bool check_bias_share_in_channel(const TensorLayout& bias, | |||||
bias[2] == 1); | bias[2] == 1); | ||||
} else if (format == param::ConvBias::Format::NCHW4 || | } else if (format == param::ConvBias::Format::NCHW4 || | ||||
format == param::ConvBias::Format::NCHW8 || | format == param::ConvBias::Format::NCHW8 || | ||||
format == param::ConvBias::Format::NCHW32) { | |||||
format == param::ConvBias::Format::NCHW32 || | |||||
format == param::ConvBias::Format::NCHW4_NCHW32 || | |||||
format == param::ConvBias::Format::NCHW32_NCHW4) { | |||||
share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[2] == 1 && | share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[2] == 1 && | ||||
bias[3] == 1); | bias[3] == 1); | ||||
} else if (format == param::ConvBias::Format::NHWCD4) { | } else if (format == param::ConvBias::Format::NHWCD4) { | ||||
@@ -72,12 +72,19 @@ namespace conv_bias { | |||||
const TensorLayout& dst, const TensorLayout& bias, | const TensorLayout& dst, const TensorLayout& bias, | ||||
const TensorLayout& z, | const TensorLayout& z, | ||||
const param::ConvBias& param) { | const param::ConvBias& param) { | ||||
src_desc.set(src, param.format); | |||||
using Format = param::ConvBias::Format; | |||||
Format src_format, dst_format; | |||||
src_format = dst_format = param.format; | |||||
if (param.format == Format::NCHW4_NCHW) { | |||||
src_format = Format::NCHW4; | |||||
dst_format = Format::NCHW; | |||||
} | |||||
src_desc.set(src, src_format); | |||||
filter_desc.set(filter); | filter_desc.set(filter); | ||||
if (z.ndim > 0) { | if (z.ndim > 0) { | ||||
z_desc.set(z, param.format); | |||||
z_desc.set(z, dst_format); | |||||
} | } | ||||
dst_desc.set(dst, param.format); | |||||
dst_desc.set(dst, dst_format); | |||||
conv_desc.set_conv_bias(src.dtype, param, filter.group); | conv_desc.set_conv_bias(src.dtype, param, filter.group); | ||||
// cudnn requires the bias to be float tensor. | // cudnn requires the bias to be float tensor. | ||||
@@ -91,6 +98,12 @@ namespace conv_bias { | |||||
float_bias_layout[1] * float_bias_layout[4], | float_bias_layout[1] * float_bias_layout[4], | ||||
float_bias_layout[2], float_bias_layout[3]}); | float_bias_layout[2], float_bias_layout[3]}); | ||||
bias_desc.set(float_bias_layout); | bias_desc.set(float_bias_layout); | ||||
} else if (param.format == param::ConvBias::Format::NCHW4_NCHW) { | |||||
megdnn_assert(float_bias_layout.ndim == 4, | |||||
"NCHW4_NCHW format assumes bias tensor is stored " | |||||
"in NCHW layout, ndim(expected:4,got:%zu)", | |||||
float_bias_layout.ndim); | |||||
bias_desc.set(float_bias_layout); | |||||
} else { | } else { | ||||
bias_desc.set(float_bias_layout, param.format); | bias_desc.set(float_bias_layout, param.format); | ||||
} | } | ||||
@@ -99,9 +112,16 @@ namespace conv_bias { | |||||
void set_conv(const TensorLayout& src, | void set_conv(const TensorLayout& src, | ||||
const CanonizedFilterMeta& filter, | const CanonizedFilterMeta& filter, | ||||
const TensorLayout& dst, const param::ConvBias& param) { | const TensorLayout& dst, const param::ConvBias& param) { | ||||
src_desc.set(src, param.format); | |||||
using Format = param::ConvBias::Format; | |||||
Format src_format, dst_format; | |||||
src_format = dst_format = param.format; | |||||
if (param.format == Format::NCHW4_NCHW) { | |||||
src_format = Format::NCHW4; | |||||
dst_format = Format::NCHW; | |||||
} | |||||
src_desc.set(src, src_format); | |||||
filter_desc.set(filter); | filter_desc.set(filter); | ||||
dst_desc.set(dst, param.format); | |||||
dst_desc.set(dst, dst_format); | |||||
conv_desc.set_conv(src.dtype, param, filter.group); | conv_desc.set_conv(src.dtype, param, filter.group); | ||||
} | } | ||||
}; | }; | ||||
@@ -187,11 +187,15 @@ void FilterDesc<Param>::set( | |||||
megdnn_assert(filter_meta.group == 1); | megdnn_assert(filter_meta.group == 1); | ||||
#endif | #endif | ||||
auto filter_format = filter_meta.format; | |||||
if (filter_format == param::ConvBias::Format::NCHW4_NCHW) { | |||||
filter_format = param::ConvBias::Format::NCHW4; | |||||
} | |||||
// cuDNN version 6 or below filter_meta.group always is 1. | // cuDNN version 6 or below filter_meta.group always is 1. | ||||
// So it is compatible for all cuDNN versions. | // So it is compatible for all cuDNN versions. | ||||
cudnn_check(cudnnSetFilter4dDescriptor( | cudnn_check(cudnnSetFilter4dDescriptor( | ||||
desc, to_cudnn_dtype(filter_meta.dtype, filter_meta.format), | |||||
to_cudnn_format(filter_meta.format), | |||||
desc, to_cudnn_dtype(filter_meta.dtype, filter_format), | |||||
to_cudnn_format(filter_format), | |||||
filter_meta.ocpg * filter_meta.group, // cudnn 6 group always be 1 | filter_meta.ocpg * filter_meta.group, // cudnn 6 group always be 1 | ||||
filter_meta.icpg, filter_meta.spatial[0], filter_meta.spatial[1])); | filter_meta.icpg, filter_meta.spatial[0], filter_meta.spatial[1])); | ||||
} | } | ||||
@@ -203,6 +203,7 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
DISPATCH(Int8, Int16) | DISPATCH(Int8, Int16) | ||||
DISPATCH(Int8, Int32) | DISPATCH(Int8, Int32) | ||||
DISPATCH(QuantizedS8, QuantizedS32) | DISPATCH(QuantizedS8, QuantizedS32) | ||||
DISPATCH(QuantizedS8, Float32) | |||||
DISPATCH(Quantized8Asymm, QuantizedS32) | DISPATCH(Quantized8Asymm, QuantizedS32) | ||||
DISPATCH(Quantized4Asymm, QuantizedS32) | DISPATCH(Quantized4Asymm, QuantizedS32) | ||||
DISPATCH_RAW(QuantizedS8, QuantizedS32, QuantizedS32, FLOAT32, | DISPATCH_RAW(QuantizedS8, QuantizedS32, QuantizedS32, FLOAT32, | ||||
@@ -66,6 +66,15 @@ inline void StrategyFwd::on(dt_quint8& s, dt_quint8& f, dt_qint32& d, | |||||
} | } | ||||
template <> | template <> | ||||
inline void StrategyFwd::on(dt_qint8& s, dt_qint8& f, dt_float32& d, | |||||
DType src_dt, DType filt_dt, DType) { | |||||
auto cast = [](const dt_qint8& val, DType dt) { | |||||
return dt.param<dtype::QuantizedS8>().dequantize(val); | |||||
}; | |||||
d += cast(s, src_dt) * cast(f, filt_dt); | |||||
} | |||||
template <> | |||||
inline void StrategyFwd::on(dt_qint8& s, dt_qint8& f, dt_qint32& d, DType, | inline void StrategyFwd::on(dt_qint8& s, dt_qint8& f, dt_qint32& d, DType, | ||||
DType, DType) { | DType, DType) { | ||||
auto cast = [](const dt_qint8& val) { | auto cast = [](const dt_qint8& val) { | ||||
@@ -149,8 +158,11 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
filter_meta.format == Format::NCHW44 || | filter_meta.format == Format::NCHW44 || | ||||
filter_meta.format == Format::NCHW44_DOT || | filter_meta.format == Format::NCHW44_DOT || | ||||
filter_meta.format == Format::NCHW4 || | filter_meta.format == Format::NCHW4 || | ||||
filter_meta.format == Format::NCHW4_NCHW || | |||||
filter_meta.format == Format::NCHW4_NCHW32 || | |||||
filter_meta.format == Format::NCHW8 || | filter_meta.format == Format::NCHW8 || | ||||
filter_meta.format == Format::NCHW32) { | |||||
filter_meta.format == Format::NCHW32 || | |||||
filter_meta.format == Format::NCHW32_NCHW4) { | |||||
spatial_start = 2; | spatial_start = 2; | ||||
channel_pos = 1; | channel_pos = 1; | ||||
batch_pos = 0; | batch_pos = 0; | ||||
@@ -176,20 +188,25 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
if (filter_meta.format == Format::NCHW4 || | if (filter_meta.format == Format::NCHW4 || | ||||
filter_meta.format == Format::CHWN4 || | filter_meta.format == Format::CHWN4 || | ||||
filter_meta.format == Format::NCHW44_DOT || | filter_meta.format == Format::NCHW44_DOT || | ||||
filter_meta.format == Format::NCHW44) { | |||||
filter_meta.format == Format::NCHW44 || | |||||
filter_meta.format == Format::NCHW32_NCHW4) { | |||||
OC *= 4; | OC *= 4; | ||||
} else if (filter_meta.format == Format::NCHW8 || | } else if (filter_meta.format == Format::NCHW8 || | ||||
filter_meta.format == Format::NCHW88) { | filter_meta.format == Format::NCHW88) { | ||||
OC *= 8; | OC *= 8; | ||||
} else if (filter_meta.format == Format::NCHW32) { | |||||
} else if (filter_meta.format == Format::NCHW32 || | |||||
filter_meta.format == Format::NCHW4_NCHW32) { | |||||
OC *= 32; | OC *= 32; | ||||
} | } | ||||
size_t FS_G, FS_OC, FS_IC, FS_SPATIAL; | size_t FS_G, FS_OC, FS_IC, FS_SPATIAL; | ||||
if (filter_meta.format == Format::NCHW || | if (filter_meta.format == Format::NCHW || | ||||
filter_meta.format == Format::NCHW4 || | filter_meta.format == Format::NCHW4 || | ||||
filter_meta.format == Format::NCHW4_NCHW || | |||||
filter_meta.format == Format::NCHW4_NCHW32 || | |||||
filter_meta.format == Format::NCHW8 || | filter_meta.format == Format::NCHW8 || | ||||
filter_meta.format == Format::NCHW32) { | |||||
filter_meta.format == Format::NCHW32 || | |||||
filter_meta.format == Format::NCHW32_NCHW4) { | |||||
// g, oc, ic, fh, fw | // g, oc, ic, fh, fw | ||||
FS_SPATIAL = 1; | FS_SPATIAL = 1; | ||||
FS_IC = FH * FW; | FS_IC = FH * FW; | ||||
@@ -299,10 +316,39 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
return n * layout.stride[0] + (c >> 5) * layout.stride[1] + | return n * layout.stride[0] + (c >> 5) * layout.stride[1] + | ||||
h * layout.stride[2] + w * layout.stride[3] + | h * layout.stride[2] + w * layout.stride[3] + | ||||
(c & 0x1F) * layout.stride[4]; | (c & 0x1F) * layout.stride[4]; | ||||
} else if (filter_meta.format == Format::NCHW32_NCHW4) { | |||||
if (is_output) { | |||||
return n * layout.stride[0] + (c / 4) * layout.stride[1] + | |||||
h * layout.stride[2] + w * layout.stride[3] + | |||||
(c & 0b11) * layout.stride[4]; | |||||
} else { | |||||
return n * layout.stride[0] + (c >> 5) * layout.stride[1] + | |||||
h * layout.stride[2] + w * layout.stride[3] + | |||||
(c & 0x1F) * layout.stride[4]; | |||||
} | |||||
} else if (filter_meta.format == Format::CHWN4) { | } else if (filter_meta.format == Format::CHWN4) { | ||||
return (c / 4) * layout.stride[0] + h * layout.stride[1] + | return (c / 4) * layout.stride[0] + h * layout.stride[1] + | ||||
w * layout.stride[2] + n * layout.stride[3] + | w * layout.stride[2] + n * layout.stride[3] + | ||||
(c % 4) * layout.stride[4]; | (c % 4) * layout.stride[4]; | ||||
} else if (filter_meta.format == Format::NCHW4_NCHW) { | |||||
if (is_output) { | |||||
return n * layout.stride[0] + c * layout.stride[1] + | |||||
h * layout.stride[2] + w * layout.stride[3]; | |||||
} else { | |||||
return n * layout.stride[0] + (c / 4) * layout.stride[1] + | |||||
h * layout.stride[2] + w * layout.stride[3] + | |||||
(c & 0b11) * layout.stride[4]; | |||||
} | |||||
} else if (filter_meta.format == Format::NCHW4_NCHW32) { | |||||
if (is_output) { | |||||
return n * layout.stride[0] + (c >> 5) * layout.stride[1] + | |||||
h * layout.stride[2] + w * layout.stride[3] + | |||||
(c & 0x1F) * layout.stride[4]; | |||||
} else { | |||||
return n * layout.stride[0] + (c / 4) * layout.stride[1] + | |||||
h * layout.stride[2] + w * layout.stride[3] + | |||||
(c & 0b11) * layout.stride[4]; | |||||
} | |||||
} else { | } else { | ||||
megdnn_assert(filter_meta.format == Format::NCHW4, | megdnn_assert(filter_meta.format == Format::NCHW4, | ||||
"invalid conv format"); | "invalid conv format"); | ||||
@@ -314,7 +360,9 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
auto get_filter_addr = [&](GroupCounter& gc_out, size_t ic, size_t ic0, | auto get_filter_addr = [&](GroupCounter& gc_out, size_t ic, size_t ic0, | ||||
size_t fh, size_t fw) { | size_t fh, size_t fw) { | ||||
if (filter_meta.format == Format::NCHW4) { | |||||
if (filter_meta.format == Format::NCHW4 || | |||||
filter_meta.format == Format::NCHW4_NCHW || | |||||
filter_meta.format == Format::NCHW4_NCHW32) { | |||||
return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | ||||
(ic - ic0) / 4 * FS_IC * 4 + | (ic - ic0) / 4 * FS_IC * 4 + | ||||
(fh * FW + fw) * FS_SPATIAL * 4 + ((ic - ic0) & 0b11); | (fh * FW + fw) * FS_SPATIAL * 4 + ((ic - ic0) & 0b11); | ||||
@@ -322,7 +370,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | ||||
(ic - ic0) / 8 * FS_IC * 8 + | (ic - ic0) / 8 * FS_IC * 8 + | ||||
(fh * FW + fw) * FS_SPATIAL * 8 + ((ic - ic0) & 0b111); | (fh * FW + fw) * FS_SPATIAL * 8 + ((ic - ic0) & 0b111); | ||||
} else if (filter_meta.format == Format::NCHW32) { | |||||
} else if (filter_meta.format == Format::NCHW32 || | |||||
filter_meta.format == Format::NCHW32_NCHW4) { | |||||
return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | ||||
(ic - ic0) / 32 * FS_IC * 32 + | (ic - ic0) / 32 * FS_IC * 32 + | ||||
(fh * FW + fw) * FS_SPATIAL * 32 + ((ic - ic0) & 0x1F); | (fh * FW + fw) * FS_SPATIAL * 32 + ((ic - ic0) & 0x1F); | ||||
@@ -569,12 +618,16 @@ template <typename stype, typename ftype, typename dtype, typename comp_type> | |||||
void forward(_megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_out dst, | void forward(_megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_out dst, | ||||
const Convolution::CanonizedFilterMeta& filter_meta) { | const Convolution::CanonizedFilterMeta& filter_meta) { | ||||
megdnn_assert(filter_meta.spatial_ndim == 2); | megdnn_assert(filter_meta.spatial_ndim == 2); | ||||
megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW || | |||||
filter_meta.format == param::Convolution::Format::NHWC || | |||||
filter_meta.format == param::Convolution::Format::NCHW88 || | |||||
filter_meta.format == param::Convolution::Format::NCHW44 || | |||||
filter_meta.format == param::Convolution::Format::NCHW44_DOT || | |||||
filter_meta.format == param::Convolution::Format::NCHW4); | |||||
megdnn_assert( | |||||
filter_meta.format == param::Convolution::Format::NCHW || | |||||
filter_meta.format == param::Convolution::Format::NHWC || | |||||
filter_meta.format == param::Convolution::Format::NCHW88 || | |||||
filter_meta.format == param::Convolution::Format::NCHW44 || | |||||
filter_meta.format == param::Convolution::Format::NCHW44_DOT || | |||||
filter_meta.format == param::Convolution::Format::NCHW4 || | |||||
filter_meta.format == param::Convolution::Format::NCHW4_NCHW || | |||||
filter_meta.format == param::Convolution::Format::NCHW4_NCHW32 || | |||||
filter_meta.format == param::Convolution::Format::NCHW32_NCHW4); | |||||
compute2d<stype, ftype, dtype, comp_type, StrategyFwd>( | compute2d<stype, ftype, dtype, comp_type, StrategyFwd>( | ||||
src, const_cast<ftype*>(fptr), dst, filter_meta); | src, const_cast<ftype*>(fptr), dst, filter_meta); | ||||
} | } | ||||
@@ -631,8 +684,11 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
case param::Convolution::Format::NCHW44_DOT: | case param::Convolution::Format::NCHW44_DOT: | ||||
case param::Convolution::Format::NHWC: | case param::Convolution::Format::NHWC: | ||||
case param::Convolution::Format::NCHW4: | case param::Convolution::Format::NCHW4: | ||||
case param::Convolution::Format::NCHW4_NCHW: | |||||
case param::Convolution::Format::NCHW4_NCHW32: | |||||
case param::Convolution::Format::NCHW8: | case param::Convolution::Format::NCHW8: | ||||
case param::Convolution::Format::NCHW32: | case param::Convolution::Format::NCHW32: | ||||
case param::Convolution::Format::NCHW32_NCHW4: | |||||
case param::Convolution::Format::CHWN4: | case param::Convolution::Format::CHWN4: | ||||
compute2d<stype, ftype, dtype, comp_type, StrategyFwd, FilterMeta, | compute2d<stype, ftype, dtype, comp_type, StrategyFwd, FilterMeta, | ||||
FilterVisitor>(src, filter.compatible_ptr<ftype>(), dst, | FilterVisitor>(src, filter.compatible_ptr<ftype>(), dst, | ||||
@@ -666,7 +722,8 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
using Format = param::ConvBias::Format; | using Format = param::ConvBias::Format; | ||||
switch (filter_meta.format) { | switch (filter_meta.format) { | ||||
case Format::NCHW: { | |||||
case Format::NCHW: | |||||
case Format::NCHW4_NCHW: { | |||||
int dst_batch = dst.layout.shape[0]; | int dst_batch = dst.layout.shape[0]; | ||||
int dst_channel = dst.layout.shape[1]; | int dst_channel = dst.layout.shape[1]; | ||||
int chann_stride = dst.layout.shape[2] * dst.layout.shape[3]; | int chann_stride = dst.layout.shape[2] * dst.layout.shape[3]; | ||||
@@ -707,6 +764,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
} while (0) | } while (0) | ||||
case Format::NCHW44: | case Format::NCHW44: | ||||
case Format::NCHW44_DOT: | case Format::NCHW44_DOT: | ||||
case Format::NCHW32_NCHW4: | |||||
case Format::NCHW4: { | case Format::NCHW4: { | ||||
BIAS_ADD_NCHWx(4); | BIAS_ADD_NCHWx(4); | ||||
break; | break; | ||||
@@ -715,6 +773,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
BIAS_ADD_NCHWx(8); | BIAS_ADD_NCHWx(8); | ||||
break; | break; | ||||
}; | }; | ||||
case Format::NCHW4_NCHW32: | |||||
case Format::NCHW32: { | case Format::NCHW32: { | ||||
BIAS_ADD_NCHWx(32); | BIAS_ADD_NCHWx(32); | ||||
break; | break; | ||||
@@ -429,6 +429,62 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_NCHW4) { | |||||
checker.exec({{1, 4, 2, 2, 4}, {16, 4, 3, 3, 4}, {1, 4, 1, 1, 4}, {}, {}}); | checker.exec({{1, 4, 2, 2, 4}, {16, 4, 3, 3, 4}, {1, 4, 1, 1, 4}, {}, {}}); | ||||
} | } | ||||
TEST_F(CUDA, CONV_BIAS_FORWARD_NCHW4_NCHW) { | |||||
require_compute_capability(6, 1); | |||||
using namespace conv_bias; | |||||
Checker<ConvBiasForward> checker(handle_cuda()); | |||||
UniformIntRNG int_rng{-3, 3}; | |||||
UniformFloatRNG float_rng{-50, 50}; | |||||
ConvBias::Param param; | |||||
param.format = ConvBias::Param::Format::NCHW4_NCHW; | |||||
param.nonlineMode = ConvBias::Param::NonlineMode::IDENTITY; | |||||
checker.set_dtype(0, dtype::QuantizedS8(1.9980618f)) | |||||
.set_dtype(1, dtype::QuantizedS8(1.9980927f)) | |||||
.set_dtype(2, dtype::Float32()) | |||||
.set_dtype(3, dtype::Float32()) | |||||
.set_dtype(4, dtype::Float32()) | |||||
.set_rng(0, &int_rng) | |||||
.set_rng(1, &int_rng) | |||||
.set_rng(2, &float_rng) | |||||
.set_rng(3, &float_rng) | |||||
.set_param(param); | |||||
auto opr = handle_cuda()->create_operator<ConvBias>(); | |||||
auto run = [&](const TensorShapeArray& shapes) { | |||||
opr->param() = param; | |||||
TensorLayout dst_layout; | |||||
opr->deduce_layout({shapes[0], dtype::Float32()}, | |||||
{shapes[1], dtype::Float32()}, {}, {}, dst_layout); | |||||
checker.execs({shapes[0], shapes[1], shapes[2], dst_layout, {}}); | |||||
}; | |||||
run({{1, 4, 4, 4, 4}, {4, 4, 3, 3, 4}, {1, 4, 1, 1}}); | |||||
run({{20, 1, 24, 24, 4}, {24, 1, 2, 2, 4}, {1, 24, 1, 1}}); | |||||
run({{20, 2, 24, 24, 4}, {24, 2, 3, 3, 4}, {1, 24, 1, 1}}); | |||||
param.sparse = ConvBias::Param::Sparse::GROUP; | |||||
param.nonlineMode = ConvBias::Param::NonlineMode::RELU; | |||||
checker.set_param(param); | |||||
run({{1, 4, 24, 24, 4}, {4, 4, 1, 1, 1, 4}, {1, 16, 1, 1}}); | |||||
run({{20, 8, 24, 24, 4}, {4, 24, 2, 2, 2, 4}, {1, 96, 1, 1}}); | |||||
run({{1, 3, 24, 24, 4}, {3, 8, 1, 3, 3, 4}, {1, 24, 1, 1}}); | |||||
param.pad_h = param.pad_w = 1; | |||||
param.stride_h = param.stride_w = 2; | |||||
checker.set_param(param); | |||||
run({{10, 16, 28, 28, 4}, {8, 8, 2, 3, 3, 4}, {1, 64, 1, 1}}); | |||||
// case which cudnn not supported | |||||
param.sparse = ConvBias::Param::Sparse::DENSE; | |||||
param.pad_h = param.pad_w = 1; | |||||
param.stride_h = param.stride_w = 1; | |||||
param.nonlineMode = ConvBias::Param::NonlineMode::H_SWISH; | |||||
checker.set_param(param); | |||||
checker.exec({{1, 4, 2, 2, 4}, {16, 4, 3, 3, 4}, {1, 16, 1, 1}, {}, {}}); | |||||
} | |||||
#endif | #endif | ||||
TEST_F(CUDA, CONV_BIAS_FORWARD_CHANWISE) { | TEST_F(CUDA, CONV_BIAS_FORWARD_CHANWISE) { | ||||