Browse Source

Enable StreamActive for Loop NextIteration

pull/1691/head
zhangxiaokun 4 years ago
parent
commit
7001a2bbc1
13 changed files with 205 additions and 104 deletions
  1. +18
    -12
      ge/graph/common/omg_util.cc
  2. +3
    -3
      ge/graph/common/omg_util.h
  3. +1
    -1
      ge/graph/passes/next_iteration_pass.cc
  4. +50
    -37
      ge/hybrid/executor/node_state.cc
  5. +5
    -5
      ge/hybrid/executor/node_state.h
  6. +5
    -0
      ge/hybrid/executor/subgraph_context.cc
  7. +2
    -0
      ge/hybrid/executor/subgraph_context.h
  8. +0
    -2
      ge/hybrid/executor/subgraph_executor.cc
  9. +89
    -42
      ge/hybrid/model/hybrid_model_builder.cc
  10. +2
    -0
      ge/hybrid/model/hybrid_model_builder.h
  11. +20
    -2
      ge/hybrid/model/node_item.cc
  12. +2
    -0
      ge/hybrid/model/node_item.h
  13. +8
    -0
      ge/hybrid/node_executor/rts/rts_node_task.cc

+ 18
- 12
ge/graph/common/omg_util.cc View File

@@ -193,23 +193,29 @@ Status SetCyclicDependenceFlag(const ge::NodePtr &node) {


/// ///
/// @brief set op next_iteration name /// @brief set op next_iteration name
/// @param [in] node
/// @param [in] next
/// @param [in] Merge Node
/// @param [in] NextIteration Node
/// @return Status /// @return Status
/// ///
Status SetNextIteration(const ge::NodePtr &node, const std::string &next) {
Status SetNextIteration(const NodePtr &node, const NodePtr &next) {
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
OpDescPtr tmp_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(tmp_desc);
GE_CHECK_NOTNULL(next);
GE_CHECK_NOTNULL(node->GetOpDesc());
GE_CHECK_NOTNULL(next->GetOpDesc());


if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_NEXT_ITERATION, next)) {
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(),
node->GetName().c_str(), node->GetType().c_str());
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(),
node->GetName().c_str(), node->GetType().c_str());
return FAILED;
}
const auto SetIterationName = [](const OpDescPtr &op_desc, const std::string &name) {
if (!AttrUtils::SetStr(op_desc, ATTR_NAME_NEXT_ITERATION, name)) {
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str());
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(),
op_desc->GetName().c_str(), op_desc->GetType().c_str());
return FAILED;
}
return SUCCESS;
};


GE_CHK_STATUS_RET_NOLOG(SetIterationName(node->GetOpDesc(), next->GetName()));
GE_CHK_STATUS_RET_NOLOG(SetIterationName(next->GetOpDesc(), node->GetName()));
return SUCCESS; return SUCCESS;
} }




+ 3
- 3
ge/graph/common/omg_util.h View File

@@ -96,11 +96,11 @@ Status SetCyclicDependenceFlag(const ge::NodePtr &node);


/// ///
/// @brief set op next_iteration name /// @brief set op next_iteration name
/// @param [in] node
/// @param [in] next
/// @param [in] Merge Node
/// @param [in] NextIteration Node
/// @return Status /// @return Status
/// ///
Status SetNextIteration(const ge::NodePtr &node, const std::string &next);
Status SetNextIteration(const NodePtr &node, const NodePtr &next);


/// ///
/// @brief Align the memory /// @brief Align the memory


+ 1
- 1
ge/graph/passes/next_iteration_pass.cc View File

