Browse Source

feat(mgb/opr): add NHWC support for AdaptivePooling

GitOrigin-RevId: b23e37ac23
release-1.10
Megvii Engine Team 3 years ago
parent
commit
3bd40887b6
2 changed files with 31 additions and 9 deletions
  1. +15
    -2
      dnn/src/common/adaptive_pooling.cpp
  2. +16
    -7
      src/opr/impl/dnn/adaptive_pooling.cpp

+ 15
- 2
dnn/src/common/adaptive_pooling.cpp View File

@@ -6,8 +6,21 @@ namespace megdnn {

param::Pooling AdaptivePoolingBase::deduce_pooling_param(
const TensorLayout& src, const TensorLayout& dst) {
megdnn_assert(param().format == param::AdaptivePooling::Format::NCHW);
size_t IH = src.shape[2], IW = src.shape[3], OH = dst.shape[2], OW = dst.shape[3];
auto param_format = param().format;
size_t IH, IW, OH, OW;
if (param_format == param::AdaptivePooling::Format::NCHW) {
IH = src.shape[2];
IW = src.shape[3];
OH = dst.shape[2];
OW = dst.shape[3];
} else if (param_format == param::AdaptivePooling::Format::NHWC) {
IH = src.shape[1];
IW = src.shape[2];
OH = dst.shape[1];
OW = dst.shape[2];
} else {
megdnn_throw("AdaptivePooling only support NCHW or NHWC format");
}

param::Pooling ret;
ret.mode = param().mode;


+ 16
- 7
src/opr/impl/dnn/adaptive_pooling.cpp View File

@@ -43,13 +43,22 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape(
"shape mismatch for AdaptivePooling: src=%s, out2d=%s",
src.to_string().c_str(), oshp2d.to_string().c_str());

mgb_assert(
param().format == Param::Format::NCHW, "AdaptivePooling only support NCHW");
dest.ndim = 4;
dest.shape[0] = src.shape[0];
dest.shape[1] = src.shape[1];
dest.shape[2] = oshp2d.shape[0];
dest.shape[3] = oshp2d.shape[1];
auto param_format = param().format;
if (param_format == Param::Format::NCHW) {
dest.ndim = 4;
dest.shape[0] = src.shape[0];
dest.shape[1] = src.shape[1];
dest.shape[2] = oshp2d.shape[0];
dest.shape[3] = oshp2d.shape[1];
} else if (param_format == Param::Format::NHWC) {
dest.ndim = 4;
dest.shape[0] = src.shape[0];
dest.shape[1] = oshp2d.shape[0];
dest.shape[2] = oshp2d.shape[1];
dest.shape[3] = src.shape[3];
} else {
mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format");
}
}

size_t AdaptivePoolingForward::get_workspace_size_bytes(


Loading…
Cancel
Save