|
@@ -198,7 +198,7 @@ bool HcclMemcpyPass::IsDataNode(const std::string& node_type) { |
|
|
/// @param [in] ge::OutDataAnchorPtr in_node |
|
|
/// @param [in] ge::OutDataAnchorPtr in_node |
|
|
/// @return ge::NodePtr |
|
|
/// @return ge::NodePtr |
|
|
/// |
|
|
/// |
|
|
NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor) { |
|
|
|
|
|
|
|
|
NodePtr HcclMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor) { |
|
|
GE_IF_BOOL_EXEC(graph == nullptr, return nullptr); |
|
|
GE_IF_BOOL_EXEC(graph == nullptr, return nullptr); |
|
|
NodePtr pre_node = out_data_anchor->GetOwnerNode(); |
|
|
NodePtr pre_node = out_data_anchor->GetOwnerNode(); |
|
|
OpDescPtr pre_op_desc = pre_node->GetOpDesc(); |
|
|
OpDescPtr pre_op_desc = pre_node->GetOpDesc(); |
|
@@ -207,24 +207,24 @@ NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const O |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
std::string node_name = pre_node->GetName() + "_" + IDENTITY; |
|
|
|
|
|
|
|
|
std::string node_name = pre_node->GetName() + "_" + MEMCPYASYNC; |
|
|
node_name = CheckDuplicateName(node_name); |
|
|
node_name = CheckDuplicateName(node_name); |
|
|
OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY); |
|
|
|
|
|
|
|
|
OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), MEMCPYASYNC); |
|
|
if (op_desc == nullptr) { |
|
|
if (op_desc == nullptr) { |
|
|
GELOGE(INTERNAL_ERROR, "Create identity op: MakeShared op_desc fail."); |
|
|
|
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: MakeShared op_desc fail."); |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
GELOGI("Create identity op:%s.", op_desc->GetName().c_str()); |
|
|
|
|
|
|
|
|
GELOGI("Create MemcpyAsync op:%s.", op_desc->GetName().c_str()); |
|
|
|
|
|
|
|
|
graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); |
|
|
graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); |
|
|
if (ret != GRAPH_SUCCESS) { |
|
|
if (ret != GRAPH_SUCCESS) { |
|
|
GELOGE(INTERNAL_ERROR, "Create identity op: add input desc fail."); |
|
|
|
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: add input desc fail."); |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); |
|
|
ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); |
|
|
if (ret != GRAPH_SUCCESS) { |
|
|
if (ret != GRAPH_SUCCESS) { |
|
|
GELOGE(INTERNAL_ERROR, "Create identity op: add output desc fail."); |
|
|
|
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: add output desc fail."); |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
// because history reason ,this pass can not do work after constant fold so mark it |
|
|
// because history reason ,this pass can not do work after constant fold so mark it |
|
@@ -232,7 +232,7 @@ NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const O |
|
|
|
|
|
|
|
|
NodePtr memcpy_node = graph->AddNode(op_desc); |
|
|
NodePtr memcpy_node = graph->AddNode(op_desc); |
|
|
if (memcpy_node == nullptr) { |
|
|
if (memcpy_node == nullptr) { |
|
|
GELOGE(INTERNAL_ERROR, "Insert identity node fail."); |
|
|
|
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "Insert MemcpyAsync node fail."); |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@@ -267,7 +267,7 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const |
|
|
const InDataAnchorPtr &hccl_in_anchor) { |
|
|
const InDataAnchorPtr &hccl_in_anchor) { |
|
|
GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), |
|
|
GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), |
|
|
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); |
|
|
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); |
|
|
NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); |
|
|
|
|
|
|
|
|
NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, src_out_anchor); |
|
|
GE_CHECK_NOTNULL(memcpy_node); |
|
|
GE_CHECK_NOTNULL(memcpy_node); |
|
|
|
|
|
|
|
|
Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); |
|
|
Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); |
|
|