From d968942fe369de2e05807a2ee8efb42d61444a03 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 5 Apr 2022 16:38:02 +0800 Subject: [PATCH] perf(cuda): speedup direct large kernel conv GitOrigin-RevId: 3ff6a9caebbd1dc4c5c1c23b51945f7574f186ca --- .../conv_bias/chanwise/depthwise_large_filter.cuh | 26 +- .../chanwise/depthwise_large_filter_algo.cuh | 761 ++++++++++++++++----- 2 files changed, 611 insertions(+), 176 deletions(-) diff --git a/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh index 2de44d5d..c458a1ae 100644 --- a/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh +++ b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh @@ -59,14 +59,15 @@ struct ConvTraitInner { static int const smem_src_h = (OutTileConfig::block_h - 1) * stride_h + FilterTileConfig::unroll_h; static int const smem_buff_h = FilterTileConfig::unroll_h; - static int const smem_load_h = smem_src_h + smem_buff_h; + static int const smem_load_h = smem_src_h + smem_buff_h * + FilterTileConfig::unroll_w * + ThreadConfig::thread_x; static int const smem_h = smem_load_h + smem_buff_h; static int const smem_w = DIVUP((OutTileConfig::block_w - 1) * stride_w + FilterTileConfig::unroll_w * ThreadConfig::thread_x, 2) * 2; - static int const smem_size = smem_h * smem_w; static int const load_w = smem_w > ThreadConfig::nr_threads ? ThreadConfig::nr_threads : smem_w; static int const load_h = 1; @@ -74,21 +75,36 @@ struct ConvTraitInner { static int const reg_w = DIVUP(smem_w, load_w); static bool constexpr check_bounds_h = smem_h % load_h != 0; static bool constexpr check_bounds_w = smem_w % load_w != 0; + // to avoid bank confilct, every bank_offset_line in 8 lines, add one offset + static int const bank_w = smem_w / (4 / sizeof(CompType)); + static int const bank_offset_line = + (bank_w % 32 == 0 || bank_w % FilterTileConfig::unroll_w == 0) + ? 1 + : (bank_w % 16 == 0 ? 2 : 4); + static int const smem_size = smem_h * smem_w + DIVUP(smem_h, bank_offset_line) * + (4 / sizeof(CompType)); }; struct FilterTileCount { static int const smem_flt_h = FilterTileConfig::unroll_h; static int const smem_buff_h = FilterTileConfig::unroll_h; - static int const smem_load_h = smem_flt_h + smem_buff_h; - static int const smem_h = smem_load_h + smem_buff_h; static int const smem_w = FilterTileConfig::unroll_w * ThreadConfig::thread_x; - static int const smem_size = smem_h * smem_w; + static int const smem_load_h = smem_flt_h + smem_buff_h * smem_w; + static int const smem_h = smem_load_h + smem_buff_h; static int const load_w = smem_w > 32 ? 32 : smem_w; static int const load_h = ThreadConfig::nr_threads / load_w; static int const reg_h = 1; static int const reg_w = DIVUP(smem_w, load_w); static bool constexpr check_bounds_h = smem_h % load_h != 0; static bool constexpr check_bounds_w = smem_w % load_w != 0; + // to avoid bank confilct, every bank_offset_line in 8 lines, add one offset + static int const bank_w = smem_w / (4 / sizeof(CompType)); + static int const bank_offset_line = + (bank_w % 32 == 0 || bank_w % FilterTileConfig::unroll_w == 0) + ? 1 + : (bank_w % 16 == 0 ? 2 : 4); + static int const smem_size = smem_h * smem_w + DIVUP(smem_h, bank_offset_line) * + (4 / sizeof(CompType)); }; }; diff --git a/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh index e434b3cf..af2d001c 100644 --- a/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh +++ b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh @@ -119,11 +119,12 @@ __device__ __forceinline__ void Global2SharedMem< #pragma unroll for (int i = 0; i < h_per_thread; ++i) { int smem_h_idx = y_base_idx + i * load_h; + int bank_offset = smem_h_idx / TileCount::bank_offset_line; int src_h_idx; if (is_fwd) { src_h_idx = start_h + smem_h_idx; } else { - src_h_idx = start_h + TileCount::smem_load_h - smem_h_idx - 1; + src_h_idx = start_h - smem_h_idx; } if (check_bounds_h && smem_h_idx >= TileCount::smem_load_h) continue; @@ -146,7 +147,8 @@ __device__ __forceinline__ void Global2SharedMem< TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) { val = g_ptr[src_h_idx / stride_h * stride + src_w_idx / stride_w]; } - *(sh_ptr_as_copy_t(smem_h_idx, smem_w_idx)) = val; + *(sh_ptr_as_copy_t( + smem_h_idx, smem_w_idx + bank_offset * (4 / sizeof(T)))) = val; } } } @@ -261,26 +263,31 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( SrcGlobal2ShareVisitor gl2sh_src = { smem_src, - param.src_w, - is_fwd ? src_start_h - : src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - - param.src_h * param.stride_h / 2), - is_fwd ? src_start_w - : src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - - param.src_w * param.stride_w / 2), - is_fwd ? param.src_h : param.src_h * param.stride_h, - is_fwd ? param.src_w : param.src_w * param.stride_w, - is_fwd ? 1 : param.stride_h, - is_fwd ? 1 : param.stride_w}; - - FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt, - param.flt_w, - is_fwd ? 0 : param.flt_h - 2, - 0, - param.flt_h, - param.flt_w, - 1, - 1}; + static_cast(param.src_w), + static_cast( + is_fwd ? src_start_h + : src_start_h - + (param.out_h / 2 + param.flt_h / 2 - param.pad_h - + param.src_h * param.stride_h / 2)), + static_cast( + is_fwd ? src_start_w + : src_start_w - + (param.out_w / 2 + param.flt_w / 2 - param.pad_w - + param.src_w * param.stride_w / 2)), + static_cast(is_fwd ? param.src_h : param.src_h * param.stride_h), + static_cast(is_fwd ? param.src_w : param.src_w * param.stride_w), + is_fwd ? 1 : static_cast(param.stride_h), + is_fwd ? 1 : static_cast(param.stride_w)}; + + FilterGlobal2ShareVisitor gl2sh_flt = { + smem_flt, + static_cast(param.flt_w), + is_fwd ? 0 : static_cast(param.flt_h - 1), + 0, + static_cast(param.flt_h), + static_cast(param.flt_w), + 1, + 1}; gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w; gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w; @@ -290,14 +297,51 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( __syncthreads(); - T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w], - reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w]; + T2 reg_src[2][SrcTileConfig::unroll_h * t2_src_unroll_w], + reg_flt[2][2][FilterTileConfig::unroll_h * t2_flt_unroll_w]; T2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}}; - for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { - gl2sh_src.copy(); - gl2sh_flt.copy(); +#pragma unroll + for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { +#pragma unroll + for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) { + int src_offset = (off_oh * stride_h + s_h) % SrcTileCount::smem_h * + SrcTileCount::smem_w + + s_w * 2; + reg_src[0][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast( + smem_src_ptr + src_offset + + ((off_oh * stride_h + s_h) / SrcTileCount::bank_offset_line) * 2); + } + } + if (off_ow == ThreadConfig::thread_x - 1) { + reg_src[0][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0}; + } + +#pragma unroll + for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { +#pragma unroll + for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) { + int flt_offset = + (f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + f_w * 2; + reg_flt[0][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast( + smem_flt_ptr + flt_offset + + 2 * (f_h / FilterTileCount::bank_offset_line)); + if (f_w > 0) { + reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] = + T2{reg_flt[0][0][f_h * t2_flt_unroll_w + f_w - 1].y, + reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x}; + } else { + reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] = + T2{0.0, reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x}; + } + } + reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; + reg_flt[0][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = + T2{reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; + } + + for (int fh = 1; fh < param.flt_h - 1; fh += FilterTileConfig::unroll_h * 2) { #pragma unroll for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { #pragma unroll @@ -305,10 +349,15 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( int src_offset = (off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h * SrcTileCount::smem_w + s_w * 2; - reg_src[s_h * t2_src_unroll_w + s_w] = - *reinterpret_cast(smem_src_ptr + src_offset); + reg_src[1][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast( + smem_src_ptr + src_offset + + 2 * ((off_oh * stride_h + fh + s_h) / + SrcTileCount::bank_offset_line)); } } + if (off_ow == ThreadConfig::thread_x - 1) { + reg_src[1][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0}; + } #pragma unroll for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { @@ -317,20 +366,21 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( int flt_offset = (fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + f_w * 2; - reg_flt[0][f_h * t2_flt_unroll_w + f_w] = - *reinterpret_cast(smem_flt_ptr + flt_offset); + reg_flt[1][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast( + smem_flt_ptr + flt_offset + + 2 * ((fh + f_h) / FilterTileCount::bank_offset_line)); if (f_w > 0) { - reg_flt[1][f_h * t2_flt_unroll_w + f_w] = - T2{reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y, - reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; + reg_flt[1][1][f_h * t2_flt_unroll_w + f_w] = + T2{reg_flt[1][0][f_h * t2_flt_unroll_w + f_w - 1].y, + reg_flt[1][0][f_h * t2_flt_unroll_w + f_w].x}; } else { - reg_flt[1][f_h * t2_flt_unroll_w + f_w] = - T2{0.0, reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; + reg_flt[1][1][f_h * t2_flt_unroll_w + f_w] = + T2{0.0, reg_flt[1][0][f_h * t2_flt_unroll_w + f_w].x}; } } - reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; - reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = - T2{reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; + reg_flt[1][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; + reg_flt[1][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{ + reg_flt[1][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; } #pragma unroll @@ -342,9 +392,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( #pragma unroll for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( - reg_flt[ow * stride_w % 2] + reg_flt[0][ow * stride_w % 2] [inner_fh * t2_flt_unroll_w + fw], - reg_src[(inner_fh + oh) * t2_src_unroll_w + fw + + reg_src[0] + [(inner_fh + oh) * t2_src_unroll_w + fw + ow * stride_w / 2], sum[oh * t2_out_unroll_w + ow]); } @@ -352,14 +403,92 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( } } - __syncthreads(); - gl2sh_src.commit(); - gl2sh_flt.commit(); - gl2sh_src.iter_forward(); - gl2sh_flt.iter_forward(); - __syncthreads(); +#pragma unroll + for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { +#pragma unroll + for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) { + int src_offset = (off_oh * stride_h + fh + 1 + s_h) % + SrcTileCount::smem_h * SrcTileCount::smem_w + + s_w * 2; + reg_src[0][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast( + smem_src_ptr + src_offset + + 2 * ((off_oh * stride_h + fh + 1 + s_h) / + SrcTileCount::bank_offset_line)); + } + } + if (off_ow == ThreadConfig::thread_x - 1) { + reg_src[0][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0}; + } + +#pragma unroll + for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { +#pragma unroll + for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) { + int flt_offset = (fh + 1 + f_h) % FilterTileCount::smem_h * + FilterTileCount::smem_w + + f_w * 2; + reg_flt[0][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast( + smem_flt_ptr + flt_offset + + 2 * ((fh + 1 + f_h) / FilterTileCount::bank_offset_line)); + if (f_w > 0) { + reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] = + T2{reg_flt[0][0][f_h * t2_flt_unroll_w + f_w - 1].y, + reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x}; + } else { + reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] = + T2{0.0, reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x}; + } + } + reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; + reg_flt[0][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{ + reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; + } + +#pragma unroll + for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { +#pragma unroll + for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { +#pragma unroll + for (int fw = 0; fw < t2_flt_unroll_w; ++fw) { +#pragma unroll + for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { + sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( + reg_flt[1][ow * stride_w % 2] + [inner_fh * t2_flt_unroll_w + fw], + reg_src[1] + [(inner_fh + oh) * t2_src_unroll_w + fw + + ow * stride_w / 2], + sum[oh * t2_out_unroll_w + ow]); + } + } + } + } } + if (param.flt_h % 2 != 0) { +#pragma unroll + for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { +#pragma unroll + for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { +#pragma unroll + for (int fw = 0; fw < t2_flt_unroll_w; ++fw) { +#pragma unroll + for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { + sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( + reg_flt[0][ow * stride_w % 2] + [inner_fh * t2_flt_unroll_w + fw], + reg_src[0] + [(inner_fh + oh) * t2_src_unroll_w + fw + + ow * stride_w / 2], + sum[oh * t2_out_unroll_w + ow]); + } + } + } + } + } + + __syncthreads(); + for (int o = 0; o < OutTileConfig::unroll_size; ++o) { for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) { sum[o] = megdnn::cuda::hadd2(sum[o], __shfl_xor(sum[o], i, 32)); @@ -429,26 +558,31 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( SrcGlobal2ShareVisitor gl2sh_src = { smem_src, - param.src_w, - is_fwd ? src_start_h - : src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - - param.src_h * param.stride_h / 2), - is_fwd ? src_start_w - : src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - - param.src_w * param.stride_w / 2), - is_fwd ? param.src_h : param.src_h * param.stride_h, - is_fwd ? param.src_w : param.src_w * param.stride_w, - is_fwd ? 1 : param.stride_h, - is_fwd ? 1 : param.stride_w}; - - FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt, - param.flt_w, - is_fwd ? 0 : param.flt_h - 2, - 0, - param.flt_h, - param.flt_w, - 1, - 1}; + static_cast(param.src_w), + static_cast( + is_fwd ? src_start_h + : src_start_h - + (param.out_h / 2 + param.flt_h / 2 - param.pad_h - + param.src_h * param.stride_h / 2)), + static_cast( + is_fwd ? src_start_w + : src_start_w - + (param.out_w / 2 + param.flt_w / 2 - param.pad_w - + param.src_w * param.stride_w / 2)), + static_cast(is_fwd ? param.src_h : param.src_h * param.stride_h), + static_cast(is_fwd ? param.src_w : param.src_w * param.stride_w), + is_fwd ? 1 : static_cast(param.stride_h), + is_fwd ? 1 : static_cast(param.stride_w)}; + + FilterGlobal2ShareVisitor gl2sh_flt = { + smem_flt, + static_cast(param.flt_w), + is_fwd ? 0 : static_cast(param.flt_h - 1), + 0, + static_cast(param.flt_h), + static_cast(param.flt_w), + 1, + 1}; gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w; gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w; @@ -458,14 +592,51 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( __syncthreads(); - T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w], - reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w]; + T2 reg_src[2][SrcTileConfig::unroll_h * t2_src_unroll_w], + reg_flt[2][2][FilterTileConfig::unroll_h * t2_flt_unroll_w]; float2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}}; - for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { - gl2sh_src.copy(); - gl2sh_flt.copy(); +#pragma unroll + for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { +#pragma unroll + for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) { + int src_offset = (off_oh * stride_h + s_h) % SrcTileCount::smem_h * + SrcTileCount::smem_w + + s_w * 2; + reg_src[0][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast( + smem_src_ptr + src_offset + + ((off_oh * stride_h + s_h) / SrcTileCount::bank_offset_line) * 2); + } + } + if (off_ow == ThreadConfig::thread_x - 1) { + reg_src[0][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0}; + } + +#pragma unroll + for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { +#pragma unroll + for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) { + int flt_offset = + (f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + f_w * 2; + reg_flt[0][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast( + smem_flt_ptr + flt_offset + + 2 * (f_h / FilterTileCount::bank_offset_line)); + if (f_w > 0) { + reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] = + T2{reg_flt[0][0][f_h * t2_flt_unroll_w + f_w - 1].y, + reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x}; + } else { + reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] = + T2{0.0, reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x}; + } + } + reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; + reg_flt[0][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = + T2{reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; + } + + for (int fh = 1; fh < param.flt_h - 1; fh += FilterTileConfig::unroll_h * 2) { #pragma unroll for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { #pragma unroll @@ -473,10 +644,15 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( int src_offset = (off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h * SrcTileCount::smem_w + s_w * 2; - reg_src[s_h * t2_src_unroll_w + s_w] = - *reinterpret_cast(smem_src_ptr + src_offset); + reg_src[1][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast( + smem_src_ptr + src_offset + + 2 * ((off_oh * stride_h + fh + s_h) / + SrcTileCount::bank_offset_line)); } } + if (off_ow == ThreadConfig::thread_x - 1) { + reg_src[1][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0}; + } #pragma unroll for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { @@ -485,20 +661,82 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( int flt_offset = (fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + f_w * 2; - reg_flt[0][f_h * t2_flt_unroll_w + f_w] = - *reinterpret_cast(smem_flt_ptr + flt_offset); + reg_flt[1][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast( + smem_flt_ptr + flt_offset + + 2 * ((fh + f_h) / FilterTileCount::bank_offset_line)); + if (f_w > 0) { + reg_flt[1][1][f_h * t2_flt_unroll_w + f_w] = + T2{reg_flt[1][0][f_h * t2_flt_unroll_w + f_w - 1].y, + reg_flt[1][0][f_h * t2_flt_unroll_w + f_w].x}; + } else { + reg_flt[1][1][f_h * t2_flt_unroll_w + f_w] = + T2{0.0, reg_flt[1][0][f_h * t2_flt_unroll_w + f_w].x}; + } + } + reg_flt[1][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; + reg_flt[1][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{ + reg_flt[1][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; + } + +#pragma unroll + for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { +#pragma unroll + for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { +#pragma unroll + for (int fw = 0; fw < t2_flt_unroll_w; ++fw) { +#pragma unroll + for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { + sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( + reg_flt[0][ow * stride_w % 2] + [inner_fh * t2_flt_unroll_w + fw], + reg_src[0] + [(inner_fh + oh) * t2_src_unroll_w + fw + + ow * stride_w / 2], + sum[oh * t2_out_unroll_w + ow]); + } + } + } + } + +#pragma unroll + for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { +#pragma unroll + for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) { + int src_offset = (off_oh * stride_h + fh + 1 + s_h) % + SrcTileCount::smem_h * SrcTileCount::smem_w + + s_w * 2; + reg_src[0][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast( + smem_src_ptr + src_offset + + 2 * ((off_oh * stride_h + fh + 1 + s_h) / + SrcTileCount::bank_offset_line)); + } + } + if (off_ow == ThreadConfig::thread_x - 1) { + reg_src[0][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0}; + } + +#pragma unroll + for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { +#pragma unroll + for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) { + int flt_offset = (fh + 1 + f_h) % FilterTileCount::smem_h * + FilterTileCount::smem_w + + f_w * 2; + reg_flt[0][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast( + smem_flt_ptr + flt_offset + + 2 * ((fh + 1 + f_h) / FilterTileCount::bank_offset_line)); if (f_w > 0) { - reg_flt[1][f_h * t2_flt_unroll_w + f_w] = - T2{reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y, - reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; + reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] = + T2{reg_flt[0][0][f_h * t2_flt_unroll_w + f_w - 1].y, + reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x}; } else { - reg_flt[1][f_h * t2_flt_unroll_w + f_w] = - T2{0.0, reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; + reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] = + T2{0.0, reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x}; } } - reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; - reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = - T2{reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; + reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; + reg_flt[0][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{ + reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; } #pragma unroll @@ -510,24 +748,42 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( #pragma unroll for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( - reg_flt[ow * stride_w % 2] + reg_flt[1][ow * stride_w % 2] [inner_fh * t2_flt_unroll_w + fw], - reg_src[(inner_fh + oh) * t2_src_unroll_w + fw + + reg_src[1] + [(inner_fh + oh) * t2_src_unroll_w + fw + ow * stride_w / 2], sum[oh * t2_out_unroll_w + ow]); } } } } + } - __syncthreads(); - gl2sh_src.commit(); - gl2sh_flt.commit(); - gl2sh_src.iter_forward(); - gl2sh_flt.iter_forward(); - __syncthreads(); + if (param.flt_h % 2 != 0) { +#pragma unroll + for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { +#pragma unroll + for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { +#pragma unroll + for (int fw = 0; fw < t2_flt_unroll_w; ++fw) { +#pragma unroll + for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { + sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( + reg_flt[0][ow * stride_w % 2] + [inner_fh * t2_flt_unroll_w + fw], + reg_src[0] + [(inner_fh + oh) * t2_src_unroll_w + fw + + ow * stride_w / 2], + sum[oh * t2_out_unroll_w + ow]); + } + } + } + } } + __syncthreads(); + for (int o = 0; o < OutTileConfig::unroll_size; ++o) { for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) { sum[o].x += __shfl_xor(sum[o].x, i, 32); @@ -595,26 +851,31 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( SrcGlobal2ShareVisitor gl2sh_src = { smem_src, - param.src_w, - is_fwd ? src_start_h - : src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - - param.src_h * param.stride_h / 2), - is_fwd ? src_start_w - : src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - - param.src_w * param.stride_w / 2), - is_fwd ? param.src_h : param.src_h * param.stride_h, - is_fwd ? param.src_w : param.src_w * param.stride_w, - is_fwd ? 1 : param.stride_h, - is_fwd ? 1 : param.stride_w}; - - FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt, - param.flt_w, - is_fwd ? 0 : param.flt_h - 2, - 0, - param.flt_h, - param.flt_w, - 1, - 1}; + static_cast(param.src_w), + static_cast( + is_fwd ? src_start_h + : src_start_h - + (param.out_h / 2 + param.flt_h / 2 - param.pad_h - + param.src_h * param.stride_h / 2)), + static_cast( + is_fwd ? src_start_w + : src_start_w - + (param.out_w / 2 + param.flt_w / 2 - param.pad_w - + param.src_w * param.stride_w / 2)), + static_cast(is_fwd ? param.src_h : param.src_h * param.stride_h), + static_cast(is_fwd ? param.src_w : param.src_w * param.stride_w), + is_fwd ? 1 : static_cast(param.stride_h), + is_fwd ? 1 : static_cast(param.stride_w)}; + + FilterGlobal2ShareVisitor gl2sh_flt = { + smem_flt, + static_cast(param.flt_w), + is_fwd ? 0 : static_cast(param.flt_h - 1), + 0, + static_cast(param.flt_h), + static_cast(param.flt_w), + 1, + 1}; gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w; gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w; @@ -624,22 +885,43 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( __syncthreads(); - T reg_src[SrcTileConfig::unroll_h * SrcTileConfig::unroll_w], - reg_flt[FilterTileConfig::unroll_h * FilterTileConfig::unroll_w]; + T reg_src[2][SrcTileConfig::unroll_h * SrcTileConfig::unroll_w], + reg_flt[2][FilterTileConfig::unroll_h * FilterTileConfig::unroll_w]; T sum[OutTileConfig::unroll_size] = {0.0}; - for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { - gl2sh_src.copy(); - gl2sh_flt.copy(); +#pragma unroll + for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { +#pragma unroll + for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { + reg_src[0][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr + [(off_oh * stride_h + s_h) % SrcTileCount::smem_h * + SrcTileCount::smem_w + + s_w + (off_oh * stride_h + s_h) / SrcTileCount::bank_offset_line]; + } + } + +#pragma unroll + for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { +#pragma unroll + for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { + reg_flt[0][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr + [(f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + f_w + + f_h / FilterTileCount::bank_offset_line]; + } + } + + for (int fh = 1; fh < param.flt_h + 1; fh += FilterTileConfig::unroll_h * 2) { #pragma unroll for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { #pragma unroll for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { - reg_src[s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr + reg_src[1][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr [(off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h * SrcTileCount::smem_w + - s_w]; + s_w + + (off_oh * stride_h + fh + s_h) / + SrcTileCount::bank_offset_line]; } } @@ -647,14 +929,54 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { #pragma unroll for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { - reg_flt[f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr + reg_flt[1][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr [(fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + - f_w]; + f_w + (fh + f_h) / FilterTileCount::bank_offset_line]; + } + } +#pragma unroll + for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { +#pragma unroll + for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { +#pragma unroll + for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) { +#pragma unroll + for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { + sum[oh * OutTileConfig::unroll_w + ow] += + reg_flt[0][inner_fh * FilterTileConfig::unroll_w + fw] * + reg_src[0] + [(inner_fh + oh) * SrcTileConfig::unroll_w + fw + + ow * stride_w]; + } + } + } + } + +#pragma unroll + for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { +#pragma unroll + for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { + reg_src[0][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr + [(off_oh * stride_h + fh + 1 + s_h) % SrcTileCount::smem_h * + SrcTileCount::smem_w + + s_w + + (off_oh * stride_h + fh + 1 + s_h) / + SrcTileCount::bank_offset_line]; } } #pragma unroll + for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { +#pragma unroll + for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { + reg_flt[0][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr + [(fh + 1 + f_h) % FilterTileCount::smem_h * + FilterTileCount::smem_w + + f_w + (fh + 1 + f_h) / FilterTileCount::bank_offset_line]; + } + } +#pragma unroll for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { #pragma unroll for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { @@ -663,22 +985,38 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( #pragma unroll for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { sum[oh * OutTileConfig::unroll_w + ow] += - reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] * - reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw + + reg_flt[1][inner_fh * FilterTileConfig::unroll_w + fw] * + reg_src[1] + [(inner_fh + oh) * SrcTileConfig::unroll_w + fw + ow * stride_w]; } } } } + } - __syncthreads(); - gl2sh_src.commit(); - gl2sh_flt.commit(); - gl2sh_src.iter_forward(); - gl2sh_flt.iter_forward(); - __syncthreads(); + if (param.flt_h % 2 != 0) { +#pragma unroll + for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { +#pragma unroll + for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { +#pragma unroll + for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) { +#pragma unroll + for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { + sum[oh * OutTileConfig::unroll_w + ow] += + reg_flt[0][inner_fh * FilterTileConfig::unroll_w + fw] * + reg_src[0] + [(inner_fh + oh) * SrcTileConfig::unroll_w + fw + + ow * stride_w]; + } + } + } + } } + __syncthreads(); + for (int o = 0; o < OutTileConfig::unroll_size; ++o) { for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) { sum[o] += __shfl_xor(sum[o], i, 32); @@ -743,26 +1081,31 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( SrcGlobal2ShareVisitor gl2sh_src = { smem_src, - param.src_w, - is_fwd ? src_start_h - : src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - - param.src_h * param.stride_h / 2), - is_fwd ? src_start_w - : src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - - param.src_w * param.stride_w / 2), - is_fwd ? param.src_h : param.src_h * param.stride_h, - is_fwd ? param.src_w : param.src_w * param.stride_w, - is_fwd ? 1 : param.stride_h, - is_fwd ? 1 : param.stride_w}; - - FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt, - param.flt_w, - is_fwd ? 0 : param.flt_h - 2, - 0, - param.flt_h, - param.flt_w, - 1, - 1}; + static_cast(param.src_w), + static_cast( + is_fwd ? src_start_h + : src_start_h - + (param.out_h / 2 + param.flt_h / 2 - param.pad_h - + param.src_h * param.stride_h / 2)), + static_cast( + is_fwd ? src_start_w + : src_start_w - + (param.out_w / 2 + param.flt_w / 2 - param.pad_w - + param.src_w * param.stride_w / 2)), + static_cast(is_fwd ? param.src_h : param.src_h * param.stride_h), + static_cast(is_fwd ? param.src_w : param.src_w * param.stride_w), + is_fwd ? 1 : static_cast(param.stride_h), + is_fwd ? 1 : static_cast(param.stride_w)}; + + FilterGlobal2ShareVisitor gl2sh_flt = { + smem_flt, + static_cast(param.flt_w), + is_fwd ? 0 : static_cast(param.flt_h - 1), + 0, + static_cast(param.flt_h), + static_cast(param.flt_w), + 1, + 1}; gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w; gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w; @@ -772,22 +1115,43 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( __syncthreads(); - T reg_src[SrcTileConfig::unroll_h * SrcTileConfig::unroll_w], - reg_flt[FilterTileConfig::unroll_h * FilterTileConfig::unroll_w]; + T reg_src[2][SrcTileConfig::unroll_h * SrcTileConfig::unroll_w], + reg_flt[2][FilterTileConfig::unroll_h * FilterTileConfig::unroll_w]; T sum[OutTileConfig::unroll_size] = {0.0}; - for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { - gl2sh_src.copy(); - gl2sh_flt.copy(); +#pragma unroll + for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { +#pragma unroll + for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { + reg_src[0][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr + [(off_oh * stride_h + s_h) % SrcTileCount::smem_h * + SrcTileCount::smem_w + + s_w + (off_oh * stride_h + s_h) / SrcTileCount::bank_offset_line]; + } + } + +#pragma unroll + for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { +#pragma unroll + for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { + reg_flt[0][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr + [(f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + f_w + + f_h / FilterTileCount::bank_offset_line]; + } + } + + for (int fh = 1; fh < param.flt_h + 1; fh += FilterTileConfig::unroll_h * 2) { #pragma unroll for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { #pragma unroll for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { - reg_src[s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr + reg_src[1][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr [(off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h * SrcTileCount::smem_w + - s_w]; + s_w + + (off_oh * stride_h + fh + s_h) / + SrcTileCount::bank_offset_line]; } } @@ -795,14 +1159,54 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { #pragma unroll for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { - reg_flt[f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr + reg_flt[1][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr [(fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + - f_w]; + f_w + (fh + f_h) / FilterTileCount::bank_offset_line]; + } + } +#pragma unroll + for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { +#pragma unroll + for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { +#pragma unroll + for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) { +#pragma unroll + for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { + sum[oh * OutTileConfig::unroll_w + ow] += + reg_flt[0][inner_fh * FilterTileConfig::unroll_w + fw] * + reg_src[0] + [(inner_fh + oh) * SrcTileConfig::unroll_w + fw + + ow * stride_w]; + } + } + } + } + +#pragma unroll + for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { +#pragma unroll + for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { + reg_src[0][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr + [(off_oh * stride_h + fh + 1 + s_h) % SrcTileCount::smem_h * + SrcTileCount::smem_w + + s_w + + (off_oh * stride_h + fh + 1 + s_h) / + SrcTileCount::bank_offset_line]; } } #pragma unroll + for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { +#pragma unroll + for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { + reg_flt[0][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr + [(fh + 1 + f_h) % FilterTileCount::smem_h * + FilterTileCount::smem_w + + f_w + (fh + 1 + f_h) / FilterTileCount::bank_offset_line]; + } + } +#pragma unroll for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { #pragma unroll for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { @@ -811,22 +1215,38 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( #pragma unroll for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { sum[oh * OutTileConfig::unroll_w + ow] += - reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] * - reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw + + reg_flt[1][inner_fh * FilterTileConfig::unroll_w + fw] * + reg_src[1] + [(inner_fh + oh) * SrcTileConfig::unroll_w + fw + ow * stride_w]; } } } } + } - __syncthreads(); - gl2sh_src.commit(); - gl2sh_flt.commit(); - gl2sh_src.iter_forward(); - gl2sh_flt.iter_forward(); - __syncthreads(); + if (param.flt_h % 2 != 0) { +#pragma unroll + for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { +#pragma unroll + for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { +#pragma unroll + for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) { +#pragma unroll + for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { + sum[oh * OutTileConfig::unroll_w + ow] += + reg_flt[0][inner_fh * FilterTileConfig::unroll_w + fw] * + reg_src[0] + [(inner_fh + oh) * SrcTileConfig::unroll_w + fw + + ow * stride_w]; + } + } + } + } } + __syncthreads(); + for (int o = 0; o < OutTileConfig::unroll_size; ++o) { for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) { sum[o] += __shfl_xor(sum[o], i, 32); @@ -899,11 +1319,10 @@ void LaunchDepthwiseConv2dGPU( } \ } -#define INSTANCE_A(type1, type2, a, direction) \ - if (param.flt_w > a * 4) { \ - INSTANCE_AB(type1, type2, a, 15, direction) \ - else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB( \ - type1, type2, a, 3, direction) \ +#define INSTANCE_A(type1, type2, a, direction) \ + if (param.flt_w > a * 4) { \ + INSTANCE_AB(type1, type2, a, 7, direction) \ + else INSTANCE_AB(type1, type2, a, 3, direction) \ } #define INSTANCE(type1, type2, direction) \