Browse Source

onnx if lopp

pull/276/head
陈华 4 years ago
parent
commit
ae16b5707f
12 changed files with 775 additions and 29 deletions
  1. +4
    -1
      parser/onnx/CMakeLists.txt
  2. +13
    -1
      parser/onnx/onnx_data_parser.cc
  3. +6
    -0
      parser/onnx/onnx_data_parser.h
  4. +332
    -24
      parser/onnx/onnx_parser.cc
  5. +16
    -3
      parser/onnx/onnx_parser.h
  6. +5
    -0
      parser/onnx/onnx_util.cc
  7. +3
    -0
      parser/onnx/onnx_util.h
  8. +131
    -0
      parser/onnx/subgraph_adapter/if_subgraph_adapter.cc
  9. +40
    -0
      parser/onnx/subgraph_adapter/if_subgraph_adapter.h
  10. +61
    -0
      parser/onnx/subgraph_adapter/subgraph_adapter.h
  11. +45
    -0
      parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc
  12. +119
    -0
      parser/onnx/subgraph_adapter/subgraph_adapter_factory.h

+ 4
- 1
parser/onnx/CMakeLists.txt View File

@@ -8,7 +8,9 @@ set(SRC_LIST
"onnx_parser.cc"
"onnx_data_parser.cc"
"onnx_util.cc"
"onnx_constant_parser.cc"
"onnx_constant_parser.cc"
"subgraph_adapter/if_subgraph_adapter.cc"
"subgraph_adapter/subgraph_adapter_factory.cc"
)

protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST})
@@ -31,6 +33,7 @@ target_compile_definitions(fmk_onnx_parser PRIVATE

target_include_directories(fmk_onnx_parser PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${CMAKE_CURRENT_LIST_DIR}/subgraph_adapter
${PARSER_DIR}
${PARSER_DIR}/inc
${PARSER_DIR}/parser


+ 13
- 1
parser/onnx/onnx_data_parser.cc View File

@@ -35,6 +35,11 @@ Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def)
GELOGE(FAILED, "parse shape of data op %s from model failed", op_def.GetName().c_str());
return FAILED;
}
// Subgraph data operator don't need parse input shape
// the shape mappings from parent node input
if (IsSubgraphDataOp()) {
return SUCCESS;
}

