|
|
@@ -217,7 +217,7 @@ __device__ __forceinline__ void Global2SharedMem< |
|
|
|
// Backprop input direction is the same as forward direction with the filter |
|
|
|
// rotated by 180°. |
|
|
|
#if CUDA_VERSION >= 9000 |
|
|
|
template <typename ConvTrait, DepthwiseConv2dDirection kDirection> |
|
|
|
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride> |
|
|
|
__global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
const Param param, const __half* input, const __half* filter, __half* output) { |
|
|
|
using T = __half; |
|
|
@@ -230,7 +230,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
using FilterTileCount = typename ConvTrait::FilterTileCount; |
|
|
|
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor; |
|
|
|
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor; |
|
|
|
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); |
|
|
|
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); |
|
|
|
|
|
|
|
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, |
|
|
|
off_oh = threadIdx.y, off_ow = threadIdx.x; |
|
|
@@ -243,8 +243,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); |
|
|
|
T* smem_src = reinterpret_cast<T*>(smem); |
|
|
|
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]); |
|
|
|
int stride_h = is_fwd ? param.stride_h : 1; |
|
|
|
int stride_w = is_fwd ? param.stride_w : 1; |
|
|
|
constexpr int stride_h = is_fwd ? stride : 1; |
|
|
|
constexpr int stride_w = is_fwd ? stride : 1; |
|
|
|
|
|
|
|
int off_ichannel = off_ochannel / param.chl_mul, |
|
|
|
off_fchannel = off_ichannel % param.src_chl, |
|
|
@@ -385,7 +385,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename ConvTrait, DepthwiseConv2dDirection kDirection> |
|
|
|
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride> |
|
|
|
__global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
const Param param, const __half* input, const __half* filter, __half* output) { |
|
|
|
using T = __half; |
|
|
@@ -398,7 +398,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
using FilterTileCount = typename ConvTrait::FilterTileCount; |
|
|
|
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor; |
|
|
|
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor; |
|
|
|
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); |
|
|
|
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); |
|
|
|
|
|
|
|
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, |
|
|
|
off_oh = threadIdx.y, off_ow = threadIdx.x; |
|
|
@@ -411,8 +411,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); |
|
|
|
T* smem_src = reinterpret_cast<T*>(smem); |
|
|
|
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]); |
|
|
|
int stride_h = is_fwd ? param.stride_h : 1; |
|
|
|
int stride_w = is_fwd ? param.stride_w : 1; |
|
|
|
constexpr int stride_h = is_fwd ? stride : 1; |
|
|
|
constexpr int stride_w = is_fwd ? stride : 1; |
|
|
|
|
|
|
|
int off_ichannel = off_ochannel / param.chl_mul, |
|
|
|
off_fchannel = off_ichannel % param.src_chl, |
|
|
@@ -555,7 +555,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
template <typename ConvTrait, DepthwiseConv2dDirection kDirection> |
|
|
|
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride> |
|
|
|
__global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
const Param param, const float* input, const float* filter, float* output) { |
|
|
|
using T = float; |
|
|
@@ -568,7 +568,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
using FilterTileCount = typename ConvTrait::FilterTileCount; |
|
|
|
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor; |
|
|
|
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor; |
|
|
|
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); |
|
|
|
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); |
|
|
|
|
|
|
|
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, |
|
|
|
off_oh = threadIdx.y, off_ow = threadIdx.x; |
|
|
@@ -577,8 +577,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); |
|
|
|
T* smem_src = reinterpret_cast<T*>(smem); |
|
|
|
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]); |
|
|
|
int stride_h = is_fwd ? param.stride_h : 1; |
|
|
|
int stride_w = is_fwd ? param.stride_w : 1; |
|
|
|
constexpr int stride_h = is_fwd ? stride : 1; |
|
|
|
constexpr int stride_w = is_fwd ? stride : 1; |
|
|
|
|
|
|
|
int off_ichannel = off_ochannel / param.chl_mul, |
|
|
|
off_fchannel = off_ichannel % param.src_chl, |
|
|
@@ -703,7 +703,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename ConvTrait, DepthwiseConv2dDirection kDirection> |
|
|
|
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride> |
|
|
|
__global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
const Param param, const float* input, const float* filter, float* output) { |
|
|
|
using T = float; |
|
|
@@ -716,7 +716,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
using FilterTileCount = typename ConvTrait::FilterTileCount; |
|
|
|
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor; |
|
|
|
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor; |
|
|
|
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); |
|
|
|
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); |
|
|
|
|
|
|
|
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, |
|
|
|
off_oh = threadIdx.y, off_ow = threadIdx.x; |
|
|
@@ -725,8 +725,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( |
|
|
|
static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); |
|
|
|
T* smem_src = reinterpret_cast<T*>(smem); |
|
|
|
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]); |
|
|
|
int stride_h = is_fwd ? param.stride_h : 1; |
|
|
|
int stride_w = is_fwd ? param.stride_w : 1; |
|
|
|
constexpr int stride_h = is_fwd ? stride : 1; |
|
|
|
constexpr int stride_w = is_fwd ? stride : 1; |
|
|
|
|
|
|
|
int off_ichannel = off_ochannel / param.chl_mul, |
|
|
|
off_fchannel = off_ichannel % param.src_chl, |
|
|
@@ -879,16 +879,16 @@ void LaunchDepthwiseConv2dGPU( |
|
|
|
void (*kernel)(const Param, const T*, const T*, T*); |
|
|
|
|
|
|
|
if (param.is_compute_deafult) { |
|
|
|
kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection>; |
|
|
|
kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection, stride>; |
|
|
|
} else { |
|
|
|
kernel = DepthwiseConv2dGPUKernelNCHWC32<IConvTrait, kDirection>; |
|
|
|
kernel = DepthwiseConv2dGPUKernelNCHWC32<IConvTrait, kDirection, stride>; |
|
|
|
} |
|
|
|
kernel<<<grid, block, shared_storage, stream>>>(param, input, filter, output); |
|
|
|
after_kernel_launch(); |
|
|
|
} |
|
|
|
|
|
|
|
#define INSTANCE_AB(type1, type2, a, b, direction) \ |
|
|
|
if (param.out_w > b * 4) { \ |
|
|
|
if (param.out_w > b * 4 || b == 3) { \ |
|
|
|
if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \ |
|
|
|
(param.stride_h == 1 && param.stride_w == 1)) { \ |
|
|
|
LaunchDepthwiseConv2dGPU<type1, type2, direction, a + 2, b + 1, 1>( \ |
|
|
@@ -899,12 +899,11 @@ void LaunchDepthwiseConv2dGPU( |
|
|
|
} \ |
|
|
|
} |
|
|
|
|
|
|
|
#define INSTANCE_A(type1, type2, a, direction) \ |
|
|
|
if (param.flt_w > a * 4) { \ |
|
|
|
INSTANCE_AB(type1, type2, a, 15, direction) \ |
|
|
|
else INSTANCE_AB(type1, type2, a, 14, direction) else INSTANCE_AB(type1, type2, a, 13, direction) else INSTANCE_AB(type1, type2, a, 12, direction) else INSTANCE_AB(type1, type2, a, 11, direction) else INSTANCE_AB(type1, type2, a, 10, direction) else INSTANCE_AB( \ |
|
|
|
type1, type2, \ |
|
|
|
a, 9, direction) else INSTANCE_AB(type1, type2, a, 8, direction) else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB(type1, type2, a, 6, direction) else INSTANCE_AB(type1, type2, a, 5, direction) else INSTANCE_AB(type1, type2, a, 4, direction) else INSTANCE_AB(type1, type2, a, 3, direction) else INSTANCE_AB(type1, type2, a, 2, direction) else INSTANCE_AB(type1, type2, a, 1, direction) else INSTANCE_AB(type1, type2, a, 0, direction) \ |
|
|
|
#define INSTANCE_A(type1, type2, a, direction) \ |
|
|
|
if (param.flt_w > a * 4) { \ |
|
|
|
INSTANCE_AB(type1, type2, a, 15, direction) \ |
|
|
|
else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB( \ |
|
|
|
type1, type2, a, 3, direction) \ |
|
|
|
} |
|
|
|
|
|
|
|
#define INSTANCE(type1, type2, direction) \ |
|
|
|