From 3bd40887b62ae8aa6fc2ca158c891898d3202733 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 11 Feb 2022 14:37:14 +0800 Subject: [PATCH] feat(mgb/opr): add NHWC support for AdaptivePooling GitOrigin-RevId: b23e37ac23764085568d4c303619eec10a0b4867 --- dnn/src/common/adaptive_pooling.cpp | 17 +++++++++++++++-- src/opr/impl/dnn/adaptive_pooling.cpp | 23 ++++++++++++++++------- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/dnn/src/common/adaptive_pooling.cpp b/dnn/src/common/adaptive_pooling.cpp index 92916400..88862e1e 100644 --- a/dnn/src/common/adaptive_pooling.cpp +++ b/dnn/src/common/adaptive_pooling.cpp @@ -6,8 +6,21 @@ namespace megdnn { param::Pooling AdaptivePoolingBase::deduce_pooling_param( const TensorLayout& src, const TensorLayout& dst) { - megdnn_assert(param().format == param::AdaptivePooling::Format::NCHW); - size_t IH = src.shape[2], IW = src.shape[3], OH = dst.shape[2], OW = dst.shape[3]; + auto param_format = param().format; + size_t IH, IW, OH, OW; + if (param_format == param::AdaptivePooling::Format::NCHW) { + IH = src.shape[2]; + IW = src.shape[3]; + OH = dst.shape[2]; + OW = dst.shape[3]; + } else if (param_format == param::AdaptivePooling::Format::NHWC) { + IH = src.shape[1]; + IW = src.shape[2]; + OH = dst.shape[1]; + OW = dst.shape[2]; + } else { + megdnn_throw("AdaptivePooling only support NCHW or NHWC format"); + } param::Pooling ret; ret.mode = param().mode; diff --git a/src/opr/impl/dnn/adaptive_pooling.cpp b/src/opr/impl/dnn/adaptive_pooling.cpp index 1d8f402f..17c8088d 100644 --- a/src/opr/impl/dnn/adaptive_pooling.cpp +++ b/src/opr/impl/dnn/adaptive_pooling.cpp @@ -43,13 +43,22 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape( "shape mismatch for AdaptivePooling: src=%s, out2d=%s", src.to_string().c_str(), oshp2d.to_string().c_str()); - mgb_assert( - param().format == Param::Format::NCHW, "AdaptivePooling only support NCHW"); - dest.ndim = 4; - dest.shape[0] = src.shape[0]; - dest.shape[1] = src.shape[1]; - dest.shape[2] = oshp2d.shape[0]; - dest.shape[3] = oshp2d.shape[1]; + auto param_format = param().format; + if (param_format == Param::Format::NCHW) { + dest.ndim = 4; + dest.shape[0] = src.shape[0]; + dest.shape[1] = src.shape[1]; + dest.shape[2] = oshp2d.shape[0]; + dest.shape[3] = oshp2d.shape[1]; + } else if (param_format == Param::Format::NHWC) { + dest.ndim = 4; + dest.shape[0] = src.shape[0]; + dest.shape[1] = oshp2d.shape[0]; + dest.shape[2] = oshp2d.shape[1]; + dest.shape[3] = src.shape[3]; + } else { + mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); + } } size_t AdaptivePoolingForward::get_workspace_size_bytes(