From c20d4cc6dc5865993d0318fe7562b1c7b6c8d79f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 27 Aug 2020 10:57:51 +0800 Subject: [PATCH] feat(dnn): fix opt pass nchw44 can not dump resnet GitOrigin-RevId: 28e5c37f53349d482b191751923b5a4b05b0633d --- src/gopt/impl/tensor_reformat.cpp | 14 +++++++++++++- src/gopt/test/inference.cpp | 15 +++++++++++---- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index bae9498f..c6ad8bd1 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -1815,6 +1815,15 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, return new_var; } +static inline TensorShape nchwxx_shape_2_nchw_shape( + const TensorShape& origin_shape) { + mgb_assert(origin_shape.ndim == 5); + TensorShape result = origin_shape; + result[1] *= result[4]; + result.ndim = 4; + return result; +} + template static inline bool nchw_nchwxx_valid( const OprType& opr, const VarNodeArray& new_inp, const size_t pack_size, @@ -1847,7 +1856,10 @@ static inline bool nchw_nchwxx_valid( megdnn::ConvBiasForward::BiasMode bias_mode = megdnn::ConvBiasForward::BiasMode::NO_BIAS; if (std::is_same::value) { - auto& bias_shape = new_inp[2]->shape(); + TensorShape bias_shape = new_inp[2]->shape(); + if (bias_shape.ndim == 5) { + bias_shape = nchwxx_shape_2_nchw_shape(bias_shape); + } if (bias_shape.ndim == 0) { bias_mode = megdnn::ConvBiasForward::BiasMode::NO_BIAS; } else if (bias_shape.eq_shape(dst_node->shape())) { diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index f8ad6450..c6a7d795 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -3069,12 +3069,18 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { //! Dense param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; - auto w4 = mkcvar("w4", {4, 32, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), + auto w4 = mkcvar("w4", {16, 32, 3, 3}), b4 = mkcvar("b4", {1, 16, 1, 1}), conv4 = opr::ConvBias::make(conv3_3, w4, b4, param_conv_bias, {}, OperatorNodeConfig("conv4")); - - auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}), - conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias, {}, + auto w4_1 = mkcvar("w4_1", {16, 32, 1, 1}), + b4_1 = mkcvar("b4_1", {2, 16, 4, 4}), + conv4_1 = + opr::ConvBias::make(conv3_3, w4_1, b4_1, param_conv_bias_pad0, + {}, OperatorNodeConfig("conv4_1")); + auto conv4_add = conv4 + conv4_1; + + auto w5 = mkcvar("w5", {6, 16, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}), + conv5 = opr::ConvBias::make(conv4_add, w5, b5, param_conv_bias, {}, OperatorNodeConfig("conv5")); auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias, {}, @@ -3082,6 +3088,7 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { SymbolVar y_opt; auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_fuse_conv_bias_nonlinearity(); options.enable_nchw44(); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);