Browse Source

move hccl memcpy after trans op fusion pass

pull/277/head
wangxiaotian22 4 years ago
parent
commit
7610599a3a
3 changed files with 12 additions and 12 deletions
  1. +2
    -2
      ge/graph/manager/graph_manager.cc
  2. +9
    -9
      ge/graph/passes/hccl_memcpy_pass.cc
  3. +1
    -1
      ge/graph/passes/hccl_memcpy_pass.h

+ 2
- 2
ge/graph/manager/graph_manager.cc View File

@@ -1962,8 +1962,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
}
PassManager after_merge_passes;
GE_CHK_STATUS_RET(
after_merge_passes.AddPass("OptimizeStage1_1::HcclMemcpyPass", new (std::nothrow) HcclMemcpyPass));
GE_CHK_STATUS_RET(
after_merge_passes.AddPass("OptimizeStage1_1::MergeInputMemcpyPass", new (std::nothrow) MergeInputMemcpyPass));
GE_CHK_STATUS_RET(
after_merge_passes.AddPass("OptimizeStage1_1::SwitchDataEdgesBypass", new (std::nothrow) SwitchDataEdgesBypass));
@@ -1996,6 +1994,8 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
new (std::nothrow) TransOpWithoutReshapeFusionPass))
GE_CHK_STATUS_RET(after_merge_passes.AddPass("OptimizeStage1_1::TransOpBreadthFusionPass",
new (std::nothrow) TransOpBreadthFusionPass))
GE_CHK_STATUS_RET(
after_merge_passes.AddPass("OptimizeStage1_1::HcclMemcpyPass", new (std::nothrow) HcclMemcpyPass));

GE_TIMESTAMP_START(after_merge_passes);
auto ret = after_merge_passes.Run(compute_graph);


+ 9
- 9
ge/graph/passes/hccl_memcpy_pass.cc View File

@@ -198,7 +198,7 @@ bool HcclMemcpyPass::IsDataNode(const std::string& node_type) {
/// @param [in] ge::OutDataAnchorPtr in_node
/// @return ge::NodePtr
///
NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor) {
NodePtr HcclMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor) {
GE_IF_BOOL_EXEC(graph == nullptr, return nullptr);
NodePtr pre_node = out_data_anchor->GetOwnerNode();
OpDescPtr pre_op_desc = pre_node->GetOpDesc();
@@ -207,24 +207,24 @@ NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const O
return nullptr;
}

std::string node_name = pre_node->GetName() + "_" + IDENTITY;
std::string node_name = pre_node->GetName() + "_" + MEMCPYASYNC;
node_name = CheckDuplicateName(node_name);
OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY);
OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), MEMCPYASYNC);
if (op_desc == nullptr) {
GELOGE(INTERNAL_ERROR, "Create identity op: MakeShared op_desc fail.");
GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: MakeShared op_desc fail.");
return nullptr;
}
GELOGI("Create identity op:%s.", op_desc->GetName().c_str());
GELOGI("Create MemcpyAsync op:%s.", op_desc->GetName().c_str());

graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()));
if (ret != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Create identity op: add input desc fail.");
GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: add input desc fail.");
return nullptr;
}

ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()));
if (ret != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Create identity op: add output desc fail.");
GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: add output desc fail.");
return nullptr;
}
// because history reason ,this pass can not do work after constant fold so mark it
@@ -232,7 +232,7 @@ NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const O

NodePtr memcpy_node = graph->AddNode(op_desc);
if (memcpy_node == nullptr) {
GELOGE(INTERNAL_ERROR, "Insert identity node fail.");
GELOGE(INTERNAL_ERROR, "Insert MemcpyAsync node fail.");
return nullptr;
}

@@ -267,7 +267,7 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const
const InDataAnchorPtr &hccl_in_anchor) {
GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str());
NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor);
NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, src_out_anchor);
GE_CHECK_NOTNULL(memcpy_node);

Status ret1 = src_out_anchor->Unlink(hccl_in_anchor);


+ 1
- 1
ge/graph/passes/hccl_memcpy_pass.h View File

@@ -30,7 +30,7 @@ class HcclMemcpyPass : public GraphPass {
Status ClearStatus() override;

private:
NodePtr CreateIdentityNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor);
NodePtr CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor);

std::string CheckDuplicateName(const std::string &node_name);



Loading…
Cancel
Save