diff --git a/parser/tensorflow/tensorflow_parser.cc b/parser/tensorflow/tensorflow_parser.cc index 568d8a1..8fde8fc 100644 --- a/parser/tensorflow/tensorflow_parser.cc +++ b/parser/tensorflow/tensorflow_parser.cc @@ -2109,9 +2109,12 @@ Status TensorFlowModelParser::NormalizeInputOrOutputMap(const string &node_name, std::set compare_set; for (auto &pair : pairs) { + bool is_fusion_child = (fusion_op_children_.find(node_name) != fusion_op_children_.end()) || + (fusion_op_children_.find(iter->first) != fusion_op_children_.end()); + bool is_fusion_op = (fusion_op_type_map_.find(node_name) != fusion_op_type_map_.end()) || + (fusion_op_type_map_.find(iter->first) != fusion_op_type_map_.end()); if (((pair.first == ge::kFusionDisableIndex) || (pair.second == ge::kFusionDisableIndex)) && - ((fusion_op_children_.find(node_name) != fusion_op_children_.end()) || - (fusion_op_children_.find(iter->first) != fusion_op_children_.end()))) { + (is_fusion_child || is_fusion_op)) { // The edge will be cut off at the back, ignoring continue; }