@@ -193,23 +193,29 @@ Status SetCyclicDependenceFlag(const ge::NodePtr &node) { | |||||
/// | /// | ||||
/// @brief set op next_iteration name | /// @brief set op next_iteration name | ||||
/// @param [in] node | |||||
/// @param [in] next | |||||
/// @param [in] Merge Node | |||||
/// @param [in] NextIteration Node | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status SetNextIteration(const ge::NodePtr &node, const std::string &next) { | |||||
Status SetNextIteration(const NodePtr &node, const NodePtr &next) { | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
OpDescPtr tmp_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(tmp_desc); | |||||
GE_CHECK_NOTNULL(next); | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
GE_CHECK_NOTNULL(next->GetOpDesc()); | |||||
if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_NEXT_ITERATION, next)) { | |||||
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), | |||||
node->GetName().c_str(), node->GetType().c_str()); | |||||
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), | |||||
node->GetName().c_str(), node->GetType().c_str()); | |||||
return FAILED; | |||||
} | |||||
const auto SetIterationName = [](const OpDescPtr &op_desc, const std::string &name) { | |||||
if (!AttrUtils::SetStr(op_desc, ATTR_NAME_NEXT_ITERATION, name)) { | |||||
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(), | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
return FAILED; | |||||
} | |||||
return SUCCESS; | |||||
}; | |||||
GE_CHK_STATUS_RET_NOLOG(SetIterationName(node->GetOpDesc(), next->GetName())); | |||||
GE_CHK_STATUS_RET_NOLOG(SetIterationName(next->GetOpDesc(), node->GetName())); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -96,11 +96,11 @@ Status SetCyclicDependenceFlag(const ge::NodePtr &node); | |||||
/// | /// | ||||
/// @brief set op next_iteration name | /// @brief set op next_iteration name | ||||
/// @param [in] node | |||||
/// @param [in] next | |||||
/// @param [in] Merge Node | |||||
/// @param [in] NextIteration Node | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status SetNextIteration(const ge::NodePtr &node, const std::string &next); | |||||
Status SetNextIteration(const NodePtr &node, const NodePtr &next); | |||||
/// | /// | ||||
/// @brief Align the memory | /// @brief Align the memory | ||||
@@ -387,6 +387,9 @@ void DynamicShapePartitioner::MergeClustersUnknownShape() { | |||||
if (!in_cluster->IsUnknownShape()) { | if (!in_cluster->IsUnknownShape()) { | ||||
continue; | continue; | ||||
} | } | ||||
if (!cluster->IsAdjoinNodes(in_cluster)) { | |||||
continue; | |||||
} | |||||
auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); | auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); | ||||
GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(), | GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(), | ||||
ToString(merged_clusters).c_str()); | ToString(merged_clusters).c_str()); | ||||
@@ -80,6 +80,10 @@ class DynamicShapePartitioner { | |||||
Status BuildPartitionSubgraph(); | Status BuildPartitionSubgraph(); | ||||
// Clear resource and break circular dependency | // Clear resource and break circular dependency | ||||
void Clear(); | void Clear(); | ||||
bool IsAdjoinNodes(const std::shared_ptr<Cluster> &other) const { | |||||
const auto &out_clusters = other->out_clusters_; | |||||
return std::find(out_clusters.begin(), out_clusters.end(), shared_from_this()) != out_clusters.end(); | |||||
} | |||||
private: | private: | ||||
static thread_local size_t unique_id_; | static thread_local size_t unique_id_; | ||||
@@ -354,7 +354,7 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr & | |||||
merge_node->GetName().c_str()); | merge_node->GetName().c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
if (SetNextIteration(merge_node, next_node->GetName()) != SUCCESS) { | |||||
if (SetNextIteration(merge_node, next_node) != SUCCESS) { | |||||
REPORT_CALL_ERROR("E19999", "Set attr NEXT_ITERATION value:%s to node:%s(%s) failed", | REPORT_CALL_ERROR("E19999", "Set attr NEXT_ITERATION value:%s to node:%s(%s) failed", | ||||
next_node->GetName().c_str(), merge_node->GetName().c_str(), merge_node->GetType().c_str()); | next_node->GetName().c_str(), merge_node->GetName().c_str(), merge_node->GetType().c_str()); | ||||
GELOGE(INTERNAL_ERROR, "Set attr NEXT_ITERATION for node %s failed.", merge_node->GetName().c_str()); | GELOGE(INTERNAL_ERROR, "Set attr NEXT_ITERATION for node %s failed.", merge_node->GetName().c_str()); | ||||
@@ -306,28 +306,15 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||||
return task_context_; | return task_context_; | ||||
} | } | ||||
void NodeState::ResetContext(int group) { | |||||
SetGroup(group); | |||||
if (loop_count_ == 0) { | |||||
++loop_count_; | |||||
return; | |||||
} | |||||
++loop_count_; | |||||
if (loop_count_ == UINT64_MAX) { | |||||
loop_count_ = 1; | |||||
} | |||||
void NodeState::ResetContext(uint64_t loop_count) { | |||||
loop_count_ = loop_count; | |||||
switch_index_ = -1; | switch_index_ = -1; | ||||
subgraph_context_->ResetContext(node_item_->node); | subgraph_context_->ResetContext(node_item_->node); | ||||
GELOGD("Node[%s] in while loop, current loop: %lu, merge index: %d", GetName().c_str(), loop_count_, merge_index_); | |||||
} | |||||
void NodeState::ResetSchedule() { | |||||
std::lock_guard<std::mutex> lk(mu_); | |||||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | ||||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | ||||
GELOGD("[%s] set schedule for root nodes, data: %u, ctrl: %u", GetName().c_str(), data_scheduled_, ctrl_scheduled_); | |||||
GELOGD("[%s] in while loop, loop count: %lu, data scheduled: %u, ctrl scheduled: %u, merge index: %d", | |||||
GetName().c_str(), loop_count_, data_scheduled_, ctrl_scheduled_, merge_index_); | |||||
} | } | ||||
Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &ready) const { | Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &ready) const { | ||||
@@ -335,14 +322,14 @@ Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &rea | |||||
for (const auto &node : node_item_->data_send_) { | for (const auto &node : node_item_->data_send_) { | ||||
const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | ||||
GE_CHECK_NOTNULL(dst_node_state); | GE_CHECK_NOTNULL(dst_node_state); | ||||
dst_node_state->SetDataSchedule(node_item_, ready); | |||||
dst_node_state->SetDataSchedule(*this, ready); | |||||
} | } | ||||
// Schedule ctrl output. | // Schedule ctrl output. | ||||
for (const auto &node : node_item_->ctrl_send_) { | for (const auto &node : node_item_->ctrl_send_) { | ||||
const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | ||||
GE_CHECK_NOTNULL(dst_node_state); | GE_CHECK_NOTNULL(dst_node_state); | ||||
dst_node_state->SetCtrlSchedule(node_item_, ready); | |||||
dst_node_state->SetCtrlSchedule(*this, ready); | |||||
} | } | ||||
// Schedule switch group. | // Schedule switch group. | ||||
@@ -351,7 +338,7 @@ Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &rea | |||||
for (const auto &node : node_item_->switch_groups_[switch_index_]) { | for (const auto &node : node_item_->switch_groups_[switch_index_]) { | ||||
const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | const auto &dst_node_state = subgraph_context_->GetOrCreateNodeState(node); | ||||
GE_CHECK_NOTNULL(dst_node_state); | GE_CHECK_NOTNULL(dst_node_state); | ||||
dst_node_state->SetCtrlSchedule(node_item_, ready); | |||||
dst_node_state->SetCtrlSchedule(*this, ready); | |||||
} | } | ||||
} | } | ||||
@@ -359,36 +346,44 @@ Status NodeState::NodeScheduled(const std::function<void(const NodeItem *)> &rea | |||||
} | } | ||||
bool NodeState::IsScheduleReady() const { | bool NodeState::IsScheduleReady() const { | ||||
GELOGD("[%s] data[input: %zu, scheduled: %u], ctrl[input: %zu, scheduled: %u]", GetName().c_str(), | |||||
node_item_->data_recv_.size(), data_scheduled_, node_item_->ctrl_recv_.size(), ctrl_scheduled_); | |||||
if (ctrl_scheduled_ != node_item_->ctrl_recv_.size()) { | |||||
return false; | |||||
} | |||||
GELOGD("[%s] loop[%lu] data[input: %zu, scheduled: %u], ctrl[input: %zu+%zu, scheduled: %u]", | |||||
GetName().c_str(), loop_count_, node_item_->data_recv_.size(), data_scheduled_, | |||||
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||||
if (node_item_->IsMergeOp()) { | if (node_item_->IsMergeOp()) { | ||||
if (ctrl_scheduled_ != node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1) + node_item_->ctrl_recv_.size()) { | |||||
return false; | |||||
} | |||||
return data_scheduled_ > 0; | return data_scheduled_ > 0; | ||||
} | } | ||||
if (ctrl_scheduled_ != node_item_->ctrl_recv_.size()) { | |||||
return false; | |||||
} | |||||
// Exit may feed loop times... | // Exit may feed loop times... | ||||
return data_scheduled_ >= node_item_->data_recv_.size(); | return data_scheduled_ >= node_item_->data_recv_.size(); | ||||
} | } | ||||
void NodeState::SetDataSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready) { | |||||
GELOGD("[%s] data schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu, current scheduled: %u", | |||||
node_item->node_name.c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, | |||||
node_item_->ctrl_recv_.size(), ctrl_scheduled_); | |||||
void NodeState::SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) { | |||||
GELOGD("[%s] data schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu+%zu, current scheduled: %u", | |||||
node_state.GetName().c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, | |||||
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||||
std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
if (loop_count_ != node_state.loop_count_) { | |||||
ResetContext(node_state.loop_count_); | |||||
} | |||||
++data_scheduled_; | ++data_scheduled_; | ||||
if (node_item_->IsMergeOp()) { | if (node_item_->IsMergeOp()) { | ||||
const auto it = node_item_->data_recv_.find(node_item); | |||||
const auto it = node_item_->data_recv_.find(node_state.node_item_); | |||||
if (it != node_item_->data_recv_.end()) { | if (it != node_item_->data_recv_.end()) { | ||||
merge_index_ = it->second; | merge_index_ = it->second; | ||||
(void)AttrUtils::SetInt(node_item_->node->GetOpDesc(), ATTR_NAME_MERGE_INPUT_INDEX, it->second); | (void)AttrUtils::SetInt(node_item_->node->GetOpDesc(), ATTR_NAME_MERGE_INPUT_INDEX, it->second); | ||||
GELOGD("[%s] scheduled, [%s] set merge index: %d", node_item->node_name.c_str(), GetName().c_str(), it->second); | |||||
GELOGD("[%s] scheduled, [%s] set merge index: %d", node_state.GetName().c_str(), GetName().c_str(), it->second); | |||||
} else { | } else { | ||||
GELOGW("[%s] scheduled, [%s] not followed", node_item->node_name.c_str(), GetName().c_str()); | |||||
GELOGW("[%s] scheduled, [%s] not followed", node_state.GetName().c_str(), GetName().c_str()); | |||||
} | } | ||||
} | } | ||||
@@ -397,12 +392,15 @@ void NodeState::SetDataSchedule(const NodeItem *node_item, const std::function<v | |||||
} | } | ||||
} | } | ||||
void NodeState::SetCtrlSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready) { | |||||
GELOGD("[%s] ctrl schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu, current scheduled: %u", | |||||
node_item->node_name.c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, | |||||
node_item_->ctrl_recv_.size(), ctrl_scheduled_); | |||||
void NodeState::SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready) { | |||||
GELOGD("[%s] ctrl schedule node[%s], data num: %zu, current scheduled: %u, ctrl num: %zu+%zu, current scheduled: %u", | |||||
node_state.GetName().c_str(), GetName().c_str(), node_item_->data_recv_.size(), data_scheduled_, | |||||
node_item_->ctrl_recv_.size(), node_item_->GetMergeCtrl(loop_count_ == 0 ? 0 : 1), ctrl_scheduled_); | |||||
std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
if (loop_count_ != node_state.loop_count_) { | |||||
ResetContext(node_state.loop_count_); | |||||
} | |||||
++ctrl_scheduled_; | ++ctrl_scheduled_; | ||||
if (IsScheduleReady()) { | if (IsScheduleReady()) { | ||||
@@ -410,6 +408,21 @@ void NodeState::SetCtrlSchedule(const NodeItem *node_item, const std::function<v | |||||
} | } | ||||
} | } | ||||
void NodeState::RunLoopNext() { | |||||
GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_); | |||||
std::lock_guard<std::mutex> lk(mu_); | |||||
++loop_count_; | |||||
if (loop_count_ == UINT64_MAX) { | |||||
loop_count_ = 1; | |||||
} | |||||
} | |||||
void NodeState::RunLoopExit() { | |||||
GELOGD("Node[%s] run in loop, current count: %lu", GetName().c_str(), loop_count_); | |||||
std::lock_guard<std::mutex> lk(mu_); | |||||
loop_count_ = 0; | |||||
} | |||||
void NodeState::SetScheduleFuture(std::future<Status> &&future) { | void NodeState::SetScheduleFuture(std::future<Status> &&future) { | ||||
schedule_future_ = std::move(future); | schedule_future_ = std::move(future); | ||||
} | } | ||||
@@ -112,9 +112,8 @@ struct NodeState { | |||||
return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE; | return node_item_->IsControlFlowOp() || node_item_->shape_inference_type >= DEPEND_SHAPE_RANGE; | ||||
} | } | ||||
void ResetContext(int group); | |||||
void ResetSchedule(); | |||||
void RunLoopNext(); | |||||
void RunLoopExit(); | |||||
Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | ||||
@@ -166,8 +165,9 @@ struct NodeState { | |||||
private: | private: | ||||
bool IsScheduleReady() const; | bool IsScheduleReady() const; | ||||
void SetDataSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready); | |||||
void SetCtrlSchedule(const NodeItem *node_item, const std::function<void(const NodeItem *)> &ready); | |||||
void SetDataSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | |||||
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | |||||
void ResetContext(uint64_t loop_count); | |||||
const NodeItem *node_item_ = nullptr; | const NodeItem *node_item_ = nullptr; | ||||
std::shared_ptr<NodeTask> kernel_task_ = nullptr; | std::shared_ptr<NodeTask> kernel_task_ = nullptr; | ||||
@@ -46,6 +46,10 @@ Status SubgraphContext::Init() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void SubgraphContext::SetGroup(int group) { | |||||
group_ = group; | |||||
} | |||||
void SubgraphContext::ResetContext(const NodePtr &node) { | void SubgraphContext::ResetContext(const NodePtr &node) { | ||||
node_done_manager_.Reset(node); | node_done_manager_.Reset(node); | ||||
} | } | ||||
@@ -85,6 +89,7 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | |||||
if (node_state == nullptr) { | if (node_state == nullptr) { | ||||
const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | ||||
node_state = std::move(std::unique_ptr<NodeState>(new(std::nothrow)NodeState(*node_item, this))); | node_state = std::move(std::unique_ptr<NodeState>(new(std::nothrow)NodeState(*node_item, this))); | ||||
node_state->SetGroup(group_); | |||||
(void)guard; | (void)guard; | ||||
} | } | ||||
GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); | GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); | ||||
@@ -34,6 +34,7 @@ class SubgraphContext { | |||||
~SubgraphContext(); | ~SubgraphContext(); | ||||
Status Init(); | Status Init(); | ||||
void SetGroup(int group); | |||||
void ResetContext(const NodePtr &node); | void ResetContext(const NodePtr &node); | ||||
void Reset(); | void Reset(); | ||||
NodeStatePtr GetOrCreateNodeState(const NodeItem *node_item); | NodeStatePtr GetOrCreateNodeState(const NodeItem *node_item); | ||||
@@ -58,6 +59,7 @@ class SubgraphContext { | |||||
std::vector<TensorValue> all_outputs_; | std::vector<TensorValue> all_outputs_; | ||||
NodeDoneManager node_done_manager_; | NodeDoneManager node_done_manager_; | ||||
std::unordered_map<const NodeItem *, NodeStatePtr> node_states_; | std::unordered_map<const NodeItem *, NodeStatePtr> node_states_; | ||||
int group_ = -1; | |||||
}; | }; | ||||
} // namespace hybrid | } // namespace hybrid | ||||
} // namespace ge | } // namespace ge | ||||
@@ -242,7 +242,6 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) { | |||||
auto node_state = subgraph_context_->GetOrCreateNodeState(&node_item); | auto node_state = subgraph_context_->GetOrCreateNodeState(&node_item); | ||||
GE_CHECK_NOTNULL(node_state); | GE_CHECK_NOTNULL(node_state); | ||||
node_state->ResetContext(group); | |||||
auto p_node_state = node_state.get(); | auto p_node_state = node_state.get(); | ||||
if (node_item.node_type == NETOUTPUT) { | if (node_item.node_type == NETOUTPUT) { | ||||
@@ -367,7 +366,6 @@ Status SubgraphExecutor::NodeScheduled(NodeState *node_state) { | |||||
}; | }; | ||||
GE_CHK_STATUS_RET_NOLOG(node_state->NodeScheduled(callback)); | GE_CHK_STATUS_RET_NOLOG(node_state->NodeScheduled(callback)); | ||||
node_state->ResetSchedule(); | |||||
RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] End"); | RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] End"); | ||||
return SUCCESS; | return SUCCESS; | ||||
}); | }); | ||||
@@ -539,6 +537,7 @@ Status SubgraphExecutor::LaunchTasks() { | |||||
Status SubgraphExecutor::ScheduleTasks(int group) { | Status SubgraphExecutor::ScheduleTasks(int group) { | ||||
GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); | GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); | ||||
subgraph_context_->SetGroup(group); | |||||
auto prepare_future = std::async(std::launch::async, [&]() -> Status { | auto prepare_future = std::async(std::launch::async, [&]() -> Status { | ||||
GetContext().SetSessionId(context_->session_id); | GetContext().SetSessionId(context_->session_id); | ||||
GetContext().SetContextId(context_->context_id); | GetContext().SetContextId(context_->context_id); | ||||
@@ -21,6 +21,7 @@ | |||||
#include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
#include "graph/build/memory/var_mem_assign_util.h" | #include "graph/build/memory/var_mem_assign_util.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/common/omg_util.h" | |||||
#include "graph/load/model_manager/model_utils.h" | #include "graph/load/model_manager/model_utils.h" | ||||
#include "graph/load/model_manager/model_manager.h" | #include "graph/load/model_manager/model_manager.h" | ||||
#include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
@@ -43,8 +44,9 @@ const uint64_t kProfilingBpEndLogid = 2U; | |||||
const uint64_t kProfilingIterEndLogid = 65535U; | const uint64_t kProfilingIterEndLogid = 65535U; | ||||
const int kBytes = 8; | const int kBytes = 8; | ||||
const int kDecimal = 10; | const int kDecimal = 10; | ||||
const uint8_t kStreamActiveIdx = 0; | |||||
const uint8_t kStreamActiveNum = 1; | |||||
const uint8_t kLoopEnterIdx = 0; | |||||
const uint8_t kLoopIterationIdx = 1; | |||||
const uint8_t kLoopMergeSize = 2; | |||||
const uint8_t kStreamSwitchIdx = 1; | const uint8_t kStreamSwitchIdx = 1; | ||||
const uint8_t kStreamSwitchNum = 2; | const uint8_t kStreamSwitchNum = 2; | ||||
const uint32_t kStringHeadElems = 2; | const uint32_t kStringHeadElems = 2; | ||||
@@ -57,6 +59,10 @@ const char *const kProfilingArNode = "ProfilingAllReduceNode"; | |||||
const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; | const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; | ||||
const char *const kForceInfershape = "_force_infershape_when_running"; | const char *const kForceInfershape = "_force_infershape_when_running"; | ||||
const std::set<std::string> kExecutionDependentTypes{ IF, STATELESSIF, CASE, STREAMSWITCH }; | |||||
const std::set<std::string> kMergeInputSkipTypes{ STREAMACTIVE, STREAMSWITCH, CONSTANT, CONSTANTOP }; | |||||
const std::set<std::string> kStreamActiveTypes{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | |||||
Status SetOutputNameAttr(ComputeGraph &graph) { | Status SetOutputNameAttr(ComputeGraph &graph) { | ||||
vector<string> output_names; | vector<string> output_names; | ||||
for (const auto &node : graph.GetDirectNode()) { | for (const auto &node : graph.GetDirectNode()) { | ||||
@@ -389,7 +395,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s | |||||
} | } | ||||
// cond or branch need to be prepared before the execution of IF or CASE | // cond or branch need to be prepared before the execution of IF or CASE | ||||
if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { | |||||
if (kExecutionDependentTypes.count(node_item.node_type) > 0) { | |||||
auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input | auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input | ||||
GE_CHECK_NOTNULL(src_node); | GE_CHECK_NOTNULL(src_node); | ||||
auto src_node_item = MutableNodeItem(src_node); | auto src_node_item = MutableNodeItem(src_node); | ||||
@@ -575,7 +581,7 @@ Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) { | |||||
auto in_nodes = root_node->GetInAllNodes(); | auto in_nodes = root_node->GetInAllNodes(); | ||||
std::set<NodePtr> in_node_set(in_nodes.begin(), in_nodes.end()); | std::set<NodePtr> in_node_set(in_nodes.begin(), in_nodes.end()); | ||||
for (auto &in_control_node : wrapped_node->GetInControlNodes()) { | for (auto &in_control_node : wrapped_node->GetInControlNodes()) { | ||||
if (in_node_set.count(in_control_node) == 0) { | |||||
if (in_node_set.count(in_control_node) == 0 && kMergeInputSkipTypes.count(root_node->GetType()) == 0) { | |||||
GELOGD("[%s] Restore control edge to [%s]", in_control_node->GetName().c_str(), root_node->GetName().c_str()); | GELOGD("[%s] Restore control edge to [%s]", in_control_node->GetName().c_str(), root_node->GetName().c_str()); | ||||
GE_CHECK_NOTNULL(in_control_node->GetOutControlAnchor()); | GE_CHECK_NOTNULL(in_control_node->GetOutControlAnchor()); | ||||
(void) in_control_node->GetOutControlAnchor()->LinkTo(root_node->GetInControlAnchor()); | (void) in_control_node->GetOutControlAnchor()->LinkTo(root_node->GetInControlAnchor()); | ||||
@@ -2282,8 +2288,6 @@ Status HybridModelBuilder::RelinkNextIteration() { | |||||
} | } | ||||
} | } | ||||
stream_merge_op_nodes_.clear(); | |||||
next_iteration_op_nodes_.clear(); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -2371,10 +2375,12 @@ Status HybridModelBuilder::BuildControlFlowGroup(GraphItem &graph_item, const No | |||||
} | } | ||||
Status HybridModelBuilder::CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item) { | Status HybridModelBuilder::CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item) { | ||||
const auto out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
for (const auto &dst_node : node->GetOutControlNodes()) { | |||||
GE_CHECK_NOTNULL(dst_node); | GE_CHECK_NOTNULL(dst_node); | ||||
if ((dst_node->GetType() == STREAMACTIVE) && (kStreamActiveTypes.count(node->GetType()) == 0)) { | |||||
GELOGI("[%s] ignore control to [%s]", node->GetName().c_str(), dst_node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
NodeItem *dst_node_item = nullptr; | NodeItem *dst_node_item = nullptr; | ||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | ||||
@@ -2384,27 +2390,80 @@ Status HybridModelBuilder::CreateNormalNodeGroup(const NodePtr &node, NodeItem * | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HybridModelBuilder::CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item) { | |||||
// Enter --> StreamActive --> StreamMerge | |||||
for (const auto &dst_node : node->GetOutControlNodes()) { | |||||
GE_CHECK_NOTNULL(dst_node); | |||||
NodeItem *dst_node_item = nullptr; | |||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | |||||
"[%s] failed to get or create node item", dst_node->GetName().c_str()); | |||||
// Set Enter Control to StreamMerge as Group 0. | |||||
dst_node_item->switch_groups_.resize(kLoopMergeSize); | |||||
dst_node_item->SetMergeCtrl(node_item, kLoopEnterIdx); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status HybridModelBuilder::CreateMergeIterationGroup(const NodePtr &node, NodeItem *node_item) { | |||||
// NextIteration --> StreamActive {-->} StreamMerge | |||||
std::string node_name; | |||||
for (const auto &src_node : node->GetInControlNodes()) { | |||||
GE_CHECK_NOTNULL(src_node); | |||||
if (kNextIterationOpTypes.count(src_node->GetType()) == 0) { | |||||
GELOGI("[%s] Skip Not NextIteration node [%s]", node->GetName().c_str(), src_node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
if (!AttrUtils::GetStr(src_node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, node_name)) { | |||||
GELOGE(INTERNAL_ERROR, "[%s] input node [%s] expect attribute[%s] not found", | |||||
node->GetName().c_str(), src_node->GetName().c_str(), ATTR_NAME_NEXT_ITERATION.c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
const auto it = stream_merge_op_nodes_.find(node_name); | |||||
if (it == stream_merge_op_nodes_.end()) { | |||||
GELOGE(INTERNAL_ERROR, "[%s] expect StreamMerge[%s] not found", node->GetName().c_str(), node_name.c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
const auto &dst_node = it->second; | |||||
GE_CHECK_NOTNULL(dst_node); | |||||
NodeItem *dst_node_item = nullptr; | |||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), "[%s] failed to get or create node item", | |||||
dst_node->GetName().c_str()); | |||||
// Set NextIteration Control to StreamMerge as Group 1. | |||||
dst_node_item->SetMergeCtrl(node_item, kLoopIterationIdx); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item) { | Status HybridModelBuilder::CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item) { | ||||
if (node_item->node_type != STREAMACTIVE) { | if (node_item->node_type != STREAMACTIVE) { | ||||
GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node_item->node_type.c_str()); | GELOGE(INTERNAL_ERROR, "Called by %s is invalid", node_item->node_type.c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
node_item->switch_groups_.resize(kStreamActiveNum); | |||||
const auto &out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
GE_CHECK_NOTNULL(dst_node); | |||||
if (dst_node->GetType() == STREAMMERGE) { | |||||
GELOGI("[%s] skip control node: %s", node->GetName().c_str(), dst_node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
const auto ctrl_nodes = node->GetInControlNodes(); | |||||
if (ctrl_nodes.empty()) { | |||||
GELOGW("Skip no in control node: %s", node->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
NodeItem *dst_node_item = nullptr; | |||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | |||||
"[%s] failed to get or create node item", dst_node->GetName().c_str()); | |||||
node_item->SetCtrlSend(dst_node_item, kStreamActiveIdx); | |||||
const auto IsEnterNode = [](const NodePtr &n) { | |||||
return kEnterOpTypes.count(n->GetType()) > 0; | |||||
}; | |||||
const auto IsIterationNode = [](const NodePtr &n) { | |||||
return kNextIterationOpTypes.count(n->GetType()) > 0; | |||||
}; | |||||
if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsEnterNode)) { | |||||
// Enter --> StreamActive --> StreamMerge | |||||
return CreateMergeEnterGroup(node, node_item); | |||||
} else if (std::any_of(ctrl_nodes.begin(), ctrl_nodes.end(), IsIterationNode)) { | |||||
// NextIteration --> StreamActive {-->} StreamMerge | |||||
return CreateMergeIterationGroup(node, node_item); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -2416,11 +2475,8 @@ Status HybridModelBuilder::CreateStreamSwitchGroup(const NodePtr &node, NodeItem | |||||
// Consider as two groups, group[0] set empty for false, group[1] for true. | // Consider as two groups, group[0] set empty for false, group[1] for true. | ||||
node_item->switch_groups_.resize(kStreamSwitchNum); | node_item->switch_groups_.resize(kStreamSwitchNum); | ||||
const auto &out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
for (const auto &dst_node : node->GetOutControlNodes()) { | |||||
GE_CHECK_NOTNULL(dst_node); | GE_CHECK_NOTNULL(dst_node); | ||||
NodeItem *dst_node_item = nullptr; | NodeItem *dst_node_item = nullptr; | ||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | ||||
"[%s] failed to get or create node item", dst_node->GetName().c_str()); | "[%s] failed to get or create node item", dst_node->GetName().c_str()); | ||||
@@ -2447,20 +2503,17 @@ Status HybridModelBuilder::CreateStreamSwitchNGroup(const NodePtr &node, NodeIte | |||||
} | } | ||||
node_item->switch_groups_.resize(batch_num); | node_item->switch_groups_.resize(batch_num); | ||||
const auto &out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
for (const auto &dst_node : node->GetOutControlNodes()) { | |||||
GE_CHECK_NOTNULL(dst_node); | GE_CHECK_NOTNULL(dst_node); | ||||
std::string batch_label; | std::string batch_label; | ||||
if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { | |||||
GELOGE(INTERNAL_ERROR, "[%s] Get ATTR_NAME_BATCH_LABEL failed", node->GetName().c_str()); | |||||
if (!AttrUtils::GetStr(dst_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { | |||||
GELOGE(INTERNAL_ERROR, "[%s] Get ATTR_NAME_BATCH_LABEL failed", dst_node->GetName().c_str()); | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
std::string::size_type pos = batch_label.rfind("_"); | std::string::size_type pos = batch_label.rfind("_"); | ||||
if (pos == std::string::npos) { | if (pos == std::string::npos) { | ||||
GELOGW("[%s] Separator not found in batch label: %s.", node->GetName().c_str(), batch_label.c_str()); | |||||
GELOGW("[%s] Separator not found in batch label: %s.", dst_node->GetName().c_str(), batch_label.c_str()); | |||||
continue; | continue; | ||||
} | } | ||||
@@ -2486,7 +2539,7 @@ Status HybridModelBuilder::CreateNextIterationGroup(const NodePtr &node, NodeIte | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
return SUCCESS; | |||||
return CreateNormalNodeGroup(node, node_item); | |||||
} | } | ||||
Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node_item) { | Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node_item) { | ||||
@@ -2495,11 +2548,8 @@ Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
const auto &out_ctrl_anchor = node->GetOutControlAnchor(); | |||||
for (const auto &peer_in_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { | |||||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
for (const auto &dst_node : node->GetOutControlNodes()) { | |||||
GE_CHECK_NOTNULL(dst_node); | GE_CHECK_NOTNULL(dst_node); | ||||
NodeItem *dst_node_item = nullptr; | NodeItem *dst_node_item = nullptr; | ||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | ||||
"[%s] failed to get or create node item", dst_node->GetName().c_str()); | "[%s] failed to get or create node item", dst_node->GetName().c_str()); | ||||
@@ -2509,11 +2559,8 @@ Status HybridModelBuilder::CreateSwitchGroup(const NodePtr &node, NodeItem *node | |||||
// Group switch flow by out put data. | // Group switch flow by out put data. | ||||
node_item->switch_groups_.resize(SWITCH_OUTPUT_NUM); | node_item->switch_groups_.resize(SWITCH_OUTPUT_NUM); | ||||
for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | ||||
const auto &out_anchor = node->GetOutDataAnchor(i); | |||||
for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||||
const auto &dst_node = peer_in_anchor->GetOwnerNode(); | |||||
for (const auto &dst_node : node->GetOutDataNodes()) { | |||||
GE_CHECK_NOTNULL(dst_node); | GE_CHECK_NOTNULL(dst_node); | ||||
NodeItem *dst_node_item = nullptr; | NodeItem *dst_node_item = nullptr; | ||||
GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | GE_CHK_STATUS_RET(GetOrCreateNodeItem(dst_node, &dst_node_item), | ||||
"[%s] failed to get or create node item", dst_node->GetName().c_str()); | "[%s] failed to get or create node item", dst_node->GetName().c_str()); | ||||
@@ -99,6 +99,8 @@ class HybridModelBuilder { | |||||
Status BuildProfilingControl(GraphItem &graph_item, const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes); | Status BuildProfilingControl(GraphItem &graph_item, const std::map<size_t, std::pair<uint32_t, uint32_t>> &nodes); | ||||
Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item); | Status BuildControlFlowGroup(GraphItem &graph_item, const NodePtr &node, NodeItem *node_item); | ||||
Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item); | Status CreateNormalNodeGroup(const NodePtr &node, NodeItem *node_item); | ||||
Status CreateMergeEnterGroup(const NodePtr &node, NodeItem *node_item); | |||||
Status CreateMergeIterationGroup(const NodePtr &node, NodeItem *node_item); | |||||
Status CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item); | Status CreateStreamActiveGroup(const NodePtr &node, NodeItem *node_item); | ||||
Status CreateStreamSwitchGroup(const NodePtr &node, NodeItem *node_item); | Status CreateStreamSwitchGroup(const NodePtr &node, NodeItem *node_item); | ||||
Status CreateStreamSwitchNGroup(const NodePtr &node, NodeItem *node_item); | Status CreateStreamSwitchNGroup(const NodePtr &node, NodeItem *node_item); | ||||
@@ -34,8 +34,8 @@ const std::set<std::string> kControlOpTypes{ | |||||
}; | }; | ||||
const std::set<std::string> kControlFlowOpTypes{ | const std::set<std::string> kControlFlowOpTypes{ | ||||
STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX, | |||||
NEXTITERATION, REFNEXTITERATION | |||||
STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT, | |||||
LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX | |||||
}; | }; | ||||
const std::set<std::string> kMergeOpTypes{ | const std::set<std::string> kMergeOpTypes{ | ||||
@@ -401,6 +401,11 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||||
if (is_root_node_) { | if (is_root_node_) { | ||||
node_item->root_data_.emplace(this); | node_item->root_data_.emplace(this); | ||||
} | } | ||||
// If Enter feed Not Merge, take as root Node. | |||||
if ((kEnterOpTypes.count(node_type) > 0) && (node_item->node_type != STREAMMERGE)) { | |||||
node_item->root_data_.emplace(this); | |||||
node_item->enter_inside_.emplace(anchor_index); | |||||
} | |||||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||
} | } | ||||
@@ -416,10 +421,31 @@ void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | |||||
if (is_root_node_) { | if (is_root_node_) { | ||||
node_item->root_ctrl_.emplace(this); | node_item->root_ctrl_.emplace(this); | ||||
} | } | ||||
// If Enter feed control signal, take as root Node. | |||||
if (kEnterOpTypes.count(node_type) > 0) { | |||||
node_item->root_ctrl_.emplace(this); | |||||
} | |||||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | ||||
} | } | ||||
void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) { | |||||
if (merge_index >= switch_groups_.size()) { | |||||
GELOGE(FAILED, "[%s] group size: %zu, merge index: %u", NodeName().c_str(), switch_groups_.size(), merge_index); | |||||
return; | |||||
} | |||||
// this is StreamMerge node, node_item is StreamActive node. | |||||
std::vector<const NodeItem *> &switch_group = switch_groups_[merge_index]; | |||||
switch_group.emplace_back(node_item); | |||||
node_item->ctrl_send_.emplace(this); | |||||
GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); | |||||
} | |||||
size_t NodeItem::GetMergeCtrl(uint32_t merge_index) const { | |||||
return (merge_index < switch_groups_.size()) ? switch_groups_[merge_index].size() : 0; | |||||
} | |||||
OptionalMutexGuard::OptionalMutexGuard(std::mutex *mutex, const string &name) : mu_(mutex), name_(name) { | OptionalMutexGuard::OptionalMutexGuard(std::mutex *mutex, const string &name) : mu_(mutex), name_(name) { | ||||
if (mu_ != nullptr) { | if (mu_ != nullptr) { | ||||
GELOGD("lock for %s", name_.c_str()); | GELOGD("lock for %s", name_.c_str()); | ||||
@@ -98,6 +98,8 @@ struct NodeItem { | |||||
void SetDataSend(NodeItem *node_item, int anchor_index); | void SetDataSend(NodeItem *node_item, int anchor_index); | ||||
void SetCtrlSend(NodeItem *node_item, uint32_t switch_index); | void SetCtrlSend(NodeItem *node_item, uint32_t switch_index); | ||||
void SetMergeCtrl(NodeItem *node_item, uint32_t merge_index); | |||||
size_t GetMergeCtrl(uint32_t merge_index) const; | |||||
OptionalMutexGuard MutexGuard(const std::string &name) const { | OptionalMutexGuard MutexGuard(const std::string &name) const { | ||||
return OptionalMutexGuard(copy_mu_.get(), name + "_" + node_name); | return OptionalMutexGuard(copy_mu_.get(), name + "_" + node_name); | ||||
@@ -140,6 +142,7 @@ struct NodeItem { | |||||
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | ||||
std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | ||||
std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | ||||
std::set<int> enter_inside_; // Enter feed loop inside Node, Not cross Merge. | |||||
std::shared_ptr<NodeTask> kernel_task; | std::shared_ptr<NodeTask> kernel_task; | ||||
std::unique_ptr<FusedSubgraph> fused_subgraph; | std::unique_ptr<FusedSubgraph> fused_subgraph; | ||||
@@ -20,6 +20,7 @@ | |||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "graph/utils/node_utils.h" | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
@@ -201,6 +202,13 @@ Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::functio | |||||
GE_CHECK_NOTNULL(in_x); | GE_CHECK_NOTNULL(in_x); | ||||
GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(0, *in_x)); // y | GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(0, *in_x)); // y | ||||
const auto &node_state = task_context.GetNodeState(); | |||||
if (kNextIterationOpTypes.count(node_state->GetType()) > 0) { | |||||
node_state->RunLoopNext(); | |||||
} else if (kExitOpTypes.count(node_state->GetType()) > 0) { | |||||
node_state->RunLoopExit(); | |||||
} | |||||
if (done_callback) { | if (done_callback) { | ||||
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | ||||
} | } | ||||
@@ -489,6 +489,11 @@ void TaskContext::ReleaseInputsAndOutputs() { | |||||
} | } | ||||
void TaskContext::ReleaseInput(int index) { | void TaskContext::ReleaseInput(int index) { | ||||
if (node_item_->enter_inside_.count(index) > 0) { | |||||
GELOGD("[%s] Tensor of input[%d] is enter, keep it", GetNodeName(), index); | |||||
return; | |||||
} | |||||
auto input_tensor = MutableInput(index); | auto input_tensor = MutableInput(index); | ||||
if (input_tensor != nullptr) { | if (input_tensor != nullptr) { | ||||
input_tensor->Destroy(); | input_tensor->Destroy(); | ||||
@@ -1 +1 @@ | |||||
Subproject commit 23718da69af64f8a57051ee64d5515ae1e103c70 | |||||
Subproject commit 7ef25103b99c322e77b1fa7e0d6bd7b68b4acb6b |
@@ -1 +1 @@ | |||||
Subproject commit 9bb03f21773f028b07d5a912db6f176268962c7d | |||||
Subproject commit a796624b45b01d3d216b9b0e1ac74915b8c483b9 |
@@ -86,7 +86,7 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||||
* | | * | | ||||
* Merge | * Merge | ||||
* / \. | * / \. | ||||
* / \. | |||||
* Active / \ Active | |||||
* / \. | * / \. | ||||
* Add Sub | * Add Sub | ||||
* | \ / | | * | \ / | | ||||
@@ -96,8 +96,8 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||||
* Switch Switch | * Switch Switch | ||||
* | \ / | | * | \ / | | ||||
* | \ / | | * | \ / | | ||||
* | \ / | | |||||
* | \ / | | |||||
* | Active | | |||||
* | \ / | | |||||
* | Less | | * | Less | | ||||
* | / \ | | * | / \ | | ||||
* | / \ | | * | / \ | | ||||
@@ -127,7 +127,7 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||||
AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); | AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); | ||||
} | } | ||||
const auto less1 = CreateNode(graph, "less", ENTER, 2, 1); | |||||
const auto less1 = CreateNode(graph, "less", EXIT, 2, 1); // Mock for less, just pass input0. | |||||
const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); | const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); | ||||
switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); | switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); | ||||
@@ -135,13 +135,14 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||||
AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. | AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. | ||||
AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); | AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); | ||||
const auto add1 = CreateNode(graph, "add", ENTER, 2, 1); | |||||
const auto sub1 = CreateNode(graph, "sub", ENTER, 2, 1); | |||||
const auto add1 = CreateNode(graph, "add", EXIT, 2, 1); // Mock for add, just pass input0. | |||||
const auto sub1 = CreateNode(graph, "sub", EXIT, 2, 1); // Mock for sub, just pass input0. | |||||
const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); | const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); | ||||
const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); | const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); | ||||
const auto active3 = CreateNode(graph, "active3", STREAMACTIVE, 0, 0); | const auto active3 = CreateNode(graph, "active3", STREAMACTIVE, 0, 0); | ||||
const auto iteration1 = CreateNode(graph, "iteration1", NEXTITERATION, 1, 1); | |||||
const auto output1 = CreateNode(graph, "net_output", NETOUTPUT, 1, 1); | const auto output1 = CreateNode(graph, "net_output", NETOUTPUT, 1, 1); | ||||
output1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | output1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); | ||||
@@ -170,7 +171,8 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt | |||||
GraphUtils::AddEdge(sub1->GetOutControlAnchor(), active3->GetInControlAnchor()); | GraphUtils::AddEdge(sub1->GetOutControlAnchor(), active3->GetInControlAnchor()); | ||||
GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); | GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); | ||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), iteration1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(iteration1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
} | } | ||||
TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) { | TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) { | ||||
@@ -28,6 +28,7 @@ | |||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/ge_local_context.h" | #include "graph/ge_local_context.h" | ||||
#include "graph/common/omg_util.h" | |||||
using namespace std; | using namespace std; | ||||
using namespace testing; | using namespace testing; | ||||
@@ -157,7 +158,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { | |||||
GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); | GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); | ||||
GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | ||||
AttrUtils::SetStr(merge1->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next1->GetName()); | |||||
SetNextIteration(merge1, next1); | |||||
AttrUtils::SetBool(enter1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); | AttrUtils::SetBool(enter1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); | ||||
AttrUtils::SetBool(output1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); | AttrUtils::SetBool(output1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true); | ||||