@@ -354,7 +354,7 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr &
merge_node->GetName().c_str()); merge_node->GetName().c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
if (SetNextIteration(merge_node, next_node->GetName()) != SUCCESS) {
if (SetNextIteration(merge_node, next_node) != SUCCESS) {
REPORT_CALL_ERROR("E19999", "Set attr NEXT_ITERATION value:%s to node:%s(%s) failed", REPORT_CALL_ERROR("E19999", "Set attr NEXT_ITERATION value:%s to node:%s(%s) failed",
next_node->GetName().c_str(), merge_node->GetName().c_str(), merge_node->GetType().c_str()); next_node->GetName().c_str(), merge_node->GetName().c_str(), merge_node->GetType().c_str());
GELOGE(INTERNAL_ERROR, "Set attr NEXT_ITERATION for node %s failed.", merge_node->GetName().c_str()); GELOGE(INTERNAL_ERROR, "Set attr NEXT_ITERATION for node %s failed.", merge_node->GetName().c_str());


+ 50
- 37
ge/hybrid/executor/node_state.cc View File

@@ -306,28 +306,15 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() {
return task_context_; return task_context_;
} }


void NodeState::ResetContext(int group) {
SetGroup(group);
if (loop_count_ == 0) {
++loop_count_;
return;
}

++loop_count_;
if (loop_count_ == UINT64_MAX) {
loop_count_ = 1;
}
void NodeState::ResetContext(uint64_t loop_count) {
loop_count_ = loop_count;


switch_index_ = -1; switch_index_ = -1;
subgraph_context_->ResetContext(node_item_->node); subgraph_context_->ResetContext(node_item_->node);
GELOGD("Node[%s] in while loop, current loop: %lu, merge index: %d", GetName().c_str(), loop_count_, merge_index_);
}

void NodeState::ResetSchedule() {
std::lock_guard<std::mutex> lk(mu_);
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size());
GELOGD("[%s] set schedule for root nodes, data: %u, ctrl: %u", GetName().c_str(), data_scheduled_, ctrl_scheduled_);
GELOGD("[%s] in while loop, loop count: %lu, data scheduled: %u, ctrl scheduled: %u, merge index: %d",
GetName().c_str(), loop_count_, data_scheduled_, ctrl_scheduled_, merge_index_);
} }


Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &ready) const { Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &ready) const {
@@ -335,14 +322,14 @@ Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &rea
for (const auto &node : node_item_->data_send_) { for (const auto &node : node_item_->data_send_) {
const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node);
GE_CHECK_NOTNULL(dst_node_state); GE_CHECK_NOTNULL(dst_node_state);
dst_node_state->SetDataSchedule(node_item_, ready);
dst_node_state->SetDataSchedule(*this, ready);
} }


// Schedule ctrl output. // Schedule ctrl output.
for (const auto &node : node_item_->ctrl_send_) { for (const auto &node : node_item_->ctrl_send_) {
const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node);
GE_CHECK_NOTNULL(dst_node_state); GE_CHECK_NOTNULL(dst_node_state);
dst_node_state->SetCtrlSchedule(node_item_, ready);
dst_node_state->SetCtrlSchedule(*this, ready);
} }


// Schedule switch group. // Schedule switch group.
@@ -351,7 +338,7 @@ Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &rea
for (const auto &node : node_item_->switch_groups_[switch_index_]) { for (const auto &node : node_item_->switch_groups_[switch_index_]) {
const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node);
GE_CHECK_NOTNULL(dst_node_state); GE_CHECK_NOTNULL(dst_node_state);
dst_node_state->SetCtrlSchedule(node_item_, ready);
dst_node_state->SetCtrlSchedule(*this, ready);
} }
} }


@@ -359,36 +346,44 @@ Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &rea
} }


bool NodeState::IsScheduleReady() const { bool NodeState::IsScheduleReady() const {
GELOGD("[%s] data[input: %zu, scheduled: %u], ctrl[input: %zu, scheduled: %u]", GetName().c_str(),
node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), ctrl_scheduled_);
if (ctrl_scheduled_ != node_item_->ctrl_recv_.size()) {
return false;
}

GELOGD("[%s] loop[%lu] data[input: %zu, scheduled: %u], ctrl[input: %zu+%zu, scheduled: %u]",
GetName().c_str(), loop_count_, node_item_->data_recv_.size(), data_scheduled_,
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_);
if (node_item_->IsMergeOp()) { if (node_item_->IsMergeOp()) {
if (ctrl_scheduled_ != node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1) + node_item_->ctrl_recv_.size()) {
return false;
}

return data_scheduled_ > 0; return data_scheduled_ > 0;
} }


