Compare commits

...

15 Commits

Author SHA1 Message Date
  i-robot 6c69b97e86 !339 fixed coverity warning 3 years ago
  i-robot 9462b9675d !342 update submodule 3 years ago
  wuweikang 14d1a77ddf update submodule 3 years ago
  李磊 596b630a4d fixed coverity warning 3 years ago
  i-robot b59a36e241 !338 fix coverity 3 years ago
  i-robot b79ef8ad19 !330 update owners 4 years ago
  i-robot aec3c227ca !328 custom op register 4 years ago
  i-robot c074dfa596 !329 update protobuf to 3.13.0 4 years ago
  李磊 57505df4ab update version of protobuf to v3.13.0 4 years ago
  wqtshg 59ac22dfe4 update owners 4 years ago
  i-robot 978cf3a0df !325 update submodule file 4 years ago
  王涛 4c82774e0c update .gitmodules. 4 years ago
  wjm 29a321b404 fix 4 years ago
  wjm 6c9441a473 fix 4 years ago
  wjm 6e3fa785dc custom op register 4 years ago
22 changed files with 357 additions and 201 deletions
Split View
  1. +1
    -1
      .gitmodules
  2. +6
    -2
      OWNERS
  3. +1
    -1
      cmake/external_libs/protobuf_static.cmake
  4. +1
    -1
      metadef
  5. +2
    -143
      parser/caffe/caffe_parser.cc
  6. +0
    -40
      parser/caffe/caffe_parser.h
  7. +1
    -0
      parser/common/CMakeLists.txt
  8. +170
    -0
      parser/common/convert/message2operator.cc
  9. +38
    -0
      parser/common/convert/message2operator.h
  10. +1
    -1
      parser/func_to_graph/func2graph.py
  11. +19
    -4
      parser/onnx/onnx_custom_parser_adapter.cc
  12. +2
    -0
      parser/onnx/onnx_custom_parser_adapter.h
  13. +1
    -1
      parser/onnx/onnx_data_parser.h
  14. +37
    -5
      parser/onnx/onnx_parser.cc
  15. +2
    -0
      parser/onnx/onnx_parser.h
  16. +1
    -1
      parser/onnx/subgraph_adapter/subgraph_adapter_factory.h
  17. +1
    -0
      parser/tensorflow/graph_optimizer.cc
  18. +1
    -1
      parser/tensorflow/tensorflow_frameworkop_parser.cc
  19. +2
    -0
      tests/ut/parser/CMakeLists.txt
  20. +58
    -0
      tests/ut/parser/testcase/onnx_parser_testcase/message2operator_unittest.cc
  21. +7
    -0
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/if.py
  22. +5
    -0
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc

+ 1
- 1
.gitmodules View File

@@ -1,4 +1,4 @@
[submodule "metadef"]
path = metadef
url = https://gitee.com/ascend/metadef.git
branch = master
branch = r1.5.0

+ 6
- 2
OWNERS View File

@@ -1,7 +1,11 @@
approvers:
- ji_chen
- wqtshg
- ljl0711
- startzgf168
- lbisdaddy
- liyihan123
reviewers:
- xchu42
- sheng-nan
- wqtshg
- wangxiaotian22
- zhangxiaokun9

+ 1
- 1
cmake/external_libs/protobuf_static.cmake View File

@@ -15,7 +15,7 @@ else()
set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.13.0.tar.gz")
set(MD5 "f4489cb88922ad9c58cbe3308d59cee5")
else()
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz")
set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.13.0.tar.gz")
set(MD5 "1a6274bc4a65b55a6fa70e264d796490")
endif ()
endif()


+ 1
- 1
metadef

@@ -1 +1 @@
Subproject commit c6030152c6dc05515115765babb5d64fde649df4
Subproject commit 3ace5b6f10e0af784a1c3211fd769d6e8860e864

+ 2
- 143
parser/caffe/caffe_parser.cc View File

