Compare commits

...

12 Commits

Author SHA1 Message Date
  i-robot 25b8a1c4dc
!523 code review 3 years ago
  wangzhengjun c2839b8162 fixed cdb84d0 from https://gitee.com/wangzhengjun3/parser/pulls/516 3 years ago
  i-robot 6e91b5f24d
!507 delete graph valid check 3 years ago
  isaacxr 413f0b0ee5 delete graph valid check 3 years ago
  i-robot e7a658ea14
!502 cherry pick主线提高depth的阈值 3 years ago
  gengchao 6436b4a144 增加递归深度阈值 3 years ago
  i-robot ef7237b85c
!494 cherry pick主线安全检视代码 3 years ago
  gengchao 725918958a 安全代码检视 3 years ago
  王涛 bbf10ec26b
update OWNERS. 3 years ago
  i-robot 049ae100a3
!483 告警清除 3 years ago
  zhao-lupeng de7bdecd28 fix sc 3 years ago
  王涛 6a4107de6f
update .gitmodules. 3 years ago
11 changed files with 48 additions and 82 deletions
Split View
  1. +1
    -1
      .gitmodules
  2. +5
    -3
      OWNERS
  3. +14
    -8
      parser/common/convert/pb2json.cc
  4. +3
    -3
      parser/common/convert/pb2json.h
  5. +13
    -13
      parser/func_to_graph/func2graph.py
  6. +3
    -0
      parser/tensorflow/graph_optimizer.cc
  7. +2
    -2
      parser/tensorflow/tensorflow_data_parser.cc
  8. +1
    -39
      parser/tensorflow/tensorflow_parser.cc
  9. +0
    -9
      parser/tensorflow/tensorflow_parser.h
  10. +3
    -2
      tests/st/testcase/test_tensorflow_parser.cc
  11. +3
    -2
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 1
- 1
.gitmodules View File

@@ -1,4 +1,4 @@
[submodule "metadef"]
path = metadef
url = https://gitee.com/ascend/metadef.git
branch = master
branch = r1.8.0

+ 5
- 3
OWNERS View File

@@ -1,8 +1,10 @@
approvers:
- ji_chen
- wqtshg
- ljl0711
- liu-jisheng
- startzgf168
- andylhy
- liyihan123
reviewers:
- xchu42
- sheng-nan
- wqtshg
- liu-jisheng

+ 14
- 8
parser/common/convert/pb2json.cc View File

@@ -31,11 +31,17 @@ using std::string;
namespace ge {
namespace {
const int kSignificantDigits = 10;
const int kMaxParseDepth = 20;
}
// JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message,
const set<string> &black_fields, Json &json,
bool enum2str) {
bool enum2str, int depth) {
if (depth > kMaxParseDepth) {
REPORT_INNER_ERROR("E19999", "Message depth:%d can not exceed %d.", depth, kMaxParseDepth);
GELOGE(FAILED, "[Check][Param]Message depth can not exceed %d.", kMaxParseDepth);
return;
}
auto descriptor = message.GetDescriptor();
auto reflection = message.GetReflection();
if (descriptor == nullptr || reflection == nullptr) {
@@ -57,7 +63,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(cons

if (field->is_repeated()) {
if (reflection->FieldSize(message, field) > 0) {
RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str);
RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str, depth);
}
continue;
}
@@ -66,18 +72,18 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(cons
continue;
}

OneField2Json(message, field, reflection, black_fields, json, enum2str);
OneField2Json(message, field, reflection, black_fields, json, enum2str, depth);
}
}

void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json,
bool enum2str) {
bool enum2str, int depth) {
switch (field->type()) {
case ProtobufFieldDescriptor::TYPE_MESSAGE: {
const ProtobufMsg &tmp_message = reflection->GetMessage(message, field);
if (0UL != tmp_message.ByteSizeLong()) {
Message2Json(tmp_message, black_fields, json[field->name()], enum2str);
Message2Json(tmp_message, black_fields, json[field->name()], enum2str, depth + 1);
}
break;
}
@@ -163,9 +169,9 @@ string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) {

void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json,
bool enum2str) {
bool enum2str, int depth) {
if ((field == nullptr) || (reflection == nullptr)) {
Message2Json(message, black_fields, json, enum2str);
Message2Json(message, black_fields, json, enum2str, depth + 1);
return;
}

@@ -175,7 +181,7 @@ void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFie
case ProtobufFieldDescriptor::TYPE_MESSAGE: {
const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i);
if (0UL != tmp_message.ByteSizeLong()) {
Message2Json(tmp_message, black_fields, tmp_json, enum2str);
Message2Json(tmp_message, black_fields, tmp_json, enum2str, depth + 1);
}
} break;



+ 3
- 3
parser/common/convert/pb2json.h View File

@@ -45,11 +45,11 @@ class Pb2Json {
* @author
*/
static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json,
bool enum2str = false);
bool enum2str = false, int depth = 0);

