Browse Source

!503 aclgrphparser support zero input

Merge pull request !503 from 徐睿/ge_dev
pull/504/head
计晨 Gitee 3 years ago
parent
commit
6035824cee
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 2 additions and 49 deletions
  1. +0
    -38
      parser/tensorflow/tensorflow_parser.cc
  2. +0
    -9
      parser/tensorflow/tensorflow_parser.h
  3. +1
    -1
      tests/st/testcase/test_tensorflow_parser.cc
  4. +1
    -1
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 0
- 38
parser/tensorflow/tensorflow_parser.cc View File

@@ -1237,9 +1237,6 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g
// This function call affects the return value of prechecker::instance().Haserror()
GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list));

// Check the input validity of the node, the input attribute must have a corresponding node
GE_RETURN_IF_ERROR(CheckGraphDefValid(graph_def));

// Building input and input relationships for all OP nodes
GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def));
GELOGD("[TF ParseFromMemory] get op nodes context from graph success");
@@ -1472,10 +1469,6 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro
// This function call affects the return value of prechecker::instance().Haserror()
GE_RETURN_IF_ERROR(ClearFusionOpError(op_node_name_list));

// Check the input validity of the node, the input attribute must have a corresponding node
GE_RETURN_IF_ERROR(CheckGraphDefValid(graph_def));
GELOGD("[TF Parse] check graph success");

// Building input and input relationships for all OP nodes
GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(graph_def));
GELOGD("[TF Parse] get op nodes context from graph success");
@@ -1548,37 +1541,6 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro
return SUCCESS;
}

Status TensorFlowModelParser::CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) const {
// Number of data nodes
uint32_t data_node_count = 0;
for (const domi::tensorflow::NodeDef &node_def : graph_def.node()) {
// Check that all input is valid
for (const string &node_name : node_def.input()) {
string tmp_node_name;
GE_RETURN_IF_ERROR(CheckInputNodeName(node_name, &tmp_node_name, nullptr, nullptr));

if (nodedef_map_.find(tmp_node_name) == nodedef_map_.end()) {
ErrorManager::GetInstance().ATCReportErrMessage("E12009", {"opname", "inputopname"},
{node_def.name(), node_name});
GELOGE(INTERNAL_ERROR, "Op[%s]'s input op[%s] is not exist in the graph_def.", node_def.name().c_str(),
node_name.c_str());
return INTERNAL_ERROR;
}
}

if (node_def.op() == TENSORFLOWF_NODE_OP_PLACEHOLDER || node_def.op() == ge::parser::ARG) {
data_node_count++;
}
}
if (data_node_count == 0) {
ErrorManager::GetInstance().ATCReportErrMessage("E12010");
GELOGE(INTERNAL_ERROR, "Model has no Placeholder node.");
return INTERNAL_ERROR;
}

return SUCCESS;
}

Status TensorFlowModelParser::GetOpNodesContextFromGraph(const domi::tensorflow::GraphDef &graph_def) {
// Build the input relationship first
for (auto &iter : op_node_context_map_) {


+ 0
- 9
parser/tensorflow/tensorflow_parser.h View File

@@ -241,15 +241,6 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {

/**
* @ingroup domi_omg
* @brief Verifying the validity of graphdef object parsed by pb
* @param [in] graph_def Parsed tensorflow:: graphdef object
* @return SUCCESS check successfully
* @return FAILED check failed
*/
Status CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def) const;

/**
* @ingroup domi_omg
* @brief whether const OP need to update context
* @param const op name
* @return true or false


+ 1
- 1
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -1259,7 +1259,7 @@ TEST_F(STestTensorflowParser, tensorflow_parserAllGraph_failed)
ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph);
TensorFlowModelParser tensorflow_parser;
ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph);
EXPECT_EQ(INTERNAL_ERROR, ret);
ASSERT_NE(ret, SUCCESS);
}

TEST_F(STestTensorflowParser, test_parse_acl_output_nodes)


+ 1
- 1
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -1419,7 +1419,7 @@ TEST_F(UtestTensorflowParser, tensorflow_parserAllGraph_failed)
ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph);
TensorFlowModelParser tensorflow_parser;
ret = tensorflow_parser.ParseAllGraph(reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph);
EXPECT_EQ(INTERNAL_ERROR, ret);
ASSERT_NE(ret, SUCCESS);
}

TEST_F(UtestTensorflowParser, test_parse_acl_output_nodes)


Loading…
Cancel
Save