Browse Source

!478 嵌套if 转模型失败问题修复

Merge pull request !478 from 苏俊伟/ge_dev
pull/480/head
i-robot Gitee 3 years ago
parent
commit
7ea0e2a80e
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 31 additions and 9 deletions
  1. +11
    -2
      parser/onnx/onnx_parser.cc
  2. +4
    -0
      parser/onnx/onnx_util.cc
  3. +1
    -0
      parser/onnx/onnx_util.h
  4. +8
    -4
      parser/onnx/subgraph_adapter/if_subgraph_adapter.cc
  5. +4
    -2
      parser/onnx/subgraph_adapter/if_subgraph_adapter.h
  6. +3
    -1
      parser/onnx/subgraph_adapter/subgraph_adapter.h

+ 11
- 2
parser/onnx/onnx_parser.cc View File

@@ -240,7 +240,7 @@ Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_gra
if (node->GetOpDesc() == nullptr) { if (node->GetOpDesc() == nullptr) {
continue; continue;
} }
node->GetOpDesc()->SetName(sub_graph->GetName() + "/" + node->GetName());
node->GetOpDesc()->SetName(OnnxUtil::GenUniqueNodeName(sub_graph->GetName(), node->GetName()));
} }


auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(sub_graph); auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(sub_graph);
@@ -750,6 +750,14 @@ Status OnnxModelParser::AdaptAndFindAllOnnxGraph(
while (!onnx_graph_tasks.empty()) { while (!onnx_graph_tasks.empty()) {
ge::onnx::GraphProto *onnx_graph = onnx_graph_tasks.front(); ge::onnx::GraphProto *onnx_graph = onnx_graph_tasks.front();
onnx_graph_tasks.pop(); onnx_graph_tasks.pop();
std::string graph_name;
for (const auto &graph_iter : name_to_onnx_graph) {
if (graph_iter.second == onnx_graph) {
graph_name = graph_iter.first;
break;
}
}

for (int i = 0; i < onnx_graph->node_size(); i++) { for (int i = 0; i < onnx_graph->node_size(); i++) {
ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i); ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i);
if (node_proto->name().empty()) { if (node_proto->name().empty()) {
@@ -767,7 +775,8 @@ Status OnnxModelParser::AdaptAndFindAllOnnxGraph(
} }
std::vector<ge::onnx::GraphProto *> onnx_graphs; std::vector<ge::onnx::GraphProto *> onnx_graphs;
std::map<std::string, ge::onnx::GraphProto *> name_to_onnx_subgraph; std::map<std::string, ge::onnx::GraphProto *> name_to_onnx_subgraph;
if (subgraph_adapter->AdaptAndFindAllSubgraphs(node_proto, onnx_graphs, name_to_onnx_subgraph) != SUCCESS) {
if (subgraph_adapter->AdaptAndFindAllSubgraphs(
node_proto, onnx_graphs, name_to_onnx_subgraph, graph_name) != SUCCESS) {
GELOGE(FAILED, "[Adapt][Subgraph] adapt subgraph of node:%s failed.", node_proto->name().c_str()); GELOGE(FAILED, "[Adapt][Subgraph] adapt subgraph of node:%s failed.", node_proto->name().c_str());
REPORT_INNER_ERROR("E19999", "adapt subgraph of node:%s failed.", node_proto->name().c_str()); REPORT_INNER_ERROR("E19999", "adapt subgraph of node:%s failed.", node_proto->name().c_str());
return FAILED; return FAILED;


+ 4
- 0
parser/onnx/onnx_util.cc View File

@@ -45,4 +45,8 @@ void OnnxUtil::GenUniqueSubgraphName(int subgraph_index, const std::string &orig
const std::string &parent_node_name, std::string &unique_subgraph_name) { const std::string &parent_node_name, std::string &unique_subgraph_name) {
unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name; unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name;
} }

std::string OnnxUtil::GenUniqueNodeName(const std::string &graph_name, const std::string &node_name) {
return graph_name + "/" + node_name;
}
} // namespace ge } // namespace ge

+ 1
- 0
parser/onnx/onnx_util.h View File

@@ -54,6 +54,7 @@ class OnnxUtil {
static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type); static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type);
static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name,
const std::string &parent_node_name, std::string &unique_subgraph_name); const std::string &parent_node_name, std::string &unique_subgraph_name);
static std::string GenUniqueNodeName(const std::string &graph_name, const std::string &node_name);
}; };
} // namespace ge } // namespace ge




