|
@@ -37,6 +37,10 @@ namespace { |
|
|
const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal"; |
|
|
const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal"; |
|
|
const char *const kShapeNodeType = "Shape"; |
|
|
const char *const kShapeNodeType = "Shape"; |
|
|
const char *const kShapeNodeNamePrefix = "getnext_shape_"; |
|
|
const char *const kShapeNodeNamePrefix = "getnext_shape_"; |
|
|
|
|
|
const char *const kIteratorType = "Iterator"; |
|
|
|
|
|
const char *const kIteratorV2Type = "IteratorV2"; |
|
|
|
|
|
const char *const kGetNextType = "IteratorGetNext"; |
|
|
|
|
|
const char *const kDynGetNextType = "DynamicGetNext"; |
|
|
} // namespace |
|
|
} // namespace |
|
|
|
|
|
|
|
|
Status ParserGraphOptimizer::FusionFmkop() { |
|
|
Status ParserGraphOptimizer::FusionFmkop() { |
|
@@ -66,28 +70,33 @@ Status ParserGraphOptimizer::FusionFmkop() { |
|
|
Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr>> &node_cluster_map) { |
|
|
Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr>> &node_cluster_map) { |
|
|
GE_CHECK_NOTNULL(graph_); |
|
|
GE_CHECK_NOTNULL(graph_); |
|
|
bool has_get_next = false; |
|
|
bool has_get_next = false; |
|
|
|
|
|
bool has_dyn_get_next = false; |
|
|
for (auto node : graph_->GetDirectNode()) { |
|
|
for (auto node : graph_->GetDirectNode()) { |
|
|
GE_CHECK_NOTNULL(node); |
|
|
GE_CHECK_NOTNULL(node); |
|
|
|
|
|
if (node->GetType() == kDynGetNextType) { |
|
|
|
|
|
has_dyn_get_next = true; |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue); |
|
|
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue); |
|
|
string type = ""; |
|
|
|
|
|
|
|
|
string type; |
|
|
GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); |
|
|
GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); |
|
|
if (type == "IteratorGetNext") { |
|
|
|
|
|
|
|
|
if (type == kGetNextType) { |
|
|
has_get_next = true; |
|
|
has_get_next = true; |
|
|
break; |
|
|
break; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
return GetFusionCluster(has_get_next, node_cluster_map); |
|
|
|
|
|
|
|
|
return GetFusionCluster(has_get_next, has_dyn_get_next, node_cluster_map); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next, |
|
|
|
|
|
|
|
|
Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next, const bool has_dyn_get_next, |
|
|
unordered_map<string, vector<NodePtr>> &node_cluster_map) { |
|
|
unordered_map<string, vector<NodePtr>> &node_cluster_map) { |
|
|
GE_CHECK_NOTNULL(graph_); |
|
|
GE_CHECK_NOTNULL(graph_); |
|
|
for (auto node : graph_->GetDirectNode()) { |
|
|
for (auto node : graph_->GetDirectNode()) { |
|
|
GE_CHECK_NOTNULL(node); |
|
|
GE_CHECK_NOTNULL(node); |
|
|
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue) |
|
|
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue) |
|
|
string type = ""; |
|
|
|
|
|
|
|
|
string type; |
|
|
GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); |
|
|
GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); |
|
|
if (type == "IteratorGetNext") { |
|
|
|
|
|
|
|
|
if (type == kGetNextType) { |
|
|
vector<NodePtr> temp_node_cluser; |
|
|
vector<NodePtr> temp_node_cluser; |
|
|
for (auto in_anchor : node->GetAllInDataAnchors()) { |
|
|
for (auto in_anchor : node->GetAllInDataAnchors()) { |
|
|
OutDataAnchorPtr peer_out_anchor = in_anchor->GetPeerOutAnchor(); |
|
|
OutDataAnchorPtr peer_out_anchor = in_anchor->GetPeerOutAnchor(); |
|
@@ -119,7 +128,9 @@ Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next, |
|
|
GELOGI("MarkForFusion, IteratorGetNext graph mark success."); |
|
|
GELOGI("MarkForFusion, IteratorGetNext graph mark success."); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (!has_get_next && (type == "Iterator" || type == "IteratorV2")) { |
|
|
|
|
|
|
|
|
const bool dataset_init = (!has_get_next) && (!has_dyn_get_next) && |
|
|
|
|
|
((type == kIteratorType) || (type == kIteratorV2Type)); |
|
|
|
|
|
if (dataset_init) { |
|
|
GE_CHK_STATUS_RET(FindFmkNodeCluser(node_cluster_map), "find framework node to be fused fail."); |
|
|
GE_CHK_STATUS_RET(FindFmkNodeCluser(node_cluster_map), "find framework node to be fused fail."); |
|
|
GELOGI("MarkForFusion, Iterator init graph mark success."); |
|
|
GELOGI("MarkForFusion, Iterator init graph mark success."); |
|
|
} |
|
|
} |
|
|