From cff61a53d4fc6cb23dce4d6dbeb70857d46188c1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 12 Apr 2021 16:54:26 +0800 Subject: [PATCH] perf(dnn/cuda): optimize int4 sass conv main loop and epilogue without fuse_z GitOrigin-RevId: 4274e58d64b8de532c03a137cc86eb6274977a2b --- .../sass_implicit_gemm_int4_nchw64_imma.cpp | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp b/dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp index 3b50a10b..4899531e 100644 --- a/dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp +++ b/dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp @@ -167,12 +167,21 @@ void ConvBiasForwardImpl::AlgoSASSInt4NCHW64IMMAImplicitGemm::exec( bias_ptr = reinterpret_cast(args.workspace.raw_ptr + args.filter_layout->span().dist_byte()); - reorder_imma_filter_bias<4, 64>( - reinterpret_cast(filter_ptr), - reinterpret_cast(bias_ptr), - reinterpret_cast(args.filter_tensor->raw_ptr), - args.bias_tensor->compatible_ptr(), co, ci, fh, fw, - stream); + if (args.z_layout->ndim > 0) { + reorder_imma_filter_bias<4, 64>( + reinterpret_cast(filter_ptr), + reinterpret_cast(bias_ptr), + reinterpret_cast(args.filter_tensor->raw_ptr), + args.bias_tensor->compatible_ptr(), co, ci, fh, fw, + stream); + } else { + reorder_imma_filter_bias<4, 64, true>( + reinterpret_cast(filter_ptr), + reinterpret_cast(bias_ptr), + reinterpret_cast(args.filter_tensor->raw_ptr), + args.bias_tensor->compatible_ptr(), co, ci, fh, fw, + stream); + } } uint32_t u32_n = n, u32_ci = ci, u32_hi = hi, u32_wi = wi, u32_fh = fh,