@@ -275,21 +275,6 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) { | |||||
} | } | ||||
/// | /// | ||||
/// @brief Set Op _force_unknown_shape flag | |||||
/// @param [in] node | |||||
/// @param [in] force_unknown, set attribute if true | |||||
/// @param [in] group_index, condition group index of node. | |||||
/// @return | |||||
/// | |||||
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) { | |||||
if (!force_unknown) { | |||||
return; | |||||
} | |||||
SetControlFlowGroup(node, group_index); | |||||
} | |||||
/// | |||||
/// @brief Set Op _control_flow_group flag | /// @brief Set Op _control_flow_group flag | ||||
/// @param [in] node | /// @param [in] node | ||||
/// @param [in] group, condition group index of node. | /// @param [in] group, condition group index of node. | ||||
@@ -126,15 +126,6 @@ Status GetMemorySize(const NodePtr &node, int64_t &output_size); | |||||
bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); | bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); | ||||
/// | /// | ||||
/// @brief Set Op _force_unknown_shape flag | |||||
/// @param [in] node | |||||
/// @param [in] force_unknown, set attribute if true | |||||
/// @param [in] group_index, condition group index of node. | |||||
/// @return | |||||
/// | |||||
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); | |||||
/// | |||||
/// @brief Set Op _control_flow_group flag | /// @brief Set Op _control_flow_group flag | ||||
/// @param [in] node | /// @param [in] node | ||||
/// @param [in] group, condition group index of node. | /// @param [in] group, condition group index of node. | ||||
@@ -132,38 +132,18 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||||
/// @return | /// @return | ||||
/// | /// | ||||
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { | void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { | ||||
std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | |||||
return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP); | |||||
}; | |||||
for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) { | |||||
const auto &op_node1 = it1->first; | |||||
const auto &op_desc1 = op_node1->GetOpDesc(); | |||||
if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||||
for (auto it = switch_groups.begin(); it != switch_groups.end(); ++it) { | |||||
const auto &op_node = it->first; | |||||
const auto &op_desc = op_node->GetOpDesc(); | |||||
if (op_desc->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||||
continue; | continue; | ||||
} | } | ||||
int64_t group_index = op_desc1->GetId(); | |||||
GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index); | |||||
SetControlFlowGroup(op_node1, group_index); | |||||
for (const auto &n : it1->second) { | |||||
int64_t group_index = op_desc->GetId(); | |||||
SetControlFlowGroup(op_node, group_index); | |||||
for (const auto &n : it->second) { | |||||
SetControlFlowGroup(n, group_index); | SetControlFlowGroup(n, group_index); | ||||
} | } | ||||
for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) { | |||||
const auto &op_node2 = it2->first; | |||||
const auto &op_desc2 = op_node2->GetOpDesc(); | |||||
if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||||
continue; | |||||
} | |||||
if (std::any_of(it2->second.begin(), it2->second.end(), callback)) { | |||||
SetControlFlowGroup(op_node2, group_index); | |||||
for (const auto &n : it2->second) { | |||||
SetControlFlowGroup(n, group_index); | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -89,8 +89,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||||
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | ||||
return FAILED, "[Check][Param] Param of pre node is nullptr."); | return FAILED, "[Check][Param] Param of pre node is nullptr."); | ||||
int64_t group_index = -1; | int64_t group_index = -1; | ||||
bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
MarkForceUnknownShape(node, force_unknown, group_index); | |||||
(void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | ||||
@@ -109,7 +108,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||||
GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str()); | GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
MarkForceUnknownShape(active_node, force_unknown, group_index); | |||||
SetControlFlowGroup(active_node, group_index); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -395,8 +395,8 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & | |||||
peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); | peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); | ||||
int64_t group_index = -1; | int64_t group_index = -1; | ||||
bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
MarkForceUnknownShape(stream_switch, force_unknown, group_index); | |||||
(void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
SetControlFlowGroup(stream_switch, group_index); | |||||
return stream_switch; | return stream_switch; | ||||
} | } | ||||
@@ -491,8 +491,8 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { | |||||
Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { | Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { | ||||
for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { | for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { | ||||
for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { | for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { | ||||
std::list<NodePtr> false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; | |||||
std::list<NodePtr> true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; | |||||
const std::list<NodePtr> &false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; | |||||
const std::list<NodePtr> &true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; | |||||
std::set<NodePtr> same_cond_switch; | std::set<NodePtr> same_cond_switch; | ||||
same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); | same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); | ||||
same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); | same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); | ||||
@@ -524,13 +524,13 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||||
std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) { | std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) { | ||||
return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | ||||
}; | }; | ||||
bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | |||||
MarkForceUnknownShape(active_node, is_unknown_shape, group_index); | |||||
(void)std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | |||||
SetControlFlowGroup(active_node, group_index); | |||||
const std::string &cond_group = cond_node->GetName(); | const std::string &cond_group = cond_node->GetName(); | ||||
for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | ||||
bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); | bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); | ||||
std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); | |||||
const std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); | |||||
GE_IF_BOOL_EXEC(switch_list.empty(), continue); | GE_IF_BOOL_EXEC(switch_list.empty(), continue); | ||||
// select first stream_switch | // select first stream_switch | ||||
@@ -559,7 +559,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||||
"[Add][Edge] between %s and %s failed.", | "[Add][Edge] between %s and %s failed.", | ||||
cast_node->GetName().c_str(), stream_switch->GetName().c_str()); | cast_node->GetName().c_str(), stream_switch->GetName().c_str()); | ||||
MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index); | |||||
SetControlFlowGroup(stream_switch, group_index); | |||||
for (const NodePtr &node : switch_list) { | for (const NodePtr &node : switch_list) { | ||||
GE_IF_BOOL_EXEC(node != stream_switch, { | GE_IF_BOOL_EXEC(node != stream_switch, { | ||||
GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), | GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), | ||||
@@ -317,9 +317,9 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||||
return task_context_; | return task_context_; | ||||
} | } | ||||
void NodeState::SaveRootTensor(int input_idx, const TensorValue &tensor) { | |||||
void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { | |||||
if (node_item_->root_data_.count(input_idx) > 0) { | if (node_item_->root_data_.count(input_idx) > 0) { | ||||
GELOGD("[%s] Save Const input tensor: %d", GetName().c_str(), input_idx); | |||||
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | |||||
root_tensor_values_[input_idx] = tensor; | root_tensor_values_[input_idx] = tensor; | ||||
} | } | ||||
@@ -329,7 +329,7 @@ void NodeState::SaveRootTensor(int input_idx, const TensorValue &tensor) { | |||||
} | } | ||||
} | } | ||||
void NodeState::UpdateRootTensor(int input_idx) { | |||||
void NodeState::UpdatePersistTensor(int input_idx) { | |||||
const auto it = root_tensor_values_.find(input_idx); | const auto it = root_tensor_values_.find(input_idx); | ||||
if (it == root_tensor_values_.end()) { | if (it == root_tensor_values_.end()) { | ||||
GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx); | GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx); | ||||
@@ -355,14 +355,14 @@ void NodeState::ResetContext(uint64_t iteration) { | |||||
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | ||||
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); | ||||
for (auto item : node_item_->root_data_) { | for (auto item : node_item_->root_data_) { | ||||
UpdateRootTensor(item.first); | |||||
UpdatePersistTensor(item.first); | |||||
} | } | ||||
if (iteration > 0) { | if (iteration > 0) { | ||||
data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); | data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); | ||||
ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); | ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); | ||||
for (auto item : node_item_->enter_data_) { | for (auto item : node_item_->enter_data_) { | ||||
UpdateRootTensor(item.first); | |||||
UpdatePersistTensor(item.first); | |||||
} | } | ||||
} | } | ||||
@@ -129,7 +129,7 @@ struct NodeState { | |||||
void RunStreamActive(); | void RunStreamActive(); | ||||
void RunNextIteration(); | void RunNextIteration(); | ||||
void SaveRootTensor(int input_idx, const TensorValue &tensor); | |||||
void SavePersistTensor(int input_idx, const TensorValue &tensor); | |||||
Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; | ||||
@@ -189,7 +189,7 @@ struct NodeState { | |||||
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); | ||||
void ResetContext(uint64_t iteration); | void ResetContext(uint64_t iteration); | ||||
void ScheduleContext(const NodeState &node_state); | void ScheduleContext(const NodeState &node_state); | ||||
void UpdateRootTensor(int input_idx); | |||||
void UpdatePersistTensor(int input_idx); | |||||
const NodeItem *node_item_ = nullptr; | const NodeItem *node_item_ = nullptr; | ||||
std::shared_ptr<NodeTask> kernel_task_ = nullptr; | std::shared_ptr<NodeTask> kernel_task_ = nullptr; | ||||
@@ -461,7 +461,7 @@ Status TaskContext::PropagateOutputs() { | |||||
auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); | auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); | ||||
GE_CHECK_NOTNULL(dst_node_state); | GE_CHECK_NOTNULL(dst_node_state); | ||||
dst_node_state->SaveRootTensor(dst_input_idx, *tensor); | |||||
dst_node_state->SavePersistTensor(dst_input_idx, *tensor); | |||||
} | } | ||||
} | } | ||||
(void)guard; | (void)guard; | ||||