Browse Source

handle mutable_input node with same input

pull/1690/head
chenyemeng 4 years ago
parent
commit
ee79746b08
4 changed files with 95 additions and 45 deletions
  1. +26
    -31
      ge/graph/optimize/mem_rw_conflict_optimize.cc
  2. +4
    -14
      ge/graph/passes/hccl_memcpy_pass.cc
  3. +1
    -0
      tests/ut/ge/CMakeLists.txt
  4. +64
    -0
      tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc

+ 26
- 31
ge/graph/optimize/mem_rw_conflict_optimize.cc View File

@@ -32,8 +32,8 @@ const int kCaseReadOnly = 0;
const int kCaseScopeWriteable = 2;
const int kCaseWriteable = 3;
const int kCaseInvalidRWType = 5;
const char *const kInputMutable = "_input_mutable";
// attr _input_mutable = true means node will modify its input in runtime
const char *const kModifyInput = "_input_mutable";

// rw type of input.
enum class InputRWType {
@@ -276,8 +276,6 @@ InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) {
// single node without sub graph
return GetSingleNodeInputRWTypeByIndex(node, index);
} else {
// node with sub graph
std::set<int> node_rw_type_set;
auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index);
// get all input data node in subgraph
std::set<int> anchor_rw_type_set;
@@ -635,33 +633,30 @@ Status InsertIdentityAsNeeded(const NodePtr &node) {
return SUCCESS;
}
Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) {
for (const auto &node : compute_graph->GetDirectNode()) {
// op_desc of node should not be null
const auto &op_desc = node->GetOpDesc();
bool mutable_input_flag = false;
if (!AttrUtils::GetBool(op_desc, kInputMutable, mutable_input_flag) || !mutable_input_flag) {
GELOGD("[Node:%s] Input is not mutable, ignore memory conflict handle", op_desc->GetName().c_str());
continue;
}
std::set<OutDataAnchorPtr> pre_out_anchor_set;
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(pre_out_anchor);
if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) {
pre_out_anchor_set.emplace(pre_out_anchor);
continue;
}
// need insert identity
auto pre_node = pre_out_anchor->GetOwnerNode();
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx());
GE_CHECK_NOTNULL(identity_node);
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node);
GE_CHK_STATUS_RET(ret, "Fail to insert identity.");
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(),
pre_node->GetName().c_str(), node->GetName().c_str());
}
}
return SUCCESS;
for (const auto &node : compute_graph->GetDirectNode()) {
bool mutable_input_flag = false;
(void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, mutable_input_flag);
if (!mutable_input_flag) {
continue;
}
std::set<OutDataAnchorPtr> pre_out_anchor_set;
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(pre_out_anchor);
if (pre_out_anchor_set.insert(pre_out_anchor).second) {
continue;
}
// need insert identity
auto pre_node = pre_out_anchor->GetOwnerNode();
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx());
GE_CHECK_NOTNULL(identity_node);
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node);
GE_CHK_STATUS_RET(ret, "Fail to insert identity.");
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(),
pre_node->GetName().c_str(), node->GetName().c_str());
}
}
return SUCCESS;
}
} // namespace



+ 4
- 14
ge/graph/passes/hccl_memcpy_pass.cc View File

@@ -30,7 +30,8 @@ const int32_t kAnchorSize = 1;
const int kAnchorNum = 0;
const int32_t kAnchorAssignRefIndex = 0;
const int32_t kAnchorAssignValueIndex = 1;
const char *const kInputMutable = "_input_mutable";
// attr _input_mutable = true means hccl node will modify its input in runtime
const char *const kModifyInput = "_input_mutable";
} // namespace
namespace ge {
Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) {
@@ -58,24 +59,13 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) {
// need to inset memcpy node between.
// also works on situation that input is variable or const.
Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) {
auto op_desc = node->GetOpDesc();

bool node_input_mutable = false;
if (!AttrUtils::HasAttr(op_desc, kInputMutable)) {
return SUCCESS;
}

if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) {
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", kInputMutable,
op_desc->GetName().c_str(), op_desc->GetType().c_str());
GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str());
return FAILED;
}
(void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, node_input_mutable);
if (!node_input_mutable) {
return SUCCESS;
}

GELOGI("input mutable hcom op is:%s.", op_desc->GetName().c_str());
GELOGI("input mutable hcom op is:%s.", node->GetName().c_str());
for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) {
if (hccl_in_anchor == nullptr) {
continue;


+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -716,6 +716,7 @@ set(PASS_TEST_FILES
"graph/passes/reshape_recovery_pass_unittest.cc"
"graph/passes/cast_remove_pass_unittest.cc"
"graph/passes/memcpy_addr_async_unittest.cc"
"graph/optimize/mem_rw_conflict_optimize_unittest.cc"
)

set(KERNEL_TEST_FILES


+ 64
- 0
tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc View File

@@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/optimize/graph_optimize.h"

#include <gtest/gtest.h>
#include "graph/passes/graph_builder_utils.h"
#include "graph/utils/attr_utils.h"

namespace ge {
class MemRwConflictOptimizeTest : public testing::Test {
protected:
void SetUp() override {}
void TearDown() override {}
};

namespace {
///
/// HcomAllReduce
/// \ /
/// add
/// / \
/// var
///
ComputeGraphPtr build_all_reduce_repeat_input_graph() {
auto builder = ut::GraphBuilder("build_all_reduce_repeat_input_graph");
auto var = builder.AddNode("var", VARIABLEV2, 0, 1);
auto add = builder.AddNode("add", ADD, 2, 1);
auto hcom_all_reduce = builder.AddNode("HcomAllReduce", HCOMALLREDUCE, 2, 1);
AttrUtils::SetBool(hcom_all_reduce->GetOpDesc(), "_input_mutable", true);

builder.AddDataEdge(var, 1, add, 0);
builder.AddDataEdge(var, 1, add, 1);
builder.AddDataEdge(add, 0, hcom_all_reduce, 0);
builder.AddDataEdge(add, 0, hcom_all_reduce, 1);
return builder.GetGraph();
}
} // namespace

TEST_F(MemRwConflictOptimizeTest, test_handle_allreduce_duplicate_input) {
auto graph = build_all_reduce_repeat_input_graph();
EXPECT_NE(graph, nullptr);
GraphOptimize optimize;
EXPECT_EQ(optimize.HandleMemoryRWConflict(graph), SUCCESS);
auto all_reduce = graph->FindNode("HcomAllReduce");
EXPECT_NE(all_reduce, nullptr);
EXPECT_EQ(all_reduce->GetInDataNodes().size(), 2);
EXPECT_EQ(all_reduce->GetInDataNodes().at(0)->GetType(), ADD);
EXPECT_EQ(all_reduce->GetInDataNodes().at(1)->GetType(), IDENTITY);
}
} // namespace ge

Loading…
Cancel
Save