Browse Source

fix(cuda): fix direct conv speed and memory problem

GitOrigin-RevId: 6faeeff3b8
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
ac26bdcef5
2 changed files with 28 additions and 31 deletions
  1. +4
    -6
      dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh
  2. +24
    -25
      dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh

+ 4
- 6
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh View File

@@ -142,7 +142,7 @@ struct ConvTraitInner {
}

#define CHECK_AB_BWD(a, b) \
if (param.out_w > b * 4) { \
if (param.out_w > b * 4 || b == 3) { \
using FilterTileConfig_ = FilterTileConfig<unroll_fh, a + 2>; \
using ThreadConfig_ = ThreadConfig<4, 32>; \
using OutTileConfig_ = OutTileConfig<ThreadConfig_, unroll_oh, b + 1>; \
@@ -165,11 +165,9 @@ struct ConvTraitInner {
return true; \
}

#define CHECK_A(a, cb) \
if (param.flt_w > a * 4) { \
CHECK_AB_##cb( \
a, \
15) else CHECK_AB_##cb(a, 14) else CHECK_AB_##cb(a, 13) else CHECK_AB_##cb(a, 12) else CHECK_AB_##cb(a, 11) else CHECK_AB_##cb(a, 10) else CHECK_AB_##cb(a, 9) else CHECK_AB_##cb(a, 8) else CHECK_AB_##cb(a, 7) else CHECK_AB_##cb(a, 6) else CHECK_AB_##cb(a, 5) else CHECK_AB_##cb(a, 4) else CHECK_AB_##cb(a, 3) else CHECK_AB_##cb(a, 2) else CHECK_AB_##cb(a, 1) else CHECK_AB_##cb(a, 0) \
#define CHECK_A(a, cb) \
if (param.flt_w > a * 4) { \
CHECK_AB_##cb(a, 15) else CHECK_AB_##cb(a, 7) else CHECK_AB_##cb(a, 3) \
}

#define CHECK(cb) \


+ 24
- 25
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh View File

@@ -217,7 +217,7 @@ __device__ __forceinline__ void Global2SharedMem<
// Backprop input direction is the same as forward direction with the filter
// rotated by 180°.
#if CUDA_VERSION >= 9000
template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride>
__global__ void DepthwiseConv2dGPUKernelNCHW(
const Param param, const __half* input, const __half* filter, __half* output) {
using T = __half;
@@ -230,7 +230,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);

int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
@@ -243,8 +243,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int stride_h = is_fwd ? param.stride_h : 1;
int stride_w = is_fwd ? param.stride_w : 1;
constexpr int stride_h = is_fwd ? stride : 1;
constexpr int stride_w = is_fwd ? stride : 1;

int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
@@ -385,7 +385,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
}
}

template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride>
__global__ void DepthwiseConv2dGPUKernelNCHWC32(
const Param param, const __half* input, const __half* filter, __half* output) {
using T = __half;
@@ -398,7 +398,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);

int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
@@ -411,8 +411,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int stride_h = is_fwd ? param.stride_h : 1;
int stride_w = is_fwd ? param.stride_w : 1;
constexpr int stride_h = is_fwd ? stride : 1;
constexpr int stride_w = is_fwd ? stride : 1;

int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
@@ -555,7 +555,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
}
#endif

template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride>
__global__ void DepthwiseConv2dGPUKernelNCHW(
const Param param, const float* input, const float* filter, float* output) {
using T = float;
@@ -568,7 +568,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);

int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
@@ -577,8 +577,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int stride_h = is_fwd ? param.stride_h : 1;
int stride_w = is_fwd ? param.stride_w : 1;
constexpr int stride_h = is_fwd ? stride : 1;
constexpr int stride_w = is_fwd ? stride : 1;

int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
@@ -703,7 +703,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
}
}

template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
template <typename ConvTrait, DepthwiseConv2dDirection kDirection, int stride>
__global__ void DepthwiseConv2dGPUKernelNCHWC32(
const Param param, const float* input, const float* filter, float* output) {
using T = float;
@@ -716,7 +716,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);

int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
@@ -725,8 +725,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int stride_h = is_fwd ? param.stride_h : 1;
int stride_w = is_fwd ? param.stride_w : 1;
constexpr int stride_h = is_fwd ? stride : 1;
constexpr int stride_w = is_fwd ? stride : 1;

int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
@@ -879,16 +879,16 @@ void LaunchDepthwiseConv2dGPU(
void (*kernel)(const Param, const T*, const T*, T*);

if (param.is_compute_deafult) {
kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection>;
kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection, stride>;
} else {
kernel = DepthwiseConv2dGPUKernelNCHWC32<IConvTrait, kDirection>;
kernel = DepthwiseConv2dGPUKernelNCHWC32<IConvTrait, kDirection, stride>;
}
kernel<<<grid, block, shared_storage, stream>>>(param, input, filter, output);
after_kernel_launch();
}

#define INSTANCE_AB(type1, type2, a, b, direction) \
if (param.out_w > b * 4) { \
if (param.out_w > b * 4 || b == 3) { \
if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \
(param.stride_h == 1 && param.stride_w == 1)) { \
LaunchDepthwiseConv2dGPU<type1, type2, direction, a + 2, b + 1, 1>( \
@@ -899,12 +899,11 @@ void LaunchDepthwiseConv2dGPU(
} \
}

#define INSTANCE_A(type1, type2, a, direction) \
if (param.flt_w > a * 4) { \
INSTANCE_AB(type1, type2, a, 15, direction) \
else INSTANCE_AB(type1, type2, a, 14, direction) else INSTANCE_AB(type1, type2, a, 13, direction) else INSTANCE_AB(type1, type2, a, 12, direction) else INSTANCE_AB(type1, type2, a, 11, direction) else INSTANCE_AB(type1, type2, a, 10, direction) else INSTANCE_AB( \
type1, type2, \
a, 9, direction) else INSTANCE_AB(type1, type2, a, 8, direction) else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB(type1, type2, a, 6, direction) else INSTANCE_AB(type1, type2, a, 5, direction) else INSTANCE_AB(type1, type2, a, 4, direction) else INSTANCE_AB(type1, type2, a, 3, direction) else INSTANCE_AB(type1, type2, a, 2, direction) else INSTANCE_AB(type1, type2, a, 1, direction) else INSTANCE_AB(type1, type2, a, 0, direction) \
#define INSTANCE_A(type1, type2, a, direction) \
if (param.flt_w > a * 4) { \
INSTANCE_AB(type1, type2, a, 15, direction) \
else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB( \
type1, type2, a, 3, direction) \
}

#define INSTANCE(type1, type2, direction) \


Loading…
Cancel
Save