if (ctrl_scheduled_ != node_item_->ctrl_recv_.size()) {
return false;
}

// Exit may feed loop times... // Exit may feed loop times...
return data_scheduled_ >= node_item_->data_recv_.size(); return data_scheduled_ >= node_item_->data_recv_.size();
} }


void NodeState::SetDataSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready) {
GELOGD("[%s] data schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu, current scheduled: %u",
node_item->node_name.c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_,
node_item_->ctrl_recv_.size(), ctrl_scheduled_);
void NodeState::SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) {
GELOGD("[%s] data schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu+%zu, current scheduled: %u",
node_state.GetName().c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_,
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_);


std::lock_guard<std::mutex> lk(mu_); std::lock_guard<std::mutex> lk(mu_);
if (loop_count_ != node_state.loop_count_) {
ResetContext(node_state.loop_count_);
}
++data_scheduled_; ++data_scheduled_;


if (node_item_->IsMergeOp()) { if (node_item_->IsMergeOp()) {
const auto it = node_item_->data_recv_.find(node_item);
const auto it = node_item_->data_recv_.find(node_state.node_item_);
if (it != node_item_->data_recv_.end()) { if (it != node_item_->data_recv_.end()) {
merge_index_ = it->second; merge_index_ = it->second;
(void)AttrUtils::SetInt(node_item_->node->GetOpDesc(), ATTR_NAME_MERGE_INPUT_INDEX, it->second); (void)AttrUtils::SetInt(node_item_->node->GetOpDesc(), ATTR_NAME_MERGE_INPUT_INDEX, it->second);
GELOGD("[%s] scheduled, [%s] set merge index: %d", node_item->node_name.c_str(), GetName().c_str(), it->second);
GELOGD("[%s] scheduled, [%s] set merge index: %d", node_state.GetName().c_str(), GetName().c_str(), it->second);
} else { } else {
GELOGW("[%s] scheduled, [%s] not followed", node_item->node_name.c_str(), GetName().c_str());
GELOGW("[%s] scheduled, [%s] not followed", node_state.GetName().c_str(), GetName().c_str());
} }
} }


@@ -397,12 +392,15 @@ void NodeState::SetDataSchedule(const NodeItem *node_item, const std::function<v
} }
} }


void NodeState::SetCtrlSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready) {
GELOGD("[%s] ctrl schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu, current scheduled: %u",
node_item->node_name.c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_,
node_item_->ctrl_recv_.size(), ctrl_scheduled_);
void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) {
GELOGD("[%s] ctrl schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu+%zu, current scheduled: %u",
node_state.GetName().c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_,
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_);


std::lock_guard<std::mutex> lk(mu_); std::lock_guard<std::mutex> lk(mu_);
if (loop_count_ != node_state.loop_count_) {
ResetContext(node_state.loop_count_);
}
++ctrl_scheduled_; ++ctrl_scheduled_;


if (IsScheduleReady()) { if (IsScheduleReady()) {
@@ -410,6 +408,21 @@ void NodeState::SetCtrlSchedule(const NodeItem *node_item, const std::function<v
} }
} }


void NodeState::RunLoopNext() {
GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_);
std::lock_guard<std::mutex> lk(mu_);
++loop_count_;
if (loop_count_ == UINT64_MAX) {
loop_count_ = 1;
}
}

void NodeState::RunLoopExit() {
GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_);
std::lock_guard<std::mutex> lk(mu_);
loop_count_ = 0;
}

void NodeState::SetScheduleFuture(std::future<Status> &&future) { void NodeState::SetScheduleFuture(std::future<Status> &&future) {
schedule_future_ = std::move(future); schedule_future_ = std::move(future);
} }


+ 5
- 5
ge/hybrid/executor/node_state.h View File

@@ -112,9 +112,8 @@ struct NodeState {
return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE; return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE;
} }


void ResetContext(int group);

void ResetSchedule();
void RunLoopNext();
void RunLoopExit();


Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const;


@@ -166,8 +165,9 @@ struct NodeState {


private: private:
bool IsScheduleReady() const; bool IsScheduleReady() const;
void SetDataSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready);
void SetCtrlSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready);
void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready);
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready);
void ResetContext(uint64_t loop_count);