if (ParseInputFromUser(op_def) != SUCCESS) {
GELOGE(FAILED, "parse shape of data op %s from user failed", op_def.GetName().c_str());
@@ -72,15 +77,23 @@ Status OnnxDataParser::ParseInputFromModel(const Message *op_src, ge::Operator &
// Get attr t:'input_tensor' form NodeProto
int64_t data_type = 1;
int64_t index = 0;
is_subgraph_data_op_ = false;
for (auto it : node->attribute()) {
if (it.name() == ge::kAttrNameInput) {
data_type = ParseInputTensor(it);
} else if (it.name() == ge::kAttrNameIndex) {
index = it.i();
GELOGI("The node has attribute with index: %ld", index);
} else if (it.name() == ge::kAttrNameIsSubgraphOp) {
is_subgraph_data_op_ = true;
}
}

op_def.SetAttr(ge::ATTR_NAME_INDEX, index);
if (IsSubgraphDataOp()) {
return SUCCESS;
}

// Trans onnx type to ge type
DataType type = OnnxUtil::ConvertOnnxDataType(data_type);
if (type == ge::DataType::DT_UNDEFINED) {
@@ -88,7 +101,6 @@ Status OnnxDataParser::ParseInputFromModel(const Message *op_src, ge::Operator &
return FAILED;
}
op_def.SetAttr(ge::DATA_ATTR_NAME_DATA_TYPE, static_cast<int64_t>(type));
op_def.SetAttr(ge::ATTR_NAME_INDEX, index);

return SUCCESS;
}


+ 6
- 0
parser/onnx/onnx_data_parser.h View File

@@ -32,11 +32,17 @@ class PARSER_FUNC_VISIBILITY OnnxDataParser : public OnnxOpParser {

Status ParseInputFromUser(const ge::Operator &op_def);

bool IsSubgraphDataOp() {
return is_subgraph_data_op_;
}

int64_t ParseInputTensor(const ge::onnx::AttributeProto &attribute);

std::vector<int64_t> model_input_dims_v_;

std::vector<int64_t> user_input_dims_v_;

bool is_subgraph_data_op_;
};
} // namespace ge



+ 332
- 24
parser/onnx/onnx_parser.cc View File

@@ -17,6 +17,7 @@
#include "onnx_parser.h"
#include <algorithm>
#include <iostream>
#include <queue>
#include "common/convert/pb2json.h"
#include "common/util.h"
#include "common/util/error_manager/error_manager.h"
@@ -37,6 +38,9 @@
#include "parser/onnx/onnx_util.h"
#include "register/op_registry.h"
#include "register/register_fmk_types.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "subgraph_adapter/subgraph_adapter_factory.h"

namespace ge {
graphStatus PrepareBeforeParse(AclGrphParseUtil &acl_graph_parse_util,
@@ -95,7 +99,7 @@ graphStatus aclgrphParseONNX(const char *model_file,
}

GE_CHECK_NOTNULL(model_parser);
// parse caffe model_file to GE graph
// parse onnx model_file to GE graph
ge::graphStatus ret = model_parser->Parse(model_file, graph);
if (ret != ge::SUCCESS) {
GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str());
@@ -144,18 +148,130 @@ graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size,
namespace ge {
namespace {
const std::map<std::string, std::string> kOnnxOpMap = {
{ge::kOpTypeInput, ge::parser::DATA}, {ge::kOpTypeConstant, ge::parser::CONSTANT},
{ge::kOpTypeInput, ge::parser::DATA},
{ge::kOpTypeConstant, ge::parser::CONSTANT}
};
const char* const MATMULV2 = "MatMulV2";
const std::vector<std::string> kNoNeedUpdateFormat = {MATMULV2};
const int64_t kDimValue = 1;

struct ParseArg {
ge::onnx::GraphProto *onnx_graph;
ge::NodePtr parent_node;
std::string graph_name;
uint32_t subgraph_index;
};

Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque<ParseArg> &args) {
GELOGI("Gen subgraph parse tasks start");
for (auto &node : parent_graph->GetDirectNode()) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
for (const auto subgraph_name_to_index : op_desc->GetSubgraphNameIndexes()) {
auto i = subgraph_name_to_index.second;
auto subgraph_iname = subgraph_name_to_index.first;
if (subgraph_iname.empty()) {
GELOGW("The subgraph index %u of node %s is empty", i, node->GetName().c_str());
continue;
}

// change the graph name to ensure it is unique in GE
std::string unique_subgraph_name;
OnnxUtil::GenUniqueSubgraphName(i, subgraph_iname, node->GetName(), unique_subgraph_name);

GELOGD("Add subgraph parse task to the queue, node %s, index %u, subgraph instance name %s",
node->GetName().c_str(), i, unique_subgraph_name.c_str());
args.push_back({nullptr, node, unique_subgraph_name, i});
}
}
GELOGI("Gen subgraph parse tasks end");
return SUCCESS;
}

Status BuildLinkForChildAndParentGraph(const ge::ComputeGraphPtr &sub_graph, const ParseArg &arg) {
if (arg.parent_node == nullptr) {
return SUCCESS;
}
auto parent_node = arg.parent_node;
auto index = arg.subgraph_index;
auto ret = ge::NodeUtils::SetSubgraph(*parent_node, index, sub_graph);
if (ret != SUCCESS) {
GELOGE(ret, "[Set][Subgraph] Failed to set subgraph %s to node %s index %u", sub_graph->GetName().c_str(),
parent_node->GetName().c_str(), index);
REPORT_CALL_ERROR("E19999", "Failed to set subgraph %s to node %s index %u", sub_graph->GetName().c_str(),
parent_node->GetName().c_str(), index);
return ret;
}
return SUCCESS;
}

Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph,
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor) {
if (onnx_graph.input_size() == 0) {
Status PostOpProcessForSubgraph(const ParseArg &arg, ge::ComputeGraphPtr sub_graph) {
if (arg.parent_node == nullptr) {
return SUCCESS;
}
std::string op_type = arg.parent_node->GetType();
std::string op_name = arg.parent_node->GetName();
domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr;
auto post_func =
domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type);
if (post_func == nullptr) {
GELOGW("The subgraph post func for node %s type %s is null", op_name.c_str(), op_type.c_str());
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(op_type, parse_func_v2) != SUCCESS || parse_func_v2 == nullptr) {
GELOGW("The subgraph post func v2 for node %s type %s is null", op_name.c_str(), op_type.c_str());
return SUCCESS;
}
}

GELOGD("Post process for subgraph %s node %s type %s", arg.graph_name.c_str(), arg.parent_node->GetName().c_str(),
arg.parent_node->GetType().c_str());

// Refresh node_name in subgraph
for (const ge::NodePtr &node : sub_graph->GetDirectNode()) {
if (node->GetOpDesc() == nullptr) {
continue;
}
node->GetOpDesc()->SetName(sub_graph->GetName() + "/" + node->GetName());
}

auto graph = ge::GraphUtils::CreateGraphFromComputeGraph(sub_graph);
Status ret = FAILED;
if (post_func != nullptr) {
ret = post_func(arg.graph_name, graph);
} else if (parse_func_v2 != nullptr) {
ret = parse_func_v2(arg.graph_name.c_str(), graph);
}
if (ret != SUCCESS) {
GELOGE(FAILED, "[PostProcess][Subgraph]Failed to post-process subgraph %s on node %s type %s",
arg.graph_name.c_str(), arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str());
REPORT_CALL_ERROR("E19999", "Failed to post-process subgraph %s on node %s type %s",
arg.graph_name.c_str(), arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str());
return FAILED;
}
return SUCCESS;
}
}

Status OnnxModelParser::ParseOutput(ge::onnx::GraphProto &onnx_graph) {
if (onnx_graph.output_size() == 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E16001");
GELOGE(FAILED, "[Parse][Output] Onnx graph:%s has zero output", onnx_graph.name().c_str());
REPORT_INNER_ERROR("E19999", "Onnx graph:%s has zero output", onnx_graph.name().c_str());
return FAILED;
}

// get output value info map
for (int i = 0; i < onnx_graph.output_size(); i++) {
ge::onnx::ValueInfoProto value_info = onnx_graph.output(i);
GELOGI("The index of %d output name : %s.", i, value_info.name().c_str());
output_node_names_.emplace_back(value_info.name());
}
return SUCCESS;
}

Status OnnxModelParser::ParseInput(const std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor,
bool is_subgraph, ge::onnx::GraphProto &onnx_graph) {
if (!is_subgraph && onnx_graph.input_size() == 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E16001");
GELOGE(FAILED, "Onnx graph has zero input");
GELOGE(FAILED, "[Parse][Input] Root onnx graph:%s has zero input", onnx_graph.name().c_str());
REPORT_INNER_ERROR("E19999", "Root onnx graph:%s has zero input", onnx_graph.name().c_str());
return FAILED;
}

@@ -207,6 +323,11 @@ Status OnnxModelParser::ParseInput(ge::onnx::GraphProto &onnx_graph,
ge::onnx::AttributeProto *attribute_index = input_node->add_attribute();
attribute_index->set_name(ge::kAttrNameIndex);
attribute_index->set_i(data_index++);
// add subgraph attr
if (is_subgraph) {
attribute = input_node->add_attribute();
attribute->set_name(ge::kAttrNameIsSubgraphOp);
}
input_node_names_.emplace_back(value_info.name());
}
return SUCCESS;
@@ -319,7 +440,8 @@ Status OnnxModelParser::TransNodeToOperator(const ge::onnx::NodeProto *node_prot
op = ge::OperatorFactory::CreateOperator(node_name, op_type);
if (op.GetName() != node_name) {
ErrorManager::GetInstance().ATCReportErrMessage("E16003", {"opname", "optype"}, {node_name, op_type});
GELOGE(INTERNAL_ERROR, "IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str());
GELOGE(INTERNAL_ERROR, "[Creat][Op] IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str());
REPORT_INNER_ERROR("E19999", "IR for op[%s] optype[%s] is not registered.", node_name.c_str(), op_type.c_str());
return INTERNAL_ERROR;
}

@@ -428,7 +550,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
std::string op_type;
Status status = AdapterOpType(node_proto, ori_type, op_type);
if (status != SUCCESS) {
GELOGE(status, "Adapter op type for ori type %s failed.", ori_type.c_str());
GELOGE(status, "[Adapt][OpType] Adapter op type for ori type %s failed.", ori_type.c_str());
REPORT_CALL_ERROR("E19999", "Adapter op type for ori type %s failed.", ori_type.c_str());
return status;
}
node_proto->set_op_type(ori_type);
@@ -438,7 +561,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
ge::Operator op;
status = TransNodeToOperator(node_proto, op, op_type);
if (status != SUCCESS) {
GELOGE(status, "Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str());
GELOGE(status, "[Trans][Node] Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str());
REPORT_CALL_ERROR("E19999", "Trans node to operator for %s:%s failed.", node_name.c_str(), op_type.c_str());
return status;
}

@@ -455,9 +579,14 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
return status;
}

GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(),
op.GetOpType().c_str(), op.GetInputsSize(), op.GetOutputsSize());


ge::graphStatus graph_status = graph.AddOp(op);
if (graph_status != ge::GRAPH_SUCCESS) {
GELOGE(FAILED, "Add op:%s to graph failed.", op.GetName().c_str());
GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", op.GetName().c_str());
REPORT_CALL_ERROR("E19999", "Add op:%s to graph failed.", op.GetName().c_str());
return FAILED;
}
name_operator_[op.GetName()] = op;
@@ -473,18 +602,54 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
return SUCCESS;
}

Status OnnxModelParser::GetGraphInputs(std::vector<ge::Operator> &input_ops) {
Status OnnxModelParser::GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops) {
if (input_node_names_.empty()) {
// subgraph might not have input, we use constant nodes as the start nodes of graph
for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node = onnx_graph.mutable_node(i);
if (node->op_type() == kOpTypeConstant) {
input_node_names_.emplace_back(node->name());
}
}
}
for (auto in_name : input_node_names_) {
auto in_op = name_operator_.find(in_name);
if (in_op == name_operator_.end()) {
GELOGE(PARAM_INVALID, "Model assigned output node name: %s can not find in graph.",
GELOGE(PARAM_INVALID, "[Get][Inputs] Model assigned input node name: %s can not find in graph.",
in_name.c_str());
REPORT_INNER_ERROR("E19999", "Model assigned input node name: %s can not find in graph.",
in_name.c_str());
return PARAM_INVALID;
}
input_ops.emplace_back(in_op->second);
GELOGI("Model assigned input node name: %s", in_op->second.GetName().c_str());
}
return SUCCESS;
}

Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &output_ops) {
for (auto output_name : output_node_names_) {
auto itr = outputs_map_.find(output_name);
if (itr == outputs_map_.end()) {
GELOGE(PARAM_INVALID, "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str());
REPORT_INNER_ERROR( "E19999", "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str());
return PARAM_INVALID;
}

std::vector<std::pair<std::string, int>> node_names_indexes = itr->second;
for (const auto &node_name_index : node_names_indexes) {
auto node_name = node_name_index.first;
auto out_op_itr = name_operator_.find(node_name);
if (out_op_itr == name_operator_.end()) {
GELOGE(PARAM_INVALID, "[Get][Operator] Can not find operator: %s in graph.", node_name.c_str());
REPORT_INNER_ERROR("E19999", "Can not find operator: %s in graph.", node_name.c_str());
return PARAM_INVALID;
}
int index = node_name_index.second;
output_ops.emplace_back(out_op_itr->second, vector<size_t>{static_cast<size_t>(index)});
GELOGI("out node index %d, node:%s", index, node_name.c_str());
}
}
return SUCCESS;
}

@@ -515,19 +680,146 @@ Status OnnxModelParser::GetModelFromMemory(const char *data, uint32_t size, ge::
return SUCCESS;
}

Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph) {
void OnnxModelParser::ClearMembers() {
name_operator_.clear();
input_node_names_.clear();
output_node_names_.clear();
inputs_map_.clear();
outputs_map_.clear();
}

Status OnnxModelParser::AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
std::queue<ge::onnx::GraphProto *> onnx_graph_tasks;
int index = 0;
onnx_graph_tasks.push(&root_onnx_graph);

while (!onnx_graph_tasks.empty()) {
ge::onnx::GraphProto *onnx_graph = onnx_graph_tasks.front();
onnx_graph_tasks.pop();
for (int i = 0; i < onnx_graph->node_size(); i++) {
ge::onnx::NodeProto *node_proto = onnx_graph->mutable_node(i);
if (node_proto->name().empty()) {
std::string node_name = node_proto->op_type() + "_" + to_string(index++);
node_proto->set_name(node_name);
}
GELOGD("adapt op name:%s, op type:%s", node_proto->name().c_str(), node_proto->op_type().c_str());

SubgraphAdapterFactory *factory = SubgraphAdapterFactory::Instance();
GE_CHECK_NOTNULL(factory);
std::shared_ptr<SubgraphAdapter> subgraph_adapter = factory->CreateSubgraphAdapter(node_proto->op_type());
if(subgraph_adapter == nullptr) {
GELOGD("Do not need adapt subgraph, op type:%s", node_proto->op_type().c_str());
continue;
}
std::vector<ge::onnx::GraphProto *> onnx_graphs;
std::map<std::string, ge::onnx::GraphProto *> name_to_onnx_subgraph;
if (subgraph_adapter->AdaptAndFindAllSubgraphs(node_proto, onnx_graphs, name_to_onnx_subgraph) != SUCCESS) {
GELOGE(FAILED, "[Adapt][Subgraph] adapt subgraph of node:%s failed.", node_proto->name().c_str());
REPORT_INNER_ERROR("E19999", "adapt subgraph of node:%s failed.", node_proto->name().c_str());
return FAILED;
}

for (const auto &onnx_graph : onnx_graphs) {
onnx_graph_tasks.push(onnx_graph);
}
for (const auto &itr : name_to_onnx_subgraph) {
name_to_onnx_graph.emplace(itr.first, itr.second);
}
}
}
return SUCCESS;
}

Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &root_graph) {
if (!onnx_model.has_graph()) {
ErrorManager::GetInstance().ATCReportErrMessage("E16004");
GELOGE(PARAM_INVALID, "Onnx model do not has graph.");
return FAILED;
}
ge::onnx::GraphProto onnx_graph = onnx_model.graph();
std::map<std::string, ge::onnx::GraphProto *> name_to_onnx_graph;
std::deque<ParseArg> tasks;
ge::onnx::GraphProto root_onnx_graph = onnx_model.graph();

auto ret = AdaptAndFindAllOnnxGraph(root_onnx_graph, name_to_onnx_graph);
if (ret != SUCCESS) {
GELOGE(FAILED, "[AdaptAndFind][OnnxGraph]adapt and find all onnx graph failed, root graph:%s.",
root_onnx_graph.name().c_str());
return FAILED;
}

auto opset_import = onnx_model.opset_import();
for (auto it : opset_import) {
domain_verseion_[it.domain()] = it.version();
GELOGI("Domain: %s, Version: %ld ", it.domain().c_str(), it.version());
}
std::string root_graph_name = root_graph.GetName().empty() ? "default_graph" : root_graph.GetName();
tasks.push_back({&root_onnx_graph, nullptr, root_graph_name, 0});

while (!tasks.empty()) {
ParseArg arg = tasks.front();
tasks.pop_front();
bool is_subgraph = (arg.parent_node != nullptr) ? true : false;

if (arg.onnx_graph == nullptr) {
auto itr = name_to_onnx_graph.find(arg.graph_name);
if (itr == name_to_onnx_graph.end()) {
GELOGE(FAILED, "[Find][OnnxGraph] Can not find onnx graph, graph:%s.", arg.graph_name.c_str());
REPORT_INNER_ERROR("E19999", "Can not find onnx graph, graph:%s.", arg.graph_name.c_str());
return FAILED;
}
arg.onnx_graph = itr->second;
}

ge::onnx::GraphProto *onnx_graph = arg.onnx_graph;
ge::Graph tmp_graph(arg.graph_name);
ret = ModelParseToGraphImpl(is_subgraph, *onnx_graph, tmp_graph);
if (ret != SUCCESS) {
GELOGE(ret, "[Parse][Model] Model parse to graph failed, graph name:%s.", arg.graph_name.c_str());
REPORT_INNER_ERROR("E19999", "Model parse to graph failed, graph name:%s.", arg.graph_name.c_str());
return ret;
}
// To get the result for root graph
if (!is_subgraph) {
root_graph = tmp_graph;
}

ge::ComputeGraphPtr cur_compute_graph = ge::GraphUtils::GetComputeGraph(tmp_graph);
GE_CHECK_NOTNULL(cur_compute_graph);

ret = PostOpProcessForSubgraph(arg, cur_compute_graph);
if (ret != SUCCESS) {
GELOGE(ret, "[PostProcess][Subgraph]Post Op for subgraph:%s failed.", cur_compute_graph->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Post Op for subgraph:%s failed.", cur_compute_graph->GetName().c_str());
return ret;
}

ret = BuildLinkForChildAndParentGraph(cur_compute_graph, arg);
if (ret != SUCCESS) {
GELOGE(ret, "[BuildLink][Graph] Build link for child graph:%s and parent graph failed.",
cur_compute_graph->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Build link for child graph:%s and parent graph failed.",
cur_compute_graph->GetName().c_str());
return ret;
}

ret = GenSubgraphParseTasks(cur_compute_graph, tasks);
if (ret != SUCCESS) {
GELOGE(ret, "[Generate][Task] Failed to gen tasks on graph %s for next iteration",
cur_compute_graph->GetName().c_str());
REPORT_CALL_ERROR("E19999", "Failed to gen tasks on graph %s for next iteration",
cur_compute_graph->GetName().c_str());
return ret;
}

}
UpdateDataFormat(root_graph);
return SUCCESS;
}

Status OnnxModelParser::ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) {

ClearMembers();

// 2. Get all inializer.
std::map<std::string, ge::onnx::TensorProto> initializer_name_tensor;
@@ -541,7 +833,8 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model

// 3. Parse Input from graph.
GELOGI("The size of initializer_name_tensor is %zu ", initializer_name_tensor.size());
Status ret = ParseInput(onnx_graph, initializer_name_tensor);

Status ret = ParseInput(initializer_name_tensor, is_subgraph, onnx_graph);
if (ret != SUCCESS) {
GELOGE(ret, "Parse input for onnx failed.");
return ret;
@@ -555,6 +848,12 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model
return ret;
}

ret = ParseOutput(onnx_graph);
if (ret != SUCCESS) {
GELOGE(ret, "[Parse][Output] Parse output for onnx failed.");
return ret;
}

// 5. Update node name for node do not has name.
ret = UpdateAllNodeName(onnx_graph);
if (ret != SUCCESS) {
@@ -582,6 +881,10 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model
return ret;
}

std::vector<string> op_names;
graph.GetAllOpName(op_names);
GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size());

// 8. Set all operator input.
ret = SetOperatorInputs();
if (ret != SUCCESS) {
@@ -589,22 +892,27 @@ Status OnnxModelParser::ModelParseToGraph(const ge::onnx::ModelProto &onnx_model
return ret;
}

std::vector<string> op_names;
graph.GetAllOpName(op_names);
GELOGI("After trans node to operator, graph has the size of operator is %zu.", op_names.size());

// 9. Construct graph.
std::vector<ge::Operator> input_ops;

ret = GetGraphInputs(input_ops);
ret = GetGraphInputs(onnx_graph, input_ops);
if (ret != SUCCESS) {
GELOGE(ret, "Get graph inputs failed.");
return ret;
}
graph.SetInputs(input_ops);

// root graph needn't set outputs.
if(is_subgraph) {
std::vector<std::pair<Operator, std::vector<size_t>>> output_ops;
ret = GetGraphOutputs(output_ops);
if (ret != SUCCESS) {
GELOGE(ret, "[Get][Outputs]Get graph outputs failed.");
return ret;
}
graph.SetOutputs(output_ops);
}

GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph));
UpdateDataFormat(graph);

GELOGI("Onnx model parser success.");
return SUCCESS;


+ 16
- 3
parser/onnx/onnx_parser.h View File

@@ -72,8 +72,10 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {
private:
Status ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph);

Status ParseInput(ge::onnx::GraphProto &onnx_graph,
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor);
Status ParseInput(const std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor,
bool is_subgraph, ge::onnx::GraphProto &onnx_graph);

Status ParseOutput(ge::onnx::GraphProto &onnx_graph);

Status ParseInitializer(ge::onnx::GraphProto &onnx_graph,
std::map<std::string, ge::onnx::TensorProto> &initializer_name_tensor);
@@ -90,7 +92,9 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {

Status SetOperatorInputs();

Status GetGraphInputs(std::vector<ge::Operator> &input_ops);
Status GetGraphInputs(ge::onnx::GraphProto &onnx_graph, std::vector<ge::Operator> &input_ops);

Status GetGraphOutputs(std::vector<std::pair<Operator, std::vector<size_t>>> &outputs);

Status Prechecker(ge::onnx::GraphProto &onnx_graph);
@@ -100,8 +104,15 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {

Status ModelParseToGraph(const ge::onnx::ModelProto &onnx_model, ge::Graph &graph);

Status ModelParseToGraphImpl(bool is_subgraph, ge::onnx::GraphProto &onnx_graph, ge::Graph &graph);

void UpdateDataFormat(ge::Graph &graph);

void ClearMembers();

Status AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph);

std::map<std::string, std::string> ori_to_om_type_;

std::map<std::string, int64_t> domain_verseion_;
@@ -110,6 +121,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {

std::vector<std::string> input_node_names_;

std::vector<std::string> output_node_names_;

std::map<std::string, std::vector<std::pair<std::string, int>>> inputs_map_;

std::map<std::string, std::vector<std::pair<std::string, int>>> outputs_map_;


+ 5
- 0
parser/onnx/onnx_util.cc View File

@@ -60,4 +60,9 @@ int64_t OnnxUtil::CaculateDataSize(int64_t onnx_data_type) {
return ge::DataType::DT_UNDEFINED;
}
}

void OnnxUtil::GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name,
const std::string &parent_node_name, std::string &unique_subgraph_name) {
unique_subgraph_name = parent_node_name + "_" + std::to_string(subgraph_index) + "_" + original_subgraph_name;
}
} // namespace ge

+ 3
- 0
parser/onnx/onnx_util.h View File

@@ -45,6 +45,7 @@ namespace ge {
const char *const kAttrNameValue = "value";
const char *const kAttrNameInput = "input_tensor";
const char *const kAttrNameIndex = "index";
const char *const kAttrNameIsSubgraphOp = "is_subgraph_op";
const char *const kOpTypeConstant = "Constant";
const char *const kOpTypeInput = "Input";

@@ -52,6 +53,8 @@ class OnnxUtil {
public:
static ge::DataType ConvertOnnxDataType(int64_t onnx_data_type);
static int64_t CaculateDataSize(int64_t onnx_data_type);
static void GenUniqueSubgraphName(int subgraph_index, const std::string &original_subgraph_name,
const std::string &parent_node_name, std::string &unique_subgraph_name);
};
} // namespace ge



+ 131
- 0
parser/onnx/subgraph_adapter/if_subgraph_adapter.cc View File

@@ -0,0 +1,131 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "if_subgraph_adapter.h"
#include "subgraph_adapter_factory.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"

namespace ge{
namespace {
const std::map<std::string, int> kAttrNameToIndex = {{"then_branch", 0}, {"else_branch", 1}};
const int kIfNodeAttrSize = 2;
}
Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_node,
std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
GE_CHECK_NOTNULL(parent_node);
GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(),
parent_node->op_type().c_str());

auto ret = ParseIfNodeSubgraphs(parent_node, onnx_graphs, name_to_onnx_graph);
if (ret != SUCCESS) {
GELOGE(ret, "[Parse][Node] Parse if node failed.");
REPORT_CALL_ERROR("E19999", "[Parse][Node] Parse if node:%s failed.", parent_node->name().c_str());
return ret;
}

return SUCCESS;
}

