Browse Source

!202 add control edge after remove input

Merge pull request !202 from yangyongqiang/dev_tensorflow_parser
pull/202/MERGE
i-robot Gitee 4 years ago
parent
commit
2c18ba6530
2 changed files with 49 additions and 3 deletions
  1. +40
    -2
      parser/tensorflow/tensorflow_parser.cc
  2. +9
    -1
      parser/tensorflow/tensorflow_parser.h

+ 40
- 2
parser/tensorflow/tensorflow_parser.cc View File

@@ -3221,7 +3221,7 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap
}

// 2.4 remove the input const nodes
Status ret = RemoveInputs(current_node, unused_inputs);
Status ret = RemoveInputs(graph_def, current_node, unused_inputs, all_nodedef_map);
if (ret != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E12006", {"opname"}, {current_op_name});
GELOGE(INTERNAL_ERROR, "Op[%s] remove input failed.", current_op_name.c_str());
@@ -3232,6 +3232,34 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap
return SUCCESS;
}

Status TensorFlowModelParser::AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def,
domi::tensorflow::NodeDef *node_def,
const map<string, NodeDef *> &all_node_map,
const vector<string> &removed_inputs_vec) {
GE_CHECK_NOTNULL(graph_def);
GE_CHECK_NOTNULL(node_def);
for (const auto &remove_input : removed_inputs_vec) {
string input_node_name = NodeNameFromInput(remove_input);
auto it = all_node_map.find(input_node_name);
if (it == all_node_map.end()) {
GELOGE(FAILED, "Can not find node name:%s in all node map.", input_node_name.c_str());
return FAILED;
}
NodeDef *input_node_def = it->second;
if (input_node_def->op() == SWITCH || input_node_def->op() == REFSWITCH) {
NodeDef *identity_node_def = graph_def->add_node();
GE_CHECK_NOTNULL(identity_node_def);
input_node_name = input_node_name + "identity";
identity_node_def->set_name(input_node_name);
identity_node_def->set_op(IDENTITY);
identity_node_def->add_input(remove_input);
}
string control_input = "^" + input_node_name;
node_def->add_input(control_input);
GELOGD("Add control input:%s for node:%s", control_input.c_str(), node_def->name().c_str());
}
return SUCCESS;
}
/**
* @ingroup domi_omg
* @brief Delete input from nodedef
@@ -3241,7 +3269,10 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap
* @return false remove failed
*
*/
Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def, const set<uint32_t> &remove_index_set) {
Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::GraphDef *graph_def,
domi::tensorflow::NodeDef *node_def,
const set<uint32_t> &remove_index_set,
const map<string, NodeDef *> &all_node_map) {
GE_CHECK_NOTNULL(node_def);
if (remove_index_set.empty()) {
GELOGI("The size of remove_index_set is zero.");
@@ -3258,6 +3289,7 @@ Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def,
RemoveInputAttr(node_def, remove_inputs_map);

int index = 0;
vector<string> removed_inputs_vec;
auto *inputs = node_def->mutable_input();
for (auto input_it = inputs->begin(); input_it != inputs->end(); ++index) {
// 1.decide whether to remove the input
@@ -3269,6 +3301,7 @@ Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def,
std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end()) {
GELOGD("Remove input:%s, index:%d", remove_input_name.c_str(), index);
flag = true;
removed_inputs_vec.emplace_back(remove_input_name);
break;
}
}
@@ -3281,6 +3314,11 @@ Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::NodeDef *node_def,
}
}

Status ret = AddControlEdgeAfterRemoveInputs(graph_def, node_def, all_node_map, removed_inputs_vec);
if (ret != SUCCESS) {
GELOGE(FAILED, "Add control edges for node:%s failed.", node_def->name().c_str());
return FAILED;
}
return SUCCESS;
}



+ 9
- 1
parser/tensorflow/tensorflow_parser.h View File

@@ -537,7 +537,15 @@ class TensorFlowModelParser : public domi::ModelParser {
* @return false remove failed
*
*/
Status RemoveInputs(domi::tensorflow::NodeDef *node_def, const set<uint32_t> &remove_index_set);
Status RemoveInputs(domi::tensorflow::GraphDef *graph_def,
domi::tensorflow::NodeDef *node_def,
const set<uint32_t> &remove_index_set,
const map<string, NodeDef *> &all_node_map);

Status AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def,
domi::tensorflow::NodeDef *node_def,
const map<string, NodeDef *> &all_node_map,
const vector<string> &removed_inputs_vec);

void RemoveInputAttr(domi::tensorflow::NodeDef *node_def, const map<string, vector<int>> &remove_inputs_map);



Loading…
Cancel
Save