|
|
@@ -119,11 +119,12 @@ __device__ __forceinline__ void Global2SharedMem< |
|
|
|
#pragma unroll |
|
|
|
for (int i = 0; i < h_per_thread; ++i) { |
|
|
|
int smem_h_idx = y_base_idx + i * load_h; |
|
|
|
int bank_offset = smem_h_idx / TileCount::bank_offset_line; |
|
|
|
int src_h_idx; |
|
|
|
if (is_fwd) { |
|
|
|
src_h_idx = start_h + smem_h_idx; |
|
|
|
} else { |
|
|
|
src_h_idx = start_h + TileCount::smem_load_h - smem_h_idx - 1; |
|
|
|
src_h_idx = start_h - smem_h_idx; |
|
|
|
} |
|
|
|
if (check_bounds_h && smem_h_idx >= TileCount::smem_load_h) |
|
|
|
continue; |
|
|
@@ -146,7 +147,8 @@ __device__ __forceinline__ void Global2SharedMem< |
|
|
|
TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) { |
|
|
|
val = g_ptr[src_h_idx / stride_h * stride + src_w_idx / stride_w]; |
|
|
|
} |
|
|
|
*(sh_ptr_as_copy_t(smem_h_idx, smem_w_idx)) = val; |
|
|
|
*(sh_ptr_as_copy_t( |
|
|
|
smem_h_idx, smem_w_idx + bank_offset * (4 / sizeof(T)))) = val; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
@@ -261,26 +263,31 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
|
|
|
|
SrcGlobal2ShareVisitor gl2sh_src = { |
|
|
|
smem_src, |
|
|
|
param.src_w, |
|
|
|
is_fwd ? src_start_h |
|
|
|
: src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - |
|
|
|
param.src_h * param.stride_h / 2), |
|
|
|
is_fwd ? src_start_w |
|
|
|
: src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - |
|
|
|
param.src_w * param.stride_w / 2), |
|
|
|
is_fwd ? param.src_h : param.src_h * param.stride_h, |
|
|
|
is_fwd ? param.src_w : param.src_w * param.stride_w, |
|
|
|
is_fwd ? 1 : param.stride_h, |
|
|
|
is_fwd ? 1 : param.stride_w}; |
|
|
|
|
|
|
|
FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt, |
|
|
|
param.flt_w, |
|
|
|
is_fwd ? 0 : param.flt_h - 2, |
|
|
|
0, |
|
|
|
param.flt_h, |
|
|
|
param.flt_w, |
|
|
|
1, |
|
|
|
1}; |
|
|
|
static_cast<int>(param.src_w), |
|
|
|
static_cast<int>( |
|
|
|
is_fwd ? src_start_h |
|
|
|
: src_start_h - |
|
|
|
(param.out_h / 2 + param.flt_h / 2 - param.pad_h - |
|
|
|
param.src_h * param.stride_h / 2)), |
|
|
|
static_cast<int>( |
|
|
|
is_fwd ? src_start_w |
|
|
|
: src_start_w - |
|
|
|
(param.out_w / 2 + param.flt_w / 2 - param.pad_w - |
|
|
|
param.src_w * param.stride_w / 2)), |
|
|
|
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), |
|
|
|
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), |
|
|
|
is_fwd ? 1 : static_cast<int>(param.stride_h), |
|
|
|
is_fwd ? 1 : static_cast<int>(param.stride_w)}; |
|
|
|
|
|
|
|
FilterGlobal2ShareVisitor gl2sh_flt = { |
|
|
|
smem_flt, |
|
|
|
static_cast<int>(param.flt_w), |
|
|
|
is_fwd ? 0 : static_cast<int>(param.flt_h - 1), |
|
|
|
0, |
|
|
|
static_cast<int>(param.flt_h), |
|
|
|
static_cast<int>(param.flt_w), |
|
|
|
1, |
|
|
|
1}; |
|
|
|
|
|
|
|
gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w; |
|
|
|
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w; |
|
|
@@ -290,14 +297,51 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w], |
|
|
|
reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w]; |
|
|
|
T2 reg_src[2][SrcTileConfig::unroll_h * t2_src_unroll_w], |
|
|
|
reg_flt[2][2][FilterTileConfig::unroll_h * t2_flt_unroll_w]; |
|
|
|
|
|
|
|
T2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}}; |
|
|
|
|
|
|
|
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { |
|
|
|
gl2sh_src.copy(); |
|
|
|
gl2sh_flt.copy(); |
|
|
|
#pragma unroll |
|
|
|
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) { |
|
|
|
int src_offset = (off_oh * stride_h + s_h) % SrcTileCount::smem_h * |
|
|
|
SrcTileCount::smem_w + |
|
|
|
s_w * 2; |
|
|
|
reg_src[0][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast<T2*>( |
|
|
|
smem_src_ptr + src_offset + |
|
|
|
((off_oh * stride_h + s_h) / SrcTileCount::bank_offset_line) * 2); |
|
|
|
} |
|
|
|
} |
|
|
|
if (off_ow == ThreadConfig::thread_x - 1) { |
|
|
|
reg_src[0][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0}; |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) { |
|
|
|
int flt_offset = |
|
|
|
(f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + f_w * 2; |
|
|
|
reg_flt[0][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast<T2*>( |
|
|
|
smem_flt_ptr + flt_offset + |
|
|
|
2 * (f_h / FilterTileCount::bank_offset_line)); |
|
|
|
if (f_w > 0) { |
|
|
|
reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
T2{reg_flt[0][0][f_h * t2_flt_unroll_w + f_w - 1].y, |
|
|
|
reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
} else { |
|
|
|
reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
T2{0.0, reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
} |
|
|
|
} |
|
|
|
reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; |
|
|
|
reg_flt[0][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = |
|
|
|
T2{reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; |
|
|
|
} |
|
|
|
|
|
|
|
for (int fh = 1; fh < param.flt_h - 1; fh += FilterTileConfig::unroll_h * 2) { |
|
|
|
#pragma unroll |
|
|
|
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { |
|
|
|
#pragma unroll |
|
|
@@ -305,10 +349,15 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
int src_offset = (off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h * |
|
|
|
SrcTileCount::smem_w + |
|
|
|
s_w * 2; |
|
|
|
reg_src[s_h * t2_src_unroll_w + s_w] = |
|
|
|
*reinterpret_cast<T2*>(smem_src_ptr + src_offset); |
|
|
|
reg_src[1][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast<T2*>( |
|
|
|
smem_src_ptr + src_offset + |
|
|
|
2 * ((off_oh * stride_h + fh + s_h) / |
|
|
|
SrcTileCount::bank_offset_line)); |
|
|
|
} |
|
|
|
} |
|
|
|
if (off_ow == ThreadConfig::thread_x - 1) { |
|
|
|
reg_src[1][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0}; |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { |
|
|
@@ -317,20 +366,21 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
int flt_offset = |
|
|
|
(fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + |
|
|
|
f_w * 2; |
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
*reinterpret_cast<T2*>(smem_flt_ptr + flt_offset); |
|
|
|
reg_flt[1][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast<T2*>( |
|
|
|
smem_flt_ptr + flt_offset + |
|
|
|
2 * ((fh + f_h) / FilterTileCount::bank_offset_line)); |
|
|
|
if (f_w > 0) { |
|
|
|
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
T2{reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y, |
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
reg_flt[1][1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
T2{reg_flt[1][0][f_h * t2_flt_unroll_w + f_w - 1].y, |
|
|
|
reg_flt[1][0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
} else { |
|
|
|
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
T2{0.0, reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
reg_flt[1][1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
T2{0.0, reg_flt[1][0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
} |
|
|
|
} |
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; |
|
|
|
reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = |
|
|
|
T2{reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; |
|
|
|
reg_flt[1][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; |
|
|
|
reg_flt[1][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{ |
|
|
|
reg_flt[1][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
@@ -342,9 +392,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
#pragma unroll |
|
|
|
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { |
|
|
|
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( |
|
|
|
reg_flt[ow * stride_w % 2] |
|
|
|
reg_flt[0][ow * stride_w % 2] |
|
|
|
[inner_fh * t2_flt_unroll_w + fw], |
|
|
|
reg_src[(inner_fh + oh) * t2_src_unroll_w + fw + |
|
|
|
reg_src[0] |
|
|
|
[(inner_fh + oh) * t2_src_unroll_w + fw + |
|
|
|
ow * stride_w / 2], |
|
|
|
sum[oh * t2_out_unroll_w + ow]); |
|
|
|
} |
|
|
@@ -352,14 +403,92 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
gl2sh_src.commit(); |
|
|
|
gl2sh_flt.commit(); |
|
|
|
gl2sh_src.iter_forward(); |
|
|
|
gl2sh_flt.iter_forward(); |
|
|
|
__syncthreads(); |
|
|
|
#pragma unroll |
|
|
|
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) { |
|
|
|
int src_offset = (off_oh * stride_h + fh + 1 + s_h) % |
|
|
|
SrcTileCount::smem_h * SrcTileCount::smem_w + |
|
|
|
s_w * 2; |
|
|
|
reg_src[0][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast<T2*>( |
|
|
|
smem_src_ptr + src_offset + |
|
|
|
2 * ((off_oh * stride_h + fh + 1 + s_h) / |
|
|
|
SrcTileCount::bank_offset_line)); |
|
|
|
} |
|
|
|
} |
|
|
|
if (off_ow == ThreadConfig::thread_x - 1) { |
|
|
|
reg_src[0][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0}; |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) { |
|
|
|
int flt_offset = (fh + 1 + f_h) % FilterTileCount::smem_h * |
|
|
|
FilterTileCount::smem_w + |
|
|
|
f_w * 2; |
|
|
|
reg_flt[0][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast<T2*>( |
|
|
|
smem_flt_ptr + flt_offset + |
|
|
|
2 * ((fh + 1 + f_h) / FilterTileCount::bank_offset_line)); |
|
|
|
if (f_w > 0) { |
|
|
|
reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
T2{reg_flt[0][0][f_h * t2_flt_unroll_w + f_w - 1].y, |
|
|
|
reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
} else { |
|
|
|
reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
T2{0.0, reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
} |
|
|
|
} |
|
|
|
reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; |
|
|
|
reg_flt[0][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{ |
|
|
|
reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { |
|
|
|
#pragma unroll |
|
|
|
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { |
|
|
|
#pragma unroll |
|
|
|
for (int fw = 0; fw < t2_flt_unroll_w; ++fw) { |
|
|
|
#pragma unroll |
|
|
|
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { |
|
|
|
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( |
|
|
|
reg_flt[1][ow * stride_w % 2] |
|
|
|
[inner_fh * t2_flt_unroll_w + fw], |
|
|
|
reg_src[1] |
|
|
|
[(inner_fh + oh) * t2_src_unroll_w + fw + |
|
|
|
ow * stride_w / 2], |
|
|
|
sum[oh * t2_out_unroll_w + ow]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (param.flt_h % 2 != 0) { |
|
|
|
#pragma unroll |
|
|
|
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { |
|
|
|
#pragma unroll |
|
|
|
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { |
|
|
|
#pragma unroll |
|
|
|
for (int fw = 0; fw < t2_flt_unroll_w; ++fw) { |
|
|
|
#pragma unroll |
|
|
|
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { |
|
|
|
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( |
|
|
|
reg_flt[0][ow * stride_w % 2] |
|
|
|
[inner_fh * t2_flt_unroll_w + fw], |
|
|
|
reg_src[0] |
|
|
|
[(inner_fh + oh) * t2_src_unroll_w + fw + |
|
|
|
ow * stride_w / 2], |
|
|
|
sum[oh * t2_out_unroll_w + ow]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
for (int o = 0; o < OutTileConfig::unroll_size; ++o) { |
|
|
|
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) { |
|
|
|
sum[o] = megdnn::cuda::hadd2(sum[o], __shfl_xor(sum[o], i, 32)); |
|
|
@@ -429,26 +558,31 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
|
|
|
|
SrcGlobal2ShareVisitor gl2sh_src = { |
|
|
|
smem_src, |
|
|
|
param.src_w, |
|
|
|
is_fwd ? src_start_h |
|
|
|
: src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - |
|
|
|
param.src_h * param.stride_h / 2), |
|
|
|
is_fwd ? src_start_w |
|
|
|
: src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - |
|
|
|
param.src_w * param.stride_w / 2), |
|
|
|
is_fwd ? param.src_h : param.src_h * param.stride_h, |
|
|
|
is_fwd ? param.src_w : param.src_w * param.stride_w, |
|
|
|
is_fwd ? 1 : param.stride_h, |
|
|
|
is_fwd ? 1 : param.stride_w}; |
|
|
|
|
|
|
|
FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt, |
|
|
|
param.flt_w, |
|
|
|
is_fwd ? 0 : param.flt_h - 2, |
|
|
|
0, |
|
|
|
param.flt_h, |
|
|
|
param.flt_w, |
|
|
|
1, |
|
|
|
1}; |
|
|
|
static_cast<int>(param.src_w), |
|
|
|
static_cast<int>( |
|
|
|
is_fwd ? src_start_h |
|
|
|
: src_start_h - |
|
|
|
(param.out_h / 2 + param.flt_h / 2 - param.pad_h - |
|
|
|
param.src_h * param.stride_h / 2)), |
|
|
|
static_cast<int>( |
|
|
|
is_fwd ? src_start_w |
|
|
|
: src_start_w - |
|
|
|
(param.out_w / 2 + param.flt_w / 2 - param.pad_w - |
|
|
|
param.src_w * param.stride_w / 2)), |
|
|
|
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), |
|
|
|
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), |
|
|
|
is_fwd ? 1 : static_cast<int>(param.stride_h), |
|
|
|
is_fwd ? 1 : static_cast<int>(param.stride_w)}; |
|
|
|
|
|
|
|
FilterGlobal2ShareVisitor gl2sh_flt = { |
|
|
|
smem_flt, |
|
|
|
static_cast<int>(param.flt_w), |
|
|
|
is_fwd ? 0 : static_cast<int>(param.flt_h - 1), |
|
|
|
0, |
|
|
|
static_cast<int>(param.flt_h), |
|
|
|
static_cast<int>(param.flt_w), |
|
|
|
1, |
|
|
|
1}; |
|
|
|
|
|
|
|
gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w; |
|
|
|
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w; |
|
|
@@ -458,14 +592,51 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w], |
|
|
|
reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w]; |
|
|
|
T2 reg_src[2][SrcTileConfig::unroll_h * t2_src_unroll_w], |
|
|
|
reg_flt[2][2][FilterTileConfig::unroll_h * t2_flt_unroll_w]; |
|
|
|
|
|
|
|
float2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}}; |
|
|
|
|
|
|
|
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { |
|
|
|
gl2sh_src.copy(); |
|
|
|
gl2sh_flt.copy(); |
|
|
|
#pragma unroll |
|
|
|
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) { |
|
|
|
int src_offset = (off_oh * stride_h + s_h) % SrcTileCount::smem_h * |
|
|
|
SrcTileCount::smem_w + |
|
|
|
s_w * 2; |
|
|
|
reg_src[0][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast<T2*>( |
|
|
|
smem_src_ptr + src_offset + |
|
|
|
((off_oh * stride_h + s_h) / SrcTileCount::bank_offset_line) * 2); |
|
|
|
} |
|
|
|
} |
|
|
|
if (off_ow == ThreadConfig::thread_x - 1) { |
|
|
|
reg_src[0][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0}; |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) { |
|
|
|
int flt_offset = |
|
|
|
(f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + f_w * 2; |
|
|
|
reg_flt[0][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast<T2*>( |
|
|
|
smem_flt_ptr + flt_offset + |
|
|
|
2 * (f_h / FilterTileCount::bank_offset_line)); |
|
|
|
if (f_w > 0) { |
|
|
|
reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
T2{reg_flt[0][0][f_h * t2_flt_unroll_w + f_w - 1].y, |
|
|
|
reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
} else { |
|
|
|
reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
T2{0.0, reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
} |
|
|
|
} |
|
|
|
reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; |
|
|
|
reg_flt[0][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = |
|
|
|
T2{reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; |
|
|
|
} |
|
|
|
|
|
|
|
for (int fh = 1; fh < param.flt_h - 1; fh += FilterTileConfig::unroll_h * 2) { |
|
|
|
#pragma unroll |
|
|
|
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { |
|
|
|
#pragma unroll |
|
|
@@ -473,10 +644,15 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
int src_offset = (off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h * |
|
|
|
SrcTileCount::smem_w + |
|
|
|
s_w * 2; |
|
|
|
reg_src[s_h * t2_src_unroll_w + s_w] = |
|
|
|
*reinterpret_cast<T2*>(smem_src_ptr + src_offset); |
|
|
|
reg_src[1][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast<T2*>( |
|
|
|
smem_src_ptr + src_offset + |
|
|
|
2 * ((off_oh * stride_h + fh + s_h) / |
|
|
|
SrcTileCount::bank_offset_line)); |
|
|
|
} |
|
|
|
} |
|
|
|
if (off_ow == ThreadConfig::thread_x - 1) { |
|
|
|
reg_src[1][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0}; |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { |
|
|
@@ -485,20 +661,82 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
int flt_offset = |
|
|
|
(fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + |
|
|
|
f_w * 2; |
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
*reinterpret_cast<T2*>(smem_flt_ptr + flt_offset); |
|
|
|
reg_flt[1][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast<T2*>( |
|
|
|
smem_flt_ptr + flt_offset + |
|
|
|
2 * ((fh + f_h) / FilterTileCount::bank_offset_line)); |
|
|
|
if (f_w > 0) { |
|
|
|
reg_flt[1][1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
T2{reg_flt[1][0][f_h * t2_flt_unroll_w + f_w - 1].y, |
|
|
|
reg_flt[1][0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
} else { |
|
|
|
reg_flt[1][1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
T2{0.0, reg_flt[1][0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
} |
|
|
|
} |
|
|
|
reg_flt[1][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; |
|
|
|
reg_flt[1][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{ |
|
|
|
reg_flt[1][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { |
|
|
|
#pragma unroll |
|
|
|
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { |
|
|
|
#pragma unroll |
|
|
|
for (int fw = 0; fw < t2_flt_unroll_w; ++fw) { |
|
|
|
#pragma unroll |
|
|
|
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { |
|
|
|
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( |
|
|
|
reg_flt[0][ow * stride_w % 2] |
|
|
|
[inner_fh * t2_flt_unroll_w + fw], |
|
|
|
reg_src[0] |
|
|
|
[(inner_fh + oh) * t2_src_unroll_w + fw + |
|
|
|
ow * stride_w / 2], |
|
|
|
sum[oh * t2_out_unroll_w + ow]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) { |
|
|
|
int src_offset = (off_oh * stride_h + fh + 1 + s_h) % |
|
|
|
SrcTileCount::smem_h * SrcTileCount::smem_w + |
|
|
|
s_w * 2; |
|
|
|
reg_src[0][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast<T2*>( |
|
|
|
smem_src_ptr + src_offset + |
|
|
|
2 * ((off_oh * stride_h + fh + 1 + s_h) / |
|
|
|
SrcTileCount::bank_offset_line)); |
|
|
|
} |
|
|
|
} |
|
|
|
if (off_ow == ThreadConfig::thread_x - 1) { |
|
|
|
reg_src[0][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0}; |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) { |
|
|
|
int flt_offset = (fh + 1 + f_h) % FilterTileCount::smem_h * |
|
|
|
FilterTileCount::smem_w + |
|
|
|
f_w * 2; |
|
|
|
reg_flt[0][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast<T2*>( |
|
|
|
smem_flt_ptr + flt_offset + |
|
|
|
2 * ((fh + 1 + f_h) / FilterTileCount::bank_offset_line)); |
|
|
|
if (f_w > 0) { |
|
|
|
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
T2{reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y, |
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
T2{reg_flt[0][0][f_h * t2_flt_unroll_w + f_w - 1].y, |
|
|
|
reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
} else { |
|
|
|
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
T2{0.0, reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
T2{0.0, reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
} |
|
|
|
} |
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; |
|
|
|
reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = |
|
|
|
T2{reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; |
|
|
|
reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; |
|
|
|
reg_flt[0][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{ |
|
|
|
reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
@@ -510,24 +748,42 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
#pragma unroll |
|
|
|
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { |
|
|
|
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( |
|
|
|
reg_flt[ow * stride_w % 2] |
|
|
|
reg_flt[1][ow * stride_w % 2] |
|
|
|
[inner_fh * t2_flt_unroll_w + fw], |
|
|
|
reg_src[(inner_fh + oh) * t2_src_unroll_w + fw + |
|
|
|
reg_src[1] |
|
|
|
[(inner_fh + oh) * t2_src_unroll_w + fw + |
|
|
|
ow * stride_w / 2], |
|
|
|
sum[oh * t2_out_unroll_w + ow]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
gl2sh_src.commit(); |
|
|
|
gl2sh_flt.commit(); |
|
|
|
gl2sh_src.iter_forward(); |
|
|
|
gl2sh_flt.iter_forward(); |
|
|
|
__syncthreads(); |
|
|
|
if (param.flt_h % 2 != 0) { |
|
|
|
#pragma unroll |
|
|
|
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { |
|
|
|
#pragma unroll |
|
|
|
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { |
|
|
|
#pragma unroll |
|
|
|
for (int fw = 0; fw < t2_flt_unroll_w; ++fw) { |
|
|
|
#pragma unroll |
|
|
|
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { |
|
|
|
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( |
|
|
|
reg_flt[0][ow * stride_w % 2] |
|
|
|
[inner_fh * t2_flt_unroll_w + fw], |
|
|
|
reg_src[0] |
|
|
|
[(inner_fh + oh) * t2_src_unroll_w + fw + |
|
|
|
ow * stride_w / 2], |
|
|
|
sum[oh * t2_out_unroll_w + ow]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
for (int o = 0; o < OutTileConfig::unroll_size; ++o) { |
|
|
|
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) { |
|
|
|
sum[o].x += __shfl_xor(sum[o].x, i, 32); |
|
|
@@ -595,26 +851,31 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
|
|
|
|
SrcGlobal2ShareVisitor gl2sh_src = { |
|
|
|
smem_src, |
|
|
|
param.src_w, |
|
|
|
is_fwd ? src_start_h |
|
|
|
: src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - |
|
|
|
param.src_h * param.stride_h / 2), |
|
|
|
is_fwd ? src_start_w |
|
|
|
: src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - |
|
|
|
param.src_w * param.stride_w / 2), |
|
|
|
is_fwd ? param.src_h : param.src_h * param.stride_h, |
|
|
|
is_fwd ? param.src_w : param.src_w * param.stride_w, |
|
|
|
is_fwd ? 1 : param.stride_h, |
|
|
|
is_fwd ? 1 : param.stride_w}; |
|
|
|
|
|
|
|
FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt, |
|
|
|
param.flt_w, |
|
|
|
is_fwd ? 0 : param.flt_h - 2, |
|
|
|
0, |
|
|
|
param.flt_h, |
|
|
|
param.flt_w, |
|
|
|
1, |
|
|
|
1}; |
|
|
|
static_cast<int>(param.src_w), |
|
|
|
static_cast<int>( |
|
|
|
is_fwd ? src_start_h |
|
|
|
: src_start_h - |
|
|
|
(param.out_h / 2 + param.flt_h / 2 - param.pad_h - |
|
|
|
param.src_h * param.stride_h / 2)), |
|
|
|
static_cast<int>( |
|
|
|
is_fwd ? src_start_w |
|
|
|
: src_start_w - |
|
|
|
(param.out_w / 2 + param.flt_w / 2 - param.pad_w - |
|
|
|
param.src_w * param.stride_w / 2)), |
|
|
|
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), |
|
|
|
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), |
|
|
|
is_fwd ? 1 : static_cast<int>(param.stride_h), |
|
|
|
is_fwd ? 1 : static_cast<int>(param.stride_w)}; |
|
|
|
|
|
|
|
FilterGlobal2ShareVisitor gl2sh_flt = { |
|
|
|
smem_flt, |
|
|
|
static_cast<int>(param.flt_w), |
|
|
|
is_fwd ? 0 : static_cast<int>(param.flt_h - 1), |
|
|
|
0, |
|
|
|
static_cast<int>(param.flt_h), |
|
|
|
static_cast<int>(param.flt_w), |
|
|
|
1, |
|
|
|
1}; |
|
|
|
|
|
|
|
gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w; |
|
|
|
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w; |
|
|
@@ -624,22 +885,43 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
T reg_src[SrcTileConfig::unroll_h * SrcTileConfig::unroll_w], |
|
|
|
reg_flt[FilterTileConfig::unroll_h * FilterTileConfig::unroll_w]; |
|
|
|
T reg_src[2][SrcTileConfig::unroll_h * SrcTileConfig::unroll_w], |
|
|
|
reg_flt[2][FilterTileConfig::unroll_h * FilterTileConfig::unroll_w]; |
|
|
|
|
|
|
|
T sum[OutTileConfig::unroll_size] = {0.0}; |
|
|
|
|
|
|
|
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { |
|
|
|
gl2sh_src.copy(); |
|
|
|
gl2sh_flt.copy(); |
|
|
|
#pragma unroll |
|
|
|
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { |
|
|
|
reg_src[0][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr |
|
|
|
[(off_oh * stride_h + s_h) % SrcTileCount::smem_h * |
|
|
|
SrcTileCount::smem_w + |
|
|
|
s_w + (off_oh * stride_h + s_h) / SrcTileCount::bank_offset_line]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { |
|
|
|
reg_flt[0][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr |
|
|
|
[(f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + f_w + |
|
|
|
f_h / FilterTileCount::bank_offset_line]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (int fh = 1; fh < param.flt_h + 1; fh += FilterTileConfig::unroll_h * 2) { |
|
|
|
#pragma unroll |
|
|
|
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { |
|
|
|
reg_src[s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr |
|
|
|
reg_src[1][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr |
|
|
|
[(off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h * |
|
|
|
SrcTileCount::smem_w + |
|
|
|
s_w]; |
|
|
|
s_w + |
|
|
|
(off_oh * stride_h + fh + s_h) / |
|
|
|
SrcTileCount::bank_offset_line]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -647,14 +929,54 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { |
|
|
|
reg_flt[f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr |
|
|
|
reg_flt[1][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr |
|
|
|
[(fh + f_h) % FilterTileCount::smem_h * |
|
|
|
FilterTileCount::smem_w + |
|
|
|
f_w]; |
|
|
|
f_w + (fh + f_h) / FilterTileCount::bank_offset_line]; |
|
|
|
} |
|
|
|
} |
|
|
|
#pragma unroll |
|
|
|
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { |
|
|
|
#pragma unroll |
|
|
|
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { |
|
|
|
#pragma unroll |
|
|
|
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) { |
|
|
|
#pragma unroll |
|
|
|
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { |
|
|
|
sum[oh * OutTileConfig::unroll_w + ow] += |
|
|
|
reg_flt[0][inner_fh * FilterTileConfig::unroll_w + fw] * |
|
|
|
reg_src[0] |
|
|
|
[(inner_fh + oh) * SrcTileConfig::unroll_w + fw + |
|
|
|
ow * stride_w]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { |
|
|
|
reg_src[0][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr |
|
|
|
[(off_oh * stride_h + fh + 1 + s_h) % SrcTileCount::smem_h * |
|
|
|
SrcTileCount::smem_w + |
|
|
|
s_w + |
|
|
|
(off_oh * stride_h + fh + 1 + s_h) / |
|
|
|
SrcTileCount::bank_offset_line]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { |
|
|
|
reg_flt[0][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr |
|
|
|
[(fh + 1 + f_h) % FilterTileCount::smem_h * |
|
|
|
FilterTileCount::smem_w + |
|
|
|
f_w + (fh + 1 + f_h) / FilterTileCount::bank_offset_line]; |
|
|
|
} |
|
|
|
} |
|
|
|
#pragma unroll |
|
|
|
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { |
|
|
|
#pragma unroll |
|
|
|
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { |
|
|
@@ -663,22 +985,38 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
#pragma unroll |
|
|
|
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { |
|
|
|
sum[oh * OutTileConfig::unroll_w + ow] += |
|
|
|
reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] * |
|
|
|
reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw + |
|
|
|
reg_flt[1][inner_fh * FilterTileConfig::unroll_w + fw] * |
|
|
|
reg_src[1] |
|
|
|
[(inner_fh + oh) * SrcTileConfig::unroll_w + fw + |
|
|
|
ow * stride_w]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
gl2sh_src.commit(); |
|
|
|
gl2sh_flt.commit(); |
|
|
|
gl2sh_src.iter_forward(); |
|
|
|
gl2sh_flt.iter_forward(); |
|
|
|
__syncthreads(); |
|
|
|
if (param.flt_h % 2 != 0) { |
|
|
|
#pragma unroll |
|
|
|
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { |
|
|
|
#pragma unroll |
|
|
|
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { |
|
|
|
#pragma unroll |
|
|
|
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) { |
|
|
|
#pragma unroll |
|
|
|
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { |
|
|
|
sum[oh * OutTileConfig::unroll_w + ow] += |
|
|
|
reg_flt[0][inner_fh * FilterTileConfig::unroll_w + fw] * |
|
|
|
reg_src[0] |
|
|
|
[(inner_fh + oh) * SrcTileConfig::unroll_w + fw + |
|
|
|
ow * stride_w]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
for (int o = 0; o < OutTileConfig::unroll_size; ++o) { |
|
|
|
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) { |
|
|
|
sum[o] += __shfl_xor(sum[o], i, 32); |
|
|
@@ -743,26 +1081,31 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
|
|
|
|
SrcGlobal2ShareVisitor gl2sh_src = { |
|
|
|
smem_src, |
|
|
|
param.src_w, |
|
|
|
is_fwd ? src_start_h |
|
|
|
: src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - |
|
|
|
param.src_h * param.stride_h / 2), |
|
|
|
is_fwd ? src_start_w |
|
|
|
: src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - |
|
|
|
param.src_w * param.stride_w / 2), |
|
|
|
is_fwd ? param.src_h : param.src_h * param.stride_h, |
|
|
|
is_fwd ? param.src_w : param.src_w * param.stride_w, |
|
|
|
is_fwd ? 1 : param.stride_h, |
|
|
|
is_fwd ? 1 : param.stride_w}; |
|
|
|
|
|
|
|
FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt, |
|
|
|
param.flt_w, |
|
|
|
is_fwd ? 0 : param.flt_h - 2, |
|
|
|
0, |
|
|
|
param.flt_h, |
|
|
|
param.flt_w, |
|
|
|
1, |
|
|
|
1}; |
|
|
|
static_cast<int>(param.src_w), |
|
|
|
static_cast<int>( |
|
|
|
is_fwd ? src_start_h |
|
|
|
: src_start_h - |
|
|
|
(param.out_h / 2 + param.flt_h / 2 - param.pad_h - |
|
|
|
param.src_h * param.stride_h / 2)), |
|
|
|
static_cast<int>( |
|
|
|
is_fwd ? src_start_w |
|
|
|
: src_start_w - |
|
|
|
(param.out_w / 2 + param.flt_w / 2 - param.pad_w - |
|
|
|
param.src_w * param.stride_w / 2)), |
|
|
|
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h), |
|
|
|
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w), |
|
|
|
is_fwd ? 1 : static_cast<int>(param.stride_h), |
|
|
|
is_fwd ? 1 : static_cast<int>(param.stride_w)}; |
|
|
|
|
|
|
|
FilterGlobal2ShareVisitor gl2sh_flt = { |
|
|
|
smem_flt, |
|
|
|
static_cast<int>(param.flt_w), |
|
|
|
is_fwd ? 0 : static_cast<int>(param.flt_h - 1), |
|
|
|
0, |
|
|
|
static_cast<int>(param.flt_h), |
|
|
|
static_cast<int>(param.flt_w), |
|
|
|
1, |
|
|
|
1}; |
|
|
|
|
|
|
|
gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w; |
|
|
|
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w; |
|
|
@@ -772,22 +1115,43 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
T reg_src[SrcTileConfig::unroll_h * SrcTileConfig::unroll_w], |
|
|
|
reg_flt[FilterTileConfig::unroll_h * FilterTileConfig::unroll_w]; |
|
|
|
T reg_src[2][SrcTileConfig::unroll_h * SrcTileConfig::unroll_w], |
|
|
|
reg_flt[2][FilterTileConfig::unroll_h * FilterTileConfig::unroll_w]; |
|
|
|
|
|
|
|
T sum[OutTileConfig::unroll_size] = {0.0}; |
|
|
|
|
|
|
|
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { |
|
|
|
gl2sh_src.copy(); |
|
|
|
gl2sh_flt.copy(); |
|
|
|
#pragma unroll |
|
|
|
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { |
|
|
|
reg_src[0][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr |
|
|
|
[(off_oh * stride_h + s_h) % SrcTileCount::smem_h * |
|
|
|
SrcTileCount::smem_w + |
|
|
|
s_w + (off_oh * stride_h + s_h) / SrcTileCount::bank_offset_line]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { |
|
|
|
reg_flt[0][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr |
|
|
|
[(f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + f_w + |
|
|
|
f_h / FilterTileCount::bank_offset_line]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (int fh = 1; fh < param.flt_h + 1; fh += FilterTileConfig::unroll_h * 2) { |
|
|
|
#pragma unroll |
|
|
|
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { |
|
|
|
reg_src[s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr |
|
|
|
reg_src[1][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr |
|
|
|
[(off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h * |
|
|
|
SrcTileCount::smem_w + |
|
|
|
s_w]; |
|
|
|
s_w + |
|
|
|
(off_oh * stride_h + fh + s_h) / |
|
|
|
SrcTileCount::bank_offset_line]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -795,14 +1159,54 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { |
|
|
|
reg_flt[f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr |
|
|
|
reg_flt[1][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr |
|
|
|
[(fh + f_h) % FilterTileCount::smem_h * |
|
|
|
FilterTileCount::smem_w + |
|
|
|
f_w]; |
|
|
|
f_w + (fh + f_h) / FilterTileCount::bank_offset_line]; |
|
|
|
} |
|
|
|
} |
|
|
|
#pragma unroll |
|
|
|
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { |
|
|
|
#pragma unroll |
|
|
|
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { |
|
|
|
#pragma unroll |
|
|
|
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) { |
|
|
|
#pragma unroll |
|
|
|
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { |
|
|
|
sum[oh * OutTileConfig::unroll_w + ow] += |
|
|
|
reg_flt[0][inner_fh * FilterTileConfig::unroll_w + fw] * |
|
|
|
reg_src[0] |
|
|
|
[(inner_fh + oh) * SrcTileConfig::unroll_w + fw + |
|
|
|
ow * stride_w]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { |
|
|
|
reg_src[0][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr |
|
|
|
[(off_oh * stride_h + fh + 1 + s_h) % SrcTileCount::smem_h * |
|
|
|
SrcTileCount::smem_w + |
|
|
|
s_w + |
|
|
|
(off_oh * stride_h + fh + 1 + s_h) / |
|
|
|
SrcTileCount::bank_offset_line]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { |
|
|
|
reg_flt[0][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr |
|
|
|
[(fh + 1 + f_h) % FilterTileCount::smem_h * |
|
|
|
FilterTileCount::smem_w + |
|
|
|
f_w + (fh + 1 + f_h) / FilterTileCount::bank_offset_line]; |
|
|
|
} |
|
|
|
} |
|
|
|
#pragma unroll |
|
|
|
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { |
|
|
|
#pragma unroll |
|
|
|
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { |
|
|
@@ -811,22 +1215,38 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
#pragma unroll |
|
|
|
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { |
|
|
|
sum[oh * OutTileConfig::unroll_w + ow] += |
|
|
|
reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] * |
|
|
|
reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw + |
|
|
|
reg_flt[1][inner_fh * FilterTileConfig::unroll_w + fw] * |
|
|
|
reg_src[1] |
|
|
|
[(inner_fh + oh) * SrcTileConfig::unroll_w + fw + |
|
|
|
ow * stride_w]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
gl2sh_src.commit(); |
|
|
|
gl2sh_flt.commit(); |
|
|
|
gl2sh_src.iter_forward(); |
|
|
|
gl2sh_flt.iter_forward(); |
|
|
|
__syncthreads(); |
|
|
|
if (param.flt_h % 2 != 0) { |
|
|
|
#pragma unroll |
|
|
|
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { |
|
|
|
#pragma unroll |
|
|
|
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { |
|
|
|
#pragma unroll |
|
|
|
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) { |
|
|
|
#pragma unroll |
|
|
|
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { |
|
|
|
sum[oh * OutTileConfig::unroll_w + ow] += |
|
|
|
reg_flt[0][inner_fh * FilterTileConfig::unroll_w + fw] * |
|
|
|
reg_src[0] |
|
|
|
[(inner_fh + oh) * SrcTileConfig::unroll_w + fw + |
|
|
|
ow * stride_w]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
for (int o = 0; o < OutTileConfig::unroll_size; ++o) { |
|
|
|
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) { |
|
|
|
sum[o] += __shfl_xor(sum[o], i, 32); |
|
|
@@ -899,11 +1319,10 @@ void LaunchDepthwiseConv2dGPU( |
|
|
|
} \ |
|
|
|
} |
|
|
|
|
|
|
|
#define INSTANCE_A(type1, type2, a, direction) \ |
|
|
|
if (param.flt_w > a * 4) { \ |
|
|
|
INSTANCE_AB(type1, type2, a, 15, direction) \ |
|
|
|
else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB( \ |
|
|
|
type1, type2, a, 3, direction) \ |
|
|
|
#define INSTANCE_A(type1, type2, a, direction) \ |
|
|
|
if (param.flt_w > a * 4) { \ |
|
|
|
INSTANCE_AB(type1, type2, a, 7, direction) \ |
|
|
|
else INSTANCE_AB(type1, type2, a, 3, direction) \ |
|
|
|
} |
|
|
|
|
|
|
|
#define INSTANCE(type1, type2, direction) \ |
|
|
|