Browse Source

fix(cuda): conv algo heuristic choose

GitOrigin-RevId: 95c5e7d627
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
3228fb75a5
2 changed files with 6 additions and 6 deletions
  1. +3
    -3
      dnn/src/cuda/conv_bias/opr_impl.cpp
  2. +3
    -3
      dnn/src/cuda/convolution/opr_impl.cpp

+ 3
- 3
dnn/src/cuda/conv_bias/opr_impl.cpp View File

@@ -148,9 +148,9 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
//! choose for large kernel cases
size_t fh = args.filter_meta.spatial[0], fw = args.filter_meta.spatial[1];
size_t hi = src[2], wi = src[3];
const bool prefer_dnn_lk_implbmm =
hi <= 2 * fh && wi <= 2 * fw && wi < 32 && hi <= 32;
const bool prefer_direct_lk = fh > 10 && fw > 10;
const bool prefer_dnn_lk_implbmm = hi <= 2 * fh && wi <= 2 * fw;
//! filter size > 9, choose large kernel cases
const bool prefer_direct_lk = fh > 9 && fw > 9;
//! avoid bad case in cudnn, check dnn chanwise impl first
if (is_chanwise) {
if (prefer_dnn_lk_implbmm) {


+ 3
- 3
dnn/src/cuda/convolution/opr_impl.cpp View File

@@ -119,10 +119,10 @@ ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::
size_t fh = args.filter_meta.spatial[0], fw = args.filter_meta.spatial[1];
size_t ho = diff[2], wo = diff[3];
const bool prefer_dnn_lk_implbmm = args.filter_meta.format == Param::Format::NCHW &&
ho <= 2 * fh && wo <= 2 * fw && ho < 32 &&
wo < 32;
ho <= 2 * fh && wo <= 2 * fw;
//! filter size > 9, choose large kernel cases
const bool prefer_direct_lk =
args.filter_meta.format == Param::Format::NCHW && fh > 10 && fw > 10;
args.filter_meta.format == Param::Format::NCHW && fh > 9 && fw > 9;
if (prefer_dnn_lk_implbmm) {
#if CUDA_VERSION >= 10020
if (sm_algo_pack.implbmm_nchw_hmma[0].is_available_attribute(


Loading…
Cancel
Save