|
@@ -235,7 +235,175 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, |
|
|
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, |
|
|
off_oh = threadIdx.y, off_ow = threadIdx.x; |
|
|
off_oh = threadIdx.y, off_ow = threadIdx.x; |
|
|
|
|
|
|
|
|
constexpr int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2; |
|
|
|
|
|
|
|
|
constexpr int t2_src_unroll_w = (SrcTileConfig::unroll_w + 3) / 2; |
|
|
|
|
|
constexpr int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2; |
|
|
|
|
|
constexpr int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2; |
|
|
|
|
|
|
|
|
|
|
|
extern __shared__ __align__(8) unsigned char smem[]; |
|
|
|
|
|
static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); |
|
|
|
|
|
T* smem_src = reinterpret_cast<T*>(smem); |
|
|
|
|
|
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]); |
|
|
|
|
|
int stride_h = is_fwd ? param.stride_h : 1; |
|
|
|
|
|
int stride_w = is_fwd ? param.stride_w : 1; |
|
|
|
|
|
|
|
|
|
|
|
int off_ichannel = off_ochannel / param.chl_mul, |
|
|
|
|
|
off_fchannel = off_ichannel % param.src_chl, |
|
|
|
|
|
out_start_h = off_obh * OutTileConfig::block_h, |
|
|
|
|
|
out_start_w = off_obw * OutTileConfig::block_w, |
|
|
|
|
|
src_start_h = out_start_h * stride_h - param.pad_h, |
|
|
|
|
|
src_start_w = out_start_w * stride_w - param.pad_w, |
|
|
|
|
|
out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h; |
|
|
|
|
|
|
|
|
|
|
|
T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w; |
|
|
|
|
|
T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w; |
|
|
|
|
|
|
|
|
|
|
|
T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w; |
|
|
|
|
|
|
|
|
|
|
|
SrcGlobal2ShareVisitor gl2sh_src = { |
|
|
|
|
|
smem_src, |
|
|
|
|
|
param.src_w, |
|
|
|
|
|
is_fwd ? src_start_h |
|
|
|
|
|
: src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - |
|
|
|
|
|
param.src_h * param.stride_h / 2), |
|
|
|
|
|
is_fwd ? src_start_w |
|
|
|
|
|
: src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - |
|
|
|
|
|
param.src_w * param.stride_w / 2), |
|
|
|
|
|
is_fwd ? param.src_h : param.src_h * param.stride_h, |
|
|
|
|
|
is_fwd ? param.src_w : param.src_w * param.stride_w, |
|
|
|
|
|
is_fwd ? 1 : param.stride_h, |
|
|
|
|
|
is_fwd ? 1 : param.stride_w}; |
|
|
|
|
|
|
|
|
|
|
|
FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt, |
|
|
|
|
|
param.flt_w, |
|
|
|
|
|
is_fwd ? 0 : param.flt_h - 2, |
|
|
|
|
|
0, |
|
|
|
|
|
param.flt_h, |
|
|
|
|
|
param.flt_w, |
|
|
|
|
|
1, |
|
|
|
|
|
1}; |
|
|
|
|
|
|
|
|
|
|
|
gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w; |
|
|
|
|
|
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w; |
|
|
|
|
|
|
|
|
|
|
|
gl2sh_src.first_copy(); |
|
|
|
|
|
gl2sh_flt.first_copy(); |
|
|
|
|
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
|
|
|
|
T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w], |
|
|
|
|
|
reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w]; |
|
|
|
|
|
|
|
|
|
|
|
T2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}}; |
|
|
|
|
|
|
|
|
|
|
|
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { |
|
|
|
|
|
gl2sh_src.copy(); |
|
|
|
|
|
gl2sh_flt.copy(); |
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { |
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) { |
|
|
|
|
|
int src_offset = (off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h * |
|
|
|
|
|
SrcTileCount::smem_w + |
|
|
|
|
|
s_w * 2; |
|
|
|
|
|
reg_src[s_h * t2_src_unroll_w + s_w] = |
|
|
|
|
|
*reinterpret_cast<T2*>(smem_src_ptr + src_offset); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { |
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) { |
|
|
|
|
|
int flt_offset = |
|
|
|
|
|
(fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + |
|
|
|
|
|
f_w * 2; |
|
|
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
|
|
*reinterpret_cast<T2*>(smem_flt_ptr + flt_offset); |
|
|
|
|
|
if (f_w > 0) { |
|
|
|
|
|
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
|
|
T2{reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y, |
|
|
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
|
|
} else { |
|
|
|
|
|
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
|
|
T2{0.0, reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; |
|
|
|
|
|
reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = |
|
|
|
|
|
T2{reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { |
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { |
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int fw = 0; fw < t2_flt_unroll_w; ++fw) { |
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { |
|
|
|
|
|
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( |
|
|
|
|
|
reg_flt[ow * stride_w % 2] |
|
|
|
|
|
[inner_fh * t2_flt_unroll_w + fw], |
|
|
|
|
|
reg_src[(inner_fh + oh) * t2_src_unroll_w + fw + |
|
|
|
|
|
ow * stride_w / 2], |
|
|
|
|
|
sum[oh * t2_out_unroll_w + ow]); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
gl2sh_src.commit(); |
|
|
|
|
|
gl2sh_flt.commit(); |
|
|
|
|
|
gl2sh_src.iter_forward(); |
|
|
|
|
|
gl2sh_flt.iter_forward(); |
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
for (int o = 0; o < OutTileConfig::unroll_size; ++o) { |
|
|
|
|
|
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) { |
|
|
|
|
|
sum[o] = megdnn::cuda::hadd2(sum[o], __shfl_xor(sum[o], i, 32)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) { |
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int i = 0; i < OutTileConfig::unroll_h; ++i) { |
|
|
|
|
|
int out_h_idx = out_base_h_idx + i; |
|
|
|
|
|
if (out_h_idx < param.out_h) { |
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int j = 0; j < OutTileConfig::unroll_w; ++j) { |
|
|
|
|
|
int out_w_idx = out_start_w + j; |
|
|
|
|
|
if (out_w_idx >= param.out_w) |
|
|
|
|
|
return; |
|
|
|
|
|
out_base_ptr[out_h_idx * param.out_w + out_w_idx] = __float2half( |
|
|
|
|
|
__half2float(sum[i * OutTileConfig::unroll_w + j].x) + |
|
|
|
|
|
__half2float(sum[i * OutTileConfig::unroll_w + j].y)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
template <typename ConvTrait, DepthwiseConv2dDirection kDirection> |
|
|
|
|
|
__global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
|
|
const Param param, const __half* input, const __half* filter, __half* output) { |
|
|
|
|
|
using T = __half; |
|
|
|
|
|
using T2 = __half2; |
|
|
|
|
|
using ThreadConfig = typename ConvTrait::ThreadConfig; |
|
|
|
|
|
using SrcTileConfig = typename ConvTrait::SrcTileConfig; |
|
|
|
|
|
using FilterTileConfig = typename ConvTrait::FilterTileConfig; |
|
|
|
|
|
using OutTileConfig = typename ConvTrait::OutTileConfig; |
|
|
|
|
|
using SrcTileCount = typename ConvTrait::SrcTileCount; |
|
|
|
|
|
using FilterTileCount = typename ConvTrait::FilterTileCount; |
|
|
|
|
|
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor; |
|
|
|
|
|
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor; |
|
|
|
|
|
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); |
|
|
|
|
|
|
|
|
|
|
|
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, |
|
|
|
|
|
off_oh = threadIdx.y, off_ow = threadIdx.x; |
|
|
|
|
|
|
|
|
|
|
|
constexpr int t2_src_unroll_w = (SrcTileConfig::unroll_w + 3) / 2; |
|
|
constexpr int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2; |
|
|
constexpr int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2; |
|
|
constexpr int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2; |
|
|
constexpr int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2; |
|
|
|
|
|
|
|
@@ -320,17 +488,17 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
reg_flt[0][f_h * t2_flt_unroll_w + f_w] = |
|
|
reg_flt[0][f_h * t2_flt_unroll_w + f_w] = |
|
|
*reinterpret_cast<T2*>(smem_flt_ptr + flt_offset); |
|
|
*reinterpret_cast<T2*>(smem_flt_ptr + flt_offset); |
|
|
if (f_w > 0) { |
|
|
if (f_w > 0) { |
|
|
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = { |
|
|
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y, |
|
|
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
|
|
|
|
|
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
|
|
T2{reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y, |
|
|
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
} else { |
|
|
} else { |
|
|
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = { |
|
|
|
|
|
0.0, reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
|
|
|
|
|
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
|
|
T2{0.0, reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = {0.0, 0.0}; |
|
|
|
|
|
reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = { |
|
|
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; |
|
|
|
|
|
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0}; |
|
|
|
|
|
reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = |
|
|
|
|
|
T2{reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
#pragma unroll |
|
@@ -535,6 +703,154 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
template <typename ConvTrait, DepthwiseConv2dDirection kDirection> |
|
|
|
|
|
__global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
|
|
const Param param, const float* input, const float* filter, float* output) { |
|
|
|
|
|
using T = float; |
|
|
|
|
|
using T2 = float2; |
|
|
|
|
|
using ThreadConfig = typename ConvTrait::ThreadConfig; |
|
|
|
|
|
using SrcTileConfig = typename ConvTrait::SrcTileConfig; |
|
|
|
|
|
using FilterTileConfig = typename ConvTrait::FilterTileConfig; |
|
|
|
|
|
using OutTileConfig = typename ConvTrait::OutTileConfig; |
|
|
|
|
|
using SrcTileCount = typename ConvTrait::SrcTileCount; |
|
|
|
|
|
using FilterTileCount = typename ConvTrait::FilterTileCount; |
|
|
|
|
|
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor; |
|
|
|
|
|
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor; |
|
|
|
|
|
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); |
|
|
|
|
|
|
|
|
|
|
|
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, |
|
|
|
|
|
off_oh = threadIdx.y, off_ow = threadIdx.x; |
|
|
|
|
|
|
|
|
|
|
|
extern __shared__ __align__(8) unsigned char smem[]; |
|
|
|
|
|
static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); |
|
|
|
|
|
T* smem_src = reinterpret_cast<T*>(smem); |
|
|
|
|
|
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]); |
|
|
|
|
|
int stride_h = is_fwd ? param.stride_h : 1; |
|
|
|
|
|
int stride_w = is_fwd ? param.stride_w : 1; |
|
|
|
|
|
|
|
|
|
|
|
int off_ichannel = off_ochannel / param.chl_mul, |
|
|
|
|
|
off_fchannel = off_ichannel % param.src_chl, |
|
|
|
|
|
out_start_h = off_obh * OutTileConfig::block_h, |
|
|
|
|
|
out_start_w = off_obw * OutTileConfig::block_w, |
|
|
|
|
|
src_start_h = out_start_h * stride_h - param.pad_h, |
|
|
|
|
|
src_start_w = out_start_w * stride_w - param.pad_w, |
|
|
|
|
|
out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h; |
|
|
|
|
|
|
|
|
|
|
|
T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w; |
|
|
|
|
|
T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w; |
|
|
|
|
|
|
|
|
|
|
|
T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w; |
|
|
|
|
|
|
|
|
|
|
|
SrcGlobal2ShareVisitor gl2sh_src = { |
|
|
|
|
|
smem_src, |
|
|
|
|
|
param.src_w, |
|
|
|
|
|
is_fwd ? src_start_h |
|
|
|
|
|
: src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - |
|
|
|
|
|
param.src_h * param.stride_h / 2), |
|
|
|
|
|
is_fwd ? src_start_w |
|
|
|
|
|
: src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - |
|
|
|
|
|
param.src_w * param.stride_w / 2), |
|
|
|
|
|
is_fwd ? param.src_h : param.src_h * param.stride_h, |
|
|
|
|
|
is_fwd ? param.src_w : param.src_w * param.stride_w, |
|
|
|
|
|
is_fwd ? 1 : param.stride_h, |
|
|
|
|
|
is_fwd ? 1 : param.stride_w}; |
|
|
|
|
|
|
|
|
|
|
|
FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt, |
|
|
|
|
|
param.flt_w, |
|
|
|
|
|
is_fwd ? 0 : param.flt_h - 2, |
|
|
|
|
|
0, |
|
|
|
|
|
param.flt_h, |
|
|
|
|
|
param.flt_w, |
|
|
|
|
|
1, |
|
|
|
|
|
1}; |
|
|
|
|
|
|
|
|
|
|
|
gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w; |
|
|
|
|
|
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w; |
|
|
|
|
|
|
|
|
|
|
|
gl2sh_src.first_copy(); |
|
|
|
|
|
gl2sh_flt.first_copy(); |
|
|
|
|
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
|
|
|
|
T reg_src[SrcTileConfig::unroll_h * SrcTileConfig::unroll_w], |
|
|
|
|
|
reg_flt[FilterTileConfig::unroll_h * FilterTileConfig::unroll_w]; |
|
|
|
|
|
|
|
|
|
|
|
T sum[OutTileConfig::unroll_size] = {0.0}; |
|
|
|
|
|
|
|
|
|
|
|
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { |
|
|
|
|
|
gl2sh_src.copy(); |
|
|
|
|
|
gl2sh_flt.copy(); |
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { |
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { |
|
|
|
|
|
reg_src[s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr |
|
|
|
|
|
[(off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h * |
|
|
|
|
|
SrcTileCount::smem_w + |
|
|
|
|
|
s_w]; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { |
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { |
|
|
|
|
|
reg_flt[f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr |
|
|
|
|
|
[(fh + f_h) % FilterTileCount::smem_h * |
|
|
|
|
|
FilterTileCount::smem_w + |
|
|
|
|
|
f_w]; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { |
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { |
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) { |
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { |
|
|
|
|
|
sum[oh * OutTileConfig::unroll_w + ow] += |
|
|
|
|
|
reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] * |
|
|
|
|
|
reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw + |
|
|
|
|
|
ow * stride_w]; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
gl2sh_src.commit(); |
|
|
|
|
|
gl2sh_flt.commit(); |
|
|
|
|
|
gl2sh_src.iter_forward(); |
|
|
|
|
|
gl2sh_flt.iter_forward(); |
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
for (int o = 0; o < OutTileConfig::unroll_size; ++o) { |
|
|
|
|
|
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) { |
|
|
|
|
|
sum[o] += __shfl_xor(sum[o], i, 32); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) { |
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int i = 0; i < OutTileConfig::unroll_h; ++i) { |
|
|
|
|
|
int out_h_idx = out_base_h_idx + i; |
|
|
|
|
|
if (out_h_idx < param.out_h) { |
|
|
|
|
|
#pragma unroll |
|
|
|
|
|
for (int j = 0; j < OutTileConfig::unroll_w; ++j) { |
|
|
|
|
|
int out_w_idx = out_start_w + j; |
|
|
|
|
|
if (out_w_idx >= param.out_w) |
|
|
|
|
|
return; |
|
|
|
|
|
out_base_ptr[out_h_idx * param.out_w + out_w_idx] = |
|
|
|
|
|
sum[i * OutTileConfig::unroll_w + j]; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
template < |
|
|
template < |
|
|
typename T, typename T2, DepthwiseConv2dDirection kDirection, int unroll_fw, |
|
|
typename T, typename T2, DepthwiseConv2dDirection kDirection, int unroll_fw, |
|
|
int unroll_ow, int stride> |
|
|
int unroll_ow, int stride> |
|
@@ -561,7 +877,12 @@ void LaunchDepthwiseConv2dGPU( |
|
|
(SrcTileCount::smem_size + FilterTileCount::smem_size) * sizeof(T); |
|
|
(SrcTileCount::smem_size + FilterTileCount::smem_size) * sizeof(T); |
|
|
|
|
|
|
|
|
void (*kernel)(const Param, const T*, const T*, T*); |
|
|
void (*kernel)(const Param, const T*, const T*, T*); |
|
|
kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection>; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (param.is_compute_deafult) { |
|
|
|
|
|
kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection>; |
|
|
|
|
|
} else { |
|
|
|
|
|
kernel = DepthwiseConv2dGPUKernelNCHWC32<IConvTrait, kDirection>; |
|
|
|
|
|
} |
|
|
kernel<<<grid, block, shared_storage, stream>>>(param, input, filter, output); |
|
|
kernel<<<grid, block, shared_storage, stream>>>(param, input, filter, output); |
|
|
after_kernel_launch(); |
|
|
after_kernel_launch(); |
|
|
} |
|
|
} |
|
|