Browse Source

Fix ut

pull/1782/head
zhangxiaokun 4 years ago
parent
commit
34d47b8e9a
8 changed files with 25 additions and 70 deletions
  1. +0
    -15
      ge/graph/common/omg_util.cc
  2. +0
    -9
      ge/graph/common/omg_util.h
  3. +7
    -27
      ge/graph/passes/mark_force_unknown_for_cond_pass.cc
  4. +2
    -3
      ge/graph/passes/merge_to_stream_merge_pass.cc
  5. +8
    -8
      ge/graph/passes/switch_to_stream_switch_pass.cc
  6. +5
    -5
      ge/hybrid/executor/node_state.cc
  7. +2
    -2
      ge/hybrid/executor/node_state.h
  8. +1
    -1
      ge/hybrid/node_executor/task_context.cc

+ 0
- 15
ge/graph/common/omg_util.cc View File

@@ -275,21 +275,6 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) {
} }


/// ///
/// @brief Set Op _force_unknown_shape flag
/// @param [in] node
/// @param [in] force_unknown, set attribute if true
/// @param [in] group_index, condition group index of node.
/// @return
///
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) {
if (!force_unknown) {
return;
}

SetControlFlowGroup(node, group_index);
}

///
/// @brief Set Op _control_flow_group flag /// @brief Set Op _control_flow_group flag
/// @param [in] node /// @param [in] node
/// @param [in] group, condition group index of node. /// @param [in] group, condition group index of node.


+ 0
- 9
ge/graph/common/omg_util.h View File

@@ -126,15 +126,6 @@ Status GetMemorySize(const NodePtr &node, int64_t &output_size);
bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc);


/// ///
/// @brief Set Op _force_unknown_shape flag
/// @param [in] node
/// @param [in] force_unknown, set attribute if true
/// @param [in] group_index, condition group index of node.
/// @return
///
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index);

///
/// @brief Set Op _control_flow_group flag /// @brief Set Op _control_flow_group flag
/// @param [in] node /// @param [in] node
/// @param [in] group, condition group index of node. /// @param [in] group, condition group index of node.


+ 7
- 27
ge/graph/passes/mark_force_unknown_for_cond_pass.cc View File

@@ -132,38 +132,18 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std:
/// @return /// @return
/// ///
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) {
std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) {
return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP);
};

for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) {
const auto &op_node1 = it1->first;
const auto &op_desc1 = op_node1->GetOpDesc();
if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
for (auto it = switch_groups.begin(); it != switch_groups.end(); ++it) {
const auto &op_node = it->first;
const auto &op_desc = op_node->GetOpDesc();
if (op_desc->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
continue; continue;
} }


int64_t group_index = op_desc1->GetId();
GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index);
SetControlFlowGroup(op_node1, group_index);
for (const auto &n : it1->second) {
int64_t group_index = op_desc->GetId();
SetControlFlowGroup(op_node, group_index);
for (const auto &n : it->second) {
SetControlFlowGroup(n, group_index); SetControlFlowGroup(n, group_index);
} }

for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) {
const auto &op_node2 = it2->first;
const auto &op_desc2 = op_node2->GetOpDesc();
if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
continue;
}

if (std::any_of(it2->second.begin(), it2->second.end(), callback)) {
SetControlFlowGroup(op_node2, group_index);
for (const auto &n : it2->second) {
SetControlFlowGroup(n, group_index);
}
}
}
} }
} }
} // namespace ge } // namespace ge

+ 2
- 3
ge/graph/passes/merge_to_stream_merge_pass.cc View File

@@ -89,8 +89,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid");
return FAILED, "[Check][Param] Param of pre node is nullptr."); return FAILED, "[Check][Param] Param of pre node is nullptr.");
int64_t group_index = -1; int64_t group_index = -1;
bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
MarkForceUnknownShape(node, force_unknown, group_index);
(void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) {
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue);
@@ -109,7 +108,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons
GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str()); GELOGE(FAILED, "[Set][ActiveLabelList] for node %s failed.", active_node->GetName().c_str());
return FAILED; return FAILED;
} }
MarkForceUnknownShape(active_node, force_unknown, group_index);
SetControlFlowGroup(active_node, group_index);
} }


return SUCCESS; return SUCCESS;


+ 8
- 8
ge/graph/passes/switch_to_stream_switch_pass.cc View File

@@ -395,8 +395,8 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr &
peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str()); peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str());


int64_t group_index = -1; int64_t group_index = -1;
bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
MarkForceUnknownShape(stream_switch, force_unknown, group_index);
(void)AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
SetControlFlowGroup(stream_switch, group_index);
return stream_switch; return stream_switch;
} }


