diff --git a/parser/onnx/CMakeLists.txt b/parser/onnx/CMakeLists.txt index 77cdcf1..304ae82 100644 --- a/parser/onnx/CMakeLists.txt +++ b/parser/onnx/CMakeLists.txt @@ -8,7 +8,9 @@ set(SRC_LIST "onnx_parser.cc" "onnx_data_parser.cc" "onnx_util.cc" - "onnx_constant_parser.cc" + "onnx_constant_parser.cc" + "subgraph_adapter/if_subgraph_adapter.cc" + "subgraph_adapter/subgraph_adapter_factory.cc" ) protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) @@ -31,6 +33,7 @@ target_compile_definitions(fmk_onnx_parser PRIVATE target_include_directories(fmk_onnx_parser PRIVATE ${CMAKE_CURRENT_LIST_DIR} + ${CMAKE_CURRENT_LIST_DIR}/subgraph_adapter ${PARSER_DIR} ${PARSER_DIR}/inc ${PARSER_DIR}/parser diff --git a/parser/onnx/onnx_data_parser.cc b/parser/onnx/onnx_data_parser.cc index 29f966a..367eb32 100644 --- a/parser/onnx/onnx_data_parser.cc +++ b/parser/onnx/onnx_data_parser.cc @@ -35,6 +35,11 @@ Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def) GELOGE(FAILED, "parse shape of data op %s from model failed", op_def.GetName().c_str()); return FAILED; } + // Subgraph data operator don't need parse input shape + // the shape mappings from parent node input + if (IsSubgraphDataOp()) { + return SUCCESS; + } if (ParseInputFromUser(op_def) != SUCCESS) { GELOGE(FAILED, "parse shape of data op %s from user failed", op_def.GetName().c_str()); @@ -72,15 +77,23 @@ Status OnnxDataParser::ParseInputFromModel(const Message *op_src, ge::Operator & // Get attr t:'input_tensor' form NodeProto int64_t data_type = 1; int64_t index = 0; + is_subgraph_data_op_ = false; for (auto it : node->attribute()) { if (it.name() == ge::kAttrNameInput) { data_type = ParseInputTensor(it); } else if (it.name() == ge::kAttrNameIndex) { index = it.i(); GELOGI("The node has attribute with index: %ld", index); + } else if (it.name() == ge::kAttrNameIsSubgraphOp) { + is_subgraph_data_op_ = true; } } + op_def.SetAttr(ge::ATTR_NAME_INDEX, index); + if (IsSubgraphDataOp()) { + return SUCCESS; + } + // Trans onnx type to ge type DataType type = OnnxUtil::ConvertOnnxDataType(data_type); if (type == ge::DataType::DT_UNDEFINED) { @@ -88,7 +101,6 @@ Status OnnxDataParser::ParseInputFromModel(const Message *op_src, ge::Operator & return FAILED; } op_def.SetAttr(ge::DATA_ATTR_NAME_DATA_TYPE, static_cast(type)); - op_def.SetAttr(ge::ATTR_NAME_INDEX, index); return SUCCESS; } diff --git a/parser/onnx/onnx_data_parser.h b/parser/onnx/onnx_data_parser.h index 9650af6..fdc59e2 100644 --- a/parser/onnx/onnx_data_parser.h +++ b/parser/onnx/onnx_data_parser.h @@ -32,11 +32,17 @@ class PARSER_FUNC_VISIBILITY OnnxDataParser : public OnnxOpParser { Status ParseInputFromUser(const ge::Operator &op_def); + bool IsSubgraphDataOp() { + return is_subgraph_data_op_; + } + int64_t ParseInputTensor(const ge::onnx::AttributeProto &attribute); std::vector model_input_dims_v_; std::vector user_input_dims_v_; + + bool is_subgraph_data_op_; }; } // namespace ge diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index 8745fb9..39b4bc8 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -17,6 +17,7 @@ #include "onnx_parser.h" #include #include +#include #include "common/convert/pb2json.h" #include "common/util.h" #include "common/util/error_manager/error_manager.h" @@ -37,6 +38,9 @@ #include "parser/onnx/onnx_util.h" #include "register/op_registry.h" #include "register/register_fmk_types.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" +#include "subgraph_adapter/subgraph_adapter_factory.h" namespace ge { graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, @@ -95,7 +99,7 @@ graphStatus aclgrphParseONNX(const char *model_file, } GE_CHECK_NOTNULL(model_parser); - // parse caffe model_file to GE graph + // parse onnx model_file to GE graph ge::graphStatus ret = model_parser->Parse(model_file, graph); if (ret != ge::SUCCESS) { GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); @@ -144,18 +148,130 @@ graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, namespace ge { namespace { const std::map kOnnxOpMap = { - {ge::kOpTypeInput, ge::parser::DATA}, {ge::kOpTypeConstant, ge::parser::CONSTANT}, + {ge::kOpTypeInput, ge::parser::DATA}, + {ge::kOpTypeConstant, ge::parser::CONSTANT} }; -const char* const MATMULV2 = "MatMulV2"; -const std::vector kNoNeedUpdateFormat = {MATMULV2}; const int64_t kDimValue = 1; + +struct ParseArg { + ge::onnx::GraphProto *onnx_graph; + ge::NodePtr parent_node; + std::string graph_name; + uint32_t subgraph_index; +}; + +Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque &args) { + GELOGI("Gen subgraph parse tasks start"); + for (auto &node : parent_graph->GetDirectNode()) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (const auto subgraph_name_to_index : op_desc->GetSubgraphNameIndexes()) { + auto i = subgraph_name_to_index.second; + auto subgraph_iname = subgraph_name_to_index.first; + if (subgraph_iname.empty()) { + GELOGW("The subgraph index %u of node %s is empty", i, node->GetName().c_str()); + continue; + } + + // change the graph name to ensure it is unique in GE + std::string unique_subgraph_name; + OnnxUtil::GenUniqueSubgraphName(i, subgraph_iname, node->GetName(), unique_subgraph_name); + + GELOGD("Add subgraph parse task to the queue, node %s, index %u, subgraph instance name %s", + node->GetName().c_str(), i, unique_subgraph_name.c_str()); + args.push_back({nullptr, node, unique_subgraph_name, i}); + } + } + GELOGI("Gen subgraph parse tasks end"); + return SUCCESS; +} + +Status BuildLinkForChildAndParentGraph(const ge::ComputeGraphPtr &sub_graph, const ParseArg &arg) { + if (arg.parent_node == nullptr) { + return SUCCESS; + } + auto parent_node = arg.parent_node; + auto index = arg.subgraph_index; + auto ret = ge::NodeUtils::SetSubgraph(*parent_node, index, sub_graph); + if (ret != SUCCESS) { + GELOGE(ret, "[Set][Subgraph] Failed to set subgraph %s to node %s index %u", sub_graph->GetName().c_str(), + parent_node->GetName().c_str(), index); + REPORT_CALL_ERROR("E19999", "Failed to set subgraph %s to node %s index %u", sub_graph->GetName().c_str(), + parent_node->GetName().c_str(), index); + return ret; + } + return SUCCESS; } -Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, - std::map &initializer_name_tensor) { - if (onnx_graph.input_size() == 0) { +Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_graph) { + if (arg.parent_node == nullptr) { + return SUCCESS; + } + std::string op_type = arg.parent_node->GetType(); + std::string op_name = arg.parent_node->GetName(); + domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr; + auto post_func = + domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type); + if (post_func == nullptr) { + GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str()); + if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS || parse_func_v2 == nullptr) { + GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str()); + return SUCCESS; + } + } + + GELOGD("Post process for subgraph %s node %s type %s", arg.graph_name.c_str(), arg.parent_node->GetName().c_str(), + arg.parent_node->GetType().c_str()); + + // Refresh node_name in subgraph + for (const ge::NodePtr &node : sub_graph->GetDirectNode()) { + if (node->GetOpDesc() == nullptr) { + continue; + } + node->GetOpDesc()->SetName(sub_graph->GetName() + "/" + node->GetName()); + } + + auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(sub_graph); + Status ret = FAILED; + if (post_func != nullptr) { + ret = post_func(arg.graph_name, graph); + } else if (parse_func_v2 != nullptr) { + ret = parse_func_v2(arg.graph_name.c_str(), graph); + } + if (ret != SUCCESS) { + GELOGE(FAILED, "[PostProcess][Subgraph]Failed to post-process subgraph %s on node %s type %s", + arg.graph_name.c_str(), arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to post-process subgraph %s on node %s type %s", + arg.graph_name.c_str(), arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str()); + return FAILED; + } + return SUCCESS; +} +} + +Status OnnxModelParser::ParseOutput(ge::onnx::GraphProto &onnx_graph) { + if (onnx_graph.output_size() == 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E16001"); + GELOGE(FAILED, "[Parse][Output] Onnx graph:%s has zero output", onnx_graph.name().c_str()); + REPORT_INNER_ERROR("E19999", "Onnx graph:%s has zero output", onnx_graph.name().c_str()); + return FAILED; + } + + // get output value info map + for (int i = 0; i < onnx_graph.output_size(); i++) { + ge::onnx::ValueInfoProto value_info = onnx_graph.output(i); + GELOGI("The index of %d output name : %s.", i, value_info.name().c_str()); + output_node_names_.emplace_back(value_info.name()); + } + return SUCCESS; +} + +Status OnnxModelParser::ParseInput(const std::map &initializer_name_tensor, + bool is_subgraph, ge::onnx::GraphProto &onnx_graph) { + if (!is_subgraph && onnx_graph.input_size() == 0) { ErrorManager::GetInstance().ATCReportErrMessage("E16001"); - GELOGE(FAILED, "Onnx graph has zero input"); + GELOGE(FAILED, "[Parse][Input] Root onnx graph:%s has zero input", onnx_graph.name().c_str()); + REPORT_INNER_ERROR("E19999", "Root onnx graph:%s has zero input", onnx_graph.name().c_str()); return FAILED; } @@ -207,6 +323,11 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, ge::onnx::AttributeProto *attribute_index = input_node->add_attribute(); attribute_index->set_name(ge::kAttrNameIndex); attribute_index->set_i(data_index++); + // add subgraph attr + if (is_subgraph) { + attribute = input_node->add_attribute(); + attribute->set_name(ge::kAttrNameIsSubgraphOp); + } input_node_names_.emplace_back(value_info.name()); } return SUCCESS; @@ -319,7 +440,8 @@ Status OnnxModelParser::TransNodeToOperator(const ge::onnx::NodeProto *node_prot op = ge::OperatorFactory::CreateOperator(node_name, op_type); if (op.GetName() != node_name) { ErrorManager::GetInstance().ATCReportErrMessage("E16003", {"opname", "optype"}, {node_name, op_type}); - GELOGE(INTERNAL_ERROR, "IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str()); + GELOGE(INTERNAL_ERROR, "[Creat][Op] IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str()); + REPORT_INNER_ERROR("E19999", "IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str()); return INTERNAL_ERROR; } @@ -428,7 +550,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: std::string op_type; Status status = AdapterOpType(node_proto, ori_type, op_type); if (status != SUCCESS) { - GELOGE(status, "Adapter op type for ori type %s failed.", ori_type.c_str()); + GELOGE(status, "[Adapt][OpType] Adapter op type for ori type %s failed.", ori_type.c_str()); + REPORT_CALL_ERROR("E19999", "Adapter op type for ori type %s failed.", ori_type.c_str()); return status; } node_proto->set_op_type(ori_type); @@ -438,7 +561,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: ge::Operator op; status = TransNodeToOperator(node_proto, op, op_type); if (status != SUCCESS) { - GELOGE(status, "Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); + GELOGE(status, "[Trans][Node] Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); + REPORT_CALL_ERROR("E19999", "Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); return status; } @@ -455,9 +579,14 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: return status; } + GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(), + op.GetOpType().c_str(), op.GetInputsSize(), op.GetOutputsSize()); + + ge::graphStatus graph_status = graph.AddOp(op); if (graph_status != ge::GRAPH_SUCCESS) { - GELOGE(FAILED, "Add op:%s to graph failed.", op.GetName().c_str()); + GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", op.GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Add op:%s to graph failed.", op.GetName().c_str()); return FAILED; } name_operator_[op.GetName()] = op; @@ -473,18 +602,54 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: return SUCCESS; } -Status OnnxModelParser::GetGraphInputs(std::vector &input_ops) { +Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector &input_ops) { + if (input_node_names_.empty()) { + // subgraph might not have input, we use constant nodes as the start nodes of graph + for (int i = 0; i < onnx_graph.node_size(); i++) { + ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); + if (node->op_type() == kOpTypeConstant) { + input_node_names_.emplace_back(node->name()); + } + } + } for (auto in_name : input_node_names_) { auto in_op = name_operator_.find(in_name); if (in_op == name_operator_.end()) { - GELOGE(PARAM_INVALID, "Model assigned output node name: %s can not find in graph.", + GELOGE(PARAM_INVALID, "[Get][Inputs] Model assigned input node name: %s can not find in graph.", in_name.c_str()); + REPORT_INNER_ERROR("E19999", "Model assigned input node name: %s can not find in graph.", + in_name.c_str()); return PARAM_INVALID; } input_ops.emplace_back(in_op->second); GELOGI("Model assigned input node name: %s", in_op->second.GetName().c_str()); } + return SUCCESS; +} +Status OnnxModelParser::GetGraphOutputs(std::vector>> &output_ops) { + for (auto output_name : output_node_names_) { + auto itr = outputs_map_.find(output_name); + if (itr == outputs_map_.end()) { + GELOGE(PARAM_INVALID, "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); + REPORT_INNER_ERROR( "E19999", "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); + return PARAM_INVALID; + } + + std::vector> node_names_indexes = itr->second; + for (const auto &node_name_index : node_names_indexes) { + auto node_name = node_name_index.first; + auto out_op_itr = name_operator_.find(node_name); + if (out_op_itr == name_operator_.end()) { + GELOGE(PARAM_INVALID, "[Get][Operator] Can not find operator: %s in graph.", node_name.c_str()); + REPORT_INNER_ERROR("E19999", "Can not find operator: %s in graph.", node_name.c_str()); + return PARAM_INVALID; + } + int index = node_name_index.second; + output_ops.emplace_back(out_op_itr->second, vector{static_cast(index)}); + GELOGI("out node index %d, node:%s", index, node_name.c_str()); + } + } return SUCCESS; } @@ -515,19 +680,146 @@ Status OnnxModelParser::GetModelFromMemory(const char *data, uint32_t size, ge:: return SUCCESS; } -Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph) { +void OnnxModelParser::ClearMembers() { + name_operator_.clear(); + input_node_names_.clear(); + output_node_names_.clear(); + inputs_map_.clear(); + outputs_map_.clear(); +} + +Status OnnxModelParser::AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph, + std::map &name_to_onnx_graph) { + std::queue onnx_graph_tasks; + int index = 0; + onnx_graph_tasks.push(&root_onnx_graph); + + while (!onnx_graph_tasks.empty()) { + ge::onnx::GraphProto *onnx_graph = onnx_graph_tasks.front(); + onnx_graph_tasks.pop(); + for (int i = 0; i < onnx_graph->node_size(); i++) { + ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i); + if (node_proto->name().empty()) { + std::string node_name = node_proto->op_type() + "_" + to_string(index++); + node_proto->set_name(node_name); + } + GELOGD("adapt op name:%s, op type:%s", node_proto->name().c_str(), node_proto->op_type().c_str()); + + SubgraphAdapterFactory *factory = SubgraphAdapterFactory::Instance(); + GE_CHECK_NOTNULL(factory); + std::shared_ptr subgraph_adapter = factory->CreateSubgraphAdapter(node_proto->op_type()); + if(subgraph_adapter == nullptr) { + GELOGD("Do not need adapt subgraph, op type:%s", node_proto->op_type().c_str()); + continue; + } + std::vector onnx_graphs; + std::map name_to_onnx_subgraph; + if (subgraph_adapter->AdaptAndFindAllSubgraphs(node_proto, onnx_graphs, name_to_onnx_subgraph) != SUCCESS) { + GELOGE(FAILED, "[Adapt][Subgraph] adapt subgraph of node:%s failed.", node_proto->name().c_str()); + REPORT_INNER_ERROR("E19999", "adapt subgraph of node:%s failed.", node_proto->name().c_str()); + return FAILED; + } + + for (const auto &onnx_graph : onnx_graphs) { + onnx_graph_tasks.push(onnx_graph); + } + for (const auto &itr : name_to_onnx_subgraph) { + name_to_onnx_graph.emplace(itr.first, itr.second); + } + } + } + return SUCCESS; +} + +Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &root_graph) { if (!onnx_model.has_graph()) { ErrorManager::GetInstance().ATCReportErrMessage("E16004"); GELOGE(PARAM_INVALID, "Onnx model do not has graph."); return FAILED; } - ge::onnx::GraphProto onnx_graph = onnx_model.graph(); + std::map name_to_onnx_graph; + std::deque tasks; + ge::onnx::GraphProto root_onnx_graph = onnx_model.graph(); + + auto ret = AdaptAndFindAllOnnxGraph(root_onnx_graph, name_to_onnx_graph); + if (ret != SUCCESS) { + GELOGE(FAILED, "[AdaptAndFind][OnnxGraph]adapt and find all onnx graph failed, root graph:%s.", + root_onnx_graph.name().c_str()); + return FAILED; + } auto opset_import = onnx_model.opset_import(); for (auto it : opset_import) { domain_verseion_[it.domain()] = it.version(); GELOGI("Domain: %s, Version: %ld ", it.domain().c_str(), it.version()); } + std::string root_graph_name = root_graph.GetName().empty() ? "default_graph" : root_graph.GetName(); + tasks.push_back({&root_onnx_graph, nullptr, root_graph_name, 0}); + + while (!tasks.empty()) { + ParseArg arg = tasks.front(); + tasks.pop_front(); + bool is_subgraph = (arg.parent_node != nullptr) ? true : false; + + if (arg.onnx_graph == nullptr) { + auto itr = name_to_onnx_graph.find(arg.graph_name); + if (itr == name_to_onnx_graph.end()) { + GELOGE(FAILED, "[Find][OnnxGraph] Can not find onnx graph, graph:%s.", arg.graph_name.c_str()); + REPORT_INNER_ERROR("E19999", "Can not find onnx graph, graph:%s.", arg.graph_name.c_str()); + return FAILED; + } + arg.onnx_graph = itr->second; + } + + ge::onnx::GraphProto *onnx_graph = arg.onnx_graph; + ge::Graph tmp_graph(arg.graph_name); + ret = ModelParseToGraphImpl(is_subgraph, *onnx_graph, tmp_graph); + if (ret != SUCCESS) { + GELOGE(ret, "[Parse][Model] Model parse to graph failed, graph name:%s.", arg.graph_name.c_str()); + REPORT_INNER_ERROR("E19999", "Model parse to graph failed, graph name:%s.", arg.graph_name.c_str()); + return ret; + } + // To get the result for root graph + if (!is_subgraph) { + root_graph = tmp_graph; + } + + ge::ComputeGraphPtr cur_compute_graph = ge::GraphUtils::GetComputeGraph(tmp_graph); + GE_CHECK_NOTNULL(cur_compute_graph); + + ret = PostOpProcessForSubgraph(arg, cur_compute_graph); + if (ret != SUCCESS) { + GELOGE(ret, "[PostProcess][Subgraph]Post Op for subgraph:%s failed.", cur_compute_graph->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Post Op for subgraph:%s failed.", cur_compute_graph->GetName().c_str()); + return ret; + } + + ret = BuildLinkForChildAndParentGraph(cur_compute_graph, arg); + if (ret != SUCCESS) { + GELOGE(ret, "[BuildLink][Graph] Build link for child graph:%s and parent graph failed.", + cur_compute_graph->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Build link for child graph:%s and parent graph failed.", + cur_compute_graph->GetName().c_str()); + return ret; + } + + ret = GenSubgraphParseTasks(cur_compute_graph, tasks); + if (ret != SUCCESS) { + GELOGE(ret, "[Generate][Task] Failed to gen tasks on graph %s for next iteration", + cur_compute_graph->GetName().c_str()); + REPORT_CALL_ERROR("E19999", "Failed to gen tasks on graph %s for next iteration", + cur_compute_graph->GetName().c_str()); + return ret; + } + + } + UpdateDataFormat(root_graph); + return SUCCESS; +} + +Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { + + ClearMembers(); // 2. Get all inializer. std::map initializer_name_tensor; @@ -541,7 +833,8 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model // 3. Parse Input from graph. GELOGI("The size of initializer_name_tensor is %zu ", initializer_name_tensor.size()); - Status ret = ParseInput(onnx_graph, initializer_name_tensor); + + Status ret = ParseInput(initializer_name_tensor, is_subgraph, onnx_graph); if (ret != SUCCESS) { GELOGE(ret, "Parse input for onnx failed."); return ret; @@ -555,6 +848,12 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model return ret; } + ret = ParseOutput(onnx_graph); + if (ret != SUCCESS) { + GELOGE(ret, "[Parse][Output] Parse output for onnx failed."); + return ret; + } + // 5. Update node name for node do not has name. ret = UpdateAllNodeName(onnx_graph); if (ret != SUCCESS) { @@ -582,6 +881,10 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model return ret; } + std::vector op_names; + graph.GetAllOpName(op_names); + GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size()); + // 8. Set all operator input. ret = SetOperatorInputs(); if (ret != SUCCESS) { @@ -589,22 +892,27 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model return ret; } - std::vector op_names; - graph.GetAllOpName(op_names); - GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size()); - // 9. Construct graph. std::vector input_ops; - - ret = GetGraphInputs(input_ops); + ret = GetGraphInputs(onnx_graph, input_ops); if (ret != SUCCESS) { GELOGE(ret, "Get graph inputs failed."); return ret; } graph.SetInputs(input_ops); + // root graph needn't set outputs. + if(is_subgraph) { + std::vector>> output_ops; + ret = GetGraphOutputs(output_ops); + if (ret != SUCCESS) { + GELOGE(ret, "[Get][Outputs]Get graph outputs failed."); + return ret; + } + graph.SetOutputs(output_ops); + } + GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph)); - UpdateDataFormat(graph); GELOGI("Onnx model parser success."); return SUCCESS; diff --git a/parser/onnx/onnx_parser.h b/parser/onnx/onnx_parser.h index 45adf7c..b28494b 100644 --- a/parser/onnx/onnx_parser.h +++ b/parser/onnx/onnx_parser.h @@ -72,8 +72,10 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { private: Status ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph); - Status ParseInput(ge::onnx::GraphProto &onnx_graph, - std::map &initializer_name_tensor); + Status ParseInput(const std::map &initializer_name_tensor, + bool is_subgraph, ge::onnx::GraphProto &onnx_graph); + + Status ParseOutput(ge::onnx::GraphProto &onnx_graph); Status ParseInitializer(ge::onnx::GraphProto &onnx_graph, std::map &initializer_name_tensor); @@ -90,7 +92,9 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { Status SetOperatorInputs(); - Status GetGraphInputs(std::vector &input_ops); + Status GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector &input_ops); + + Status GetGraphOutputs(std::vector>> &outputs); Status Prechecker(ge::onnx::GraphProto &onnx_graph); @@ -100,8 +104,15 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph); + Status ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph); + void UpdateDataFormat(ge::Graph &graph); + void ClearMembers(); + + Status AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph, + std::map &name_to_onnx_graph); + std::map ori_to_om_type_; std::map domain_verseion_; @@ -110,6 +121,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { std::vector input_node_names_; + std::vector output_node_names_; + std::map>> inputs_map_; std::map>> outputs_map_; diff --git a/parser/onnx/onnx_util.cc b/parser/onnx/onnx_util.cc index d42ab39..72ba260 100644 --- a/parser/onnx/onnx_util.cc +++ b/parser/onnx/onnx_util.cc @@ -60,4 +60,9 @@ int64_t OnnxUtil::CaculateDataSize(int64_t onnx_data_type) { return ge::DataType::DT_UNDEFINED; } } + +void OnnxUtil::GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, + const std::string &parent_node_name, std::string &unique_subgraph_name) { + unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name; +} } // namespace ge diff --git a/parser/onnx/onnx_util.h b/parser/onnx/onnx_util.h index 259ed42..c8ffa8c 100644 --- a/parser/onnx/onnx_util.h +++ b/parser/onnx/onnx_util.h @@ -45,6 +45,7 @@ namespace ge { const char *const kAttrNameValue = "value"; const char *const kAttrNameInput = "input_tensor"; const char *const kAttrNameIndex = "index"; +const char *const kAttrNameIsSubgraphOp = "is_subgraph_op"; const char *const kOpTypeConstant = "Constant"; const char *const kOpTypeInput = "Input"; @@ -52,6 +53,8 @@ class OnnxUtil { public: static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type); static int64_t CaculateDataSize(int64_t onnx_data_type); + static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, + const std::string &parent_node_name, std::string &unique_subgraph_name); }; } // namespace ge diff --git a/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc b/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc new file mode 100644 index 0000000..0265ed2 --- /dev/null +++ b/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc @@ -0,0 +1,131 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "if_subgraph_adapter.h" +#include "subgraph_adapter_factory.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" + +namespace ge{ +namespace { +const std::map kAttrNameToIndex = {{"then_branch", 0}, {"else_branch", 1}}; +const int kIfNodeAttrSize = 2; +} +Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_node, + std::vector &onnx_graphs, + std::map &name_to_onnx_graph) { + GE_CHECK_NOTNULL(parent_node); + GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), + parent_node->op_type().c_str()); + + auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph); + if (ret != SUCCESS) { + GELOGE(ret, "[Parse][Node] Parse if node failed."); + REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str()); + return ret; + } + + return SUCCESS; +} + +Status IfSubgraphAdapter::ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, + std::vector &onnx_graphs, + std::map &name_to_onnx_graph) { + if (parent_node->attribute_size() != kIfNodeAttrSize) { + GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); + REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); + return FAILED; + } + + GELOGD("node attribute size:%d.", parent_node->attribute_size()); + std::set all_inputs; + // for onnx graph, the first attribute may be else branch and the second attribute may be then branch + for (int i = 0; i < parent_node->attribute_size(); i++) { + ge::onnx::AttributeProto *attribute = parent_node->mutable_attribute(i); + GE_CHECK_NOTNULL(attribute); + std::string attr_name = attribute->name(); + auto itr = kAttrNameToIndex.find(attr_name); + if (itr == kAttrNameToIndex.end()) { + GELOGE(FAILED, "[Parse][Attribute] Invalid attribute name:%s, it should be then_branch or else_branch.", + attr_name.c_str()); + REPORT_INNER_ERROR("E19999", "Invalid attribute name:%s, it should be then_branch or else_branch.", + attr_name.c_str()); + return FAILED; + } + std::string unique_subgraph_name; + OnnxUtil::GenUniqueSubgraphName(itr->second, itr->first, parent_node->name(), unique_subgraph_name); + GELOGI("Adapt if node attribute:%s, subgraph name:%s.", attr_name.c_str(), unique_subgraph_name.c_str()); + ge::onnx::GraphProto *onnx_graph = attribute->mutable_g(); + name_to_onnx_graph[unique_subgraph_name] = onnx_graph; + onnx_graphs.emplace_back(onnx_graph); + + auto ret = GetSubgraphsAllInputs(*onnx_graph, all_inputs); + if (ret != SUCCESS) { + GELOGE(ret, "[Get][Inputs] Get subgraph all inputs failed, attr_name:%s.", attr_name.c_str()); + REPORT_INNER_ERROR("E19999", "Get subgraph all inputs failed, attr_name:%s.", attr_name.c_str()); + return ret; + } + } + + for (auto &onnx_graph : onnx_graphs) { + AddInputNodeForGraph(all_inputs, *onnx_graph); + } + + AddInputForParentNode(all_inputs, *parent_node); + return SUCCESS; +} + +Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, + std::set &all_inputs) { + std::set graph_inputs; + std::set graph_outputs; + for (int i = 0; i < onnx_graph.node_size(); i++) { + ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); + for (int j = 0; j < node_proto->input_size(); j++) { + graph_inputs.emplace(node_proto->input(j)); + } + for (int j = 0; j < node_proto->output_size(); j++) { + graph_outputs.emplace(node_proto->output(j)); + } + } + + for (const auto &input : graph_inputs) { + auto out_iter = graph_outputs.find(input); + if (out_iter == graph_outputs.end()) { + // Record input node need to be constructed + all_inputs.emplace(input); + } + } + + return SUCCESS; +} + +void IfSubgraphAdapter::AddInputNodeForGraph(const std::set &all_inputs, + ge::onnx::GraphProto &onnx_graph) { + for (const auto &input_name : all_inputs) { + ge::onnx::ValueInfoProto *value_info = onnx_graph.add_input(); + value_info->set_name(input_name); + } +} + +void IfSubgraphAdapter::AddInputForParentNode(const std::set &all_inputs, + ge::onnx::NodeProto &parent_node) { + for (const auto &input_name : all_inputs) { + parent_node.add_input(input_name); + } +} +REGISTER_SUBGRAPH_ADAPTER_CREATOR(IF, IfSubgraphAdapter); +} // namespace ge diff --git a/parser/onnx/subgraph_adapter/if_subgraph_adapter.h b/parser/onnx/subgraph_adapter/if_subgraph_adapter.h new file mode 100644 index 0000000..103a937 --- /dev/null +++ b/parser/onnx/subgraph_adapter/if_subgraph_adapter.h @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_ONNX_SUBGRAPH_ADAPTER_IF_SUBGRAPH_ADAPTER_H_ +#define GE_PARSER_ONNX_SUBGRAPH_ADAPTER_IF_SUBGRAPH_ADAPTER_H_ + +#include +#include +#include "subgraph_adapter.h" + +using ge::onnx::NodeProto; + +namespace ge { +class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { + public: + Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, std::vector &onnx_graphs, + std::map &name_to_onnx_graph) override; +private: + Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector &onnx_graphs, + std::map &name_to_onnx_graph); + Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set &all_inputs); + void AddInputNodeForGraph(const std::set &all_inputs, ge::onnx::GraphProto &onnx_graph); + void AddInputForParentNode(const std::set &all_inputs, ge::onnx::NodeProto &parent_node); +}; +} // namespace ge + +#endif // GE_PARSER_ONNX_SUBGRAPH_ADAPTER_IF_SUBGRAPH_ADAPTER_H_ diff --git a/parser/onnx/subgraph_adapter/subgraph_adapter.h b/parser/onnx/subgraph_adapter/subgraph_adapter.h new file mode 100644 index 0000000..afb2f96 --- /dev/null +++ b/parser/onnx/subgraph_adapter/subgraph_adapter.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_H_ +#define GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define PARSER_FUNC_VISIBILITY _declspec(dllexport) +#else +#define PARSER_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define PARSER_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define PARSER_FUNC_VISIBILITY +#endif +#endif + +#include +#include +#include "proto/onnx/ge_onnx.pb.h" +#include "external/register/register_error_codes.h" +#include "framework/omg/parser/parser_types.h" +#include "onnx_util.h" + +using Status = domi::Status; +using namespace ge::parser; + +namespace ge { +class PARSER_FUNC_VISIBILITY SubgraphAdapter { + public: + /// @brief parse params + /// @param [in/out] parent_op parent op + /// @param [in/out] onnx_graph_tasks onnx graph task + /// @param [in/out] name_to_onnx_graph map name to onnx graph + /// @return SUCCESS parse success + /// @return FAILED Parse failed + virtual Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, + std::vector &onnx_graphs, + std::map &name_to_onnx_graph) { + return domi::SUCCESS; + } +}; +} // namespace ge + +#endif // GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_H_ diff --git a/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc b/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc new file mode 100644 index 0000000..7632520 --- /dev/null +++ b/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "subgraph_adapter_factory.h" +#include "framework/common/debug/ge_log.h" + +namespace ge{ +SubgraphAdapterFactory* SubgraphAdapterFactory::Instance() { + static SubgraphAdapterFactory instance; + return &instance; +} + +std::shared_ptr SubgraphAdapterFactory::CreateSubgraphAdapter( + const std::string &op_type) { + // First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create SubgraphAdapter. + auto iter = subgraph_adapter_creator_map_.find(op_type); + if (iter != subgraph_adapter_creator_map_.end()) { + return iter->second(); + } + + GELOGW("SubgraphAdapterFactory::CreateSubgraphAdapter: Not supported type: %s", op_type.c_str()); + return nullptr; +} + +// This function is only called within the constructor of the global SubgraphAdapterRegisterar object, +// and does not involve concurrency, so there is no need to lock it +void SubgraphAdapterFactory::RegisterCreator(const std::string &type, CREATOR_FUN fun) { + std::map *subgraph_adapter_creator_map = &subgraph_adapter_creator_map_; + GELOGD("SubgraphAdapterFactory::RegisterCreator: op type:%s.", type.c_str()); + (*subgraph_adapter_creator_map)[type] = fun; +} +} // namespace ge diff --git a/parser/onnx/subgraph_adapter/subgraph_adapter_factory.h b/parser/onnx/subgraph_adapter/subgraph_adapter_factory.h new file mode 100644 index 0000000..fa023fa --- /dev/null +++ b/parser/onnx/subgraph_adapter/subgraph_adapter_factory.h @@ -0,0 +1,119 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_FACTORY_H_ +#define GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_FACTORY_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define PARSER_FUNC_VISIBILITY _declspec(dllexport) +#else +#define PARSER_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define PARSER_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define PARSER_FUNC_VISIBILITY +#endif +#endif + +#include +#include +#include "subgraph_adapter.h" + +namespace ge { +/** + * @brief Used to create OpParser + * + */ +class PARSER_FUNC_VISIBILITY SubgraphAdapterFactory { +public: + /** + * @brief Returns the SubgraphAdapterFactory instance + * @return SubgraphAdapterFactory object + */ + static SubgraphAdapterFactory* Instance(); + + /** + * @brief Create SubgraphAdapter based on input type + * @param [in] op_type Op type + * @return Created SubgraphAdapter + */ + std::shared_ptr CreateSubgraphAdapter(const std::string &op_type); + + +protected: + /** + * @brief SubgraphAdapter creation function + * @return Created SubgraphAdapter + */ + // typedef shared_ptr (*CREATOR_FUN)(void); + using CREATOR_FUN = std::function(void)>; + + /** + * @brief Factory instances can only be created automatically, not new methods, so the constructor is not public. + */ + SubgraphAdapterFactory() {} + + /** + * @brief Register creation function + * @param [in] type Op type + * @param [in] fun OpParser creation function + */ + void RegisterCreator(const std::string &type, CREATOR_FUN fun); + +private: + std::map subgraph_adapter_creator_map_; // lint !e1073 + + friend class SubgraphAdapterRegisterar; +}; + +/** + * @brief For registering Creator functions for different types of subgraph adapter + * + */ +class PARSER_FUNC_VISIBILITY SubgraphAdapterRegisterar { +public: + /** + * @brief Constructor + * @param [in] op_type Op type + * @param [in] fun Creator function corresponding to Subgrap adapter + */ + SubgraphAdapterRegisterar(const std::string &op_type, SubgraphAdapterFactory::CREATOR_FUN fun) { + SubgraphAdapterFactory::Instance()->RegisterCreator(op_type, fun); + } + ~SubgraphAdapterRegisterar() {} +}; + +/** + * @brief SubgraphAdapter Registration Macro + * @param [in] op_type Op type + * @param [in] clazz SubgraphAdapter implementation class + */ +#define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \ + std::shared_ptr Creator_##op_type##_Subgraph_Adapter() { \ + std::shared_ptr ptr(new (std::nothrow) clazz()); \ + if (ptr == nullptr) { \ + GELOGW("MakeShared failed, result is nullptr."); \ + } \ + return std::shared_ptr(ptr); \ + } \ + ge::SubgraphAdapterRegisterar g_##op_type##_Subgraph_Adapter_Creator(op_type, \ + Creator_##op_type##_Subgraph_Adapter) +} // namespace ge + +#endif // GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_FACTORY_H_