From 94f2185b9b72fb0fab84d08c477e60927775aa3e Mon Sep 17 00:00:00 2001 From: wangxiaotian22 Date: Thu, 5 Nov 2020 20:11:13 +0800 Subject: [PATCH] fix --- ge/graph/build/memory/block_mem_assigner.cc | 2 +- ge/graph/passes/hccl_memcpy_pass.cc | 154 ++++++++++++++++++++-------- ge/graph/passes/hccl_memcpy_pass.h | 8 +- 3 files changed, 117 insertions(+), 47 deletions(-) diff --git a/ge/graph/build/memory/block_mem_assigner.cc b/ge/graph/build/memory/block_mem_assigner.cc index a406d384..d59023f8 100755 --- a/ge/graph/build/memory/block_mem_assigner.cc +++ b/ge/graph/build/memory/block_mem_assigner.cc @@ -428,7 +428,7 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector &all_memory_size) { // if input size just one, no need to reassign continuous memory bool is_input_continuous = false; (void)ge::AttrUtils::GetBool(node_op_desc, ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous); - if (is_input_continuous && (node_op_desc->GetInputSize() <= 1)) { + if (is_input_continuous && (node_op_desc->GetInputsSize() <= 1)) { (void)ge::AttrUtils::SetBool(node_op_desc, ATTR_NAME_CONTINUOUS_INPUT_ALLOC, true); } diff --git a/ge/graph/passes/hccl_memcpy_pass.cc b/ge/graph/passes/hccl_memcpy_pass.cc index 2c34c38e..8471b1d8 100755 --- a/ge/graph/passes/hccl_memcpy_pass.cc +++ b/ge/graph/passes/hccl_memcpy_pass.cc @@ -32,75 +32,101 @@ const char *const kInputMutable = "_input_mutable"; } // namespace namespace ge { Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { + Status ret = SUCCESS; GE_IF_BOOL_EXEC(graph == nullptr, GELOGE(PARAM_INVALID, "param [graph] must not be null."); return PARAM_INVALID); for (const auto &node : graph->GetDirectNode()) { auto op_desc = node->GetOpDesc(); - GE_IF_BOOL_EXEC(op_desc == nullptr, continue); + if (op_desc == nullptr) { + GELOGE(INTERNAL_ERROR, "node has no op_desc, node_name : %s.", node->GetName().c_str()); + return INTERNAL_ERROR; + } - Status ret = ProcessBroadcastMemcpy(graph, node); + ret = ContinuousInputProcess(graph, node); if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "failed ProcessBroadcastMemcpy."); + GELOGE(INTERNAL_ERROR, "failed ProcessBroadcastMemcpy, node_name:%s.", node->GetName().c_str()); return ret; } - bool node_input_mutable = false; - if (!AttrUtils::HasAttr(op_desc, kInputMutable)) { - continue; + ret = MutableInputProcess(graph, node); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "failed MutableInputProcess, node_name:%s.", node->GetName().c_str()); + return ret; + } + + ret = P2pmemInputProcess(graph, node); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "failed P2pmemInputProcess, node_name:%s.", node->GetName().c_str()); + return ret; } + + } + return ret; +} + +// If node has _input_mutable attr, means input mem may be modified when op execute. +// In order to avoid to affect another op execute with same input when data modified, +// need to inset memcpy node between. +// also works on situation that input is variable or const. +Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { + auto op_desc = node->GetOpDesc(); + + bool node_input_mutable = false; + if (!AttrUtils::HasAttr(op_desc, kInputMutable)) { + return SUCCESS; + } + + if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) { + GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); + return FAILED; + } + if (!node_input_mutable) { + return SUCCESS; + } - GE_IF_BOOL_EXEC(!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable), - GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); return FAILED); - if (!node_input_mutable) { + GELOGI("input mutable hcom op is:%s.", op_desc->GetName().c_str()); + for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { + if (hccl_in_anchor == nullptr) { continue; } + auto src_out_anchor = hccl_in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(src_out_anchor); - GELOGI("hcom op is:%s.", op_desc->GetName().c_str()); - for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { - if (hccl_in_anchor == nullptr) { - continue; - } - auto src_out_anchor = hccl_in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(src_out_anchor); - - int32_t src_out_anchor_size = src_out_anchor->GetPeerInDataAnchors().size(); - if (src_out_anchor_size == kAnchorSize) { - // Memcpyasync needs to be inserted between constant (/data) and hcomallreduce to avoid constant being cleared. - NodePtr src_node = src_out_anchor->GetOwnerNode(); - std::string src_type = src_node->GetType(); - bool check_src_type = (src_type == CONSTANTOP) || (src_type == VARIABLE) || (src_type == DATA) || (src_type == CONSTANT); - if (check_src_type) { - Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to modify the connection."); - return ret; - } + int32_t src_out_anchor_size = src_out_anchor->GetPeerInDataAnchors().size(); + if (src_out_anchor_size == kAnchorSize) { + // Memcpyasync needs to be inserted between constant (/data) and hcomallreduce to avoid constant being cleared. + if (IsDataNode(src_out_anchor->GetOwnerNode()->GetType())) { + Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to modify the connection."); + return ret; } - continue; } + continue; + } - Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to modify the connection."); - return ret; - } + Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to modify the connection."); + return ret; } } return SUCCESS; } + // If broadcast input size is bigger than 1, and input from variable, // cause by broadcast input memory should be continuous, // another featuremap mem will be allocated for broadcast input. // In this condition, move data from variable mem to broadcast input featuremap mem will be executed each step. // In order to avoid move action out of model, use memcpy node instead of move action code. -Status HcclMemcpyPass::ProcessBroadcastMemcpy(const ComputeGraphPtr &graph, const NodePtr node) { +Status HcclMemcpyPass::ContinuousInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { auto op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - GELOGE(INTERNAL_ERROR, "node has no op_desc, node_name : %s.", node->GetName().c_str()); - return INTERNAL_ERROR; - } - if ((node->GetType() == HCOMBROADCAST || node->GetType() == HVDCALLBACKBROADCAST) && op_desc->GetInputSize() > 1) { + bool is_input_continuous = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous); + + if (is_input_continuous && op_desc->GetInputsSize() > 1) { + // if input size bigger than one, insert memcpy between var data for support continous mem alloc for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { if (hccl_in_anchor == nullptr) { continue; @@ -111,10 +137,7 @@ Status HcclMemcpyPass::ProcessBroadcastMemcpy(const ComputeGraphPtr &graph, cons return INTERNAL_ERROR; } - NodePtr src_node = src_out_anchor->GetOwnerNode(); - std::string src_type = src_node->GetType(); - bool check_src_type = (src_type == CONSTANTOP) || (src_type == VARIABLE) || (src_type == DATA) || (src_type == CONSTANT); - if (check_src_type) { + if (IsDataNode(src_out_anchor->GetOwnerNode()->GetType())) { Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); if (ret != SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to modify the connection."); @@ -126,6 +149,47 @@ Status HcclMemcpyPass::ProcessBroadcastMemcpy(const ComputeGraphPtr &graph, cons return SUCCESS; } +// if input is var type, and node input need p2p mem, then memcpy should be insert between the two +Status HcclMemcpyPass::P2pmemInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { + auto op_desc = node->GetOpDesc(); + + vector input_memory_types; + (void) ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_INPUT_MEM_TYPE_LIST, input_memory_types); + + if (input_memory_types.empty()) { + return SUCCESS; + } + + for (int index = 0; index < input_memory_types.size() && index < op_desc->GetInputsSize(); index++) { + if (input_memory_types[index] != RT_MEMORY_P2P_DDR) { + continue; + } + + auto hccl_in_anchor = GetInDataAnchor(index); + if (hccl_in_anchor == nullptr) { + continue; + } + auto src_out_anchor = hccl_in_anchor->GetPeerOutAnchor(); + if (src_out_anchor == nullptr) { + GELOGE(INTERNAL_ERROR, "hcom op input has no peer anchor, node_name:%s", node->GetName().c_str()); + return INTERNAL_ERROR; + } + + if (IsDataNode(src_out_anchor->GetOwnerNode()->GetType())) { + Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to modify the connection."); + return ret; + } + } + } + return SUCCESS; +} + +bool HcclMemcpyPass::IsDataNode(const std::string& node_type) { + return (node_type == CONSTANTOP) || (node_type == VARIABLE) || (node_type == DATA) || (node_type == CONSTANT); +} + /// /// @brief Add MemcpyAsync Node /// @param [in] ge::ComputeGraphPtr graph diff --git a/ge/graph/passes/hccl_memcpy_pass.h b/ge/graph/passes/hccl_memcpy_pass.h index aaf00779..81de2e80 100755 --- a/ge/graph/passes/hccl_memcpy_pass.h +++ b/ge/graph/passes/hccl_memcpy_pass.h @@ -37,7 +37,13 @@ class HcclMemcpyPass : public GraphPass { Status ModifyEdgeConnection(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, const InDataAnchorPtr &hccl_in_anchor); - Status ProcessBroadcastMemcpy(const ComputeGraphPtr &graph, const NodePtr node); + Status ContinuousInputProcess(const ComputeGraphPtr &graph, const NodePtr node); + + Status MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node); + + Status P2pmemInputProcess(const ComputeGraphPtr &graph, const NodePtr node); + + bool HcclMemcpyPass::IsDataNode(const std::string& node_type); std::unordered_map node_num_map_; };