diff --git a/ge/graph/passes/mark_agnostic_pass.cc b/ge/graph/passes/mark_agnostic_pass.cc index 77fa64fb..00a3dad9 100644 --- a/ge/graph/passes/mark_agnostic_pass.cc +++ b/ge/graph/passes/mark_agnostic_pass.cc @@ -21,7 +21,7 @@ namespace ge { Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { for (const auto &node : graph->GetDirectNode()) { auto node_type = NodeUtils::GetNodeType(*node); - if (node_type == SWITCH || node_type == REFSWITCH || node_type == SWITCHN) { + if (node_type == SWITCH || node_type == SWITCHN) { GELOGD("Mark format agnostic and continuous for switch node %s", node->GetName().c_str()); const OpDescPtr op_desc = node->GetOpDesc(); const GeTensorDescPtr op_tensor = op_desc->MutableInputDesc(0); @@ -37,10 +37,15 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { if (node_type == IDENTITY) { GELOGD("Mark format agnostic for identity node %s", node->GetName().c_str()); AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); + continue; + } + if (node_type == REFMERGE || node_type == REFSWITCH) { + GELOGD("Mark format agnostic for regmerge and refswitch node %s", node->GetName().c_str()); + AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector({1})); continue; } - if (node_type == MERGE || node_type == REFMERGE) { + if (node_type == MERGE) { GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); const OpDescPtr op_desc = node->GetOpDesc(); const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0); diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index f90c0d80..98371426 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -117,6 +117,7 @@ #include "graph/passes/variable_op_pass.h" #include "graph/passes/variable_prepare_op_pass.h" #include "graph/passes/variable_ref_delete_op_pass.h" +#include "graph/passes/mark_agnostic_pass.h" namespace ge { @@ -1626,6 +1627,7 @@ Status GraphPrepare::PrepareOptimize() { try { (void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass); (void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass); + (void)original_graph_passes.AddPass("PrepareOptimize::MarkAgnosticPass", new MarkAgnosticPass); } catch (std::bad_alloc &e) { GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); return INTERNAL_ERROR; diff --git a/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc b/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc index 66235d52..a1eb104d 100755 --- a/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc +++ b/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc @@ -40,8 +40,6 @@ using domi::AippOpParams; namespace ge { namespace { const char *const kMbatchSwitchnName = "mbatch-switch-name"; -const int64_t kFormatAgnosticSwitch = 1; -const int64_t kFormatDependInputIndex = 1; } // namespace static void ConvertShape2Nhwc(Format &format, vector &shape_vec) { if ((format == FORMAT_NHWC) || (shape_vec.size() != static_cast(NORMAL_TENSOR_SIZE))) { @@ -269,23 +267,6 @@ Status InsertNewOpUtil::GetAippParams(const std::unique_ptr return SUCCESS; } -Status InsertNewOpUtil::AddFormatAgnosticAttrToSwitchn(const NodePtr &aipp_node) { - GE_CHECK_NOTNULL(aipp_node); - auto next_nodes = aipp_node->GetOutDataNodes(); - for (const auto next_node : next_nodes) { - GE_CHECK_NOTNULL(next_node); - auto op_desc = next_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - if (op_desc->GetType() == SWITCHN) { - GELOGI("Find switchn node [%s] after aipp [%s]", op_desc->GetName().c_str(), aipp_node->GetName().c_str()); - (void)AttrUtils::SetInt(op_desc, "_format_agnostic", kFormatAgnosticSwitch); - (void)AttrUtils::SetListInt(op_desc, "_format_agnostic_except_input", - std::vector({kFormatDependInputIndex})); - } - } - return SUCCESS; -} - Status InsertNewOpUtil::UpdateDataNodeByAipp(const ComputeGraphPtr &graph) { std::map switchn_names_to_data; std::set updated_switchn; @@ -300,9 +281,6 @@ Status InsertNewOpUtil::UpdateDataNodeByAipp(const ComputeGraphPtr &graph) { } if (node->GetType() == AIPP) { GE_RETURN_IF_ERROR(UpdatePrevNodeByAipp(node, updated_switchn)); - // In dynamic batch/HW and dynamic aipp scend, switchn should be set format agnostic, otherwise transdata maybe - // inserted between aipp and switchn which introduce performance and memory increase problem. - GE_RETURN_IF_ERROR(AddFormatAgnosticAttrToSwitchn(node)); } if (node->GetType() == CASE && node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { multbatch_case = node; diff --git a/ge/graph/preprocess/insert_op/util_insert_aipp_op.h b/ge/graph/preprocess/insert_op/util_insert_aipp_op.h index 52e7ed5d..d917b74d 100644 --- a/ge/graph/preprocess/insert_op/util_insert_aipp_op.h +++ b/ge/graph/preprocess/insert_op/util_insert_aipp_op.h @@ -68,7 +68,6 @@ class InsertNewOpUtil { void UpdateMultiBatchInputDims(const OpDescPtr &data_opdesc, Format &old_format); Status UpdatePrevNodeByAipp(NodePtr &node, std::set &switchns); Status UpdateDataBySwitchN(const NodePtr &switchn, const NodePtr &data); - Status AddFormatAgnosticAttrToSwitchn(const NodePtr &aipp_node); Status GetDataRelatedNode(NodePtr &node, std::map> &data_next_node_map); Status GetAllAipps(const NodePtr &data_node, const NodePtr &node, std::vector &aipps); Status GetInputOutputInfo(NodePtr &data_node, NodePtr &aipp_node, std::string &input, std::string &output);