const NodeItem *node_item_ = nullptr; const NodeItem *node_item_ = nullptr;
std::shared_ptr<NodeTask> kernel_task_ = nullptr; std::shared_ptr<NodeTask> kernel_task_ = nullptr;


+ 5
- 0
ge/hybrid/executor/subgraph_context.cc View File

@@ -46,6 +46,10 @@ Status SubgraphContext::Init() {
return SUCCESS; return SUCCESS;
} }


void SubgraphContext::SetGroup(int group) {
group_ = group;
}

void SubgraphContext::ResetContext(const NodePtr &node) { void SubgraphContext::ResetContext(const NodePtr &node) {
node_done_manager_.Reset(node); node_done_manager_.Reset(node);
} }
@@ -85,6 +89,7 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) {
if (node_state == nullptr) { if (node_state == nullptr) {
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); const auto &guard = node_item->MutexGuard("GetOrCreateNodeState");
node_state = std::move(std::unique_ptr<NodeState>(new(std::nothrow)NodeState(*node_item, this))); node_state = std::move(std::unique_ptr<NodeState>(new(std::nothrow)NodeState(*node_item, this)));
node_state->SetGroup(group_);
(void)guard; (void)guard;
} }
GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); GELOGD("[%s] unlock for write", node_item->NodeName().c_str());


+ 2
- 0
ge/hybrid/executor/subgraph_context.h View File

@@ -34,6 +34,7 @@ class SubgraphContext {
~SubgraphContext(); ~SubgraphContext();


Status Init(); Status Init();
void SetGroup(int group);
void ResetContext(const NodePtr &node); void ResetContext(const NodePtr &node);
void Reset(); void Reset();
NodeStatePtr GetOrCreateNodeState(const NodeItem *node_item); NodeStatePtr GetOrCreateNodeState(const NodeItem *node_item);
@@ -58,6 +59,7 @@ class SubgraphContext {
std::vector<TensorValue> all_outputs_; std::vector<TensorValue> all_outputs_;
NodeDoneManager node_done_manager_; NodeDoneManager node_done_manager_;
std::unordered_map<const NodeItem *, NodeStatePtr> node_states_; std::unordered_map<const NodeItem *, NodeStatePtr> node_states_;
int group_ = -1;
}; };
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge


+ 0
- 2
ge/hybrid/executor/subgraph_executor.cc View File

@@ -242,7 +242,6 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) {


auto node_state = subgraph_context_->GetOrCreateNodeState(&node_item); auto node_state = subgraph_context_->GetOrCreateNodeState(&node_item);
GE_CHECK_NOTNULL(node_state); GE_CHECK_NOTNULL(node_state);
node_state->ResetContext(group);
auto p_node_state = node_state.get(); auto p_node_state = node_state.get();


if (node_item.node_type == NETOUTPUT) { if (node_item.node_type == NETOUTPUT) {
@@ -367,7 +366,6 @@ Status SubgraphExecutor::NodeScheduled(NodeState *node_state) {
}; };


GE_CHK_STATUS_RET_NOLOG(node_state->NodeScheduled(callback)); GE_CHK_STATUS_RET_NOLOG(node_state->NodeScheduled(callback));
node_state->ResetSchedule();
RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] End"); RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] End");
return SUCCESS; return SUCCESS;
}); });


