|
|
@@ -297,6 +297,9 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb( |
|
|
|
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, |
|
|
|
const AlgoAttribute& positive_attr, |
|
|
|
const AlgoAttribute& negative_attr) { |
|
|
|
if (ConvBiasImpl::param().format == Param::Format::NHWCD4) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto algo_data_type = param.deduce_algo_data_type(); |
|
|
|
auto suggest_category_order = suggest_algo_category_order(param); |
|
|
|
for (auto category : suggest_category_order) { |
|
|
@@ -346,7 +349,7 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param( |
|
|
|
param().format == Param::Format::NCHW32 || |
|
|
|
param().format == Param::Format::NCHW64) { |
|
|
|
spatial_pos = 2; |
|
|
|
} else if (param().format == Param::Format::NHWC) { |
|
|
|
} else if (param().format == Param::Format::NHWC || param().format == Param::Format::NHWCD4) { |
|
|
|
spatial_pos = 1; |
|
|
|
} else { |
|
|
|
megdnn_assert(0, "invalid conv format %d", |
|
|
@@ -497,6 +500,9 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_from_desc( |
|
|
|
|
|
|
|
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( |
|
|
|
const NCBKernSizeParam& param, size_t workspace_size) { |
|
|
|
if (ConvBiasImpl::param().format == Param::Format::NHWCD4) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (auto algo = get_algorithm_from_desc(execution_policy().algo)) { |
|
|
|
return algo; |
|
|
|
} |
|
|
|