modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.ccpull/1702/head
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
* You may obtain a copy of the License at | * You may obtain a copy of the License at | ||||
@@ -22,6 +22,7 @@ | |||||
#include "graph/optimize/graph_optimize.h" | #include "graph/optimize/graph_optimize.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
#include "graph/utils/op_desc_utils.h" | |||||
namespace { | namespace { | ||||
using namespace ge; | using namespace ge; | ||||
@@ -32,12 +33,14 @@ const int kCaseReadOnly = 0; | |||||
const int kCaseScopeWriteable = 2; | const int kCaseScopeWriteable = 2; | ||||
const int kCaseWriteable = 3; | const int kCaseWriteable = 3; | ||||
const int kCaseInvalidRWType = 5; | const int kCaseInvalidRWType = 5; | ||||
// attr _input_mutable = true means node will modify its input in runtime | |||||
const char *const kModifyInput = "_input_mutable"; | |||||
// rw type of input. | // rw type of input. | ||||
enum class InputRWType { | enum class InputRWType { | ||||
kReadOnly, // Normal op input only read | kReadOnly, // Normal op input only read | ||||
kWriteable, // Op like Assign/ApplyMomentum | kWriteable, // Op like Assign/ApplyMomentum | ||||
kScopeWriteable, // Op like hcom_allreduce, it will modify input ,but not expect take effect on pre ouput | |||||
kScopeWriteable, // Op like hcom_allreduce/while, it will modify input ,but not expect take effect on pre ouput | |||||
kInvalidRWType | kInvalidRWType | ||||
}; | }; | ||||
// rw type of output | // rw type of output | ||||
@@ -154,7 +157,7 @@ bool IsSubgraphOutputNode(const NodePtr &node) { | |||||
return true; | return true; | ||||
} | } | ||||
NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { | |||||
NodePtr AddIdentityToGraph(const Node &src_node, int out_anchor_idx) { | |||||
if (src_node.GetOpDesc() == nullptr) { | if (src_node.GetOpDesc() == nullptr) { | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
@@ -162,30 +165,19 @@ NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { | |||||
auto next_num = identity_num.fetch_add(1); | auto next_num = identity_num.fetch_add(1); | ||||
// 1. create new identity op desc | // 1. create new identity op desc | ||||
string identity_name = src_node.GetName() + "_" + IDENTITY + std::to_string(next_num); | string identity_name = src_node.GetName() + "_" + IDENTITY + std::to_string(next_num); | ||||
auto identity_opdesc = MakeShared<OpDesc>(identity_name, IDENTITY); | |||||
if (identity_opdesc == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Failed to insert identity node, name %s", identity_name.c_str()); | |||||
return nullptr; | |||||
} | |||||
OpDescBuilder op_desc_builder(identity_name, IDENTITY); | |||||
auto data_desc = src_node.GetOpDesc()->GetOutputDesc(out_anchor_idx); | auto data_desc = src_node.GetOpDesc()->GetOutputDesc(out_anchor_idx); | ||||
// 2. add input_desc & output_desc for new identity | |||||
Status ret = identity_opdesc->AddInputDesc("x", data_desc); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Add Input desc failed for new identity %s.", identity_name.c_str()); | |||||
return nullptr; | |||||
} | |||||
ret = identity_opdesc->AddOutputDesc("y", data_desc); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Add Output desc failed for new Identity %s.", identity_name.c_str()); | |||||
return nullptr; | |||||
} | |||||
auto identity_op_desc = op_desc_builder.AddInput("x", data_desc) | |||||
.AddOutput("y", data_desc) | |||||
.Build(); | |||||
GELOGI("Insert new Identity node %s.", identity_name.c_str()); | GELOGI("Insert new Identity node %s.", identity_name.c_str()); | ||||
auto graph = src_node.GetOwnerComputeGraph(); | auto graph = src_node.GetOwnerComputeGraph(); | ||||
if (graph == nullptr) { | if (graph == nullptr) { | ||||
GELOGE(GRAPH_PARAM_INVALID, "Node %s owner compute graph is null.", src_node.GetName().c_str()); | GELOGE(GRAPH_PARAM_INVALID, "Node %s owner compute graph is null.", src_node.GetName().c_str()); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
return graph->AddNode(identity_opdesc); | |||||
return graph->AddNode(identity_op_desc); | |||||
} | } | ||||
OutputRWType GetOutputRWTypeByIndex(const Node &node, uint32_t index) { | OutputRWType GetOutputRWTypeByIndex(const Node &node, uint32_t index) { | ||||
@@ -274,8 +266,6 @@ InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) { | |||||
// single node without sub graph | // single node without sub graph | ||||
return GetSingleNodeInputRWTypeByIndex(node, index); | return GetSingleNodeInputRWTypeByIndex(node, index); | ||||
} else { | } else { | ||||
// node with sub graph | |||||
std::set<int> node_rw_type_set; | |||||
auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index); | auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index); | ||||
// get all input data node in subgraph | // get all input data node in subgraph | ||||
std::set<int> anchor_rw_type_set; | std::set<int> anchor_rw_type_set; | ||||
@@ -345,12 +335,24 @@ Status MarkRWTypeForSubgraph(const ComputeGraphPtr &sub_graph) { | |||||
auto parent_node = sub_graph->GetParentNode(); | auto parent_node = sub_graph->GetParentNode(); | ||||
if (pre_output_rw_type == OutputRWType::kWriteable && parent_node->GetType() != PARTITIONEDCALL) { | if (pre_output_rw_type == OutputRWType::kWriteable && parent_node->GetType() != PARTITIONEDCALL) { | ||||
// insert identity | // insert identity | ||||
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||||
auto identity_node = AddIdentityToGraph(*pre_node, pre_out_anchor->GetIdx()); | |||||
GE_CHECK_NOTNULL(identity_node); | GE_CHECK_NOTNULL(identity_node); | ||||
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Fail to insert identity"); | |||||
return ret; | |||||
if (GraphUtils::InsertNodeAfter(pre_out_anchor, {in_data_anchor}, identity_node) != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Insert Identity node %s(%s) between %s(%s) -> %s(%s) failed.", | |||||
identity_node->GetName().c_str(), | |||||
identity_node->GetType().c_str(), | |||||
pre_node->GetName().c_str(), | |||||
pre_node->GetType().c_str(), | |||||
node->GetName().c_str(), | |||||
node->GetType().c_str()); | |||||
GELOGE(FAILED, "Insert Identity node %s(%s) between %s(%s) -> %s(%s) failed.", | |||||
identity_node->GetName().c_str(), | |||||
identity_node->GetType().c_str(), | |||||
pre_node->GetName().c_str(), | |||||
pre_node->GetType().c_str(), | |||||
node->GetName().c_str(), | |||||
node->GetType().c_str()); | |||||
return FAILED; | |||||
} | } | ||||
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | ||||
pre_node->GetName().c_str(), node->GetName().c_str()); | pre_node->GetName().c_str(), node->GetName().c_str()); | ||||
@@ -505,34 +507,24 @@ Status SplitIdentityAlongAnchor(const OutDataAnchorPtr &out_data_anchor, const I | |||||
auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | ||||
GE_CHECK_NOTNULL(peer_in_data_node); | GE_CHECK_NOTNULL(peer_in_data_node); | ||||
auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx()); | auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx()); | ||||
auto ret = out_data_anchor->Unlink(peer_in_data_anchor); | |||||
auto old_identity = out_data_anchor->GetOwnerNode(); | auto old_identity = out_data_anchor->GetOwnerNode(); | ||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to unlink from %s %dth out to %s.", old_identity->GetName().c_str(), out_data_anchor->GetIdx(), | |||||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) { | if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) { | ||||
auto new_identity = CreateIdentityAfterSrcNode(*pre_node, pre_out_data_anchor->GetIdx()); | |||||
auto new_identity = AddIdentityToGraph(*pre_node, pre_out_data_anchor->GetIdx()); | |||||
GE_CHECK_NOTNULL(new_identity); | GE_CHECK_NOTNULL(new_identity); | ||||
if (GraphUtils::AddEdge(pre_out_data_anchor, new_identity->GetInDataAnchor(kIdentityAnchorIndex)) != SUCCESS | |||||
|| GraphUtils::AddEdge(new_identity->GetOutDataAnchor(kIdentityAnchorIndex), peer_in_data_anchor) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to insert Identity between node %s and %s", | |||||
pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), | |||||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
// 2. copy in-control-edge from dst to Identity | |||||
if (GraphUtils::CopyInCtrlEdges(peer_in_data_node, new_identity) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to copy in_control edges from node %s to %s", peer_in_data_node->GetName().c_str(), | |||||
new_identity->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
auto ret = GraphUtils::InsertNodeBefore(peer_in_data_anchor, new_identity, kIdentityAnchorIndex, | |||||
kIdentityAnchorIndex); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to insert Identity %s before %s %dth input.", | |||||
new_identity->GetName().c_str(), | |||||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str(), | |||||
peer_in_data_anchor->GetIdx()); | |||||
return ret; | |||||
} | } | ||||
GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(), | GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(), | ||||
InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), | InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), | ||||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | ||||
} else { | } else { | ||||
(void) out_data_anchor->Unlink(peer_in_data_anchor); | |||||
// copy control edge to pre and peer node | // copy control edge to pre and peer node | ||||
if (GraphUtils::CopyInCtrlEdges(old_identity, peer_in_data_node) != SUCCESS | if (GraphUtils::CopyInCtrlEdges(old_identity, peer_in_data_node) != SUCCESS | ||||
|| GraphUtils::CopyOutCtrlEdges(old_identity, pre_node) != SUCCESS) { | || GraphUtils::CopyOutCtrlEdges(old_identity, pre_node) != SUCCESS) { | ||||
@@ -613,16 +605,14 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||||
GELOGD("No need insert Identity."); | GELOGD("No need insert Identity."); | ||||
continue; | continue; | ||||
case INSERT_IDENTITY: | case INSERT_IDENTITY: | ||||
auto identity_node = CreateIdentityAfterSrcNode(*node, out_data_anchor->GetIdx()); | |||||
if (identity_node == nullptr) { | |||||
GELOGE(FAILED, "Create identity node failed."); | |||||
return FAILED; | |||||
} | |||||
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(out_data_anchor, peer_in_data_anchor, identity_node); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to insert reshape between node %s and %s", node->GetName().c_str(), | |||||
peer_in_node->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
auto identity_node = AddIdentityToGraph(*node, out_data_anchor->GetIdx()); | |||||
GE_CHECK_NOTNULL(identity_node); | |||||
auto ret = GraphUtils::InsertNodeBefore(peer_in_data_anchor, identity_node, kIdentityAnchorIndex, | |||||
kIdentityAnchorIndex); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Fail to insert %s before %s %dth input.", identity_node->GetName().c_str(), | |||||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str(), peer_in_data_anchor->GetIdx()); | |||||
return ret; | |||||
} | } | ||||
GELOGI("Insert Identity between %s and %s to handle memory conflict.", node->GetName().c_str(), | GELOGI("Insert Identity between %s and %s to handle memory conflict.", node->GetName().c_str(), | ||||
peer_in_node->GetName().c_str()); | peer_in_node->GetName().c_str()); | ||||
@@ -633,28 +623,35 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | ||||
for (const auto &node : compute_graph->GetDirectNode()) { | |||||
if (node->GetType() == HCOMALLREDUCE) { | |||||
std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||||
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
GE_CHECK_NOTNULL(pre_out_anchor); | |||||
if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { | |||||
pre_out_anchor_set.emplace(pre_out_anchor); | |||||
continue; | |||||
} | |||||
// need insert identity | |||||
auto pre_node = pre_out_anchor->GetOwnerNode(); | |||||
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||||
GE_CHECK_NOTNULL(identity_node); | |||||
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||||
GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | |||||
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||||
pre_node->GetName().c_str(), node->GetName().c_str()); | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
for (const auto &node : compute_graph->GetDirectNode()) { | |||||
bool mutable_input_flag = false; | |||||
(void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, mutable_input_flag); | |||||
if (!mutable_input_flag) { | |||||
continue; | |||||
} | |||||
std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||||
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
GE_CHECK_NOTNULL(pre_out_anchor); | |||||
if (pre_out_anchor_set.insert(pre_out_anchor).second) { | |||||
continue; | |||||
} | |||||
// need insert identity | |||||
auto pre_node = pre_out_anchor->GetOwnerNode(); | |||||
auto identity_node = AddIdentityToGraph(*pre_node, pre_out_anchor->GetIdx()); | |||||
GE_CHECK_NOTNULL(identity_node); | |||||
auto ret = | |||||
GraphUtils::InsertNodeBefore(in_data_anchor, identity_node, kIdentityAnchorIndex, kIdentityAnchorIndex); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to insert node %s before %s %dth input.", identity_node->GetName().c_str(), | |||||
node->GetName().c_str(), in_data_anchor->GetIdx()); | |||||
return ret; | |||||
} | |||||
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||||
pre_node->GetName().c_str(), node->GetName().c_str()); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | } | ||||
} // namespace | } // namespace | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -24,11 +24,12 @@ | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/utils/op_desc_utils.h" | |||||
namespace { | namespace { | ||||
const int kAnchorNum = 0; | |||||
const int32_t kAnchorAssignRefIndex = 0; | const int32_t kAnchorAssignRefIndex = 0; | ||||
const int32_t kAnchorAssignValueIndex = 1; | const int32_t kAnchorAssignValueIndex = 1; | ||||
const int32_t kAnchorIdentityIndex = 0; | |||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
Status HcclContinuousMemcpyPass::Run(ge::ComputeGraphPtr graph) { | Status HcclContinuousMemcpyPass::Run(ge::ComputeGraphPtr graph) { | ||||
@@ -161,41 +162,23 @@ NodePtr HcclContinuousMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &grap | |||||
std::string node_name = pre_node->GetName() + "_" + IDENTITY; | std::string node_name = pre_node->GetName() + "_" + IDENTITY; | ||||
node_name = CheckDuplicateName(node_name); | node_name = CheckDuplicateName(node_name); | ||||
OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY); | |||||
if (op_desc == nullptr) { | |||||
REPORT_CALL_ERROR("E19999", "New OpDesc failed"); | |||||
GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail."); | |||||
return nullptr; | |||||
} | |||||
GELOGI("Create Identity op:%s.", op_desc->GetName().c_str()); | |||||
graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail."); | |||||
return nullptr; | |||||
} | |||||
ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail."); | |||||
OpDescBuilder op_desc_builder(node_name, IDENTITY); | |||||
auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()); | |||||
auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build(); | |||||
if (identity_op_desc == nullptr) { | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
// because history reason ,this pass can not do work after constant fold so mark it | // because history reason ,this pass can not do work after constant fold so mark it | ||||
(void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||||
(void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||||
NodePtr memcpy_node = graph->AddNode(op_desc); | |||||
if (memcpy_node == nullptr) { | |||||
NodePtr identity_node = graph->AddNode(identity_op_desc); | |||||
if (identity_node == nullptr) { | |||||
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | ||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), graph->GetName().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
return memcpy_node; | |||||
return identity_node; | |||||
} | } | ||||
/// | /// | ||||
@@ -256,50 +239,24 @@ Status HcclContinuousMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &gra | |||||
Status HcclContinuousMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, | Status HcclContinuousMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, | ||||
const OutDataAnchorPtr &src_out_anchor, | const OutDataAnchorPtr &src_out_anchor, | ||||
const InDataAnchorPtr &hccl_in_anchor) { | const InDataAnchorPtr &hccl_in_anchor) { | ||||
GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | ||||
NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); | |||||
GE_CHECK_NOTNULL(memcpy_node); | |||||
NodePtr identity_node = CreateIdentityNode(graph, src_out_anchor); | |||||
GE_CHECK_NOTNULL(identity_node); | |||||
Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); | |||||
if (ret1 != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", | |||||
"Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", | |||||
src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||||
hccl_in_anchor->GetIdx()); | |||||
GELOGE(INTERNAL_ERROR, "The op %s Unlink anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
auto out_data_anchor_0 = memcpy_node->GetOutDataAnchor(kAnchorNum); | |||||
GE_CHECK_NOTNULL(out_data_anchor_0); | |||||
ret1 = out_data_anchor_0->LinkTo(hccl_in_anchor); | |||||
if (ret1 != SUCCESS) { | |||||
auto ret = GraphUtils::InsertNodeBefore(hccl_in_anchor, identity_node, kAnchorIdentityIndex, kAnchorIdentityIndex); | |||||
if (ret != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", | REPORT_CALL_ERROR("E19999", | ||||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", | |||||
out_data_anchor_0->GetOwnerNode()->GetName().c_str(), | |||||
out_data_anchor_0->GetOwnerNode()->GetType().c_str(), out_data_anchor_0->GetIdx(), | |||||
"Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | ||||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | ||||
hccl_in_anchor->GetIdx()); | hccl_in_anchor->GetIdx()); | ||||
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", memcpy_node->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
Status ret = src_out_anchor->LinkTo(memcpy_node->GetInDataAnchor(kAnchorNum)); | |||||
if (ret != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", | |||||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%u failed", | |||||
src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||||
memcpy_node->GetName().c_str(), memcpy_node->GetType().c_str(), | |||||
kAnchorNum); | |||||
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
memcpy_node->GetName().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||||
hccl_in_anchor->GetIdx()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -24,13 +24,15 @@ | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/utils/op_desc_utils.h" | |||||
namespace { | namespace { | ||||
const int32_t kAnchorSize = 1; | const int32_t kAnchorSize = 1; | ||||
const int kAnchorNum = 0; | |||||
const int32_t kAnchorAssignRefIndex = 0; | const int32_t kAnchorAssignRefIndex = 0; | ||||
const int32_t kAnchorAssignValueIndex = 1; | const int32_t kAnchorAssignValueIndex = 1; | ||||
const char *const kInputMutable = "_input_mutable"; | |||||
const int32_t kAnchorIdentityIndex = 0; | |||||
// attr _input_mutable = true means hccl node will modify its input in runtime | |||||
const char *const kModifyInput = "_input_mutable"; | |||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | ||||
@@ -58,24 +60,13 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||||
// need to inset memcpy node between. | // need to inset memcpy node between. | ||||
// also works on situation that input is variable or const. | // also works on situation that input is variable or const. | ||||
Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { | Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { | ||||
auto op_desc = node->GetOpDesc(); | |||||
bool node_input_mutable = false; | bool node_input_mutable = false; | ||||
if (!AttrUtils::HasAttr(op_desc, kInputMutable)) { | |||||
return SUCCESS; | |||||
} | |||||
if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) { | |||||
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", kInputMutable, | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
(void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, node_input_mutable); | |||||
if (!node_input_mutable) { | if (!node_input_mutable) { | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
GELOGI("input mutable hcom op is:%s.", op_desc->GetName().c_str()); | |||||
GELOGI("input mutable hcom op is:%s.", node->GetName().c_str()); | |||||
for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { | for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { | ||||
if (hccl_in_anchor == nullptr) { | if (hccl_in_anchor == nullptr) { | ||||
continue; | continue; | ||||
@@ -127,41 +118,23 @@ NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const O | |||||
std::string node_name = pre_node->GetName() + "_" + IDENTITY; | std::string node_name = pre_node->GetName() + "_" + IDENTITY; | ||||
node_name = CheckDuplicateName(node_name); | node_name = CheckDuplicateName(node_name); | ||||
OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY); | |||||
if (op_desc == nullptr) { | |||||
REPORT_CALL_ERROR("E19999", "New OpDesc failed"); | |||||
GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail."); | |||||
return nullptr; | |||||
} | |||||
GELOGI("Create Identity op:%s.", op_desc->GetName().c_str()); | |||||
graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed, name:x", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail."); | |||||
return nullptr; | |||||
} | |||||
ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed, name:y", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail."); | |||||
OpDescBuilder op_desc_builder(node_name, IDENTITY); | |||||
auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()); | |||||
auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build(); | |||||
if (identity_op_desc == nullptr) { | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
// because history reason ,this pass can not do work after constant fold so mark it | // because history reason ,this pass can not do work after constant fold so mark it | ||||
(void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||||
(void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||||
NodePtr memcpy_node = graph->AddNode(op_desc); | |||||
if (memcpy_node == nullptr) { | |||||
NodePtr identity_node = graph->AddNode(identity_op_desc); | |||||
if (identity_node == nullptr) { | |||||
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | ||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), graph->GetName().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
return memcpy_node; | |||||
return identity_node; | |||||
} | } | ||||
/// | /// | ||||
@@ -220,49 +193,24 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const | |||||
/// | /// | ||||
Status HcclMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, | Status HcclMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, | ||||
const InDataAnchorPtr &hccl_in_anchor) { | const InDataAnchorPtr &hccl_in_anchor) { | ||||
GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | ||||
NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); | |||||
GE_CHECK_NOTNULL(memcpy_node); | |||||
NodePtr identity_node = CreateIdentityNode(graph, src_out_anchor); | |||||
GE_CHECK_NOTNULL(identity_node); | |||||
Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); | |||||
if (ret1 != SUCCESS) { | |||||
auto ret = GraphUtils::InsertNodeBefore(hccl_in_anchor, identity_node, kAnchorIdentityIndex, kAnchorIdentityIndex); | |||||
if (ret != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", | REPORT_CALL_ERROR("E19999", | ||||
"Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", | |||||
src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), hccl_in_anchor->GetIdx()); | |||||
GELOGE(INTERNAL_ERROR, "The op %s Unlink anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
auto out_data_anchor_0 = memcpy_node->GetOutDataAnchor(kAnchorNum); | |||||
GE_CHECK_NOTNULL(out_data_anchor_0); | |||||
ret1 = out_data_anchor_0->LinkTo(hccl_in_anchor); | |||||
if (ret1 != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", | |||||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", | |||||
out_data_anchor_0->GetOwnerNode()->GetName().c_str(), | |||||
out_data_anchor_0->GetOwnerNode()->GetType().c_str(), out_data_anchor_0->GetIdx(), | |||||
"Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | ||||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | ||||
hccl_in_anchor->GetIdx()); | hccl_in_anchor->GetIdx()); | ||||
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", memcpy_node->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
Status ret = src_out_anchor->LinkTo(memcpy_node->GetInDataAnchor(kAnchorNum)); | |||||
if (ret != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", | |||||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%u failed", | |||||
src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||||
memcpy_node->GetName().c_str(), memcpy_node->GetType().c_str(), | |||||
kAnchorNum); | |||||
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
memcpy_node->GetName().c_str()); | |||||
GELOGE(INTERNAL_ERROR, "Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||||
hccl_in_anchor->GetIdx()); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -1,5 +1,5 @@ | |||||
/** | /** | ||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | * | ||||
* Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
* you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
@@ -372,6 +372,7 @@ set(COMMON_FORMAT_SRC_FILES | |||||
set(GRAPH_OPTIMIZE_COMMON_SRC_FILES | set(GRAPH_OPTIMIZE_COMMON_SRC_FILES | ||||
"${GE_CODE_DIR}/ge/graph/optimize/graph_optimize.cc" | "${GE_CODE_DIR}/ge/graph/optimize/graph_optimize.cc" | ||||
"${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" | "${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" | ||||
"${GE_CODE_DIR}/ge/graph/optimize/mem_rw_conflict_optimize.cc" | |||||
) | ) | ||||
@@ -715,7 +716,10 @@ set(PASS_TEST_FILES | |||||
"graph/passes/mark_node_unknown_shape_pass_unittest.cc" | "graph/passes/mark_node_unknown_shape_pass_unittest.cc" | ||||
"graph/passes/reshape_recovery_pass_unittest.cc" | "graph/passes/reshape_recovery_pass_unittest.cc" | ||||
"graph/passes/cast_remove_pass_unittest.cc" | "graph/passes/cast_remove_pass_unittest.cc" | ||||
"graph/passes/memcpy_addr_async_unittest.cc" | |||||
"graph/passes/memcpy_addr_async_unittest.cc" | |||||
"graph/passes/hccl_continuous_pass_unittest.cc" | |||||
"graph/passes/hccl_memcpy_pass_unittest.cc" | |||||
) | ) | ||||
set(KERNEL_TEST_FILES | set(KERNEL_TEST_FILES | ||||
@@ -798,6 +802,7 @@ set(MULTI_PARTS_TEST_FILES | |||||
"graph/manager/run_graph_unittest.cc" | "graph/manager/run_graph_unittest.cc" | ||||
"graph/partition/dynamic_shape_partition_unittest.cc" | "graph/partition/dynamic_shape_partition_unittest.cc" | ||||
"graph/manager/graph_manager_unittest.cc" | "graph/manager/graph_manager_unittest.cc" | ||||
"graph/optimize/mem_rw_conflict_optimize_unittest.cc" | |||||
"session/omg_omg_unittest.cc" | "session/omg_omg_unittest.cc" | ||||
"session/ge_api_unittest.cc" | "session/ge_api_unittest.cc" | ||||
"session/inner_session_unittest.cc" | "session/inner_session_unittest.cc" | ||||
@@ -0,0 +1,150 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <cstdint> | |||||
#include <string> | |||||
#include <gtest/gtest.h> | |||||
#define protected public | |||||
#define private public | |||||
#include "graph/optimize/graph_optimize.h" | |||||
#undef protected | |||||
#undef private | |||||
#include "../passes/graph_builder_utils.h" | |||||
#include "graph/debug/ge_attr_define.h" | |||||
namespace ge { | |||||
class UTest_Graph_Mem_RW_Conflict_Optimize : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
namespace { | |||||
/* | |||||
* Data -cast - netoutput | |||||
*/ | |||||
ComputeGraphPtr BuildGraph_Readonly_Subgraph(const string subraph_name){ | |||||
auto sub_builder = ut::GraphBuilder(subraph_name); | |||||
auto data1 = sub_builder.AddNode("data1", DATA, 0,1); | |||||
auto cast = sub_builder.AddNode("cast", CAST, 1,1); | |||||
auto netoutput = sub_builder.AddNode("netoutput",NETOUTPUT, 1,1); | |||||
AttrUtils::SetInt(data1->GetOpDesc(),ATTR_NAME_PARENT_NODE_INDEX, 1); | |||||
AttrUtils::SetInt(netoutput->GetOpDesc(),ATTR_NAME_PARENT_NODE_INDEX,0); | |||||
sub_builder.AddDataEdge(data1,0,cast,0); | |||||
sub_builder.AddDataEdge(cast,0,netoutput,0); | |||||
return sub_builder.GetGraph(); | |||||
} | |||||
/* | |||||
* const - allreduce | |||||
* \ if | |||||
* insert identity | |||||
*/ | |||||
ComputeGraphPtr BuildGraph_Readonly_ScopeWrite() { | |||||
auto builder = ut::GraphBuilder("test"); | |||||
auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); | |||||
auto ctrl_const = builder.AddNode("ctrl_const", CONSTANT, 0, 1); | |||||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||||
auto if_node = builder.AddNode("if", IF, 1,0); | |||||
builder.AddDataEdge(const1, 0, allreduce, 0); | |||||
builder.AddDataEdge(const1, 0, if_node, 0); | |||||
builder.AddControlEdge(ctrl_const, allreduce); | |||||
auto root_graph = builder.GetGraph(); | |||||
string subgraph_name = "then_branch"; | |||||
ComputeGraphPtr then_branch_graph = BuildGraph_Readonly_Subgraph(subgraph_name); | |||||
then_branch_graph->SetParentNode(if_node); | |||||
then_branch_graph->SetParentGraph(root_graph); | |||||
if_node->GetOpDesc()->AddSubgraphName(subgraph_name); | |||||
if_node->GetOpDesc()->SetSubgraphInstanceName(0,subgraph_name); | |||||
root_graph->AddSubgraph(subgraph_name, then_branch_graph); | |||||
return root_graph; | |||||
} | |||||
/* const1---allreduce const1--identity - allreduce | |||||
* / / | |||||
* var-identity--cast1 ==> var-----cast1 | |||||
* \ \ | |||||
* if if | |||||
*/ | |||||
ComputeGraphPtr BuildGraph_Identiyt_Split(){ | |||||
auto builder = ut::GraphBuilder("g1"); | |||||
auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||||
auto identity = builder.AddNode("identity", IDENTITY, 1, 1); | |||||
auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); | |||||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||||
auto cast1 = builder.AddNode("cast1", CAST, 1, 1); | |||||
auto if_node = builder.AddNode("if", IF, 1,0); | |||||
builder.AddDataEdge(var, 0 , identity, 0); | |||||
builder.AddDataEdge(identity, 0 , allreduce, 0); | |||||
builder.AddDataEdge(identity, 0 , cast1, 0); | |||||
builder.AddDataEdge(identity, 0 , if_node, 0); | |||||
builder.AddControlEdge(const1, allreduce); | |||||
auto root_graph = builder.GetGraph(); | |||||
string subgraph_name = "then_branch"; | |||||
ComputeGraphPtr then_branch_graph = BuildGraph_Readonly_Subgraph(subgraph_name); | |||||
then_branch_graph->SetParentNode(if_node); | |||||
then_branch_graph->SetParentGraph(root_graph); | |||||
if_node->GetOpDesc()->AddSubgraphName(subgraph_name); | |||||
if_node->GetOpDesc()->SetSubgraphInstanceName(0,subgraph_name); | |||||
root_graph->AddSubgraph(subgraph_name, then_branch_graph); | |||||
return root_graph; | |||||
} | |||||
/* | |||||
* mul == allreduce | |||||
* need insert identity | |||||
*/ | |||||
ComputeGraphPtr BuildGraph_mul_1To2_ScopeWrite() { | |||||
auto builder = ut::GraphBuilder("test"); | |||||
auto mul = builder.AddNode("mul", MUL, 2,1); | |||||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 2,0); | |||||
AttrUtils::SetBool(allreduce->GetOpDesc(), "_input_mutable", true); | |||||
builder.AddDataEdge(mul,0,allreduce,0); | |||||
builder.AddDataEdge(mul,0,allreduce,1); | |||||
return builder.GetGraph(); | |||||
} | |||||
} // namespace | |||||
// const -> allreduce | |||||
// const -> Identity -> allreduce | |||||
TEST(UtestGraphPassesHcclMemcpyPass, testReadonlyScopeWriteConflict) { | |||||
ComputeGraphPtr graph = BuildGraph_Readonly_ScopeWrite(); | |||||
GraphOptimize graph_optimizer; | |||||
auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
auto allreduce = graph->FindNode("allreduce"); | |||||
EXPECT_EQ(allreduce->GetInDataNodes().at(0)->GetType(), IDENTITY); | |||||
} | |||||
TEST(UtestGraphPassesHcclMemcpyPass, testIdentiytSplit) { | |||||
ComputeGraphPtr graph = BuildGraph_Identiyt_Split(); | |||||
GraphOptimize graph_optimizer; | |||||
auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
auto allreduce = graph->FindNode("allreduce"); | |||||
auto allreduce_in_node = allreduce->GetInDataNodes().at(0); | |||||
EXPECT_EQ(allreduce_in_node->GetType(), IDENTITY); | |||||
EXPECT_EQ(allreduce_in_node->GetInControlNodes().at(0)->GetType(), CONSTANT); | |||||
} | |||||
TEST(UtestGraphPassesHcclMemcpyPass, testMul_1To2_ScopeWrite) { | |||||
ComputeGraphPtr graph = BuildGraph_mul_1To2_ScopeWrite(); | |||||
EXPECT_EQ(graph->GetDirectNodesSize(), 2); | |||||
GraphOptimize graph_optimizer; | |||||
auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
EXPECT_EQ(graph->GetDirectNodesSize(), 3); | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,79 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <cstdint> | |||||
#include <string> | |||||
#include <gtest/gtest.h> | |||||
#include "common/ge_inner_error_codes.h" | |||||
#define protected public | |||||
#define private public | |||||
#include "graph/passes/hccl_continuous_memcpy_pass.h" | |||||
#undef protected | |||||
#undef private | |||||
#include "graph_builder_utils.h" | |||||
namespace ge { | |||||
class UtestGraphPassesHcclContinuousMemcpyPass : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
namespace { | |||||
/* | |||||
* var var | |||||
* | \ | \ | |||||
* | assign | assign | |||||
* | // =======> | // | |||||
* allreduce identity | |||||
* | | | |||||
* netoutput allreduce | |||||
* | | |||||
* netoutput | |||||
*/ | |||||
ComputeGraphPtr BuildGraph_Allreduce_Read_Var_After_Assign(){ | |||||
auto builder = ut::GraphBuilder("test"); | |||||
auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||||
auto assign = builder.AddNode("assign", ASSIGN, 1, 1); | |||||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||||
auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||||
builder.AddDataEdge(var, 0, assign, 0); | |||||
builder.AddDataEdge(var,0,allreduce,0); | |||||
builder.AddControlEdge(assign, allreduce); | |||||
return builder.GetGraph(); | |||||
} | |||||
} // namespace | |||||
// const -> allreduce | |||||
// const -> Identity -> allreduce | |||||
TEST(UtestGraphPassesHcclContinuousMemcpyPass, testInsertIdentityBeforeHccl) { | |||||
ComputeGraphPtr graph = BuildGraph_Allreduce_Read_Var_After_Assign(); | |||||
auto src_node = graph->FindNode("var"); | |||||
auto dst_node = graph->FindNode("allreduce"); | |||||
// test InsertIdentityBeforeHccl | |||||
HcclContinuousMemcpyPass hccl_continuous_memcpy_pass; | |||||
hccl_continuous_memcpy_pass.InsertIdentityBeforeHccl(graph, src_node->GetOutDataAnchor(0), dst_node->GetInDataAnchor(0)); | |||||
// check | |||||
dst_node = graph->FindNode("allreduce"); | |||||
auto in_node_before_dst_node = dst_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||||
EXPECT_EQ(in_node_before_dst_node->GetType(), IDENTITY); | |||||
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().size(), 1); | |||||
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().at(0)->GetName(), "assign"); | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,80 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <cstdint> | |||||
#include <string> | |||||
#include <gtest/gtest.h> | |||||
#include "common/ge_inner_error_codes.h" | |||||
#define protected public | |||||
#define private public | |||||
#include "graph/passes/hccl_memcpy_pass.h" | |||||
#undef protected | |||||
#undef private | |||||
#include "graph_builder_utils.h" | |||||
namespace ge { | |||||
class UtestGraphPassesHcclMemcpyPass : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() {} | |||||
}; | |||||
namespace { | |||||
/* | |||||
* var var | |||||
* | \ | \ | |||||
* | assign | assign | |||||
* | // =======> | // | |||||
* allreduce identity | |||||
* | | | |||||
* netoutput allreduce | |||||
* | | |||||
* netoutput | |||||
*/ | |||||
ComputeGraphPtr BuildGraph_Allreduce_Read_Var_After_Assign(){ | |||||
auto builder = ut::GraphBuilder("test"); | |||||
auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||||
auto assign = builder.AddNode("assign", ASSIGN, 1, 1); | |||||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||||
auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||||
builder.AddDataEdge(var, 0, assign, 0); | |||||
builder.AddDataEdge(var,0,allreduce,0); | |||||
builder.AddControlEdge(assign, allreduce); | |||||
return builder.GetGraph(); | |||||
} | |||||
} // namespace | |||||
// const -> allreduce | |||||
// const -> Identity -> allreduce | |||||
TEST(UtestGraphPassesHcclMemcpyPass, testInsertIdentityBeforeHccl) { | |||||
ComputeGraphPtr graph = BuildGraph_Allreduce_Read_Var_After_Assign(); | |||||
auto src_node = graph->FindNode("var"); | |||||
auto dst_node = graph->FindNode("allreduce"); | |||||
// test InsertIdentityBeforeHccl | |||||
HcclMemcpyPass hccl_memcpy_pass; | |||||
hccl_memcpy_pass.InsertIdentityBeforeHccl(graph, src_node->GetOutDataAnchor(0), | |||||
dst_node->GetInDataAnchor(0)); | |||||
// check | |||||
dst_node = graph->FindNode("allreduce"); | |||||
auto in_node_before_dst_node = dst_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||||
EXPECT_EQ(in_node_before_dst_node->GetType(), IDENTITY); | |||||
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().size(), 1); | |||||
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().at(0)->GetName(), "assign"); | |||||
} | |||||
} // namespace ge |