|
@@ -58,7 +58,7 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { |
|
|
GELOGE(INTERNAL_ERROR, "failed P2pmemInputProcess, node_name:%s.", node->GetName().c_str()); |
|
|
GELOGE(INTERNAL_ERROR, "failed P2pmemInputProcess, node_name:%s.", node->GetName().c_str()); |
|
|
return ret; |
|
|
return ret; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
} |
|
|
return ret; |
|
|
return ret; |
|
|
} |
|
|
} |
|
@@ -66,7 +66,7 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { |
|
|
// If node has _input_mutable attr, means input mem may be modified when op execute. |
|
|
// If node has _input_mutable attr, means input mem may be modified when op execute. |
|
|
// In order to avoid to affect another op execute with same input when data modified, |
|
|
// In order to avoid to affect another op execute with same input when data modified, |
|
|
// need to inset memcpy node between. |
|
|
// need to inset memcpy node between. |
|
|
// also works on situation that input is variable or const. |
|
|
|
|
|
|
|
|
// also works on situation that input is variable or const. |
|
|
Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { |
|
|
Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { |
|
|
auto op_desc = node->GetOpDesc(); |
|
|
auto op_desc = node->GetOpDesc(); |
|
|
|
|
|
|
|
@@ -77,7 +77,7 @@ Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const N |
|
|
|
|
|
|
|
|
if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) { |
|
|
if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) { |
|
|
GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); |
|
|
GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); |
|
|
return FAILED; |
|
|
|
|
|
|
|
|
return FAILED; |
|
|
} |
|
|
} |
|
|
if (!node_input_mutable) { |
|
|
if (!node_input_mutable) { |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
@@ -126,6 +126,7 @@ Status HcclMemcpyPass::ContinuousInputProcess(const ComputeGraphPtr &graph, cons |
|
|
(void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous); |
|
|
(void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous); |
|
|
|
|
|
|
|
|
if (is_input_continuous && op_desc->GetInputsSize() > 1) { |
|
|
if (is_input_continuous && op_desc->GetInputsSize() > 1) { |
|
|
|
|
|
GELOGI("continuous input op is:%s.", op_desc->GetName().c_str()); |
|
|
// if input size bigger than one, insert memcpy between var data for support continous mem alloc |
|
|
// if input size bigger than one, insert memcpy between var data for support continous mem alloc |
|
|
for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { |
|
|
for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { |
|
|
if (hccl_in_anchor == nullptr) { |
|
|
if (hccl_in_anchor == nullptr) { |
|
@@ -136,7 +137,7 @@ Status HcclMemcpyPass::ContinuousInputProcess(const ComputeGraphPtr &graph, cons |
|
|
GELOGE(INTERNAL_ERROR, "hcom op input has no peer anchor, node_name:%s", node->GetName().c_str()); |
|
|
GELOGE(INTERNAL_ERROR, "hcom op input has no peer anchor, node_name:%s", node->GetName().c_str()); |
|
|
return INTERNAL_ERROR; |
|
|
return INTERNAL_ERROR; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (IsDataNode(src_out_anchor->GetOwnerNode()->GetType())) { |
|
|
if (IsDataNode(src_out_anchor->GetOwnerNode()->GetType())) { |
|
|
Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); |
|
|
Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); |
|
|
if (ret != SUCCESS) { |
|
|
if (ret != SUCCESS) { |
|
@@ -165,6 +166,7 @@ Status HcclMemcpyPass::P2pmemInputProcess(const ComputeGraphPtr &graph, const No |
|
|
continue; |
|
|
continue; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
GELOGI("p2p input op is:%s.", op_desc->GetName().c_str()); |
|
|
auto hccl_in_anchor = node->GetInDataAnchor(index); |
|
|
auto hccl_in_anchor = node->GetInDataAnchor(index); |
|
|
if (hccl_in_anchor == nullptr) { |
|
|
if (hccl_in_anchor == nullptr) { |
|
|
continue; |
|
|
continue; |
|
@@ -263,7 +265,8 @@ std::string HcclMemcpyPass::CheckDuplicateName(const std::string &node_name) { |
|
|
/// |
|
|
/// |
|
|
Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, |
|
|
Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, |
|
|
const InDataAnchorPtr &hccl_in_anchor) { |
|
|
const InDataAnchorPtr &hccl_in_anchor) { |
|
|
GELOGI("The op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str()); |
|
|
|
|
|
|
|
|
GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), |
|
|
|
|
|
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); |
|
|
NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); |
|
|
NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); |
|
|
GE_CHECK_NOTNULL(memcpy_node); |
|
|
GE_CHECK_NOTNULL(memcpy_node); |
|
|
|
|
|
|
|
|