static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const std::set<std::string> &black_fields,
Json &json, bool enum2str);
Json &json, bool enum2str, int depth = 0);

protected:
static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field,
@@ -59,7 +59,7 @@ class Pb2Json {

static void OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const std::set<std::string> &black_fields, Json &json,
bool enum2str);
bool enum2str, int depth);

static std::string TypeBytes2String(std::string &field_name, std::string &type_bytes);
};


+ 13
- 13
parser/func_to_graph/func2graph.py View File

@@ -227,9 +227,9 @@ def convert_subgraphs(graph_def, filename):
print(graph_def_library.graph_def[i])

# Write to prototxt
graph_def_file = '{}/graph_def_library.pbtxt'.format(os.path.dirname(os.path.abspath(filename)))
print("graph_def_file: ", graph_def_file)
try:
graph_def_file = '{}/graph_def_library.pbtxt'.format(os.path.dirname(os.path.abspath(filename)))
print("graph_def_file: ", graph_def_file)
with open(graph_def_file, "w") as f:
print(graph_def_library, file=f)
except IOError:
@@ -261,18 +261,18 @@ if __name__ == '__main__':
model = ''
try:
opts, args = getopt.getopt(sys.argv[1:], '-v-h-m:', ['version', 'help', 'model='])
for opt_name, opt_value in opts:
if opt_name in ('-m', '--model'):
model = opt_value
print("INFO: Input model file is", model)
convert_graphs(model)
elif opt_name in ('-h', '--help'):
usage()
break
elif opt_name in ('-v', '--version'):
print("version 1.0.0")
break
except getopt.GetoptError:
print("ERROR: Input parameters is invalid, use '--help' to view the help.")
for opt_name, opt_value in opts:
if opt_name in ('-m', '--model'):
model = opt_value
print("INFO: Input model file is", model)
convert_graphs(model)
elif opt_name in ('-h', '--help'):
usage()
break
elif opt_name in ('-v', '--version'):
print("version 1.0.0")
break
if len(sys.argv) == 1:
print("INFO: Please specify the input parameters, and use '--help' to view the help.")

+ 3
- 0
parser/tensorflow/graph_optimizer.cc View File

@@ -93,6 +93,7 @@ Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next, const boo
GE_CHECK_NOTNULL(graph_);
for (auto node : graph_->GetDirectNode()) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue)
string type;
GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type));
@@ -284,6 +285,7 @@ Status ParserGraphOptimizer::InsertNode(ge::ComputeGraphPtr sub_graph, vector<ge
GE_CHECK_NOTNULL(node);
OpDescPtr op_def = node->GetOpDesc();
NodePtr new_node = sub_graph->AddNode(op_def);
GE_CHECK_NOTNULL(new_node);
node_map[node->GetName()] = new_node;

// Input
@@ -440,6 +442,7 @@ Status ParserGraphOptimizer::RebuildFusionNode(vector<ge::InDataAnchorPtr> &inpu
vector<ge::InControlAnchorPtr> &input_control_anchors,
vector<ge::OutControlAnchorPtr> &output_control_anchors,
ge::NodePtr fusion_node) {
GE_CHECK_NOTNULL(fusion_node);
int32_t src_index = 0;

for (auto out_anchor : output_anchors) {


+ 2
- 2
parser/tensorflow/tensorflow_data_parser.cc View File

@@ -110,7 +110,7 @@ Status TensorFlowDataParser::ParseInputFromUser(const Message *op_src, const ge:
std::string name = op_def->GetName();
if (input_dims.count(name) == 0) {
GELOGI("input shape of node %s is not designated ,need parse from model", name.c_str());
for (uint32_t i = 0; i < model_input_dims_v.size(); i++) {
for (size_t i = 0; i < model_input_dims_v.size(); ++i) {
user_input_dims_v.push_back(model_input_dims_v[i]);
}

@@ -138,7 +138,7 @@ Status TensorFlowDataParser::ParseInputFromUser(const Message *op_src, const ge:
}

Status TensorFlowDataParser::CheckInputShape(const std::string &name) {
for (uint32_t i = 0; i < user_input_dims_v.size(); i++) {
for (size_t i = 0; i < user_input_dims_v.size(); ++i) {
// if input_shape has some placeholders, user should designate them.
// dim i = 0, means empty tensor.
// dim i = -1 or -2, means unknown shape.


+ 1
- 39
parser/tensorflow/tensorflow_parser.cc View File

@@ -1237,9 +1237,6 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g
// This function call affects the return value of prechecker::instance().Haserror()
GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list));

// Check the input validity of the node, the input attribute must have a corresponding node
GE_RETURN_IF_ERROR(CheckGraphDefValid(graph_def));

// Building input and input relationships for all OP nodes
GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def));
GELOGD("[TF ParseFromMemory] get op nodes context from graph success");
@@ -1472,10 +1469,6 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro
// This function call affects the return value of prechecker::instance().Haserror()
GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list));

// Check the input validity of the node, the input attribute must have a corresponding node
GE_RETURN_IF_ERROR(CheckGraphDefValid(graph_def));
GELOGD("[TF Parse] check graph success");

// Building input and input relationships for all OP nodes
GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def));
GELOGD("[TF Parse] get op nodes context from graph success");
@@ -1548,37 +1541,6 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro
return SUCCESS;
}