Status IfSubgraphAdapter::ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node,
std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
if (parent_node->attribute_size() != kIfNodeAttrSize) {
GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size());
REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size());
return FAILED;
}

GELOGD("node attribute size:%d.", parent_node->attribute_size());
std::set<std::string> all_inputs;
// for onnx graph, the first attribute may be else branch and the second attribute may be then branch
for (int i = 0; i < parent_node->attribute_size(); i++) {
ge::onnx::AttributeProto *attribute = parent_node->mutable_attribute(i);
GE_CHECK_NOTNULL(attribute);
std::string attr_name = attribute->name();
auto itr = kAttrNameToIndex.find(attr_name);
if (itr == kAttrNameToIndex.end()) {
GELOGE(FAILED, "[Parse][Attribute] Invalid attribute name:%s, it should be then_branch or else_branch.",
attr_name.c_str());
REPORT_INNER_ERROR("E19999", "Invalid attribute name:%s, it should be then_branch or else_branch.",
attr_name.c_str());
return FAILED;
}
std::string unique_subgraph_name;
OnnxUtil::GenUniqueSubgraphName(itr->second, itr->first, parent_node->name(), unique_subgraph_name);
GELOGI("Adapt if node attribute:%s, subgraph name:%s.", attr_name.c_str(), unique_subgraph_name.c_str());
ge::onnx::GraphProto *onnx_graph = attribute->mutable_g();
name_to_onnx_graph[unique_subgraph_name] = onnx_graph;
onnx_graphs.emplace_back(onnx_graph);

auto ret = GetSubgraphsAllInputs(*onnx_graph, all_inputs);
if (ret != SUCCESS) {
GELOGE(ret, "[Get][Inputs] Get subgraph all inputs failed, attr_name:%s.", attr_name.c_str());
REPORT_INNER_ERROR("E19999", "Get subgraph all inputs failed, attr_name:%s.", attr_name.c_str());
return ret;
}
}

