Browse Source

Pre Merge pull request !613 from 梁昊/lh

pull/613/MERGE
梁昊 Gitee 2 years ago
parent
commit
db754d9bcc
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 6 additions and 5 deletions
  1. +6
    -5
      parser/tensorflow/tensorflow_parser.cc

+ 6
- 5
parser/tensorflow/tensorflow_parser.cc View File

@@ -3373,7 +3373,7 @@ Status TensorFlowModelParser::AddControlEdgeAfterRemoveInputs(domi::tensorflow::
return FAILED; return FAILED;
} }
NodeDef *input_node_def = it->second; NodeDef *input_node_def = it->second;
if (input_node_def->op() == parser::SWITCH || input_node_def->op() == parser::REFSWITCH) {
if ((input_node_def->op() == parser::SWITCH) || (input_node_def->op() == parser::REFSWITCH)) {
NodeDef *identity_node_def = graph_def->add_node(); NodeDef *identity_node_def = graph_def->add_node();
GE_CHECK_NOTNULL(identity_node_def); GE_CHECK_NOTNULL(identity_node_def);
std::string remove_input_name = remove_input; std::string remove_input_name = remove_input;
@@ -3426,8 +3426,8 @@ Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::GraphDef *graph_def
for (auto &remove_input : remove_inputs_map) { for (auto &remove_input : remove_inputs_map) {
string remove_input_name = remove_input.first; string remove_input_name = remove_input.first;
vector<int> remove_input_indexs = remove_input.second; vector<int> remove_input_indexs = remove_input.second;
if ((*input_it) == remove_input_name &&
std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end()) {
if (((*input_it) == remove_input_name) &&
(std::find(remove_input_indexs.begin(), remove_input_indexs.end(), index) != remove_input_indexs.end())) {
GELOGD("Remove input:%s, index:%d", remove_input_name.c_str(), index); GELOGD("Remove input:%s, index:%d", remove_input_name.c_str(), index);
flag = true; flag = true;
removed_inputs_vec.emplace_back(remove_input_name); removed_inputs_vec.emplace_back(remove_input_name);
@@ -3481,7 +3481,7 @@ void TensorFlowModelParser::RemoveInputAttr(domi::tensorflow::NodeDef *node_def,


if (flag) { if (flag) {
// 2.1 remove the input attr // 2.1 remove the input attr
if (!tmp_attr->empty() && attr_it != tmp_attr->end()) {
if (!tmp_attr->empty() && (attr_it != tmp_attr->end())) {
attr_it = tmp_attr->erase(attr_it); attr_it = tmp_attr->erase(attr_it);
} else { } else {
++attr_it; ++attr_it;
@@ -3990,7 +3990,7 @@ Status TensorFlowModelParser::UpdateOutputsInfo(const ParserUtils::OutputMapping
Status TensorFlowModelParser::AddExternalGraph(const ComputeGraphPtr &root_graph) { Status TensorFlowModelParser::AddExternalGraph(const ComputeGraphPtr &root_graph) {
GE_CHECK_NOTNULL(root_graph); GE_CHECK_NOTNULL(root_graph);
for (const NodePtr &node : root_graph->GetAllNodes()) { for (const NodePtr &node : root_graph->GetAllNodes()) {
if (node == nullptr || node->GetOpDesc() == nullptr) {
if ((node == nullptr) || (node->GetOpDesc() == nullptr)) {
continue; continue;
} }
std::string model_data; std::string model_data;
@@ -4010,6 +4010,7 @@ Status TensorFlowModelParser::AddExternalGraph(const ComputeGraphPtr &root_graph
REPORT_CALL_ERROR("E19999", "Failed to map and add sub graph, node:%s.", node->GetName().c_str()); REPORT_CALL_ERROR("E19999", "Failed to map and add sub graph, node:%s.", node->GetName().c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
(void)node->GetOpDesc()->DelAttr(kExternalModel);
} }
} }
return SUCCESS; return SUCCESS;


Loading…
Cancel
Save