@@ -16,6 +16,7 @@ | |||||
#include "graph/passes/mark_agnostic_pass.h" | #include "graph/passes/mark_agnostic_pass.h" | ||||
#include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
#include "graph/utils/tensor_utils.h" | |||||
namespace ge { | namespace ge { | ||||
Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | ||||
@@ -47,6 +48,16 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | |||||
} | } | ||||
if (node_type == MERGE) { | if (node_type == MERGE) { | ||||
GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); | GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); | ||||
auto in_nodes = node->GetInAllNodes(); | |||||
vector<NodePtr> input_nodes(in_nodes.begin(), in_nodes.end()); | |||||
/// Enter-----------+ | |||||
/// +-> Merge | |||||
/// NextIteration---+ | |||||
if (input_nodes.size() == 2) { | |||||
if (input_nodes[0]->GetType() == ENTER && input_nodes[1]->GetType() == NEXTITERATION) { | |||||
continue; | |||||
} | |||||
} | |||||
const OpDescPtr op_desc = node->GetOpDesc(); | const OpDescPtr op_desc = node->GetOpDesc(); | ||||
const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0); | const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0); | ||||
if (op_tensor == nullptr) { | if (op_tensor == nullptr) { | ||||
@@ -117,6 +117,7 @@ | |||||
#include "graph/passes/variable_op_pass.h" | #include "graph/passes/variable_op_pass.h" | ||||
#include "graph/passes/variable_prepare_op_pass.h" | #include "graph/passes/variable_prepare_op_pass.h" | ||||
#include "graph/passes/variable_ref_delete_op_pass.h" | #include "graph/passes/variable_ref_delete_op_pass.h" | ||||
#include "graph/passes/mark_agnostic_pass.h" | |||||
namespace ge { | namespace ge { | ||||
@@ -1700,6 +1701,7 @@ Status GraphPrepare::PrepareOptimize() { | |||||
try { | try { | ||||
(void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass); | (void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass); | ||||
(void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass); | (void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass); | ||||
(void)original_graph_passes.AddPass("PrepareOptimize::MarkAgnosticPass" , new MarkAgnosticPass); | |||||
} catch (std::bad_alloc &e) { | } catch (std::bad_alloc &e) { | ||||
GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); | GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||