@@ -21,7 +21,7 @@ namespace ge { | |||||
Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | ||||
for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
auto node_type = NodeUtils::GetNodeType(*node); | auto node_type = NodeUtils::GetNodeType(*node); | ||||
if (node_type == SWITCH || node_type == REFSWITCH || node_type == SWITCHN) { | |||||
if (node_type == SWITCH || node_type == SWITCHN) { | |||||
GELOGD("Mark format agnostic and continuous for switch node %s", node->GetName().c_str()); | GELOGD("Mark format agnostic and continuous for switch node %s", node->GetName().c_str()); | ||||
const OpDescPtr op_desc = node->GetOpDesc(); | const OpDescPtr op_desc = node->GetOpDesc(); | ||||
const GeTensorDescPtr op_tensor = op_desc->MutableInputDesc(0); | const GeTensorDescPtr op_tensor = op_desc->MutableInputDesc(0); | ||||
@@ -37,10 +37,15 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | |||||
if (node_type == IDENTITY) { | if (node_type == IDENTITY) { | ||||
GELOGD("Mark format agnostic for identity node %s", node->GetName().c_str()); | GELOGD("Mark format agnostic for identity node %s", node->GetName().c_str()); | ||||
AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | ||||
continue; | |||||
} | |||||
if (node_type == REFMERGE || node_type == REFSWITCH) { | |||||
GELOGD("Mark format agnostic for regmerge and refswitch node %s", node->GetName().c_str()); | |||||
AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | |||||
AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector<int64_t>({1})); | AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector<int64_t>({1})); | ||||
continue; | continue; | ||||
} | } | ||||
if (node_type == MERGE || node_type == REFMERGE) { | |||||
if (node_type == MERGE) { | |||||
GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); | GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); | ||||
const OpDescPtr op_desc = node->GetOpDesc(); | const OpDescPtr op_desc = node->GetOpDesc(); | ||||
const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0); | const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0); | ||||
@@ -117,6 +117,7 @@ | |||||
#include "graph/passes/variable_op_pass.h" | #include "graph/passes/variable_op_pass.h" | ||||
#include "graph/passes/variable_prepare_op_pass.h" | #include "graph/passes/variable_prepare_op_pass.h" | ||||
#include "graph/passes/variable_ref_delete_op_pass.h" | #include "graph/passes/variable_ref_delete_op_pass.h" | ||||
#include "graph/passes/mark_agnostic_pass.h" | |||||
namespace ge { | namespace ge { | ||||
@@ -1626,6 +1627,7 @@ Status GraphPrepare::PrepareOptimize() { | |||||
try { | try { | ||||
(void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass); | (void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass); | ||||
(void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass); | (void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass); | ||||
(void)original_graph_passes.AddPass("PrepareOptimize::MarkAgnosticPass", new MarkAgnosticPass); | |||||
} catch (std::bad_alloc &e) { | } catch (std::bad_alloc &e) { | ||||
GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); | GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
@@ -40,8 +40,6 @@ using domi::AippOpParams; | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
const char *const kMbatchSwitchnName = "mbatch-switch-name"; | const char *const kMbatchSwitchnName = "mbatch-switch-name"; | ||||
const int64_t kFormatAgnosticSwitch = 1; | |||||
const int64_t kFormatDependInputIndex = 1; | |||||
} // namespace | } // namespace | ||||
static void ConvertShape2Nhwc(Format &format, vector<int64_t> &shape_vec) { | static void ConvertShape2Nhwc(Format &format, vector<int64_t> &shape_vec) { | ||||
if ((format == FORMAT_NHWC) || (shape_vec.size() != static_cast<size_t>(NORMAL_TENSOR_SIZE))) { | if ((format == FORMAT_NHWC) || (shape_vec.size() != static_cast<size_t>(NORMAL_TENSOR_SIZE))) { | ||||
@@ -269,23 +267,6 @@ Status InsertNewOpUtil::GetAippParams(const std::unique_ptr<domi::AippOpParams> | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status InsertNewOpUtil::AddFormatAgnosticAttrToSwitchn(const NodePtr &aipp_node) { | |||||
GE_CHECK_NOTNULL(aipp_node); | |||||
auto next_nodes = aipp_node->GetOutDataNodes(); | |||||
for (const auto next_node : next_nodes) { | |||||
GE_CHECK_NOTNULL(next_node); | |||||
auto op_desc = next_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (op_desc->GetType() == SWITCHN) { | |||||
GELOGI("Find switchn node [%s] after aipp [%s]", op_desc->GetName().c_str(), aipp_node->GetName().c_str()); | |||||
(void)AttrUtils::SetInt(op_desc, "_format_agnostic", kFormatAgnosticSwitch); | |||||
(void)AttrUtils::SetListInt(op_desc, "_format_agnostic_except_input", | |||||
std::vector<int64_t>({kFormatDependInputIndex})); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status InsertNewOpUtil::UpdateDataNodeByAipp(const ComputeGraphPtr &graph) { | Status InsertNewOpUtil::UpdateDataNodeByAipp(const ComputeGraphPtr &graph) { | ||||
std::map<std::string, NodePtr> switchn_names_to_data; | std::map<std::string, NodePtr> switchn_names_to_data; | ||||
std::set<NodePtr> updated_switchn; | std::set<NodePtr> updated_switchn; | ||||
@@ -300,9 +281,6 @@ Status InsertNewOpUtil::UpdateDataNodeByAipp(const ComputeGraphPtr &graph) { | |||||
} | } | ||||
if (node->GetType() == AIPP) { | if (node->GetType() == AIPP) { | ||||
GE_RETURN_IF_ERROR(UpdatePrevNodeByAipp(node, updated_switchn)); | GE_RETURN_IF_ERROR(UpdatePrevNodeByAipp(node, updated_switchn)); | ||||
// In dynamic batch/HW and dynamic aipp scend, switchn should be set format agnostic, otherwise transdata maybe | |||||
// inserted between aipp and switchn which introduce performance and memory increase problem. | |||||
GE_RETURN_IF_ERROR(AddFormatAgnosticAttrToSwitchn(node)); | |||||
} | } | ||||
if (node->GetType() == CASE && node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { | if (node->GetType() == CASE && node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { | ||||
multbatch_case = node; | multbatch_case = node; | ||||
@@ -68,7 +68,6 @@ class InsertNewOpUtil { | |||||
void UpdateMultiBatchInputDims(const OpDescPtr &data_opdesc, Format &old_format); | void UpdateMultiBatchInputDims(const OpDescPtr &data_opdesc, Format &old_format); | ||||
Status UpdatePrevNodeByAipp(NodePtr &node, std::set<NodePtr> &switchns); | Status UpdatePrevNodeByAipp(NodePtr &node, std::set<NodePtr> &switchns); | ||||
Status UpdateDataBySwitchN(const NodePtr &switchn, const NodePtr &data); | Status UpdateDataBySwitchN(const NodePtr &switchn, const NodePtr &data); | ||||
Status AddFormatAgnosticAttrToSwitchn(const NodePtr &aipp_node); | |||||
Status GetDataRelatedNode(NodePtr &node, std::map<NodePtr, std::set<NodePtr>> &data_next_node_map); | Status GetDataRelatedNode(NodePtr &node, std::map<NodePtr, std::set<NodePtr>> &data_next_node_map); | ||||
Status GetAllAipps(const NodePtr &data_node, const NodePtr &node, std::vector<NodePtr> &aipps); | Status GetAllAipps(const NodePtr &data_node, const NodePtr &node, std::vector<NodePtr> &aipps); | ||||
Status GetInputOutputInfo(NodePtr &data_node, NodePtr &aipp_node, std::string &input, std::string &output); | Status GetInputOutputInfo(NodePtr &data_node, NodePtr &aipp_node, std::string &input, std::string &output); | ||||