diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 62cfd861..4e7e1f34 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -1066,7 +1066,31 @@ py::object _adaptive_pool2d_cpp( py::handle inp_hdl, py::handle shape_val_hdl, py::handle pool_mode_hdl) { py::object shape_hdl = py::reinterpret_borrow(shape_val_hdl); py::list shps(0); - if (!PyTuple_Check(shape_val_hdl.ptr())) { + auto mode_string = pool_mode_hdl.cast(); + ::megdnn::param::AdaptivePooling::Mode pool_mode = + ::megdnn::param::AdaptivePooling::Mode::MAX; + if (mode_string.compare(std::string("AVERAGE")) == 0) { + pool_mode = ::megdnn::param::AdaptivePooling::Mode::AVERAGE; + } + std::shared_ptr op; + std::vector p; + auto pool_format = ::megdnn::param::AdaptivePooling::Format::NCHW; + auto inp_format = getattr(inp_hdl, "format").cast(); + if (inp_format == "nhwc") { + pool_format = ::megdnn::param::AdaptivePooling::Format::NHWC; + } + if (TensorWrapper::try_cast(shape_val_hdl.ptr())) { + std::vector shp; + op = AdaptivePooling::make(pool_mode, pool_format, shp); + py::object Op = py::cast(op); + p.resize(3); + p[0] = Op.ptr(); + p[1] = inp_hdl.ptr(); + p[2] = shape_val_hdl.ptr(); + py::tuple ret = + py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + return ret[0]; + } else if (!PyTuple_Check(shape_val_hdl.ptr())) { shps.append(PyLong_AsLong(shape_val_hdl.ptr())); shps.append(PyLong_AsLong(shape_val_hdl.ptr())); @@ -1078,19 +1102,11 @@ py::object _adaptive_pool2d_cpp( } catch (py::error_already_set& err) { shape_tuple = py::reinterpret_borrow(shape_hdl); } - auto mode_string = pool_mode_hdl.cast(); - ::megdnn::param::AdaptivePooling::Mode pool_mode = - ::megdnn::param::AdaptivePooling::Mode::MAX; - if (mode_string.compare(std::string("AVERAGE")) == 0) { - pool_mode = ::megdnn::param::AdaptivePooling::Mode::AVERAGE; - } + auto [shape, fastpath] = tuple2vector(shape_tuple); fastpath &= enable_fastpath(inp_hdl); - std::shared_ptr op; - std::vector p; py::object shape_tensor; - op = AdaptivePooling::make( - pool_mode, ::megdnn::param::AdaptivePooling::Format::NCHW, shape); + op = AdaptivePooling::make(pool_mode, pool_format, shape); if (fastpath) { p.resize(2); } else { diff --git a/imperative/src/impl/ops/adaptive_pooling.cpp b/imperative/src/impl/ops/adaptive_pooling.cpp index fbd972fc..5701b9c5 100644 --- a/imperative/src/impl/ops/adaptive_pooling.cpp +++ b/imperative/src/impl/ops/adaptive_pooling.cpp @@ -39,6 +39,7 @@ std::tuple, bool> infer_output_attrs_fallible( const dt_int32* oshp2d = nullptr; dst_layout.ndim = 4u; + bool tshp1n = false; if (nr_inp == 1) { oshp2d = pool.shape.data(); } else { @@ -51,17 +52,18 @@ std::tuple, bool> infer_output_attrs_fallible( "target shape of AdaptivePooling expects ndim=1; got ndim=%lu actually", tshp.layout.ndim); oshp2d = tshp.value.ptr(); + tshp1n = tshp.layout.total_nr_elems() == 1; } auto param_format = pool.param().format; if (param_format == opr::AdaptivePooling::Param::Format::NCHW) { dst_layout[0] = src.layout[0]; dst_layout[1] = src.layout[1]; dst_layout[2] = oshp2d[0]; - dst_layout[3] = oshp2d[1]; + dst_layout[3] = tshp1n ? oshp2d[0] : oshp2d[1]; } else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) { dst_layout[0] = src.layout[0]; dst_layout[1] = oshp2d[0]; - dst_layout[2] = oshp2d[1]; + dst_layout[2] = tshp1n ? oshp2d[0] : oshp2d[1]; dst_layout[3] = src.layout[3]; } else { mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); @@ -83,8 +85,10 @@ SmallVector apply_on_physical_tensor( if (!validated) { dst_layout.ndim = src_layout.ndim; const dt_int32* oshp2d = nullptr; + bool tshp1n = false; if (inputs.size() == 2) { auto&& tshp_nd = inputs[1]; + tshp1n = inputs[1]->layout().total_nr_elems() == 1; oshp2d = tshp_nd->get_value().proxy_to_default_cpu().ptr(); } else { oshp2d = pool.shape.data(); @@ -93,11 +97,11 @@ SmallVector apply_on_physical_tensor( dst_layout[0] = src_layout[0]; dst_layout[1] = src_layout[1]; dst_layout[2] = oshp2d[0]; - dst_layout[3] = oshp2d[1]; + dst_layout[3] = tshp1n ? oshp2d[0] : oshp2d[1]; } else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) { dst_layout[0] = src_layout[0]; dst_layout[1] = oshp2d[0]; - dst_layout[2] = oshp2d[1]; + dst_layout[2] = tshp1n ? oshp2d[0] : oshp2d[1]; dst_layout[3] = src_layout[3]; } else { mgb_throw( diff --git a/src/opr/impl/dnn/adaptive_pooling.cpp b/src/opr/impl/dnn/adaptive_pooling.cpp index 17c8088d..fe6e0c1a 100644 --- a/src/opr/impl/dnn/adaptive_pooling.cpp +++ b/src/opr/impl/dnn/adaptive_pooling.cpp @@ -39,22 +39,23 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape( cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0)); auto src = shpinfo.shape_inp_shp.at(0); mgb_assert( - src.ndim == 4 && oshp2d.ndim == 2, + src.ndim == 4 && (oshp2d.ndim == 2 || oshp2d.ndim == 1), "shape mismatch for AdaptivePooling: src=%s, out2d=%s", src.to_string().c_str(), oshp2d.to_string().c_str()); auto param_format = param().format; + bool tshp1n = oshp2d.ndim == 1; if (param_format == Param::Format::NCHW) { dest.ndim = 4; dest.shape[0] = src.shape[0]; dest.shape[1] = src.shape[1]; dest.shape[2] = oshp2d.shape[0]; - dest.shape[3] = oshp2d.shape[1]; + dest.shape[3] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1]; } else if (param_format == Param::Format::NHWC) { dest.ndim = 4; dest.shape[0] = src.shape[0]; dest.shape[1] = oshp2d.shape[0]; - dest.shape[2] = oshp2d.shape[1]; + dest.shape[2] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1]; dest.shape[3] = src.shape[3]; } else { mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format");