Browse Source

perf(cuda): speedup chanwise conv with small feature map and large filter size

GitOrigin-RevId: e65b2ce856
release-1.8
Megvii Engine Team 王彪 3 years ago
parent
commit
87a2aeebb1
9 changed files with 825 additions and 1 deletions
  1. +2
    -0
      dnn/src/cuda/conv_bias/algo.cpp
  2. +22
    -1
      dnn/src/cuda/conv_bias/algo.h
  3. +446
    -0
      dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl
  4. +48
    -0
      dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu
  5. +4
    -0
      dnn/src/cuda/conv_bias/chanwise/kern.cuh
  6. +109
    -0
      dnn/src/cuda/conv_bias/depthwise_large_filter.cpp
  7. +1
    -0
      dnn/src/cuda/conv_bias/opr_impl.h
  8. +116
    -0
      dnn/test/cuda/conv_bias.cpp
  9. +77
    -0
      dnn/test/cuda/convolution.cpp

+ 2
- 0
dnn/src/cuda/conv_bias/algo.cpp View File

@@ -19,6 +19,7 @@ using namespace cuda;
ConvBiasForwardImpl::AlgoPack::AlgoPack() {
non_cudnn_algos.push_back(&chanwise);
non_cudnn_algos.push_back(&chanwise_small);
non_cudnn_algos.push_back(&depthwise_large_filter);

non_cudnn_algos.push_back(&inplace_matmul);
non_cudnn_algos.push_back(&matmul);
@@ -34,6 +35,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
std::vector<AlgoBase*> conv_algos;
conv_algos.push_back(&chanwise);
conv_algos.push_back(&chanwise_small);
conv_algos.push_back(&depthwise_large_filter);
conv_algos.push_back(&chanwise8x8x32);
for (auto&& algo : cudnn_convs) {
conv_algos.push_back(&algo);


+ 22
- 1
dnn/src/cuda/conv_bias/algo.h View File

@@ -22,7 +22,6 @@
#include "src/cuda/conv_bias/opr_impl.h"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/handle.h"

#include <cuda.h>
#include <memory>
@@ -57,6 +56,7 @@ public:
CUDA_CUDNN_CONVBIAS,
CUDA_CHANWISE,
CUDA_CHANWISE_SMALL,
CUDA_DEPTHWISE_LARGE_FILTER,
CUDA_CHANWISE_INT8X8X32,
CUDA_CUDNN_CONV,
CUDA_INPLACE_MATMUL,
@@ -257,6 +257,26 @@ private:
mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoDepthwiseLargeFilter final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;

const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasForward::algo_name<DirectParam>(
"DEPTHWISE_LARGE_FILTER", {});
}
return m_name.c_str();
}
MEGDNN_DECL_ALGO_TYPE(CUDA_DEPTHWISE_LARGE_FILTER)
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }

private:
mutable std::string m_name;
};

class ConvBiasForwardImpl::AlgoChanwise8x8x32 final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
@@ -1084,6 +1104,7 @@ public:
AlgoFallbackNCHWQS8 fallback_nchw_qs8;
AlgoChanwise chanwise;
AlgoChanwiseSmall chanwise_small;
AlgoDepthwiseLargeFilter depthwise_large_filter;
AlgoChanwise8x8x32 chanwise8x8x32;
AlgoInplaceMatmul inplace_matmul;
AlgoMatmul matmul;


+ 446
- 0
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl View File

