diff --git a/ge/graph/common/omg_util.cc b/ge/graph/common/omg_util.cc index 15fa3c47..670355b8 100644 --- a/ge/graph/common/omg_util.cc +++ b/ge/graph/common/omg_util.cc @@ -193,23 +193,29 @@ Status SetCyclicDependenceFlag(const ge::NodePtr &node) { /// /// @brief set op next_iteration name -/// @param [in] node -/// @param [in] next +/// @param [in] Merge Node +/// @param [in] NextIteration Node /// @return Status /// -Status SetNextIteration(const ge::NodePtr &node, const std::string &next) { +Status SetNextIteration(const NodePtr &node, const NodePtr &next) { GE_CHECK_NOTNULL(node); - OpDescPtr tmp_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(tmp_desc); + GE_CHECK_NOTNULL(next); + GE_CHECK_NOTNULL(node->GetOpDesc()); + GE_CHECK_NOTNULL(next->GetOpDesc()); - if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_NEXT_ITERATION, next)) { - REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), - node->GetName().c_str(), node->GetType().c_str()); - GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), - node->GetName().c_str(), node->GetType().c_str()); - return FAILED; - } + const auto SetIterationName = [](const OpDescPtr &op_desc, const std::string &name) { + if (!AttrUtils::SetStr(op_desc, ATTR_NAME_NEXT_ITERATION, name)) { + REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), + op_desc->GetName().c_str(), op_desc->GetType().c_str()); + GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), + op_desc->GetName().c_str(), op_desc->GetType().c_str()); + return FAILED; + } + return SUCCESS; + }; + GE_CHK_STATUS_RET_NOLOG(SetIterationName(node->GetOpDesc(), next->GetName())); + GE_CHK_STATUS_RET_NOLOG(SetIterationName(next->GetOpDesc(), node->GetName())); return SUCCESS; } diff --git a/ge/graph/common/omg_util.h b/ge/graph/common/omg_util.h index fdb0e138..91fcd29e 100644 --- a/ge/graph/common/omg_util.h +++ b/ge/graph/common/omg_util.h @@ -96,11 +96,11 @@ Status SetCyclicDependenceFlag(const ge::NodePtr &node); /// /// @brief set op next_iteration name -/// @param [in] node -/// @param [in] next +/// @param [in] Merge Node +/// @param [in] NextIteration Node /// @return Status /// -Status SetNextIteration(const ge::NodePtr &node, const std::string &next); +Status SetNextIteration(const NodePtr &node, const NodePtr &next); /// /// @brief Align the memory diff --git a/ge/graph/passes/next_iteration_pass.cc b/ge/graph/passes/next_iteration_pass.cc index 7128b3dc..71b9e621 100644 --- a/ge/graph/passes/next_iteration_pass.cc +++ b/ge/graph/passes/next_iteration_pass.cc @@ -354,7 +354,7 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr & merge_node->GetName().c_str()); return INTERNAL_ERROR; } - if (SetNextIteration(merge_node, next_node->GetName()) != SUCCESS) { + if (SetNextIteration(merge_node, next_node) != SUCCESS) { REPORT_CALL_ERROR("E19999", "Set attr NEXT_ITERATION value:%s to node:%s(%s) failed", next_node->GetName().c_str(), merge_node->GetName().c_str(), merge_node->GetType().c_str()); GELOGE(INTERNAL_ERROR, "Set attr NEXT_ITERATION for node %s failed.", merge_node->GetName().c_str()); diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index 9ec5431a..617adaaf 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -306,28 +306,15 @@ std::shared_ptr NodeState::GetTaskContext() { return task_context_; } -void NodeState::ResetContext(int group) { - SetGroup(group); - if (loop_count_ == 0) { - ++loop_count_; - return; - } - - ++loop_count_; - if (loop_count_ == UINT64_MAX) { - loop_count_ = 1; - } +void NodeState::ResetContext(uint64_t loop_count) { + loop_count_ = loop_count; switch_index_ = -1; subgraph_context_->ResetContext(node_item_->node); - GELOGD("Node[%s] in while loop, current loop: %lu, merge index: %d", GetName().c_str(), loop_count_, merge_index_); -} - -void NodeState::ResetSchedule() { - std::lock_guard lk(mu_); data_scheduled_ = static_cast(node_item_->root_data_.size()); ctrl_scheduled_ = static_cast(node_item_->root_ctrl_.size()); - GELOGD("[%s] set schedule for root nodes, data: %u, ctrl: %u", GetName().c_str(), data_scheduled_, ctrl_scheduled_); + GELOGD("[%s] in while loop, loop count: %lu, data scheduled: %u, ctrl scheduled: %u, merge index: %d", + GetName().c_str(), loop_count_, data_scheduled_, ctrl_scheduled_, merge_index_); } Status NodeState::NodeScheduled(const std::function &ready) const { @@ -335,14 +322,14 @@ Status NodeState::NodeScheduled(const std::function &rea for (const auto &node : node_item_->data_send_) { const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); GE_CHECK_NOTNULL(dst_node_state); - dst_node_state->SetDataSchedule(node_item_, ready); + dst_node_state->SetDataSchedule(*this, ready); } // Schedule ctrl output. for (const auto &node : node_item_->ctrl_send_) { const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); GE_CHECK_NOTNULL(dst_node_state); - dst_node_state->SetCtrlSchedule(node_item_, ready); + dst_node_state->SetCtrlSchedule(*this, ready); } // Schedule switch group. @@ -351,7 +338,7 @@ Status NodeState::NodeScheduled(const std::function &rea for (const auto &node : node_item_->switch_groups_[switch_index_]) { const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); GE_CHECK_NOTNULL(dst_node_state); - dst_node_state->SetCtrlSchedule(node_item_, ready); + dst_node_state->SetCtrlSchedule(*this, ready); } } @@ -359,36 +346,44 @@ Status NodeState::NodeScheduled(const std::function &rea } bool NodeState::IsScheduleReady() const { - GELOGD("[%s] data[input: %zu, scheduled: %u], ctrl[input: %zu, scheduled: %u]", GetName().c_str(), - node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), ctrl_scheduled_); - if (ctrl_scheduled_ != node_item_->ctrl_recv_.size()) { - return false; - } - + GELOGD("[%s] loop[%lu] data[input: %zu, scheduled: %u], ctrl[input: %zu+%zu, scheduled: %u]", + GetName().c_str(), loop_count_, node_item_->data_recv_.size(), data_scheduled_, + node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); if (node_item_->IsMergeOp()) { + if (ctrl_scheduled_ != node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1) + node_item_->ctrl_recv_.size()) { + return false; + } + return data_scheduled_ > 0; } + if (ctrl_scheduled_ != node_item_->ctrl_recv_.size()) { + return false; + } + // Exit may feed loop times... return data_scheduled_ >= node_item_->data_recv_.size(); } -void NodeState::SetDataSchedule(const NodeItem *node_item, const std::function &ready) { - GELOGD("[%s] data schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu, current scheduled: %u", - node_item->node_name.c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, - node_item_->ctrl_recv_.size(), ctrl_scheduled_); +void NodeState::SetDataSchedule(const NodeState &node_state, const std::function &ready) { + GELOGD("[%s] data schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu+%zu, current scheduled: %u", + node_state.GetName().c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, + node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); std::lock_guard lk(mu_); + if (loop_count_ != node_state.loop_count_) { + ResetContext(node_state.loop_count_); + } ++data_scheduled_; if (node_item_->IsMergeOp()) { - const auto it = node_item_->data_recv_.find(node_item); + const auto it = node_item_->data_recv_.find(node_state.node_item_); if (it != node_item_->data_recv_.end()) { merge_index_ = it->second; (void)AttrUtils::SetInt(node_item_->node->GetOpDesc(), ATTR_NAME_MERGE_INPUT_INDEX, it->second); - GELOGD("[%s] scheduled, [%s] set merge index: %d", node_item->node_name.c_str(), GetName().c_str(), it->second); + GELOGD("[%s] scheduled, [%s] set merge index: %d", node_state.GetName().c_str(), GetName().c_str(), it->second); } else { - GELOGW("[%s] scheduled, [%s] not followed", node_item->node_name.c_str(), GetName().c_str()); + GELOGW("[%s] scheduled, [%s] not followed", node_state.GetName().c_str(), GetName().c_str()); } } @@ -397,12 +392,15 @@ void NodeState::SetDataSchedule(const NodeItem *node_item, const std::function &ready) { - GELOGD("[%s] ctrl schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu, current scheduled: %u", - node_item->node_name.c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, - node_item_->ctrl_recv_.size(), ctrl_scheduled_); +void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function &ready) { + GELOGD("[%s] ctrl schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu+%zu, current scheduled: %u", + node_state.GetName().c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, + node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); std::lock_guard lk(mu_); + if (loop_count_ != node_state.loop_count_) { + ResetContext(node_state.loop_count_); + } ++ctrl_scheduled_; if (IsScheduleReady()) { @@ -410,6 +408,21 @@ void NodeState::SetCtrlSchedule(const NodeItem *node_item, const std::function lk(mu_); + ++loop_count_; + if (loop_count_ == UINT64_MAX) { + loop_count_ = 1; + } +} + +void NodeState::RunLoopExit() { + GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_); + std::lock_guard lk(mu_); + loop_count_ = 0; +} + void NodeState::SetScheduleFuture(std::future &&future) { schedule_future_ = std::move(future); } diff --git a/ge/hybrid/executor/node_state.h b/ge/hybrid/executor/node_state.h index d3f176ce..e4afdb9f 100644 --- a/ge/hybrid/executor/node_state.h +++ b/ge/hybrid/executor/node_state.h @@ -112,9 +112,8 @@ struct NodeState { return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE; } - void ResetContext(int group); - - void ResetSchedule(); + void RunLoopNext(); + void RunLoopExit(); Status NodeScheduled(const std::function &ready) const; @@ -166,8 +165,9 @@ struct NodeState { private: bool IsScheduleReady() const; - void SetDataSchedule(const NodeItem *node_item, const std::function &ready); - void SetCtrlSchedule(const NodeItem *node_item, const std::function &ready); + void SetDataSchedule(const NodeState &node_state, const std::function &ready); + void SetCtrlSchedule(const NodeState &node_state, const std::function &ready); + void ResetContext(uint64_t loop_count); const NodeItem *node_item_ = nullptr; std::shared_ptr kernel_task_ = nullptr; diff --git a/ge/hybrid/executor/subgraph_context.cc b/ge/hybrid/executor/subgraph_context.cc index 5de0828f..08d8e30b 100644 --- a/ge/hybrid/executor/subgraph_context.cc +++ b/ge/hybrid/executor/subgraph_context.cc @@ -46,6 +46,10 @@ Status SubgraphContext::Init() { return SUCCESS; } +void SubgraphContext::SetGroup(int group) { + group_ = group; +} + void SubgraphContext::ResetContext(const NodePtr &node) { node_done_manager_.Reset(node); } @@ -85,6 +89,7 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { if (node_state == nullptr) { const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); node_state = std::move(std::unique_ptr(new(std::nothrow)NodeState(*node_item, this))); + node_state->SetGroup(group_); (void)guard; } GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); diff --git a/ge/hybrid/executor/subgraph_context.h b/ge/hybrid/executor/subgraph_context.h index 7a99e324..303382c1 100644 --- a/ge/hybrid/executor/subgraph_context.h +++ b/ge/hybrid/executor/subgraph_context.h @@ -34,6 +34,7 @@ class SubgraphContext { ~SubgraphContext(); Status Init(); + void SetGroup(int group); void ResetContext(const NodePtr &node); void Reset(); NodeStatePtr GetOrCreateNodeState(const NodeItem *node_item); @@ -58,6 +59,7 @@ class SubgraphContext { std::vector all_outputs_; NodeDoneManager node_done_manager_; std::unordered_map node_states_; + int group_ = -1; }; } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/executor/subgraph_executor.cc b/ge/hybrid/executor/subgraph_executor.cc index 3536f295..9f5c3600 100644 --- a/ge/hybrid/executor/subgraph_executor.cc +++ b/ge/hybrid/executor/subgraph_executor.cc @@ -242,7 +242,6 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) { auto node_state = subgraph_context_->GetOrCreateNodeState(&node_item); GE_CHECK_NOTNULL(node_state); - node_state->ResetContext(group); auto p_node_state = node_state.get(); if (node_item.node_type == NETOUTPUT) { @@ -367,7 +366,6 @@ Status SubgraphExecutor::NodeScheduled(NodeState *node_state) { }; GE_CHK_STATUS_RET_NOLOG(node_state->NodeScheduled(callback)); - node_state->ResetSchedule(); RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] End"); return SUCCESS; }); diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 91188326..bb3c8dc8 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -21,6 +21,7 @@ #include "graph/ge_context.h" #include "graph/build/memory/var_mem_assign_util.h" #include "graph/debug/ge_attr_define.h" +#include "graph/common/omg_util.h" #include "graph/load/model_manager/model_utils.h" #include "graph/load/model_manager/model_manager.h" #include "graph/manager/graph_var_manager.h" @@ -43,8 +44,9 @@ const uint64_t kProfilingBpEndLogid = 2U; const uint64_t kProfilingIterEndLogid = 65535U; const int kBytes = 8; const int kDecimal = 10; -const uint8_t kStreamActiveIdx = 0; -const uint8_t kStreamActiveNum = 1; +const uint8_t kLoopEnterIdx = 0; +const uint8_t kLoopIterationIdx = 1; +const uint8_t kLoopMergeSize = 2; const uint8_t kStreamSwitchIdx = 1; const uint8_t kStreamSwitchNum = 2; const uint32_t kStringHeadElems = 2; @@ -57,6 +59,10 @@ const char *const kProfilingArNode = "ProfilingAllReduceNode"; const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; const char *const kForceInfershape = "_force_infershape_when_running"; +const std::set kExecutionDependentTypes{ IF, STATELESSIF, CASE, STREAMSWITCH }; +const std::set kMergeInputSkipTypes{ STREAMACTIVE, STREAMSWITCH, CONSTANT, CONSTANTOP }; +const std::set kStreamActiveTypes{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; + Status SetOutputNameAttr(ComputeGraph &graph) { vector output_names; for (const auto &node : graph.GetDirectNode()) { @@ -389,7 +395,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s } // cond or branch need to be prepared before the execution of IF or CASE - if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { + if (kExecutionDependentTypes.count(node_item.node_type) > 0) { auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input GE_CHECK_NOTNULL(src_node); auto src_node_item = MutableNodeItem(src_node); @@ -575,7 +581,7 @@ Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) { auto in_nodes = root_node->GetInAllNodes(); std::set in_node_set(in_nodes.begin(), in_nodes.end()); for (auto &in_control_node : wrapped_node->GetInControlNodes()) { - if (in_node_set.count(in_control_node) == 0) { + if (in_node_set.count(in_control_node) == 0 && kMergeInputSkipTypes.count(root_node->GetType()) == 0) { GELOGD("[%s] Restore control edge to [%s]", in_control_node->GetName().c_str(), root_node->GetName().c_str()); GE_CHECK_NOTNULL(in_control_node->GetOutControlAnchor()); (void) in_control_node->GetOutControlAnchor()->LinkTo(root_node->GetInControlAnchor()); @@ -2282,8 +2288,6 @@ Status HybridModelBuilder::RelinkNextIteration() { } } - stream_merge_op_nodes_.clear(); - next_iteration_op_nodes_.clear(); return SUCCESS; } @@ -2371,10 +2375,12 @@ Status HybridModelBuilder::BuildControlFlowGroup(GraphItem &graph_item, const No } Status HybridModelBuilder::CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item) { - const auto out_ctrl_anchor = node->GetOutControlAnchor(); - for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { - const auto &dst_node = peer_in_anchor->GetOwnerNode(); + for (const auto &dst_node : node->GetOutControlNodes()) { GE_CHECK_NOTNULL(dst_node); + if ((dst_node->GetType() == STREAMACTIVE) && (kStreamActiveTypes.count(node->GetType()) == 0)) { + GELOGI("[%s] ignore control to [%s]", node->GetName().c_str(), dst_node->GetName().c_str()); + continue; + } NodeItem *dst_node_item = nullptr; GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), @@ -2384,27 +2390,80 @@ Status HybridModelBuilder::CreateNormalNodeGroup(const NodePtr &node, NodeItem * return SUCCESS; } +Status HybridModelBuilder::CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item) { + // Enter --> StreamActive --> StreamMerge + for (const auto &dst_node : node->GetOutControlNodes()) { + GE_CHECK_NOTNULL(dst_node); + NodeItem *dst_node_item = nullptr; + GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), + "[%s] failed to get or create node item", dst_node->GetName().c_str()); + // Set Enter Control to StreamMerge as Group 0. + dst_node_item->switch_groups_.resize(kLoopMergeSize); + dst_node_item->SetMergeCtrl(node_item, kLoopEnterIdx); + } + return SUCCESS; +} + +Status HybridModelBuilder::CreateMergeIterationGroup(const NodePtr &node, NodeItem *node_item) { + // NextIteration --> StreamActive {-->} StreamMerge + std::string node_name; + for (const auto &src_node : node->GetInControlNodes()) { + GE_CHECK_NOTNULL(src_node); + if (kNextIterationOpTypes.count(src_node->GetType()) == 0) { + GELOGI("[%s] Skip Not NextIteration node [%s]", node->GetName().c_str(), src_node->GetName().c_str()); + continue; + } + + if (!AttrUtils::GetStr(src_node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, node_name)) { + GELOGE(INTERNAL_ERROR, "[%s] input node [%s] expect attribute[%s] not found", + node->GetName().c_str(), src_node->GetName().c_str(), ATTR_NAME_NEXT_ITERATION.c_str()); + return INTERNAL_ERROR; + } + + const auto it = stream_merge_op_nodes_.find(node_name); + if (it == stream_merge_op_nodes_.end()) { + GELOGE(INTERNAL_ERROR, "[%s] expect StreamMerge[%s] not found", node->GetName().c_str(), node_name.c_str()); + return INTERNAL_ERROR; + } + + const auto &dst_node = it->second; + GE_CHECK_NOTNULL(dst_node); + NodeItem *dst_node_item = nullptr; + GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), "[%s] failed to get or create node item", + dst_node->GetName().c_str()); + // Set NextIteration Control to StreamMerge as Group 1. + dst_node_item->SetMergeCtrl(node_item, kLoopIterationIdx); + } + return SUCCESS; +} + Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item) { if (node_item->node_type != STREAMACTIVE) { GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node_item->node_type.c_str()); return INTERNAL_ERROR; } - node_item->switch_groups_.resize(kStreamActiveNum); - const auto &out_ctrl_anchor = node->GetOutControlAnchor(); - for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { - const auto &dst_node = peer_in_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(dst_node); - if (dst_node->GetType() == STREAMMERGE) { - GELOGI("[%s] skip control node: %s", node->GetName().c_str(), dst_node->GetName().c_str()); - continue; - } + const auto ctrl_nodes = node->GetInControlNodes(); + if (ctrl_nodes.empty()) { + GELOGW("Skip no in control node: %s", node->GetName().c_str()); + return SUCCESS; + } - NodeItem *dst_node_item = nullptr; - GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), - "[%s] failed to get or create node item", dst_node->GetName().c_str()); - node_item->SetCtrlSend(dst_node_item, kStreamActiveIdx); + const auto IsEnterNode = [](const NodePtr &n) { + return kEnterOpTypes.count(n->GetType()) > 0; + }; + const auto IsIterationNode = [](const NodePtr &n) { + return kNextIterationOpTypes.count(n->GetType()) > 0; + }; + + if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) { + // Enter --> StreamActive --> StreamMerge + return CreateMergeEnterGroup(node, node_item); + } else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) { + // NextIteration --> StreamActive {-->} StreamMerge + return CreateMergeIterationGroup(node, node_item); } + return SUCCESS; } @@ -2416,11 +2475,8 @@ Status HybridModelBuilder::CreateStreamSwitchGroup(const NodePtr &node, NodeItem // Consider as two groups, group[0] set empty for false, group[1] for true. node_item->switch_groups_.resize(kStreamSwitchNum); - const auto &out_ctrl_anchor = node->GetOutControlAnchor(); - for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { - const auto &dst_node = peer_in_anchor->GetOwnerNode(); + for (const auto &dst_node : node->GetOutControlNodes()) { GE_CHECK_NOTNULL(dst_node); - NodeItem *dst_node_item = nullptr; GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), "[%s] failed to get or create node item", dst_node->GetName().c_str()); @@ -2447,20 +2503,17 @@ Status HybridModelBuilder::CreateStreamSwitchNGroup(const NodePtr &node, NodeIte } node_item->switch_groups_.resize(batch_num); - const auto &out_ctrl_anchor = node->GetOutControlAnchor(); - for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { - const auto &dst_node = peer_in_anchor->GetOwnerNode(); + for (const auto &dst_node : node->GetOutControlNodes()) { GE_CHECK_NOTNULL(dst_node); - std::string batch_label; - if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { - GELOGE(INTERNAL_ERROR, "[%s] Get ATTR_NAME_BATCH_LABEL failed", node->GetName().c_str()); + if (!AttrUtils::GetStr(dst_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { + GELOGE(INTERNAL_ERROR, "[%s] Get ATTR_NAME_BATCH_LABEL failed", dst_node->GetName().c_str()); return INTERNAL_ERROR; } std::string::size_type pos = batch_label.rfind("_"); if (pos == std::string::npos) { - GELOGW("[%s] Separator not found in batch label: %s.", node->GetName().c_str(), batch_label.c_str()); + GELOGW("[%s] Separator not found in batch label: %s.", dst_node->GetName().c_str(), batch_label.c_str()); continue; } @@ -2486,7 +2539,7 @@ Status HybridModelBuilder::CreateNextIterationGroup(const NodePtr &node, NodeIte return INTERNAL_ERROR; } - return SUCCESS; + return CreateNormalNodeGroup(node, node_item); } Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node_item) { @@ -2495,11 +2548,8 @@ Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node return INTERNAL_ERROR; } - const auto &out_ctrl_anchor = node->GetOutControlAnchor(); - for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { - const auto &dst_node = peer_in_anchor->GetOwnerNode(); + for (const auto &dst_node : node->GetOutControlNodes()) { GE_CHECK_NOTNULL(dst_node); - NodeItem *dst_node_item = nullptr; GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), "[%s] failed to get or create node item", dst_node->GetName().c_str()); @@ -2509,11 +2559,8 @@ Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node // Group switch flow by out put data. node_item->switch_groups_.resize(SWITCH_OUTPUT_NUM); for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { - const auto &out_anchor = node->GetOutDataAnchor(i); - for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { - const auto &dst_node = peer_in_anchor->GetOwnerNode(); + for (const auto &dst_node : node->GetOutDataNodes()) { GE_CHECK_NOTNULL(dst_node); - NodeItem *dst_node_item = nullptr; GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), "[%s] failed to get or create node item", dst_node->GetName().c_str()); diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index ad288317..d0ee54ed 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -99,6 +99,8 @@ class HybridModelBuilder { Status BuildProfilingControl(GraphItem &graph_item, const std::map> &nodes); Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item); Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item); + Status CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item); + Status CreateMergeIterationGroup(const NodePtr &node, NodeItem *node_item); Status CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item); Status CreateStreamSwitchGroup(const NodePtr &node, NodeItem *node_item); Status CreateStreamSwitchNGroup(const NodePtr &node, NodeItem *node_item); diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index c6adce00..07c8038b 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -34,8 +34,8 @@ const std::set kControlOpTypes{ }; const std::set kControlFlowOpTypes{ - STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX, - NEXTITERATION, REFNEXTITERATION + STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT, + LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX }; const std::set kMergeOpTypes{ @@ -420,6 +420,24 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); } +void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) { + if (merge_index >= switch_groups_.size()) { + GELOGE(FAILED, "[%s] group size: %zu, merge index: %u", NodeName().c_str(), switch_groups_.size(), merge_index); + return; + } + + // this is StreamMerge node, node_item is StreamActive node. + std::vector &switch_group = switch_groups_[merge_index]; + switch_group.emplace_back(node_item); + + node_item->ctrl_send_.emplace(this); + GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); +} + +size_t NodeItem::GetMergeCtrl(uint32_t merge_index) const { + return (merge_index < switch_groups_.size()) ? switch_groups_[merge_index].size() : 0; +} + OptionalMutexGuard::OptionalMutexGuard(std::mutex *mutex, const string &name) : mu_(mutex), name_(name) { if (mu_ != nullptr) { GELOGD("lock for %s", name_.c_str()); diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index 606e58fe..af796753 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -98,6 +98,8 @@ struct NodeItem { void SetDataSend(NodeItem *node_item, int anchor_index); void SetCtrlSend(NodeItem *node_item, uint32_t switch_index); + void SetMergeCtrl(NodeItem *node_item, uint32_t merge_index); + size_t GetMergeCtrl(uint32_t merge_index) const; OptionalMutexGuard MutexGuard(const std::string &name) const { return OptionalMutexGuard(copy_mu_.get(), name + "_" + node_name); diff --git a/ge/hybrid/node_executor/rts/rts_node_task.cc b/ge/hybrid/node_executor/rts/rts_node_task.cc index f6d6ddb6..5ad8eaf4 100644 --- a/ge/hybrid/node_executor/rts/rts_node_task.cc +++ b/ge/hybrid/node_executor/rts/rts_node_task.cc @@ -20,6 +20,7 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" +#include "graph/utils/node_utils.h" #include "common/ge/ge_util.h" #include "common/op/ge_op_utils.h" @@ -201,6 +202,13 @@ Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::functio GE_CHECK_NOTNULL(in_x); GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(0, *in_x)); // y + const auto &node_state = task_context.GetNodeState(); + if (kNextIterationOpTypes.count(node_state->GetType()) > 0) { + node_state->RunLoopNext(); + } else if (kExitOpTypes.count(node_state->GetType()) > 0) { + node_state->RunLoopExit(); + } + if (done_callback) { GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); }