for (auto &onnx_graph : onnx_graphs) {
AddInputNodeForGraph(all_inputs, *onnx_graph);
}

AddInputForParentNode(all_inputs, *parent_node);
return SUCCESS;
}

Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph,
std::set<std::string> &all_inputs) {
std::set<std::string> graph_inputs;
std::set<std::string> graph_outputs;
for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i);
for (int j = 0; j < node_proto->input_size(); j++) {
graph_inputs.emplace(node_proto->input(j));
}
for (int j = 0; j < node_proto->output_size(); j++) {
graph_outputs.emplace(node_proto->output(j));
}
}

for (const auto &input : graph_inputs) {
auto out_iter = graph_outputs.find(input);
if (out_iter == graph_outputs.end()) {
// Record input node need to be constructed
all_inputs.emplace(input);
}
}

return SUCCESS;
}

void IfSubgraphAdapter::AddInputNodeForGraph(const std::set<std::string> &all_inputs,
ge::onnx::GraphProto &onnx_graph) {
for (const auto &input_name : all_inputs) {
ge::onnx::ValueInfoProto *value_info = onnx_graph.add_input();
value_info->set_name(input_name);
}
}

void IfSubgraphAdapter::AddInputForParentNode(const std::set<std::string> &all_inputs,
ge::onnx::NodeProto &parent_node) {
for (const auto &input_name : all_inputs) {
parent_node.add_input(input_name);
}
}
REGISTER_SUBGRAPH_ADAPTER_CREATOR(IF, IfSubgraphAdapter);
} // namespace ge

