diff --git a/parser/common/convert/pb2json.cc b/parser/common/convert/pb2json.cc index 3e82ca6..b411252 100644 --- a/parser/common/convert/pb2json.cc +++ b/parser/common/convert/pb2json.cc @@ -31,11 +31,17 @@ using std::string; namespace ge { namespace { const int kSignificantDigits = 10; +const int kMaxParseDepth = 5; } // JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message, const set &black_fields, Json &json, - bool enum2str) { + bool enum2str, int depth) { + if (depth > kMaxParseDepth) { + REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth); + GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth); + return; + } auto descriptor = message.GetDescriptor(); auto reflection = message.GetReflection(); if (descriptor == nullptr || reflection == nullptr) { @@ -57,7 +63,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(cons if (field->is_repeated()) { if (reflection->FieldSize(message, field) > 0) { - RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str); + RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str, depth); } continue; } @@ -66,18 +72,18 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(cons continue; } - OneField2Json(message, field, reflection, black_fields, json, enum2str); + OneField2Json(message, field, reflection, black_fields, json, enum2str, depth); } } void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, const ProtobufReflection *reflection, const set &black_fields, Json &json, - bool enum2str) { + bool enum2str, int depth) { switch (field->type()) { case ProtobufFieldDescriptor::TYPE_MESSAGE: { const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); if (0UL != tmp_message.ByteSizeLong()) { - Message2Json(tmp_message, black_fields, json[field->name()], enum2str); + Message2Json(tmp_message, black_fields, json[field->name()], enum2str, depth + 1); } break; } @@ -163,9 +169,9 @@ string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, const ProtobufReflection *reflection, const set &black_fields, Json &json, - bool enum2str) { + bool enum2str, int depth) { if ((field == nullptr) || (reflection == nullptr)) { - Message2Json(message, black_fields, json, enum2str); + Message2Json(message, black_fields, json, enum2str, depth + 1); return; } @@ -175,7 +181,7 @@ void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFie case ProtobufFieldDescriptor::TYPE_MESSAGE: { const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i); if (0UL != tmp_message.ByteSizeLong()) { - Message2Json(tmp_message, black_fields, tmp_json, enum2str); + Message2Json(tmp_message, black_fields, tmp_json, enum2str, depth + 1); } } break; diff --git a/parser/common/convert/pb2json.h b/parser/common/convert/pb2json.h index 4f8e406..28e796d 100644 --- a/parser/common/convert/pb2json.h +++ b/parser/common/convert/pb2json.h @@ -45,11 +45,11 @@ class Pb2Json { * @author */ static void Message2Json(const ProtobufMsg &message, const std::set &black_fields, Json &json, - bool enum2str = false); + bool enum2str = false, int depth = 0); static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, const ProtobufReflection *reflection, const std::set &black_fields, - Json &json, bool enum2str); + Json &json, bool enum2str, int depth = 0); protected: static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, @@ -59,7 +59,7 @@ class Pb2Json { static void OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, const ProtobufReflection *reflection, const std::set &black_fields, Json &json, - bool enum2str); + bool enum2str, int depth); static std::string TypeBytes2String(std::string &field_name, std::string &type_bytes); }; diff --git a/tests/st/testcase/test_tensorflow_parser.cc b/tests/st/testcase/test_tensorflow_parser.cc index 07346f9..03b0dd7 100644 --- a/tests/st/testcase/test_tensorflow_parser.cc +++ b/tests/st/testcase/test_tensorflow_parser.cc @@ -3534,7 +3534,8 @@ TEST_F(STestTensorflowParser, tensorflow_Pb2Json_OneField2Json_test) ge::Operator ops = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc); field->CppTypeName(google::protobuf::FieldDescriptor::CPPTYPE_ENUM); mess2Op.ParseField(reflection, node_def, field, depth, ops); - toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str); + toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 1); + toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 5); delete field; } diff --git a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc index 0afb636..75450f6 100644 --- a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc +++ b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc @@ -3696,7 +3696,8 @@ TEST_F(UtestTensorflowParser, tensorflow_Pb2Json_OneField2Json_test) ge::Operator ops = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc); field->CppTypeName(google::protobuf::FieldDescriptor::CPPTYPE_ENUM); mess2Op.ParseField(reflection, node_def, field, depth, ops); - toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str); + toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 1); + toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 5); delete field; }