Browse Source

!676 sync parser to master 20220923

Merge pull request !676 from zhangfan/ge_dev
pull/684/MERGE
zhangfan Gitee 2 years ago
parent
commit
1eca41e8bf
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 90 additions and 15 deletions
  1. +1
    -1
      metadef
  2. +3
    -8
      parser/common/acl_graph_parser_util.cc
  3. +7
    -6
      parser/onnx/onnx_parser.cc
  4. +24
    -0
      tests/st/testcase/origin_models/onnx_if_const_intput.onnx
  5. +45
    -0
      tests/st/testcase/origin_models/onnx_if_const_intput_gen.py
  6. +10
    -0
      tests/st/testcase/test_onnx_parser.cc

+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit 599fbd9d7f9509b7673af90e186817b5a75ad547
Subproject commit f1af97e1c9ce9164901d4e719d3acaa1b8597d14

+ 3
- 8
parser/common/acl_graph_parser_util.cc View File

@@ -514,15 +514,10 @@ domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_gr
if (!default_out_nodes.empty()) {
for (size_t i = 0; i < default_out_nodes.size(); ++i) {
ge::NodePtr out_node = compute_graph->FindNode(default_out_nodes[i].first);
if (out_node == nullptr) {
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"},
{"out_nodes", default_out_nodes[i].first});
GELOGE(domi::FAILED, "[Check][Param] Can not find out_nodes(%zu) (%s) in graph.",
i, default_out_nodes[i].first.c_str());
return domi::FAILED;
if (out_node != nullptr) {
output_nodes_info.push_back(std::make_pair(out_node, default_out_nodes[i].second));
GELOGD("Get default output node:%s.", out_node->GetName().c_str());
}
output_nodes_info.push_back(std::make_pair(out_node, default_out_nodes[i].second));
GELOGD("Get default output node:%s.", out_node->GetName().c_str());
}
return domi::SUCCESS;
}


+ 7
- 6
parser/onnx/onnx_parser.cc View File

@@ -672,12 +672,13 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
}

Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) {
// subgraph might not have input, or isolated const nodes exist in the graph,
// we use constant nodes as the start nodes of graph
for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i);
if (node->op_type() == kOpTypeConstant) {
input_node_names_.emplace_back(node->name());
if (input_node_names_.empty()) {
// subgraph might not have input, we use constant nodes as the start nodes of the graph,
for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i);
if (node->op_type() == kOpTypeConstant) {
input_node_names_.emplace_back(node->name());
}
}
}
for (auto in_name : input_node_names_) {


+ 24
- 0
tests/st/testcase/origin_models/onnx_if_const_intput.onnx View File

@@ -0,0 +1,24 @@
:ß
¡
X"If*K
else_branch29
else_out"Constant else_bodyb
else_out


 *K
then_branch29
then_out"Constant then_bodyb
then_out


 
Y"Constantif_modelZ
X


b
Y


B

+ 45
- 0
tests/st/testcase/origin_models/onnx_if_const_intput_gen.py View File

@@ -0,0 +1,45 @@
import os
import numpy as np
import onnx

def gen_onnx():
X = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [5])
Y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [5])
then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5])
else_out = onnx.helper.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, [5])

const_out_node = onnx.helper.make_node("Constant", inputs=[], outputs=["Y"])
then_const_node = onnx.helper.make_node("Constant", inputs=[], outputs=["then_out"])
else_const_node = onnx.helper.make_node("Constant", inputs=[], outputs=["else_out"])

then_body = onnx.helper.make_graph(
[then_const_node],
"then_body",
[],
[then_out]
)

else_body = onnx.helper.make_graph(
[else_const_node],
"else_body",
[],
[else_out]
)

if_node = onnx.helper.make_node("If", inputs=["X"], outputs=[], then_branch=then_body, else_branch=else_body)

graph_def = onnx.helper.make_graph(
[if_node, const_out_node],
"if_model",
[X],
[Y]
)

model_def = onnx.helper.make_model(graph_def)
model_def.opset_import[0].version=11
onnx.save(model_def, "onnx_if_const_intput.onnx")
print(model_def)

if __name__ == "__main__":
gen_onnx()

+ 10
- 0
tests/st/testcase/test_onnx_parser.cc View File

@@ -174,4 +174,14 @@ TEST_F(STestOnnxParser, onnx_parser_const_data_type) {
EXPECT_EQ(ret, GRAPH_SUCCESS);
}

TEST_F(STestOnnxParser, onnx_parser_if_node_with_const_input) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/onnx_if_const_intput.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, GRAPH_SUCCESS);
}

} // namespace ge

Loading…
Cancel
Save