diff --git a/dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp b/dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp index 3b50a10b..4899531e 100644 --- a/dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp +++ b/dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp @@ -167,12 +167,21 @@ void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec( bias_ptr = reinterpret_cast(args.workspace.raw_ptr + args.filter_layout->span().dist_byte()); - reorder_imma_filter_bias<4, 64>( - reinterpret_cast(filter_ptr), - reinterpret_cast(bias_ptr), - reinterpret_cast(args.filter_tensor->raw_ptr), - args.bias_tensor->compatible_ptr(), co, ci, fh, fw, - stream); + if (args.z_layout->ndim > 0) { + reorder_imma_filter_bias<4, 64>( + reinterpret_cast(filter_ptr), + reinterpret_cast(bias_ptr), + reinterpret_cast(args.filter_tensor->raw_ptr), + args.bias_tensor->compatible_ptr(), co, ci, fh, fw, + stream); + } else { + reorder_imma_filter_bias<4, 64, true>( + reinterpret_cast(filter_ptr), + reinterpret_cast(bias_ptr), + reinterpret_cast(args.filter_tensor->raw_ptr), + args.bias_tensor->compatible_ptr(), co, ci, fh, fw, + stream); + } } uint32_t u32_n = n, u32_ci = ci, u32_hi = hi, u32_wi = wi, u32_fh = fh,