+ 89
- 42
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -21,6 +21,7 @@
#include "graph/ge_context.h" #include "graph/ge_context.h"
#include "graph/build/memory/var_mem_assign_util.h" #include "graph/build/memory/var_mem_assign_util.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "graph/common/omg_util.h"
#include "graph/load/model_manager/model_utils.h" #include "graph/load/model_manager/model_utils.h"
#include "graph/load/model_manager/model_manager.h" #include "graph/load/model_manager/model_manager.h"
#include "graph/manager/graph_var_manager.h" #include "graph/manager/graph_var_manager.h"
@@ -43,8 +44,9 @@ const uint64_t kProfilingBpEndLogid = 2U;
const uint64_t kProfilingIterEndLogid = 65535U; const uint64_t kProfilingIterEndLogid = 65535U;
const int kBytes = 8; const int kBytes = 8;
const int kDecimal = 10; const int kDecimal = 10;
const uint8_t kStreamActiveIdx = 0;
const uint8_t kStreamActiveNum = 1;
const uint8_t kLoopEnterIdx = 0;
const uint8_t kLoopIterationIdx = 1;
const uint8_t kLoopMergeSize = 2;
const uint8_t kStreamSwitchIdx = 1; const uint8_t kStreamSwitchIdx = 1;
const uint8_t kStreamSwitchNum = 2; const uint8_t kStreamSwitchNum = 2;
const uint32_t kStringHeadElems = 2; const uint32_t kStringHeadElems = 2;
@@ -57,6 +59,10 @@ const char *const kProfilingArNode = "ProfilingAllReduceNode";
const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE";
const char *const kForceInfershape = "_force_infershape_when_running"; const char *const kForceInfershape = "_force_infershape_when_running";


const std::set<std::string> kExecutionDependentTypes{ IF, STATELESSIF, CASE, STREAMSWITCH };
const std::set<std::string> kMergeInputSkipTypes{ STREAMACTIVE, STREAMSWITCH, CONSTANT, CONSTANTOP };
const std::set<std::string> kStreamActiveTypes{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION };

Status SetOutputNameAttr(ComputeGraph &graph) { Status SetOutputNameAttr(ComputeGraph &graph) {
vector<string> output_names; vector<string> output_names;
for (const auto &node : graph.GetDirectNode()) { for (const auto &node : graph.GetDirectNode()) {
@@ -389,7 +395,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
} }


// cond or branch need to be prepared before the execution of IF or CASE // cond or branch need to be prepared before the execution of IF or CASE
if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) {
if (kExecutionDependentTypes.count(node_item.node_type) > 0) {
auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input
GE_CHECK_NOTNULL(src_node); GE_CHECK_NOTNULL(src_node);
auto src_node_item = MutableNodeItem(src_node); auto src_node_item = MutableNodeItem(src_node);
@@ -575,7 +581,7 @@ Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) {
auto in_nodes = root_node->GetInAllNodes(); auto in_nodes = root_node->GetInAllNodes();
std::set<NodePtr> in_node_set(in_nodes.begin(), in_nodes.end()); std::set<NodePtr> in_node_set(in_nodes.begin(), in_nodes.end());
for (auto &in_control_node : wrapped_node->GetInControlNodes()) { for (auto &in_control_node : wrapped_node->GetInControlNodes()) {
if (in_node_set.count(in_control_node) == 0) {
if (in_node_set.count(in_control_node) == 0 && kMergeInputSkipTypes.count(root_node->GetType()) == 0) {
GELOGD("[%s] Restore control edge to [%s]", in_control_node->GetName().c_str(), root_node->GetName().c_str()); GELOGD("[%s] Restore control edge to [%s]", in_control_node->GetName().c_str(), root_node->GetName().c_str());
GE_CHECK_NOTNULL(in_control_node->GetOutControlAnchor()); GE_CHECK_NOTNULL(in_control_node->GetOutControlAnchor());
(void) in_control_node->GetOutControlAnchor()->LinkTo(root_node->GetInControlAnchor()); (void) in_control_node->GetOutControlAnchor()->LinkTo(root_node->GetInControlAnchor());
@@ -2282,8 +2288,6 @@ Status HybridModelBuilder::RelinkNextIteration() {
} }
} }


stream_merge_op_nodes_.clear();
next_iteration_op_nodes_.clear();
return SUCCESS; return SUCCESS;
} }


@@ -2371,10 +2375,12 @@ Status HybridModelBuilder::BuildControlFlowGroup(GraphItem &graph_item, const No
} }


Status HybridModelBuilder::CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item) { Status HybridModelBuilder::CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item) {
const auto out_ctrl_anchor = node->GetOutControlAnchor();
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
const auto &dst_node = peer_in_anchor->GetOwnerNode();
for (const auto &dst_node : node->GetOutControlNodes()) {
GE_CHECK_NOTNULL(dst_node); GE_CHECK_NOTNULL(dst_node);
if ((dst_node->GetType() == STREAMACTIVE) && (kStreamActiveTypes.count(node->GetType()) == 0)) {
GELOGI("[%s] ignore control to [%s]", node->GetName().c_str(), dst_node->GetName().c_str());
continue;
}


NodeItem *dst_node_item = nullptr; NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item),
@@ -2384,27 +2390,80 @@ Status HybridModelBuilder::CreateNormalNodeGroup(const NodePtr &node, NodeItem *
return SUCCESS; return SUCCESS;
} }