+ 8
- 4
parser/onnx/subgraph_adapter/if_subgraph_adapter.cc View File

@@ -27,12 +27,12 @@ const int kIfNodeAttrSize = 2;
} // namespace } // namespace
domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) {
GE_CHECK_NOTNULL(parent_node); GE_CHECK_NOTNULL(parent_node);
GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(),
parent_node->op_type().c_str()); parent_node->op_type().c_str());


auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph);
auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "[Parse][Node] Parse if node failed."); GELOGE(ret, "[Parse][Node] Parse if node failed.");
REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str()); REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str());
@@ -44,7 +44,7 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(


domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs(
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) {
if (parent_node->attribute_size() != kIfNodeAttrSize) { if (parent_node->attribute_size() != kIfNodeAttrSize) {
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size());
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size());
@@ -67,7 +67,11 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs(
return FAILED; return FAILED;
} }
std::string unique_subgraph_name; std::string unique_subgraph_name;
OnnxUtil::GenUniqueSubgraphName(itr->second, itr->first, parent_node->name(), unique_subgraph_name);
std::string node_name = parent_node->name();
if (!parent_graph_name.empty()) {
node_name = OnnxUtil::GenUniqueNodeName(parent_graph_name, node_name);
}
OnnxUtil::GenUniqueSubgraphName(itr->second, itr->first, node_name, unique_subgraph_name);
GELOGI("Adapt if node attribute:%s, subgraph name:%s.", attr_name.c_str(), unique_subgraph_name.c_str()); GELOGI("Adapt if node attribute:%s, subgraph name:%s.", attr_name.c_str(), unique_subgraph_name.c_str());
ge::onnx::GraphProto *onnx_graph = attribute->mutable_g(); ge::onnx::GraphProto *onnx_graph = attribute->mutable_g();
name_to_onnx_graph[unique_subgraph_name] = onnx_graph; name_to_onnx_graph[unique_subgraph_name] = onnx_graph;


+ 4
- 2
parser/onnx/subgraph_adapter/if_subgraph_adapter.h View File

@@ -26,11 +26,13 @@ class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter {
public: public:
domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_node, domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_node,
std::vector<ge::onnx::GraphProto *> &onnx_graphs, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) override;
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph,
const std::string &parent_graph_name = "") override;


private: private:
domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph);
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph,
const std::string &parent_graph_name);
domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const; domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const;
void AddInputNodeForGraph(const std::set<std::string> &all_inputs, ge::onnx::GraphProto &onnx_graph) const; void AddInputNodeForGraph(const std::set<std::string> &all_inputs, ge::onnx::GraphProto &onnx_graph) const;
void AddInputForParentNode(const std::set<std::string> &all_inputs, ge::onnx::NodeProto &parent_node) const; void AddInputForParentNode(const std::set<std::string> &all_inputs, ge::onnx::NodeProto &parent_node) const;


+ 3
- 1
parser/onnx/subgraph_adapter/subgraph_adapter.h View File

@@ -49,10 +49,12 @@ class PARSER_FUNC_VISIBILITY SubgraphAdapter {
/// @return FAILED Parse failed /// @return FAILED Parse failed
virtual domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, virtual domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op,
std::vector<ge::onnx::GraphProto *> &onnx_graphs, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph,
const std::string &parent_graph_name = "") {
(void)parent_op; (void)parent_op;
(void)onnx_graphs; (void)onnx_graphs;
(void)name_to_onnx_graph; (void)name_to_onnx_graph;
(void)parent_graph_name;
return domi::SUCCESS; return domi::SUCCESS;
} }
}; };


Loading…
Cancel
Save