|
|
@@ -25,7 +25,17 @@ |
|
|
|
namespace ge { |
|
|
|
Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { |
|
|
|
GE_CHECK_NOTNULL(graph); |
|
|
|
|
|
|
|
if (graph->GetGraphUnknownFlag()) { |
|
|
|
// insert memcpyasync node when parent is unknown graph |
|
|
|
for (const auto &node : graph->GetAllNodes()) { |
|
|
|
if (node->GetType() == STREAMSWITCH) { |
|
|
|
auto sub_graph = node->GetOwnerComputeGraph(); |
|
|
|
if (sub_graph != nullptr && !sub_graph->GetGraphUnknownFlag()) { |
|
|
|
GE_CHK_STATUS_RET(AddMemcpyAsyncNode(node), "Add memcpyasync node failed in known subgraph."); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
GELOGD("Graph[%s] is unknown graph, skip.", graph->GetName().c_str()); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
@@ -63,6 +73,26 @@ Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status MemcpyAddrAsyncPass::AddMemcpyAsyncNode(const NodePtr &node) { |
|
|
|
GELOGI("Start add memcpyasync node in front of node %s", node->GetName().c_str()); |
|
|
|
auto sub_graph = node->GetOwnerComputeGraph(); |
|
|
|
for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { |
|
|
|
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); |
|
|
|
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); |
|
|
|
auto memcpy_async_node = CreateMemcpyAsyncNode(sub_graph, peer_out_anchor, node); |
|
|
|
if (memcpy_async_node == nullptr) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Create memcpyasync node failed."); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
Status ret = InsertMemcpyAddrAsyncNode(peer_out_anchor, in_data_anchor, memcpy_async_node); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(ret, "Insert memcpyasync node failed."); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status MemcpyAddrAsyncPass::AddMemcpyAddrAsyncNode(const ComputeGraphPtr &graph, const NodePtr &node) { |
|
|
|
GELOGI("Start AddMemcpyAddrAsyncNode for %s.", node->GetName().c_str()); |
|
|
|
for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { |
|
|
@@ -256,6 +286,35 @@ NodePtr MemcpyAddrAsyncPass::CreateMemcpyAddrAsyncNode(const ComputeGraphPtr &gr |
|
|
|
return memcpy_addr_async_node; |
|
|
|
} |
|
|
|
|
|
|
|
// create memcpy async node for known sub graph |
|
|
|
NodePtr MemcpyAddrAsyncPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, |
|
|
|
const OutDataAnchorPtr &out_data_anchor, |
|
|
|
const NodePtr &out_of_user_data) { |
|
|
|
GELOGD("Start CreateMemcpyAsyncNode."); |
|
|
|
static uint32_t new_node_index = 0; |
|
|
|
OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); |
|
|
|
GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "Op_desc of pre node is invalid."); |
|
|
|
|
|
|
|
string node_name = pre_op_desc->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(new_node_index++); |
|
|
|
OpDescPtr op_desc = MakeShared<OpDesc>(node_name, MEMCPYASYNC); |
|
|
|
GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); |
|
|
|
|
|
|
|
if (op_desc->AddInputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Add memcpyasync input desc failed."); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
if (op_desc->AddOutputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Add memcpyasync output desc failed."); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
NodePtr memcpy_async_node = graph->AddNode(op_desc); |
|
|
|
GE_CHECK_NOTNULL_EXEC(memcpy_async_node, return nullptr); |
|
|
|
|
|
|
|
return memcpy_async_node; |
|
|
|
} |
|
|
|
|
|
|
|
Status MemcpyAddrAsyncPass::InsertMemcpyAddrAsyncNode(const OutDataAnchorPtr &out_anchor, |
|
|
|
const InDataAnchorPtr &in_anchor, const NodePtr &node) { |
|
|
|
// insert memcpy_addr of each user_data and out_of_user_data |
|
|
|