|
|
@@ -57,10 +57,13 @@ struct Global2SharedMem { |
|
|
|
T* smem; |
|
|
|
int stride; |
|
|
|
int start_h, start_w, bound_h, bound_w, ring_smem_h, ring_src_h; |
|
|
|
// just used in backward src data |
|
|
|
int stride_h, stride_w; |
|
|
|
const T* g_ptr; |
|
|
|
|
|
|
|
__device__ __forceinline__ |
|
|
|
Global2SharedMem(T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w); |
|
|
|
__device__ __forceinline__ Global2SharedMem( |
|
|
|
T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w, int stride_h_, |
|
|
|
int stride_w_); |
|
|
|
|
|
|
|
__device__ __forceinline__ void first_copy(); |
|
|
|
__device__ __forceinline__ void copy(); |
|
|
@@ -77,7 +80,7 @@ struct Global2SharedMem { |
|
|
|
|
|
|
|
template < |
|
|
|
typename ldg_dtype, DepthwiseConv2dDirection kDirection, typename ThreadConfig_, |
|
|
|
typename OutTileConfig_, typename FilterTileConfig_> |
|
|
|
typename OutTileConfig_, typename FilterTileConfig_, int stride_w, int stride_h> |
|
|
|
struct ConvTrait { |
|
|
|
using ThreadConfig = ThreadConfig_; |
|
|
|
using OutTileConfig = OutTileConfig_; |
|
|
@@ -88,19 +91,19 @@ struct ConvTrait { |
|
|
|
static int const unroll_h = |
|
|
|
OutTileConfig::unroll_h + FilterTileConfig::unroll_h - 1; |
|
|
|
static int const unroll_w = |
|
|
|
OutTileConfig::unroll_w + FilterTileConfig::unroll_w - 1; |
|
|
|
(OutTileConfig::unroll_w - 1) * stride_w + FilterTileConfig::unroll_w; |
|
|
|
static int const unroll_size = unroll_h * unroll_w; |
|
|
|
}; |
|
|
|
|
|
|
|
struct SrcTileCount { |
|
|
|
static int const smem_src_h = |
|
|
|
OutTileConfig::block_h + FilterTileConfig::unroll_h - 1; |
|
|
|
(OutTileConfig::block_h - 1) * stride_h + FilterTileConfig::unroll_h; |
|
|
|
static int const smem_buff_h = FilterTileConfig::unroll_h; |
|
|
|
static int const smem_load_h = smem_src_h + smem_buff_h; |
|
|
|
static int const smem_h = smem_load_h + smem_buff_h; |
|
|
|
static int const smem_w = |
|
|
|
DIVUP(OutTileConfig::block_w + |
|
|
|
FilterTileConfig::unroll_w * ThreadConfig::thread_x - 1, |
|
|
|
DIVUP((OutTileConfig::block_w - 1) * stride_w + |
|
|
|
FilterTileConfig::unroll_w * ThreadConfig::thread_x, |
|
|
|
2) * |
|
|
|
2; |
|
|
|
static int const smem_size = smem_h * smem_w; |
|
|
@@ -140,20 +143,25 @@ template < |
|
|
|
typename TileCount_> |
|
|
|
__device__ __forceinline__ |
|
|
|
Global2SharedMem<T, kDirection, ThreadConfig_, TileCount_>::Global2SharedMem( |
|
|
|
T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w) |
|
|
|
T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w, int stride_h_, |
|
|
|
int stride_w_) |
|
|
|
: smem(smem_), |
|
|
|
stride(stride_), |
|
|
|
start_h(s_h), |
|
|
|
start_w(s_w), |
|
|
|
bound_h(b_h), |
|
|
|
bound_w(b_w), |
|
|
|
ring_smem_h(TileCount::smem_load_h) { |
|
|
|
ring_smem_h(TileCount::smem_load_h), |
|
|
|
stride_h(stride_h_), |
|
|
|
stride_w(stride_w_) { |
|
|
|
if (is_fwd) { |
|
|
|
ring_src_h = s_h + TileCount::smem_load_h; |
|
|
|
w_offset = 0; |
|
|
|
} else { |
|
|
|
ring_src_h = s_h - 1; |
|
|
|
w_offset = TileCount::smem_w - b_w; |
|
|
|
// stride_h and stride_w just used in backward src data. |
|
|
|
stride_h = stride_w = 1; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -195,9 +203,10 @@ __device__ __forceinline__ void Global2SharedMem< |
|
|
|
T val = 0.0f; |
|
|
|
if (src_h_idx >= 0 && src_h_idx < bound_h && src_w_idx >= 0 && |
|
|
|
src_w_idx < bound_w && |
|
|
|
(is_fwd || (TileCount::smem_load_h - smem_h_idx - 1 >= 0 && |
|
|
|
TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) { |
|
|
|
val = g_ptr[src_h_idx * stride + src_w_idx]; |
|
|
|
((is_fwd && src_h_idx % stride_h == 0 && src_w_idx % stride_w == 0) || |
|
|
|
(!is_fwd && TileCount::smem_load_h - smem_h_idx - 1 >= 0 && |
|
|
|
TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) { |
|
|
|
val = g_ptr[src_h_idx / stride_h * stride + src_w_idx / stride_w]; |
|
|
|
} |
|
|
|
*(sh_ptr_as_copy_t(smem_h_idx, smem_w_idx)) = val; |
|
|
|
} |
|
|
@@ -223,8 +232,9 @@ __device__ __forceinline__ void Global2SharedMem< |
|
|
|
T val = 0.0f; |
|
|
|
if (ring_src_h >= 0 && ring_src_h < bound_h && src_w_idx >= 0 && |
|
|
|
src_w_idx < bound_w && |
|
|
|
(is_fwd || TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0)) { |
|
|
|
val = g_ptr[ring_src_h * stride + src_w_idx]; |
|
|
|
((is_fwd && ring_src_h % stride_h == 0 && src_w_idx % stride_w == 0) || |
|
|
|
(!is_fwd && TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) { |
|
|
|
val = g_ptr[ring_src_h / stride_h * stride + src_w_idx / stride_w]; |
|
|
|
} |
|
|
|
reg[j] = val; |
|
|
|
} |
|
|
@@ -286,21 +296,23 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( |
|
|
|
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, |
|
|
|
off_oh = threadIdx.y, off_ow = threadIdx.x; |
|
|
|
|
|
|
|
const int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2; |
|
|
|
const int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2; |
|
|
|
const int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2; |
|
|
|
constexpr int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2; |
|
|
|
constexpr int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2; |
|
|
|
constexpr int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2; |
|
|
|
|
|
|
|
extern __shared__ __align__(8) unsigned char smem[]; |
|
|
|
static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); |
|
|
|
T* smem_src = reinterpret_cast<T*>(smem); |
|
|
|
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]); |
|
|
|
int stride_h = is_fwd ? param.stride_h : 1; |
|
|
|
int stride_w = is_fwd ? param.stride_w : 1; |
|
|
|
|
|
|
|
int off_ichannel = off_ochannel / param.chl_mul, |
|
|
|
off_fchannel = off_ichannel % param.src_chl, |
|
|
|
out_start_h = off_obh * OutTileConfig::block_h, |
|
|
|
out_start_w = off_obw * OutTileConfig::block_w, |
|
|
|
src_start_h = out_start_h - param.pad_h, |
|
|
|
src_start_w = out_start_w - param.pad_w, |
|
|
|
src_start_h = out_start_h * stride_h - param.pad_h, |
|
|
|
src_start_w = out_start_w * stride_w - param.pad_w, |
|
|
|
out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h; |
|
|
|
|
|
|
|
T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w; |
|
|
@@ -308,12 +320,28 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( |
|
|
|
|
|
|
|
T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w; |
|
|
|
|
|
|
|
SrcGlobal2ShareVisitor gl2sh_src( |
|
|
|
smem_src, param.src_w, src_start_h, src_start_w, param.src_h, param.src_w); |
|
|
|
|
|
|
|
FilterGlobal2ShareVisitor gl2sh_flt = { |
|
|
|
smem_flt, param.flt_w, is_fwd ? 0 : param.flt_h - 2, |
|
|
|
0, param.flt_h, param.flt_w}; |
|
|
|
SrcGlobal2ShareVisitor gl2sh_src = { |
|
|
|
smem_src, |
|
|
|
param.src_w, |
|
|
|
is_fwd ? src_start_h |
|
|
|
: src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - |
|
|
|
param.src_h * param.stride_h / 2), |
|
|
|
is_fwd ? src_start_w |
|
|
|
: src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - |
|
|
|
param.src_w * param.stride_w / 2), |
|
|
|
is_fwd ? param.src_h : param.src_h * param.stride_h, |
|
|
|
is_fwd ? param.src_w : param.src_w * param.stride_w, |
|
|
|
is_fwd ? 1 : param.stride_h, |
|
|
|
is_fwd ? 1 : param.stride_w}; |
|
|
|
|
|
|
|
FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt, |
|
|
|
param.flt_w, |
|
|
|
is_fwd ? 0 : param.flt_h - 2, |
|
|
|
0, |
|
|
|
param.flt_h, |
|
|
|
param.flt_w, |
|
|
|
1, |
|
|
|
1}; |
|
|
|
|
|
|
|
gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w; |
|
|
|
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w; |
|
|
@@ -326,7 +354,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( |
|
|
|
T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w], |
|
|
|
reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w]; |
|
|
|
|
|
|
|
T2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}}; |
|
|
|
float2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}}; |
|
|
|
|
|
|
|
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { |
|
|
|
gl2sh_src.copy(); |
|
|
@@ -335,7 +363,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( |
|
|
|
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) { |
|
|
|
int src_offset = (off_oh + fh + s_h) % SrcTileCount::smem_h * |
|
|
|
int src_offset = (off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h * |
|
|
|
SrcTileCount::smem_w + |
|
|
|
s_w * 2; |
|
|
|
reg_src[s_h * t2_src_unroll_w + s_w] = |
|
|
@@ -373,9 +401,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( |
|
|
|
#pragma unroll |
|
|
|
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { |
|
|
|
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( |
|
|
|
reg_flt[ow % 2][inner_fh * t2_flt_unroll_w + fw], |
|
|
|
reg_flt[ow * stride_w % 2] |
|
|
|
[inner_fh * t2_flt_unroll_w + fw], |
|
|
|
reg_src[(inner_fh + oh) * t2_src_unroll_w + fw + |
|
|
|
ow / 2], |
|
|
|
ow * stride_w / 2], |
|
|
|
sum[oh * t2_out_unroll_w + ow]); |
|
|
|
} |
|
|
|
} |
|
|
@@ -392,7 +421,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( |
|
|
|
|
|
|
|
for (int o = 0; o < OutTileConfig::unroll_size; ++o) { |
|
|
|
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) { |
|
|
|
sum[o] += __shfl_xor(sum[o], i, 32); |
|
|
|
sum[o].x += __shfl_xor(sum[o].x, i, 32); |
|
|
|
sum[o].y += __shfl_xor(sum[o].y, i, 32); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -406,9 +436,9 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( |
|
|
|
int out_w_idx = out_start_w + j; |
|
|
|
if (out_w_idx >= param.out_w) |
|
|
|
return; |
|
|
|
out_base_ptr[out_h_idx * param.out_w + out_w_idx] = |
|
|
|
out_base_ptr[out_h_idx * param.out_w + out_w_idx] = __float2half( |
|
|
|
sum[i * OutTileConfig::unroll_w + j].x + |
|
|
|
sum[i * OutTileConfig::unroll_w + j].y; |
|
|
|
sum[i * OutTileConfig::unroll_w + j].y); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
@@ -433,21 +463,19 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( |
|
|
|
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, |
|
|
|
off_oh = threadIdx.y, off_ow = threadIdx.x; |
|
|
|
|
|
|
|
const int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2; |
|
|
|
const int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2; |
|
|
|
const int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2; |
|
|
|
|
|
|
|
extern __shared__ __align__(8) unsigned char smem[]; |
|
|
|
static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); |
|
|
|
T* smem_src = reinterpret_cast<T*>(smem); |
|
|
|
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]); |
|
|
|
int stride_h = is_fwd ? param.stride_h : 1; |
|
|
|
int stride_w = is_fwd ? param.stride_w : 1; |
|
|
|
|
|
|
|
int off_ichannel = off_ochannel / param.chl_mul, |
|
|
|
off_fchannel = off_ichannel % param.src_chl, |
|
|
|
out_start_h = off_obh * OutTileConfig::block_h, |
|
|
|
out_start_w = off_obw * OutTileConfig::block_w, |
|
|
|
src_start_h = out_start_h - param.pad_h, |
|
|
|
src_start_w = out_start_w - param.pad_w, |
|
|
|
src_start_h = out_start_h * stride_h - param.pad_h, |
|
|
|
src_start_w = out_start_w * stride_w - param.pad_w, |
|
|
|
out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h; |
|
|
|
|
|
|
|
T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w; |
|
|
@@ -455,12 +483,28 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( |
|
|
|
|
|
|
|
T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w; |
|
|
|
|
|
|
|
SrcGlobal2ShareVisitor gl2sh_src( |
|
|
|
smem_src, param.src_w, src_start_h, src_start_w, param.src_h, param.src_w); |
|
|
|
|
|
|
|
FilterGlobal2ShareVisitor gl2sh_flt = { |
|
|
|
smem_flt, param.flt_w, is_fwd ? 0 : param.flt_h - 2, |
|
|
|
0, param.flt_h, param.flt_w}; |
|
|
|
SrcGlobal2ShareVisitor gl2sh_src = { |
|
|
|
smem_src, |
|
|
|
param.src_w, |
|
|
|
is_fwd ? src_start_h |
|
|
|
: src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - |
|
|
|
param.src_h * param.stride_h / 2), |
|
|
|
is_fwd ? src_start_w |
|
|
|
: src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - |
|
|
|
param.src_w * param.stride_w / 2), |
|
|
|
is_fwd ? param.src_h : param.src_h * param.stride_h, |
|
|
|
is_fwd ? param.src_w : param.src_w * param.stride_w, |
|
|
|
is_fwd ? 1 : param.stride_h, |
|
|
|
is_fwd ? 1 : param.stride_w}; |
|
|
|
|
|
|
|
FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt, |
|
|
|
param.flt_w, |
|
|
|
is_fwd ? 0 : param.flt_h - 2, |
|
|
|
0, |
|
|
|
param.flt_h, |
|
|
|
param.flt_w, |
|
|
|
1, |
|
|
|
1}; |
|
|
|
|
|
|
|
gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w; |
|
|
|
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w; |
|
|
@@ -470,10 +514,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( |
|
|
|
|
|
|
|
__syncthreads(); |
|
|
|
|
|
|
|
T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w], |
|
|
|
reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w]; |
|
|
|
T reg_src[SrcTileConfig::unroll_h * SrcTileConfig::unroll_w], |
|
|
|
reg_flt[FilterTileConfig::unroll_h * FilterTileConfig::unroll_w]; |
|
|
|
|
|
|
|
T2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}}; |
|
|
|
T sum[OutTileConfig::unroll_size] = {0.0}; |
|
|
|
|
|
|
|
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { |
|
|
|
gl2sh_src.copy(); |
|
|
@@ -481,34 +525,28 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( |
|
|
|
#pragma unroll |
|
|
|
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) { |
|
|
|
int src_offset = (off_oh + fh + s_h) % SrcTileCount::smem_h * |
|
|
|
SrcTileCount::smem_w + |
|
|
|
s_w * 2; |
|
|
|
reg_src[s_h * t2_src_unroll_w + s_w] = |
|
|
|
*reinterpret_cast<T2*>(smem_src_ptr + src_offset); |
|
|
|
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { |
|
|
|
reg_src[s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr |
|
|
|
[(off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h * |
|
|
|
SrcTileCount::smem_w + |
|
|
|
s_w]; |
|
|
|
if (off_ochannel == 0 && off_obw == 0 && off_obh == 0 && off_oh == 30 && |
|
|
|
off_ow == 0) { |
|
|
|
printf("reg_src[%d] = %f\n", s_h * SrcTileConfig::unroll_w + s_w, |
|
|
|
reg_src[s_h * SrcTileConfig::unroll_w + s_w]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { |
|
|
|
#pragma unroll |
|
|
|
for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) { |
|
|
|
int flt_offset = |
|
|
|
(fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + |
|
|
|
f_w * 2; |
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + f_w] = |
|
|
|
*reinterpret_cast<T2*>(smem_flt_ptr + flt_offset); |
|
|
|
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = { |
|
|
|
f_w > 0 ? reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y |
|
|
|
: static_cast<T>(0.0), |
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; |
|
|
|
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { |
|
|
|
reg_flt[f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr |
|
|
|
[(fh + f_h) % FilterTileCount::smem_h * |
|
|
|
FilterTileCount::smem_w + |
|
|
|
f_w]; |
|
|
|
} |
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = { |
|
|
|
static_cast<T>(0.0), static_cast<T>(0.0)}; |
|
|
|
reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = { |
|
|
|
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, |
|
|
|
static_cast<T>(0.0)}; |
|
|
|
} |
|
|
|
|
|
|
|
#pragma unroll |
|
|
@@ -516,14 +554,22 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( |
|
|
|
#pragma unroll |
|
|
|
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { |
|
|
|
#pragma unroll |
|
|
|
for (int fw = 0; fw < t2_flt_unroll_w; ++fw) { |
|
|
|
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) { |
|
|
|
#pragma unroll |
|
|
|
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { |
|
|
|
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( |
|
|
|
reg_flt[ow % 2][inner_fh * t2_flt_unroll_w + fw], |
|
|
|
reg_src[(inner_fh + oh) * t2_src_unroll_w + fw + |
|
|
|
ow / 2], |
|
|
|
sum[oh * t2_out_unroll_w + ow]); |
|
|
|
sum[oh * OutTileConfig::unroll_w + ow] += |
|
|
|
reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] * |
|
|
|
reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw + |
|
|
|
ow * stride_w]; |
|
|
|
if (off_ochannel == 0 && off_obw == 0 && off_obh == 0 && |
|
|
|
off_oh == 30) { |
|
|
|
printf("sum[%d] += %f * %f\nsum = %f\n", |
|
|
|
oh * OutTileConfig::unroll_w + ow, |
|
|
|
reg_flt[inner_fh * FilterTileConfig::unroll_w + fw], |
|
|
|
reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + |
|
|
|
fw + ow * stride_w], |
|
|
|
sum[oh * OutTileConfig::unroll_w + ow]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
@@ -539,8 +585,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( |
|
|
|
|
|
|
|
for (int o = 0; o < OutTileConfig::unroll_size; ++o) { |
|
|
|
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) { |
|
|
|
sum[o].x += __shfl_xor(sum[o].x, i, 32); |
|
|
|
sum[o].y += __shfl_xor(sum[o].y, i, 32); |
|
|
|
sum[o] += __shfl_xor(sum[o], i, 32); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -555,8 +600,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( |
|
|
|
if (out_w_idx >= param.out_w) |
|
|
|
return; |
|
|
|
out_base_ptr[out_h_idx * param.out_w + out_w_idx] = |
|
|
|
sum[i * OutTileConfig::unroll_w + j].x + |
|
|
|
sum[i * OutTileConfig::unroll_w + j].y; |
|
|
|
sum[i * OutTileConfig::unroll_w + j]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
@@ -565,7 +609,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( |
|
|
|
|
|
|
|
template < |
|
|
|
typename T, typename T2, DepthwiseConv2dDirection kDirection, int unroll_fw, |
|
|
|
int unroll_ow> |
|
|
|
int unroll_ow, int stride> |
|
|
|
void LaunchDepthwiseConv2dGPUSmall( |
|
|
|
const Param& param, const T* input, const T* filter, T* output, |
|
|
|
cudaStream_t stream) { |
|
|
@@ -574,8 +618,9 @@ void LaunchDepthwiseConv2dGPUSmall( |
|
|
|
using FilterTileConfig = FilterTileConfig<unroll_fh, unroll_fw>; |
|
|
|
using ThreadConfig = ThreadConfig<4, 32>; |
|
|
|
using OutTileConfig = OutTileConfig<ThreadConfig, unroll_oh, unroll_ow>; |
|
|
|
using IConvTrait = |
|
|
|
ConvTrait<T, kDirection, ThreadConfig, OutTileConfig, FilterTileConfig>; |
|
|
|
using IConvTrait = ConvTrait< |
|
|
|
T, kDirection, ThreadConfig, OutTileConfig, FilterTileConfig, stride, |
|
|
|
stride>; |
|
|
|
using SrcTileCount = typename IConvTrait::SrcTileCount; |
|
|
|
using FilterTileCount = typename IConvTrait::FilterTileCount; |
|
|
|
|
|
|
@@ -593,10 +638,17 @@ void LaunchDepthwiseConv2dGPUSmall( |
|
|
|
after_kernel_launch(); |
|
|
|
} |
|
|
|
|
|
|
|
#define INSTANCE_AB(type1, type2, a, b, direction) \ |
|
|
|
if (param.out_w > b * 4) { \ |
|
|
|
LaunchDepthwiseConv2dGPUSmall<type1, type2, direction, a + 2, b + 1>( \ |
|
|
|
param, src, flt, dst, stream); \ |
|
|
|
#define INSTANCE_AB(type1, type2, a, b, direction) \ |
|
|
|
if (param.out_w > b * 4) { \ |
|
|
|
printf("param.out_w = %d, b = %d\n", param.out_w, b); \ |
|
|
|
if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \ |
|
|
|
(param.stride_h == 1 && param.stride_w == 1)) { \ |
|
|
|
LaunchDepthwiseConv2dGPUSmall<type1, type2, direction, a + 2, b + 1, 1>( \ |
|
|
|
param, src, flt, dst, stream); \ |
|
|
|
} else if (param.stride_h == 2 && param.stride_w == 2) { \ |
|
|
|
LaunchDepthwiseConv2dGPUSmall<type1, type2, direction, a + 2, b + 1, 2>( \ |
|
|
|
param, src, flt, dst, stream); \ |
|
|
|
} \ |
|
|
|
} |
|
|
|
|
|
|
|
#define INSTANCE_A(type1, type2, a, direction) \ |
|
|
|