Merge pull request !675 from yangyongqiang/bugfix_0921pull/676/head
@@ -1 +1 @@ | |||||
Subproject commit 599fbd9d7f9509b7673af90e186817b5a75ad547 | |||||
Subproject commit f1af97e1c9ce9164901d4e719d3acaa1b8597d14 |
@@ -514,15 +514,10 @@ domi::Status AclGrphParseUtil::GetDefaultOutInfo(ge::ComputeGraphPtr &compute_gr | |||||
if (!default_out_nodes.empty()) { | if (!default_out_nodes.empty()) { | ||||
for (size_t i = 0; i < default_out_nodes.size(); ++i) { | for (size_t i = 0; i < default_out_nodes.size(); ++i) { | ||||
ge::NodePtr out_node = compute_graph->FindNode(default_out_nodes[i].first); | ge::NodePtr out_node = compute_graph->FindNode(default_out_nodes[i].first); | ||||
if (out_node == nullptr) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||||
{"out_nodes", default_out_nodes[i].first}); | |||||
GELOGE(domi::FAILED, "[Check][Param] Can not find out_nodes(%zu) (%s) in graph.", | |||||
i, default_out_nodes[i].first.c_str()); | |||||
return domi::FAILED; | |||||
if (out_node != nullptr) { | |||||
output_nodes_info.push_back(std::make_pair(out_node, default_out_nodes[i].second)); | |||||
GELOGD("Get default output node:%s.", out_node->GetName().c_str()); | |||||
} | } | ||||
output_nodes_info.push_back(std::make_pair(out_node, default_out_nodes[i].second)); | |||||
GELOGD("Get default output node:%s.", out_node->GetName().c_str()); | |||||
} | } | ||||
return domi::SUCCESS; | return domi::SUCCESS; | ||||
} | } | ||||
@@ -672,12 +672,13 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||||
} | } | ||||
Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) { | Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) { | ||||
// subgraph might not have input, or isolated const nodes exist in the graph, | |||||
// we use constant nodes as the start nodes of graph | |||||
for (int i = 0; i < onnx_graph.node_size(); i++) { | |||||
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); | |||||
if (node->op_type() == kOpTypeConstant) { | |||||
input_node_names_.emplace_back(node->name()); | |||||
if (input_node_names_.empty()) { | |||||
// subgraph might not have input, we use constant nodes as the start nodes of the graph, | |||||
for (int i = 0; i < onnx_graph.node_size(); i++) { | |||||
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); | |||||
if (node->op_type() == kOpTypeConstant) { | |||||
input_node_names_.emplace_back(node->name()); | |||||
} | |||||
} | } | ||||
} | } | ||||
for (auto in_name : input_node_names_) { | for (auto in_name : input_node_names_) { | ||||
@@ -0,0 +1,24 @@ | |||||
:ß | |||||
¡ | |||||
X"If*K | |||||
else_branch29 | |||||
else_out"Constant else_bodyb | |||||
else_out | |||||
*K | |||||
then_branch29 | |||||
then_out"Constant then_bodyb | |||||
then_out | |||||
Y"Constantif_modelZ | |||||
X | |||||
b | |||||
Y | |||||
B |
@@ -0,0 +1,45 @@ | |||||
import os | |||||
import numpy as np | |||||
import onnx | |||||
def gen_onnx(): | |||||
X = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [5]) | |||||
Y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [5]) | |||||
then_out = onnx.helper.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, [5]) | |||||
else_out = onnx.helper.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, [5]) | |||||
const_out_node = onnx.helper.make_node("Constant", inputs=[], outputs=["Y"]) | |||||
then_const_node = onnx.helper.make_node("Constant", inputs=[], outputs=["then_out"]) | |||||
else_const_node = onnx.helper.make_node("Constant", inputs=[], outputs=["else_out"]) | |||||
then_body = onnx.helper.make_graph( | |||||
[then_const_node], | |||||
"then_body", | |||||
[], | |||||
[then_out] | |||||
) | |||||
else_body = onnx.helper.make_graph( | |||||
[else_const_node], | |||||
"else_body", | |||||
[], | |||||
[else_out] | |||||
) | |||||
if_node = onnx.helper.make_node("If", inputs=["X"], outputs=[], then_branch=then_body, else_branch=else_body) | |||||
graph_def = onnx.helper.make_graph( | |||||
[if_node, const_out_node], | |||||
"if_model", | |||||
[X], | |||||
[Y] | |||||
) | |||||
model_def = onnx.helper.make_model(graph_def) | |||||
model_def.opset_import[0].version=11 | |||||
onnx.save(model_def, "onnx_if_const_intput.onnx") | |||||
print(model_def) | |||||
if __name__ == "__main__": | |||||
gen_onnx() |
@@ -174,4 +174,14 @@ TEST_F(STestOnnxParser, onnx_parser_const_data_type) { | |||||
EXPECT_EQ(ret, GRAPH_SUCCESS); | EXPECT_EQ(ret, GRAPH_SUCCESS); | ||||
} | } | ||||
TEST_F(STestOnnxParser, onnx_parser_if_node_with_const_input) { | |||||
std::string case_dir = __FILE__; | |||||
case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
std::string model_file = case_dir + "/origin_models/onnx_if_const_intput.onnx"; | |||||
std::map<ge::AscendString, ge::AscendString> parser_params; | |||||
ge::Graph graph; | |||||
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||||
EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
} | |||||
} // namespace ge | } // namespace ge |