From 418b11e7c1fdf90ebf941588f1e4616be3669ab3 Mon Sep 17 00:00:00 2001 From: guopeian Date: Fri, 21 Oct 2022 09:26:47 +0000 Subject: [PATCH] =?UTF-8?q?!712=20=E6=8F=90=E4=BE=9B1=E5=AF=B9=E5=A4=9A?= =?UTF-8?q?=E6=9E=84=E9=80=A0=E7=9A=84=E5=9B=BE=E5=8F=AF=E4=BB=A5=E5=B8=A6?= =?UTF-8?q?=E5=AD=90=E5=9B=BE=E7=9A=84=E5=8A=9F=E8=83=BD=20Merge=20pull=20?= =?UTF-8?q?request=20!712=20from=20guopeian/fix=5Fif?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- parser/common/parser_utils.cc | 57 +++++++++++++++++++++++++------------------ parser/onnx/onnx_parser.cc | 22 ++++++++--------- 2 files changed, 44 insertions(+), 35 deletions(-) diff --git a/parser/common/parser_utils.cc b/parser/common/parser_utils.cc index 7b4e886..ba4f2ea 100644 --- a/parser/common/parser_utils.cc +++ b/parser/common/parser_utils.cc @@ -32,16 +32,17 @@ namespace { bool HasOneNonDataNode(const ComputeGraphPtr &graph) { GE_CHECK_NOTNULL(graph); int32_t non_data_nums = 0; - for (const auto& n : graph->GetDirectNode()) { - if (n->GetType() != parser::DATA) { + for (const auto& node : graph->GetDirectNode()) { + if (node->GetType() != parser::DATA) { non_data_nums++; } } - GELOGD("graph has non data node num is %d", non_data_nums); + GELOGD("Graph has non data node num is %d", non_data_nums); return (non_data_nums == 1); } Status HandleNewOp(const NodePtr &node, const ComputeGraphPtr &compute_graph, + const ComputeGraphPtr &sub_compute_graph, const NodePtr &new_node, bool no_need_change_name) { GE_CHECK_NOTNULL(node); @@ -60,35 +61,43 @@ Status HandleNewOp(const NodePtr &node, new_name = "PartitionedCall_" + new_node->GetName() + "_" + to_string(new_node_index++); } op_desc->SetName(new_name); - bool ret = ge::AttrUtils::SetListStr(op_desc, - ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, - std::move(std::vector{node->GetName()})); - if (!ret) { + std::vector node_name_vec = { node->GetName() }; + if (!ge::AttrUtils::SetListStr(op_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, + std::move(node_name_vec))) { GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), op_desc->GetName().c_str()); } - GELOGD("Handle new op[%s] for node[%s] success.", new_node->GetName().c_str(), node->GetName().c_str()); + // handle control op + const auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); + for (size_t i = 0UL; i < sub_graph_names.size(); i++) { + auto branch_graph = sub_compute_graph->GetSubgraph(sub_graph_names[i]); + GE_CHECK_NOTNULL(branch_graph); + branch_graph->SetParentNode(new_node); + branch_graph->SetParentGraph(compute_graph); + compute_graph->AddSubGraph(branch_graph); + } + GELOGD("Handle new node[%s] for node[%s] success.", new_node->GetName().c_str(), node->GetName().c_str()); return SUCCESS; } } Status ParserUtils::ExpandOneToManyGraph(const Graph &graph, OutputMapping &output_mapping) { - GELOGD("Begin run ParserUtils::ExpandOneToManyGraph."); - for (const auto &gn : graph.GetDirectNode()) { - NodePtr n = NodeAdapter::GNode2Node(gn); - GE_CHECK_NOTNULL(n); + GELOGD("Begin to run ParserUtils::ExpandOneToManyGraph."); + for (const auto &ge_node : graph.GetDirectNode()) { + NodePtr node = NodeAdapter::GNode2Node(ge_node); + GE_CHECK_NOTNULL(node); std::string ori_type; - (void)AttrUtils::GetStr(n->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, ori_type); + (void)AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, ori_type); domi::ParseOpToGraphFunc parse_op_to_graph_func = - domi::OpRegistry::Instance()->GetParseOpToGraphFunc(n->GetType(), ori_type); + domi::OpRegistry::Instance()->GetParseOpToGraphFunc(node->GetType(), ori_type); if (parse_op_to_graph_func == nullptr) { GELOGD("node:%s type:%s ori type:%s has no parse_op_to_graph_func.", - n->GetName().c_str(), n->GetType().c_str(), ori_type.c_str()); + node->GetName().c_str(), node->GetType().c_str(), ori_type.c_str()); continue; } GELOGI("node:%s type:%s ori type:%s has registered one to many parser func.", - n->GetName().c_str(), n->GetType().c_str(), ori_type.c_str()); + node->GetName().c_str(), node->GetType().c_str(), ori_type.c_str()); Graph subgraph("one_to_many_graph"); - Operator op = OpDescUtils::CreateOperatorFromNode(n); + Operator op = OpDescUtils::CreateOperatorFromNode(node); Status ret = parse_op_to_graph_func(op, subgraph); if (ret != SUCCESS) { REPORT_CALL_ERROR("E19999", "Get one to many graph failed for op:%s.", GetOperatorName(op).c_str()); @@ -96,14 +105,14 @@ Status ParserUtils::ExpandOneToManyGraph(const Graph &graph, OutputMapping &outp GetOperatorName(op).c_str()); return FAILED; } - ret = ExpandNodeToSubgraph(subgraph, n, graph, output_mapping); + ret = ExpandNodeToSubgraph(subgraph, node, graph, output_mapping); if (ret != SUCCESS) { GELOGE(FAILED, "[Invoke][ExpandNodeToSubgraph]Expand one to many graph failed for op:%s.", GetOperatorName(op).c_str()); return FAILED; } } - GELOGD("run ParserUtils::ExpandOneToManyGraph success."); + GELOGD("Run ParserUtils::ExpandOneToManyGraph success."); return SUCCESS; } @@ -117,10 +126,10 @@ Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &n // add subgraph node to graph. bool no_need_change_name = HasOneNonDataNode(sub_compute_graph); std::vector input_nodes; - for (const auto &n : sub_compute_graph->GetDirectNode()) { - auto new_node = compute_graph->AddNode(n); + for (const auto &sub_node : sub_compute_graph->GetDirectNode()) { + auto new_node = compute_graph->AddNode(sub_node); GE_CHECK_NOTNULL(new_node); - if (HandleNewOp(node, compute_graph, new_node, no_need_change_name) != SUCCESS) { + if (HandleNewOp(node, compute_graph, sub_compute_graph, new_node, no_need_change_name) != SUCCESS) { GELOGE(FAILED, "[Handle][NewOp][%s] for node[%s] failed.", new_node->GetName().c_str(), node->GetName().c_str()); return FAILED; } @@ -240,7 +249,7 @@ Status ParserUtils::HandleOutputContext(const NodePtr &node, const std::vector> &out_node_index, OutputMapping &output_mapping) { GE_CHECK_NOTNULL(node); - GELOGD("The size of out node is %zu", out_node_index.size()); + GELOGD("The size of output node is %zu", out_node_index.size()); for (size_t index = 0; index < out_node_index.size(); index++) { auto node_out_anchor = node->GetOutDataAnchor(index); if (node_out_anchor == nullptr) { @@ -249,7 +258,7 @@ Status ParserUtils::HandleOutputContext(const NodePtr &node, NodePtr out_node = out_node_index[index].first; int32_t out_index = out_node_index[index].second; - GELOGD("Begin to handle output node:%s[%d] with index:%zu", out_node->GetName().c_str(), out_index, index); + GELOGD("Begin to handle output node: %s[%d] with index:%zu", out_node->GetName().c_str(), out_index, index); std::string key = GenOutputKey({node->GetName(), index}); output_mapping[key] = std::make_pair(out_node->GetName(), out_index); auto src_out_anchor = out_node->GetOutDataAnchor(out_index); // get out node's out anchor. diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index fecef89..f3769e8 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -121,14 +121,14 @@ graphStatus aclgrphParseONNX(const char *model_file, GELOGE(ret, "[Parse][ModelFile] %s failed, graph %s.", model_file, ParserUtils::GetGraphName(graph).c_str()); return ge::FAILED; } - GELOGI("Parser graph %s success.", ParserUtils::GetGraphName(graph).c_str()); + GELOGI("Parse graph %s success.", ParserUtils::GetGraphName(graph).c_str()); if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) { GELOGE(ge::FAILED, "[Invoke][HandleAfterParse] failed."); return ge::FAILED; } - GELOGI("AclgrphParse graph %s success.", ParserUtils::GetGraphName(graph).c_str()); + GELOGI("Call aclgrphParse to parse graph %s success.", ParserUtils::GetGraphName(graph).c_str()); return ge::SUCCESS; } @@ -151,13 +151,13 @@ graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, GELOGE(ret, "[Parser][Graph] %s failed.", ParserUtils::GetGraphName(graph).c_str()); return ge::FAILED; } - GELOGI("Parser graph %s success.", ParserUtils::GetGraphName(graph).c_str()); + GELOGI("Parse graph %s success.", ParserUtils::GetGraphName(graph).c_str()); if (HandleAfterParse(acl_graph_parse_util, parser_params, graph) != ge::SUCCESS) { GELOGE(ge::FAILED, "[Invoke][HandleAfterParse] failed."); return ge::FAILED; } - GELOGI("AclgrphParse graph %s success.", ParserUtils::GetGraphName(graph).c_str()); + GELOGI("Call aclgrphParse to parse graph %s success.", ParserUtils::GetGraphName(graph).c_str()); return ge::SUCCESS; } } // namespace ge @@ -179,7 +179,7 @@ struct ParseArg { }; Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque &args) { - GELOGI("Gen subgraph parse tasks start"); + GELOGI("Generate subgraph parse tasks start"); for (auto &node : parent_graph->GetDirectNode()) { auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -200,7 +200,7 @@ Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque args.push_back({nullptr, node, unique_subgraph_name, i}); } } - GELOGI("Gen subgraph parse tasks end"); + GELOGI("Generate subgraph parse tasks end"); return SUCCESS; } @@ -239,8 +239,9 @@ Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_gra } } - GELOGD("Post process for subgraph %s node %s type %s", arg.graph_name.c_str(), arg.parent_node->GetName().c_str(), - arg.parent_node->GetType().c_str()); + GELOGD("Post process for node %s with type %s in subgraph %s ", arg.graph_name.c_str(), + arg.parent_node->GetType().c_str(), + arg.parent_node->GetName().c_str()); // Refresh node_name in subgraph for (const ge::NodePtr &node : sub_graph->GetDirectNode()) { @@ -890,9 +891,8 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model if (arg.onnx_graph == nullptr) { std::map::const_iterator itr = name_to_onnx_graph.find(arg.graph_name); if (itr == name_to_onnx_graph.end()) { - GELOGE(FAILED, "[Find][OnnxGraph] Can not find onnx graph, graph:%s.", arg.graph_name.c_str()); - REPORT_INNER_ERROR("E19999", "Can not find onnx graph, graph:%s.", arg.graph_name.c_str()); - return FAILED; + GELOGI("Graph: %s is subgraph from plugin, no need parser", arg.graph_name.c_str()); + continue; } arg.onnx_graph = itr->second; }