@@ -16,8 +16,6 @@ | |||
#include "mark_force_unknown_for_cond_pass.h" | |||
#include <queue> | |||
#include "graph/utils/node_utils.h" | |||
#include "graph/common/omg_util.h" | |||
@@ -26,17 +24,7 @@ namespace { | |||
inline bool IsMergeInLoop(const NodePtr &node) { | |||
const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | |||
std::string node_type; | |||
(void)GetOriginalType(node, node_type); | |||
return kLoopMergeInputs.count(node_type) > 0; | |||
} | |||
inline bool IsSwitchInLoop(const NodePtr &node) { | |||
const static std::set<std::string> kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND }; | |||
std::string node_type; | |||
(void)GetOriginalType(node, node_type); | |||
return kLoopSwitchInputs.count(node_type) > 0; | |||
return kLoopMergeInputs.count(NodeUtils::GetNodeType(node)) > 0; | |||
} | |||
} | |||
@@ -44,10 +32,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||
GELOGD("MarkForceUnknownForCondPass Enter"); | |||
std::map<NodePtr, std::vector<NodePtr>> switch_groups; | |||
for (const auto &node : graph->GetDirectNode()) { | |||
std::string node_type; | |||
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), | |||
"[Get][OriginalType] of node in graph:%s failed.", graph->GetName().c_str()); | |||
if (kMergeOpTypes.count(node_type) == 0) { | |||
if (kMergeOpTypes.count(NodeUtils::GetNodeType(node)) == 0) { | |||
continue; | |||
} | |||
@@ -65,6 +50,51 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||
} | |||
/// | |||
/// @brief Deal with Switch node for LoopCond | |||
/// @param [in] Switch node | |||
/// @param [in] dest span | |||
/// @param [out] Search queue | |||
/// @return true: Switch In while loop / false: Not in while Loop. | |||
/// | |||
bool MarkForceUnknownForCondPass::DealWithLoopSwitch(const NodePtr &node, uint32_t dst_span, | |||
std::queue<std::pair<NodePtr, uint32_t>> search_queue) { | |||
/// LoopCond --->\. | |||
/// \. | |||
/// Enter-----------+ \. | |||
/// +--> Merge --> Switch --> Exit | |||
/// NextIteration---+ | |||
const auto is_loop_op = [](const NodePtr &n) { | |||
return NodeUtils::GetNodeType(n) == LOOPCOND; | |||
}; | |||
const auto is_exit_op = [](const NodePtr &n) { | |||
return kExitOpTypes.count(NodeUtils::GetNodeType(n)) > 0; | |||
}; | |||
const auto src_nodes = node->GetInAllNodes(); | |||
const auto dst_nodes = node->GetOutAllNodes(); | |||
if (std::none_of(src_nodes.begin(), src_nodes.end(), is_loop_op) && | |||
std::none_of(dst_nodes.begin(), dst_nodes.end(), is_exit_op)) { | |||
return false; | |||
} | |||
for (const auto &m : src_nodes) { | |||
if (kMergeOpTypes.count(NodeUtils::GetNodeType(m)) > 0) { | |||
for (const auto &n : m->GetInAllNodes()) { | |||
if (kNextIterationOpTypes.count(NodeUtils::GetNodeType(n)) > 0) { | |||
continue; | |||
} | |||
search_queue.push({n, dst_span}); | |||
GELOGD("Travel in Loop: %s <-- %s <-- %s, span is: %u", node->GetName().c_str(), m->GetName().c_str(), | |||
n->GetName().c_str(), dst_span); | |||
} | |||
} | |||
} | |||
return true; | |||
} | |||
/// | |||
/// @brief Mark force unknown shape for Switch node | |||
/// @param [in] merge node | |||
/// @param [out] switch group | |||
@@ -72,6 +102,7 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { | |||
/// | |||
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group) { | |||
// Switch --> {Switch --> Merge} --> Merge | |||
GELOGD("Search Switch node for Merge: %s", node->GetName().c_str()); | |||
std::unordered_set<NodePtr> nodes_seen; | |||
std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}}); | |||
while (!search_queue.empty()) { | |||
@@ -79,43 +110,25 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||
const auto dst_span = search_queue.front().second; | |||
search_queue.pop(); | |||
// Switch --> Identity --> Constant | |||
for (const auto &in_node : dst_node->GetInControlNodes()) { | |||
if (nodes_seen.count(in_node) > 0) { | |||
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | |||
continue; | |||
} | |||
nodes_seen.insert(in_node); | |||
if (in_node->GetType() == IDENTITY) { | |||
GELOGD("Travel node: %s, In control: %s, span is: %u", dst_node->GetName().c_str(), | |||
in_node->GetName().c_str(), dst_span); | |||
search_queue.push({in_node, dst_span}); | |||
} | |||
} | |||
for (const auto &in_node : dst_node->GetInDataNodes()) { | |||
for (const auto &in_node : dst_node->GetInAllNodes()) { | |||
if (nodes_seen.count(in_node) > 0) { | |||
GELOGD("Travel node: %s, Skip already seen node: %s", dst_node->GetName().c_str(), in_node->GetName().c_str()); | |||
continue; | |||
} | |||
nodes_seen.insert(in_node); | |||
std::string node_type; | |||
(void)GetOriginalType(in_node, node_type); | |||
const std::string node_type = NodeUtils::GetNodeType(in_node); | |||
GELOGD("Travel node: %s, %s node: %s, span is: %u", dst_node->GetName().c_str(), node_type.c_str(), | |||
in_node->GetName().c_str(), dst_span); | |||
if (kSwitchOpTypes.count(node_type) > 0) { // Switch input node. | |||
if (DealWithLoopSwitch(in_node, dst_span, search_queue)) { | |||
continue; | |||
} | |||
if (dst_span > 0) { | |||
search_queue.push({in_node, dst_span - 1}); | |||
} else { | |||
const auto &all_in_nodes = in_node->GetInDataNodes(); | |||
if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) { | |||
GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(), | |||
in_node->GetName().c_str()); | |||
} else { | |||
switch_group.emplace_back(in_node); | |||
} | |||
switch_group.emplace_back(in_node); | |||
} | |||
} else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. | |||
search_queue.push({in_node, dst_span + 1}); | |||
@@ -19,6 +19,8 @@ | |||
#include "inc/graph_pass.h" | |||
#include <queue> | |||
namespace ge { | |||
class MarkForceUnknownForCondPass : public GraphPass { | |||
public: | |||
@@ -26,6 +28,15 @@ class MarkForceUnknownForCondPass : public GraphPass { | |||
private: | |||
/// | |||
/// @brief Deal with Switch node for LoopCond | |||
/// @param [in] Switch node | |||
/// @param [in] dest span | |||
/// @param [out] Search queue | |||
/// @return true: Switch In while loop / false: Not in while Loop. | |||
/// | |||
bool DealWithLoopSwitch(const NodePtr &node, uint32_t dst_span, std::queue<std::pair<NodePtr, uint32_t>> search_queue); | |||
/// | |||
/// @brief Mark force unknown shape for Switch node | |||
/// @param [in] merge node | |||
/// @param [out] switch group | |||
@@ -395,8 +395,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & | |||
peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); | |||
int64_t group_index = -1; | |||
(void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
SetControlFlowGroup(stream_switch, group_index); | |||
if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||
SetControlFlowGroup(stream_switch, group_index); | |||
} | |||
return stream_switch; | |||
} | |||
@@ -326,17 +326,37 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||
} | |||
void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { | |||
if (node_item_->root_data_.count(input_idx) > 0) { | |||
const auto is_persist_tensor = [](const std::map<const NodeItem *, std::set<int>> &items, int idx) { | |||
const auto is_exist = [&idx](const std::pair<const NodeItem *, std::set<int>> &items) { | |||
return items.second.count(idx) > 0; | |||
}; | |||
return std::any_of(items.begin(), items.end(), is_exist); | |||
}; | |||
if (is_persist_tensor(node_item_->root_data_, input_idx)) { | |||
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | |||
root_tensor_values_[input_idx] = tensor; | |||
} | |||
if (node_item_->enter_data_.count(input_idx) > 0) { | |||
} else if (is_persist_tensor(node_item_->enter_data_, input_idx)) { | |||
GELOGD("[%s] Save Enter input tensor: %d", GetName().c_str(), input_idx); | |||
root_tensor_values_[input_idx] = tensor; | |||
} | |||
} | |||
void NodeState::UpdatePersistTensor() { | |||
const auto update_tensor = [&](const std::map<const NodeItem *, std::set<int>> &items) { | |||
for (const auto &item : items) { | |||
for (const auto idx : item.second) { | |||
UpdatePersistTensor(idx); | |||
} | |||
} | |||
}; | |||
update_tensor(node_item_->root_data_); | |||
if (iteration_count_ > 0) { | |||
update_tensor(node_item_->enter_data_); | |||
} | |||
} | |||
void NodeState::UpdatePersistTensor(int input_idx) { | |||
const auto it = root_tensor_values_.find(input_idx); | |||
if (it == root_tensor_values_.end()) { | |||
@@ -363,16 +383,9 @@ void NodeState::ResetContext(uint64_t iteration) { | |||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||
for (auto item : node_item_->root_data_) { | |||
UpdatePersistTensor(item.first); | |||
} | |||
if (iteration > 0) { | |||
data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); | |||
ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); | |||
for (auto item : node_item_->enter_data_) { | |||
UpdatePersistTensor(item.first); | |||
} | |||
} | |||
iteration_count_ = iteration; | |||
@@ -132,6 +132,7 @@ struct NodeState { | |||
void RunNextIteration(); | |||
void SavePersistTensor(int input_idx, const TensorValue &tensor); | |||
void UpdatePersistTensor(); | |||
Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | |||
@@ -395,11 +395,13 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||
data_send_.emplace(node_item); | |||
node_item->data_recv_[this] = anchor_index; | |||
if (is_root_node_) { | |||
node_item->root_data_[anchor_index] = this; | |||
auto &data_anchors = node_item->root_data_[this]; | |||
data_anchors.emplace(anchor_index); | |||
} | |||
// If Enter feed Not Merge, take as root Node. | |||
if (IsEnterOp() && (node_item->node_type != STREAMMERGE)) { | |||
node_item->enter_data_[anchor_index] = this; | |||
auto &data_anchors = node_item->enter_data_[this]; | |||
data_anchors.emplace(anchor_index); | |||
} | |||
GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str()); | |||
} | |||
@@ -148,9 +148,9 @@ struct NodeItem { | |||
int64_t frame_index_ = -1; | |||
int64_t parent_frame_ = -1; | |||
std::set<const NodeItem *> root_ctrl_; // Recv ctrl from root node | |||
std::map<int, const NodeItem *> root_data_; // Recv data from root node | |||
std::map<const NodeItem *, std::set<int>> root_data_; // Recv data from root node | |||
std::set<const NodeItem *> enter_ctrl_; // Recv ctrl from Enter node | |||
std::map<int, const NodeItem *> enter_data_; // Recv data from Enter node | |||
std::map<const NodeItem *, std::set<int>> enter_data_; // Recv data from Enter node | |||
std::set<const NodeItem *> data_send_; // Send data notify to | |||
std::map<const NodeItem *, int> data_recv_; // Recv data notify from | |||
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | |||
@@ -39,6 +39,7 @@ const char *const kEngineNameHostCpu = "DNN_VM_HOST_CPU_OP_STORE"; | |||
Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { | |||
GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); | |||
GE_CHK_STATUS_RET_NOLOG(context.AllocateWorkspaces()); | |||
GE_CHK_STATUS_RET_NOLOG(context.UpdatePersistTensor()); | |||
GE_CHK_STATUS_RET_NOLOG(task.UpdateArgs(context)); | |||
return SUCCESS; | |||
} | |||
@@ -468,6 +468,12 @@ Status TaskContext::PropagateOutputs() { | |||
return SUCCESS; | |||
} | |||
Status TaskContext::UpdatePersistTensor() { | |||
GE_CHECK_NOTNULL(node_state_); | |||
node_state_->UpdatePersistTensor(); | |||
return SUCCESS; | |||
} | |||
const void *TaskContext::GetVarBaseAddr() { | |||
return execution_context_->model->GetVarMemBase(); | |||
} | |||
@@ -78,6 +78,7 @@ class TaskContext { | |||
Status AllocateOutputs(AllocationAttr *attr = nullptr); | |||
Status AllocateWorkspaces(); | |||
Status AllocateWorkspace(size_t size, void **buffer, void *ori_addr = nullptr); | |||
Status UpdatePersistTensor(); | |||
bool IsTraceEnabled() const; | |||