@@ -137,6 +137,7 @@ set(TRAIN_SRC_LIST | |||||
"graph/passes/atomic_addr_clean_pass.cc" | "graph/passes/atomic_addr_clean_pass.cc" | ||||
"graph/passes/mark_same_addr_pass.cc" | "graph/passes/mark_same_addr_pass.cc" | ||||
"graph/passes/mark_graph_unknown_status_pass.cc" | "graph/passes/mark_graph_unknown_status_pass.cc" | ||||
"graph/passes/mark_agnostic_pass.cc" | |||||
"graph/partition/dynamic_shape_partition.cc" | "graph/partition/dynamic_shape_partition.cc" | ||||
"graph/partition/stage_partition.cc" | "graph/partition/stage_partition.cc" | ||||
"graph/passes/base_pass.cc" | "graph/passes/base_pass.cc" | ||||
@@ -488,6 +489,7 @@ set(INFER_SRC_LIST | |||||
"graph/passes/atomic_addr_clean_pass.cc" | "graph/passes/atomic_addr_clean_pass.cc" | ||||
"graph/passes/mark_same_addr_pass.cc" | "graph/passes/mark_same_addr_pass.cc" | ||||
"graph/passes/mark_graph_unknown_status_pass.cc" | "graph/passes/mark_graph_unknown_status_pass.cc" | ||||
"graph/passes/mark_agnostic_pass.cc" | |||||
"graph/common/omg_util.cc" | "graph/common/omg_util.cc" | ||||
"graph/common/bcast.cc" | "graph/common/bcast.cc" | ||||
"graph/common/local_context.cc" | "graph/common/local_context.cc" | ||||
@@ -109,6 +109,7 @@ OMG_HOST_SRC_FILES := \ | |||||
graph/passes/atomic_addr_clean_pass.cc \ | graph/passes/atomic_addr_clean_pass.cc \ | ||||
graph/passes/mark_same_addr_pass.cc \ | graph/passes/mark_same_addr_pass.cc \ | ||||
graph/passes/mark_graph_unknown_status_pass.cc \ | graph/passes/mark_graph_unknown_status_pass.cc \ | ||||
graph/passes/mark_agnostic_pass.cc \ | |||||
graph/common/omg_util.cc \ | graph/common/omg_util.cc \ | ||||
graph/common/bcast.cc \ | graph/common/bcast.cc \ | ||||
graph/common/local_context.cc \ | graph/common/local_context.cc \ | ||||
@@ -110,6 +110,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
graph/passes/atomic_addr_clean_pass.cc \ | graph/passes/atomic_addr_clean_pass.cc \ | ||||
graph/passes/mark_same_addr_pass.cc \ | graph/passes/mark_same_addr_pass.cc \ | ||||
graph/passes/mark_graph_unknown_status_pass.cc \ | graph/passes/mark_graph_unknown_status_pass.cc \ | ||||
graph/passes/mark_agnostic_pass.cc \ | |||||
graph/partition/dynamic_shape_partition.cc \ | graph/partition/dynamic_shape_partition.cc \ | ||||
graph/partition/stage_partition.cc \ | graph/partition/stage_partition.cc \ | ||||
graph/passes/base_pass.cc \ | graph/passes/base_pass.cc \ | ||||
@@ -15,20 +15,40 @@ | |||||
*/ | */ | ||||
#include "graph/passes/mark_agnostic_pass.h" | #include "graph/passes/mark_agnostic_pass.h" | ||||
#include "utils/node_utils.h" | |||||
#include "graph/utils/node_utils.h" | |||||
namespace ge { | namespace ge { | ||||
Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | ||||
for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
auto node_type = NodeUtils::GetNodeType(*node); | auto node_type = NodeUtils::GetNodeType(*node); | ||||
if (node_type == SWITCH || node_type == REFSWITCH || node_type == SWITCHN) { | if (node_type == SWITCH || node_type == REFSWITCH || node_type == SWITCHN) { | ||||
GELOGD("Mark format agnostic for switch ndoe %s", node->GetName().c_str()); | |||||
GELOGD("Mark format agnostic and continuous for switch node %s", node->GetName().c_str()); | |||||
const OpDescPtr op_desc = node->GetOpDesc(); | |||||
const GeTensorDescPtr op_tensor = op_desc->MutableInputDesc(0); | |||||
if (op_tensor == nullptr) { | |||||
GELOGD("Op: %s, Index:0,has no input", node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
AttrUtils::SetInt(op_tensor, "_format_continuous", 1); | |||||
AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | |||||
AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector<int64_t>({1})); | |||||
continue; | |||||
} | |||||
if (node_type == IDENTITY) { | |||||
GELOGD("Mark format agnostic for identity node %s", node->GetName().c_str()); | |||||
AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | ||||
AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector<int64_t>({1})); | AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector<int64_t>({1})); | ||||
continue; | continue; | ||||
} | } | ||||
if (node_type == MERGE || node_type == REFMERGE) { | if (node_type == MERGE || node_type == REFMERGE) { | ||||
GELOGD("Mark format agnostic for merge node %s", node->GetName().c_str()); | |||||
GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); | |||||
const OpDescPtr op_desc = node->GetOpDesc(); | |||||
const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0); | |||||
if (op_tensor == nullptr) { | |||||
GELOGD("Op: %s, Index:0,has no output", node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
AttrUtils::SetInt(op_tensor, "_format_continuous", 1); | |||||
AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | ||||
AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_output", std::vector<int64_t>({1})); | AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_output", std::vector<int64_t>({1})); | ||||
continue; | continue; | ||||
@@ -36,4 +56,4 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} | |||||
} // namespace ge |
@@ -92,6 +92,7 @@ | |||||
#include "graph/passes/unused_op_remove_pass.h" | #include "graph/passes/unused_op_remove_pass.h" | ||||
#include "graph/passes/var_is_initialized_op_pass.h" | #include "graph/passes/var_is_initialized_op_pass.h" | ||||
#include "graph/passes/variable_prepare_op_pass.h" | #include "graph/passes/variable_prepare_op_pass.h" | ||||
#include "graph/passes/mark_agnostic_pass.h" | |||||
#include "graph/preprocess/insert_op/util_insert_aipp_op.h" | #include "graph/preprocess/insert_op/util_insert_aipp_op.h" | ||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
@@ -1626,6 +1627,7 @@ Status GraphPrepare::PrepareOptimize() { | |||||
try { | try { | ||||
(void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass); | (void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass); | ||||
(void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass); | (void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass); | ||||
(void)original_graph_passes.AddPass("PrepareOptimize::MarkAgnosticPass", new MarkAgnosticPass); | |||||
} catch (std::bad_alloc &e) { | } catch (std::bad_alloc &e) { | ||||
GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); | GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||