From eab7ab0530856d5296074862b5cb2e6c953659bb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 20 Aug 2020 10:59:12 +0800 Subject: [PATCH] fix(gopt): gen nchw_nchw44 when kernel is optimized GitOrigin-RevId: 89083be200041dcbcbe12d64f9c67667bcf101b2 --- src/gopt/impl/tensor_reformat.cpp | 49 ++++++++++++++++++++++++++++----------- src/gopt/test/inference.cpp | 47 ++++++++++++++++++++++++++++++------- 2 files changed, 74 insertions(+), 22 deletions(-) diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index c9cfbcdd..b88fabc8 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -1811,7 +1811,15 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, } return new_var; } - +//! nchw_nchwxx_valid is used to indicate optimized nchw_nchw44 conv +static inline bool nchw_nchwxx_valid(const size_t oc, const size_t ic, + const size_t pack_c_size, const size_t fh, + const size_t fw, const size_t stride_h, + const size_t stride_w) { + return ic < pack_c_size && oc % pack_c_size == 0 && fh == fw && + stride_h == stride_w && (stride_h == 1 || stride_h == 2) && + (fh == 2 || fh == 3 || fh == 5 || fh == 7); +} void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { using RelayoutMode = RelayoutPlaceholder::LayoutType; using TestFilterResult = std::pair; @@ -1848,15 +1856,19 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { weight_to_nchwxx_mode_group, weight_to_nchwxx_mode_chan, hybrid_nchw_nchwxx]( const megdnn::param::Convolution::Sparse conv_mode, - const VarNode* filter) -> TestFilterResult { + const VarNode* filter, const size_t stride_h, + const size_t stride_w) -> TestFilterResult { TestFilterResult ret{TransType::TRANS_NONE, {}}; if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { - size_t IC = filter->shape()[1]; size_t OC = filter->shape()[0]; + size_t IC = filter->shape()[1]; + size_t FH = filter->shape()[2]; + size_t FW = filter->shape()[3]; if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { ret.first = TransType::TRANS_PURE_NCHWXX; ret.second = weight_to_nchwxx_mode_dense; - } else if (IC < pack_c_size && OC % pack_c_size == 0) { + } else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h, + stride_w)) { ret.first = TransType::TRANS_HYBIRD_NCHWXX; ret.second = hybrid_nchw_nchwxx; } @@ -1883,7 +1895,9 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { mgb_assert(conv_opr.param().format == megdnn::param::Convolution::Format::NCHW, "ConvertFormat Pass only support converting NCHW to NCHWXX"); - auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1]); + auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1], + conv_opr.param().stride_h, + conv_opr.param().stride_w); //! can not trans to nchwxx if (is_trans.first == TransType::TRANS_NONE) { mgb_assert(new_inp[1]->shape().ndim == 4 || @@ -1957,8 +1971,9 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { mgb_assert(conv_bias_opr.param().format == megdnn::param::ConvBias::Format::NCHW, "ConvertFormat Pass only support converting NCHW to NCHWXX"); - auto is_trans = - test_trans_nchwxx(conv_bias_opr.param().sparse, new_inp[1]); + auto is_trans = test_trans_nchwxx( + conv_bias_opr.param().sparse, new_inp[1], + conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w); //! can not trans to nchwxx if (is_trans.first == TransType::TRANS_NONE) { mgb_assert(new_inp[1]->shape().ndim == 4 || @@ -2199,17 +2214,21 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { constexpr size_t pack_c_size = 4_z; auto test_trans_nchw44_dot = [](const megdnn::param::Convolution::Sparse conv_mode, - const VarNode* filter) -> TestTransResult { + const VarNode* filter, const size_t stride_h, + const size_t stride_w) -> TestTransResult { TestTransResult ret{TransType::TRANS_NONE, {}, {}}; if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { - size_t IC = filter->shape()[1]; size_t OC = filter->shape()[0]; + size_t IC = filter->shape()[1]; + size_t FH = filter->shape()[2]; + size_t FW = filter->shape()[3]; if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { ret.trans_type = TransType::TRANS_PURE_NCHWXX; ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; - } else if (IC < pack_c_size && OC % pack_c_size == 0) { + } else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h, + stride_w)) { ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX; ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; @@ -2241,8 +2260,9 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { megdnn::param::Convolution::Format::NCHW, "ConvertFormat Pass only support converting NCHW to " "NCHW44_DOT"); - auto is_trans = - test_trans_nchw44_dot(conv_opr.param().sparse, new_inp[1]); + auto is_trans = test_trans_nchw44_dot( + conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, + conv_opr.param().stride_w); //! can not trans to nchwxx if (is_trans.trans_type == TransType::TRANS_NONE) { mgb_assert(new_inp[1]->shape().ndim == 4 || @@ -2315,8 +2335,9 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { mgb_assert(conv_bias_opr.param().format == megdnn::param::ConvBias::Format::NCHW, "ConvertFormat Pass only support converting NCHW to NCHWXX"); - auto is_trans = - test_trans_nchw44_dot(conv_bias_opr.param().sparse, new_inp[1]); + auto is_trans = test_trans_nchw44_dot( + conv_bias_opr.param().sparse, new_inp[1], + conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w); //! can not trans to nchwxx if (is_trans.trans_type == TransType::TRANS_NONE) { mgb_assert(new_inp[1]->shape().ndim == 4 || diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 6a8c42a4..a8bbb9ff 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -2976,11 +2976,16 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { auto host_x = gen({2, 3, 16, 16}, cn); auto x = opr::Host2DeviceCopy::make(*graph, host_x); - //! Hybrid nchw88 mode + //! Hybrid nchw44 mode opr::Convolution::Param param_conv; param_conv.pad_h = param_conv.pad_w = 1; + opr::ConvBias::Param param_conv_bias_stride4; + param_conv_bias_stride4.stride_h = param_conv_bias_stride4.stride_w = 4; auto w1 = mkcvar("w1", {8, 3, 3, 3}), conv1 = opr::Convolution::make(x, w1, param_conv); + auto w1_1 = mkcvar("w1_1", {8, 3, 4, 4}), b1 = mkcvar("b2", {1, 8, 1, 1}), + conv1_f4 = opr::Convolution::make(x, w1_1, param_conv); + auto conv1_s4 = opr::ConvBias::make(x, w1, b1, param_conv_bias_stride4); //! channel wise opr::ConvBias::Param param_conv_bias; param_conv_bias.pad_h = param_conv_bias.pad_w = 1; @@ -3015,22 +3020,35 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); - SymbolVar y_opt; + SymbolVar y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt, conv2_opt; auto options = gopt::OptimizeForInferenceOptions{}; options.enable_nchw44(); - unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + unpack_vector(gopt::optimize_for_inference( + {y, conv1, conv1_f4, conv1_s4, conv2}, options), + y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt, conv2_opt); ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, + find_opr(conv1_opt).param().format); + ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, + find_opr(conv1_s4_opt).param().format); + ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, + find_opr(conv1_f4_opt).param().format); + ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, + find_opr(conv2_opt).param().format); + ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, find_opr(y_opt).param().format); - graph->compile({{y_opt, {}}}) + graph->compile({{y_opt, {}}, {conv2, {}}}) ->to_json() ->writeto_fpath( output_file("TestGoptInference.ConvertFormatNCHW44.json")); HostTensorND host_y_opt, host_y; + HostTensorND host_conv1; auto func = graph->compile({make_callback_copy(y, host_y), - make_callback_copy(y_opt, host_y_opt)}); + make_callback_copy(y_opt, host_y_opt), + make_callback_copy(conv1, host_conv1)}); + func->execute(); //! meybe go to winograd in x86-32, so set error 1e-1 MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); @@ -3140,11 +3158,16 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { auto host_x = gen({2, 3, 16, 16}, cn); auto x = opr::Host2DeviceCopy::make(*graph, host_x); - //! Hybrid nchw88 mode + //! Hybrid nchw44 mode opr::Convolution::Param param_conv; param_conv.pad_h = param_conv.pad_w = 1; + opr::ConvBias::Param param_conv_bias_stride4; + param_conv_bias_stride4.stride_h = param_conv_bias_stride4.stride_w = 4; auto w1 = mkcvar("w1", {8, 3, 3, 3}), conv1 = opr::Convolution::make(x, w1, param_conv); + auto w1_1 = mkcvar("w1_1", {8, 3, 4, 4}), b1 = mkcvar("b2", {1, 8, 1, 1}), + conv1_f4 = opr::Convolution::make(x, w1_1, param_conv); + auto conv1_s4 = opr::ConvBias::make(x, w1, b1, param_conv_bias_stride4); //! channel wise opr::ConvBias::Param param_conv_bias; param_conv_bias.pad_h = param_conv_bias.pad_w = 1; @@ -3179,12 +3202,20 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); - SymbolVar y_opt; + SymbolVar y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt; auto options = gopt::OptimizeForInferenceOptions{}; options.enable_nchw44_dot(); - unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + unpack_vector(gopt::optimize_for_inference({y, conv1, conv1_f4, conv1_s4}, + options), + y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt); ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT, + find_opr(conv1_opt).param().format); + ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, + find_opr(conv1_s4_opt).param().format); + ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, + find_opr(conv1_f4_opt).param().format); + ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT, find_opr(y_opt).param().format); ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, find_opr(y_opt).param().format);