@@ -21,6 +21,7 @@
#include <sstream>
#include <memory>
#include <algorithm>
#include "common/convert/message2operator.h"
#include "parser/common/convert/pb2json.h"
#include "parser/common/acl_graph_parser_util.h"
#include "common/op_map.h"
@@ -202,11 +203,9 @@ const int32_t kAnchorIndexTwo = 2;
const int32_t kAnchorIndexThree = 3;
const int32_t kNumOne = 1;
const size_t kTensorNum = 2;
const int kMaxParseDepth = 5;
const int32_t kMinLineWorldSize = 3;
const int32_t kMaxIdentifier = 536870911; // 2^29 - 1
const int32_t kBase = 10;
const uint32_t kInteval = 2;
const char *const kPython = "Python";
const char *const kProposalLayer = "ProposalLayer";
const char *const kDetectionOutput = "DetectionOutput";
@@ -578,7 +577,7 @@ Status CaffeModelParser::CreateCustomOperator(string op_name, string op_type, co
return FAILED;
}

if (ParseOperatorAttrs(message, 1, ops) != SUCCESS) {
if (Message2Operator::ParseOperatorAttrs(message, 1, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", op_name.c_str());
return FAILED;
}
@@ -589,146 +588,6 @@ Status CaffeModelParser::CreateCustomOperator(string op_name, string op_type, co
return SUCCESS;
}

Status CaffeModelParser::ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops) {
if (depth > kMaxParseDepth) {
REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth);
GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth);
return FAILED;
}

const google::protobuf::Reflection *reflection = message->GetReflection();
GE_CHECK_NOTNULL(reflection);
vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);

for (auto &field : field_desc) {
GE_CHECK_NOTNULL(field);
if (field->is_repeated()) {
if (ParseRepeatedField(reflection, message, field, depth, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][RepeatedField] %s failed.", field->name().c_str());
return FAILED;
}
} else {
if (ParseField(reflection, message, field, depth, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][Field] %s failed.", field->name().c_str());
return FAILED;
}
}
}
return SUCCESS;
}

Status CaffeModelParser::ParseField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field,
int depth, ge::Operator &ops) {
GELOGD("Start to parse field: %s.", field->name().c_str());
switch (field->cpp_type()) {
#define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \
valuetype value = reflection->Get##method(*message, field); \
GELOGD("Parse result(%s : %" #logtype ")", field->name().c_str(), value); \
(void)ops.SetAttr(field->name(), value); \
break; \
}
CASE_FIELD_TYPE(INT32, Int32, int32_t, d);
CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u);
CASE_FIELD_TYPE(INT64, Int64, int64_t, ld);
CASE_FIELD_TYPE(FLOAT, Float, float, f);
CASE_FIELD_TYPE(BOOL, Bool, bool, d);
#undef CASE_FIELD_TYPE
case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
GE_CHECK_NOTNULL(reflection->GetEnum(*message, field));
int value = reflection->GetEnum(*message, field)->number();
GELOGD("Parse result(%s : %d)", field->name().c_str(), value);
(void)ops.SetAttr(field->name(), value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_STRING: {
string value = reflection->GetString(*message, field);
GELOGD("Parse result(%s : %s)", field->name().c_str(), value.c_str());
(void)ops.SetAttr(field->name(), value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: {
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field);
if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", field->name().c_str());
return FAILED;
}
break;
}
default: {
REPORT_INPUT_ERROR("E11032", std::vector<std::string>({"message_type", "name", "reason"}),
std::vector<std::string>({"model", field->name(), "Unsupported field type"}));
GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str());
return FAILED;
}
}
GELOGD("Parse field: %s success.", field->name().c_str());
return SUCCESS;
}

