From 7610599a3a1d2ca828468e8323bb09b38f6b840d Mon Sep 17 00:00:00 2001 From: wangxiaotian22 Date: Wed, 11 Nov 2020 10:22:58 +0800 Subject: [PATCH] move hccl memcpy after trans op fusion pass --- ge/graph/manager/graph_manager.cc | 4 ++-- ge/graph/passes/hccl_memcpy_pass.cc | 18 +++++++++--------- ge/graph/passes/hccl_memcpy_pass.h | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index ce0ddcc3..3057e8ad 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -1962,8 +1962,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { } PassManager after_merge_passes; GE_CHK_STATUS_RET( - after_merge_passes.AddPass("OptimizeStage1_1::HcclMemcpyPass", new (std::nothrow) HcclMemcpyPass)); - GE_CHK_STATUS_RET( after_merge_passes.AddPass("OptimizeStage1_1::MergeInputMemcpyPass", new (std::nothrow) MergeInputMemcpyPass)); GE_CHK_STATUS_RET( after_merge_passes.AddPass("OptimizeStage1_1::SwitchDataEdgesBypass", new (std::nothrow) SwitchDataEdgesBypass)); @@ -1996,6 +1994,8 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { new (std::nothrow) TransOpWithoutReshapeFusionPass)) GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::TransOpBreadthFusionPass", new (std::nothrow) TransOpBreadthFusionPass)) + GE_CHK_STATUS_RET( + after_merge_passes.AddPass("OptimizeStage1_1::HcclMemcpyPass", new (std::nothrow) HcclMemcpyPass)); GE_TIMESTAMP_START(after_merge_passes); auto ret = after_merge_passes.Run(compute_graph); diff --git a/ge/graph/passes/hccl_memcpy_pass.cc b/ge/graph/passes/hccl_memcpy_pass.cc index 0635a1a3..553dbf20 100755 --- a/ge/graph/passes/hccl_memcpy_pass.cc +++ b/ge/graph/passes/hccl_memcpy_pass.cc @@ -198,7 +198,7 @@ bool HcclMemcpyPass::IsDataNode(const std::string& node_type) { /// @param [in] ge::OutDataAnchorPtr in_node /// @return ge::NodePtr /// -NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor) { +NodePtr HcclMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor) { GE_IF_BOOL_EXEC(graph == nullptr, return nullptr); NodePtr pre_node = out_data_anchor->GetOwnerNode(); OpDescPtr pre_op_desc = pre_node->GetOpDesc(); @@ -207,24 +207,24 @@ NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const O return nullptr; } - std::string node_name = pre_node->GetName() + "_" + IDENTITY; + std::string node_name = pre_node->GetName() + "_" + MEMCPYASYNC; node_name = CheckDuplicateName(node_name); - OpDescPtr op_desc = MakeShared(node_name.c_str(), IDENTITY); + OpDescPtr op_desc = MakeShared(node_name.c_str(), MEMCPYASYNC); if (op_desc == nullptr) { - GELOGE(INTERNAL_ERROR, "Create identity op: MakeShared op_desc fail."); + GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: MakeShared op_desc fail."); return nullptr; } - GELOGI("Create identity op:%s.", op_desc->GetName().c_str()); + GELOGI("Create MemcpyAsync op:%s.", op_desc->GetName().c_str()); graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); if (ret != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Create identity op: add input desc fail."); + GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: add input desc fail."); return nullptr; } ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); if (ret != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Create identity op: add output desc fail."); + GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: add output desc fail."); return nullptr; } // because history reason ,this pass can not do work after constant fold so mark it @@ -232,7 +232,7 @@ NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const O NodePtr memcpy_node = graph->AddNode(op_desc); if (memcpy_node == nullptr) { - GELOGE(INTERNAL_ERROR, "Insert identity node fail."); + GELOGE(INTERNAL_ERROR, "Insert MemcpyAsync node fail."); return nullptr; } @@ -267,7 +267,7 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const const InDataAnchorPtr &hccl_in_anchor) { GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), hccl_in_anchor->GetOwnerNode()->GetName().c_str()); - NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); + NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, src_out_anchor); GE_CHECK_NOTNULL(memcpy_node); Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); diff --git a/ge/graph/passes/hccl_memcpy_pass.h b/ge/graph/passes/hccl_memcpy_pass.h index 1e946fa7..26df2de0 100755 --- a/ge/graph/passes/hccl_memcpy_pass.h +++ b/ge/graph/passes/hccl_memcpy_pass.h @@ -30,7 +30,7 @@ class HcclMemcpyPass : public GraphPass { Status ClearStatus() override; private: - NodePtr CreateIdentityNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor); + NodePtr CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor); std::string CheckDuplicateName(const std::string &node_name);