diff --git a/parser/caffe/caffe_parser.cc b/parser/caffe/caffe_parser.cc index 9ea40e5..157f33f 100644 --- a/parser/caffe/caffe_parser.cc +++ b/parser/caffe/caffe_parser.cc @@ -21,6 +21,7 @@ #include #include #include +#include "common/convert/message2operator.h" #include "parser/common/convert/pb2json.h" #include "parser/common/acl_graph_parser_util.h" #include "common/op_map.h" @@ -578,7 +579,7 @@ Status CaffeModelParser::CreateCustomOperator(string op_name, string op_type, co return FAILED; } - if (ParseOperatorAttrs(message, 1, ops) != SUCCESS) { + if (Message2Operator::ParseOperatorAttrs(message, 1, ops) != SUCCESS) { GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", op_name.c_str()); return FAILED; } @@ -589,146 +590,6 @@ Status CaffeModelParser::CreateCustomOperator(string op_name, string op_type, co return SUCCESS; } -Status CaffeModelParser::ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops) { - if (depth > kMaxParseDepth) { - REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth); - GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth); - return FAILED; - } - - const google::protobuf::Reflection *reflection = message->GetReflection(); - GE_CHECK_NOTNULL(reflection); - vector field_desc; - reflection->ListFields(*message, &field_desc); - - for (auto &field : field_desc) { - GE_CHECK_NOTNULL(field); - if (field->is_repeated()) { - if (ParseRepeatedField(reflection, message, field, depth, ops) != SUCCESS) { - GELOGE(FAILED, "[Parse][RepeatedField] %s failed.", field->name().c_str()); - return FAILED; - } - } else { - if (ParseField(reflection, message, field, depth, ops) != SUCCESS) { - GELOGE(FAILED, "[Parse][Field] %s failed.", field->name().c_str()); - return FAILED; - } - } - } - return SUCCESS; -} - -Status CaffeModelParser::ParseField(const google::protobuf::Reflection *reflection, - const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, - int depth, ge::Operator &ops) { - GELOGD("Start to parse field: %s.", field->name().c_str()); - switch (field->cpp_type()) { -#define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \ - case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ - valuetype value = reflection->Get##method(*message, field); \ - GELOGD("Parse result(%s : %" #logtype ")", field->name().c_str(), value); \ - (void)ops.SetAttr(field->name(), value); \ - break; \ - } - CASE_FIELD_TYPE(INT32, Int32, int32_t, d); - CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u); - CASE_FIELD_TYPE(INT64, Int64, int64_t, ld); - CASE_FIELD_TYPE(FLOAT, Float, float, f); - CASE_FIELD_TYPE(BOOL, Bool, bool, d); -#undef CASE_FIELD_TYPE - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { - GE_CHECK_NOTNULL(reflection->GetEnum(*message, field)); - int value = reflection->GetEnum(*message, field)->number(); - GELOGD("Parse result(%s : %d)", field->name().c_str(), value); - (void)ops.SetAttr(field->name(), value); - break; - } - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { - string value = reflection->GetString(*message, field); - GELOGD("Parse result(%s : %s)", field->name().c_str(), value.c_str()); - (void)ops.SetAttr(field->name(), value); - break; - } - case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { - const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); - if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) { - GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", field->name().c_str()); - return FAILED; - } - break; - } - default: { - REPORT_INPUT_ERROR("E11032", std::vector({"message_type", "name", "reason"}), - std::vector({"model", field->name(), "Unsupported field type"})); - GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); - return FAILED; - } - } - GELOGD("Parse field: %s success.", field->name().c_str()); - return SUCCESS; -} - -Status CaffeModelParser::ParseRepeatedField(const google::protobuf::Reflection *reflection, - const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, int depth, - ge::Operator &ops) { - GELOGD("Start to parse field: %s.", field->name().c_str()); - int field_size = reflection->FieldSize(*message, field); - if (field_size <= 0) { - REPORT_INNER_ERROR("E19999", "Size of repeated field %s must bigger than 0", field->name().c_str()); - GELOGE(FAILED, "[Check][Size]Size of repeated field %s must bigger than 0", field->name().c_str()); - return FAILED; - } - - switch (field->cpp_type()) { -#define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \ - case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ - vector attr_value; \ - for (int i = 0; i < field_size; i++) { \ - valuetype value = reflection->GetRepeated##method(*message, field, i); \ - attr_value.push_back(value); \ - } \ - (void)ops.SetAttr(field->name(), attr_value); \ - break; \ - } - CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t); - CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t); - CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t); - CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float); - CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool); - CASE_FIELD_TYPE_REPEATED(STRING, String, string); -#undef CASE_FIELD_TYPE_REPEATED - case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { - nlohmann::json message_json; - Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set(), - message_json[field->name()], false); - std::string repeated_message_str; - try { - repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore); - } catch (std::exception &e) { - REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string, reason: %s.", e.what()); - GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string, reason: %s.", e.what()); - return FAILED; - } catch (...) { - REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string."); - GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string."); - return FAILED; - } - (void)ops.SetAttr(field->name(), repeated_message_str); - break; - } - default: { - REPORT_INPUT_ERROR("E11032", std::vector({"message_type", "name", "reason"}), - std::vector({"model", field->name(), "Unsupported field type"})); - GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); - return FAILED; - } - } - GELOGD("Parse repeated field: %s success.", field->name().c_str()); - return SUCCESS; -} - void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_index) { auto iter_node_name = ge::GetParserContext().out_nodes_map.find(layer_name); if (iter_node_name != ge::GetParserContext().out_nodes_map.end()) { diff --git a/parser/caffe/caffe_parser.h b/parser/caffe/caffe_parser.h index 9a3af8b..354f23e 100644 --- a/parser/caffe/caffe_parser.h +++ b/parser/caffe/caffe_parser.h @@ -209,46 +209,6 @@ class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser { */ Status CreateCustomOperator(std::string op_name, std::string op_type, const google::protobuf::Message *message, int index, std::vector &operators); - - /* - * @ingroup domi_omg - * @brief Parse message and set operator attrs - * @param [in] message, message of model - * @param [in/out] depth, depth of recursion - * @param [out] ops, operator saving custom info - * @return SUCCESS parse message successfully - * @return FAILED parse message failed - */ - Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops); - - /* - * @ingroup domi_omg - * @brief Parse field and set operator attrs - * @param [in] reflection, reflection of message - * @param [in] message, message of model - * @param [in] field, field of message - * @param [in/out] depth, depth of recursion - * @param [out] ops, operator saving custom info - * @return SUCCESS parse field successfully - * @return FAILED parse field failed - */ - Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); - - /* - * @ingroup domi_omg - * @brief Parse repeated field and set operator attrs - * @param [in] reflection, reflection of message - * @param [in] message, message of model - * @param [in] field, field of message - * @param [in/out] depth, depth of recursion - * @param [out] ops, operator saving custom info by vector - * @return SUCCESS parse field successfully - * @return FAILED parse field failed - */ - Status ParseRepeatedField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, - const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); - /** * @ingroup domi_omg * @brief Add blob information to the bottom_blobs_map and top_blobs_map_ diff --git a/parser/common/CMakeLists.txt b/parser/common/CMakeLists.txt index 379d4d8..d5b87e8 100644 --- a/parser/common/CMakeLists.txt +++ b/parser/common/CMakeLists.txt @@ -15,6 +15,7 @@ set(SRC_LIST "../tensorflow/tensorflow_fusion_op_parser.cc" "../tensorflow/tensorflow_util.cc" "convert/pb2json.cc" + "convert/message2operator.cc" "op_def/ir_pb_converter.cc" "op_def/defs.cc" "op_def/op_schema.cc" diff --git a/parser/common/convert/message2operator.cc b/parser/common/convert/message2operator.cc new file mode 100644 index 0000000..c0ef702 --- /dev/null +++ b/parser/common/convert/message2operator.cc @@ -0,0 +1,170 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "message2operator.h" + +#include + +#include "common/convert/pb2json.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +namespace { +const int kMaxParseDepth = 5; +const uint32_t kInteval = 2; +} // namespace + +Status Message2Operator::ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops) { + GE_CHECK_NOTNULL(message); + if (depth > kMaxParseDepth) { + REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth); + GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth); + return FAILED; + } + + const google::protobuf::Reflection *reflection = message->GetReflection(); + GE_CHECK_NOTNULL(reflection); + std::vector field_desc; + reflection->ListFields(*message, &field_desc); + + for (auto &field : field_desc) { + GE_CHECK_NOTNULL(field); + if (field->is_repeated()) { + if (ParseRepeatedField(reflection, message, field, depth, ops) != SUCCESS) { + GELOGE(FAILED, "[Parse][RepeatedField] %s failed.", field->name().c_str()); + return FAILED; + } + } else { + if (ParseField(reflection, message, field, depth, ops) != SUCCESS) { + GELOGE(FAILED, "[Parse][Field] %s failed.", field->name().c_str()); + return FAILED; + } + } + } + return SUCCESS; +} + +Status Message2Operator::ParseField(const google::protobuf::Reflection *reflection, + const google::protobuf::Message *message, + const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops) { + GELOGD("Start to parse field: %s.", field->name().c_str()); + switch (field->cpp_type()) { +#define CASE_FIELD_TYPE(cpptype, method, valuetype, logtype) \ + case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ + valuetype value = reflection->Get##method(*message, field); \ + GELOGD("Parse result(%s : %" #logtype ")", field->name().c_str(), value); \ + (void)ops.SetAttr(field->name(), value); \ + break; \ + } + CASE_FIELD_TYPE(INT32, Int32, int32_t, d); + CASE_FIELD_TYPE(UINT32, UInt32, uint32_t, u); + CASE_FIELD_TYPE(INT64, Int64, int64_t, ld); + CASE_FIELD_TYPE(FLOAT, Float, float, f); + CASE_FIELD_TYPE(BOOL, Bool, bool, d); +#undef CASE_FIELD_TYPE + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + GE_CHECK_NOTNULL(reflection->GetEnum(*message, field)); + int value = reflection->GetEnum(*message, field)->number(); + GELOGD("Parse result(%s : %d)", field->name().c_str(), value); + (void)ops.SetAttr(field->name(), value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + string value = reflection->GetString(*message, field); + GELOGD("Parse result(%s : %s)", field->name().c_str(), value.c_str()); + (void)ops.SetAttr(field->name(), value); + break; + } + case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { + const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); + if (ParseOperatorAttrs(&sub_message, depth + 1, ops) != SUCCESS) { + GELOGE(FAILED, "[Parse][OperatorAttrs] of %s failed.", field->name().c_str()); + return FAILED; + } + break; + } + default: { + REPORT_INPUT_ERROR("E11032", std::vector({"message_type", "name", "reason"}), + std::vector({"model", field->name(), "Unsupported field type"})); + GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); + return FAILED; + } + } + GELOGD("Parse field: %s success.", field->name().c_str()); + return SUCCESS; +} + +Status Message2Operator::ParseRepeatedField(const google::protobuf::Reflection *reflection, + const google::protobuf::Message *message, + const google::protobuf::FieldDescriptor *field, int depth, + ge::Operator &ops) { + GELOGD("Start to parse field: %s.", field->name().c_str()); + int field_size = reflection->FieldSize(*message, field); + if (field_size <= 0) { + REPORT_INNER_ERROR("E19999", "Size of repeated field %s must bigger than 0", field->name().c_str()); + GELOGE(FAILED, "[Check][Size]Size of repeated field %s must bigger than 0", field->name().c_str()); + return FAILED; + } + + switch (field->cpp_type()) { +#define CASE_FIELD_TYPE_REPEATED(cpptype, method, valuetype) \ + case google::protobuf::FieldDescriptor::CPPTYPE_##cpptype: { \ + std::vector attr_value; \ + for (int i = 0; i < field_size; i++) { \ + valuetype value = reflection->GetRepeated##method(*message, field, i); \ + attr_value.push_back(value); \ + } \ + (void)ops.SetAttr(field->name(), attr_value); \ + break; \ + } + CASE_FIELD_TYPE_REPEATED(INT32, Int32, int32_t); + CASE_FIELD_TYPE_REPEATED(UINT32, UInt32, uint32_t); + CASE_FIELD_TYPE_REPEATED(INT64, Int64, int64_t); + CASE_FIELD_TYPE_REPEATED(FLOAT, Float, float); + CASE_FIELD_TYPE_REPEATED(BOOL, Bool, bool); + CASE_FIELD_TYPE_REPEATED(STRING, String, string); +#undef CASE_FIELD_TYPE_REPEATED + case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { + nlohmann::json message_json; + Pb2Json::RepeatedMessage2Json(*message, field, reflection, std::set(), message_json[field->name()], + false); + std::string repeated_message_str; + try { + repeated_message_str = message_json.dump(kInteval, ' ', false, Json::error_handler_t::ignore); + } catch (std::exception &e) { + REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string, reason: %s.", e.what()); + GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string, reason: %s.", e.what()); + return FAILED; + } catch (...) { + REPORT_INNER_ERROR("E19999", "Failed to convert JSON to string."); + GELOGE(FAILED, "[Parse][JSON]Failed to convert JSON to string."); + return FAILED; + } + (void)ops.SetAttr(field->name(), repeated_message_str); + break; + } + default: { + REPORT_INPUT_ERROR("E11032", std::vector({"message_type", "name", "reason"}), + std::vector({"model", field->name(), "Unsupported field type"})); + GELOGE(FAILED, "[Check][FieldType]Unsupported field type, name: %s.", field->name().c_str()); + return FAILED; + } + } + GELOGD("Parse repeated field: %s success.", field->name().c_str()); + return SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/parser/common/convert/message2operator.h b/parser/common/convert/message2operator.h new file mode 100644 index 0000000..f33d4f3 --- /dev/null +++ b/parser/common/convert/message2operator.h @@ -0,0 +1,38 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef PARSER_MESSAGE2OPERATOR_H +#define PARSER_MESSAGE2OPERATOR_H + +#include "external/ge/ge_api_error_codes.h" +#include "external/graph/operator.h" +#include "google/protobuf/message.h" + +namespace ge { +class Message2Operator { + public: + static Status ParseOperatorAttrs(const google::protobuf::Message *message, int depth, ge::Operator &ops); + + private: + static Status ParseField(const google::protobuf::Reflection *reflection, const google::protobuf::Message *message, + const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); + + static Status ParseRepeatedField(const google::protobuf::Reflection *reflection, + const google::protobuf::Message *message, + const google::protobuf::FieldDescriptor *field, int depth, ge::Operator &ops); +}; +} // namespace ge +#endif // PARSER_MESSAGE2OPERATOR_H diff --git a/parser/onnx/onnx_custom_parser_adapter.cc b/parser/onnx/onnx_custom_parser_adapter.cc index 3b7e7b0..1534dff 100644 --- a/parser/onnx/onnx_custom_parser_adapter.cc +++ b/parser/onnx/onnx_custom_parser_adapter.cc @@ -15,23 +15,25 @@ */ #include "parser/onnx/onnx_custom_parser_adapter.h" + #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "parser/common/op_parser_factory.h" #include "register/op_registry.h" -using domi::ParseParamFunc; using domi::ONNX; +using domi::ParseParamByOpFunc; +using domi::ParseParamFunc; -namespace ge{ +namespace ge { Status OnnxCustomParserAdapter::ParseParams(const Message *op_src, ge::Operator &op_dest) { GE_CHECK_NOTNULL(op_src); const ge::onnx::NodeProto *node_src = reinterpret_cast(op_src); GE_CHECK_NOTNULL(node_src); GELOGI("Onnx op node name = %s, op type= %s, parse params.", node_src->name().c_str(), node_src->op_type().c_str()); - ParseParamFunc - custom_op_parser = domi::OpRegistry::Instance()->GetParseParamFunc(op_dest.GetOpType(), node_src->op_type()); + ParseParamFunc custom_op_parser = + domi::OpRegistry::Instance()->GetParseParamFunc(op_dest.GetOpType(), node_src->op_type()); GE_CHECK_NOTNULL(custom_op_parser); if (custom_op_parser(op_src, op_dest) != SUCCESS) { GELOGE(FAILED, "[Invoke][Custom_Op_Parser] Custom parser params failed."); @@ -40,5 +42,18 @@ Status OnnxCustomParserAdapter::ParseParams(const Message *op_src, ge::Operator return SUCCESS; } +Status OnnxCustomParserAdapter::ParseParams(const Operator &op_src, Operator &op_dest) { + ParseParamByOpFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_src.GetOpType()); + GE_CHECK_NOTNULL(custom_op_parser); + + if (custom_op_parser(op_src, op_dest) != SUCCESS) { + GELOGE(FAILED, "[Invoke][Custom_Op_Parser] failed, node name:%s, type:%s", op_src.GetName().c_str(), + op_src.GetOpType().c_str()); + return FAILED; + } + + return SUCCESS; +} + REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(ONNX, OnnxCustomParserAdapter); } // namespace ge diff --git a/parser/onnx/onnx_custom_parser_adapter.h b/parser/onnx/onnx_custom_parser_adapter.h index 1e5f147..7e0fb06 100644 --- a/parser/onnx/onnx_custom_parser_adapter.h +++ b/parser/onnx/onnx_custom_parser_adapter.h @@ -28,6 +28,8 @@ class PARSER_FUNC_VISIBILITY OnnxCustomParserAdapter : public OnnxOpParser { /// @return SUCCESS parse successfully /// @return FAILED parse failed Status ParseParams(const Message *op_src, ge::Operator &op_dest) override; + + Status ParseParams(const Operator &op_src, Operator &op_dest); }; } // namespace ge diff --git a/parser/onnx/onnx_parser.cc b/parser/onnx/onnx_parser.cc index cc69799..3c79a71 100644 --- a/parser/onnx/onnx_parser.cc +++ b/parser/onnx/onnx_parser.cc @@ -18,6 +18,7 @@ #include #include #include +#include "common/convert/message2operator.h" #include "common/convert/pb2json.h" #include "common/util.h" #include "common/util/error_manager/error_manager.h" @@ -36,6 +37,7 @@ #include "parser/common/model_saver.h" #include "parser/common/parser_utils.h" #include "parser/common/prototype_pass_manager.h" +#include "parser/onnx/onnx_custom_parser_adapter.h" #include "parser/onnx/onnx_util.h" #include "register/op_registry.h" #include "register/register_fmk_types.h" @@ -555,6 +557,40 @@ Status OnnxModelParser::Prechecker(ge::onnx::GraphProto &onnx_graph) { return SUCCESS; } +Status OnnxModelParser::ParseOpParam(const ge::onnx::NodeProto *node_proto, ge::Operator &op, + std::shared_ptr &op_parser) { + GE_CHECK_NOTNULL(node_proto); + GE_CHECK_NOTNULL(op_parser); + std::string op_type = node_proto->op_type(); + + Status status = FAILED; + domi::ParseParamByOpFunc parse_param_func = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_type); + if (parse_param_func == nullptr) { + status = op_parser->ParseParams(node_proto, op); + } else { + ge::Operator op_src(node_proto->name(), op_type); + status = Message2Operator::ParseOperatorAttrs(node_proto, 1, op_src); + if (status != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Auto mapping node:%s(%s) to operator failed", + node_proto->name().c_str(), op_type.c_str()); + GELOGE(status, "Node[%s] auto mapping failed.", node_proto->name().c_str()); + return status; + } + std::shared_ptr onnx_custom_op_parser = + std::dynamic_pointer_cast(op_parser); + status = onnx_custom_op_parser->ParseParams(op_src, op); + op_src.BreakConnect(); + } + + if (status != SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E11010", {"opname", "optype"}, {node_proto->name(), op_type}); + GELOGE(status, "[Parse][Params] for op [%s] fail, optype [%s]", node_proto->name().c_str(), op_type.c_str()); + return status; + } + + return SUCCESS; +} + Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge::Graph &graph) { for (int i = 0; i < onnx_graph.node_size(); i++) { ge::onnx::NodeProto *node_proto = onnx_graph.mutable_node(i); @@ -586,11 +622,8 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: GE_CHECK_NOTNULL(factory); std::shared_ptr op_parser = factory->CreateOpParser(op_type); GE_CHECK_NOTNULL(op_parser); - std::shared_ptr onnx_op_parser = std::static_pointer_cast(op_parser); - GE_CHECK_NOTNULL(onnx_op_parser); - status = onnx_op_parser->ParseParams(node_proto, op); + status = ParseOpParam(node_proto, op, op_parser); if (status != SUCCESS) { - REPORT_CALL_ERROR("E19999", "ParseParams for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status); GELOGE(status, "[Parse][Params] for %s:%s failed ret:%d.", node_name.c_str(), op_type.c_str(), status); return status; } @@ -598,7 +631,6 @@ Status OnnxModelParser::ParseAllNodeProto(ge::onnx::GraphProto &onnx_graph, ge:: GELOGI("After ParseParams, op[%s]: type[%s] have input size: %zu, output size: %zu", op.GetName().c_str(), op.GetOpType().c_str(), op.GetInputsSize(), op.GetOutputsSize()); - ge::graphStatus graph_status = graph.AddOp(op); if (graph_status != ge::GRAPH_SUCCESS) { GELOGE(FAILED, "[Add][Op] Add op:%s to graph failed.", op.GetName().c_str()); diff --git a/parser/onnx/onnx_parser.h b/parser/onnx/onnx_parser.h index b28494b..b90c1a3 100644 --- a/parser/onnx/onnx_parser.h +++ b/parser/onnx/onnx_parser.h @@ -110,6 +110,8 @@ class PARSER_FUNC_VISIBILITY OnnxModelParser : public domi::ModelParser { void ClearMembers(); + Status ParseOpParam(const ge::onnx::NodeProto *node_proto, ge::Operator &op, std::shared_ptr &op_parser); + Status AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx_graph, std::map &name_to_onnx_graph); diff --git a/tests/ut/parser/CMakeLists.txt b/tests/ut/parser/CMakeLists.txt index 7130e20..ebc9885 100644 --- a/tests/ut/parser/CMakeLists.txt +++ b/tests/ut/parser/CMakeLists.txt @@ -221,6 +221,7 @@ set(PARSER_SRC_FILES "${PARSER_DIR}/parser/caffe/caffe_reshape_parser.cc" "${PARSER_DIR}/parser/common/acl_graph_parser_util.cc" "${PARSER_DIR}/parser/common/convert/pb2json.cc" + "${PARSER_DIR}/parser/common/convert/message2operator.cc" "${PARSER_DIR}/parser/common/data_op_parser.cc" "${PARSER_DIR}/parser/common/model_saver.cc" "${PARSER_DIR}/parser/common/op_def/arg_op.cc" @@ -305,6 +306,7 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) set(PARSER_UT_FILES "testcase/onnx_parser_testcase/onnx_parser_unittest.cc" + "testcase/onnx_parser_testcase/message2operator_unittest.cc" "testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc" ) diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/message2operator_unittest.cc b/tests/ut/parser/testcase/onnx_parser_testcase/message2operator_unittest.cc new file mode 100644 index 0000000..39e1480 --- /dev/null +++ b/tests/ut/parser/testcase/onnx_parser_testcase/message2operator_unittest.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/convert/message2operator.h" + +#include + +#include "proto/onnx/ge_onnx.pb.h" + +namespace ge { +class UtestMessage2Operator : public testing::Test { + protected: + void SetUp() {} + + void TearDown() {} +}; + +TEST_F(UtestMessage2Operator, message_to_operator_success) { + ge::onnx::NodeProto input_node; + ge::onnx::AttributeProto *attribute = input_node.add_attribute(); + attribute->set_name("attribute"); + attribute->set_type(onnx::AttributeProto::AttributeType(1)); + attribute->set_f(1.0); + ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); + attribute_tensor->set_data_type(1); + attribute_tensor->add_dims(4); + ge::Operator op_src("add", "Add"); + auto ret = Message2Operator::ParseOperatorAttrs(attribute, 1, op_src); + EXPECT_EQ(ret, SUCCESS); +} + +TEST_F(UtestMessage2Operator, message_to_operator_fail) { + ge::onnx::NodeProto input_node; + ge::onnx::AttributeProto *attribute = input_node.add_attribute(); + ge::onnx::TensorProto *attribute_tensor = attribute->mutable_t(); + attribute_tensor->add_double_data(1.00); + + ge::Operator op_src("add", "Add"); + auto ret = Message2Operator::ParseOperatorAttrs(attribute, 6, op_src); + EXPECT_EQ(ret, FAILED); + + ret = Message2Operator::ParseOperatorAttrs(attribute, 1, op_src); + EXPECT_EQ(ret, FAILED); +} +} // namespace ge \ No newline at end of file diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc index a3c54d5..678b8a6 100644 --- a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc +++ b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc @@ -39,6 +39,10 @@ static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& return SUCCESS; } +static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_dest) { + return SUCCESS; +} + Status ParseSubgraphPostFnIf(const std::string& subgraph_name, const ge::Graph& graph) { domi::AutoMappingSubgraphIOIndexFunc auto_mapping_subgraph_index_func = domi::FrameworkRegistry::Instance().GetAutoMappingSubgraphIOIndexFunc(domi::ONNX); @@ -72,6 +76,7 @@ void UtestOnnxParser::RegisterCustomOp() { "ai.onnx::12::If", "ai.onnx::13::If"}) .ParseParamsFn(ParseParams) + .ParseParamsByOperatorFn(ParseParamByOpFunc) .ParseSubgraphPostFn(ParseSubgraphPostFnIf); REGISTER_CUSTOM_OP("Add")