Status CaffeModelParser::ParseRepeatedField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth,
ge::Operator &ops) {
GELOGD("Start to parse field: %s.", field->name().c_str());
int field_size = reflection->FieldSize(*message, field);
if (field_size <= 0) {
REPORT_INNER_ERROR("E19999", "Size of repeated field %s must bigger than 0", field->name().c_str());
GELOGE(FAILED, "[Check][Size]Size of repeated field %s must bigger than 0", field->name().c_str());
return FAILED;
}

switch (field->cpp_type()) {
#define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \
vector<valuetype> attr_value; \
for (int i = 0; i < field_size; i++) { \
valuetype value = reflection->GetRepeated##method(*message, field, i); \
attr_value.push_back(value); \
} \
(void)ops.SetAttr(field->name(), attr_value); \
break; \
}
CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t);
CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t);
CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t);
CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float);
CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool);
CASE_FIELD_TYPE_REPEATED(STRING, String, string);
#undef CASE_FIELD_TYPE_REPEATED
case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: {
nlohmann::json message_json;
Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set<string>(),
message_json[field->name()], false);
std::string repeated_message_str;
try {
repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore);
} catch (std::exception &e) {
REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string, reason: %s.", e.what());
GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string, reason: %s.", e.what());
return FAILED;
} catch (...) {
REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string.");
GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string.");
return FAILED;
}
(void)ops.SetAttr(field->name(), repeated_message_str);
break;
}
default: {
REPORT_INPUT_ERROR("E11032", std::vector<std::string>({"message_type", "name", "reason"}),
std::vector<std::string>({"model", field->name(), "Unsupported field type"}));
GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str());
return FAILED;
}
}
GELOGD("Parse repeated field: %s success.", field->name().c_str());
return SUCCESS;
}

void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_index) {
auto iter_node_name = ge::GetParserContext().out_nodes_map.find(layer_name);
if (iter_node_name != ge::GetParserContext().out_nodes_map.end()) {


+ 0
- 40
parser/caffe/caffe_parser.h View File

@@ -209,46 +209,6 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser {
*/
Status CreateCustomOperator(std::string op_name, std::string op_type, const google::protobuf::Message *message,
int index, std::vector<ge::Operator> &operators);

/*
* @ingroup domi_omg
* @brief Parse message and set operator attrs
* @param [in] message, message of model
* @param [in/out] depth, depth of recursion
* @param [out] ops, operator saving custom info
* @return SUCCESS parse message successfully
* @return FAILED parse message failed
*/
Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops);

/*
* @ingroup domi_omg
* @brief Parse field and set operator attrs
* @param [in] reflection, reflection of message
* @param [in] message, message of model
* @param [in] field, field of message
* @param [in/out] depth, depth of recursion
* @param [out] ops, operator saving custom info
* @return SUCCESS parse field successfully
* @return FAILED parse field failed
*/
Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops);

/*
* @ingroup domi_omg
* @brief Parse repeated field and set operator attrs
* @param [in] reflection, reflection of message
* @param [in] message, message of model
* @param [in] field, field of message
* @param [in/out] depth, depth of recursion
* @param [out] ops, operator saving custom info by vector
* @return SUCCESS parse field successfully
* @return FAILED parse field failed
*/
Status ParseRepeatedField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops);