Status HybridModelBuilder::CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item) {
// Enter --> StreamActive --> StreamMerge
for (const auto &dst_node : node->GetOutControlNodes()) {
GE_CHECK_NOTNULL(dst_node);
NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item),
"[%s] failed to get or create node item", dst_node->GetName().c_str());
// Set Enter Control to StreamMerge as Group 0.
dst_node_item->switch_groups_.resize(kLoopMergeSize);
dst_node_item->SetMergeCtrl(node_item, kLoopEnterIdx);
}
return SUCCESS;
}

Status HybridModelBuilder::CreateMergeIterationGroup(const NodePtr &node, NodeItem *node_item) {
// NextIteration --> StreamActive {-->} StreamMerge
std::string node_name;
for (const auto &src_node : node->GetInControlNodes()) {
GE_CHECK_NOTNULL(src_node);
if (kNextIterationOpTypes.count(src_node->GetType()) == 0) {
GELOGI("[%s] Skip Not NextIteration node [%s]", node->GetName().c_str(), src_node->GetName().c_str());
continue;
}

if (!AttrUtils::GetStr(src_node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, node_name)) {
GELOGE(INTERNAL_ERROR, "[%s] input node [%s] expect attribute[%s] not found",
node->GetName().c_str(), src_node->GetName().c_str(), ATTR_NAME_NEXT_ITERATION.c_str());
return INTERNAL_ERROR;
}

const auto it = stream_merge_op_nodes_.find(node_name);
if (it == stream_merge_op_nodes_.end()) {
GELOGE(INTERNAL_ERROR, "[%s] expect StreamMerge[%s] not found", node->GetName().c_str(), node_name.c_str());
return INTERNAL_ERROR;
}

const auto &dst_node = it->second;
GE_CHECK_NOTNULL(dst_node);
NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), "[%s] failed to get or create node item",
dst_node->GetName().c_str());
// Set NextIteration Control to StreamMerge as Group 1.
dst_node_item->SetMergeCtrl(node_item, kLoopIterationIdx);
}
return SUCCESS;
}

Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item) { Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item) {
if (node_item->node_type != STREAMACTIVE) { if (node_item->node_type != STREAMACTIVE) {
GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node_item->node_type.c_str()); GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node_item->node_type.c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }


node_item->switch_groups_.resize(kStreamActiveNum);
const auto &out_ctrl_anchor = node->GetOutControlAnchor();
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
const auto &dst_node = peer_in_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(dst_node);
if (dst_node->GetType() == STREAMMERGE) {
GELOGI("[%s] skip control node: %s", node->GetName().c_str(), dst_node->GetName().c_str());
continue;
}
const auto ctrl_nodes = node->GetInControlNodes();
if (ctrl_nodes.empty()) {
GELOGW("Skip no in control node: %s", node->GetName().c_str());
return SUCCESS;
}


NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item),
"[%s] failed to get or create node item", dst_node->GetName().c_str());
node_item->SetCtrlSend(dst_node_item, kStreamActiveIdx);
const auto IsEnterNode = [](const NodePtr &n) {
return kEnterOpTypes.count(n->GetType()) > 0;
};
const auto IsIterationNode = [](const NodePtr &n) {
return kNextIterationOpTypes.count(n->GetType()) > 0;
};

if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) {
// Enter --> StreamActive --> StreamMerge
return CreateMergeEnterGroup(node, node_item);
} else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) {
// NextIteration --> StreamActive {-->} StreamMerge
return CreateMergeIterationGroup(node, node_item);
} }

