diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index 23300b7..5afce4c 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -2563,6 +2563,7 @@ Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_m domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name]; GE_CHECK_NOTNULL(output_node_def); auto inputs = output_node_def->mutable_input(); + std::vector added_inputs; for (auto &input : *inputs) { string node_name; bool is_control = false; @@ -2596,12 +2597,15 @@ Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_m } } if (!is_exist_input) { - output_node_def->add_input("^" + item); - GELOGD("Optimize Snapshot node, dest:%s, set control input:%s.", output_node_name.c_str(), item.c_str()); + added_inputs.push_back("^" + item); } } } } + for (std::string added_input : added_inputs) { + GELOGD("Optimize Snapshot node, dest:%s, set control input:%s.", output_node_name.c_str(), added_input.c_str()); + output_node_def->add_input(added_input); + } } // Clear the input of snapshot and become an isolated node curr_mode_def->clear_input(); diff --git a/tests/depends/error_manager/src/error_manager_stub.cc b/tests/depends/error_manager/src/error_manager_stub.cc index 1fba95a..900d595 100644 --- a/tests/depends/error_manager/src/error_manager_stub.cc +++ b/tests/depends/error_manager/src/error_manager_stub.cc @@ -98,7 +98,7 @@ void ErrorManager::SetStage(const std::string &first_stage, const std::string &s } struct error_message::Context &ErrorManager::GetErrorManagerContext() { - struct error_message::Context error_context; + static struct error_message::Context error_context; return error_context; } diff --git a/tests/st/testcase/origin_models/test_snapshot.pb b/tests/st/testcase/origin_models/test_snapshot.pb new file mode 100644 index 0000000..03c4f22 Binary files /dev/null and b/tests/st/testcase/origin_models/test_snapshot.pb differ diff --git a/tests/st/testcase/test_tensorflow_parser.cc b/tests/st/testcase/test_tensorflow_parser.cc index 09567ff..ae0ca2b 100644 --- a/tests/st/testcase/test_tensorflow_parser.cc +++ b/tests/st/testcase/test_tensorflow_parser.cc @@ -2387,5 +2387,23 @@ TEST_F(STestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test) Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize); EXPECT_EQ(ret, ge::PARAM_INVALID); } +TEST_F(STestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) { + std::string caseDir = __FILE__; + std::size_t idx = caseDir.find_last_of("/"); + caseDir = caseDir.substr(0, idx); + const std::string root_proto = caseDir + "/origin_models/test_snapshot.pb"; + domi::tensorflow::GraphDef graphDef; + + bool protoRet = + parser::ReadProtoFromBinaryFile(root_proto.c_str(), &graphDef); + ASSERT_EQ(protoRet, true); + + TensorFlowModelParser tensorflow_parser; + ge::ComputeGraphPtr root_graph = + ge::parser::MakeShared("tmp_graph"); + Status ret = tensorflow_parser.ParseProto( + reinterpret_cast(&graphDef), root_graph); + EXPECT_EQ(FAILED, ret); +} } // namespace ge diff --git a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc index cd2a59b..fb45409 100644 --- a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc +++ b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc @@ -188,4 +188,33 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_with_external_graph) { ret = TensorFlowModelParser::AddExternalGraph(root_graph); EXPECT_EQ(ret, INTERNAL_ERROR); } + +TEST_F(UtestTensorflowParser, optimize_snapshot) { + domi::tensorflow::GraphDef graph_def; + + auto mul_node = graph_def.add_node(); + mul_node->set_name("optimizer/Mul"); + mul_node->set_op("Mul"); + mul_node->add_input("Snapshot:0"); + + auto snapshot_node = graph_def.add_node(); + snapshot_node->set_name("Snapshot"); + snapshot_node->set_op("Snapshot"); + snapshot_node->add_input("loss_scale/read:0"); + snapshot_node->add_input("^ShuffleNet/AssignMovingAvg"); + + auto identity_node = graph_def.add_node(); + identity_node->set_name("loss_scale/read"); + identity_node->set_op("Identity"); + identity_node->add_input("loss_scale/ref:0"); + + auto assign_node = graph_def.add_node(); + assign_node->set_name("ShuffleNet/AssignMovingAvg"); + assign_node->set_op("AssignSub"); + assign_node->add_input("ShuffleNet/moving_mean:0"); + + Status ret = TensorFlowModelParser().GraphDefOptimize(&graph_def); + EXPECT_EQ(ret, ge::SUCCESS); +} + } // namespace ge