/**
* @ingroup domi_omg
* @brief Add blob information to the bottom_blobs_map and top_blobs_map_


+ 1
- 0
parser/common/CMakeLists.txt View File

@@ -15,6 +15,7 @@ set(SRC_LIST
"../tensorflow/tensorflow_fusion_op_parser.cc"
"../tensorflow/tensorflow_util.cc"
"convert/pb2json.cc"
"convert/message2operator.cc"
"op_def/ir_pb_converter.cc"
"op_def/defs.cc"
"op_def/op_schema.cc"


+ 170
- 0
parser/common/convert/message2operator.cc View File

@@ -0,0 +1,170 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "message2operator.h"

#include <vector>

#include "common/convert/pb2json.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"

namespace ge {
namespace {
const int kMaxParseDepth = 5;
const uint32_t kInteval = 2;
} // namespace

Status Message2Operator::ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops) {
GE_CHECK_NOTNULL(message);
if (depth > kMaxParseDepth) {
REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth);
GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth);
return FAILED;
}

const google::protobuf::Reflection *reflection = message->GetReflection();
GE_CHECK_NOTNULL(reflection);
std::vector<const google::protobuf::FieldDescriptor *> field_desc;
reflection->ListFields(*message, &field_desc);

for (auto &field : field_desc) {
GE_CHECK_NOTNULL(field);
if (field->is_repeated()) {
if (ParseRepeatedField(reflection, message, field, depth, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][RepeatedField] %s failed.", field->name().c_str());
return FAILED;
}
} else {
if (ParseField(reflection, message, field, depth, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][Field] %s failed.", field->name().c_str());
return FAILED;
}
}
}
return SUCCESS;
}

Status Message2Operator::ParseField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops) {
GELOGD("Start to parse field: %s.", field->name().c_str());
switch (field->cpp_type()) {
#define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \
valuetype value = reflection->Get##method(*message, field); \
GELOGD("Parse result(%s : %" #logtype ")", field->name().c_str(), value); \
(void)ops.SetAttr(field->name(), value); \
break; \
}
CASE_FIELD_TYPE(INT32, Int32, int32_t, d);
CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u);
CASE_FIELD_TYPE(INT64, Int64, int64_t, ld);
CASE_FIELD_TYPE(FLOAT, Float, float, f);
CASE_FIELD_TYPE(BOOL, Bool, bool, d);
#undef CASE_FIELD_TYPE
case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: {
GE_CHECK_NOTNULL(reflection->GetEnum(*message, field));
int value = reflection->GetEnum(*message, field)->number();
GELOGD("Parse result(%s : %d)", field->name().c_str(), value);
(void)ops.SetAttr(field->name(), value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_STRING: {
string value = reflection->GetString(*message, field);
GELOGD("Parse result(%s : %s)", field->name().c_str(), value.c_str());
(void)ops.SetAttr(field->name(), value);
break;
}
case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: {
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field);
if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) {
GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", field->name().c_str());
return FAILED;
}
break;
}
default: {
REPORT_INPUT_ERROR("E11032", std::vector<std::string>({"message_type", "name", "reason"}),
std::vector<std::string>({"model", field->name(), "Unsupported field type"}));
GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str());
return FAILED;
}
}
GELOGD("Parse field: %s success.", field->name().c_str());
return SUCCESS;
}

Status Message2Operator::ParseRepeatedField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth,
ge::Operator &ops) {
GELOGD("Start to parse field: %s.", field->name().c_str());
int field_size = reflection->FieldSize(*message, field);
if (field_size <= 0) {
REPORT_INNER_ERROR("E19999", "Size of repeated field %s must bigger than 0", field->name().c_str());
GELOGE(FAILED, "[Check][Size]Size of repeated field %s must bigger than 0", field->name().c_str());
return FAILED;
}

switch (field->cpp_type()) {
#define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \
std::vector<valuetype> attr_value; \
for (int i = 0; i < field_size; i++) { \
valuetype value = reflection->GetRepeated##method(*message, field, i); \
attr_value.push_back(value); \
} \
(void)ops.SetAttr(field->name(), attr_value); \
break; \
}
CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t);
CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t);
CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t);
CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float);
CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool);
CASE_FIELD_TYPE_REPEATED(STRING, String, string);
#undef CASE_FIELD_TYPE_REPEATED
case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: {
nlohmann::json message_json;
Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set<string>(), message_json[field->name()],
false);
std::string repeated_message_str;
try {
repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore);
} catch (std::exception &e) {
REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string, reason: %s.", e.what());
GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string, reason: %s.", e.what());
return FAILED;
} catch (...) {
REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string.");
GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string.");
return FAILED;
}
(void)ops.SetAttr(field->name(), repeated_message_str);
break;
}
default: {
REPORT_INPUT_ERROR("E11032", std::vector<std::string>({"message_type", "name", "reason"}),
std::vector<std::string>({"model", field->name(), "Unsupported field type"}));
GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str());
return FAILED;
}
}
GELOGD("Parse repeated field: %s success.", field->name().c_str());
return SUCCESS;
}
} // namespace ge

+ 38
- 0
parser/common/convert/message2operator.h View File

@@ -0,0 +1,38 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PARSER_MESSAGE2OPERATOR_H
#define PARSER_MESSAGE2OPERATOR_H

#include "external/ge/ge_api_error_codes.h"
#include "external/graph/operator.h"
#include "google/protobuf/message.h"

namespace ge {
class Message2Operator {
public:
static Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops);

private:
static Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops);

static Status ParseRepeatedField(const google::protobuf::Reflection *reflection,
const google::protobuf::Message *message,
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops);
};
} // namespace ge
#endif // PARSER_MESSAGE2OPERATOR_H

+ 1
- 1
parser/func_to_graph/func2graph.py View File

@@ -1,4 +1,4 @@
#!/usr/bin/python3
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#-------------------------------------------------------------------
# Purpose:


+ 19
- 4
parser/onnx/onnx_custom_parser_adapter.cc View File

@@ -15,23 +15,25 @@
*/

