From c7c651f94889ae6d434d30a7b60ee9738b33f47e Mon Sep 17 00:00:00 2001 From: chenyemeng Date: Tue, 25 May 2021 12:01:16 +0800 Subject: [PATCH] handle mutable input --- ge/graph/optimize/mem_rw_conflict_optimize.cc | 41 ++++++++++++++++----------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/ge/graph/optimize/mem_rw_conflict_optimize.cc b/ge/graph/optimize/mem_rw_conflict_optimize.cc index 077ed110..e48cb86e 100644 --- a/ge/graph/optimize/mem_rw_conflict_optimize.cc +++ b/ge/graph/optimize/mem_rw_conflict_optimize.cc @@ -33,6 +33,8 @@ const int kCaseScopeWriteable = 2; const int kCaseWriteable = 3; const int kCaseInvalidRWType = 5; +const char *const kInputMutable = "_input_mutable"; + // rw type of input. enum class InputRWType { kReadOnly, // Normal op input only read @@ -634,24 +636,29 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { } Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { for (const auto &node : compute_graph->GetDirectNode()) { - if (node->GetType() == HCOMALLREDUCE) { - std::set pre_out_anchor_set; - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(pre_out_anchor); - if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { - pre_out_anchor_set.emplace(pre_out_anchor); - continue; - } - // need insert identity - auto pre_node = pre_out_anchor->GetOwnerNode(); - auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); - GE_CHECK_NOTNULL(identity_node); - auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); - GE_CHK_STATUS_RET(ret, "Fail to insert identity."); - GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), - pre_node->GetName().c_str(), node->GetName().c_str()); + // op_desc of node should not be null + const auto &op_desc = node->GetOpDesc(); + bool mutable_input_flag = false; + if (!AttrUtils::GetBool(op_desc, kInputMutable, mutable_input_flag) || !mutable_input_flag) { + GELOGD("[Node:%s] Input is not mutable, ignore memory conflict handle", op_desc->GetName().c_str()); + continue; + } + std::set pre_out_anchor_set; + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(pre_out_anchor); + if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { + pre_out_anchor_set.emplace(pre_out_anchor); + continue; } + // need insert identity + auto pre_node = pre_out_anchor->GetOwnerNode(); + auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); + GE_CHECK_NOTNULL(identity_node); + auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); + GE_CHK_STATUS_RET(ret, "Fail to insert identity."); + GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), + pre_node->GetName().c_str(), node->GetName().c_str()); } } return SUCCESS;