@@ -8,7 +8,9 @@ set(SRC_LIST | |||||
"onnx_parser.cc" | "onnx_parser.cc" | ||||
"onnx_data_parser.cc" | "onnx_data_parser.cc" | ||||
"onnx_util.cc" | "onnx_util.cc" | ||||
"onnx_constant_parser.cc" | |||||
"onnx_constant_parser.cc" | |||||
"subgraph_adapter/if_subgraph_adapter.cc" | |||||
"subgraph_adapter/subgraph_adapter_factory.cc" | |||||
) | ) | ||||
protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | ||||
@@ -31,6 +33,7 @@ target_compile_definitions(fmk_onnx_parser PRIVATE | |||||
target_include_directories(fmk_onnx_parser PRIVATE | target_include_directories(fmk_onnx_parser PRIVATE | ||||
${CMAKE_CURRENT_LIST_DIR} | ${CMAKE_CURRENT_LIST_DIR} | ||||
${CMAKE_CURRENT_LIST_DIR}/subgraph_adapter | |||||
${PARSER_DIR} | ${PARSER_DIR} | ||||
${PARSER_DIR}/inc | ${PARSER_DIR}/inc | ||||
${PARSER_DIR}/parser | ${PARSER_DIR}/parser | ||||
@@ -35,6 +35,11 @@ Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def) | |||||
GELOGE(FAILED, "parse shape of data op %s from model failed", op_def.GetName().c_str()); | GELOGE(FAILED, "parse shape of data op %s from model failed", op_def.GetName().c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
// Subgraph data operator don't need parse input shape | |||||
// the shape mappings from parent node input | |||||
if (IsSubgraphDataOp()) { | |||||
return SUCCESS; | |||||
} | |||||
if (ParseInputFromUser(op_def) != SUCCESS) { | if (ParseInputFromUser(op_def) != SUCCESS) { | ||||
GELOGE(FAILED, "parse shape of data op %s from user failed", op_def.GetName().c_str()); | GELOGE(FAILED, "parse shape of data op %s from user failed", op_def.GetName().c_str()); | ||||
@@ -72,15 +77,23 @@ Status OnnxDataParser::ParseInputFromModel(const Message *op_src, ge::Operator & | |||||
// Get attr t:'input_tensor' form NodeProto | // Get attr t:'input_tensor' form NodeProto | ||||
int64_t data_type = 1; | int64_t data_type = 1; | ||||
int64_t index = 0; | int64_t index = 0; | ||||
is_subgraph_data_op_ = false; | |||||
for (auto it : node->attribute()) { | for (auto it : node->attribute()) { | ||||
if (it.name() == ge::kAttrNameInput) { | if (it.name() == ge::kAttrNameInput) { | ||||
data_type = ParseInputTensor(it); | data_type = ParseInputTensor(it); | ||||
} else if (it.name() == ge::kAttrNameIndex) { | } else if (it.name() == ge::kAttrNameIndex) { | ||||
index = it.i(); | index = it.i(); | ||||
GELOGI("The node has attribute with index: %ld", index); | GELOGI("The node has attribute with index: %ld", index); | ||||
} else if (it.name() == ge::kAttrNameIsSubgraphOp) { | |||||
is_subgraph_data_op_ = true; | |||||
} | } | ||||
} | } | ||||
op_def.SetAttr(ge::ATTR_NAME_INDEX, index); | |||||
if (IsSubgraphDataOp()) { | |||||
return SUCCESS; | |||||
} | |||||
// Trans onnx type to ge type | // Trans onnx type to ge type | ||||
DataType type = OnnxUtil::ConvertOnnxDataType(data_type); | DataType type = OnnxUtil::ConvertOnnxDataType(data_type); | ||||
if (type == ge::DataType::DT_UNDEFINED) { | if (type == ge::DataType::DT_UNDEFINED) { | ||||
@@ -88,7 +101,6 @@ Status OnnxDataParser::ParseInputFromModel(const Message *op_src, ge::Operator & | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
op_def.SetAttr(ge::DATA_ATTR_NAME_DATA_TYPE, static_cast<int64_t>(type)); | op_def.SetAttr(ge::DATA_ATTR_NAME_DATA_TYPE, static_cast<int64_t>(type)); | ||||
op_def.SetAttr(ge::ATTR_NAME_INDEX, index); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -32,11 +32,17 @@ class PARSER_FUNC_VISIBILITY OnnxDataParser : public OnnxOpParser { | |||||
Status ParseInputFromUser(const ge::Operator &op_def); | Status ParseInputFromUser(const ge::Operator &op_def); | ||||
bool IsSubgraphDataOp() { | |||||
return is_subgraph_data_op_; | |||||
} | |||||
int64_t ParseInputTensor(const ge::onnx::AttributeProto &attribute); | int64_t ParseInputTensor(const ge::onnx::AttributeProto &attribute); | ||||
std::vector<int64_t> model_input_dims_v_; | std::vector<int64_t> model_input_dims_v_; | ||||
std::vector<int64_t> user_input_dims_v_; | std::vector<int64_t> user_input_dims_v_; | ||||
bool is_subgraph_data_op_; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -17,6 +17,7 @@ | |||||
#include "onnx_parser.h" | #include "onnx_parser.h" | ||||
#include <algorithm> | #include <algorithm> | ||||
#include <iostream> | #include <iostream> | ||||
#include <queue> | |||||
#include "common/convert/pb2json.h" | #include "common/convert/pb2json.h" | ||||
#include "common/util.h" | #include "common/util.h" | ||||
#include "common/util/error_manager/error_manager.h" | #include "common/util/error_manager/error_manager.h" | ||||
@@ -37,6 +38,9 @@ | |||||
#include "parser/onnx/onnx_util.h" | #include "parser/onnx/onnx_util.h" | ||||
#include "register/op_registry.h" | #include "register/op_registry.h" | ||||
#include "register/register_fmk_types.h" | #include "register/register_fmk_types.h" | ||||
#include "graph/utils/graph_utils.h" | |||||
#include "graph/utils/node_utils.h" | |||||
#include "subgraph_adapter/subgraph_adapter_factory.h" | |||||
namespace ge { | namespace ge { | ||||
graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, | graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util, | ||||
@@ -95,7 +99,7 @@ graphStatus aclgrphParseONNX(const char *model_file, | |||||
} | } | ||||
GE_CHECK_NOTNULL(model_parser); | GE_CHECK_NOTNULL(model_parser); | ||||
// parse caffe model_file to GE graph | |||||
// parse onnx model_file to GE graph | |||||
ge::graphStatus ret = model_parser->Parse(model_file, graph); | ge::graphStatus ret = model_parser->Parse(model_file, graph); | ||||
if (ret != ge::SUCCESS) { | if (ret != ge::SUCCESS) { | ||||
GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); | GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); | ||||
@@ -144,18 +148,130 @@ graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
const std::map<std::string, std::string> kOnnxOpMap = { | const std::map<std::string, std::string> kOnnxOpMap = { | ||||
{ge::kOpTypeInput, ge::parser::DATA}, {ge::kOpTypeConstant, ge::parser::CONSTANT}, | |||||
{ge::kOpTypeInput, ge::parser::DATA}, | |||||
{ge::kOpTypeConstant, ge::parser::CONSTANT} | |||||
}; | }; | ||||
const char* const MATMULV2 = "MatMulV2"; | |||||
const std::vector<std::string> kNoNeedUpdateFormat = {MATMULV2}; | |||||
const int64_t kDimValue = 1; | const int64_t kDimValue = 1; | ||||
struct ParseArg { | |||||
ge::onnx::GraphProto *onnx_graph; | |||||
ge::NodePtr parent_node; | |||||
std::string graph_name; | |||||
uint32_t subgraph_index; | |||||
}; | |||||
Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque<ParseArg> &args) { | |||||
GELOGI("Gen subgraph parse tasks start"); | |||||
for (auto &node : parent_graph->GetDirectNode()) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
for (const auto subgraph_name_to_index : op_desc->GetSubgraphNameIndexes()) { | |||||
auto i = subgraph_name_to_index.second; | |||||
auto subgraph_iname = subgraph_name_to_index.first; | |||||
if (subgraph_iname.empty()) { | |||||
GELOGW("The subgraph index %u of node %s is empty", i, node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
// change the graph name to ensure it is unique in GE | |||||
std::string unique_subgraph_name; | |||||
OnnxUtil::GenUniqueSubgraphName(i, subgraph_iname, node->GetName(), unique_subgraph_name); | |||||
GELOGD("Add subgraph parse task to the queue, node %s, index %u, subgraph instance name %s", | |||||
node->GetName().c_str(), i, unique_subgraph_name.c_str()); | |||||
args.push_back({nullptr, node, unique_subgraph_name, i}); | |||||
} | |||||
} | |||||
GELOGI("Gen subgraph parse tasks end"); | |||||
return SUCCESS; | |||||
} | |||||
Status BuildLinkForChildAndParentGraph(const ge::ComputeGraphPtr &sub_graph, const ParseArg &arg) { | |||||
if (arg.parent_node == nullptr) { | |||||
return SUCCESS; | |||||
} | |||||
auto parent_node = arg.parent_node; | |||||
auto index = arg.subgraph_index; | |||||
auto ret = ge::NodeUtils::SetSubgraph(*parent_node, index, sub_graph); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Set][Subgraph] Failed to set subgraph %s to node %s index %u", sub_graph->GetName().c_str(), | |||||
parent_node->GetName().c_str(), index); | |||||
REPORT_CALL_ERROR("E19999", "Failed to set subgraph %s to node %s index %u", sub_graph->GetName().c_str(), | |||||
parent_node->GetName().c_str(), index); | |||||
return ret; | |||||
} | |||||
return SUCCESS; | |||||
} | } | ||||
Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, | |||||
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor) { | |||||
if (onnx_graph.input_size() == 0) { | |||||
Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_graph) { | |||||
if (arg.parent_node == nullptr) { | |||||
return SUCCESS; | |||||
} | |||||
std::string op_type = arg.parent_node->GetType(); | |||||
std::string op_name = arg.parent_node->GetName(); | |||||
domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr; | |||||
auto post_func = | |||||
domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type); | |||||
if (post_func == nullptr) { | |||||
GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str()); | |||||
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS || parse_func_v2 == nullptr) { | |||||
GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str()); | |||||
return SUCCESS; | |||||
} | |||||
} | |||||
GELOGD("Post process for subgraph %s node %s type %s", arg.graph_name.c_str(), arg.parent_node->GetName().c_str(), | |||||
arg.parent_node->GetType().c_str()); | |||||
// Refresh node_name in subgraph | |||||
for (const ge::NodePtr &node : sub_graph->GetDirectNode()) { | |||||
if (node->GetOpDesc() == nullptr) { | |||||
continue; | |||||
} | |||||
node->GetOpDesc()->SetName(sub_graph->GetName() + "/" + node->GetName()); | |||||
} | |||||
auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(sub_graph); | |||||
Status ret = FAILED; | |||||
if (post_func != nullptr) { | |||||
ret = post_func(arg.graph_name, graph); | |||||
} else if (parse_func_v2 != nullptr) { | |||||
ret = parse_func_v2(arg.graph_name.c_str(), graph); | |||||
} | |||||
if (ret != SUCCESS) { | |||||
GELOGE(FAILED, "[PostProcess][Subgraph]Failed to post-process subgraph %s on node %s type %s", | |||||
arg.graph_name.c_str(), arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str()); | |||||
REPORT_CALL_ERROR("E19999", "Failed to post-process subgraph %s on node %s type %s", | |||||
arg.graph_name.c_str(), arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str()); | |||||
return FAILED; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
} | |||||
Status OnnxModelParser::ParseOutput(ge::onnx::GraphProto &onnx_graph) { | |||||
if (onnx_graph.output_size() == 0) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E16001"); | |||||
GELOGE(FAILED, "[Parse][Output] Onnx graph:%s has zero output", onnx_graph.name().c_str()); | |||||
REPORT_INNER_ERROR("E19999", "Onnx graph:%s has zero output", onnx_graph.name().c_str()); | |||||
return FAILED; | |||||
} | |||||
// get output value info map | |||||
for (int i = 0; i < onnx_graph.output_size(); i++) { | |||||
ge::onnx::ValueInfoProto value_info = onnx_graph.output(i); | |||||
GELOGI("The index of %d output name : %s.", i, value_info.name().c_str()); | |||||
output_node_names_.emplace_back(value_info.name()); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status OnnxModelParser::ParseInput(const std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor, | |||||
bool is_subgraph, ge::onnx::GraphProto &onnx_graph) { | |||||
if (!is_subgraph && onnx_graph.input_size() == 0) { | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E16001"); | ErrorManager::GetInstance().ATCReportErrMessage("E16001"); | ||||
GELOGE(FAILED, "Onnx graph has zero input"); | |||||
GELOGE(FAILED, "[Parse][Input] Root onnx graph:%s has zero input", onnx_graph.name().c_str()); | |||||
REPORT_INNER_ERROR("E19999", "Root onnx graph:%s has zero input", onnx_graph.name().c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -207,6 +323,11 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph, | |||||
ge::onnx::AttributeProto *attribute_index = input_node->add_attribute(); | ge::onnx::AttributeProto *attribute_index = input_node->add_attribute(); | ||||
attribute_index->set_name(ge::kAttrNameIndex); | attribute_index->set_name(ge::kAttrNameIndex); | ||||
attribute_index->set_i(data_index++); | attribute_index->set_i(data_index++); | ||||
// add subgraph attr | |||||
if (is_subgraph) { | |||||
attribute = input_node->add_attribute(); | |||||
attribute->set_name(ge::kAttrNameIsSubgraphOp); | |||||
} | |||||
input_node_names_.emplace_back(value_info.name()); | input_node_names_.emplace_back(value_info.name()); | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -319,7 +440,8 @@ Status OnnxModelParser::TransNodeToOperator(const ge::onnx::NodeProto *node_prot | |||||
op = ge::OperatorFactory::CreateOperator(node_name, op_type); | op = ge::OperatorFactory::CreateOperator(node_name, op_type); | ||||
if (op.GetName() != node_name) { | if (op.GetName() != node_name) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E16003", {"opname", "optype"}, {node_name, op_type}); | ErrorManager::GetInstance().ATCReportErrMessage("E16003", {"opname", "optype"}, {node_name, op_type}); | ||||
GELOGE(INTERNAL_ERROR, "IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str()); | |||||
GELOGE(INTERNAL_ERROR, "[Creat][Op] IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str()); | |||||
REPORT_INNER_ERROR("E19999", "IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str()); | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
@@ -428,7 +550,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||||
std::string op_type; | std::string op_type; | ||||
Status status = AdapterOpType(node_proto, ori_type, op_type); | Status status = AdapterOpType(node_proto, ori_type, op_type); | ||||
if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
GELOGE(status, "Adapter op type for ori type %s failed.", ori_type.c_str()); | |||||
GELOGE(status, "[Adapt][OpType] Adapter op type for ori type %s failed.", ori_type.c_str()); | |||||
REPORT_CALL_ERROR("E19999", "Adapter op type for ori type %s failed.", ori_type.c_str()); | |||||
return status; | return status; | ||||
} | } | ||||
node_proto->set_op_type(ori_type); | node_proto->set_op_type(ori_type); | ||||
@@ -438,7 +561,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||||
ge::Operator op; | ge::Operator op; | ||||
status = TransNodeToOperator(node_proto, op, op_type); | status = TransNodeToOperator(node_proto, op, op_type); | ||||
if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
GELOGE(status, "Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); | |||||
GELOGE(status, "[Trans][Node] Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); | |||||
REPORT_CALL_ERROR("E19999", "Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str()); | |||||
return status; | return status; | ||||
} | } | ||||
@@ -455,9 +579,14 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||||
return status; | return status; | ||||
} | } | ||||
GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(), | |||||
op.GetOpType().c_str(), op.GetInputsSize(), op.GetOutputsSize()); | |||||
ge::graphStatus graph_status = graph.AddOp(op); | ge::graphStatus graph_status = graph.AddOp(op); | ||||
if (graph_status != ge::GRAPH_SUCCESS) { | if (graph_status != ge::GRAPH_SUCCESS) { | ||||
GELOGE(FAILED, "Add op:%s to graph failed.", op.GetName().c_str()); | |||||
GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", op.GetName().c_str()); | |||||
REPORT_CALL_ERROR("E19999", "Add op:%s to graph failed.", op.GetName().c_str()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
name_operator_[op.GetName()] = op; | name_operator_[op.GetName()] = op; | ||||
@@ -473,18 +602,54 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status OnnxModelParser::GetGraphInputs(std::vector<ge::Operator> &input_ops) { | |||||
Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) { | |||||
if (input_node_names_.empty()) { | |||||
// subgraph might not have input, we use constant nodes as the start nodes of graph | |||||
for (int i = 0; i < onnx_graph.node_size(); i++) { | |||||
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i); | |||||
if (node->op_type() == kOpTypeConstant) { | |||||
input_node_names_.emplace_back(node->name()); | |||||
} | |||||
} | |||||
} | |||||
for (auto in_name : input_node_names_) { | for (auto in_name : input_node_names_) { | ||||
auto in_op = name_operator_.find(in_name); | auto in_op = name_operator_.find(in_name); | ||||
if (in_op == name_operator_.end()) { | if (in_op == name_operator_.end()) { | ||||
GELOGE(PARAM_INVALID, "Model assigned output node name: %s can not find in graph.", | |||||
GELOGE(PARAM_INVALID, "[Get][Inputs] Model assigned input node name: %s can not find in graph.", | |||||
in_name.c_str()); | in_name.c_str()); | ||||
REPORT_INNER_ERROR("E19999", "Model assigned input node name: %s can not find in graph.", | |||||
in_name.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
input_ops.emplace_back(in_op->second); | input_ops.emplace_back(in_op->second); | ||||
GELOGI("Model assigned input node name: %s", in_op->second.GetName().c_str()); | GELOGI("Model assigned input node name: %s", in_op->second.GetName().c_str()); | ||||
} | } | ||||
return SUCCESS; | |||||
} | |||||
Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &output_ops) { | |||||
for (auto output_name : output_node_names_) { | |||||
auto itr = outputs_map_.find(output_name); | |||||
if (itr == outputs_map_.end()) { | |||||
GELOGE(PARAM_INVALID, "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); | |||||
REPORT_INNER_ERROR( "E19999", "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
std::vector<std::pair<std::string, int>> node_names_indexes = itr->second; | |||||
for (const auto &node_name_index : node_names_indexes) { | |||||
auto node_name = node_name_index.first; | |||||
auto out_op_itr = name_operator_.find(node_name); | |||||
if (out_op_itr == name_operator_.end()) { | |||||
GELOGE(PARAM_INVALID, "[Get][Operator] Can not find operator: %s in graph.", node_name.c_str()); | |||||
REPORT_INNER_ERROR("E19999", "Can not find operator: %s in graph.", node_name.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
int index = node_name_index.second; | |||||
output_ops.emplace_back(out_op_itr->second, vector<size_t>{static_cast<size_t>(index)}); | |||||
GELOGI("out node index %d, node:%s", index, node_name.c_str()); | |||||
} | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -515,19 +680,146 @@ Status OnnxModelParser::GetModelFromMemory(const char *data, uint32_t size, ge:: | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph) { | |||||
void OnnxModelParser::ClearMembers() { | |||||
name_operator_.clear(); | |||||
input_node_names_.clear(); | |||||
output_node_names_.clear(); | |||||
inputs_map_.clear(); | |||||
outputs_map_.clear(); | |||||
} | |||||
Status OnnxModelParser::AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph, | |||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { | |||||
std::queue<ge::onnx::GraphProto *> onnx_graph_tasks; | |||||
int index = 0; | |||||
onnx_graph_tasks.push(&root_onnx_graph); | |||||
while (!onnx_graph_tasks.empty()) { | |||||
ge::onnx::GraphProto *onnx_graph = onnx_graph_tasks.front(); | |||||
onnx_graph_tasks.pop(); | |||||
for (int i = 0; i < onnx_graph->node_size(); i++) { | |||||
ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i); | |||||
if (node_proto->name().empty()) { | |||||
std::string node_name = node_proto->op_type() + "_" + to_string(index++); | |||||
node_proto->set_name(node_name); | |||||
} | |||||
GELOGD("adapt op name:%s, op type:%s", node_proto->name().c_str(), node_proto->op_type().c_str()); | |||||
SubgraphAdapterFactory *factory = SubgraphAdapterFactory::Instance(); | |||||
GE_CHECK_NOTNULL(factory); | |||||
std::shared_ptr<SubgraphAdapter> subgraph_adapter = factory->CreateSubgraphAdapter(node_proto->op_type()); | |||||
if(subgraph_adapter == nullptr) { | |||||
GELOGD("Do not need adapt subgraph, op type:%s", node_proto->op_type().c_str()); | |||||
continue; | |||||
} | |||||
std::vector<ge::onnx::GraphProto *> onnx_graphs; | |||||
std::map<std::string, ge::onnx::GraphProto *> name_to_onnx_subgraph; | |||||
if (subgraph_adapter->AdaptAndFindAllSubgraphs(node_proto, onnx_graphs, name_to_onnx_subgraph) != SUCCESS) { | |||||
GELOGE(FAILED, "[Adapt][Subgraph] adapt subgraph of node:%s failed.", node_proto->name().c_str()); | |||||
REPORT_INNER_ERROR("E19999", "adapt subgraph of node:%s failed.", node_proto->name().c_str()); | |||||
return FAILED; | |||||
} | |||||
for (const auto &onnx_graph : onnx_graphs) { | |||||
onnx_graph_tasks.push(onnx_graph); | |||||
} | |||||
for (const auto &itr : name_to_onnx_subgraph) { | |||||
name_to_onnx_graph.emplace(itr.first, itr.second); | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &root_graph) { | |||||
if (!onnx_model.has_graph()) { | if (!onnx_model.has_graph()) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E16004"); | ErrorManager::GetInstance().ATCReportErrMessage("E16004"); | ||||
GELOGE(PARAM_INVALID, "Onnx model do not has graph."); | GELOGE(PARAM_INVALID, "Onnx model do not has graph."); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
ge::onnx::GraphProto onnx_graph = onnx_model.graph(); | |||||
std::map<std::string, ge::onnx::GraphProto *> name_to_onnx_graph; | |||||
std::deque<ParseArg> tasks; | |||||
ge::onnx::GraphProto root_onnx_graph = onnx_model.graph(); | |||||
auto ret = AdaptAndFindAllOnnxGraph(root_onnx_graph, name_to_onnx_graph); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(FAILED, "[AdaptAndFind][OnnxGraph]adapt and find all onnx graph failed, root graph:%s.", | |||||
root_onnx_graph.name().c_str()); | |||||
return FAILED; | |||||
} | |||||
auto opset_import = onnx_model.opset_import(); | auto opset_import = onnx_model.opset_import(); | ||||
for (auto it : opset_import) { | for (auto it : opset_import) { | ||||
domain_verseion_[it.domain()] = it.version(); | domain_verseion_[it.domain()] = it.version(); | ||||
GELOGI("Domain: %s, Version: %ld ", it.domain().c_str(), it.version()); | GELOGI("Domain: %s, Version: %ld ", it.domain().c_str(), it.version()); | ||||
} | } | ||||
std::string root_graph_name = root_graph.GetName().empty() ? "default_graph" : root_graph.GetName(); | |||||
tasks.push_back({&root_onnx_graph, nullptr, root_graph_name, 0}); | |||||
while (!tasks.empty()) { | |||||
ParseArg arg = tasks.front(); | |||||
tasks.pop_front(); | |||||
bool is_subgraph = (arg.parent_node != nullptr) ? true : false; | |||||
if (arg.onnx_graph == nullptr) { | |||||
auto itr = name_to_onnx_graph.find(arg.graph_name); | |||||
if (itr == name_to_onnx_graph.end()) { | |||||
GELOGE(FAILED, "[Find][OnnxGraph] Can not find onnx graph, graph:%s.", arg.graph_name.c_str()); | |||||
REPORT_INNER_ERROR("E19999", "Can not find onnx graph, graph:%s.", arg.graph_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
arg.onnx_graph = itr->second; | |||||
} | |||||
ge::onnx::GraphProto *onnx_graph = arg.onnx_graph; | |||||
ge::Graph tmp_graph(arg.graph_name); | |||||
ret = ModelParseToGraphImpl(is_subgraph, *onnx_graph, tmp_graph); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Parse][Model] Model parse to graph failed, graph name:%s.", arg.graph_name.c_str()); | |||||
REPORT_INNER_ERROR("E19999", "Model parse to graph failed, graph name:%s.", arg.graph_name.c_str()); | |||||
return ret; | |||||
} | |||||
// To get the result for root graph | |||||
if (!is_subgraph) { | |||||
root_graph = tmp_graph; | |||||
} | |||||
ge::ComputeGraphPtr cur_compute_graph = ge::GraphUtils::GetComputeGraph(tmp_graph); | |||||
GE_CHECK_NOTNULL(cur_compute_graph); | |||||
ret = PostOpProcessForSubgraph(arg, cur_compute_graph); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[PostProcess][Subgraph]Post Op for subgraph:%s failed.", cur_compute_graph->GetName().c_str()); | |||||
REPORT_CALL_ERROR("E19999", "Post Op for subgraph:%s failed.", cur_compute_graph->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
ret = BuildLinkForChildAndParentGraph(cur_compute_graph, arg); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[BuildLink][Graph] Build link for child graph:%s and parent graph failed.", | |||||
cur_compute_graph->GetName().c_str()); | |||||
REPORT_CALL_ERROR("E19999", "Build link for child graph:%s and parent graph failed.", | |||||
cur_compute_graph->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
ret = GenSubgraphParseTasks(cur_compute_graph, tasks); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Generate][Task] Failed to gen tasks on graph %s for next iteration", | |||||
cur_compute_graph->GetName().c_str()); | |||||
REPORT_CALL_ERROR("E19999", "Failed to gen tasks on graph %s for next iteration", | |||||
cur_compute_graph->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
} | |||||
UpdateDataFormat(root_graph); | |||||
return SUCCESS; | |||||
} | |||||
Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { | |||||
ClearMembers(); | |||||
// 2. Get all inializer. | // 2. Get all inializer. | ||||
std::map<std::string, ge::onnx::TensorProto> initializer_name_tensor; | std::map<std::string, ge::onnx::TensorProto> initializer_name_tensor; | ||||
@@ -541,7 +833,8 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||||
// 3. Parse Input from graph. | // 3. Parse Input from graph. | ||||
GELOGI("The size of initializer_name_tensor is %zu ", initializer_name_tensor.size()); | GELOGI("The size of initializer_name_tensor is %zu ", initializer_name_tensor.size()); | ||||
Status ret = ParseInput(onnx_graph, initializer_name_tensor); | |||||
Status ret = ParseInput(initializer_name_tensor, is_subgraph, onnx_graph); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Parse input for onnx failed."); | GELOGE(ret, "Parse input for onnx failed."); | ||||
return ret; | return ret; | ||||
@@ -555,6 +848,12 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||||
return ret; | return ret; | ||||
} | } | ||||
ret = ParseOutput(onnx_graph); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Parse][Output] Parse output for onnx failed."); | |||||
return ret; | |||||
} | |||||
// 5. Update node name for node do not has name. | // 5. Update node name for node do not has name. | ||||
ret = UpdateAllNodeName(onnx_graph); | ret = UpdateAllNodeName(onnx_graph); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
@@ -582,6 +881,10 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<string> op_names; | |||||
graph.GetAllOpName(op_names); | |||||
GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size()); | |||||
// 8. Set all operator input. | // 8. Set all operator input. | ||||
ret = SetOperatorInputs(); | ret = SetOperatorInputs(); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
@@ -589,22 +892,27 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model | |||||
return ret; | return ret; | ||||
} | } | ||||
std::vector<string> op_names; | |||||
graph.GetAllOpName(op_names); | |||||
GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size()); | |||||
// 9. Construct graph. | // 9. Construct graph. | ||||
std::vector<ge::Operator> input_ops; | std::vector<ge::Operator> input_ops; | ||||
ret = GetGraphInputs(input_ops); | |||||
ret = GetGraphInputs(onnx_graph, input_ops); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Get graph inputs failed."); | GELOGE(ret, "Get graph inputs failed."); | ||||
return ret; | return ret; | ||||
} | } | ||||
graph.SetInputs(input_ops); | graph.SetInputs(input_ops); | ||||
// root graph needn't set outputs. | |||||
if(is_subgraph) { | |||||
std::vector<std::pair<Operator, std::vector<size_t>>> output_ops; | |||||
ret = GetGraphOutputs(output_ops); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Get][Outputs]Get graph outputs failed."); | |||||
return ret; | |||||
} | |||||
graph.SetOutputs(output_ops); | |||||
} | |||||
GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph)); | GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph)); | ||||
UpdateDataFormat(graph); | |||||
GELOGI("Onnx model parser success."); | GELOGI("Onnx model parser success."); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -72,8 +72,10 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||||
private: | private: | ||||
Status ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph); | Status ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph); | ||||
Status ParseInput(ge::onnx::GraphProto &onnx_graph, | |||||
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor); | |||||
Status ParseInput(const std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor, | |||||
bool is_subgraph, ge::onnx::GraphProto &onnx_graph); | |||||
Status ParseOutput(ge::onnx::GraphProto &onnx_graph); | |||||
Status ParseInitializer(ge::onnx::GraphProto &onnx_graph, | Status ParseInitializer(ge::onnx::GraphProto &onnx_graph, | ||||
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor); | std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor); | ||||
@@ -90,7 +92,9 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||||
Status SetOperatorInputs(); | Status SetOperatorInputs(); | ||||
Status GetGraphInputs(std::vector<ge::Operator> &input_ops); | |||||
Status GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops); | |||||
Status GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &outputs); | |||||
Status Prechecker(ge::onnx::GraphProto &onnx_graph); | Status Prechecker(ge::onnx::GraphProto &onnx_graph); | ||||
@@ -100,8 +104,15 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||||
Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph); | Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph); | ||||
Status ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph); | |||||
void UpdateDataFormat(ge::Graph &graph); | void UpdateDataFormat(ge::Graph &graph); | ||||
void ClearMembers(); | |||||
Status AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph, | |||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph); | |||||
std::map<std::string, std::string> ori_to_om_type_; | std::map<std::string, std::string> ori_to_om_type_; | ||||
std::map<std::string, int64_t> domain_verseion_; | std::map<std::string, int64_t> domain_verseion_; | ||||
@@ -110,6 +121,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||||
std::vector<std::string> input_node_names_; | std::vector<std::string> input_node_names_; | ||||
std::vector<std::string> output_node_names_; | |||||
std::map<std::string, std::vector<std::pair<std::string, int>>> inputs_map_; | std::map<std::string, std::vector<std::pair<std::string, int>>> inputs_map_; | ||||
std::map<std::string, std::vector<std::pair<std::string, int>>> outputs_map_; | std::map<std::string, std::vector<std::pair<std::string, int>>> outputs_map_; | ||||
@@ -60,4 +60,9 @@ int64_t OnnxUtil::CaculateDataSize(int64_t onnx_data_type) { | |||||
return ge::DataType::DT_UNDEFINED; | return ge::DataType::DT_UNDEFINED; | ||||
} | } | ||||
} | } | ||||
void OnnxUtil::GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, | |||||
const std::string &parent_node_name, std::string &unique_subgraph_name) { | |||||
unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -45,6 +45,7 @@ namespace ge { | |||||
const char *const kAttrNameValue = "value"; | const char *const kAttrNameValue = "value"; | ||||
const char *const kAttrNameInput = "input_tensor"; | const char *const kAttrNameInput = "input_tensor"; | ||||
const char *const kAttrNameIndex = "index"; | const char *const kAttrNameIndex = "index"; | ||||
const char *const kAttrNameIsSubgraphOp = "is_subgraph_op"; | |||||
const char *const kOpTypeConstant = "Constant"; | const char *const kOpTypeConstant = "Constant"; | ||||
const char *const kOpTypeInput = "Input"; | const char *const kOpTypeInput = "Input"; | ||||
@@ -52,6 +53,8 @@ class OnnxUtil { | |||||
public: | public: | ||||
static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type); | static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type); | ||||
static int64_t CaculateDataSize(int64_t onnx_data_type); | static int64_t CaculateDataSize(int64_t onnx_data_type); | ||||
static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name, | |||||
const std::string &parent_node_name, std::string &unique_subgraph_name); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -0,0 +1,131 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "if_subgraph_adapter.h" | |||||
#include "subgraph_adapter_factory.h" | |||||
#include "common/util.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
namespace ge{ | |||||
namespace { | |||||
const std::map<std::string, int> kAttrNameToIndex = {{"then_branch", 0}, {"else_branch", 1}}; | |||||
const int kIfNodeAttrSize = 2; | |||||
} | |||||
Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_node, | |||||
std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { | |||||
GE_CHECK_NOTNULL(parent_node); | |||||
GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), | |||||
parent_node->op_type().c_str()); | |||||
auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Parse][Node] Parse if node failed."); | |||||
REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str()); | |||||
return ret; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status IfSubgraphAdapter::ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, | |||||
std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { | |||||
if (parent_node->attribute_size() != kIfNodeAttrSize) { | |||||
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | |||||
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | |||||
return FAILED; | |||||
} | |||||
GELOGD("node attribute size:%d.", parent_node->attribute_size()); | |||||
std::set<std::string> all_inputs; | |||||
// for onnx graph, the first attribute may be else branch and the second attribute may be then branch | |||||
for (int i = 0; i < parent_node->attribute_size(); i++) { | |||||
ge::onnx::AttributeProto *attribute = parent_node->mutable_attribute(i); | |||||
GE_CHECK_NOTNULL(attribute); | |||||
std::string attr_name = attribute->name(); | |||||
auto itr = kAttrNameToIndex.find(attr_name); | |||||
if (itr == kAttrNameToIndex.end()) { | |||||
GELOGE(FAILED, "[Parse][Attribute] Invalid attribute name:%s, it should be then_branch or else_branch.", | |||||
attr_name.c_str()); | |||||
REPORT_INNER_ERROR("E19999", "Invalid attribute name:%s, it should be then_branch or else_branch.", | |||||
attr_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
std::string unique_subgraph_name; | |||||
OnnxUtil::GenUniqueSubgraphName(itr->second, itr->first, parent_node->name(), unique_subgraph_name); | |||||
GELOGI("Adapt if node attribute:%s, subgraph name:%s.", attr_name.c_str(), unique_subgraph_name.c_str()); | |||||
ge::onnx::GraphProto *onnx_graph = attribute->mutable_g(); | |||||
name_to_onnx_graph[unique_subgraph_name] = onnx_graph; | |||||
onnx_graphs.emplace_back(onnx_graph); | |||||
auto ret = GetSubgraphsAllInputs(*onnx_graph, all_inputs); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Get][Inputs] Get subgraph all inputs failed, attr_name:%s.", attr_name.c_str()); | |||||
REPORT_INNER_ERROR("E19999", "Get subgraph all inputs failed, attr_name:%s.", attr_name.c_str()); | |||||
return ret; | |||||
} | |||||
} | |||||
for (auto &onnx_graph : onnx_graphs) { | |||||
AddInputNodeForGraph(all_inputs, *onnx_graph); | |||||
} | |||||
AddInputForParentNode(all_inputs, *parent_node); | |||||
return SUCCESS; | |||||
} | |||||
Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, | |||||
std::set<std::string> &all_inputs) { | |||||
std::set<std::string> graph_inputs; | |||||
std::set<std::string> graph_outputs; | |||||
for (int i = 0; i < onnx_graph.node_size(); i++) { | |||||
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); | |||||
for (int j = 0; j < node_proto->input_size(); j++) { | |||||
graph_inputs.emplace(node_proto->input(j)); | |||||
} | |||||
for (int j = 0; j < node_proto->output_size(); j++) { | |||||
graph_outputs.emplace(node_proto->output(j)); | |||||
} | |||||
} | |||||
for (const auto &input : graph_inputs) { | |||||
auto out_iter = graph_outputs.find(input); | |||||
if (out_iter == graph_outputs.end()) { | |||||
// Record input node need to be constructed | |||||
all_inputs.emplace(input); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
void IfSubgraphAdapter::AddInputNodeForGraph(const std::set<std::string> &all_inputs, | |||||
ge::onnx::GraphProto &onnx_graph) { | |||||
for (const auto &input_name : all_inputs) { | |||||
ge::onnx::ValueInfoProto *value_info = onnx_graph.add_input(); | |||||
value_info->set_name(input_name); | |||||
} | |||||
} | |||||
void IfSubgraphAdapter::AddInputForParentNode(const std::set<std::string> &all_inputs, | |||||
ge::onnx::NodeProto &parent_node) { | |||||
for (const auto &input_name : all_inputs) { | |||||
parent_node.add_input(input_name); | |||||
} | |||||
} | |||||
REGISTER_SUBGRAPH_ADAPTER_CREATOR(IF, IfSubgraphAdapter); | |||||
} // namespace ge |
@@ -0,0 +1,40 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_PARSER_ONNX_SUBGRAPH_ADAPTER_IF_SUBGRAPH_ADAPTER_H_ | |||||
#define GE_PARSER_ONNX_SUBGRAPH_ADAPTER_IF_SUBGRAPH_ADAPTER_H_ | |||||
#include <set> | |||||
#include <string> | |||||
#include "subgraph_adapter.h" | |||||
using ge::onnx::NodeProto; | |||||
namespace ge { | |||||
class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { | |||||
public: | |||||
Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) override; | |||||
private: | |||||
Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph); | |||||
Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs); | |||||
void AddInputNodeForGraph(const std::set<std::string> &all_inputs, ge::onnx::GraphProto &onnx_graph); | |||||
void AddInputForParentNode(const std::set<std::string> &all_inputs, ge::onnx::NodeProto &parent_node); | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_PARSER_ONNX_SUBGRAPH_ADAPTER_IF_SUBGRAPH_ADAPTER_H_ |
@@ -0,0 +1,61 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_H_ | |||||
#define GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_H_ | |||||
#if defined(_MSC_VER) | |||||
#ifdef FUNC_VISIBILITY | |||||
#define PARSER_FUNC_VISIBILITY _declspec(dllexport) | |||||
#else | |||||
#define PARSER_FUNC_VISIBILITY | |||||
#endif | |||||
#else | |||||
#ifdef FUNC_VISIBILITY | |||||
#define PARSER_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||||
#else | |||||
#define PARSER_FUNC_VISIBILITY | |||||
#endif | |||||
#endif | |||||
#include <map> | |||||
#include <vector> | |||||
#include "proto/onnx/ge_onnx.pb.h" | |||||
#include "external/register/register_error_codes.h" | |||||
#include "framework/omg/parser/parser_types.h" | |||||
#include "onnx_util.h" | |||||
using Status = domi::Status; | |||||
using namespace ge::parser; | |||||
namespace ge { | |||||
class PARSER_FUNC_VISIBILITY SubgraphAdapter { | |||||
public: | |||||
/// @brief parse params | |||||
/// @param [in/out] parent_op parent op | |||||
/// @param [in/out] onnx_graph_tasks onnx graph task | |||||
/// @param [in/out] name_to_onnx_graph map name to onnx graph | |||||
/// @return SUCCESS parse success | |||||
/// @return FAILED Parse failed | |||||
virtual Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, | |||||
std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { | |||||
return domi::SUCCESS; | |||||
} | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_H_ |
@@ -0,0 +1,45 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "subgraph_adapter_factory.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
namespace ge{ | |||||
SubgraphAdapterFactory* SubgraphAdapterFactory::Instance() { | |||||
static SubgraphAdapterFactory instance; | |||||
return &instance; | |||||
} | |||||
std::shared_ptr<SubgraphAdapter> SubgraphAdapterFactory::CreateSubgraphAdapter( | |||||
const std::string &op_type) { | |||||
// First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create SubgraphAdapter. | |||||
auto iter = subgraph_adapter_creator_map_.find(op_type); | |||||
if (iter != subgraph_adapter_creator_map_.end()) { | |||||
return iter->second(); | |||||
} | |||||
GELOGW("SubgraphAdapterFactory::CreateSubgraphAdapter: Not supported type: %s", op_type.c_str()); | |||||
return nullptr; | |||||
} | |||||
// This function is only called within the constructor of the global SubgraphAdapterRegisterar object, | |||||
// and does not involve concurrency, so there is no need to lock it | |||||
void SubgraphAdapterFactory::RegisterCreator(const std::string &type, CREATOR_FUN fun) { | |||||
std::map<std::string, CREATOR_FUN> *subgraph_adapter_creator_map = &subgraph_adapter_creator_map_; | |||||
GELOGD("SubgraphAdapterFactory::RegisterCreator: op type:%s.", type.c_str()); | |||||
(*subgraph_adapter_creator_map)[type] = fun; | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,119 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_FACTORY_H_ | |||||
#define GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_FACTORY_H_ | |||||
#if defined(_MSC_VER) | |||||
#ifdef FUNC_VISIBILITY | |||||
#define PARSER_FUNC_VISIBILITY _declspec(dllexport) | |||||
#else | |||||
#define PARSER_FUNC_VISIBILITY | |||||
#endif | |||||
#else | |||||
#ifdef FUNC_VISIBILITY | |||||
#define PARSER_FUNC_VISIBILITY __attribute__((visibility("default"))) | |||||
#else | |||||
#define PARSER_FUNC_VISIBILITY | |||||
#endif | |||||
#endif | |||||
#include <map> | |||||
#include <functional> | |||||
#include "subgraph_adapter.h" | |||||
namespace ge { | |||||
/** | |||||
* @brief Used to create OpParser | |||||
* | |||||
*/ | |||||
class PARSER_FUNC_VISIBILITY SubgraphAdapterFactory { | |||||
public: | |||||
/** | |||||
* @brief Returns the SubgraphAdapterFactory instance | |||||
* @return SubgraphAdapterFactory object | |||||
*/ | |||||
static SubgraphAdapterFactory* Instance(); | |||||
/** | |||||
* @brief Create SubgraphAdapter based on input type | |||||
* @param [in] op_type Op type | |||||
* @return Created SubgraphAdapter | |||||
*/ | |||||
std::shared_ptr<SubgraphAdapter> CreateSubgraphAdapter(const std::string &op_type); | |||||
protected: | |||||
/** | |||||
* @brief SubgraphAdapter creation function | |||||
* @return Created SubgraphAdapter | |||||
*/ | |||||
// typedef shared_ptr<SubgraphAdapter> (*CREATOR_FUN)(void); | |||||
using CREATOR_FUN = std::function<std::shared_ptr<SubgraphAdapter>(void)>; | |||||
/** | |||||
* @brief Factory instances can only be created automatically, not new methods, so the constructor is not public. | |||||
*/ | |||||
SubgraphAdapterFactory() {} | |||||
/** | |||||
* @brief Register creation function | |||||
* @param [in] type Op type | |||||
* @param [in] fun OpParser creation function | |||||
*/ | |||||
void RegisterCreator(const std::string &type, CREATOR_FUN fun); | |||||
private: | |||||
std::map<std::string, CREATOR_FUN> subgraph_adapter_creator_map_; // lint !e1073 | |||||
friend class SubgraphAdapterRegisterar; | |||||
}; | |||||
/** | |||||
* @brief For registering Creator functions for different types of subgraph adapter | |||||
* | |||||
*/ | |||||
class PARSER_FUNC_VISIBILITY SubgraphAdapterRegisterar { | |||||
public: | |||||
/** | |||||
* @brief Constructor | |||||
* @param [in] op_type Op type | |||||
* @param [in] fun Creator function corresponding to Subgrap adapter | |||||
*/ | |||||
SubgraphAdapterRegisterar(const std::string &op_type, SubgraphAdapterFactory::CREATOR_FUN fun) { | |||||
SubgraphAdapterFactory::Instance()->RegisterCreator(op_type, fun); | |||||
} | |||||
~SubgraphAdapterRegisterar() {} | |||||
}; | |||||
/** | |||||
* @brief SubgraphAdapter Registration Macro | |||||
* @param [in] op_type Op type | |||||
* @param [in] clazz SubgraphAdapter implementation class | |||||
*/ | |||||
#define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \ | |||||
std::shared_ptr<SubgraphAdapter> Creator_##op_type##_Subgraph_Adapter() { \ | |||||
std::shared_ptr<clazz> ptr(new (std::nothrow) clazz()); \ | |||||
if (ptr == nullptr) { \ | |||||
GELOGW("MakeShared failed, result is nullptr."); \ | |||||
} \ | |||||
return std::shared_ptr<SubgraphAdapter>(ptr); \ | |||||
} \ | |||||
ge::SubgraphAdapterRegisterar g_##op_type##_Subgraph_Adapter_Creator(op_type, \ | |||||
Creator_##op_type##_Subgraph_Adapter) | |||||
} // namespace ge | |||||
#endif // GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_FACTORY_H_ |