#include "parser/onnx/onnx_custom_parser_adapter.h"

#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "parser/common/op_parser_factory.h"
#include "register/op_registry.h"

using domi::ParseParamFunc;
using domi::ONNX;
using domi::ParseParamByOpFunc;
using domi::ParseParamFunc;

namespace ge{
namespace ge {
Status OnnxCustomParserAdapter::ParseParams(const Message *op_src, ge::Operator &op_dest) {
GE_CHECK_NOTNULL(op_src);
const ge::onnx::NodeProto *node_src = reinterpret_cast<const ge::onnx::NodeProto *>(op_src);
GE_CHECK_NOTNULL(node_src);
GELOGI("Onnx op node name = %s, op type= %s, parse params.", node_src->name().c_str(), node_src->op_type().c_str());

ParseParamFunc
custom_op_parser = domi::OpRegistry::Instance()->GetParseParamFunc(op_dest.GetOpType(), node_src->op_type());
ParseParamFunc custom_op_parser =
domi::OpRegistry::Instance()->GetParseParamFunc(op_dest.GetOpType(), node_src->op_type());
GE_CHECK_NOTNULL(custom_op_parser);
if (custom_op_parser(op_src, op_dest) != SUCCESS) {
GELOGE(FAILED, "[Invoke][Custom_Op_Parser] Custom parser params failed.");
@@ -40,5 +42,18 @@ Status OnnxCustomParserAdapter::ParseParams(const Message *op_src, ge::Operator
return SUCCESS;
}

Status OnnxCustomParserAdapter::ParseParams(const Operator &op_src, Operator &op_dest) {
ParseParamByOpFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_src.GetOpType());
GE_CHECK_NOTNULL(custom_op_parser);

if (custom_op_parser(op_src, op_dest) != SUCCESS) {
GELOGE(FAILED, "[Invoke][Custom_Op_Parser] failed, node name:%s, type:%s", op_src.GetName().c_str(),
op_src.GetOpType().c_str());
return FAILED;
}

return SUCCESS;
}

REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(ONNX, OnnxCustomParserAdapter);
} // namespace ge

+ 2
- 0
parser/onnx/onnx_custom_parser_adapter.h View File

@@ -28,6 +28,8 @@ class PARSER_FUNC_VISIBILITY OnnxCustomParserAdapter : public OnnxOpParser {
/// @return SUCCESS parse successfully
/// @return FAILED parse failed
Status ParseParams(const Message *op_src, ge::Operator &op_dest) override;

Status ParseParams(const Operator &op_src, Operator &op_dest);
};
} // namespace ge



+ 1
- 1
parser/onnx/onnx_data_parser.h View File

@@ -42,7 +42,7 @@ class PARSER_FUNC_VISIBILITY OnnxDataParser : public OnnxOpParser {

std::vector<int64_t> user_input_dims_v_;

bool is_subgraph_data_op_;
bool is_subgraph_data_op_ = false;
};
} // namespace ge



+ 37
- 5
parser/onnx/onnx_parser.cc View File

@@ -18,6 +18,7 @@
#include <algorithm>
#include <iostream>
#include <queue>
#include "common/convert/message2operator.h"
#include "common/convert/pb2json.h"
#include "common/util.h"
#include "common/util/error_manager/error_manager.h"
@@ -36,6 +37,7 @@
#include "parser/common/model_saver.h"
#include "parser/common/parser_utils.h"
#include "parser/common/prototype_pass_manager.h"
#include "parser/onnx/onnx_custom_parser_adapter.h"
#include "parser/onnx/onnx_util.h"
#include "register/op_registry.h"
#include "register/register_fmk_types.h"
@@ -555,6 +557,40 @@ Status OnnxModelParser::Prechecker(ge::onnx::GraphProto &onnx_graph) {
return SUCCESS;
}

