diff --git a/ge/graph/load/model_manager/davinci_model.cc b/ge/graph/load/model_manager/davinci_model.cc index 580394ec..c6bfe8eb 100755 --- a/ge/graph/load/model_manager/davinci_model.cc +++ b/ge/graph/load/model_manager/davinci_model.cc @@ -859,6 +859,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { continue; } + // for dynamic shape with control flow + SetLabelForDynamic(node); auto it = op_desc_handle.find(op_desc->GetType()); if (it != op_desc_handle.end()) { if ((this->*it->second)(op_desc) != SUCCESS) { @@ -867,8 +869,6 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { } continue; } - // for dynamic shape with control flow - SetLabelForDynamic(node); if (IsNoTaskAndDumpNeeded(op_desc)) { GELOGD("node[%s] without task, and save op_desc and addr for dump", op_desc->GetName().c_str()); const RuntimeParam &rts_param = GetRuntimeParam(); @@ -910,14 +910,14 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { } void DavinciModel::SetLabelForDynamic(const NodePtr &node) { - if (known_node_ && node->GetOpDesc()->GetType() == LABELSWITCHBYINDEX) { + if (known_node_ && (node->GetType() == LABELSWITCHBYINDEX || node->GetType() == STREAMSWITCH)) { for (auto &in_data_anchor : node->GetAllInDataAnchors()) { auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); if (peer_out_data_anchor != nullptr) { - string tensor_name = node->GetName(); + // name+index as the label of switch input + string tensor_name = node->GetName() + std::to_string(in_data_anchor->GetIdx()); auto peer_node = peer_out_data_anchor->GetOwnerNode(); (void)AttrUtils::SetStr(peer_node->GetOpDesc(), ATTR_DYNAMIC_SHAPE_FIXED_ADDR, tensor_name); - (void)AttrUtils::SetInt(peer_node->GetOpDesc(), ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX, 0); tensor_name_to_peer_output_index_[tensor_name] = 0; } } diff --git a/ge/graph/load/model_manager/task_info/stream_switch_task_info.cc b/ge/graph/load/model_manager/task_info/stream_switch_task_info.cc index f129950a..f51c154c 100644 --- a/ge/graph/load/model_manager/task_info/stream_switch_task_info.cc +++ b/ge/graph/load/model_manager/task_info/stream_switch_task_info.cc @@ -123,7 +123,7 @@ Status StreamSwitchTaskInfo::CalculateArgs(const domi::TaskDef &task_def, Davinc return FAILED; } for (uint32_t i = 0; i < STREAM_SWITCH_INPUT_NUM; ++i) { - string input_tensor_name = op_desc->GetInputNameByIndex(i); + string input_tensor_name = op_desc->GetName() + std::to_string(i); int64_t fixed_addr_offset = davinci_model->GetFixedAddrsSize(input_tensor_name); fixed_addr_offset_.emplace_back(fixed_addr_offset); auto tensor_desc = op_desc->GetInputDesc(i); diff --git a/ge/graph/passes/memcpy_addr_async_pass.cc b/ge/graph/passes/memcpy_addr_async_pass.cc index b930f7cb..1f6ed4bb 100755 --- a/ge/graph/passes/memcpy_addr_async_pass.cc +++ b/ge/graph/passes/memcpy_addr_async_pass.cc @@ -25,7 +25,17 @@ namespace ge { Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); + if (graph->GetGraphUnknownFlag()) { + // insert memcpyasync node when parent is unknown graph + for (const auto &node : graph->GetAllNodes()) { + if (node->GetType() == STREAMSWITCH) { + auto sub_graph = node->GetOwnerComputeGraph(); + if (sub_graph != nullptr && !sub_graph->GetGraphUnknownFlag()) { + GE_CHK_STATUS_RET(AddMemcpyAsyncNode(node), "Add memcpyasync node failed in known subgraph."); + } + } + } GELOGD("Graph[%s] is unknown graph, skip.", graph->GetName().c_str()); return SUCCESS; } @@ -63,6 +73,26 @@ Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { return SUCCESS; } +Status MemcpyAddrAsyncPass::AddMemcpyAsyncNode(const NodePtr &node) { + GELOGI("Start add memcpyasync node in front of node %s", node->GetName().c_str()); + auto sub_graph = node->GetOwnerComputeGraph(); + for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); + auto memcpy_async_node = CreateMemcpyAsyncNode(sub_graph, peer_out_anchor, node); + if (memcpy_async_node == nullptr) { + GELOGE(INTERNAL_ERROR, "Create memcpyasync node failed."); + return INTERNAL_ERROR; + } + Status ret = InsertMemcpyAddrAsyncNode(peer_out_anchor, in_data_anchor, memcpy_async_node); + if (ret != SUCCESS) { + GELOGE(ret, "Insert memcpyasync node failed."); + return ret; + } + } + return SUCCESS; +} + Status MemcpyAddrAsyncPass::AddMemcpyAddrAsyncNode(const ComputeGraphPtr &graph, const NodePtr &node) { GELOGI("Start AddMemcpyAddrAsyncNode for %s.", node->GetName().c_str()); for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { @@ -256,6 +286,35 @@ NodePtr MemcpyAddrAsyncPass::CreateMemcpyAddrAsyncNode(const ComputeGraphPtr &gr return memcpy_addr_async_node; } +// create memcpy async node for known sub graph +NodePtr MemcpyAddrAsyncPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, + const OutDataAnchorPtr &out_data_anchor, + const NodePtr &out_of_user_data) { + GELOGD("Start CreateMemcpyAsyncNode."); + static uint32_t new_node_index = 0; + OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "Op_desc of pre node is invalid."); + + string node_name = pre_op_desc->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(new_node_index++); + OpDescPtr op_desc = MakeShared(node_name, MEMCPYASYNC); + GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); + + if (op_desc->AddInputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add memcpyasync input desc failed."); + return nullptr; + } + + if (op_desc->AddOutputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add memcpyasync output desc failed."); + return nullptr; + } + + NodePtr memcpy_async_node = graph->AddNode(op_desc); + GE_CHECK_NOTNULL_EXEC(memcpy_async_node, return nullptr); + + return memcpy_async_node; +} + Status MemcpyAddrAsyncPass::InsertMemcpyAddrAsyncNode(const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &in_anchor, const NodePtr &node) { // insert memcpy_addr of each user_data and out_of_user_data diff --git a/ge/graph/passes/memcpy_addr_async_pass.h b/ge/graph/passes/memcpy_addr_async_pass.h index 0f22d10b..1c27a1d5 100755 --- a/ge/graph/passes/memcpy_addr_async_pass.h +++ b/ge/graph/passes/memcpy_addr_async_pass.h @@ -27,6 +27,7 @@ class MemcpyAddrAsyncPass : public GraphPass { private: Status AddMemcpyAddrAsyncNode(const ComputeGraphPtr &graph, const NodePtr &node); + Status AddMemcpyAsyncNode(const NodePtr &node); void FindUserData(const NodePtr &node, uint32_t &parent_index); void FindUserDataForKnown(const NodePtr &parent_node, uint32_t &parent_index); void FindUserDataForNonDynamic(const ge::NodePtr &parent_node, uint32_t &parent_index); @@ -34,6 +35,8 @@ class MemcpyAddrAsyncPass : public GraphPass { NodePtr CreateMemcpyAddrAsyncNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, const NodePtr &out_of_user_data); + NodePtr CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, + const NodePtr &out_of_user_data); Status InsertMemcpyAddrAsyncNode(const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &in_anchor, const NodePtr &node); Status InsertMemAddrAsyncNodeBeforeNetoutput(const ComputeGraphPtr &graph, const NodePtr &node);