Status TensorFlowModelParser::CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) const {
// Number of data nodes
uint32_t data_node_count = 0;
for (const domi::tensorflow::NodeDef &node_def : graph_def.node()) {
// Check that all input is valid
for (const string &node_name : node_def.input()) {
string tmp_node_name;
GE_RETURN_IF_ERROR(CheckInputNodeName(node_name, &tmp_node_name, nullptr, nullptr));

if (nodedef_map_.find(tmp_node_name) == nodedef_map_.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E12009", {"opname", "inputopname"},
{node_def.name(), node_name});
GELOGE(INTERNAL_ERROR, "Op[%s]'s input op[%s] is not exist in the graph_def.", node_def.name().c_str(),
node_name.c_str());
return INTERNAL_ERROR;
}
}

if (node_def.op() == TENSORFLOWF_NODE_OP_PLACEHOLDER || node_def.op() == ge::parser::ARG) {
data_node_count++;
}
}
if (data_node_count == 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E12010");
GELOGE(INTERNAL_ERROR, "Model has no Placeholder node.");
return INTERNAL_ERROR;
}

return SUCCESS;
}

Status TensorFlowModelParser::GetOpNodesContextFromGraph(const domi::tensorflow::GraphDef &graph_def) {
// Build the input relationship first
for (auto &iter : op_node_context_map_) {
@@ -2828,7 +2790,7 @@ Status EraseTransposeNode(std::map<std::string, std::string> &softmaxInfo,
itTranspose->second.node_def->input(0).c_str());
itTranspose = transposeInfo.erase(itTranspose);
} else {
itTranspose++;
++itTranspose;
}
}



+ 0
- 9
parser/tensorflow/tensorflow_parser.h View File

@@ -241,15 +241,6 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {

/**
* @ingroup domi_omg
* @brief Verifying the validity of graphdef object parsed by pb
* @param [in] graph_def Parsed tensorflow:: graphdef object
* @return SUCCESS check successfully
* @return FAILED check failed
*/
Status CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) const;

/**
* @ingroup domi_omg
* @brief whether const OP need to update context
* @param const op name
* @return true or false


+ 3
- 2
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -1259,7 +1259,7 @@ TEST_F(STestTensorflowParser, tensorflow_parserAllGraph_failed)
ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph);
TensorFlowModelParser tensorflow_parser;
ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph);
EXPECT_EQ(INTERNAL_ERROR, ret);
ASSERT_NE(ret, SUCCESS);
}

TEST_F(STestTensorflowParser, test_parse_acl_output_nodes)
@@ -3534,7 +3534,8 @@ TEST_F(STestTensorflowParser, tensorflow_Pb2Json_OneField2Json_test)
ge::Operator ops = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc);
field->CppTypeName(google::protobuf::FieldDescriptor::CPPTYPE_ENUM);
mess2Op.ParseField(reflection, node_def, field, depth, ops);
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str);
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 1);
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 5);
delete field;
}



+ 3
- 2
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -1419,7 +1419,7 @@ TEST_F(UtestTensorflowParser, tensorflow_parserAllGraph_failed)
ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph);
TensorFlowModelParser tensorflow_parser;
ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph);
EXPECT_EQ(INTERNAL_ERROR, ret);
ASSERT_NE(ret, SUCCESS);
}

TEST_F(UtestTensorflowParser, test_parse_acl_output_nodes)
@@ -3696,7 +3696,8 @@ TEST_F(UtestTensorflowParser, tensorflow_Pb2Json_OneField2Json_test)
ge::Operator ops = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc);
field->CppTypeName(google::protobuf::FieldDescriptor::CPPTYPE_ENUM);
mess2Op.ParseField(reflection, node_def, field, depth, ops);
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str);
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 1);
toJson.OneField2Json((*node_def), field, reflection, black_fields, json, enum2str, 5);
delete field;
}



Loading…
Cancel
Save