From 8070f40aa1896e0bfa0752e2847fa8974d424334 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 15 Jul 2020 20:12:25 +0800 Subject: [PATCH] fix(mgb/gopt): fix gopt nchwxx convert elemwise and reshape GitOrigin-RevId: 982dee36e111bf4cc25321cf5ee8ec20d14bfce2 --- src/gopt/impl/tensor_reformat.cpp | 50 +++++------------------ src/gopt/test/inference.cpp | 84 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 39 deletions(-) diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 970c953f..b6eefc15 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -2049,16 +2049,17 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ return new_opr; } }; - - auto replace_concat_opr = [=](OperatorNodeBase* opr, - const VarNodeArray& new_inp) { + //! When input change and all input can convert to nchwxx, this opr will run + //! in nchwxx mode, else it will run in nchw mode, for example concat and + //! elemwise opr + auto replace_multi_inp_opr = [=](OperatorNodeBase* opr, + const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); bool has_inp_changed = false; bool can_exec_ncwxx = true; for (size_t i = 0; i < opr->input().size(); i++) { if (new_inp[i]->shape().ndim == 5) { has_inp_changed = true; - break; } else if (new_inp[i]->shape().ndim == 4) { if (new_inp[i]->shape()[1] % pack_c_size != 0) { can_exec_ncwxx = false; @@ -2095,36 +2096,6 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ } }; - auto replace_elemwise_opr = [=](OperatorNodeBase* opr, - const VarNodeArray& new_inp) { - mgb_assert(opr->input().size() == new_inp.size()); - bool has_inp_changed = false; - for (size_t i = 0; i < opr->input().size(); i++) { - if (new_inp[i]->shape().ndim == 5) { - has_inp_changed = true; - break; - } - } - if (has_inp_changed) { - auto temp_inp = new_inp; - for (size_t i = 0; i < opr->input().size(); i++) { - if (new_inp[i]->shape().ndim == 4) { - auto new_var = RelayoutPlaceholder::make( - new_inp[i], src_to_nchwxx_mode); - temp_inp[i] = new_var.node(); - } else { - mgb_assert((new_inp[i]->shape().ndim == 5) || - new_inp[i]->shape().is_scalar()); - } - } - return serialization::copy_opr_shallow(*opr, temp_inp, - opr->config()); - } else { - return serialization::copy_opr_shallow(*opr, new_inp, - opr->config()); - } - }; - auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr, const VarNodeArray& new_inp) { mgb_assert(opr->input().size() == new_inp.size()); @@ -2146,11 +2117,11 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; - replace_func[opr::Concat::typeinfo()] = replace_concat_opr; - replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_opr; - replace_func[opr::TypeCvt::typeinfo()] = replace_elemwise_opr; - replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_opr; - replace_func[opr::PowC::typeinfo()] = replace_elemwise_opr; + replace_func[opr::Concat::typeinfo()] = replace_multi_inp_opr; + replace_func[opr::Elemwise::typeinfo()] = replace_multi_inp_opr; + replace_func[opr::TypeCvt::typeinfo()] = replace_multi_inp_opr; + replace_func[opr::ElemwiseMultiType::typeinfo()] = replace_multi_inp_opr; + replace_func[opr::PowC::typeinfo()] = replace_multi_inp_opr; //! not support yet replace_func[opr::ConvolutionBackwardData::typeinfo()] = relayout_inp_to_nchw; @@ -2164,6 +2135,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ replace_func[opr::WarpPerspectiveForward::typeinfo()] = relayout_inp_to_nchw; replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; + replace_func[opr::Reshape::typeinfo()] = relayout_inp_to_nchw; } std::unique_ptr EnableNchwxxPass::make_nchwxx_converter( diff --git a/src/gopt/test/inference.cpp b/src/gopt/test/inference.cpp index 10aa68f6..7551e45f 100644 --- a/src/gopt/test/inference.cpp +++ b/src/gopt/test/inference.cpp @@ -2948,6 +2948,90 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); } +TEST(TestGoptInference, ConvertFormatNCHW44MultiInput) { + HostTensorGenerator<> gen; + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkvar = [&](const char* name, const TensorShape& shp) { + return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); + }; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name); + }; + + auto host_x1 = gen({1, 8, 16, 16}, cn); + auto host_x2 = gen({1, 1, 16, 16}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x1); + opr::Convolution::Param param_conv; + param_conv.pad_h = param_conv.pad_w = 1; + auto w1 = mkcvar("w1", {8, 8, 3, 3}), + conv1 = opr::Convolution::make(x, w1, param_conv); + + auto b = mkvar("b", {1, 1, 16, 16}), + y = opr::Elemwise::make({conv1 + b}, opr::Elemwise::Param::Mode::RELU); + + SymbolVar y_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nchw44(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + + ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, + find_opr(y_opt).param().format); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath(output_file( + "TestGoptInference.ConvertFormatNCHW44MultiInput.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + //! meybe go to winograd in x86-32, so set error 1e-1 + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); +} + +TEST(TestGoptInference, ConvertFormatNCHW44Reshape) { + HostTensorGenerator<> gen; + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto mkcvar = [&](const char* name, const TensorShape& shp) { + return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name); + }; + + auto host_x1 = gen({1, 8, 16, 16}, cn); + auto x = opr::Host2DeviceCopy::make(*graph, host_x1); + opr::Convolution::Param param_conv; + param_conv.pad_h = param_conv.pad_w = 1; + auto w1 = mkcvar("w1", {8, 8, 3, 3}), + conv1 = opr::Convolution::make(x, w1, param_conv); + auto y = opr::Reshape::make(conv1, {8, 16 * 16}); + + SymbolVar y_opt; + auto options = gopt::OptimizeForInferenceOptions{}; + options.enable_nchw44(); + unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); + + ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, + find_opr(y_opt).param().format); + + graph->compile({{y_opt, {}}}) + ->to_json() + ->writeto_fpath(output_file( + "TestGoptInference.ConvertFormatNCHW44Reshape.json")); + + HostTensorND host_y_opt, host_y; + auto func = graph->compile({make_callback_copy(y, host_y), + make_callback_copy(y_opt, host_y_opt)}); + func->execute(); + //! meybe go to winograd in x86-32, so set error 1e-1 + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); +} + TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { HostTensorGenerator<> gen; auto cn = CompNode::load("cpu0");