diff --git a/imperative/src/impl/ops/adaptive_pooling.cpp b/imperative/src/impl/ops/adaptive_pooling.cpp index 0521ce1d..fbd972fc 100644 --- a/imperative/src/impl/ops/adaptive_pooling.cpp +++ b/imperative/src/impl/ops/adaptive_pooling.cpp @@ -37,12 +37,10 @@ std::tuple, bool> infer_output_attrs_fallible( return {{{TensorLayout(src.layout.dtype), src.comp_node}}, false}; } + const dt_int32* oshp2d = nullptr; dst_layout.ndim = 4u; if (nr_inp == 1) { - dst_layout[0] = src.layout[0]; - dst_layout[1] = src.layout[1]; - dst_layout[2] = pool.shape[0]; - dst_layout[3] = pool.shape[1]; + oshp2d = pool.shape.data(); } else { auto&& tshp = inputs[1]; if (tshp.value.empty()) { @@ -52,11 +50,21 @@ std::tuple, bool> infer_output_attrs_fallible( tshp.layout.ndim == 1, "target shape of AdaptivePooling expects ndim=1; got ndim=%lu actually", tshp.layout.ndim); + oshp2d = tshp.value.ptr(); + } + auto param_format = pool.param().format; + if (param_format == opr::AdaptivePooling::Param::Format::NCHW) { dst_layout[0] = src.layout[0]; dst_layout[1] = src.layout[1]; - auto* ptr = tshp.value.ptr(); - dst_layout[2] = ptr[0]; - dst_layout[3] = ptr[1]; + dst_layout[2] = oshp2d[0]; + dst_layout[3] = oshp2d[1]; + } else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) { + dst_layout[0] = src.layout[0]; + dst_layout[1] = oshp2d[0]; + dst_layout[2] = oshp2d[1]; + dst_layout[3] = src.layout[3]; + } else { + mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); } dst_layout.init_contiguous_stride(); return {{{dst_layout, src.comp_node}}, true}; @@ -71,26 +79,47 @@ SmallVector apply_on_physical_tensor( using TensorND = megdnn::TensorND; auto&& src_layout = inputs[0]->layout(); TensorLayout dst_layout = output_descs[0].layout; + auto param_format = pool.format; if (!validated) { - TensorShape tshp; dst_layout.ndim = src_layout.ndim; - dst_layout[0] = src_layout[0]; - dst_layout[1] = src_layout[1]; + const dt_int32* oshp2d = nullptr; if (inputs.size() == 2) { auto&& tshp_nd = inputs[1]; - cg::copy_tensor_value_to_shape( - tshp, tshp_nd->get_value().proxy_to_default_cpu()); - dst_layout[2] = tshp[0]; - dst_layout[3] = tshp[1]; + oshp2d = tshp_nd->get_value().proxy_to_default_cpu().ptr(); } else { - dst_layout[2] = pool.shape[0]; - dst_layout[3] = pool.shape[1]; + oshp2d = pool.shape.data(); + } + if (param_format == opr::AdaptivePooling::Param::Format::NCHW) { + dst_layout[0] = src_layout[0]; + dst_layout[1] = src_layout[1]; + dst_layout[2] = oshp2d[0]; + dst_layout[3] = oshp2d[1]; + } else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) { + dst_layout[0] = src_layout[0]; + dst_layout[1] = oshp2d[0]; + dst_layout[2] = oshp2d[1]; + dst_layout[3] = src_layout[3]; + } else { + mgb_throw( + MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); } dst_layout.init_contiguous_stride(); } - size_t IH = src_layout[2], IW = src_layout[3], OH = dst_layout[2], - OW = dst_layout[3]; + size_t IH, IW, OH, OW; + if (param_format == param::AdaptivePooling::Format::NCHW) { + IH = src_layout[2]; + IW = src_layout[3]; + OH = dst_layout[2]; + OW = dst_layout[3]; + } else if (param_format == param::AdaptivePooling::Format::NHWC) { + IH = src_layout[1]; + IW = src_layout[2]; + OH = dst_layout[1]; + OW = dst_layout[2]; + } else { + mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); + } DnnOprCaller dnn_opr(cn); auto&& param = dnn_opr.op->param(); param.mode = pool.mode; diff --git a/imperative/src/impl/transformations/format.cpp b/imperative/src/impl/transformations/format.cpp index 77b4773f..454b1422 100644 --- a/imperative/src/impl/transformations/format.cpp +++ b/imperative/src/impl/transformations/format.cpp @@ -105,7 +105,7 @@ std::vector convert_nchw2nhwc_vector(const std::vector& shape) } else { mgb_throw( MegBrainError, - "Unsupported shape ndim %u in convert NCHW shape to NHWC.", + "Unsupported shape ndim %lu in convert NCHW shape to NHWC.", shape.size()); } } @@ -184,7 +184,8 @@ ValueRefList reshape_rule( // output is still NHWC format auto nhwc_shape = convert_nchw2nhwc_vector(op.shape); auto outputs = imperative::apply( - *Reshape::make(op.axis, nhwc_shape), {t.unwrap_input(inputs[0])}); + *Reshape::make(op.axis, nhwc_shape), + {t.unwrap_input(inputs[0])}); return t.wrap_outputs(outputs, FT::NHWC); } else { // will not maintain src's format @@ -395,12 +396,17 @@ ValueRefList batchnorm_rule( return identity_rule_helper(op, inputs, t); } -ValueRefList checknonfinite_rule( - const CheckNonFinite& op, Span& inputs, const bool& auto_convert, +ValueRefList adaptive_pooling_rule( + const AdaptivePooling& op, Span& inputs, const bool& auto_convert, const FormatTransformation& t) { - auto&& inputs_ = t.unwrap_inputs(inputs); - auto&& outputs_ = imperative::apply(op, inputs_); - return t.wrap_outputs(outputs_); + auto&& inp_format = inputs[0].cast(t.value_type()).format(); + if (inp_format == FT::NHWC) { + auto&& new_param = op.param(); + new_param.format = AdaptivePooling::Format::NHWC; + auto new_op = AdaptivePooling::make(new_param, op.shape); + return identity_rule_helper(*new_op, inputs, t); + } + return identity_rule_helper(op, inputs, t); } // clang-format off @@ -417,7 +423,6 @@ ValueRefList checknonfinite_rule( cb(Identity) #define FOREACH_FORMAT_OP(cb) \ - cb(AdaptivePooling) \ cb(WarpAffine) \ cb(Resize) @@ -494,7 +499,7 @@ struct FormatRuleRegistry { register_format_rule(setsubtensor_rule); register_format_rule(concat_rule); register_format_rule(batchnorm_rule); - register_format_rule(checknonfinite_rule); + register_format_rule(adaptive_pooling_rule); FOREACH_MULTI_INPS_NO_PARAM_OP(REGISTER_OP_RULE) FOREACH_IDENTITY_OP(REGISTER_OP_RULE) FOREACH_FORMAT_OP(REGISTER_OP_RULE) @@ -506,7 +511,7 @@ struct FormatRuleRegistry { ValueRefList FormatTransformation::apply_transformation( const Operator& op, Span inputs) { - //mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str()); + // mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str()); if (auto* apply_op = op.as()) { // all inputs should be FormattedTensorValue auto iter = format_rules.find(apply_op->op().dyn_typeinfo());