@@ -32,8 +32,8 @@ const int kCaseReadOnly = 0; | |||
const int kCaseScopeWriteable = 2; | |||
const int kCaseWriteable = 3; | |||
const int kCaseInvalidRWType = 5; | |||
const char *const kInputMutable = "_input_mutable"; | |||
// attr _input_mutable = true means node will modify its input in runtime | |||
const char *const kModifyInput = "_input_mutable"; | |||
// rw type of input. | |||
enum class InputRWType { | |||
@@ -276,8 +276,6 @@ InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) { | |||
// single node without sub graph | |||
return GetSingleNodeInputRWTypeByIndex(node, index); | |||
} else { | |||
// node with sub graph | |||
std::set<int> node_rw_type_set; | |||
auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index); | |||
// get all input data node in subgraph | |||
std::set<int> anchor_rw_type_set; | |||
@@ -635,33 +633,30 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||
return SUCCESS; | |||
} | |||
Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | |||
for (const auto &node : compute_graph->GetDirectNode()) { | |||
// op_desc of node should not be null | |||
const auto &op_desc = node->GetOpDesc(); | |||
bool mutable_input_flag = false; | |||
if (!AttrUtils::GetBool(op_desc, kInputMutable, mutable_input_flag) || !mutable_input_flag) { | |||
GELOGD("[Node:%s] Input is not mutable, ignore memory conflict handle", op_desc->GetName().c_str()); | |||
continue; | |||
} | |||
std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
GE_CHECK_NOTNULL(pre_out_anchor); | |||
if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { | |||
pre_out_anchor_set.emplace(pre_out_anchor); | |||
continue; | |||
} | |||
// need insert identity | |||
auto pre_node = pre_out_anchor->GetOwnerNode(); | |||
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||
GE_CHECK_NOTNULL(identity_node); | |||
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||
GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | |||
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||
pre_node->GetName().c_str(), node->GetName().c_str()); | |||
} | |||
} | |||
return SUCCESS; | |||
for (const auto &node : compute_graph->GetDirectNode()) { | |||
bool mutable_input_flag = false; | |||
(void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, mutable_input_flag); | |||
if (!mutable_input_flag) { | |||
continue; | |||
} | |||
std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
GE_CHECK_NOTNULL(pre_out_anchor); | |||
if (pre_out_anchor_set.insert(pre_out_anchor).second) { | |||
continue; | |||
} | |||
// need insert identity | |||
auto pre_node = pre_out_anchor->GetOwnerNode(); | |||
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||
GE_CHECK_NOTNULL(identity_node); | |||
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||
GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | |||
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||
pre_node->GetName().c_str(), node->GetName().c_str()); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
} // namespace | |||
@@ -30,7 +30,8 @@ const int32_t kAnchorSize = 1; | |||
const int kAnchorNum = 0; | |||
const int32_t kAnchorAssignRefIndex = 0; | |||
const int32_t kAnchorAssignValueIndex = 1; | |||
const char *const kInputMutable = "_input_mutable"; | |||
// attr _input_mutable = true means hccl node will modify its input in runtime | |||
const char *const kModifyInput = "_input_mutable"; | |||
} // namespace | |||
namespace ge { | |||
Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||
@@ -58,24 +59,13 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||
// need to inset memcpy node between. | |||
// also works on situation that input is variable or const. | |||
Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { | |||
auto op_desc = node->GetOpDesc(); | |||
bool node_input_mutable = false; | |||
if (!AttrUtils::HasAttr(op_desc, kInputMutable)) { | |||
return SUCCESS; | |||
} | |||
if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) { | |||
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", kInputMutable, | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
(void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, node_input_mutable); | |||
if (!node_input_mutable) { | |||
return SUCCESS; | |||
} | |||
GELOGI("input mutable hcom op is:%s.", op_desc->GetName().c_str()); | |||
GELOGI("input mutable hcom op is:%s.", node->GetName().c_str()); | |||
for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { | |||
if (hccl_in_anchor == nullptr) { | |||
continue; | |||
@@ -716,6 +716,7 @@ set(PASS_TEST_FILES | |||
"graph/passes/reshape_recovery_pass_unittest.cc" | |||
"graph/passes/cast_remove_pass_unittest.cc" | |||
"graph/passes/memcpy_addr_async_unittest.cc" | |||
"graph/optimize/mem_rw_conflict_optimize_unittest.cc" | |||
) | |||
set(KERNEL_TEST_FILES | |||
@@ -0,0 +1,64 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph/optimize/graph_optimize.h" | |||
#include <gtest/gtest.h> | |||
#include "graph/passes/graph_builder_utils.h" | |||
#include "graph/utils/attr_utils.h" | |||
namespace ge { | |||
class MemRwConflictOptimizeTest : public testing::Test { | |||
protected: | |||
void SetUp() override {} | |||
void TearDown() override {} | |||
}; | |||
namespace { | |||
/// | |||
/// HcomAllReduce | |||
/// \ / | |||
/// add | |||
/// / \ | |||
/// var | |||
/// | |||
ComputeGraphPtr build_all_reduce_repeat_input_graph() { | |||
auto builder = ut::GraphBuilder("build_all_reduce_repeat_input_graph"); | |||
auto var = builder.AddNode("var", VARIABLEV2, 0, 1); | |||
auto add = builder.AddNode("add", ADD, 2, 1); | |||
auto hcom_all_reduce = builder.AddNode("HcomAllReduce", HCOMALLREDUCE, 2, 1); | |||
AttrUtils::SetBool(hcom_all_reduce->GetOpDesc(), "_input_mutable", true); | |||
builder.AddDataEdge(var, 1, add, 0); | |||
builder.AddDataEdge(var, 1, add, 1); | |||
builder.AddDataEdge(add, 0, hcom_all_reduce, 0); | |||
builder.AddDataEdge(add, 0, hcom_all_reduce, 1); | |||
return builder.GetGraph(); | |||
} | |||
} // namespace | |||
TEST_F(MemRwConflictOptimizeTest, test_handle_allreduce_duplicate_input) { | |||
auto graph = build_all_reduce_repeat_input_graph(); | |||
EXPECT_NE(graph, nullptr); | |||
GraphOptimize optimize; | |||
EXPECT_EQ(optimize.HandleMemoryRWConflict(graph), SUCCESS); | |||
auto all_reduce = graph->FindNode("HcomAllReduce"); | |||
EXPECT_NE(all_reduce, nullptr); | |||
EXPECT_EQ(all_reduce->GetInDataNodes().size(), 2); | |||
EXPECT_EQ(all_reduce->GetInDataNodes().at(0)->GetType(), ADD); | |||
EXPECT_EQ(all_reduce->GetInDataNodes().at(1)->GetType(), IDENTITY); | |||
} | |||
} // namespace ge |