Status OnnxModelParser::ParseOpParam(const ge::onnx::NodeProto *node_proto, ge::Operator &op,
std::shared_ptr<OpParser> &op_parser) {
GE_CHECK_NOTNULL(node_proto);
GE_CHECK_NOTNULL(op_parser);
std::string op_type = node_proto->op_type();

Status status = FAILED;
domi::ParseParamByOpFunc parse_param_func = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_type);
if (parse_param_func == nullptr) {
status = op_parser->ParseParams(node_proto, op);
} else {
ge::Operator op_src(node_proto->name(), op_type);
status = Message2Operator::ParseOperatorAttrs(node_proto, 1, op_src);
if (status != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Auto mapping node:%s(%s) to operator failed",
node_proto->name().c_str(), op_type.c_str());
GELOGE(status, "Node[%s] auto mapping failed.", node_proto->name().c_str());
return status;
}
std::shared_ptr<ge::OnnxCustomParserAdapter> onnx_custom_op_parser =
std::dynamic_pointer_cast<ge::OnnxCustomParserAdapter>(op_parser);
status = onnx_custom_op_parser->ParseParams(op_src, op);
op_src.BreakConnect();
}

if (status != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E11010", {"opname", "optype"}, {node_proto->name(), op_type});
GELOGE(status, "[Parse][Params] for op [%s] fail, optype [%s]", node_proto->name().c_str(), op_type.c_str());
return status;
}

