diff --git a/ge/graph/common/omg_util.cc b/ge/graph/common/omg_util.cc index 52e6cb9c..b2017e4d 100644 --- a/ge/graph/common/omg_util.cc +++ b/ge/graph/common/omg_util.cc @@ -275,21 +275,6 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) { } /// -/// @brief Set Op _force_unknown_shape flag -/// @param [in] node -/// @param [in] force_unknown, set attribute if true -/// @param [in] group_index, condition group index of node. -/// @return -/// -void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) { - if (!force_unknown) { - return; - } - - SetControlFlowGroup(node, group_index); -} - -/// /// @brief Set Op _control_flow_group flag /// @param [in] node /// @param [in] group, condition group index of node. diff --git a/ge/graph/common/omg_util.h b/ge/graph/common/omg_util.h index 148e4102..edaafa45 100644 --- a/ge/graph/common/omg_util.h +++ b/ge/graph/common/omg_util.h @@ -126,15 +126,6 @@ Status GetMemorySize(const NodePtr &node, int64_t &output_size); bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); /// -/// @brief Set Op _force_unknown_shape flag -/// @param [in] node -/// @param [in] force_unknown, set attribute if true -/// @param [in] group_index, condition group index of node. -/// @return -/// -void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); - -/// /// @brief Set Op _control_flow_group flag /// @param [in] node /// @param [in] group, condition group index of node. diff --git a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc index b9ee26db..74babadc 100644 --- a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc @@ -132,38 +132,18 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: /// @return /// void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map> &switch_groups) { - std::function callback = [](const NodePtr &n) { - return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP); - }; - - for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) { - const auto &op_node1 = it1->first; - const auto &op_desc1 = op_node1->GetOpDesc(); - if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { + for (auto it = switch_groups.begin(); it != switch_groups.end(); ++it) { + const auto &op_node = it->first; + const auto &op_desc = op_node->GetOpDesc(); + if (op_desc->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { continue; } - int64_t group_index = op_desc1->GetId(); - GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index); - SetControlFlowGroup(op_node1, group_index); - for (const auto &n : it1->second) { + int64_t group_index = op_desc->GetId(); + SetControlFlowGroup(op_node, group_index); + for (const auto &n : it->second) { SetControlFlowGroup(n, group_index); } - - for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) { - const auto &op_node2 = it2->first; - const auto &op_desc2 = op_node2->GetOpDesc(); - if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { - continue; - } - - if (std::any_of(it2->second.begin(), it2->second.end(), callback)) { - SetControlFlowGroup(op_node2, group_index); - for (const auto &n : it2->second) { - SetControlFlowGroup(n, group_index); - } - } - } } } } // namespace ge diff --git a/ge/graph/passes/merge_to_stream_merge_pass.cc b/ge/graph/passes/merge_to_stream_merge_pass.cc index 0b383911..dbcff620 100644 --- a/ge/graph/passes/merge_to_stream_merge_pass.cc +++ b/ge/graph/passes/merge_to_stream_merge_pass.cc @@ -89,8 +89,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); return FAILED, "[Check][Param] Param of pre node is nullptr."); int64_t group_index = -1; - bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); - MarkForceUnknownShape(node, force_unknown, group_index); + (void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); @@ -109,7 +108,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str()); return FAILED; } - MarkForceUnknownShape(active_node, force_unknown, group_index); + SetControlFlowGroup(active_node, group_index); } return SUCCESS; diff --git a/ge/graph/passes/switch_to_stream_switch_pass.cc b/ge/graph/passes/switch_to_stream_switch_pass.cc index e7743130..e4ab0111 100644 --- a/ge/graph/passes/switch_to_stream_switch_pass.cc +++ b/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -395,8 +395,8 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); int64_t group_index = -1; - bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); - MarkForceUnknownShape(stream_switch, force_unknown, group_index); + (void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); + SetControlFlowGroup(stream_switch, group_index); return stream_switch; } @@ -491,8 +491,8 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { - std::list false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; - std::list true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; + const std::list &false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; + const std::list &true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; std::set same_cond_switch; same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); @@ -524,13 +524,13 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) std::function callback = [&group_index](const NodePtr &n) { return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); }; - bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); - MarkForceUnknownShape(active_node, is_unknown_shape, group_index); + (void)std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); + SetControlFlowGroup(active_node, group_index); const std::string &cond_group = cond_node->GetName(); for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); - std::list &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); + const std::list &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); GE_IF_BOOL_EXEC(switch_list.empty(), continue); // select first stream_switch @@ -559,7 +559,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) "[Add][Edge] between %s and %s failed.", cast_node->GetName().c_str(), stream_switch->GetName().c_str()); - MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index); + SetControlFlowGroup(stream_switch, group_index); for (const NodePtr &node : switch_list) { GE_IF_BOOL_EXEC(node != stream_switch, { GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), diff --git a/ge/hybrid/executor/node_state.cc b/ge/hybrid/executor/node_state.cc index c2760f04..42e08811 100644 --- a/ge/hybrid/executor/node_state.cc +++ b/ge/hybrid/executor/node_state.cc @@ -317,9 +317,9 @@ std::shared_ptr NodeState::GetTaskContext() { return task_context_; } -void NodeState::SaveRootTensor(int input_idx, const TensorValue &tensor) { +void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { if (node_item_->root_data_.count(input_idx) > 0) { - GELOGD("[%s] Save Const input tensor: %d", GetName().c_str(), input_idx); + GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); root_tensor_values_[input_idx] = tensor; } @@ -329,7 +329,7 @@ void NodeState::SaveRootTensor(int input_idx, const TensorValue &tensor) { } } -void NodeState::UpdateRootTensor(int input_idx) { +void NodeState::UpdatePersistTensor(int input_idx) { const auto it = root_tensor_values_.find(input_idx); if (it == root_tensor_values_.end()) { GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx); @@ -355,14 +355,14 @@ void NodeState::ResetContext(uint64_t iteration) { data_scheduled_ = static_cast(node_item_->root_data_.size()); ctrl_scheduled_ = static_cast(node_item_->root_ctrl_.size()); for (auto item : node_item_->root_data_) { - UpdateRootTensor(item.first); + UpdatePersistTensor(item.first); } if (iteration > 0) { data_scheduled_ += static_cast(node_item_->enter_data_.size()); ctrl_scheduled_ += static_cast(node_item_->enter_ctrl_.size()); for (auto item : node_item_->enter_data_) { - UpdateRootTensor(item.first); + UpdatePersistTensor(item.first); } } diff --git a/ge/hybrid/executor/node_state.h b/ge/hybrid/executor/node_state.h index c6468bbd..72e2b90e 100644 --- a/ge/hybrid/executor/node_state.h +++ b/ge/hybrid/executor/node_state.h @@ -129,7 +129,7 @@ struct NodeState { void RunStreamActive(); void RunNextIteration(); - void SaveRootTensor(int input_idx, const TensorValue &tensor); + void SavePersistTensor(int input_idx, const TensorValue &tensor); Status NodeScheduled(const std::function &ready) const; @@ -189,7 +189,7 @@ struct NodeState { void SetCtrlSchedule(const NodeState &node_state, const std::function &ready); void ResetContext(uint64_t iteration); void ScheduleContext(const NodeState &node_state); - void UpdateRootTensor(int input_idx); + void UpdatePersistTensor(int input_idx); const NodeItem *node_item_ = nullptr; std::shared_ptr kernel_task_ = nullptr; diff --git a/ge/hybrid/node_executor/task_context.cc b/ge/hybrid/node_executor/task_context.cc index 417fcd96..fe580c1e 100644 --- a/ge/hybrid/node_executor/task_context.cc +++ b/ge/hybrid/node_executor/task_context.cc @@ -461,7 +461,7 @@ Status TaskContext::PropagateOutputs() { auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); GE_CHECK_NOTNULL(dst_node_state); - dst_node_state->SaveRootTensor(dst_input_idx, *tensor); + dst_node_state->SavePersistTensor(dst_input_idx, *tensor); } } (void)guard;