+ 40
- 0
parser/onnx/subgraph_adapter/if_subgraph_adapter.h View File

@@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_PARSER_ONNX_SUBGRAPH_ADAPTER_IF_SUBGRAPH_ADAPTER_H_
#define GE_PARSER_ONNX_SUBGRAPH_ADAPTER_IF_SUBGRAPH_ADAPTER_H_

#include <set>
#include <string>
#include "subgraph_adapter.h"

using ge::onnx::NodeProto;

namespace ge {
class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter {
public:
Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) override;
private:
Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph);
Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs);
void AddInputNodeForGraph(const std::set<std::string> &all_inputs, ge::onnx::GraphProto &onnx_graph);
void AddInputForParentNode(const std::set<std::string> &all_inputs, ge::onnx::NodeProto &parent_node);
};
} // namespace ge

#endif // GE_PARSER_ONNX_SUBGRAPH_ADAPTER_IF_SUBGRAPH_ADAPTER_H_

+ 61
- 0
parser/onnx/subgraph_adapter/subgraph_adapter.h View File

@@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_H_
#define GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_H_

#if defined(_MSC_VER)
#ifdef FUNC_VISIBILITY
#define PARSER_FUNC_VISIBILITY _declspec(dllexport)
#else
#define PARSER_FUNC_VISIBILITY
#endif
#else
#ifdef FUNC_VISIBILITY
#define PARSER_FUNC_VISIBILITY __attribute__((visibility("default")))
#else
#define PARSER_FUNC_VISIBILITY
#endif
#endif

