Browse Source

!478 嵌套if 转模型失败问题修复

Merge pull request !478 from 苏俊伟/ge_dev
pull/480/head
i-robot Gitee 3 years ago
parent
commit
7ea0e2a80e
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 31 additions and 9 deletions
  1. +11
    -2
      parser/onnx/onnx_parser.cc
  2. +4
    -0
      parser/onnx/onnx_util.cc
  3. +1
    -0
      parser/onnx/onnx_util.h
  4. +8
    -4
      parser/onnx/subgraph_adapter/if_subgraph_adapter.cc
  5. +4
    -2
      parser/onnx/subgraph_adapter/if_subgraph_adapter.h
  6. +3
    -1
      parser/onnx/subgraph_adapter/subgraph_adapter.h

+ 11
- 2
parser/onnx/onnx_parser.cc View File

@@ -240,7 +240,7 @@ Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_gra
if (node->GetOpDesc() == nullptr) {
continue;
}
node->GetOpDesc()->SetName(sub_graph->GetName() + "/" + node->GetName());
node->GetOpDesc()->SetName(OnnxUtil::GenUniqueNodeName(sub_graph->GetName(), node->GetName()));
}

auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(sub_graph);
@@ -750,6 +750,14 @@ Status OnnxModelParser::AdaptAndFindAllOnnxGraph(
while (!onnx_graph_tasks.empty()) {
ge::onnx::GraphProto *onnx_graph = onnx_graph_tasks.front();
onnx_graph_tasks.pop();
std::string graph_name;
for (const auto &graph_iter : name_to_onnx_graph) {
if (graph_iter.second == onnx_graph) {
graph_name = graph_iter.first;
break;
}
}

for (int i = 0; i < onnx_graph->node_size(); i++) {
ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i);
if (node_proto->name().empty()) {
@@ -767,7 +775,8 @@ Status OnnxModelParser::AdaptAndFindAllOnnxGraph(
}
std::vector<ge::onnx::GraphProto *> onnx_graphs;
std::map<std::string, ge::onnx::GraphProto *> name_to_onnx_subgraph;
if (subgraph_adapter->AdaptAndFindAllSubgraphs(node_proto, onnx_graphs, name_to_onnx_subgraph) != SUCCESS) {
if (subgraph_adapter->AdaptAndFindAllSubgraphs(
node_proto, onnx_graphs, name_to_onnx_subgraph, graph_name) != SUCCESS) {
GELOGE(FAILED, "[Adapt][Subgraph] adapt subgraph of node:%s failed.", node_proto->name().c_str());
REPORT_INNER_ERROR("E19999", "adapt subgraph of node:%s failed.", node_proto->name().c_str());
return FAILED;


+ 4
- 0
parser/onnx/onnx_util.cc View File

@@ -45,4 +45,8 @@ void OnnxUtil::GenUniqueSubgraphName(int subgraph_index, const std::string &orig
const std::string &parent_node_name, std::string &unique_subgraph_name) {
unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name;
}

std::string OnnxUtil::GenUniqueNodeName(const std::string &graph_name, const std::string &node_name) {
return graph_name + "/" + node_name;
}
} // namespace ge

+ 1
- 0
parser/onnx/onnx_util.h View File

@@ -54,6 +54,7 @@ class OnnxUtil {
static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type);
static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name,
const std::string &parent_node_name, std::string &unique_subgraph_name);
static std::string GenUniqueNodeName(const std::string &graph_name, const std::string &node_name);
};
} // namespace ge



+ 8
- 4
parser/onnx/subgraph_adapter/if_subgraph_adapter.cc View File

@@ -27,12 +27,12 @@ const int kIfNodeAttrSize = 2;
} // namespace
domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) {
GE_CHECK_NOTNULL(parent_node);
GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(),
parent_node->op_type().c_str());

auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph);
auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph, parent_graph_name);
if (ret != SUCCESS) {
GELOGE(ret, "[Parse][Node] Parse if node failed.");
REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str());
@@ -44,7 +44,7 @@ domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(

domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs(
ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph, const std::string &parent_graph_name) {
if (parent_node->attribute_size() != kIfNodeAttrSize) {
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size());
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size());
@@ -67,7 +67,11 @@ domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs(
return FAILED;
}
std::string unique_subgraph_name;
OnnxUtil::GenUniqueSubgraphName(itr->second, itr->first, parent_node->name(), unique_subgraph_name);
std::string node_name = parent_node->name();
if (!parent_graph_name.empty()) {
node_name = OnnxUtil::GenUniqueNodeName(parent_graph_name, node_name);
}
OnnxUtil::GenUniqueSubgraphName(itr->second, itr->first, node_name, unique_subgraph_name);
GELOGI("Adapt if node attribute:%s, subgraph name:%s.", attr_name.c_str(), unique_subgraph_name.c_str());
ge::onnx::GraphProto *onnx_graph = attribute->mutable_g();
name_to_onnx_graph[unique_subgraph_name] = onnx_graph;


+ 4
- 2
parser/onnx/subgraph_adapter/if_subgraph_adapter.h View File

@@ -26,11 +26,13 @@ class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter {
public:
domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_node,
std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) override;
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph,
const std::string &parent_graph_name = "") override;

private:
domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph);
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph,
const std::string &parent_graph_name);
domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs) const;
void AddInputNodeForGraph(const std::set<std::string> &all_inputs, ge::onnx::GraphProto &onnx_graph) const;
void AddInputForParentNode(const std::set<std::string> &all_inputs, ge::onnx::NodeProto &parent_node) const;


+ 3
- 1
parser/onnx/subgraph_adapter/subgraph_adapter.h View File

@@ -49,10 +49,12 @@ class PARSER_FUNC_VISIBILITY SubgraphAdapter {
/// @return FAILED Parse failed
virtual domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op,
std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph,
const std::string &parent_graph_name = "") {
(void)parent_op;
(void)onnx_graphs;
(void)name_to_onnx_graph;
(void)parent_graph_name;
return domi::SUCCESS;
}
};


Loading…
Cancel
Save