return SUCCESS; return SUCCESS;
} }


@@ -2416,11 +2475,8 @@ Status HybridModelBuilder::CreateStreamSwitchGroup(const NodePtr &node, NodeItem


// Consider as two groups, group[0] set empty for false, group[1] for true. // Consider as two groups, group[0] set empty for false, group[1] for true.
node_item->switch_groups_.resize(kStreamSwitchNum); node_item->switch_groups_.resize(kStreamSwitchNum);
const auto &out_ctrl_anchor = node->GetOutControlAnchor();
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
const auto &dst_node = peer_in_anchor->GetOwnerNode();
for (const auto &dst_node : node->GetOutControlNodes()) {
GE_CHECK_NOTNULL(dst_node); GE_CHECK_NOTNULL(dst_node);

NodeItem *dst_node_item = nullptr; NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item),
"[%s] failed to get or create node item", dst_node->GetName().c_str()); "[%s] failed to get or create node item", dst_node->GetName().c_str());
@@ -2447,20 +2503,17 @@ Status HybridModelBuilder::CreateStreamSwitchNGroup(const NodePtr &node, NodeIte
} }


node_item->switch_groups_.resize(batch_num); node_item->switch_groups_.resize(batch_num);
const auto &out_ctrl_anchor = node->GetOutControlAnchor();
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
const auto &dst_node = peer_in_anchor->GetOwnerNode();
for (const auto &dst_node : node->GetOutControlNodes()) {
GE_CHECK_NOTNULL(dst_node); GE_CHECK_NOTNULL(dst_node);

std::string batch_label; std::string batch_label;
if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) {
GELOGE(INTERNAL_ERROR, "[%s] Get ATTR_NAME_BATCH_LABEL failed", node->GetName().c_str());
if (!AttrUtils::GetStr(dst_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) {
GELOGE(INTERNAL_ERROR, "[%s] Get ATTR_NAME_BATCH_LABEL failed", dst_node->GetName().c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }


std::string::size_type pos = batch_label.rfind("_"); std::string::size_type pos = batch_label.rfind("_");
if (pos == std::string::npos) { if (pos == std::string::npos) {
GELOGW("[%s] Separator not found in batch label: %s.", node->GetName().c_str(), batch_label.c_str());
GELOGW("[%s] Separator not found in batch label: %s.", dst_node->GetName().c_str(), batch_label.c_str());
continue; continue;
} }


@@ -2486,7 +2539,7 @@ Status HybridModelBuilder::CreateNextIterationGroup(const NodePtr &node, NodeIte
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }


return SUCCESS;
return CreateNormalNodeGroup(node, node_item);
} }


Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node_item) { Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node_item) {
@@ -2495,11 +2548,8 @@ Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }


const auto &out_ctrl_anchor = node->GetOutControlAnchor();
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
const auto &dst_node = peer_in_anchor->GetOwnerNode();
for (const auto &dst_node : node->GetOutControlNodes()) {
GE_CHECK_NOTNULL(dst_node); GE_CHECK_NOTNULL(dst_node);

NodeItem *dst_node_item = nullptr; NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item),
"[%s] failed to get or create node item", dst_node->GetName().c_str()); "[%s] failed to get or create node item", dst_node->GetName().c_str());
@@ -2509,11 +2559,8 @@ Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node
// Group switch flow by out put data. // Group switch flow by out put data.
node_item->switch_groups_.resize(SWITCH_OUTPUT_NUM); node_item->switch_groups_.resize(SWITCH_OUTPUT_NUM);
for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) {
const auto &out_anchor = node->GetOutDataAnchor(i);
for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
const auto &dst_node = peer_in_anchor->GetOwnerNode();
for (const auto &dst_node : node->GetOutDataNodes()) {
GE_CHECK_NOTNULL(dst_node); GE_CHECK_NOTNULL(dst_node);

NodeItem *dst_node_item = nullptr; NodeItem *dst_node_item = nullptr;
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item),
"[%s] failed to get or create node item", dst_node->GetName().c_str()); "[%s] failed to get or create node item", dst_node->GetName().c_str());