@@ -0,0 +1,446 @@
/**
* \file dnn/src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.inl
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "src/cuda/cuda_shfl_compat.cuh"
namespace {

enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD };

template <typename ThreadConfig_, int oh_, int ow_>
struct OutTileConfig {
using ThreadConfig = ThreadConfig_;
static int const unroll_h = oh_;
static int const unroll_w = ThreadConfig::thread_x * ow_;
static int const unroll_size = unroll_h * unroll_w;
static int const block_h = unroll_h * ThreadConfig::thread_y;
static int const block_w = unroll_w;
};

template <int fh_, int fw_>
struct FilterTileConfig {
static int const unroll_h = fh_;
static int const unroll_w = fw_;
static int const unroll_size = unroll_h * unroll_w;
};

template <int x_, int y_>
struct ThreadConfig {
static int const thread_x = x_;
static_assert((thread_x & (thread_x - 1)) == 0, "thread_x must be pow of 2!");
static int const thread_y = y_;
static int const nr_threads = x_ * y_;
};

template <
typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_,
typename TileCount_>
struct Global2SharedMem {
using TileCount = TileCount_;
using ThreadConfig = ThreadConfig_;
T reg[TileCount::reg_w];
const int tidx = threadIdx.x;
const int tidy = threadIdx.y;
const int tid = tidy * ThreadConfig::thread_x + tidx;
const int gl_load_y = tid / TileCount::load_w;
const int gl_load_x = tid - gl_load_y * TileCount::load_w;
const bool is_fwd = (kDirection == DIRECTION_FORWARD);
int w_offset;

T* smem;
int stride;
int start_h, start_w, bound_h, bound_w, ring_smem_h, ring_src_h;
const T* g_ptr;

__device__ __forceinline__
Global2SharedMem(T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w);

__device__ __forceinline__ void first_copy();
__device__ __forceinline__ void copy();
__device__ __forceinline__ void commit();
__device__ __forceinline__ void iter_forward();
__device__ __forceinline__ T* sh_ptr(int y, int x) {
return &smem[y * TileCount::smem_w + x];
}

__device__ __forceinline__ T* sh_ptr_as_copy_t(int y, int x) {
return reinterpret_cast<T*>(sh_ptr(y, x));
}
};

template <
typename ldg_dtype, DepthwiseConv2dDirection kDirection, typename ThreadConfig_,
typename OutTileConfig_, typename FilterTileConfig_>
struct ConvTrait {
using ThreadConfig = ThreadConfig_;
using OutTileConfig = OutTileConfig_;
using FilterTileConfig = FilterTileConfig_;
using CompType = ldg_dtype;

struct SrcTileConfig {
static int const unroll_h =
OutTileConfig::unroll_h + FilterTileConfig::unroll_h - 1;
static int const unroll_w =
OutTileConfig::unroll_w + FilterTileConfig::unroll_w - 1;
static int const unroll_size = unroll_h * unroll_w;
};

struct SrcTileCount {
static int const smem_src_h =
OutTileConfig::block_h + FilterTileConfig::unroll_h - 1;
static int const smem_buff_h = FilterTileConfig::unroll_h;
static int const smem_load_h = smem_src_h + smem_buff_h;
static int const smem_h = smem_load_h + smem_buff_h;
static int const smem_w = OutTileConfig::block_w +
FilterTileConfig::unroll_w * ThreadConfig::thread_x -
1;
static int const smem_size = smem_h * smem_w;
static int const load_w =
smem_w > ThreadConfig::nr_threads ? ThreadConfig::nr_threads : smem_w;
static int const load_h = 1;
static int const reg_h = 1;
static int const reg_w = DIVUP(smem_w, load_w);
static bool constexpr check_bounds_h = smem_h % load_h != 0;
static bool constexpr check_bounds_w = smem_w % load_w != 0;
};

struct FilterTileCount {
static int const smem_flt_h = FilterTileConfig::unroll_h;
static int const smem_buff_h = FilterTileConfig::unroll_h;
static int const smem_load_h = smem_flt_h + smem_buff_h;
static int const smem_h = smem_load_h + smem_buff_h;
static int const smem_w = FilterTileConfig::unroll_w * ThreadConfig::thread_x;
static int const smem_size = smem_h * smem_w;
static int const load_w = smem_w > 32 ? 32 : smem_w;
static int const load_h = ThreadConfig::nr_threads / load_w;
static int const reg_h = 1;
static int const reg_w = DIVUP(smem_w, load_w);
static bool constexpr check_bounds_h = smem_h % load_h != 0;
static bool constexpr check_bounds_w = smem_w % load_w != 0;
};

using SrcGlobal2ShareVisitor = Global2SharedMem<
CompType, DepthwiseConv2dDirection::DIRECTION_FORWARD, ThreadConfig,
SrcTileCount>;
using FilterGlobal2ShareVisitor =
Global2SharedMem<CompType, kDirection, ThreadConfig, FilterTileCount>;
};

template <
typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_,
typename TileCount_>
__device__ __forceinline__
Global2SharedMem<T, kDirection, ThreadConfig_, TileCount_>::Global2SharedMem(
T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w)
: smem(smem_),
stride(stride_),
start_h(s_h),
start_w(s_w),
bound_h(b_h),
bound_w(b_w),
ring_smem_h(TileCount::smem_load_h) {
if (is_fwd) {
ring_src_h = s_h + TileCount::smem_load_h;
w_offset = 0;
} else {
ring_src_h = s_h - 1;
w_offset = TileCount::smem_w - b_w;
}
}

template <
typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_,
typename TileCount_>
__device__ __forceinline__ void Global2SharedMem<
T, kDirection, ThreadConfig_, TileCount_>::first_copy() {
static int const load_w = TileCount::smem_w > 32 ? 32 : TileCount::smem_w;
static int const load_h = ThreadConfig::nr_threads / load_w;
static int const h_per_thread = DIVUP(TileCount::smem_load_h, load_h);
static int const w_per_thread = DIVUP(TileCount::smem_w, load_w);
static bool constexpr check_bounds_h = TileCount::smem_load_h % load_h != 0;
static bool constexpr check_bounds_w = TileCount::smem_w % load_w != 0;
const int y_base_idx = tid / load_w;
const int x_base_idx = tid - y_base_idx * load_w;
#pragma unroll
for (int i = 0; i < h_per_thread; ++i) {
int smem_h_idx = y_base_idx + i * load_h;
int src_h_idx;
if (is_fwd) {
src_h_idx = start_h + smem_h_idx;
} else {
src_h_idx = start_h + TileCount::smem_load_h - smem_h_idx - 1;
}
if (check_bounds_h && smem_h_idx >= TileCount::smem_load_h)
continue;
#pragma unroll
for (int j = 0; j < w_per_thread; ++j) {
int smem_w_idx = x_base_idx + j * load_w;
int src_w_idx;
if (is_fwd) {
src_w_idx = start_w + smem_w_idx;
} else {
src_w_idx = start_w + TileCount::smem_w - w_offset - smem_w_idx - 1;
}
if (check_bounds_w && smem_w_idx >= TileCount::smem_w)
continue;
T val = 0.0f;
if (src_h_idx >= 0 && src_h_idx < bound_h && src_w_idx >= 0 &&
src_w_idx < bound_w &&
(is_fwd || (TileCount::smem_load_h - smem_h_idx - 1 >= 0 &&
TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) {
val = g_ptr[src_h_idx * stride + src_w_idx];
}
*(sh_ptr_as_copy_t(smem_h_idx, smem_w_idx)) = val;
}
}
}

template <
typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_,
typename TileCount_>
__device__ __forceinline__ void Global2SharedMem<
T, kDirection, ThreadConfig_, TileCount_>::copy() {
#pragma unroll
for (int j = 0; j < TileCount::reg_w; ++j) {
int smem_w_idx = gl_load_x + j * TileCount::load_w;
int src_w_idx;
if (is_fwd) {
src_w_idx = start_w + smem_w_idx;
} else {
src_w_idx = start_w + TileCount::smem_w - w_offset - smem_w_idx - 1;
}
if (TileCount::check_bounds_w && smem_w_idx >= TileCount::smem_w)
continue;
T val = 0.0f;
if (ring_src_h >= 0 && ring_src_h < bound_h && src_w_idx >= 0 &&
src_w_idx < bound_w &&
(is_fwd || TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0)) {
val = g_ptr[ring_src_h * stride + src_w_idx];
}
reg[j] = val;
}
}

template <
typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_,
typename TileCount_>
__device__ __forceinline__ void Global2SharedMem<
T, kDirection, ThreadConfig_, TileCount_>::commit() {
#pragma unroll
for (int j = 0; j < TileCount::reg_w; ++j) {
int smem_w_idx = gl_load_x + j * TileCount::load_w;

if (TileCount::check_bounds_w && smem_w_idx >= TileCount::smem_w)
continue;

*(sh_ptr_as_copy_t(ring_smem_h, smem_w_idx)) = reg[j];
}
}

template <
typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_,
typename TileCount_>
__device__ __forceinline__ void Global2SharedMem<
T, kDirection, ThreadConfig_, TileCount_>::iter_forward() {
if (is_fwd) {
ring_src_h++;
} else {
ring_src_h--;
}
ring_smem_h = (ring_smem_h + 1) % TileCount::smem_h;
}

// CUDA kernel to compute the depthwise convolution forward pass in NCHW format,
// tailored for small images up to 32x32. Stride and depth multiplier must be 1.
// Padding must be 'SAME', which allows to reuse the index computation. Only
// use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args) returns true.
// Tiles of the input and filter tensors are loaded into shared memory before
// performing the convolution. Each thread handles two elements per iteration,
// one each in the lower and upper half of a tile.
// Backprop input direction is the same as forward direction with the filter
// rotated by 180°.
template <typename T, typename ConvTrait, DepthwiseConv2dDirection kDirection>
__global__ void DepthwiseConv2dGPUKernelNCHWSmall(
const Param param, const T* input, const T* filter, T* output) {
using ThreadConfig = typename ConvTrait::ThreadConfig;
using SrcTileConfig = typename ConvTrait::SrcTileConfig;
using FilterTileConfig = typename ConvTrait::FilterTileConfig;
using OutTileConfig = typename ConvTrait::OutTileConfig;
using SrcTileCount = typename ConvTrait::SrcTileCount;
using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);

int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;

extern __shared__ __align__(8) unsigned char smem[];
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);

int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
out_start_h = off_obh * OutTileConfig::block_h,
out_start_w = off_obw * OutTileConfig::block_w,
src_start_h = out_start_h - param.pad_h,
src_start_w = out_start_w - param.pad_w,
out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h;

T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w;
T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w;

T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w;

SrcGlobal2ShareVisitor gl2sh_src(
smem_src, param.src_w, src_start_h, src_start_w, param.src_h, param.src_w);

FilterGlobal2ShareVisitor gl2sh_flt = {
smem_flt, param.flt_w, is_fwd ? 0 : param.flt_h - 2,
0, param.flt_h, param.flt_w};

gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w;
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w;

gl2sh_src.first_copy();
gl2sh_flt.first_copy();

__syncthreads();

T reg_src[SrcTileConfig::unroll_h * SrcTileConfig::unroll_w],
reg_flt[FilterTileConfig::unroll_h * FilterTileConfig::unroll_w];

T sum[OutTileConfig::unroll_size] = {0.0};

for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) {
gl2sh_src.copy();
gl2sh_flt.copy();
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) {
reg_src[s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr
[(off_oh + fh + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w];
}
}

#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
reg_flt[f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
[(fh + f_h) % FilterTileCount::smem_h *
FilterTileCount::smem_w +
f_w];
}
}

#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * OutTileConfig::unroll_w + ow] +=
reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] *
reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw +
ow];
}
}
}
}

__syncthreads();
gl2sh_src.commit();
gl2sh_flt.commit();
gl2sh_src.iter_forward();
gl2sh_flt.iter_forward();
__syncthreads();
}

for (int o = 0; o < OutTileConfig::unroll_size; ++o) {
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) {
sum[o] += __shfl_xor(sum[o], i, 32);
}
}

if (threadIdx.x == 0) {
#pragma unroll
for (int i = 0; i < OutTileConfig::unroll_h; ++i) {
int out_h_idx = out_base_h_idx + i;
if (out_h_idx < param.out_h) {
#pragma unroll
for (int j = 0; j < OutTileConfig::unroll_w; ++j) {
int out_w_idx = out_start_w + j;
if (out_w_idx >= param.out_w)
return;
out_base_ptr[out_h_idx * param.out_w + out_w_idx] =
sum[i * OutTileConfig::unroll_w + j];
}
}
}
}
}

template <
typename T, typename T2, DepthwiseConv2dDirection kDirection, int unroll_fw,
int unroll_ow>
void LaunchDepthwiseConv2dGPUSmall(
const Param& param, const T* input, const T* filter, T* output,
cudaStream_t stream) {
static int const unroll_oh = 1, unroll_fh = 1;

using FilterTileConfig = FilterTileConfig<unroll_fh, unroll_fw>;
using ThreadConfig = ThreadConfig<4, 32>;
using OutTileConfig = OutTileConfig<ThreadConfig, unroll_oh, unroll_ow>;
using IConvTrait =
ConvTrait<T, kDirection, ThreadConfig, OutTileConfig, FilterTileConfig>;
using SrcTileCount = typename IConvTrait::SrcTileCount;
using FilterTileCount = typename IConvTrait::FilterTileCount;

dim3 block(ThreadConfig::thread_x, ThreadConfig::thread_y);
dim3 grid;
grid.x = param.batch * param.src_chl * param.chl_mul;
grid.y = DIVUP(param.out_w, OutTileConfig::block_w);
grid.z = DIVUP(param.out_h, OutTileConfig::block_h);
const int shared_storage =
(SrcTileCount::smem_size + FilterTileCount::smem_size) * sizeof(T);

void (*kernel)(const Param, const T*, const T*, T*);
kernel = DepthwiseConv2dGPUKernelNCHWSmall<T, IConvTrait, kDirection>;
kernel<<<grid, block, shared_storage, stream>>>(param, input, filter, output);
after_kernel_launch();
}

#define INSTANCE_AB(a, b, direction) \
if (param.out_w > b * 4) { \
LaunchDepthwiseConv2dGPUSmall<float, float2, direction, a + 1, b + 1>( \
param, src, flt, dst, stream); \
}

#define INSTANCE_A(a, direction) \
if (param.flt_w > 0) { \
INSTANCE_AB(a, 15, direction) \
else INSTANCE_AB(a, 14, direction) else INSTANCE_AB(a, 13, direction) else INSTANCE_AB( \
a, 12, direction) else INSTANCE_AB(a, 11, direction) else INSTANCE_AB(a, 10, direction) else INSTANCE_AB(a, 9, direction) else INSTANCE_AB(a, 8, direction) else INSTANCE_AB(a, 7, direction) else INSTANCE_AB(a, 6, direction) else INSTANCE_AB(a, 5, direction) else INSTANCE_AB(a, 4, direction) else INSTANCE_AB(a, 3, direction) else INSTANCE_AB(a, 2, direction) else INSTANCE_AB(a, 1, direction) else INSTANCE_AB(a, 0, direction) \
}

#define INSTANCE(direction) \
INSTANCE_A(7, direction) \
else INSTANCE_A(6, direction) else INSTANCE_A(5, direction) else INSTANCE_A(4, direction) else INSTANCE_A( \
3, \
direction) else INSTANCE_A(2, direction) else INSTANCE_A(1, direction) else INSTANCE_A(0, direction)

} // anonymous namespace

+ 48
- 0
dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu View File

@@ -0,0 +1,48 @@
/**
* \file dnn/src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "cuda.h"
#include "cuda_fp16.h"
// #include "src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.cuh"
#include "src/cuda/conv_bias/chanwise/kern.cuh"
#include "src/cuda/conv_bias/chanwise/kern_helper.cuh"
#include "src/cuda/conv_bias/chanwise/launch_config.cuh"
#include "src/cuda/fp16_help.cuh"

using namespace megdnn;
using namespace cuda;
using namespace conv_bias;
using namespace chanwise;

#include "src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl"

namespace megdnn {
namespace cuda {
namespace conv_bias {
namespace chanwise {

// =====================================fwd=====================================

#define check

template <>
void run_fwd_depthwise_large_filter(
float* dst, const float* src, const float* flt, const Param& param,
cudaStream_t stream) {
INSTANCE(DepthwiseConv2dDirection::DIRECTION_FORWARD)
}

} // namespace chanwise
} // namespace conv_bias
} // namespace cuda
} // namespace megdnn

// vim: syntax=cuda.doxygen

+ 4
- 0
dnn/src/cuda/conv_bias/chanwise/kern.cuh View File

@@ -61,6 +61,10 @@ template <typename T>
void run_fwd_small(
T* dst, const T* src, const T* flt, const Param& param, cudaStream_t stream);

template <typename T>
void run_fwd_depthwise_large_filter(
T* dst, const T* src, const T* flt, const Param& param, cudaStream_t stream);

// implemented in fwd_8x8x32.cu
void run_fwd_8x8x32(
int32_t* dst, const int8_t* src, const int8_t* flt, const Param& param,


+ 109
- 0
dnn/src/cuda/conv_bias/depthwise_large_filter.cpp View File

@@ -0,0 +1,109 @@
/**
* \file dnn/src/cuda/conv_bias/depthwise_large_filter.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/conv_bias/chanwise/kern.cuh"
#include "src/cuda/utils.h"

using namespace megdnn;
using namespace cuda;
using namespace conv_bias;

namespace {
inline bool is_available_depthwise_large_filter(const chanwise::Param& param) {
auto&& device_prop = cuda::current_device_prop();
int flt_smem_w = (param.flt_w + 3) / 4 * 4;
int flt_smem_h = 3;
int flt_reg_per_thread =
flt_smem_w > 32 ? (flt_smem_w + 31) / 32 : 1 + flt_smem_w / 4;
int ow = param.out_w > 64 ? 64 : param.out_w;
int src_smem_w = ow + flt_smem_w - 1;
int src_smem_h = flt_smem_h + param.flt_h - 1;
int src_reg_per_thread = src_smem_w > 128 ? (flt_smem_w + 127) / 128
: 1 + (ow + 3) / 4 + flt_smem_w / 4 - 1;
int out_reg_per_thread = (ow + 3) / 4 * 4;
if (device_prop.regsPerBlock < 4 * 32 *
(flt_reg_per_thread + src_reg_per_thread +
out_reg_per_thread) ||
device_prop.sharedMemPerBlock <
static_cast<size_t>(
flt_smem_w * flt_smem_h + src_smem_w * src_smem_h)) {
return false;
}
return param.stride_h == 1 && param.stride_w == 1 && param.src_h == param.out_h &&
param.src_w == param.out_w;
}
} // anonymous namespace

bool ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::is_available(
const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) {
return false;
}
if (args.src_layout->dtype != args.filter_layout->dtype &&
args.src_layout->dtype != dtype::Float32()) {
return false;
}
if (args.z_layout->ndim > 0)
return false;

auto param = chanwise::Param::from_fwd_args(args);
auto&& fm = args.filter_meta;
return fm.group > 1 && args.filter_meta.format == Param::Format::NCHW &&
args.src_layout->dtype.category() == DTypeCategory::FLOAT &&
args.opr->param().compute_mode == Param::ComputeMode::DEFAULT &&
fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && !fm.should_flip &&
is_available_depthwise_large_filter(param);
}

size_t ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::get_workspace_in_bytes(
const SizeArgs& args) const {
auto dst_layout = *args.dst_layout;
if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
dst_layout.dtype = DType();
args.opr->check_or_deduce_dtype_fwd(
args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype);
return dst_layout.span().dist_byte();
}
return 0;
}

void ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::exec(const ExecArgs& args) const {
WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}};
TensorND conv_dst_tensor = *args.dst_tensor;
if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
conv_dst_tensor = TensorND{bundle.get(0), conv_dst_tensor.layout};
conv_dst_tensor.layout.dtype = DType();
args.opr->check_or_deduce_dtype_fwd(
args.src_layout->dtype, args.filter_layout->dtype,
conv_dst_tensor.layout.dtype);
}
{
auto kparam = chanwise::Param::from_fwd_args(args);
auto stream = cuda_stream(args.handle);
switch (args.src_layout->dtype.enumv()) {
case DTypeEnum::Float32:
chanwise::run_fwd_depthwise_large_filter(
conv_dst_tensor.ptr<float>(), args.src_tensor->ptr<float>(),
args.filter_tensor->ptr<float>(), kparam, stream);
break;
default:
megdnn_assert_internal(0);
}
}
handle_bias_and_nonlinear(
args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor,
args.bias_tensor);
}

// vim: syntax=cpp.doxygen

+ 1
- 0
dnn/src/cuda/conv_bias/opr_impl.h View File

@@ -45,6 +45,7 @@ public:
class AlgoCUDNNConvBiasActivation;
class AlgoChanwise;
class AlgoChanwiseSmall;
class AlgoDepthwiseLargeFilter;
class AlgoChanwise8x8x32;
class AlgoCUDNNConv;
class AlgoFallbackNCHWQS8;


+ 116
- 0
dnn/test/cuda/conv_bias.cpp View File

@@ -695,6 +695,59 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_CHANWISE_SMALL) {
}
}

TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
Checker<ConvBiasForward> checker(handle_cuda());
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
ConvBiasForward::algo_name<ConvBias::DirectParam>(
"DEPTHWISE_LARGE_FILTER", {})
.c_str()));
auto run = [&checker](size_t n, size_t g, size_t h, size_t fh) {
param::ConvBias cur_param;
cur_param.mode = param::ConvBias::Mode::CROSS_CORRELATION;
cur_param.sparse = ConvBias::Param::Sparse::GROUP;
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(3, dtype::Float32())
.set_dtype(4, dtype::Float32());

cur_param.pad_h = cur_param.pad_w = fh / 2;
cur_param.stride_h = cur_param.stride_w = 1;
checker.set_param(cur_param).execs(
{{n, g, h, h}, {g, 1, 1, fh, fh}, {}, {}, {}});
};
run(4, 8, 32, 5);
run(4, 8, 32, 7);
run(4, 8, 32, 9);
run(4, 8, 32, 11);
run(4, 8, 32, 13);
run(4, 8, 32, 15);
run(4, 8, 32, 17);
run(4, 8, 32, 19);
run(4, 8, 32, 21);
run(4, 8, 32, 23);
run(4, 8, 32, 25);
run(4, 8, 32, 27);
run(4, 8, 32, 29);
run(4, 8, 32, 31);
run(4, 8, 64, 5);
run(4, 8, 64, 7);
run(4, 8, 64, 9);
run(4, 8, 64, 11);
run(4, 8, 64, 13);
run(4, 8, 64, 15);
run(4, 8, 64, 17);
run(4, 8, 64, 19);
run(4, 8, 64, 21);
run(4, 8, 64, 23);
run(4, 8, 64, 25);
run(4, 8, 64, 27);
run(4, 8, 64, 29);
run(4, 8, 64, 31);
run(1, 2, 128, 31);
run(1, 2, 256, 31);
}

TEST_F(CUDA, CONV_BIAS_FORWARD_CHANWISE_8x8x32) {
require_compute_capability(6, 1);
Checker<ConvBiasForward> checker(handle_cuda());
@@ -1474,6 +1527,69 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_TENSORCORE_INT8) {
run_bench(256, 512, 7, 7, 512, 3, 3, 1, 1, 1000);
run_bench(256, 512, 7, 7, 2048, 1, 1, 1, 1, 1000);
}

TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
require_compute_capability(7, 5);
Benchmarker<ConvBiasForward> bencher(handle_cuda());
bencher.set_display(false);
bencher.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
"DEPTHWISE_LARGE_FILTER", {})
.c_str()));

ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW;

using NonlineMode = ConvBias::Param::NonlineMode;
param.nonlineMode = NonlineMode::IDENTITY;
param.sparse = ConvBias::Param::Sparse::GROUP;
auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh,
size_t fw, size_t sh, size_t sw, size_t nr_times) {
param.pad_h = fh / 2;
param.pad_w = fw / 2;
param.stride_h = sh;
param.stride_w = sw;

bencher.set_param(param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32());
bencher.set_times(nr_times);
size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h);
size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w);
TensorShape inp{batch, g, hi, wi}, kern{g, 1, 1, fh, fw}, out{batch, g, ho, wo};

float bandwith = static_cast<float>(
inp.total_nr_elems() + kern.total_nr_elems() +
out.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;

auto time_in_ms = bencher.execs({inp, kern, {}, {}, out}) / nr_times;
auto ops = 2.0 * batch * g * ho * wo * fh * fw / (time_in_ms * 1e-3) * 1e-12;
printf("chanwise_depthwise_large_filter: inp=%s, kern=%s, out=%s, time: "
"%.2fms, "
"perf: %.2f Tops bandwidth: %.2fGB/s.\n",
inp.to_string().c_str(), kern.to_string().c_str(),
out.to_string().c_str(), time_in_ms, ops, bandwith * 4 / time_in_ms);
};

run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10);
run_bench(64, 384, 32, 32, 5, 5, 1, 1, 10);
run_bench(64, 384, 32, 32, 7, 7, 1, 1, 10);
run_bench(64, 384, 32, 32, 9, 9, 1, 1, 10);
run_bench(64, 384, 32, 32, 11, 11, 1, 1, 10);
run_bench(64, 384, 32, 32, 13, 13, 1, 1, 10);
run_bench(64, 384, 32, 32, 15, 15, 1, 1, 10);
run_bench(64, 384, 32, 32, 17, 17, 1, 1, 10);
run_bench(64, 384, 32, 32, 19, 19, 1, 1, 10);
run_bench(64, 384, 32, 32, 21, 21, 1, 1, 10);
run_bench(64, 384, 32, 32, 23, 23, 1, 1, 10);
run_bench(64, 384, 32, 32, 25, 25, 1, 1, 10);
run_bench(64, 384, 32, 32, 27, 27, 1, 1, 10);
run_bench(64, 384, 32, 32, 29, 29, 1, 1, 10);
run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10);
}
#endif
#endif



+ 77
- 0
dnn/test/cuda/convolution.cpp View File

@@ -901,6 +901,43 @@ TEST_F(CUDA, CONVOLUTION_BWD_DATA_BENCHMARK) {
run(32, 64, 64, 56, 56, 1, 1, 0);
}

TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_CHANWISE_SMALL_FEAT_LARGE_FILTER) {
CUBenchmarker<ConvolutionBackwardData> bench{handle_cuda()};
std::unique_ptr<OprProxy<ConvolutionBackwardData>> proxy{
new OprProxy<ConvolutionBackwardData>{true}};
size_t RUNS = 10;
bench.set_proxy(proxy).set_times(RUNS);

auto run = [&](size_t N, size_t OC, size_t g, size_t IH, size_t IW, size_t FH,
size_t SH, size_t PH) {
bench.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32());
param::Convolution param;
param.stride_h = param.stride_w = SH;
param.pad_h = param.pad_w = FH / 2;
param.sparse = param::Convolution::Sparse::GROUP;
bench.set_param(param);
bench.proxy()->target_execution_policy.algo.reset();
TensorLayout src{{N, g, IH, IW}, dtype::Float32()},
filter{{g, 1, 1, FH, FH}, dtype::Float32()};
TensorLayout dst;
{
auto&& opr = handle_cuda()->create_operator<Convolution>();
opr->param() = param;
opr->deduce_layout(src, filter, dst);
}
auto time_ms_fp32 = bench.execl({filter, dst, src}) / RUNS;
float flo = 2.0 * N * g * dst[2] * dst[3] * FH * FH;
printf("inp=%s, kern=%s, dst=%s ", src.to_string().c_str(),
filter.to_string().c_str(), dst.to_string().c_str());
printf("time_fp32=%.2fms, flops=%.3fTFLOPS\n", time_ms_fp32,
(flo / (time_ms_fp32 * 1e9)));
};

run(64, 384, 384, 32, 32, 31, 1, 15);
}

TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_BF16) {
CUBenchmarker<ConvolutionBackwardData> bench{handle_cuda()};
std::unique_ptr<OprProxy<ConvolutionBackwardData>> proxy{
@@ -1065,6 +1102,46 @@ TEST_F(CUDA, CONVOLUTION_BWD_FILTER_BENCHMARK) {
run(32, 512, 1024, 14, 14, 1, 2, 0);
run(32, 64, 64, 56, 56, 1, 1, 0);
}

TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_FILTER_CHANWISE_SMALL_FEAT_LARGE_FILTER) {
CUBenchmarker<ConvolutionBackwardFilter> bench{handle_cuda()};
std::unique_ptr<OprProxy<ConvolutionBackwardFilter>> proxy{
new OprProxy<ConvolutionBackwardFilter>{true}};
size_t RUNS = 10;
bench.set_proxy(proxy).set_times(RUNS);

bench.set_before_exec_callback(AlgoChecker<ConvolutionBackwardFilter>(
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFTv7.6.3"));

auto run = [&](size_t N, size_t OC, size_t g, size_t IH, size_t IW, size_t FH,
size_t SH, size_t PH) {
bench.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32());
param::Convolution param;
param.stride_h = param.stride_w = SH;
param.pad_h = param.pad_w = FH / 2;
param.sparse = param::Convolution::Sparse::GROUP;
bench.set_param(param);
bench.proxy()->target_execution_policy.algo.reset();
TensorLayout src{{N, g, IH, IW}, dtype::Float32()},
filter{{g, 1, 1, FH, FH}, dtype::Float32()};
TensorLayout dst;
{
auto&& opr = handle_cuda()->create_operator<Convolution>();
opr->param() = param;
opr->deduce_layout(src, filter, dst);
}
auto time_ms_fp32 = bench.execl({src, dst, filter}) / RUNS;
float flo = 2.0 * N * g * dst[2] * dst[3] * FH * FH;
printf("inp=%s, kern=%s, dst=%s ", src.to_string().c_str(),
filter.to_string().c_str(), dst.to_string().c_str());
printf("time_fp32=%.2fms, flops=%.3fTFLOPS\n", time_ms_fp32,
(flo / (time_ms_fp32 * 1e9)));
};
run(64, 384, 384, 32, 32, 31, 1, 15);
}

#endif

#undef CUDNN_VERSION_STRING


Loading…
Cancel
Save