From 2bb272f3e9527799e6e5d8f45ef0a17722e48029 Mon Sep 17 00:00:00 2001 From: su-junwei3 Date: Mon, 7 Mar 2022 11:04:47 +0800 Subject: [PATCH] bug fix for if node --- parser/onnx/onnx_parser.cc | 13 +++++++++++-- parser/onnx/onnx_util.cc | 4 ++++ parser/onnx/onnx_util.h | 1 + parser/onnx/subgraph_adapter/if_subgraph_adapter.cc | 12 ++++++++---- parser/onnx/subgraph_adapter/if_subgraph_adapter.h | 6 ++++-- parser/onnx/subgraph_adapter/subgraph_adapter.h | 4 +++- 6 files changed, 31 insertions(+), 9 deletions(-) diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index 4e3fb04..764f8a4 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -240,7 +240,7 @@ Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_gra if (node->GetOpDesc() == nullptr) { continue; } - node->GetOpDesc()->SetName(sub_graph->GetName() + "/" + node->GetName()); + node->GetOpDesc()->SetName(OnnxUtil::GenUniqueNodeName(sub_graph->GetName(), node->GetName())); } auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(sub_graph); @@ -750,6 +750,14 @@ Status OnnxModelParser::AdaptAndFindAllOnnxGraph( while (!onnx_graph_tasks.empty()) { ge::onnx::GraphProto *onnx_graph = onnx_graph_tasks.front(); onnx_graph_tasks.pop(); + std::string graph_name; + for (const auto &graph_iter : name_to_onnx_graph) { + if (graph_iter.second == onnx_graph) { + graph_name = graph_iter.first; + break; + } + } + for (int i = 0; i < onnx_graph->node_size(); i++) { ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i); if (node_proto->name().empty()) { @@ -767,7 +775,8 @@ Status OnnxModelParser::AdaptAndFindAllOnnxGraph( } std::vector onnx_graphs; std::map name_to_onnx_subgraph; - if (subgraph_adapter->AdaptAndFindAllSubgraphs(node_proto, onnx_graphs, name_to_onnx_subgraph) != SUCCESS) { + if (subgraph_adapter->AdaptAndFindAllSubgraphs( + node_proto, onnx_graphs, name_to_onnx_subgraph, graph_name) != SUCCESS) { GELOGE(FAILED, "[Adapt][Subgraph] adapt subgraph of node:%s failed.", node_proto->name().c_str()); REPORT_INNER_ERROR("E19999", "adapt subgraph of node:%s failed.", node_proto->name().c_str()); return FAILED; diff --git a/parser/onnx/onnx_util.cc b/parser/onnx/onnx_util.cc index d9f036a..040b134 100644 --- a/parser/onnx/onnx_util.cc +++ b/parser/onnx/onnx_util.cc @@ -45,4 +45,8 @@ void OnnxUtil::GenUniqueSubgraphName(int subgraph_index, const std::string &orig const std::string &parent_node_name, std::string &unique_subgraph_name) { unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name; } + +std::string OnnxUtil::GenUniqueNodeName(const std::string &graph_name, const std::string &node_name) { + return graph_name + "/" + node_name; +} } // namespace ge diff --git a/parser/onnx/onnx_util.h b/parser/onnx/onnx_util.h index 6fab6fc..1cd8814 100644 --- a/parser/onnx/onnx_util.h +++ b/parser/onnx/onnx_util.h @@ -54,6 +54,7 @@ class OnnxUtil { static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type); static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, const std::string &parent_node_name, std::string &unique_subgraph_name); + static std::string GenUniqueNodeName(const std::string &graph_name, const std::string &node_name); }; } // namespace ge diff --git a/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc b/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc index cf26a98..37df217 100644 --- a/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc +++ b/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc @@ -27,12 +27,12 @@ const int kIfNodeAttrSize = 2; } // namespace domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( ge::onnx::NodeProto *parent_node, std::vector &onnx_graphs, - std::map &name_to_onnx_graph) { + std::map &name_to_onnx_graph, const std::string &parent_graph_name) { GE_CHECK_NOTNULL(parent_node); GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), parent_node->op_type().c_str()); - auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph); + auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name); if (ret != SUCCESS) { GELOGE(ret, "[Parse][Node] Parse if node failed."); REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str()); @@ -44,7 +44,7 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( ge::onnx::NodeProto *parent_node, std::vector &onnx_graphs, - std::map &name_to_onnx_graph) { + std::map &name_to_onnx_graph, const std::string &parent_graph_name) { if (parent_node->attribute_size() != kIfNodeAttrSize) { GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); @@ -67,7 +67,11 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( return FAILED; } std::string unique_subgraph_name; - OnnxUtil::GenUniqueSubgraphName(itr->second, itr->first, parent_node->name(), unique_subgraph_name); + std::string node_name = parent_node->name(); + if (!parent_graph_name.empty()) { + node_name = OnnxUtil::GenUniqueNodeName(parent_graph_name, node_name); + } + OnnxUtil::GenUniqueSubgraphName(itr->second, itr->first, node_name, unique_subgraph_name); GELOGI("Adapt if node attribute:%s, subgraph name:%s.", attr_name.c_str(), unique_subgraph_name.c_str()); ge::onnx::GraphProto *onnx_graph = attribute->mutable_g(); name_to_onnx_graph[unique_subgraph_name] = onnx_graph; diff --git a/parser/onnx/subgraph_adapter/if_subgraph_adapter.h b/parser/onnx/subgraph_adapter/if_subgraph_adapter.h index 9b6f1e5..ff2f6e6 100644 --- a/parser/onnx/subgraph_adapter/if_subgraph_adapter.h +++ b/parser/onnx/subgraph_adapter/if_subgraph_adapter.h @@ -26,11 +26,13 @@ class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { public: domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_node, std::vector &onnx_graphs, - std::map &name_to_onnx_graph) override; + std::map &name_to_onnx_graph, + const std::string &parent_graph_name = "") override; private: domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector &onnx_graphs, - std::map &name_to_onnx_graph); + std::map &name_to_onnx_graph, + const std::string &parent_graph_name); domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set &all_inputs) const; void AddInputNodeForGraph(const std::set &all_inputs, ge::onnx::GraphProto &onnx_graph) const; void AddInputForParentNode(const std::set &all_inputs, ge::onnx::NodeProto &parent_node) const; diff --git a/parser/onnx/subgraph_adapter/subgraph_adapter.h b/parser/onnx/subgraph_adapter/subgraph_adapter.h index 502e40d..ad9eb1a 100644 --- a/parser/onnx/subgraph_adapter/subgraph_adapter.h +++ b/parser/onnx/subgraph_adapter/subgraph_adapter.h @@ -49,10 +49,12 @@ class PARSER_FUNC_VISIBILITY SubgraphAdapter { /// @return FAILED Parse failed virtual domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, std::vector &onnx_graphs, - std::map &name_to_onnx_graph) { + std::map &name_to_onnx_graph, + const std::string &parent_graph_name = "") { (void)parent_op; (void)onnx_graphs; (void)name_to_onnx_graph; + (void)parent_graph_name; return domi::SUCCESS; } };