diff --git a/ge/graph/passes/mark_branch_force_unknown_pass.cc b/ge/graph/passes/mark_branch_force_unknown_pass.cc index 4b00b24d..c4c5d1dd 100644 --- a/ge/graph/passes/mark_branch_force_unknown_pass.cc +++ b/ge/graph/passes/mark_branch_force_unknown_pass.cc @@ -40,7 +40,7 @@ Status MarkBranchForceUnknownPass::Run(ComputeGraphPtr graph) { for (const auto &node : graph->GetDirectNode()) { std::string node_type; GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); - if ((node_type != MERGE) && (node_type != REFMERGE)) { + if (kMergeOpTypes.count(node_type) == 0) { continue; } diff --git a/ge/graph/passes/merge_input_memcpy_pass.cc b/ge/graph/passes/merge_input_memcpy_pass.cc index ce38a3dd..c4273584 100644 --- a/ge/graph/passes/merge_input_memcpy_pass.cc +++ b/ge/graph/passes/merge_input_memcpy_pass.cc @@ -16,18 +16,11 @@ #include "graph/passes/merge_input_memcpy_pass.h" -#include - #include "common/ge/ge_util.h" #include "ge/ge_api_types.h" #include "graph/common/omg_util.h" namespace ge { -namespace { -const std::set kLoopMergeInputs{ - ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION -}; -} Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) { GELOGD("MergeInputMemcpyPass Enter"); std::unordered_map> switch_groups; @@ -41,10 +34,8 @@ Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(node->GetOpDesc()); GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, node, node->GetOpDesc()->HasAttr(ATTR_INSERT_BY_MBATCH)), "Merge add memcpy node failed."); - CollectSwitchGroup(node, switch_groups); } - MarkUnknownForSwitch(switch_groups); GELOGD("MergeInputMemcpyPass Leave"); return SUCCESS; } @@ -114,94 +105,4 @@ NodePtr MergeInputMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph return graph->AddNode(op_desc); } - -/// -/// @brief Mark force unknown shape for Switch node -/// @param [in] merge node -/// @param [out] switch_groups -/// @return -/// -void MergeInputMemcpyPass::CollectSwitchGroup(const NodePtr &node, - std::unordered_map> &switch_groups) { - const auto &op_desc = node->GetOpDesc(); - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - const auto &src_out_anchor = in_anchor->GetPeerOutAnchor(); - if (src_out_anchor == nullptr) { - continue; - } - - std::string node_type; - GetOriginalType(src_out_anchor->GetOwnerNode(), node_type); - if (kLoopMergeInputs.count(node_type) > 0) { - return; - } - } - - // Switch --> {Switch --> Merge} --> Merge - std::queue> search_queue; - search_queue.push({node, 0}); - std::vector &switch_group = switch_groups[node]; - while (!search_queue.empty()) { - const auto dst_node = search_queue.front().first; - const auto dst_span = search_queue.front().second; - search_queue.pop(); - - // Switch --> Identity --> Constant - for (const auto &in_ctrl_node : dst_node->GetInControlNodes()) { - if (in_ctrl_node->GetType() == IDENTITY) { - GELOGD("Travel node: %s, In control: %s, span is: %u", - dst_node->GetName().c_str(), in_ctrl_node->GetName().c_str(), dst_span); - search_queue.push({in_ctrl_node, dst_span}); - } - } - - for (const auto &in_data_node : dst_node->GetInDataNodes()) { - std::string node_type; - GetOriginalType(in_data_node, node_type); - GELOGD("Travel node: %s, %s node: %s, span is: %u", - dst_node->GetName().c_str(), node_type.c_str(), in_data_node->GetName().c_str(), dst_span); - if (node_type == SWITCH || node_type == REFSWITCH) { - if (dst_span > 0) { - search_queue.push({in_data_node, dst_span - 1}); - } else { - switch_group.emplace_back(in_data_node); - } - } else if (node_type == MERGE || node_type == REFMERGE) { - search_queue.push({in_data_node, dst_span + 1}); - } else { - search_queue.push({in_data_node, dst_span}); - } - } - } - - if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0)) || op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)) { - GELOGI("Mark [%s] as for unknown shape, switch groups: %zu", node->GetName().c_str(), switch_groups.size()); - MarkForceUnknownShape(node, true); - for (const auto &n : switch_group) { - MarkForceUnknownShape(n, true); - } - } -} - -void MergeInputMemcpyPass::MarkUnknownForSwitch(const std::unordered_map> &switch_groups) { - std::function callback = [](const NodePtr &n) { - return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); - }; - - for (const auto &item : switch_groups) { - const auto &node = item.first; - if (node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)) { - continue; - } - - const std::vector &switch_group = item.second; - if (std::any_of(switch_group.begin(), switch_group.end(), callback)) { - GELOGI("Mark [%s] as force unknown shape, switch nodes: %zu", node->GetName().c_str(), switch_group.size()); - MarkForceUnknownShape(node, true); - for (const auto &n : switch_group) { - MarkForceUnknownShape(n, true); - } - } - } -} } // namespace ge diff --git a/ge/graph/passes/merge_input_memcpy_pass.h b/ge/graph/passes/merge_input_memcpy_pass.h index 2c7636ea..b8c6f0b8 100644 --- a/ge/graph/passes/merge_input_memcpy_pass.h +++ b/ge/graph/passes/merge_input_memcpy_pass.h @@ -44,21 +44,6 @@ class MergeInputMemcpyPass : public GraphPass { /// NodePtr CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); - - /// - /// @brief Mark force unknown shape for Switch node - /// @param [in] merge node - /// @param [out] switch_groups - /// @return - /// - void CollectSwitchGroup(const NodePtr &node, std::unordered_map> &switch_groups); - - /// - /// @brief Mark force unknown shape for Switch node - /// @param [in] switch_groups - /// @return - /// - void MarkUnknownForSwitch(const std::unordered_map> &switch_groups); }; } // namespace ge #endif // GE_GRAPH_PASSES_MERGE_ADD_INPUT_MEMCPY_PASS_H_