From db9078a7995f9bde1be9538bf52e2ed97462c204 Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Mon, 31 May 2021 13:42:27 +0800 Subject: [PATCH] Fix Merge to NextIteratoin control drive --- ge/graph/common/omg_util.cc | 9 +-------- ge/hybrid/executor/node_state.cc | 16 ++++++++++------ ge/hybrid/model/node_item.cc | 2 +- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/ge/graph/common/omg_util.cc b/ge/graph/common/omg_util.cc index 670355b8..598677bd 100644 --- a/ge/graph/common/omg_util.cc +++ b/ge/graph/common/omg_util.cc @@ -291,14 +291,7 @@ void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t grou GE_RT_VOID_CHECK_NOTNULL(op_desc); // op_desc as AttrHolderAdapter valid, Set attribute always success, just log for check. - GELOGD("Mark [%s] as force unknown shape node, group index: %ld", node->GetName().c_str(), group_index); - if (!AttrUtils::SetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, force_unknown)) { - REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), - node->GetName().c_str(), node->GetType().c_str()); - GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), - node->GetName().c_str(), node->GetType().c_str()); - } - + GELOGD("[%s] Set control flow group index: %ld", node->GetName().c_str(), group_index); if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), node->GetName().c_str(), node->GetType().c_str()); diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index 617adaaf..fd47cfb2 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -366,9 +366,10 @@ bool NodeState::IsScheduleReady() const { } void NodeState::SetDataSchedule(const NodeState &node_state, const std::function &ready) { - GELOGD("[%s] data schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu+%zu, current scheduled: %u", - node_state.GetName().c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, - node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); + GELOGD("[%s] schedule [%s], loop[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", + node_state.GetName().c_str(), GetName().c_str(), loop_count_, node_state.loop_count_, + node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), + node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); std::lock_guard lk(mu_); if (loop_count_ != node_state.loop_count_) { @@ -393,9 +394,10 @@ void NodeState::SetDataSchedule(const NodeState &node_state, const std::function } void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function &ready) { - GELOGD("[%s] ctrl schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu+%zu, current scheduled: %u", - node_state.GetName().c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, - node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); + GELOGD("[%s] schedule [%s], loop[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]", + node_state.GetName().c_str(), GetName().c_str(), loop_count_, node_state.loop_count_, + node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), + node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); std::lock_guard lk(mu_); if (loop_count_ != node_state.loop_count_) { @@ -415,6 +417,8 @@ void NodeState::RunLoopNext() { if (loop_count_ == UINT64_MAX) { loop_count_ = 1; } + + ResetContext(loop_count_); } void NodeState::RunLoopExit() { diff --git a/ge/hybrid/model/node_item.cc b/ge/hybrid/model/node_item.cc index f3793abf..7054fd46 100644 --- a/ge/hybrid/model/node_item.cc +++ b/ge/hybrid/model/node_item.cc @@ -443,7 +443,7 @@ void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) { } size_t NodeItem::GetMergeCtrl(uint32_t merge_index) const { - return (merge_index < switch_groups_.size()) ? switch_groups_[merge_index].size() : 0; + return ((node_type == STREAMMERGE) && (merge_index < switch_groups_.size())) ? switch_groups_[merge_index].size() : 0; } OptionalMutexGuard::OptionalMutexGuard(std::mutex *mutex, const string &name) : mu_(mutex), name_(name) {