diff --git a/dnn/src/common/adaptive_pooling.cpp b/dnn/src/common/adaptive_pooling.cpp index 92916400..88862e1e 100644 --- a/dnn/src/common/adaptive_pooling.cpp +++ b/dnn/src/common/adaptive_pooling.cpp @@ -6,8 +6,21 @@ namespace megdnn { param::Pooling AdaptivePoolingBase::deduce_pooling_param( const TensorLayout& src, const TensorLayout& dst) { - megdnn_assert(param().format == param::AdaptivePooling::Format::NCHW); - size_t IH = src.shape[2], IW = src.shape[3], OH = dst.shape[2], OW = dst.shape[3]; + auto param_format = param().format; + size_t IH, IW, OH, OW; + if (param_format == param::AdaptivePooling::Format::NCHW) { + IH = src.shape[2]; + IW = src.shape[3]; + OH = dst.shape[2]; + OW = dst.shape[3]; + } else if (param_format == param::AdaptivePooling::Format::NHWC) { + IH = src.shape[1]; + IW = src.shape[2]; + OH = dst.shape[1]; + OW = dst.shape[2]; + } else { + megdnn_throw("AdaptivePooling only support NCHW or NHWC format"); + } param::Pooling ret; ret.mode = param().mode; diff --git a/src/opr/impl/dnn/adaptive_pooling.cpp b/src/opr/impl/dnn/adaptive_pooling.cpp index 1d8f402f..17c8088d 100644 --- a/src/opr/impl/dnn/adaptive_pooling.cpp +++ b/src/opr/impl/dnn/adaptive_pooling.cpp @@ -43,13 +43,22 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape( "shape mismatch for AdaptivePooling: src=%s, out2d=%s", src.to_string().c_str(), oshp2d.to_string().c_str()); - mgb_assert( - param().format == Param::Format::NCHW, "AdaptivePooling only support NCHW"); - dest.ndim = 4; - dest.shape[0] = src.shape[0]; - dest.shape[1] = src.shape[1]; - dest.shape[2] = oshp2d.shape[0]; - dest.shape[3] = oshp2d.shape[1]; + auto param_format = param().format; + if (param_format == Param::Format::NCHW) { + dest.ndim = 4; + dest.shape[0] = src.shape[0]; + dest.shape[1] = src.shape[1]; + dest.shape[2] = oshp2d.shape[0]; + dest.shape[3] = oshp2d.shape[1]; + } else if (param_format == Param::Format::NHWC) { + dest.ndim = 4; + dest.shape[0] = src.shape[0]; + dest.shape[1] = oshp2d.shape[0]; + dest.shape[2] = oshp2d.shape[1]; + dest.shape[3] = src.shape[3]; + } else { + mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); + } } size_t AdaptivePoolingForward::get_workspace_size_bytes(