@@ -2503,6 +2503,90 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_pro | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
// For the identity operator whose output is "_retval", optimize it. | |||||
Status TensorFlowModelParser::OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map, | |||||
const string &curr_node_name, bool &clear_input_flag) { | |||||
auto context_iter = op_node_context_map_.find(curr_node_name); | |||||
if (context_iter == op_node_context_map_.end()) { | |||||
REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", curr_node_name.c_str()); | |||||
GELOGE(FAILED, "Can't find op node context."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
OpNodeContext op_node_context = context_iter->second; | |||||
const std::map<std::string, NodeDef *>::const_iterator node_def_iter = nodedef_map.find(curr_node_name); | |||||
if (node_def_iter == nodedef_map.cend()) { | |||||
REPORT_INNER_ERROR("E19999", "Node:%s can't find in nodedef_map, check invalid", curr_node_name.c_str()); | |||||
GELOGE(FAILED, "Can't find nodedef"); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
domi::tensorflow::NodeDef *curr_node_def = node_def_iter->second; | |||||
GE_CHECK_NOTNULL(curr_node_def); | |||||
bool has_out_retval = false; | |||||
// For the identity operator whose output is "_retval", optimize it | |||||
std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map = op_node_context.output_map; | |||||
for (auto output_iter = output_map.cbegin(); output_iter != output_map.cend(); ++output_iter) { | |||||
const string &output_node_name = output_iter->first; | |||||
domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name]; | |||||
GE_CHECK_NOTNULL(output_node_def); | |||||
if (output_node_def->op() == "_Retval") { | |||||
GELOGW("_Retval Identity need optimize. node:%s", curr_node_name.c_str()); | |||||
output_node_def->set_input(0, curr_node_def->input(0).c_str()); | |||||
has_out_retval = true; | |||||
GELOGW("op %s set input(0):%s.", output_node_def->name().c_str(), curr_node_def->input(0).c_str()); | |||||
} | |||||
} | |||||
// Deal with non _Retval output operator of Identity. | |||||
if (has_out_retval) { | |||||
std::map<std::string, std::vector<std::pair<int32_t, int32_t>>>::const_iterator output_iter = output_map.begin(); | |||||
for (; output_iter != output_map.end(); ++output_iter) { | |||||
const string &output_node_name = output_iter->first; | |||||
GELOGW("[test]node name:%s.", output_node_name.c_str()); | |||||
domi::tensorflow::NodeDef *output_node_def = nodedef_map[output_node_name]; | |||||
GE_CHECK_NOTNULL(output_node_def); | |||||
GELOGW("[test]op name:%s, input size:%u.", output_node_def->op().c_str(), output_node_def->input_size()); | |||||
GE_IF_BOOL_EXEC(output_node_def->op() == "_Retval", continue); | |||||
for (int k = 0; k < output_node_def->input_size(); ++k) { | |||||
GELOGW("[test]input name:%s, curr_node_name:%s.", output_node_def->input(k).c_str(), curr_node_name.c_str()); | |||||
bool is_control = false; | |||||
string node_name; | |||||
GE_RETURN_IF_ERROR(CheckInputNodeName(output_node_def->input(k), &node_name, nullptr, &is_control)); | |||||
GE_IF_BOOL_EXEC( | |||||
node_name == curr_node_name, output_node_def->set_input(k, is_control ? ("^" + curr_node_def->input(0)).c_str() : curr_node_def->input(0).c_str()); | |||||
GELOGW("%s op set input(%d):%s, is_control:%d.", output_node_def->name().c_str(), k, curr_node_def->input(0).c_str(), is_control);) | |||||
} | |||||
} | |||||
clear_input_flag = true; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status TensorFlowModelParser::GraphDefOptimizeIdentity(domi::tensorflow::GraphDef *graph_def, | |||||
map<string, NodeDef *> &nodedef_map, | |||||
const vector<NodeDef *> &nodedef_to_optimize) { | |||||
GE_CHECK_NOTNULL(graph_def); | |||||
if (!nodedef_to_optimize.empty()) { | |||||
// Building input and input relationships for all OP nodes | |||||
GE_RETURN_IF_ERROR(GetOpNodesContextFromGraph(*graph_def)); | |||||
} else { | |||||
return SUCCESS; | |||||
} | |||||
for (auto &curr_node_def : nodedef_to_optimize) { | |||||
GE_CHECK_NOTNULL(curr_node_def); | |||||
bool clear_input_flag = false; | |||||
const string &curr_node_name = curr_node_def->name(); | |||||
GE_RETURN_IF_ERROR(OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag)); | |||||
if (clear_input_flag) { | |||||
GELOGW("[test]node name:%s.", curr_node_name.c_str()); | |||||
curr_node_def->clear_input(); | |||||
} | |||||
} | |||||
GELOGI("GraphDefOptimizeIdentity success."); | |||||
return SUCCESS; | |||||
} | |||||
Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_mode_def, | Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_mode_def, | ||||
map<string, NodeDef *> &nodedef_map, | map<string, NodeDef *> &nodedef_map, | ||||
const std::pair<string, int> &input_data, | const std::pair<string, int> &input_data, | ||||
@@ -2818,6 +2902,8 @@ Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph | |||||
GE_CHECK_NOTNULL(graph_def); | GE_CHECK_NOTNULL(graph_def); | ||||
map<string, NodeDef *> nodedef_map; | map<string, NodeDef *> nodedef_map; | ||||
vector<string> op_node_name_list; | vector<string> op_node_name_list; | ||||
// Save Identity and ReadVariableOp | |||||
vector<NodeDef *> identity_to_optimize; | |||||
// Save Snapshot | // Save Snapshot | ||||
vector<NodeDef *> snapshot_to_optimize; | vector<NodeDef *> snapshot_to_optimize; | ||||
@@ -2827,12 +2913,16 @@ Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph | |||||
const string &node_name = node_def->name(); | const string &node_name = node_def->name(); | ||||
Status ret = AddFmkNodeDefToMap(node_def, op_node_name_list); | Status ret = AddFmkNodeDefToMap(node_def, op_node_name_list); | ||||
GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed"); | GE_CHK_STATUS_EXEC(ret, return PARAM_INVALID, "add node_def to map failed"); | ||||
if (node_def->op() == ge::parser::SNAPSHOT) { | |||||
if (node_def->op() == ge::parser::IDENTITY || node_def->op() == ge::parser::READVARIABLEOP) { | |||||
identity_to_optimize.push_back(node_def); | |||||
} else if (node_def->op() == ge::parser::SNAPSHOT) { | |||||
snapshot_to_optimize.push_back(node_def); | snapshot_to_optimize.push_back(node_def); | ||||
} | } | ||||
nodedef_map[node_name] = node_def; | nodedef_map[node_name] = node_def; | ||||
} | } | ||||
// Optimize for Identity/ReadVariableOp | |||||
GE_RETURN_IF_ERROR(GraphDefOptimizeIdentity(graph_def, nodedef_map, identity_to_optimize)); | |||||
// Optimize for Snapshot | // Optimize for Snapshot | ||||
GE_RETURN_IF_ERROR(GraphDefOptimizeSnapShot(graph_def, nodedef_map, snapshot_to_optimize)); | GE_RETURN_IF_ERROR(GraphDefOptimizeSnapShot(graph_def, nodedef_map, snapshot_to_optimize)); | ||||
@@ -428,7 +428,28 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
* @brief Delete the connection relationship of the identity operator connecting the Arg node in graphdef | * @brief Delete the connection relationship of the identity operator connecting the Arg node in graphdef | ||||
*/ | */ | ||||
Status GraphDefOptimize(domi::tensorflow::GraphDef *graph_def); | Status GraphDefOptimize(domi::tensorflow::GraphDef *graph_def); | ||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief Optimize for Identity/ReadVariableOp operator | |||||
* @param [in] graph_def GraphDef to be optimized | |||||
* @param [in] nodedef_map Map of all nodes in graph | |||||
* @param [in] nodedef_to_optimize vector of NodeDef to be optimized | |||||
* @return SUCCESS optimize successfully | |||||
* @return others failed | |||||
*/ | |||||
Status GraphDefOptimizeIdentity(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map, | |||||
const vector<NodeDef *> &nodedef_to_optimize); | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief For the identity operator whose output is "_retval", optimize it. | |||||
* @param [in] nodedef_map Map of all nodes in graph | |||||
* @param [in] curr_node_name Name of node to be optimized | |||||
* @param [in] clear_input_flag Flag of whether to clear the input of the current node | |||||
* @return SUCCESS optimize successfully | |||||
* @return others failed | |||||
*/ | |||||
Status OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map, const string &curr_node_name, | |||||
bool &clear_input_flag); | |||||
Status GraphDefOptimizeSnapShot(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map, | Status GraphDefOptimizeSnapShot(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map, | ||||
const vector<NodeDef *> &nodedef_to_optimize); | const vector<NodeDef *> &nodedef_to_optimize); | ||||
Status GraphDefOptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def, | Status GraphDefOptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def, | ||||
@@ -2741,6 +2741,29 @@ TEST_F(STestTensorflowParser, tensorflow_UpdateEdgesControlInfo_test) | |||||
model_parser.UpdateEdgesControlInfo(info); | model_parser.UpdateEdgesControlInfo(info); | ||||
} | } | ||||
TEST_F(STestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test) | |||||
{ | |||||
TensorFlowModelParser model_parser; | |||||
NodeDef *node_def = new NodeDef(); | |||||
node_def->set_name("Placeholder"); | |||||
node_def->set_op("Placeholder_0"); | |||||
std::map<string, NodeDef *> nodedef_map; | |||||
nodedef_map.emplace("Placeholder", node_def); | |||||
std::string curr_node_name = "Placeholder"; | |||||
bool clear_input_flag = true; | |||||
Status ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag); | |||||
EXPECT_EQ(ret, INTERNAL_ERROR); | |||||
GraphDef graph; | |||||
curr_node_name = "pre_node_a"; | |||||
nodedef_map.emplace("pre_node_a", node_def); | |||||
node_def->set_op("pre_node_a"); | |||||
GenOriginContext(&model_parser, curr_node_name); | |||||
ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
delete node_def; | |||||
} | |||||
TEST_F(STestTensorflowParser, tensorflow_OptimizeSnapShot_test) | TEST_F(STestTensorflowParser, tensorflow_OptimizeSnapShot_test) | ||||
{ | { | ||||
TensorFlowModelParser model_parser; | TensorFlowModelParser model_parser; | ||||
@@ -2912,6 +2935,25 @@ TEST_F(STestTensorflowParser, tensorflow_AddControlEdgeAfterRemoveInputs_test) | |||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
} | } | ||||
TEST_F(STestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test) | |||||
{ | |||||
tensorflow::GraphDef graph_def; | |||||
TensorFlowModelParser tensorflow_parser; | |||||
tensorflow::NodeDef *node_def = initNodeDef(); | |||||
node_def->set_name("post_node_d"); | |||||
std::map<string, NodeDef *> nodedef_map; | |||||
nodedef_map.emplace("post_node_d", node_def); | |||||
nodedef_map.emplace("post_node_a", node_def); | |||||
nodedef_map.emplace("post_node_b", node_def); | |||||
std::vector<NodeDef *> nodedef_to_optimize; | |||||
nodedef_to_optimize.emplace_back(node_def); | |||||
std::string curr_node_name = "post_node_b"; | |||||
GenOriginContext(&tensorflow_parser, curr_node_name); | |||||
Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize); | |||||
EXPECT_EQ(ret, ge::PARAM_INVALID); | |||||
} | |||||
TEST_F(STestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) { | TEST_F(STestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) { | ||||
std::string caseDir = __FILE__; | std::string caseDir = __FILE__; | ||||
std::size_t idx = caseDir.find_last_of("/"); | std::size_t idx = caseDir.find_last_of("/"); | ||||
@@ -2853,6 +2853,29 @@ TEST_F(UtestTensorflowParser, tensorflow_UpdateEdgesControlInfo_test) | |||||
model_parser.UpdateEdgesControlInfo(info); | model_parser.UpdateEdgesControlInfo(info); | ||||
} | } | ||||
TEST_F(UtestTensorflowParser, tensorflow_OptimizeIdentityByOutput_test) | |||||
{ | |||||
TensorFlowModelParser model_parser; | |||||
NodeDef *node_def = new NodeDef(); | |||||
node_def->set_name("Placeholder"); | |||||
node_def->set_op("Placeholder_0"); | |||||
std::map<string, NodeDef *> nodedef_map; | |||||
nodedef_map.emplace("Placeholder", node_def); | |||||
std::string curr_node_name = "Placeholder"; | |||||
bool clear_input_flag = true; | |||||
Status ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag); | |||||
EXPECT_EQ(ret, INTERNAL_ERROR); | |||||
GraphDef graph; | |||||
curr_node_name = "pre_node_a"; | |||||
nodedef_map.emplace("pre_node_a", node_def); | |||||
node_def->set_op("pre_node_a"); | |||||
GenOriginContext(&model_parser, curr_node_name); | |||||
ret = model_parser.OptimizeIdentityByOutput(nodedef_map, curr_node_name, clear_input_flag); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
delete node_def; | |||||
} | |||||
TEST_F(UtestTensorflowParser, tensorflow_OptimizeSnapShot_test) | TEST_F(UtestTensorflowParser, tensorflow_OptimizeSnapShot_test) | ||||
{ | { | ||||
TensorFlowModelParser model_parser; | TensorFlowModelParser model_parser; | ||||
@@ -3024,7 +3047,25 @@ TEST_F(UtestTensorflowParser, tensorflow_AddControlEdgeAfterRemoveInputs_test) | |||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
} | } | ||||
TEST_F(UtestTensorflowParser, tensorflow_GraphDefOptimizeIdentity_test) | |||||
{ | |||||
tensorflow::GraphDef graph_def; | |||||
TensorFlowModelParser tensorflow_parser; | |||||
tensorflow::NodeDef *node_def = initNodeDef(); | |||||
node_def->set_name("post_node_d"); | |||||
std::map<string, NodeDef *> nodedef_map; | |||||
nodedef_map.emplace("post_node_d", node_def); | |||||
nodedef_map.emplace("post_node_a", node_def); | |||||
nodedef_map.emplace("post_node_b", node_def); | |||||
std::vector<NodeDef *> nodedef_to_optimize; | |||||
nodedef_to_optimize.emplace_back(node_def); | |||||
std::string curr_node_name = "post_node_b"; | |||||
GenOriginContext(&tensorflow_parser, curr_node_name); | |||||
Status ret = tensorflow_parser.GraphDefOptimizeIdentity(&graph_def, nodedef_map, nodedef_to_optimize); | |||||
EXPECT_EQ(ret, ge::PARAM_INVALID); | |||||
} | |||||
TEST_F(UtestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) { | TEST_F(UtestTensorflowParser, tensorflow_optimizer_snapshot_no_retval_test) { | ||||
std::string caseDir = __FILE__; | std::string caseDir = __FILE__; | ||||
std::size_t idx = caseDir.find_last_of("/"); | std::size_t idx = caseDir.find_last_of("/"); | ||||