@@ -21,6 +21,7 @@ | |||
#include <sstream> | |||
#include <memory> | |||
#include <algorithm> | |||
#include "common/convert/message2operator.h" | |||
#include "parser/common/convert/pb2json.h" | |||
#include "parser/common/acl_graph_parser_util.h" | |||
#include "common/op_map.h" | |||
@@ -578,7 +579,7 @@ Status CaffeModelParser::CreateCustomOperator(string op_name, string op_type, co | |||
return FAILED; | |||
} | |||
if (ParseOperatorAttrs(message, 1, ops) != SUCCESS) { | |||
if (Message2Operator::ParseOperatorAttrs(message, 1, ops) != SUCCESS) { | |||
GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", op_name.c_str()); | |||
return FAILED; | |||
} | |||
@@ -589,146 +590,6 @@ Status CaffeModelParser::CreateCustomOperator(string op_name, string op_type, co | |||
return SUCCESS; | |||
} | |||
Status CaffeModelParser::ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops) { | |||
if (depth > kMaxParseDepth) { | |||
REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth); | |||
GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth); | |||
return FAILED; | |||
} | |||
const google::protobuf::Reflection *reflection = message->GetReflection(); | |||
GE_CHECK_NOTNULL(reflection); | |||
vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
reflection->ListFields(*message, &field_desc); | |||
for (auto &field : field_desc) { | |||
GE_CHECK_NOTNULL(field); | |||
if (field->is_repeated()) { | |||
if (ParseRepeatedField(reflection, message, field, depth, ops) != SUCCESS) { | |||
GELOGE(FAILED, "[Parse][RepeatedField] %s failed.", field->name().c_str()); | |||
return FAILED; | |||
} | |||
} else { | |||
if (ParseField(reflection, message, field, depth, ops) != SUCCESS) { | |||
GELOGE(FAILED, "[Parse][Field] %s failed.", field->name().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status CaffeModelParser::ParseField(const google::protobuf::Reflection *reflection, | |||
const google::protobuf::Message *message, | |||
const google::protobuf::FieldDescriptor *field, | |||
int depth, ge::Operator &ops) { | |||
GELOGD("Start to parse field: %s.", field->name().c_str()); | |||
switch (field->cpp_type()) { | |||
#define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \ | |||
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ | |||
valuetype value = reflection->Get##method(*message, field); \ | |||
GELOGD("Parse result(%s : %" #logtype ")", field->name().c_str(), value); \ | |||
(void)ops.SetAttr(field->name(), value); \ | |||
break; \ | |||
} | |||
CASE_FIELD_TYPE(INT32, Int32, int32_t, d); | |||
CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u); | |||
CASE_FIELD_TYPE(INT64, Int64, int64_t, ld); | |||
CASE_FIELD_TYPE(FLOAT, Float, float, f); | |||
CASE_FIELD_TYPE(BOOL, Bool, bool, d); | |||
#undef CASE_FIELD_TYPE | |||
case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { | |||
GE_CHECK_NOTNULL(reflection->GetEnum(*message, field)); | |||
int value = reflection->GetEnum(*message, field)->number(); | |||
GELOGD("Parse result(%s : %d)", field->name().c_str(), value); | |||
(void)ops.SetAttr(field->name(), value); | |||
break; | |||
} | |||
case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { | |||
string value = reflection->GetString(*message, field); | |||
GELOGD("Parse result(%s : %s)", field->name().c_str(), value.c_str()); | |||
(void)ops.SetAttr(field->name(), value); | |||
break; | |||
} | |||
case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { | |||
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); | |||
if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) { | |||
GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", field->name().c_str()); | |||
return FAILED; | |||
} | |||
break; | |||
} | |||
default: { | |||
REPORT_INPUT_ERROR("E11032", std::vector<std::string>({"message_type", "name", "reason"}), | |||
std::vector<std::string>({"model", field->name(), "Unsupported field type"})); | |||
GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
GELOGD("Parse field: %s success.", field->name().c_str()); | |||
return SUCCESS; | |||
} | |||
Status CaffeModelParser::ParseRepeatedField(const google::protobuf::Reflection *reflection, | |||
const google::protobuf::Message *message, | |||
const google::protobuf::FieldDescriptor *field, int depth, | |||
ge::Operator &ops) { | |||
GELOGD("Start to parse field: %s.", field->name().c_str()); | |||
int field_size = reflection->FieldSize(*message, field); | |||
if (field_size <= 0) { | |||
REPORT_INNER_ERROR("E19999", "Size of repeated field %s must bigger than 0", field->name().c_str()); | |||
GELOGE(FAILED, "[Check][Size]Size of repeated field %s must bigger than 0", field->name().c_str()); | |||
return FAILED; | |||
} | |||
switch (field->cpp_type()) { | |||
#define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \ | |||
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ | |||
vector<valuetype> attr_value; \ | |||
for (int i = 0; i < field_size; i++) { \ | |||
valuetype value = reflection->GetRepeated##method(*message, field, i); \ | |||
attr_value.push_back(value); \ | |||
} \ | |||
(void)ops.SetAttr(field->name(), attr_value); \ | |||
break; \ | |||
} | |||
CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t); | |||
CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t); | |||
CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t); | |||
CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float); | |||
CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool); | |||
CASE_FIELD_TYPE_REPEATED(STRING, String, string); | |||
#undef CASE_FIELD_TYPE_REPEATED | |||
case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { | |||
nlohmann::json message_json; | |||
Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set<string>(), | |||
message_json[field->name()], false); | |||
std::string repeated_message_str; | |||
try { | |||
repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore); | |||
} catch (std::exception &e) { | |||
REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string, reason: %s.", e.what()); | |||
GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string, reason: %s.", e.what()); | |||
return FAILED; | |||
} catch (...) { | |||
REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string."); | |||
GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string."); | |||
return FAILED; | |||
} | |||
(void)ops.SetAttr(field->name(), repeated_message_str); | |||
break; | |||
} | |||
default: { | |||
REPORT_INPUT_ERROR("E11032", std::vector<std::string>({"message_type", "name", "reason"}), | |||
std::vector<std::string>({"model", field->name(), "Unsupported field type"})); | |||
GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
GELOGD("Parse repeated field: %s success.", field->name().c_str()); | |||
return SUCCESS; | |||
} | |||
void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_index) { | |||
auto iter_node_name = ge::GetParserContext().out_nodes_map.find(layer_name); | |||
if (iter_node_name != ge::GetParserContext().out_nodes_map.end()) { | |||
@@ -209,46 +209,6 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { | |||
*/ | |||
Status CreateCustomOperator(std::string op_name, std::string op_type, const google::protobuf::Message *message, | |||
int index, std::vector<ge::Operator> &operators); | |||
/* | |||
* @ingroup domi_omg | |||
* @brief Parse message and set operator attrs | |||
* @param [in] message, message of model | |||
* @param [in/out] depth, depth of recursion | |||
* @param [out] ops, operator saving custom info | |||
* @return SUCCESS parse message successfully | |||
* @return FAILED parse message failed | |||
*/ | |||
Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops); | |||
/* | |||
* @ingroup domi_omg | |||
* @brief Parse field and set operator attrs | |||
* @param [in] reflection, reflection of message | |||
* @param [in] message, message of model | |||
* @param [in] field, field of message | |||
* @param [in/out] depth, depth of recursion | |||
* @param [out] ops, operator saving custom info | |||
* @return SUCCESS parse field successfully | |||
* @return FAILED parse field failed | |||
*/ | |||
Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, | |||
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); | |||
/* | |||
* @ingroup domi_omg | |||
* @brief Parse repeated field and set operator attrs | |||
* @param [in] reflection, reflection of message | |||
* @param [in] message, message of model | |||
* @param [in] field, field of message | |||
* @param [in/out] depth, depth of recursion | |||
* @param [out] ops, operator saving custom info by vector | |||
* @return SUCCESS parse field successfully | |||
* @return FAILED parse field failed | |||
*/ | |||
Status ParseRepeatedField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, | |||
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); | |||
/** | |||
* @ingroup domi_omg | |||
* @brief Add blob information to the bottom_blobs_map and top_blobs_map_ | |||
@@ -15,6 +15,7 @@ set(SRC_LIST | |||
"../tensorflow/tensorflow_fusion_op_parser.cc" | |||
"../tensorflow/tensorflow_util.cc" | |||
"convert/pb2json.cc" | |||
"convert/message2operator.cc" | |||
"op_def/ir_pb_converter.cc" | |||
"op_def/defs.cc" | |||
"op_def/op_schema.cc" | |||
@@ -0,0 +1,170 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "message2operator.h" | |||
#include <vector> | |||
#include "common/convert/pb2json.h" | |||
#include "common/util.h" | |||
#include "framework/common/debug/ge_log.h" | |||
namespace ge { | |||
namespace { | |||
const int kMaxParseDepth = 5; | |||
const uint32_t kInteval = 2; | |||
} // namespace | |||
Status Message2Operator::ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops) { | |||
GE_CHECK_NOTNULL(message); | |||
if (depth > kMaxParseDepth) { | |||
REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth); | |||
GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth); | |||
return FAILED; | |||
} | |||
const google::protobuf::Reflection *reflection = message->GetReflection(); | |||
GE_CHECK_NOTNULL(reflection); | |||
std::vector<const google::protobuf::FieldDescriptor *> field_desc; | |||
reflection->ListFields(*message, &field_desc); | |||
for (auto &field : field_desc) { | |||
GE_CHECK_NOTNULL(field); | |||
if (field->is_repeated()) { | |||
if (ParseRepeatedField(reflection, message, field, depth, ops) != SUCCESS) { | |||
GELOGE(FAILED, "[Parse][RepeatedField] %s failed.", field->name().c_str()); | |||
return FAILED; | |||
} | |||
} else { | |||
if (ParseField(reflection, message, field, depth, ops) != SUCCESS) { | |||
GELOGE(FAILED, "[Parse][Field] %s failed.", field->name().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
Status Message2Operator::ParseField(const google::protobuf::Reflection *reflection, | |||
const google::protobuf::Message *message, | |||
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops) { | |||
GELOGD("Start to parse field: %s.", field->name().c_str()); | |||
switch (field->cpp_type()) { | |||
#define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \ | |||
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ | |||
valuetype value = reflection->Get##method(*message, field); \ | |||
GELOGD("Parse result(%s : %" #logtype ")", field->name().c_str(), value); \ | |||
(void)ops.SetAttr(field->name(), value); \ | |||
break; \ | |||
} | |||
CASE_FIELD_TYPE(INT32, Int32, int32_t, d); | |||
CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u); | |||
CASE_FIELD_TYPE(INT64, Int64, int64_t, ld); | |||
CASE_FIELD_TYPE(FLOAT, Float, float, f); | |||
CASE_FIELD_TYPE(BOOL, Bool, bool, d); | |||
#undef CASE_FIELD_TYPE | |||
case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { | |||
GE_CHECK_NOTNULL(reflection->GetEnum(*message, field)); | |||
int value = reflection->GetEnum(*message, field)->number(); | |||
GELOGD("Parse result(%s : %d)", field->name().c_str(), value); | |||
(void)ops.SetAttr(field->name(), value); | |||
break; | |||
} | |||
case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { | |||
string value = reflection->GetString(*message, field); | |||
GELOGD("Parse result(%s : %s)", field->name().c_str(), value.c_str()); | |||
(void)ops.SetAttr(field->name(), value); | |||
break; | |||
} | |||
case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { | |||
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); | |||
if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) { | |||
GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", field->name().c_str()); | |||
return FAILED; | |||
} | |||
break; | |||
} | |||
default: { | |||
REPORT_INPUT_ERROR("E11032", std::vector<std::string>({"message_type", "name", "reason"}), | |||
std::vector<std::string>({"model", field->name(), "Unsupported field type"})); | |||
GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
GELOGD("Parse field: %s success.", field->name().c_str()); | |||
return SUCCESS; | |||
} | |||
Status Message2Operator::ParseRepeatedField(const google::protobuf::Reflection *reflection, | |||
const google::protobuf::Message *message, | |||
const google::protobuf::FieldDescriptor *field, int depth, | |||
ge::Operator &ops) { | |||
GELOGD("Start to parse field: %s.", field->name().c_str()); | |||
int field_size = reflection->FieldSize(*message, field); | |||
if (field_size <= 0) { | |||
REPORT_INNER_ERROR("E19999", "Size of repeated field %s must bigger than 0", field->name().c_str()); | |||
GELOGE(FAILED, "[Check][Size]Size of repeated field %s must bigger than 0", field->name().c_str()); | |||
return FAILED; | |||
} | |||
switch (field->cpp_type()) { | |||
#define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \ | |||
case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ | |||
std::vector<valuetype> attr_value; \ | |||
for (int i = 0; i < field_size; i++) { \ | |||
valuetype value = reflection->GetRepeated##method(*message, field, i); \ | |||
attr_value.push_back(value); \ | |||
} \ | |||
(void)ops.SetAttr(field->name(), attr_value); \ | |||
break; \ | |||
} | |||
CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t); | |||
CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t); | |||
CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t); | |||
CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float); | |||
CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool); | |||
CASE_FIELD_TYPE_REPEATED(STRING, String, string); | |||
#undef CASE_FIELD_TYPE_REPEATED | |||
case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { | |||
nlohmann::json message_json; | |||
Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set<string>(), message_json[field->name()], | |||
false); | |||
std::string repeated_message_str; | |||
try { | |||
repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore); | |||
} catch (std::exception &e) { | |||
REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string, reason: %s.", e.what()); | |||
GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string, reason: %s.", e.what()); | |||
return FAILED; | |||
} catch (...) { | |||
REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string."); | |||
GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string."); | |||
return FAILED; | |||
} | |||
(void)ops.SetAttr(field->name(), repeated_message_str); | |||
break; | |||
} | |||
default: { | |||
REPORT_INPUT_ERROR("E11032", std::vector<std::string>({"message_type", "name", "reason"}), | |||
std::vector<std::string>({"model", field->name(), "Unsupported field type"})); | |||
GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); | |||
return FAILED; | |||
} | |||
} | |||
GELOGD("Parse repeated field: %s success.", field->name().c_str()); | |||
return SUCCESS; | |||
} | |||
} // namespace ge |
@@ -0,0 +1,38 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef PARSER_MESSAGE2OPERATOR_H | |||
#define PARSER_MESSAGE2OPERATOR_H | |||
#include "external/ge/ge_api_error_codes.h" | |||
#include "external/graph/operator.h" | |||
#include "google/protobuf/message.h" | |||
namespace ge { | |||
class Message2Operator { | |||
public: | |||
static Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops); | |||
private: | |||
static Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, | |||
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); | |||
static Status ParseRepeatedField(const google::protobuf::Reflection *reflection, | |||
const google::protobuf::Message *message, | |||
const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); | |||
}; | |||
} // namespace ge | |||
#endif // PARSER_MESSAGE2OPERATOR_H |
@@ -15,23 +15,25 @@ | |||
*/ | |||
#include "parser/onnx/onnx_custom_parser_adapter.h" | |||
#include "common/util.h" | |||
#include "framework/common/debug/ge_log.h" | |||
#include "parser/common/op_parser_factory.h" | |||
#include "register/op_registry.h" | |||
using domi::ParseParamFunc; | |||
using domi::ONNX; | |||
using domi::ParseParamByOpFunc; | |||
using domi::ParseParamFunc; | |||
namespace ge{ | |||
namespace ge { | |||
Status OnnxCustomParserAdapter::ParseParams(const Message *op_src, ge::Operator &op_dest) { | |||
GE_CHECK_NOTNULL(op_src); | |||
const ge::onnx::NodeProto *node_src = reinterpret_cast<const ge::onnx::NodeProto *>(op_src); | |||
GE_CHECK_NOTNULL(node_src); | |||
GELOGI("Onnx op node name = %s, op type= %s, parse params.", node_src->name().c_str(), node_src->op_type().c_str()); | |||
ParseParamFunc | |||
custom_op_parser = domi::OpRegistry::Instance()->GetParseParamFunc(op_dest.GetOpType(), node_src->op_type()); | |||
ParseParamFunc custom_op_parser = | |||
domi::OpRegistry::Instance()->GetParseParamFunc(op_dest.GetOpType(), node_src->op_type()); | |||
GE_CHECK_NOTNULL(custom_op_parser); | |||
if (custom_op_parser(op_src, op_dest) != SUCCESS) { | |||
GELOGE(FAILED, "[Invoke][Custom_Op_Parser] Custom parser params failed."); | |||
@@ -40,5 +42,18 @@ Status OnnxCustomParserAdapter::ParseParams(const Message *op_src, ge::Operator | |||
return SUCCESS; | |||
} | |||
Status OnnxCustomParserAdapter::ParseParams(const Operator &op_src, Operator &op_dest) { | |||
ParseParamByOpFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_src.GetOpType()); | |||
GE_CHECK_NOTNULL(custom_op_parser); | |||
if (custom_op_parser(op_src, op_dest) != SUCCESS) { | |||
GELOGE(FAILED, "[Invoke][Custom_Op_Parser] failed, node name:%s, type:%s", op_src.GetName().c_str(), | |||
op_src.GetOpType().c_str()); | |||
return FAILED; | |||
} | |||
return SUCCESS; | |||
} | |||
REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(ONNX, OnnxCustomParserAdapter); | |||
} // namespace ge |
@@ -28,6 +28,8 @@ class PARSER_FUNC_VISIBILITY OnnxCustomParserAdapter : public OnnxOpParser { | |||
/// @return SUCCESS parse successfully | |||
/// @return FAILED parse failed | |||
Status ParseParams(const Message *op_src, ge::Operator &op_dest) override; | |||
Status ParseParams(const Operator &op_src, Operator &op_dest); | |||
}; | |||
} // namespace ge | |||
@@ -18,6 +18,7 @@ | |||
#include <algorithm> | |||
#include <iostream> | |||
#include <queue> | |||
#include "common/convert/message2operator.h" | |||
#include "common/convert/pb2json.h" | |||
#include "common/util.h" | |||
#include "common/util/error_manager/error_manager.h" | |||
@@ -36,6 +37,7 @@ | |||
#include "parser/common/model_saver.h" | |||
#include "parser/common/parser_utils.h" | |||
#include "parser/common/prototype_pass_manager.h" | |||
#include "parser/onnx/onnx_custom_parser_adapter.h" | |||
#include "parser/onnx/onnx_util.h" | |||
#include "register/op_registry.h" | |||
#include "register/register_fmk_types.h" | |||
@@ -555,6 +557,40 @@ Status OnnxModelParser::Prechecker(ge::onnx::GraphProto &onnx_graph) { | |||
return SUCCESS; | |||
} | |||
Status OnnxModelParser::ParseOpParam(const ge::onnx::NodeProto *node_proto, ge::Operator &op, | |||
std::shared_ptr<OpParser> &op_parser) { | |||
GE_CHECK_NOTNULL(node_proto); | |||
GE_CHECK_NOTNULL(op_parser); | |||
std::string op_type = node_proto->op_type(); | |||
Status status = FAILED; | |||
domi::ParseParamByOpFunc parse_param_func = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_type); | |||
if (parse_param_func == nullptr) { | |||
status = op_parser->ParseParams(node_proto, op); | |||
} else { | |||
ge::Operator op_src(node_proto->name(), op_type); | |||
status = Message2Operator::ParseOperatorAttrs(node_proto, 1, op_src); | |||
if (status != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Auto mapping node:%s(%s) to operator failed", | |||
node_proto->name().c_str(), op_type.c_str()); | |||
GELOGE(status, "Node[%s] auto mapping failed.", node_proto->name().c_str()); | |||
return status; | |||
} | |||
std::shared_ptr<ge::OnnxCustomParserAdapter> onnx_custom_op_parser = | |||
std::dynamic_pointer_cast<ge::OnnxCustomParserAdapter>(op_parser); | |||
status = onnx_custom_op_parser->ParseParams(op_src, op); | |||
op_src.BreakConnect(); | |||
} | |||
if (status != SUCCESS) { | |||
ErrorManager::GetInstance().ATCReportErrMessage("E11010", {"opname", "optype"}, {node_proto->name(), op_type}); | |||
GELOGE(status, "[Parse][Params] for op [%s] fail, optype [%s]", node_proto->name().c_str(), op_type.c_str()); | |||
return status; | |||
} | |||
return SUCCESS; | |||
} | |||
Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { | |||
for (int i = 0; i < onnx_graph.node_size(); i++) { | |||
ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); | |||
@@ -586,11 +622,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||
GE_CHECK_NOTNULL(factory); | |||
std::shared_ptr<ge::OpParser> op_parser = factory->CreateOpParser(op_type); | |||
GE_CHECK_NOTNULL(op_parser); | |||
std::shared_ptr<ge::OnnxOpParser> onnx_op_parser = std::static_pointer_cast<ge::OnnxOpParser>(op_parser); | |||
GE_CHECK_NOTNULL(onnx_op_parser); | |||
status = onnx_op_parser->ParseParams(node_proto, op); | |||
status = ParseOpParam(node_proto, op, op_parser); | |||
if (status != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "ParseParams for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status); | |||
GELOGE(status, "[Parse][Params] for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status); | |||
return status; | |||
} | |||
@@ -598,7 +631,6 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: | |||
GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(), | |||
op.GetOpType().c_str(), op.GetInputsSize(), op.GetOutputsSize()); | |||
ge::graphStatus graph_status = graph.AddOp(op); | |||
if (graph_status != ge::GRAPH_SUCCESS) { | |||
GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", op.GetName().c_str()); | |||
@@ -110,6 +110,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { | |||
void ClearMembers(); | |||
Status ParseOpParam(const ge::onnx::NodeProto *node_proto, ge::Operator &op, std::shared_ptr<OpParser> &op_parser); | |||
Status AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph, | |||
std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph); | |||
@@ -221,6 +221,7 @@ set(PARSER_SRC_FILES | |||
"${PARSER_DIR}/parser/caffe/caffe_reshape_parser.cc" | |||
"${PARSER_DIR}/parser/common/acl_graph_parser_util.cc" | |||
"${PARSER_DIR}/parser/common/convert/pb2json.cc" | |||
"${PARSER_DIR}/parser/common/convert/message2operator.cc" | |||
"${PARSER_DIR}/parser/common/data_op_parser.cc" | |||
"${PARSER_DIR}/parser/common/model_saver.cc" | |||
"${PARSER_DIR}/parser/common/op_def/arg_op.cc" | |||
@@ -305,6 +306,7 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) | |||
set(PARSER_UT_FILES | |||
"testcase/onnx_parser_testcase/onnx_parser_unittest.cc" | |||
"testcase/onnx_parser_testcase/message2operator_unittest.cc" | |||
"testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc" | |||
) | |||
@@ -0,0 +1,58 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "common/convert/message2operator.h" | |||
#include <gtest/gtest.h> | |||
#include "proto/onnx/ge_onnx.pb.h" | |||
namespace ge { | |||
class UtestMessage2Operator : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
}; | |||
TEST_F(UtestMessage2Operator, message_to_operator_success) { | |||
ge::onnx::NodeProto input_node; | |||
ge::onnx::AttributeProto *attribute = input_node.add_attribute(); | |||
attribute->set_name("attribute"); | |||
attribute->set_type(onnx::AttributeProto::AttributeType(1)); | |||
attribute->set_f(1.0); | |||
ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); | |||
attribute_tensor->set_data_type(1); | |||
attribute_tensor->add_dims(4); | |||
ge::Operator op_src("add", "Add"); | |||
auto ret = Message2Operator::ParseOperatorAttrs(attribute, 1, op_src); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
TEST_F(UtestMessage2Operator, message_to_operator_fail) { | |||
ge::onnx::NodeProto input_node; | |||
ge::onnx::AttributeProto *attribute = input_node.add_attribute(); | |||
ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); | |||
attribute_tensor->add_double_data(1.00); | |||
ge::Operator op_src("add", "Add"); | |||
auto ret = Message2Operator::ParseOperatorAttrs(attribute, 6, op_src); | |||
EXPECT_EQ(ret, FAILED); | |||
ret = Message2Operator::ParseOperatorAttrs(attribute, 1, op_src); | |||
EXPECT_EQ(ret, FAILED); | |||
} | |||
} // namespace ge |
@@ -39,6 +39,10 @@ static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& | |||
return SUCCESS; | |||
} | |||
static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_dest) { | |||
return SUCCESS; | |||
} | |||
Status ParseSubgraphPostFnIf(const std::string& subgraph_name, const ge::Graph& graph) { | |||
domi::AutoMappingSubgraphIOIndexFunc auto_mapping_subgraph_index_func = | |||
domi::FrameworkRegistry::Instance().GetAutoMappingSubgraphIOIndexFunc(domi::ONNX); | |||
@@ -72,6 +76,7 @@ void UtestOnnxParser::RegisterCustomOp() { | |||
"ai.onnx::12::If", | |||
"ai.onnx::13::If"}) | |||
.ParseParamsFn(ParseParams) | |||
.ParseParamsByOperatorFn(ParseParamByOpFunc) | |||
.ParseSubgraphPostFn(ParseSubgraphPostFnIf); | |||
REGISTER_CUSTOM_OP("Add") | |||