From 8f44d6ea60c1c4517c27b1cbcb2e9e43092d1794 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 16 Sep 2020 17:44:33 +0800 Subject: [PATCH] fix(mgb): fix optpass fail at transform NCHW to NCHW4 when input dtype is float GitOrigin-RevId: 3c2c68b11cbf2dd28ae61745c26cc276e74c0761 --- src/gopt/impl/tensor_reformat.cpp | 31 ++++++++++++++++++++++++++++++- src/gopt/test/inference.cpp | 2 +- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index c6ad8bd1..eebe13d8 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -1415,6 +1415,7 @@ VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var, return new_var; } +//! FIXME: All float oprs do not support NCHW4. Supports it in the future plz. std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { MIDOUT_B("EnableNCHW4Pass::make") auto ret = std::make_unique(); @@ -1467,6 +1468,10 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { auto replace_conv_opr = [trans_nchw4, conv_format]( OperatorNodeBase* opr, const VarNodeArray& new_inp) { + if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } mgb_assert(opr->input().size() == new_inp.size()); auto& conv_opr = opr->cast_final_safe(); if (conv_opr.param().format != @@ -1503,6 +1508,10 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { src_to_nchw4_mode]( OperatorNodeBase* opr, const VarNodeArray& new_inp) { + if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } mgb_assert(opr->input().size() == new_inp.size()); auto& batch_conv_bias_opr = opr->cast_final_safe(); @@ -1580,6 +1589,10 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { src_to_nchw4_mode]( OperatorNodeBase* opr, const VarNodeArray& new_inp) { + if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } mgb_assert(opr->input().size() == new_inp.size()); auto& conv_bias_opr = opr->cast_final_safe(); if (conv_bias_opr.param().format != @@ -1647,6 +1660,10 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { }; auto replace_elemwise_opr = [=](OperatorNodeBase* opr, const VarNodeArray& new_inp) { + if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } mgb_assert(opr->input().size() == new_inp.size()); bool has_inp_changed = false; for (size_t i = 0; i < opr->input().size(); i++) { @@ -1691,6 +1708,10 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { }; auto replace_pooling_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) { + if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } using Param = opr::PoolingForward::Param; using Format = Param::Format; mgb_assert(opr->input().size() == new_inp.size()); @@ -1716,6 +1737,10 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { }; auto replace_resize_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) { + if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } using Param = opr::ResizeForward::Param; using Format = Param::Format; mgb_assert(opr->input().size() == new_inp.size()); @@ -1738,6 +1763,10 @@ std::unique_ptr EnableNCHW4Pass::make_nchw4_converter() { }; auto replace_warp_perspective_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) { + if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { + return serialization::copy_opr_shallow(*opr, new_inp, + opr->config()); + } using Param = opr::WarpPerspective::Param; using Format = Param::Format; mgb_assert(opr->input().size() == new_inp.size()); @@ -3127,4 +3156,4 @@ void ShuffleShuffleRemovePass::apply(OptState& opt) const { MIDOUT_E } -// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 44989430..05892bd8 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -2816,7 +2816,7 @@ TEST(TestGoptInference, ConvertFormatNCHW4) { unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); } - ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4, + ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, find_opr(y_opt).param().format); graph->compile({{y_opt, {}}})