From 2534c1b474e0d43df51ed655dd0d0058cfb9b37c Mon Sep 17 00:00:00 2001 From: y00500818 Date: Thu, 7 Jan 2021 16:38:29 +0800 Subject: [PATCH] add control edge after remove input --- parser/tensorflow/tensorflow_parser.cc | 42 ++++++++++++++++++++++++++++++++-- parser/tensorflow/tensorflow_parser.h | 10 +++++++- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index 3791ace..4304723 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -3221,7 +3221,7 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap } // 2.4 remove the input const nodes - Status ret = RemoveInputs(current_node, unused_inputs); + Status ret = RemoveInputs(graph_def, current_node, unused_inputs, all_nodedef_map); if (ret != SUCCESS) { ErrorManager::GetInstance().ATCReportErrMessage("E12006", {"opname"}, {current_op_name}); GELOGE(INTERNAL_ERROR, "Op[%s] remove input failed.", current_op_name.c_str()); @@ -3232,6 +3232,34 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap return SUCCESS; } +Status TensorFlowModelParser::AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def, + domi::tensorflow::NodeDef *node_def, + const map &all_node_map, + const vector &removed_inputs_vec) { + GE_CHECK_NOTNULL(graph_def); + GE_CHECK_NOTNULL(node_def); + for (const auto &remove_input : removed_inputs_vec) { + string input_node_name = NodeNameFromInput(remove_input); + auto it = all_node_map.find(input_node_name); + if (it == all_node_map.end()) { + GELOGE(FAILED, "Can not find node name:%s in all node map.", input_node_name.c_str()); + return FAILED; + } + NodeDef *input_node_def = it->second; + if (input_node_def->op() == SWITCH || input_node_def->op() == REFSWITCH) { + NodeDef *identity_node_def = graph_def->add_node(); + GE_CHECK_NOTNULL(identity_node_def); + input_node_name = input_node_name + "identity"; + identity_node_def->set_name(input_node_name); + identity_node_def->set_op(IDENTITY); + identity_node_def->add_input(remove_input); + } + string control_input = "^" + input_node_name; + node_def->add_input(control_input); + GELOGD("Add control input:%s for node:%s", control_input.c_str(), node_def->name().c_str()); + } + return SUCCESS; +} /** * @ingroup domi_omg * @brief Delete input from nodedef @@ -3241,7 +3269,10 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap * @return false remove failed * */ -Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def, const set &remove_index_set) { +Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::GraphDef *graph_def, + domi::tensorflow::NodeDef *node_def, + const set &remove_index_set, + const map &all_node_map) { GE_CHECK_NOTNULL(node_def); if (remove_index_set.empty()) { GELOGI("The size of remove_index_set is zero."); @@ -3258,6 +3289,7 @@ Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def, RemoveInputAttr(node_def, remove_inputs_map); int index = 0; + vector removed_inputs_vec; auto *inputs = node_def->mutable_input(); for (auto input_it = inputs->begin(); input_it != inputs->end(); ++index) { // 1.decide whether to remove the input @@ -3269,6 +3301,7 @@ Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def, std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end()) { GELOGD("Remove input:%s, index:%d", remove_input_name.c_str(), index); flag = true; + removed_inputs_vec.emplace_back(remove_input_name); break; } } @@ -3281,6 +3314,11 @@ Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def, } } + Status ret = AddControlEdgeAfterRemoveInputs(graph_def, node_def, all_node_map, removed_inputs_vec); + if (ret != SUCCESS) { + GELOGE(FAILED, "Add control edges for node:%s failed.", node_def->name().c_str()); + return FAILED; + } return SUCCESS; } diff --git a/parser/tensorflow/tensorflow_parser.h b/parser/tensorflow/tensorflow_parser.h index 4a65659..d281590 100644 --- a/parser/tensorflow/tensorflow_parser.h +++ b/parser/tensorflow/tensorflow_parser.h @@ -537,7 +537,15 @@ class TensorFlowModelParser : public domi::ModelParser { * @return false remove failed * */ - Status RemoveInputs(domi::tensorflow::NodeDef *node_def, const set &remove_index_set); + Status RemoveInputs(domi::tensorflow::GraphDef *graph_def, + domi::tensorflow::NodeDef *node_def, + const set &remove_index_set, + const map &all_node_map); + + Status AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def, + domi::tensorflow::NodeDef *node_def, + const map &all_node_map, + const vector &removed_inputs_vec); void RemoveInputAttr(domi::tensorflow::NodeDef *node_def, const map> &remove_inputs_map);