|
|
@@ -148,7 +148,9 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( |
|
|
|
//! choose for large kernel cases |
|
|
|
size_t fh = args.filter_meta.spatial[0], fw = args.filter_meta.spatial[1]; |
|
|
|
size_t hi = src[2], wi = src[3]; |
|
|
|
const bool prefer_dnn_lk_implbmm = hi <= 2 * fh && wi <= 2 * fw; |
|
|
|
const bool prefer_dnn_lk_implbmm = |
|
|
|
hi <= 2 * fh && wi <= 2 * fw && wi < 32 && hi <= 32; |
|
|
|
const bool prefer_direct_lk = fh > 10 && fw > 10; |
|
|
|
//! avoid bad case in cudnn, check dnn chanwise impl first |
|
|
|
if (is_chanwise) { |
|
|
|
if (prefer_dnn_lk_implbmm) { |
|
|
@@ -160,6 +162,11 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( |
|
|
|
if (sm_algo_pack.f32_implicit_bmm[0].is_available_attribute( |
|
|
|
args, positive_attr, negative_attr, workspace_limit_in_bytes)) |
|
|
|
return &sm_algo_pack.f32_implicit_bmm[0]; |
|
|
|
} else if ( |
|
|
|
prefer_direct_lk && |
|
|
|
sm_algo_pack.depthwise_large_filter.is_available_attribute( |
|
|
|
args, positive_attr, negative_attr, workspace_limit_in_bytes)) { |
|
|
|
return &sm_algo_pack.depthwise_large_filter; |
|
|
|
} else if (prefer_dnn_chanwise) { |
|
|
|
if (sm_algo_pack.chanwise.is_available_attribute( |
|
|
|
args, positive_attr, negative_attr, workspace_limit_in_bytes)) |
|
|
|