Browse Source

fix(mge/convbias): make fallback convbias support nhwcd4 layout

GitOrigin-RevId: 1c306f867d
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
ea70d99b4d
2 changed files with 12 additions and 14 deletions
  1. +5
    -13
      dnn/src/common/relayout_format.cpp
  2. +7
    -1
      dnn/src/fallback/conv_bias/opr_impl.cpp

+ 5
- 13
dnn/src/common/relayout_format.cpp View File

@@ -382,19 +382,11 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {

if (dst.type() == TensorFormat::Type::IMAGE2D_PACK4 &&
(
handle()->type() != Handle::HandleType::NAIVE)) {
#if MEGDNN_ENABLE_MANGLING
megdnn_throw(
"Only naive and opencl handle support "
"Image2DPack4TensorFormat, try build with debug for get more "
"info");
#else
megdnn_throw(
"Only naive and opencl handle support "
"Image2DPack4TensorFormat, try to export MGB_USE_MEGDNN_DBG=2 "
"and also export CUDA_VISIBLE_DEVICES=\'\' at CUDA env"
"to enable naive handle");
#endif
handle()->type() != Handle::HandleType::NAIVE &&
handle()->type() != Handle::HandleType::X86)) {
megdnn_throw(
"Dump with Image2DPack4TensorFormat is not available on CUDA compnode, "
"try export CUDA_VISIBLE_DEVICES=\'\'");
}
#undef CHECK_SRC
}


+ 7
- 1
dnn/src/fallback/conv_bias/opr_impl.cpp View File

@@ -297,6 +297,9 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
if (ConvBiasImpl::param().format == Param::Format::NHWCD4) {
return nullptr;
}
auto algo_data_type = param.deduce_algo_data_type();
auto suggest_category_order = suggest_algo_category_order(param);
for (auto category : suggest_category_order) {
@@ -346,7 +349,7 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param(
param().format == Param::Format::NCHW32 ||
param().format == Param::Format::NCHW64) {
spatial_pos = 2;
} else if (param().format == Param::Format::NHWC) {
} else if (param().format == Param::Format::NHWC || param().format == Param::Format::NHWCD4) {
spatial_pos = 1;
} else {
megdnn_assert(0, "invalid conv format %d",
@@ -497,6 +500,9 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_from_desc(

ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm(
const NCBKernSizeParam& param, size_t workspace_size) {
if (ConvBiasImpl::param().format == Param::Format::NHWCD4) {
return nullptr;
}
if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
return algo;
}


Loading…
Cancel
Save