Browse Source

Fix Merge to NextIteratoin control drive

tags/v1.3.0
zhangxiaokun 4 years ago
parent
commit
db9078a799
3 changed files with 12 additions and 15 deletions
  1. +1
    -8
      ge/graph/common/omg_util.cc
  2. +10
    -6
      ge/hybrid/executor/node_state.cc
  3. +1
    -1
      ge/hybrid/model/node_item.cc

+ 1
- 8
ge/graph/common/omg_util.cc View File

@@ -291,14 +291,7 @@ void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t grou
GE_RT_VOID_CHECK_NOTNULL(op_desc);

// op_desc as AttrHolderAdapter valid, Set attribute always success, just log for check.
GELOGD("Mark [%s] as force unknown shape node, group index: %ld", node->GetName().c_str(), group_index);
if (!AttrUtils::SetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, force_unknown)) {
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(),
node->GetName().c_str(), node->GetType().c_str());
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(),
node->GetName().c_str(), node->GetType().c_str());
}

GELOGD("[%s] Set control flow group index: %ld", node->GetName().c_str(), group_index);
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) {
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(),
node->GetName().c_str(), node->GetType().c_str());


+ 10
- 6
ge/hybrid/executor/node_state.cc View File

@@ -366,9 +366,10 @@ bool NodeState::IsScheduleReady() const {
}

void NodeState::SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) {
GELOGD("[%s] data schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu+%zu, current scheduled: %u",
node_state.GetName().c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_,
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_);
GELOGD("[%s] schedule [%s], loop[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]",
node_state.GetName().c_str(), GetName().c_str(), loop_count_, node_state.loop_count_,
node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(),
node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_);

std::lock_guard<std::mutex> lk(mu_);
if (loop_count_ != node_state.loop_count_) {
@@ -393,9 +394,10 @@ void NodeState::SetDataSchedule(const NodeState &node_state, const std::function
}

void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) {
GELOGD("[%s] ctrl schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu+%zu, current scheduled: %u",
node_state.GetName().c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_,
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_);
GELOGD("[%s] schedule [%s], loop[%lu -> %lu], data[num: %zu, scheduled: %u], ctrl[num: %zu+%zu, scheduled: %u]",
node_state.GetName().c_str(), GetName().c_str(), loop_count_, node_state.loop_count_,
node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(),
node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_);

std::lock_guard<std::mutex> lk(mu_);
if (loop_count_ != node_state.loop_count_) {
@@ -415,6 +417,8 @@ void NodeState::RunLoopNext() {
if (loop_count_ == UINT64_MAX) {
loop_count_ = 1;
}

ResetContext(loop_count_);
}

void NodeState::RunLoopExit() {


+ 1
- 1
ge/hybrid/model/node_item.cc View File

@@ -443,7 +443,7 @@ void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) {
}

size_t NodeItem::GetMergeCtrl(uint32_t merge_index) const {
return (merge_index < switch_groups_.size()) ? switch_groups_[merge_index].size() : 0;
return ((node_type == STREAMMERGE) && (merge_index < switch_groups_.size())) ? switch_groups_[merge_index].size() : 0;
}

OptionalMutexGuard::OptionalMutexGuard(std::mutex *mutex, const string &name) : mu_(mutex), name_(name) {


Loading…
Cancel
Save