#include <map>
#include <vector>
#include "proto/onnx/ge_onnx.pb.h"
#include "external/register/register_error_codes.h"
#include "framework/omg/parser/parser_types.h"
#include "onnx_util.h"

using Status = domi::Status;
using namespace ge::parser;

namespace ge {
class PARSER_FUNC_VISIBILITY SubgraphAdapter {
public:
/// @brief parse params
/// @param [in/out] parent_op parent op
/// @param [in/out] onnx_graph_tasks onnx graph task
/// @param [in/out] name_to_onnx_graph map name to onnx graph
/// @return SUCCESS parse success
/// @return FAILED Parse failed
virtual Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op,
std::vector<ge::onnx::GraphProto *> &onnx_graphs,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) {
return domi::SUCCESS;
}
};
} // namespace ge

#endif // GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_H_

+ 45
- 0
parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc View File

@@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "subgraph_adapter_factory.h"
#include "framework/common/debug/ge_log.h"

namespace ge{
SubgraphAdapterFactory* SubgraphAdapterFactory::Instance() {
static SubgraphAdapterFactory instance;
return &instance;
}

std::shared_ptr<SubgraphAdapter> SubgraphAdapterFactory::CreateSubgraphAdapter(
const std::string &op_type) {
// First look for CREATOR_FUN based on OpType, then call CREATOR_FUN to create SubgraphAdapter.
auto iter = subgraph_adapter_creator_map_.find(op_type);
if (iter != subgraph_adapter_creator_map_.end()) {
return iter->second();
}

GELOGW("SubgraphAdapterFactory::CreateSubgraphAdapter: Not supported type: %s", op_type.c_str());
return nullptr;
}

// This function is only called within the constructor of the global SubgraphAdapterRegisterar object,
// and does not involve concurrency, so there is no need to lock it
void SubgraphAdapterFactory::RegisterCreator(const std::string &type, CREATOR_FUN fun) {
std::map<std::string, CREATOR_FUN> *subgraph_adapter_creator_map = &subgraph_adapter_creator_map_;
GELOGD("SubgraphAdapterFactory::RegisterCreator: op type:%s.", type.c_str());
(*subgraph_adapter_creator_map)[type] = fun;
}
} // namespace ge