return SUCCESS;
}

Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) {
for (int i = 0; i < onnx_graph.node_size(); i++) {
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i);
@@ -586,11 +622,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
GE_CHECK_NOTNULL(factory);
std::shared_ptr<ge::OpParser> op_parser = factory->CreateOpParser(op_type);
GE_CHECK_NOTNULL(op_parser);
std::shared_ptr<ge::OnnxOpParser> onnx_op_parser = std::static_pointer_cast<ge::OnnxOpParser>(op_parser);
GE_CHECK_NOTNULL(onnx_op_parser);
status = onnx_op_parser->ParseParams(node_proto, op);
status = ParseOpParam(node_proto, op, op_parser);
if (status != SUCCESS) {
REPORT_CALL_ERROR("E19999", "ParseParams for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status);
GELOGE(status, "[Parse][Params] for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status);
return status;
}
@@ -598,7 +631,6 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::
GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(),
op.GetOpType().c_str(), op.GetInputsSize(), op.GetOutputsSize());


ge::graphStatus graph_status = graph.AddOp(op);
if (graph_status != ge::GRAPH_SUCCESS) {
GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", op.GetName().c_str());


+ 2
- 0
parser/onnx/onnx_parser.h View File

@@ -110,6 +110,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser {

void ClearMembers();

Status ParseOpParam(const ge::onnx::NodeProto *node_proto, ge::Operator &op, std::shared_ptr<OpParser> &op_parser);

Status AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph,
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph);



+ 1
- 1
parser/onnx/subgraph_adapter/subgraph_adapter_factory.h View File

@@ -55,7 +55,7 @@ public:
*/
std::shared_ptr<SubgraphAdapter> CreateSubgraphAdapter(const std::string &op_type);

~SubgraphAdapterFactory() = default;
protected:
/**
* @brief SubgraphAdapter creation function


+ 1
- 0
parser/tensorflow/graph_optimizer.cc View File

@@ -1457,6 +1457,7 @@ Status CollectNodeFuncs(vector<ge::NodePtr> &nodes, FunctionDefLibrary *library)

GE_IF_BOOL_EXEC(
AttrUtils::GetBytes(opDef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes), FunctionDefLibrary funcLib;
GE_CHECK_NOTNULL(funcDefBytes.GetData());
string str(reinterpret_cast<char *>(funcDefBytes.GetData()), funcDefBytes.GetSize());
GELOGI("FUNCDEF: Get function -> %s.", str.c_str()); GE_IF_BOOL_EXEC(
funcLib.ParseFromArray(funcDefBytes.GetData(), funcDefBytes.GetSize()), library->MergeFrom(funcLib)));


+ 1
- 1
parser/tensorflow/tensorflow_frameworkop_parser.cc View File

@@ -75,9 +75,9 @@ Status ParseParams(const Message *op_src, FrameworkOpOperator *op) {
op->TfOpDef(attr_v.s());
} else {
GE_CHK_BOOL_EXEC(type == "_Retval",
GE_DELETE_NEW_SINGLE(pkg_node);
REPORT_INNER_ERROR("E19999", "In NodeDef:%s Attr:opdef is not exist, check invalid",
pkg_node->name().c_str());
GE_DELETE_NEW_SINGLE(pkg_node);
return PARAM_INVALID, "In NodeDef %s Attr opdef is not exist.", pkg_node->name().c_str());
}



+ 2
- 0
tests/ut/parser/CMakeLists.txt View File

@@ -221,6 +221,7 @@ set(PARSER_SRC_FILES
"${PARSER_DIR}/parser/caffe/caffe_reshape_parser.cc"
"${PARSER_DIR}/parser/common/acl_graph_parser_util.cc"
"${PARSER_DIR}/parser/common/convert/pb2json.cc"
"${PARSER_DIR}/parser/common/convert/message2operator.cc"
"${PARSER_DIR}/parser/common/data_op_parser.cc"
"${PARSER_DIR}/parser/common/model_saver.cc"
"${PARSER_DIR}/parser/common/op_def/arg_op.cc"
@@ -305,6 +306,7 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework)

set(PARSER_UT_FILES
"testcase/onnx_parser_testcase/onnx_parser_unittest.cc"
"testcase/onnx_parser_testcase/message2operator_unittest.cc"
"testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc"
)



+ 58
- 0
tests/ut/parser/testcase/onnx_parser_testcase/message2operator_unittest.cc View File

@@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "common/convert/message2operator.h"

#include <gtest/gtest.h>

#include "proto/onnx/ge_onnx.pb.h"

namespace ge {
class UtestMessage2Operator : public testing::Test {
protected:
void SetUp() {}

void TearDown() {}
};

TEST_F(UtestMessage2Operator, message_to_operator_success) {
ge::onnx::NodeProto input_node;
ge::onnx::AttributeProto *attribute = input_node.add_attribute();
attribute->set_name("attribute");
attribute->set_type(onnx::AttributeProto::AttributeType(1));
attribute->set_f(1.0);
ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t();
attribute_tensor->set_data_type(1);
attribute_tensor->add_dims(4);
ge::Operator op_src("add", "Add");
auto ret = Message2Operator::ParseOperatorAttrs(attribute, 1, op_src);
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestMessage2Operator, message_to_operator_fail) {
ge::onnx::NodeProto input_node;
ge::onnx::AttributeProto *attribute = input_node.add_attribute();
ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t();
attribute_tensor->add_double_data(1.00);

ge::Operator op_src("add", "Add");
auto ret = Message2Operator::ParseOperatorAttrs(attribute, 6, op_src);
EXPECT_EQ(ret, FAILED);

ret = Message2Operator::ParseOperatorAttrs(attribute, 1, op_src);
EXPECT_EQ(ret, FAILED);
}
} // namespace ge

+ 7
- 0
tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/if.py View File

@@ -1,3 +1,10 @@
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
#-------------------------------------------------------------------
# Purpose:
# Copyright 2021 Huawei Technologies Co., Ltd. All rights reserved.
#-------------------------------------------------------------------

# Given a bool scalar input cond.
# return constant tensor x if cond is True, otherwise return constant tensor y.
import numpy as np


+ 5
- 0
tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc View File

@@ -39,6 +39,10 @@ static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator&
return SUCCESS;
}

static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_dest) {
return SUCCESS;
}

Status ParseSubgraphPostFnIf(const std::string& subgraph_name, const ge::Graph& graph) {
domi::AutoMappingSubgraphIOIndexFunc auto_mapping_subgraph_index_func =
domi::FrameworkRegistry::Instance().GetAutoMappingSubgraphIOIndexFunc(domi::ONNX);
@@ -72,6 +76,7 @@ void UtestOnnxParser::RegisterCustomOp() {
"ai.onnx::12::If",
"ai.onnx::13::If"})
.ParseParamsFn(ParseParams)
.ParseParamsByOperatorFn(ParseParamByOpFunc)
.ParseSubgraphPostFn(ParseSubgraphPostFnIf);

REGISTER_CUSTOM_OP("Add")


Loading…
Cancel
Save