@@ -2563,6 +2563,7 @@ Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_m | |||||
domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name]; | domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name]; | ||||
GE_CHECK_NOTNULL(output_node_def); | GE_CHECK_NOTNULL(output_node_def); | ||||
auto inputs = output_node_def->mutable_input(); | auto inputs = output_node_def->mutable_input(); | ||||
std::vector<std::string> added_inputs; | |||||
for (auto &input : *inputs) { | for (auto &input : *inputs) { | ||||
string node_name; | string node_name; | ||||
bool is_control = false; | bool is_control = false; | ||||
@@ -2596,12 +2597,15 @@ Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_m | |||||
} | } | ||||
} | } | ||||
if (!is_exist_input) { | if (!is_exist_input) { | ||||
output_node_def->add_input("^" + item); | |||||
GELOGD("Optimize Snapshot node, dest:%s, set control input:%s.", output_node_name.c_str(), item.c_str()); | |||||
added_inputs.push_back("^" + item); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
} | } | ||||
for (std::string added_input : added_inputs) { | |||||
GELOGD("Optimize Snapshot node, dest:%s, set control input:%s.", output_node_name.c_str(), added_input.c_str()); | |||||
output_node_def->add_input(added_input); | |||||
} | |||||
} | } | ||||
// Clear the input of snapshot and become an isolated node | // Clear the input of snapshot and become an isolated node | ||||
curr_mode_def->clear_input(); | curr_mode_def->clear_input(); | ||||
@@ -98,7 +98,7 @@ void ErrorManager::SetStage(const std::string &first_stage, const std::string &s | |||||
} | } | ||||
struct error_message::Context &ErrorManager::GetErrorManagerContext() { | struct error_message::Context &ErrorManager::GetErrorManagerContext() { | ||||
struct error_message::Context error_context; | |||||
static struct error_message::Context error_context; | |||||
return error_context; | return error_context; | ||||
} | } | ||||
@@ -2387,5 +2387,23 @@ TEST_F(STestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test) | |||||
Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize); | Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize); | ||||
EXPECT_EQ(ret, ge::PARAM_INVALID); | EXPECT_EQ(ret, ge::PARAM_INVALID); | ||||
} | } | ||||
TEST_F(STestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) { | |||||
std::string caseDir = __FILE__; | |||||
std::size_t idx = caseDir.find_last_of("/"); | |||||
caseDir = caseDir.substr(0, idx); | |||||
const std::string root_proto = caseDir + "/origin_models/test_snapshot.pb"; | |||||
domi::tensorflow::GraphDef graphDef; | |||||
bool protoRet = | |||||
parser::ReadProtoFromBinaryFile(root_proto.c_str(), &graphDef); | |||||
ASSERT_EQ(protoRet, true); | |||||
TensorFlowModelParser tensorflow_parser; | |||||
ge::ComputeGraphPtr root_graph = | |||||
ge::parser::MakeShared<ge::ComputeGraph>("tmp_graph"); | |||||
Status ret = tensorflow_parser.ParseProto( | |||||
reinterpret_cast<google::protobuf::Message *>(&graphDef), root_graph); | |||||
EXPECT_EQ(FAILED, ret); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -188,4 +188,33 @@ TEST_F(UtestTensorflowParser, tensorflow_parser_with_external_graph) { | |||||
ret = TensorFlowModelParser::AddExternalGraph(root_graph); | ret = TensorFlowModelParser::AddExternalGraph(root_graph); | ||||
EXPECT_EQ(ret, INTERNAL_ERROR); | EXPECT_EQ(ret, INTERNAL_ERROR); | ||||
} | } | ||||
TEST_F(UtestTensorflowParser, optimize_snapshot) { | |||||
domi::tensorflow::GraphDef graph_def; | |||||
auto mul_node = graph_def.add_node(); | |||||
mul_node->set_name("optimizer/Mul"); | |||||
mul_node->set_op("Mul"); | |||||
mul_node->add_input("Snapshot:0"); | |||||
auto snapshot_node = graph_def.add_node(); | |||||
snapshot_node->set_name("Snapshot"); | |||||
snapshot_node->set_op("Snapshot"); | |||||
snapshot_node->add_input("loss_scale/read:0"); | |||||
snapshot_node->add_input("^ShuffleNet/AssignMovingAvg"); | |||||
auto identity_node = graph_def.add_node(); | |||||
identity_node->set_name("loss_scale/read"); | |||||
identity_node->set_op("Identity"); | |||||
identity_node->add_input("loss_scale/ref:0"); | |||||
auto assign_node = graph_def.add_node(); | |||||
assign_node->set_name("ShuffleNet/AssignMovingAvg"); | |||||
assign_node->set_op("AssignSub"); | |||||
assign_node->add_input("ShuffleNet/moving_mean:0"); | |||||
Status ret = TensorFlowModelParser().GraphDefOptimize(&graph_def); | |||||
EXPECT_EQ(ret, ge::SUCCESS); | |||||
} | |||||
} // namespace ge | } // namespace ge |