@@ -491,8 +491,8 @@ int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) {
Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) {
for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) {
for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) {
std::list<NodePtr> false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT];
std::list<NodePtr> true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT];
const std::list<NodePtr> &false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT];
const std::list<NodePtr> &true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT];
std::set<NodePtr> same_cond_switch; std::set<NodePtr> same_cond_switch;
same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end());
same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end());
@@ -524,13 +524,13 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)
std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) { std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) {
return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
}; };
bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback);
MarkForceUnknownShape(active_node, is_unknown_shape, group_index);
(void)std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback);
SetControlFlowGroup(active_node, group_index);


const std::string &cond_group = cond_node->GetName(); const std::string &cond_group = cond_node->GetName();
for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) {
bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT);
std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list);
const std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list);
GE_IF_BOOL_EXEC(switch_list.empty(), continue); GE_IF_BOOL_EXEC(switch_list.empty(), continue);


// select first stream_switch // select first stream_switch
@@ -559,7 +559,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph)
"[Add][Edge] between %s and %s failed.", "[Add][Edge] between %s and %s failed.",
cast_node->GetName().c_str(), stream_switch->GetName().c_str()); cast_node->GetName().c_str(), stream_switch->GetName().c_str());


MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index);
SetControlFlowGroup(stream_switch, group_index);
for (const NodePtr &node : switch_list) { for (const NodePtr &node : switch_list) {
GE_IF_BOOL_EXEC(node != stream_switch, { GE_IF_BOOL_EXEC(node != stream_switch, {
GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)),


+ 5
- 5
ge/hybrid/executor/node_state.cc View File

@@ -317,9 +317,9 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() {
return task_context_; return task_context_;
} }


void NodeState::SaveRootTensor(int input_idx, const TensorValue &tensor) {
void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) {
if (node_item_->root_data_.count(input_idx) > 0) { if (node_item_->root_data_.count(input_idx) > 0) {
GELOGD("[%s] Save Const input tensor: %d", GetName().c_str(), input_idx);
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx);
root_tensor_values_[input_idx] = tensor; root_tensor_values_[input_idx] = tensor;
} }


@@ -329,7 +329,7 @@ void NodeState::SaveRootTensor(int input_idx, const TensorValue &tensor) {
} }
} }


void NodeState::UpdateRootTensor(int input_idx) {
void NodeState::UpdatePersistTensor(int input_idx) {
const auto it = root_tensor_values_.find(input_idx); const auto it = root_tensor_values_.find(input_idx);
if (it == root_tensor_values_.end()) { if (it == root_tensor_values_.end()) {
GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx); GELOGW("[%s] Not found saved tensor: %d", GetName().c_str(), input_idx);
@@ -355,14 +355,14 @@ void NodeState::ResetContext(uint64_t iteration) {
data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size());
ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size()); ctrl_scheduled_ = static_cast<uint32_t>(node_item_->root_ctrl_.size());
for (auto item : node_item_->root_data_) { for (auto item : node_item_->root_data_) {
UpdateRootTensor(item.first);
UpdatePersistTensor(item.first);
} }


if (iteration > 0) { if (iteration > 0) {
data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size()); data_scheduled_ += static_cast<uint32_t>(node_item_->enter_data_.size());
ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size()); ctrl_scheduled_ += static_cast<uint32_t>(node_item_->enter_ctrl_.size());
for (auto item : node_item_->enter_data_) { for (auto item : node_item_->enter_data_) {
UpdateRootTensor(item.first);
UpdatePersistTensor(item.first);
} }
} }




+ 2
- 2
ge/hybrid/executor/node_state.h View File

@@ -129,7 +129,7 @@ struct NodeState {
void RunStreamActive(); void RunStreamActive();
void RunNextIteration(); void RunNextIteration();


void SaveRootTensor(int input_idx, const TensorValue &tensor);
void SavePersistTensor(int input_idx, const TensorValue &tensor);


Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const; Status NodeScheduled(const std::function<void(const NodeItem *)> &ready) const;


@@ -189,7 +189,7 @@ struct NodeState {
void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready); void SetCtrlSchedule(const NodeState &node_state, const std::function<void(const NodeItem *)> &ready);
void ResetContext(uint64_t iteration); void ResetContext(uint64_t iteration);
void ScheduleContext(const NodeState &node_state); void ScheduleContext(const NodeState &node_state);
void UpdateRootTensor(int input_idx);
void UpdatePersistTensor(int input_idx);


const NodeItem *node_item_ = nullptr; const NodeItem *node_item_ = nullptr;
std::shared_ptr<NodeTask> kernel_task_ = nullptr; std::shared_ptr<NodeTask> kernel_task_ = nullptr;


+ 1
- 1
ge/hybrid/node_executor/task_context.cc View File

@@ -461,7 +461,7 @@ Status TaskContext::PropagateOutputs() {


auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item);
GE_CHECK_NOTNULL(dst_node_state); GE_CHECK_NOTNULL(dst_node_state);
dst_node_state->SaveRootTensor(dst_input_idx, *tensor);
dst_node_state->SavePersistTensor(dst_input_idx, *tensor);
} }
} }
(void)guard; (void)guard;


Loading…
Cancel
Save