From 599eceac57458428b3fdcecfcfdede52b808c1d2 Mon Sep 17 00:00:00 2001 From: yangyongqiang Date: Fri, 23 Sep 2022 18:14:07 +0000 Subject: [PATCH] !675 bugfix for handle onnx model outputs Merge pull request !675 from yangyongqiang/bugfix_0921 --- metadef | 2 +- parser/common/acl_graph_parser_util.cc | 11 ++---- parser/onnx/onnx_parser.cc | 13 ++++--- .../origin_models/onnx_if_const_intput.onnx | 24 ++++++++++++ .../origin_models/onnx_if_const_intput_gen.py | 45 ++++++++++++++++++++++ tests/st/testcase/test_onnx_parser.cc | 10 +++++ 6 files changed, 90 insertions(+), 15 deletions(-) create mode 100644 tests/st/testcase/origin_models/onnx_if_const_intput.onnx create mode 100644 tests/st/testcase/origin_models/onnx_if_const_intput_gen.py diff --git a/metadef b/metadef index 599fbd9..f1af97e 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 599fbd9d7f9509b7673af90e186817b5a75ad547 +Subproject commit f1af97e1c9ce9164901d4e719d3acaa1b8597d14 diff --git a/parser/common/acl_graph_parser_util.cc b/parser/common/acl_graph_parser_util.cc index 2af798c..c10ed57 100644 --- a/parser/common/acl_graph_parser_util.cc +++ b/parser/common/acl_graph_parser_util.cc @@ -514,15 +514,10 @@ domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_gr if (!default_out_nodes.empty()) { for (size_t i = 0; i < default_out_nodes.size(); ++i) { ge::NodePtr out_node = compute_graph->FindNode(default_out_nodes[i].first); - if (out_node == nullptr) { - ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, - {"out_nodes", default_out_nodes[i].first}); - GELOGE(domi::FAILED, "[Check][Param] Can not find out_nodes(%zu) (%s) in graph.", - i, default_out_nodes[i].first.c_str()); - return domi::FAILED; + if (out_node != nullptr) { + output_nodes_info.push_back(std::make_pair(out_node, default_out_nodes[i].second)); + GELOGD("Get default output node:%s.", out_node->GetName().c_str()); } - output_nodes_info.push_back(std::make_pair(out_node, default_out_nodes[i].second)); - GELOGD("Get default output node:%s.", out_node->GetName().c_str()); } return domi::SUCCESS; } diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index a3ba4ca..0872b39 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -672,12 +672,13 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: } Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector &input_ops) { - // subgraph might not have input, or isolated const nodes exist in the graph, - // we use constant nodes as the start nodes of graph - for (int i = 0; i < onnx_graph.node_size(); i++) { - ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); - if (node->op_type() == kOpTypeConstant) { - input_node_names_.emplace_back(node->name()); + if (input_node_names_.empty()) { + // subgraph might not have input, we use constant nodes as the start nodes of the graph, + for (int i = 0; i < onnx_graph.node_size(); i++) { + ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); + if (node->op_type() == kOpTypeConstant) { + input_node_names_.emplace_back(node->name()); + } } } for (auto in_name : input_node_names_) { diff --git a/tests/st/testcase/origin_models/onnx_if_const_intput.onnx b/tests/st/testcase/origin_models/onnx_if_const_intput.onnx new file mode 100644 index 0000000..5c1db0e --- /dev/null +++ b/tests/st/testcase/origin_models/onnx_if_const_intput.onnx @@ -0,0 +1,24 @@ +:ß +¡ +X"If*K + else_branch29 +else_out"Constant else_bodyb +else_out + + + *K + then_branch29 +then_out"Constant then_bodyb +then_out + + +  + Y"Constantif_modelZ +X + + +b +Y + + +B \ No newline at end of file diff --git a/tests/st/testcase/origin_models/onnx_if_const_intput_gen.py b/tests/st/testcase/origin_models/onnx_if_const_intput_gen.py new file mode 100644 index 0000000..1d2eaa0 --- /dev/null +++ b/tests/st/testcase/origin_models/onnx_if_const_intput_gen.py @@ -0,0 +1,45 @@ +import os +import numpy as np +import onnx + +def gen_onnx(): + X = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [5]) + Y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [5]) + then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5]) + else_out = onnx.helper.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, [5]) + + const_out_node = onnx.helper.make_node("Constant", inputs=[], outputs=["Y"]) + + then_const_node = onnx.helper.make_node("Constant", inputs=[], outputs=["then_out"]) + else_const_node = onnx.helper.make_node("Constant", inputs=[], outputs=["else_out"]) + + then_body = onnx.helper.make_graph( + [then_const_node], + "then_body", + [], + [then_out] + ) + + else_body = onnx.helper.make_graph( + [else_const_node], + "else_body", + [], + [else_out] + ) + + if_node = onnx.helper.make_node("If", inputs=["X"], outputs=[], then_branch=then_body, else_branch=else_body) + + graph_def = onnx.helper.make_graph( + [if_node, const_out_node], + "if_model", + [X], + [Y] + ) + + model_def = onnx.helper.make_model(graph_def) + model_def.opset_import[0].version=11 + onnx.save(model_def, "onnx_if_const_intput.onnx") + print(model_def) + +if __name__ == "__main__": + gen_onnx() \ No newline at end of file diff --git a/tests/st/testcase/test_onnx_parser.cc b/tests/st/testcase/test_onnx_parser.cc index 459abed..799f2d6 100644 --- a/tests/st/testcase/test_onnx_parser.cc +++ b/tests/st/testcase/test_onnx_parser.cc @@ -174,4 +174,14 @@ TEST_F(STestOnnxParser, onnx_parser_const_data_type) { EXPECT_EQ(ret, GRAPH_SUCCESS); } +TEST_F(STestOnnxParser, onnx_parser_if_node_with_const_input) { + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/origin_models/onnx_if_const_intput.onnx"; + std::map parser_params; + ge::Graph graph; + auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); + EXPECT_EQ(ret, GRAPH_SUCCESS); +} + } // namespace ge