Browse Source

parser bugfix

pull/425/head
huanruizhi 3 years ago
parent
commit
781fae2808
5 changed files with 54 additions and 3 deletions
  1. +6
    -2
      parser/tensorflow/tensorflow_parser.cc
  2. +1
    -1
      tests/depends/error_manager/src/error_manager_stub.cc
  3. BIN
      tests/st/testcase/origin_models/test_snapshot.pb
  4. +18
    -0
      tests/st/testcase/test_tensorflow_parser.cc
  5. +29
    -0
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 6
- 2
parser/tensorflow/tensorflow_parser.cc View File

@@ -2563,6 +2563,7 @@ Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_m
domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name];
GE_CHECK_NOTNULL(output_node_def);
auto inputs = output_node_def->mutable_input();
std::vector<std::string> added_inputs;
for (auto &input : *inputs) {
string node_name;
bool is_control = false;
@@ -2596,12 +2597,15 @@ Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_m
}
}
if (!is_exist_input) {
output_node_def->add_input("^" + item);
GELOGD("Optimize Snapshot node, dest:%s, set control input:%s.", output_node_name.c_str(), item.c_str());
added_inputs.push_back("^" + item);
}
}
}
}
for (std::string added_input : added_inputs) {
GELOGD("Optimize Snapshot node, dest:%s, set control input:%s.", output_node_name.c_str(), added_input.c_str());
output_node_def->add_input(added_input);
}
}
// Clear the input of snapshot and become an isolated node
curr_mode_def->clear_input();


+ 1
- 1
tests/depends/error_manager/src/error_manager_stub.cc View File

@@ -98,7 +98,7 @@ void ErrorManager::SetStage(const std::string &first_stage, const std::string &s
}

struct error_message::Context &ErrorManager::GetErrorManagerContext() {
struct error_message::Context error_context;
static struct error_message::Context error_context;
return error_context;
}



BIN
tests/st/testcase/origin_models/test_snapshot.pb View File


+ 18
- 0
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -2387,5 +2387,23 @@ TEST_F(STestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test)
Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize);
EXPECT_EQ(ret, ge::PARAM_INVALID);
}
TEST_F(STestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) {
std::string caseDir = __FILE__;
std::size_t idx = caseDir.find_last_of("/");
caseDir = caseDir.substr(0, idx);
const std::string root_proto = caseDir + "/origin_models/test_snapshot.pb";
domi::tensorflow::GraphDef graphDef;

bool protoRet =
parser::ReadProtoFromBinaryFile(root_proto.c_str(), &graphDef);
ASSERT_EQ(protoRet, true);

TensorFlowModelParser tensorflow_parser;
ge::ComputeGraphPtr root_graph =
ge::parser::MakeShared<ge::ComputeGraph>("tmp_graph");
Status ret = tensorflow_parser.ParseProto(
reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph);
EXPECT_EQ(FAILED, ret);
}

} // namespace ge

+ 29
- 0
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -188,4 +188,33 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_with_external_graph) {
ret = TensorFlowModelParser::AddExternalGraph(root_graph);
EXPECT_EQ(ret, INTERNAL_ERROR);
}

TEST_F(UtestTensorflowParser, optimize_snapshot) {
domi::tensorflow::GraphDef graph_def;

auto mul_node = graph_def.add_node();
mul_node->set_name("optimizer/Mul");
mul_node->set_op("Mul");
mul_node->add_input("Snapshot:0");

auto snapshot_node = graph_def.add_node();
snapshot_node->set_name("Snapshot");
snapshot_node->set_op("Snapshot");
snapshot_node->add_input("loss_scale/read:0");
snapshot_node->add_input("^ShuffleNet/AssignMovingAvg");

auto identity_node = graph_def.add_node();
identity_node->set_name("loss_scale/read");
identity_node->set_op("Identity");
identity_node->add_input("loss_scale/ref:0");

auto assign_node = graph_def.add_node();
assign_node->set_name("ShuffleNet/AssignMovingAvg");
assign_node->set_op("AssignSub");
assign_node->add_input("ShuffleNet/moving_mean:0");

Status ret = TensorFlowModelParser().GraphDefOptimize(&graph_def);
EXPECT_EQ(ret, ge::SUCCESS);
}

} // namespace ge

Loading…
Cancel
Save