diff --git a/dnn/src/common/utils.cpp b/dnn/src/common/utils.cpp index 6f5c9f49..edf5ddd5 100644 --- a/dnn/src/common/utils.cpp +++ b/dnn/src/common/utils.cpp @@ -353,7 +353,8 @@ bool megdnn::check_bias_share_in_channel(const TensorLayout& bias, format == param::ConvBias::Format::NCHW4_NCHW) { share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 && bias[3] == 1); - } else if (format == param::ConvBias::Format::NHWC) { + } else if (format == param::ConvBias::Format::NHWC || + format == param::ConvBias::Format::NCHW4_NHWC) { share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[1] == 1 && bias[2] == 1); } else if (format == param::ConvBias::Format::NCHW4 || diff --git a/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp b/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp index c8b395e6..7b2d40e4 100644 --- a/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp +++ b/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp @@ -84,8 +84,12 @@ ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_subopr_list( inner_dst_layout, inner_bias_layout, inner_z_layout); Param inner_conv_param = o->param(); - inner_conv_param.format = Param::Format::NCHW4; - + if (layouts[4].dtype.enumv() == DTypeEnum::Float32) { + inner_conv_param.format = Param::Format::NCHW4_NCHW; + } else { + inner_conv_param.format = Param::Format::NCHW4; + } + std::string param_str; Algorithm::serialize_write_pod(inner_conv_param, param_str); @@ -192,9 +196,9 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::exec( inner_conv_param.format = dst_float ? Param::Format::NCHW4_NCHW : Param::Format::NCHW4; auto inner_opr = args.handle->create_operator(); + inner_opr->param() = inner_conv_param; set_execution_policy(args.opr, inner_opr.get()); - inner_opr->param() = inner_conv_param; relayout_nchw_nchw4->exec(*args.src_tensor, inner_src, {}); relayout_weight->exec(*args.filter_tensor, inner_weight, {}); diff --git a/dnn/test/cuda/conv_bias_int8.cpp b/dnn/test/cuda/conv_bias_int8.cpp index 707545f8..df9da234 100644 --- a/dnn/test/cuda/conv_bias_int8.cpp +++ b/dnn/test/cuda/conv_bias_int8.cpp @@ -701,9 +701,11 @@ TEST_F(CUDA, CONV_BIAS_INT8_CHWN4_UNROLL_WIDTH_TENSORCORE_1x1_ALGO_2) { TEST_F(CUDA, FALLBACK_CONV_QS8) { require_compute_capability_eq(7, 5); Checker checker(handle_cuda()); - auto check = [&checker](const std::string&& algo) { + auto check = [&checker](const std::string&& algo, + const std::string&& sub_algo) { checker.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker(algo.c_str())); + conv_bias::ConvBiasAlgoChecker( + {algo.c_str(), {sub_algo.c_str()}})); UniformIntRNG rng{-3, 3}; UniformIntRNG bias_rng{-50, 50}; checker.set_rng(0, &rng) @@ -733,15 +735,17 @@ TEST_F(CUDA, FALLBACK_CONV_QS8) { {}, {}}); }; - check("FALLBACK_CONV_NCHW_QS8"); + check("FALLBACK_CONV_NCHW_QS8", "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM"); } TEST_F(CUDA, FALLBACK_CONV_QS8_F32) { require_compute_capability_eq(7, 5); Checker checker(handle_cuda()); - auto check = [&checker](const std::string&& algo) { + auto check = [&checker](const std::string&& algo, + const std::string&& sub_algo) { checker.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker(algo.c_str())); + conv_bias::ConvBiasAlgoChecker( + {algo.c_str(), {sub_algo.c_str()}})); UniformIntRNG rng{-3, 3}; UniformFloatRNG bias_rng{-50.f, 50.f}; checker.set_rng(0, &rng) @@ -771,7 +775,7 @@ TEST_F(CUDA, FALLBACK_CONV_QS8_F32) { {}, {}}); }; - check("FALLBACK_CONV_NCHW_QS8"); + check("FALLBACK_CONV_NCHW_QS8", "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM"); } TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_WEIGHT_PREPROCESS) {