diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 7b624027..87070e79 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -131,6 +131,8 @@ const int64_t kInvalidDynaimcDimsType = -1; const char *const kSubstrOfGetNextNosinkName = "IteratorGetNext"; const char *const kShapeDataName = "ascend_mbatch_shape_data"; const char *const kGetNextName = "IteratorV2"; +const char *const kExtAttrDataNodes = "data_nodes"; +const char *const kExtAttrGetNextNoSink = "getnext_no_sink"; bool IsTailingOptimization() { string is_tailing_optimization_option; @@ -2731,37 +2733,6 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { } } -Status GraphManager::DistinguishGetNextAndData(ComputeGraphPtr &graph, vector &data_nodes, - vector &getnext_nosink_nodes, - vector &getnext_sink_nodes) { - GELOGD("Start distinguish getnext and data node."); - for (NodePtr &input_node : graph->GetDirectNode()) { - GE_CHECK_NOTNULL(input_node); - OpDescPtr op_desc = input_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - if (op_desc->GetType() == DATA && op_desc->GetName() != kShapeDataName) { - if (op_desc->GetName().find(kSubstrOfGetNextNosinkName) == string::npos) { - data_nodes.emplace_back(input_node); - } else { - getnext_nosink_nodes.emplace_back(input_node); - } - } - std::string op_type; - auto ret = GetOriginalType(input_node, op_type); - if (ret != SUCCESS) { - GELOGE(FAILED, "Failed to get node %s original type.", input_node->GetName().c_str()); - return FAILED; - } - if (op_type == kGetNextName) { - GELOGD("Name of getnext sink is %s.", op_desc->GetName().c_str()); - getnext_sink_nodes.emplace_back(input_node); - } - } - GELOGI("data count is %zu, getnext nosink count is %zu, getnext sink count is %zu.", data_nodes.size(), - getnext_nosink_nodes.size(), getnext_sink_nodes.size()); - return SUCCESS; -} - void GraphManager::ParseInputsDimsForData(const std::vector &input_tensor) { GELOGD("Start parse input dims from data."); for (size_t i = 0; i < input_tensor.size(); ++i) { @@ -2804,11 +2775,8 @@ Status GraphManager::ParseInputsDims(const std::vector &input_t if (!GetLocalOmgContext().dynamic_node_type.empty()) { vector data_nodes; vector getnext_nosink_nodes; - vector getnext_sink_nodes; - if (DistinguishGetNextAndData(compute_graph_, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) { - GELOGE(PARAM_INVALID, "Failed to distinguish getnext and data node."); - return PARAM_INVALID; - } + data_nodes = compute_graph_->TryGetExtAttr(kExtAttrDataNodes, data_nodes); + getnext_nosink_nodes = compute_graph_->TryGetExtAttr(kExtAttrGetNextNoSink, getnext_nosink_nodes); if (GetLocalOmgContext().dynamic_node_type == DATA) { if (getnext_nosink_nodes.empty()) { // just data or data+getnext_sink diff --git a/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h index 83aebeb6..feca02fc 100644 --- a/ge/graph/manager/graph_manager.h +++ b/ge/graph/manager/graph_manager.h @@ -222,8 +222,6 @@ class GraphManager { const ComputeGraphPtr &compute_graph, uint64_t session_id, const GEThreadLocalContext &ge_context); Status ParseInputsDims(const std::vector &input_tensor); - Status DistinguishGetNextAndData(ComputeGraphPtr &graph, vector &data_nodes, - vector &getnext_nosink_nodes, vector &getnext_sink_nodes); void ParseInputsDimsForData(const std::vector &input_tensor); Status ParseInputsDimsForGetNexNosinkAndData(const vector &dynamic_nodes, const std::vector &input_tensor); diff --git a/ge/graph/preprocess/multi_batch_options.cc b/ge/graph/preprocess/multi_batch_options.cc index f33c2983..c26b08bc 100644 --- a/ge/graph/preprocess/multi_batch_options.cc +++ b/ge/graph/preprocess/multi_batch_options.cc @@ -46,6 +46,8 @@ const int kDivisionConst = 2; const char *const kSubstrOfGetNextNosinkName = "IteratorGetNext"; const char *const kShapeDataName = "ascend_mbatch_shape_data"; const char *const kGetNextName = "IteratorV2"; +const char *const kExtAttrDataNodes = "data_nodes"; +const char *const kExtAttrGetNextNoSink = "getnext_no_sink"; inline bool IsGetNextType(const NodePtr &node) { std::string original_type; @@ -97,6 +99,9 @@ Status DistinguishGetNextAndData(ComputeGraphPtr &graph, vector &data_n } GELOGI("Data count is %zu, getnext nosink count is %zu, getnext sink count is %zu.", data_nodes.size(), getnext_nosink_nodes.size(), getnext_sink_nodes.size()); + GE_IF_BOOL_EXEC(!graph->SetExtAttr(kExtAttrDataNodes, data_nodes), GELOGW("Set data nodes attr failed.");) + GE_IF_BOOL_EXEC(!graph->SetExtAttr(kExtAttrGetNextNoSink, getnext_nosink_nodes), + GELOGW("Set getnext nosink nodes attr failed.");) return SUCCESS; }