Browse Source

DynamicGetNext not fusion

pull/456/head
wangzhengjun 3 years ago
parent
commit
229e9159d9
2 changed files with 20 additions and 8 deletions
  1. +18
    -7
      parser/tensorflow/graph_optimizer.cc
  2. +2
    -1
      parser/tensorflow/graph_optimizer.h

+ 18
- 7
parser/tensorflow/graph_optimizer.cc View File

@@ -37,6 +37,10 @@ namespace {
const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal"; const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal";
const char *const kShapeNodeType = "Shape"; const char *const kShapeNodeType = "Shape";
const char *const kShapeNodeNamePrefix = "getnext_shape_"; const char *const kShapeNodeNamePrefix = "getnext_shape_";
const char *const kIteratorType = "Iterator";
const char *const kIteratorV2Type = "IteratorV2";
const char *const kGetNextType = "IteratorGetNext";
const char *const kDynGetNextType = "DynamicGetNext";
} // namespace } // namespace


Status ParserGraphOptimizer::FusionFmkop() { Status ParserGraphOptimizer::FusionFmkop() {
@@ -66,28 +70,33 @@ Status ParserGraphOptimizer::FusionFmkop() {
Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr>> &node_cluster_map) { Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr>> &node_cluster_map) {
GE_CHECK_NOTNULL(graph_); GE_CHECK_NOTNULL(graph_);
bool has_get_next = false; bool has_get_next = false;
bool has_dyn_get_next = false;
for (auto node : graph_->GetDirectNode()) { for (auto node : graph_->GetDirectNode()) {
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
if (node->GetType() == kDynGetNextType) {
has_dyn_get_next = true;
break;
}
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue); GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue);
string type = "";
string type;
GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type));
if (type == "IteratorGetNext") {
if (type == kGetNextType) {
has_get_next = true; has_get_next = true;
break; break;
} }
} }
return GetFusionCluster(has_get_next, node_cluster_map);
return GetFusionCluster(has_get_next, has_dyn_get_next, node_cluster_map);
} }


Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next,
Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next, const bool has_dyn_get_next,
unordered_map<string, vector<NodePtr>> &node_cluster_map) { unordered_map<string, vector<NodePtr>> &node_cluster_map) {
GE_CHECK_NOTNULL(graph_); GE_CHECK_NOTNULL(graph_);
for (auto node : graph_->GetDirectNode()) { for (auto node : graph_->GetDirectNode()) {
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue) GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue)
string type = "";
string type;
GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type));
if (type == "IteratorGetNext") {
if (type == kGetNextType) {
vector<NodePtr> temp_node_cluser; vector<NodePtr> temp_node_cluser;
for (auto in_anchor : node->GetAllInDataAnchors()) { for (auto in_anchor : node->GetAllInDataAnchors()) {
OutDataAnchorPtr peer_out_anchor = in_anchor->GetPeerOutAnchor(); OutDataAnchorPtr peer_out_anchor = in_anchor->GetPeerOutAnchor();
@@ -119,7 +128,9 @@ Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next,
GELOGI("MarkForFusion, IteratorGetNext graph mark success."); GELOGI("MarkForFusion, IteratorGetNext graph mark success.");
} }


if (!has_get_next && (type == "Iterator" || type == "IteratorV2")) {
const bool dataset_init = (!has_get_next) && (!has_dyn_get_next) &&
((type == kIteratorType) || (type == kIteratorV2Type));
if (dataset_init) {
GE_CHK_STATUS_RET(FindFmkNodeCluser(node_cluster_map), "find framework node to be fused fail."); GE_CHK_STATUS_RET(FindFmkNodeCluser(node_cluster_map), "find framework node to be fused fail.");
GELOGI("MarkForFusion, Iterator init graph mark success."); GELOGI("MarkForFusion, Iterator init graph mark success.");
} }


+ 2
- 1
parser/tensorflow/graph_optimizer.h View File

@@ -43,7 +43,8 @@ class ParserGraphOptimizer {


domi::Status MarkForFusion(std::unordered_map<std::string, std::vector<ge::NodePtr>> &node_cluster_map); domi::Status MarkForFusion(std::unordered_map<std::string, std::vector<ge::NodePtr>> &node_cluster_map);


domi::Status GetFusionCluster(const bool has_get_next, unordered_map<string, vector<NodePtr>> &node_cluster_map);
domi::Status GetFusionCluster(const bool has_get_next, const bool has_dyn_get_next,
unordered_map<string, vector<NodePtr>> &node_cluster_map);


domi::Status UpdateGraph(std::vector<ge::NodePtr> &nodes); domi::Status UpdateGraph(std::vector<ge::NodePtr> &nodes);




Loading…
Cancel
Save