GitOrigin-RevId: b8ddcd108a
tags/v1.5.0
@@ -69,6 +69,12 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||||
return false; | return false; | ||||
} | } | ||||
if (args.src_layout->dtype.enumv() == DTypeEnum::Float16 && | |||||
args.dst_layout->dtype.enumv() == DTypeEnum::Float16 && | |||||
param.format == param::ConvBias::Format::NHWC) { | |||||
return false; | |||||
} | |||||
//! FIXME: conv kernel of cudnn for NCHW4_NCHW tensor format causes illegal | //! FIXME: conv kernel of cudnn for NCHW4_NCHW tensor format causes illegal | ||||
//! memory access errors, so we have to disable this kernel here. | //! memory access errors, so we have to disable this kernel here. | ||||
if (param.format == param::ConvBias::Format::NCHW4_NCHW || | if (param.format == param::ConvBias::Format::NCHW4_NCHW || | ||||
@@ -151,14 +151,14 @@ bool is_cudnn_supported(const BiasForwardSizeArgs& args) { | |||||
if (args.handle->is_tegra_k1()) | if (args.handle->is_tegra_k1()) | ||||
return false; | return false; | ||||
// TODO: We only support NCHW format now. It seems cuDNN provides support | |||||
// for NHWC as well. | |||||
if (args.filter_meta.format == param::Convolution::Format::NCHW4) { | |||||
if (args.filter_meta.format == param::Convolution::Format::NCHW4 || | |||||
args.filter_meta.format == param::Convolution::Format::NCHW32) { | |||||
if (args.dst_layout->dtype.enumv() != DTypeEnum::Int8 && | if (args.dst_layout->dtype.enumv() != DTypeEnum::Int8 && | ||||
args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS8) { | args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS8) { | ||||
return false; | return false; | ||||
} | } | ||||
} else if (args.filter_meta.format != param::Convolution::Format::NCHW) { | |||||
} else if (args.filter_meta.format != param::Convolution::Format::NCHW && | |||||
args.filter_meta.format != param::Convolution::Format::NHWC) { | |||||
return false; | return false; | ||||
} | } | ||||
auto& fm = args.filter_meta; | auto& fm = args.filter_meta; | ||||
@@ -216,6 +216,41 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_QS8) { | |||||
} | } | ||||
} | } | ||||
TEST_F(CUDA, CONV_BIAS_FORWARD_FLOAT16) { | |||||
require_compute_capability(6, 1); | |||||
Checker<ConvBiasForward> checker(handle_cuda()); | |||||
ConvBias::Param param; | |||||
param.format = ConvBias::Param::Format::NHWC; | |||||
param.nonlineMode = ConvBias::Param::NonlineMode::IDENTITY; | |||||
checker.set_epsilon(2e-2) | |||||
.set_dtype(0, dtype::Float16()) | |||||
.set_dtype(1, dtype::Float16()) | |||||
.set_dtype(2, dtype::Float16()) | |||||
.set_dtype(3, dtype::Float16()) | |||||
.set_dtype(4, dtype::Float16()); | |||||
{ | |||||
auto src_shape = TensorShape{20, 224, 224, 4}; | |||||
auto filter_shape = TensorShape{24, 1, 1, 4}; | |||||
auto bias_shape = TensorShape{1, 1, 1, 24}; | |||||
checker.set_param(param).execs( | |||||
{src_shape, filter_shape, bias_shape, {}, {}}); | |||||
param.compute_mode = ConvBias::Param::ComputeMode::FLOAT32; | |||||
checker.set_param(param).execs( | |||||
{src_shape, filter_shape, bias_shape, {}, {}}); | |||||
} | |||||
{ | |||||
param.sparse = ConvBias::Param::Sparse::GROUP; | |||||
auto src_shape = TensorShape{20, 224, 224, 16}; | |||||
auto filter_shape = TensorShape{4, 4, 1, 1, 4}; | |||||
auto bias_shape = TensorShape{1, 1, 1, 16}; | |||||
checker.set_param(param).execs( | |||||
{src_shape, filter_shape, bias_shape, {}, {}}); | |||||
} | |||||
} | |||||
TEST_F(CUDA, CONV_BIAS_NCHW_QS8) { | TEST_F(CUDA, CONV_BIAS_NCHW_QS8) { | ||||
//! not support NonlineMode::SIGMOID and NonlineMode::H_SWISH | //! not support NonlineMode::SIGMOID and NonlineMode::H_SWISH | ||||
require_compute_capability(6, 1); | require_compute_capability(6, 1); | ||||