+ 119
- 0
parser/onnx/subgraph_adapter/subgraph_adapter_factory.h View File

@@ -0,0 +1,119 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_FACTORY_H_
#define GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_FACTORY_H_

#if defined(_MSC_VER)
#ifdef FUNC_VISIBILITY
#define PARSER_FUNC_VISIBILITY _declspec(dllexport)
#else
#define PARSER_FUNC_VISIBILITY
#endif
#else
#ifdef FUNC_VISIBILITY
#define PARSER_FUNC_VISIBILITY __attribute__((visibility("default")))
#else
#define PARSER_FUNC_VISIBILITY
#endif
#endif

#include <map>
#include <functional>
#include "subgraph_adapter.h"

namespace ge {
/**
* @brief Used to create OpParser
*
*/
class PARSER_FUNC_VISIBILITY SubgraphAdapterFactory {
public:
/**
* @brief Returns the SubgraphAdapterFactory instance
* @return SubgraphAdapterFactory object
*/
static SubgraphAdapterFactory* Instance();

/**
* @brief Create SubgraphAdapter based on input type
* @param [in] op_type Op type
* @return Created SubgraphAdapter
*/
std::shared_ptr<SubgraphAdapter> CreateSubgraphAdapter(const std::string &op_type);


protected:
/**
* @brief SubgraphAdapter creation function
* @return Created SubgraphAdapter
*/
// typedef shared_ptr<SubgraphAdapter> (*CREATOR_FUN)(void);
using CREATOR_FUN = std::function<std::shared_ptr<SubgraphAdapter>(void)>;

/**
* @brief Factory instances can only be created automatically, not new methods, so the constructor is not public.
*/
SubgraphAdapterFactory() {}

/**
* @brief Register creation function
* @param [in] type Op type
* @param [in] fun OpParser creation function
*/
void RegisterCreator(const std::string &type, CREATOR_FUN fun);

private:
std::map<std::string, CREATOR_FUN> subgraph_adapter_creator_map_; // lint !e1073

friend class SubgraphAdapterRegisterar;
};

/**
* @brief For registering Creator functions for different types of subgraph adapter
*
*/
class PARSER_FUNC_VISIBILITY SubgraphAdapterRegisterar {
public:
/**
* @brief Constructor
* @param [in] op_type Op type
* @param [in] fun Creator function corresponding to Subgrap adapter
*/
SubgraphAdapterRegisterar(const std::string &op_type, SubgraphAdapterFactory::CREATOR_FUN fun) {
SubgraphAdapterFactory::Instance()->RegisterCreator(op_type, fun);
}
~SubgraphAdapterRegisterar() {}
};

/**
* @brief SubgraphAdapter Registration Macro
* @param [in] op_type Op type
* @param [in] clazz SubgraphAdapter implementation class
*/
#define REGISTER_SUBGRAPH_ADAPTER_CREATOR(op_type, clazz) \
std::shared_ptr<SubgraphAdapter> Creator_##op_type##_Subgraph_Adapter() { \
std::shared_ptr<clazz> ptr(new (std::nothrow) clazz()); \
if (ptr == nullptr) { \
GELOGW("MakeShared failed, result is nullptr."); \
} \
return std::shared_ptr<SubgraphAdapter>(ptr); \
} \
ge::SubgraphAdapterRegisterar g_##op_type##_Subgraph_Adapter_Creator(op_type, \
Creator_##op_type##_Subgraph_Adapter)
} // namespace ge

#endif // GE_PARSER_ONNX_SUBGRAPH_ADAPTER_SUBGRAPH_ADAPTER_FACTORY_H_

Loading…
Cancel
Save