+ 2
- 0
ge/hybrid/model/hybrid_model_builder.h View File

@@ -99,6 +99,8 @@ class HybridModelBuilder {
Status BuildProfilingControl(GraphItem &graph_item, const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes); Status BuildProfilingControl(GraphItem &graph_item, const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes);
Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item); Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item);
Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item); Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item);
Status CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item);
Status CreateMergeIterationGroup(const NodePtr &node, NodeItem *node_item);
Status CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item); Status CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item);
Status CreateStreamSwitchGroup(const NodePtr &node, NodeItem *node_item); Status CreateStreamSwitchGroup(const NodePtr &node, NodeItem *node_item);
Status CreateStreamSwitchNGroup(const NodePtr &node, NodeItem *node_item); Status CreateStreamSwitchNGroup(const NodePtr &node, NodeItem *node_item);


+ 20
- 2
ge/hybrid/model/node_item.cc View File

@@ -34,8 +34,8 @@ const std::set<std::string> kControlOpTypes{
}; };


const std::set<std::string> kControlFlowOpTypes{ const std::set<std::string> kControlFlowOpTypes{
STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX,
NEXTITERATION, REFNEXTITERATION
STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT,
LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX
}; };


const std::set<std::string> kMergeOpTypes{ const std::set<std::string> kMergeOpTypes{
@@ -420,6 +420,24 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) {
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());
} }


void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) {
if (merge_index >= switch_groups_.size()) {
GELOGE(FAILED, "[%s] group size: %zu, merge index: %u", NodeName().c_str(), switch_groups_.size(), merge_index);
return;
}

// this is StreamMerge node, node_item is StreamActive node.
std::vector<const NodeItem *> &switch_group = switch_groups_[merge_index];
switch_group.emplace_back(node_item);

node_item->ctrl_send_.emplace(this);
GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str());
}

size_t NodeItem::GetMergeCtrl(uint32_t merge_index) const {
return (merge_index < switch_groups_.size()) ? switch_groups_[merge_index].size() : 0;
}

OptionalMutexGuard::OptionalMutexGuard(std::mutex *mutex, const string &name) : mu_(mutex), name_(name) { OptionalMutexGuard::OptionalMutexGuard(std::mutex *mutex, const string &name) : mu_(mutex), name_(name) {
if (mu_ != nullptr) { if (mu_ != nullptr) {
GELOGD("lock for %s", name_.c_str()); GELOGD("lock for %s", name_.c_str());


+ 2
- 0
ge/hybrid/model/node_item.h View File

@@ -98,6 +98,8 @@ struct NodeItem {


void SetDataSend(NodeItem *node_item, int anchor_index); void SetDataSend(NodeItem *node_item, int anchor_index);
void SetCtrlSend(NodeItem *node_item, uint32_t switch_index); void SetCtrlSend(NodeItem *node_item, uint32_t switch_index);
void SetMergeCtrl(NodeItem *node_item, uint32_t merge_index);
size_t GetMergeCtrl(uint32_t merge_index) const;


OptionalMutexGuard MutexGuard(const std::string &name) const { OptionalMutexGuard MutexGuard(const std::string &name) const {
return OptionalMutexGuard(copy_mu_.get(), name + "_" + node_name); return OptionalMutexGuard(copy_mu_.get(), name + "_" + node_name);


+ 8
- 0
ge/hybrid/node_executor/rts/rts_node_task.cc View File

@@ -20,6 +20,7 @@
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"
#include "graph/utils/node_utils.h"
#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "common/op/ge_op_utils.h" #include "common/op/ge_op_utils.h"


@@ -201,6 +202,13 @@ Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::functio
GE_CHECK_NOTNULL(in_x); GE_CHECK_NOTNULL(in_x);
GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(0, *in_x)); // y GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(0, *in_x)); // y


const auto &node_state = task_context.GetNodeState();
if (kNextIterationOpTypes.count(node_state->GetType()) > 0) {
node_state->RunLoopNext();
} else if (kExitOpTypes.count(node_state->GetType()) > 0) {
node_state->RunLoopExit();
}

if (done_callback) { if (done_callback) {
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
} }


Loading…
Cancel
Save