Browse Source

fix(cuda): fix direct conv speed and memory problem

GitOrigin-RevId: 6faeeff3b8
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
ac26bdcef5
2 changed files with 28 additions and 31 deletions
  1. +4
    -6
      dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh
  2. +24
    -25
      dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh

+ 4
- 6
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh View File

@@ -142,7 +142,7 @@ struct ConvTraitInner {
} }


#define CHECK_AB_BWD(a, b) \ #define CHECK_AB_BWD(a, b) \
if (param.out_w > b * 4) { \
if (param.out_w > b * 4 || b == 3) { \
using FilterTileConfig_ = FilterTileConfig<unroll_fh, a + 2>; \ using FilterTileConfig_ = FilterTileConfig<unroll_fh, a + 2>; \
using ThreadConfig_ = ThreadConfig<4, 32>; \ using ThreadConfig_ = ThreadConfig<4, 32>; \
using OutTileConfig_ = OutTileConfig<ThreadConfig_, unroll_oh, b + 1>; \ using OutTileConfig_ = OutTileConfig<ThreadConfig_, unroll_oh, b + 1>; \
@@ -165,11 +165,9 @@ struct ConvTraitInner {
return true; \ return true; \
} }


#define CHECK_A(a, cb) \
if (param.flt_w > a * 4) { \
CHECK_AB_##cb( \
a, \
15) else CHECK_AB_##cb(a, 14) else CHECK_AB_##cb(a, 13) else CHECK_AB_##cb(a, 12) else CHECK_AB_##cb(a, 11) else CHECK_AB_##cb(a, 10) else CHECK_AB_##cb(a, 9) else CHECK_AB_##cb(a, 8) else CHECK_AB_##cb(a, 7) else CHECK_AB_##cb(a, 6) else CHECK_AB_##cb(a, 5) else CHECK_AB_##cb(a, 4) else CHECK_AB_##cb(a, 3) else CHECK_AB_##cb(a, 2) else CHECK_AB_##cb(a, 1) else CHECK_AB_##cb(a, 0) \
#define CHECK_A(a, cb) \
if (param.flt_w > a * 4) { \
CHECK_AB_##cb(a, 15) else CHECK_AB_##cb(a, 7) else CHECK_AB_##cb(a, 3) \
} }


#define CHECK(cb) \ #define CHECK(cb) \


+ 24
- 25
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh View File

@@ -217,7 +217,7 @@ __device__ __forceinline__ void Global2SharedMem<
// Backprop input direction is the same as forward direction with the filter // Backprop input direction is the same as forward direction with the filter
// rotated by 180°. // rotated by 180°.
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride>
__global__ void DepthwiseConv2dGPUKernelNCHW( __global__ void DepthwiseConv2dGPUKernelNCHW(
const Param param, const __half* input, const __half* filter, __half* output) { const Param param, const __half* input, const __half* filter, __half* output) {
using T = __half; using T = __half;
@@ -230,7 +230,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
using FilterTileCount = typename ConvTrait::FilterTileCount; using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor; using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor; using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);


int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x; off_oh = threadIdx.y, off_ow = threadIdx.x;
@@ -243,8 +243,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem); T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]); T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int stride_h = is_fwd ? param.stride_h : 1;
int stride_w = is_fwd ? param.stride_w : 1;
constexpr int stride_h = is_fwd ? stride : 1;
constexpr int stride_w = is_fwd ? stride : 1;


int off_ichannel = off_ochannel / param.chl_mul, int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl, off_fchannel = off_ichannel % param.src_chl,
@@ -385,7 +385,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
} }
} }


template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride>
__global__ void DepthwiseConv2dGPUKernelNCHWC32( __global__ void DepthwiseConv2dGPUKernelNCHWC32(
const Param param, const __half* input, const __half* filter, __half* output) { const Param param, const __half* input, const __half* filter, __half* output) {
using T = __half; using T = __half;
@@ -398,7 +398,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
using FilterTileCount = typename ConvTrait::FilterTileCount; using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor; using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor; using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);


int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x; off_oh = threadIdx.y, off_ow = threadIdx.x;
@@ -411,8 +411,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem); T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]); T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int stride_h = is_fwd ? param.stride_h : 1;
int stride_w = is_fwd ? param.stride_w : 1;
constexpr int stride_h = is_fwd ? stride : 1;
constexpr int stride_w = is_fwd ? stride : 1;


int off_ichannel = off_ochannel / param.chl_mul, int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl, off_fchannel = off_ichannel % param.src_chl,
@@ -555,7 +555,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
} }
#endif #endif


