@@ -275,21 +275,6 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) { | |||
} | |||
/// | |||
/// @brief Set Op _force_unknown_shape flag | |||
/// @param [in] node | |||
/// @param [in] force_unknown, set attribute if true | |||
/// @param [in] group_index, condition group index of node. | |||
/// @return | |||
/// | |||
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) { | |||
if (!force_unknown) { | |||
return; | |||
} | |||
SetControlFlowGroup(node, group_index); | |||
} | |||
/// | |||
/// @brief Set Op _control_flow_group flag | |||
/// @param [in] node | |||
/// @param [in] group, condition group index of node. | |||
@@ -126,15 +126,6 @@ Status GetMemorySize(const NodePtr &node, int64_t &output_size); | |||
bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); | |||
/// | |||
/// @brief Set Op _force_unknown_shape flag | |||
/// @param [in] node | |||
/// @param [in] force_unknown, set attribute if true | |||
/// @param [in] group_index, condition group index of node. | |||
/// @return | |||
/// | |||
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); | |||
/// | |||
/// @brief Set Op _control_flow_group flag | |||
/// @param [in] node | |||
/// @param [in] group, condition group index of node. | |||
@@ -132,38 +132,18 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||
/// @return | |||
/// | |||
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { | |||
std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | |||
return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP); | |||
}; | |||
for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) { | |||
const auto &op_node1 = it1->first; | |||
const auto &op_desc1 = op_node1->GetOpDesc(); | |||
if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||
for (auto it = switch_groups.begin(); it != switch_groups.end(); ++it) { | |||
const auto &op_node = it->first; | |||
const auto &op_desc = op_node->GetOpDesc(); | |||
if (op_desc->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||
continue; | |||
} | |||
int64_t group_index = op_desc1->GetId(); | |||
GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index); | |||
SetControlFlowGroup(op_node1, group_index); | |||
for (const auto &n : it1->second) { | |||
int64_t group_index = op_desc->GetId(); | |||
SetControlFlowGroup(op_node, group_index); | |||
for (const auto &n : it->second) { | |||
SetControlFlowGroup(n, group_index); | |||
} | |||
for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) { | |||
const auto &op_node2 = it2->first; | |||
const auto &op_desc2 = op_node2->GetOpDesc(); | |||
if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||
continue; | |||
} | |||
if (std::any_of(it2->second.begin(), it2->second.end(), callback)) { | |||
SetControlFlowGroup(op_node2, group_index); | |||
for (const auto &n : it2->second) { | |||
SetControlFlowGroup(n, group_index); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} // namespace ge |
@@ -89,8 +89,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | |||
return FAILED, "[Check][Param] Param of pre node is nullptr."); | |||
int64_t group_index = -1; | |||
bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
MarkForceUnknownShape(node, force_unknown, group_index); | |||
(void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | |||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | |||
@@ -109,7 +108,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||
GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
MarkForceUnknownShape(active_node, force_unknown, group_index); | |||
SetControlFlowGroup(active_node, group_index); | |||
} | |||
return SUCCESS; | |||
@@ -395,8 +395,8 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & | |||
peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); | |||
int64_t group_index = -1; | |||
bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
MarkForceUnknownShape(stream_switch, force_unknown, group_index); | |||
(void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
SetControlFlowGroup(stream_switch, group_index); | |||
return stream_switch; | |||
} | |||
@@ -491,8 +491,8 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { | |||
Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { | |||
for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { | |||
for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { | |||
std::list<NodePtr> false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; | |||
std::list<NodePtr> true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; | |||
const std::list<NodePtr> &false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; | |||
const std::list<NodePtr> &true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; | |||
std::set<NodePtr> same_cond_switch; | |||
same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); | |||
same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); | |||
@@ -524,13 +524,13 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||
std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) { | |||
return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||
}; | |||
bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | |||
MarkForceUnknownShape(active_node, is_unknown_shape, group_index); | |||
(void)std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | |||
SetControlFlowGroup(active_node, group_index); | |||
const std::string &cond_group = cond_node->GetName(); | |||
for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | |||
bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); | |||
std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); | |||
const std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); | |||
GE_IF_BOOL_EXEC(switch_list.empty(), continue); | |||
// select first stream_switch | |||
@@ -559,7 +559,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||
"[Add][Edge] between %s and %s failed.", | |||
cast_node->GetName().c_str(), stream_switch->GetName().c_str()); | |||
MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index); | |||
SetControlFlowGroup(stream_switch, group_index); | |||
for (const NodePtr &node : switch_list) { | |||
GE_IF_BOOL_EXEC(node != stream_switch, { | |||
GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), | |||
@@ -317,9 +317,9 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||
return task_context_; | |||
} | |||
void NodeState::SaveRootTensor(int input_idx, const TensorValue &tensor) { | |||
void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { | |||
if (node_item_->root_data_.count(input_idx) > 0) { | |||
GELOGD("[%s] Save Const input tensor: %d", GetName().c_str(), input_idx); | |||
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | |||
root_tensor_values_[input_idx] = tensor; | |||
} | |||
@@ -329,7 +329,7 @@ void NodeState::SaveRootTensor(int input_idx, const TensorValue &tensor) { | |||
} | |||
} | |||
void NodeState::UpdateRootTensor(int input_idx) { | |||
void NodeState::UpdatePersistTensor(int input_idx) { | |||
const auto it = root_tensor_values_.find(input_idx); | |||
if (it == root_tensor_values_.end()) { | |||
GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx); | |||
@@ -355,14 +355,14 @@ void NodeState::ResetContext(uint64_t iteration) { | |||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | |||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | |||
for (auto item : node_item_->root_data_) { | |||
UpdateRootTensor(item.first); | |||
UpdatePersistTensor(item.first); | |||
} | |||
if (iteration > 0) { | |||
data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); | |||
ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); | |||
for (auto item : node_item_->enter_data_) { | |||
UpdateRootTensor(item.first); | |||
UpdatePersistTensor(item.first); | |||
} | |||
} | |||
@@ -129,7 +129,7 @@ struct NodeState { | |||
void RunStreamActive(); | |||
void RunNextIteration(); | |||
void SaveRootTensor(int input_idx, const TensorValue &tensor); | |||
void SavePersistTensor(int input_idx, const TensorValue &tensor); | |||
Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | |||
@@ -189,7 +189,7 @@ struct NodeState { | |||
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | |||
void ResetContext(uint64_t iteration); | |||
void ScheduleContext(const NodeState &node_state); | |||
void UpdateRootTensor(int input_idx); | |||
void UpdatePersistTensor(int input_idx); | |||
const NodeItem *node_item_ = nullptr; | |||
std::shared_ptr<NodeTask> kernel_task_ = nullptr; | |||
@@ -461,7 +461,7 @@ Status TaskContext::PropagateOutputs() { | |||
auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); | |||
GE_CHECK_NOTNULL(dst_node_state); | |||
dst_node_state->SaveRootTensor(dst_input_idx, *tensor); | |||
dst_node_state->SavePersistTensor(dst_input_idx, *tensor); | |||
} | |||
} | |||
(void)guard; | |||