From c7c651f94889ae6d434d30a7b60ee9738b33f47e Mon Sep 17 00:00:00 2001 From: chenyemeng Date: Tue, 25 May 2021 12:01:16 +0800 Subject: [PATCH 1/2] handle mutable input --- ge/graph/optimize/mem_rw_conflict_optimize.cc | 41 ++++++++++++++++----------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/ge/graph/optimize/mem_rw_conflict_optimize.cc b/ge/graph/optimize/mem_rw_conflict_optimize.cc index 077ed110..e48cb86e 100644 --- a/ge/graph/optimize/mem_rw_conflict_optimize.cc +++ b/ge/graph/optimize/mem_rw_conflict_optimize.cc @@ -33,6 +33,8 @@ const int kCaseScopeWriteable = 2; const int kCaseWriteable = 3; const int kCaseInvalidRWType = 5; +const char *const kInputMutable = "_input_mutable"; + // rw type of input. enum class InputRWType { kReadOnly, // Normal op input only read @@ -634,24 +636,29 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { } Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { for (const auto &node : compute_graph->GetDirectNode()) { - if (node->GetType() == HCOMALLREDUCE) { - std::set pre_out_anchor_set; - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(pre_out_anchor); - if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { - pre_out_anchor_set.emplace(pre_out_anchor); - continue; - } - // need insert identity - auto pre_node = pre_out_anchor->GetOwnerNode(); - auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); - GE_CHECK_NOTNULL(identity_node); - auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); - GE_CHK_STATUS_RET(ret, "Fail to insert identity."); - GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), - pre_node->GetName().c_str(), node->GetName().c_str()); + // op_desc of node should not be null + const auto &op_desc = node->GetOpDesc(); + bool mutable_input_flag = false; + if (!AttrUtils::GetBool(op_desc, kInputMutable, mutable_input_flag) || !mutable_input_flag) { + GELOGD("[Node:%s] Input is not mutable, ignore memory conflict handle", op_desc->GetName().c_str()); + continue; + } + std::set pre_out_anchor_set; + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(pre_out_anchor); + if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { + pre_out_anchor_set.emplace(pre_out_anchor); + continue; } + // need insert identity + auto pre_node = pre_out_anchor->GetOwnerNode(); + auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); + GE_CHECK_NOTNULL(identity_node); + auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); + GE_CHK_STATUS_RET(ret, "Fail to insert identity."); + GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), + pre_node->GetName().c_str(), node->GetName().c_str()); } } return SUCCESS; From ee79746b08c630f5f92da77b137c36cede691226 Mon Sep 17 00:00:00 2001 From: chenyemeng Date: Tue, 25 May 2021 19:20:28 +0800 Subject: [PATCH 2/2] handle mutable_input node with same input --- ge/graph/optimize/mem_rw_conflict_optimize.cc | 57 +++++++++---------- ge/graph/passes/hccl_memcpy_pass.cc | 18 ++---- tests/ut/ge/CMakeLists.txt | 1 + .../optimize/mem_rw_conflict_optimize_unittest.cc | 64 ++++++++++++++++++++++ 4 files changed, 95 insertions(+), 45 deletions(-) create mode 100644 tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc diff --git a/ge/graph/optimize/mem_rw_conflict_optimize.cc b/ge/graph/optimize/mem_rw_conflict_optimize.cc index e48cb86e..10df00ed 100644 --- a/ge/graph/optimize/mem_rw_conflict_optimize.cc +++ b/ge/graph/optimize/mem_rw_conflict_optimize.cc @@ -32,8 +32,8 @@ const int kCaseReadOnly = 0; const int kCaseScopeWriteable = 2; const int kCaseWriteable = 3; const int kCaseInvalidRWType = 5; - -const char *const kInputMutable = "_input_mutable"; +// attr _input_mutable = true means node will modify its input in runtime +const char *const kModifyInput = "_input_mutable"; // rw type of input. enum class InputRWType { @@ -276,8 +276,6 @@ InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) { // single node without sub graph return GetSingleNodeInputRWTypeByIndex(node, index); } else { - // node with sub graph - std::set node_rw_type_set; auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index); // get all input data node in subgraph std::set anchor_rw_type_set; @@ -635,33 +633,30 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { return SUCCESS; } Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { - for (const auto &node : compute_graph->GetDirectNode()) { - // op_desc of node should not be null - const auto &op_desc = node->GetOpDesc(); - bool mutable_input_flag = false; - if (!AttrUtils::GetBool(op_desc, kInputMutable, mutable_input_flag) || !mutable_input_flag) { - GELOGD("[Node:%s] Input is not mutable, ignore memory conflict handle", op_desc->GetName().c_str()); - continue; - } - std::set pre_out_anchor_set; - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(pre_out_anchor); - if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { - pre_out_anchor_set.emplace(pre_out_anchor); - continue; - } - // need insert identity - auto pre_node = pre_out_anchor->GetOwnerNode(); - auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); - GE_CHECK_NOTNULL(identity_node); - auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); - GE_CHK_STATUS_RET(ret, "Fail to insert identity."); - GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), - pre_node->GetName().c_str(), node->GetName().c_str()); - } - } - return SUCCESS; + for (const auto &node : compute_graph->GetDirectNode()) { + bool mutable_input_flag = false; + (void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, mutable_input_flag); + if (!mutable_input_flag) { + continue; + } + std::set pre_out_anchor_set; + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(pre_out_anchor); + if (pre_out_anchor_set.insert(pre_out_anchor).second) { + continue; + } + // need insert identity + auto pre_node = pre_out_anchor->GetOwnerNode(); + auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); + GE_CHECK_NOTNULL(identity_node); + auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); + GE_CHK_STATUS_RET(ret, "Fail to insert identity."); + GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), + pre_node->GetName().c_str(), node->GetName().c_str()); + } + } + return SUCCESS; } } // namespace diff --git a/ge/graph/passes/hccl_memcpy_pass.cc b/ge/graph/passes/hccl_memcpy_pass.cc index 2d2f8220..ae798c4d 100755 --- a/ge/graph/passes/hccl_memcpy_pass.cc +++ b/ge/graph/passes/hccl_memcpy_pass.cc @@ -30,7 +30,8 @@ const int32_t kAnchorSize = 1; const int kAnchorNum = 0; const int32_t kAnchorAssignRefIndex = 0; const int32_t kAnchorAssignValueIndex = 1; -const char *const kInputMutable = "_input_mutable"; +// attr _input_mutable = true means hccl node will modify its input in runtime +const char *const kModifyInput = "_input_mutable"; } // namespace namespace ge { Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { @@ -58,24 +59,13 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { // need to inset memcpy node between. // also works on situation that input is variable or const. Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { - auto op_desc = node->GetOpDesc(); - bool node_input_mutable = false; - if (!AttrUtils::HasAttr(op_desc, kInputMutable)) { - return SUCCESS; - } - - if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) { - REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", kInputMutable, - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); - return FAILED; - } + (void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, node_input_mutable); if (!node_input_mutable) { return SUCCESS; } - GELOGI("input mutable hcom op is:%s.", op_desc->GetName().c_str()); + GELOGI("input mutable hcom op is:%s.", node->GetName().c_str()); for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { if (hccl_in_anchor == nullptr) { continue; diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index e957d119..7fda9c86 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -716,6 +716,7 @@ set(PASS_TEST_FILES "graph/passes/reshape_recovery_pass_unittest.cc" "graph/passes/cast_remove_pass_unittest.cc" "graph/passes/memcpy_addr_async_unittest.cc" + "graph/optimize/mem_rw_conflict_optimize_unittest.cc" ) set(KERNEL_TEST_FILES diff --git a/tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc b/tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file mode 100644 index 00000000..5385349c --- /dev/null +++ b/tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/optimize/graph_optimize.h" + +#include +#include "graph/passes/graph_builder_utils.h" +#include "graph/utils/attr_utils.h" + +namespace ge { +class MemRwConflictOptimizeTest : public testing::Test { + protected: + void SetUp() override {} + void TearDown() override {} +}; + +namespace { +/// +/// HcomAllReduce +/// \ / +/// add +/// / \ +/// var +/// +ComputeGraphPtr build_all_reduce_repeat_input_graph() { + auto builder = ut::GraphBuilder("build_all_reduce_repeat_input_graph"); + auto var = builder.AddNode("var", VARIABLEV2, 0, 1); + auto add = builder.AddNode("add", ADD, 2, 1); + auto hcom_all_reduce = builder.AddNode("HcomAllReduce", HCOMALLREDUCE, 2, 1); + AttrUtils::SetBool(hcom_all_reduce->GetOpDesc(), "_input_mutable", true); + + builder.AddDataEdge(var, 1, add, 0); + builder.AddDataEdge(var, 1, add, 1); + builder.AddDataEdge(add, 0, hcom_all_reduce, 0); + builder.AddDataEdge(add, 0, hcom_all_reduce, 1); + return builder.GetGraph(); +} +} // namespace + +TEST_F(MemRwConflictOptimizeTest, test_handle_allreduce_duplicate_input) { + auto graph = build_all_reduce_repeat_input_graph(); + EXPECT_NE(graph, nullptr); + GraphOptimize optimize; + EXPECT_EQ(optimize.HandleMemoryRWConflict(graph), SUCCESS); + auto all_reduce = graph->FindNode("HcomAllReduce"); + EXPECT_NE(all_reduce, nullptr); + EXPECT_EQ(all_reduce->GetInDataNodes().size(), 2); + EXPECT_EQ(all_reduce->GetInDataNodes().at(0)->GetType(), ADD); + EXPECT_EQ(all_reduce->GetInDataNodes().at(1)->GetType(), IDENTITY); +} +} // namespace ge