From 229e9159d9f52566705b52cfadcea40766b0d4cc Mon Sep 17 00:00:00 2001 From: wangzhengjun Date: Tue, 11 Jan 2022 14:30:17 +0800 Subject: [PATCH] DynamicGetNext not fusion --- parser/tensorflow/graph_optimizer.cc | 25 ++++++++++++++++++------- parser/tensorflow/graph_optimizer.h | 3 ++- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/parser/tensorflow/graph_optimizer.cc b/parser/tensorflow/graph_optimizer.cc index 829b576..9769b2a 100644 --- a/parser/tensorflow/graph_optimizer.cc +++ b/parser/tensorflow/graph_optimizer.cc @@ -37,6 +37,10 @@ namespace { const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal"; const char *const kShapeNodeType = "Shape"; const char *const kShapeNodeNamePrefix = "getnext_shape_"; +const char *const kIteratorType = "Iterator"; +const char *const kIteratorV2Type = "IteratorV2"; +const char *const kGetNextType = "IteratorGetNext"; +const char *const kDynGetNextType = "DynamicGetNext"; } // namespace Status ParserGraphOptimizer::FusionFmkop() { @@ -66,28 +70,33 @@ Status ParserGraphOptimizer::FusionFmkop() { Status ParserGraphOptimizer::MarkForFusion(unordered_map> &node_cluster_map) { GE_CHECK_NOTNULL(graph_); bool has_get_next = false; + bool has_dyn_get_next = false; for (auto node : graph_->GetDirectNode()) { GE_CHECK_NOTNULL(node); + if (node->GetType() == kDynGetNextType) { + has_dyn_get_next = true; + break; + } GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue); - string type = ""; + string type; GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); - if (type == "IteratorGetNext") { + if (type == kGetNextType) { has_get_next = true; break; } } - return GetFusionCluster(has_get_next, node_cluster_map); + return GetFusionCluster(has_get_next, has_dyn_get_next, node_cluster_map); } -Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next, +Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next, const bool has_dyn_get_next, unordered_map> &node_cluster_map) { GE_CHECK_NOTNULL(graph_); for (auto node : graph_->GetDirectNode()) { GE_CHECK_NOTNULL(node); GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue) - string type = ""; + string type; GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); - if (type == "IteratorGetNext") { + if (type == kGetNextType) { vector temp_node_cluser; for (auto in_anchor : node->GetAllInDataAnchors()) { OutDataAnchorPtr peer_out_anchor = in_anchor->GetPeerOutAnchor(); @@ -119,7 +128,9 @@ Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next, GELOGI("MarkForFusion, IteratorGetNext graph mark success."); } - if (!has_get_next && (type == "Iterator" || type == "IteratorV2")) { + const bool dataset_init = (!has_get_next) && (!has_dyn_get_next) && + ((type == kIteratorType) || (type == kIteratorV2Type)); + if (dataset_init) { GE_CHK_STATUS_RET(FindFmkNodeCluser(node_cluster_map), "find framework node to be fused fail."); GELOGI("MarkForFusion, Iterator init graph mark success."); } diff --git a/parser/tensorflow/graph_optimizer.h b/parser/tensorflow/graph_optimizer.h index 420c2b5..728230e 100644 --- a/parser/tensorflow/graph_optimizer.h +++ b/parser/tensorflow/graph_optimizer.h @@ -43,7 +43,8 @@ class ParserGraphOptimizer { domi::Status MarkForFusion(std::unordered_map> &node_cluster_map); - domi::Status GetFusionCluster(const bool has_get_next, unordered_map> &node_cluster_map); + domi::Status GetFusionCluster(const bool has_get_next, const bool has_dyn_get_next, + unordered_map> &node_cluster_map); domi::Status UpdateGraph(std::vector &nodes);