template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride>
__global__ void DepthwiseConv2dGPUKernelNCHW( __global__ void DepthwiseConv2dGPUKernelNCHW(
const Param param, const float* input, const float* filter, float* output) { const Param param, const float* input, const float* filter, float* output) {
using T = float; using T = float;
@@ -568,7 +568,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
using FilterTileCount = typename ConvTrait::FilterTileCount; using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor; using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor; using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);


int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x; off_oh = threadIdx.y, off_ow = threadIdx.x;
@@ -577,8 +577,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem); T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]); T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int stride_h = is_fwd ? param.stride_h : 1;
int stride_w = is_fwd ? param.stride_w : 1;
constexpr int stride_h = is_fwd ? stride : 1;
constexpr int stride_w = is_fwd ? stride : 1;


int off_ichannel = off_ochannel / param.chl_mul, int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl, off_fchannel = off_ichannel % param.src_chl,
@@ -703,7 +703,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
} }
} }


template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride>
__global__ void DepthwiseConv2dGPUKernelNCHWC32( __global__ void DepthwiseConv2dGPUKernelNCHWC32(
const Param param, const float* input, const float* filter, float* output) { const Param param, const float* input, const float* filter, float* output) {
using T = float; using T = float;
@@ -716,7 +716,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
using FilterTileCount = typename ConvTrait::FilterTileCount; using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor; using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor; using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);


int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x; off_oh = threadIdx.y, off_ow = threadIdx.x;
@@ -725,8 +725,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem); T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]); T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int stride_h = is_fwd ? param.stride_h : 1;
int stride_w = is_fwd ? param.stride_w : 1;
constexpr int stride_h = is_fwd ? stride : 1;
constexpr int stride_w = is_fwd ? stride : 1;


int off_ichannel = off_ochannel / param.chl_mul, int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl, off_fchannel = off_ichannel % param.src_chl,
@@ -879,16 +879,16 @@ void LaunchDepthwiseConv2dGPU(
void (*kernel)(const Param, const T*, const T*, T*); void (*kernel)(const Param, const T*, const T*, T*);


if (param.is_compute_deafult) { if (param.is_compute_deafult) {
kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection>;
kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection, stride>;
} else { } else {
kernel = DepthwiseConv2dGPUKernelNCHWC32<IConvTrait, kDirection>;
kernel = DepthwiseConv2dGPUKernelNCHWC32<IConvTrait, kDirection, stride>;
} }
kernel<<<grid, block, shared_storage, stream>>>(param, input, filter, output); kernel<<<grid, block, shared_storage, stream>>>(param, input, filter, output);
after_kernel_launch(); after_kernel_launch();
} }


#define INSTANCE_AB(type1, type2, a, b, direction) \ #define INSTANCE_AB(type1, type2, a, b, direction) \
if (param.out_w > b * 4) { \
if (param.out_w > b * 4 || b == 3) { \
if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \ if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \
(param.stride_h == 1 && param.stride_w == 1)) { \ (param.stride_h == 1 && param.stride_w == 1)) { \
LaunchDepthwiseConv2dGPU<type1, type2, direction, a + 2, b + 1, 1>( \ LaunchDepthwiseConv2dGPU<type1, type2, direction, a + 2, b + 1, 1>( \
@@ -899,12 +899,11 @@ void LaunchDepthwiseConv2dGPU(
} \ } \
} }


#define INSTANCE_A(type1, type2, a, direction) \
if (param.flt_w > a * 4) { \
INSTANCE_AB(type1, type2, a, 15, direction) \
else INSTANCE_AB(type1, type2, a, 14, direction) else INSTANCE_AB(type1, type2, a, 13, direction) else INSTANCE_AB(type1, type2, a, 12, direction) else INSTANCE_AB(type1, type2, a, 11, direction) else INSTANCE_AB(type1, type2, a, 10, direction) else INSTANCE_AB( \
type1, type2, \
a, 9, direction) else INSTANCE_AB(type1, type2, a, 8, direction) else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB(type1, type2, a, 6, direction) else INSTANCE_AB(type1, type2, a, 5, direction) else INSTANCE_AB(type1, type2, a, 4, direction) else INSTANCE_AB(type1, type2, a, 3, direction) else INSTANCE_AB(type1, type2, a, 2, direction) else INSTANCE_AB(type1, type2, a, 1, direction) else INSTANCE_AB(type1, type2, a, 0, direction) \
#define INSTANCE_A(type1, type2, a, direction) \
if (param.flt_w > a * 4) { \
INSTANCE_AB(type1, type2, a, 15, direction) \
else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB( \
type1, type2, a, 3, direction) \
} }


#define INSTANCE(type1, type2, direction) \ #define INSTANCE(type1, type2, direction) \


Loading…
Cancel
Save