|
|
@@ -48,26 +48,42 @@ void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, Grap |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool StreamGraphOptimizer::IsSameStreamId(const ComputeGraphPtr &comp_graph) { |
|
|
|
bool StreamGraphOptimizer::IsSameStreamIdOrBatchLabel(const ComputeGraphPtr &comp_graph) { |
|
|
|
if (comp_graph == nullptr) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
std::set<int64_t> stream_set; |
|
|
|
std::set<std::string> label_set; |
|
|
|
for (const ge::NodePtr &cur_node : comp_graph->GetDirectNode()) { |
|
|
|
GE_IF_BOOL_EXEC(cur_node->GetOpDesc() == nullptr, continue); |
|
|
|
int64_t stream_id = cur_node->GetOpDesc()->GetStreamId(); |
|
|
|
if (stream_id == kInvalidStream) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
GELOGD("Node %s in subgraph %s stream id is: %ld, node num: %zu", cur_node->GetName().c_str(), |
|
|
|
stream_set.insert(stream_id); |
|
|
|
|
|
|
|
std::string batch_label; |
|
|
|
if (AttrUtils::GetStr(cur_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label)) { |
|
|
|
label_set.insert(batch_label); |
|
|
|
} else { |
|
|
|
GELOGD("Node %s[%s] has no batch label, subgraph %s, stream id: %ld", cur_node->GetName().c_str(), |
|
|
|
cur_node->GetType().c_str(), comp_graph->GetName().c_str(), stream_id); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
GELOGD("Node %s in subgraph %s stream id: %ld, node num: %zu", cur_node->GetName().c_str(), |
|
|
|
comp_graph->GetName().c_str(), stream_id, comp_graph->GetDirectNodesSize()); |
|
|
|
stream_set.insert(stream_id); |
|
|
|
} |
|
|
|
if (stream_set.size() > 1) { |
|
|
|
GELOGI("Nodes of graph: %s have different stream id, node num: %zu, different stream num: %zu.", |
|
|
|
if (stream_set.size() > 1 || label_set.size() > 1) { |
|
|
|
GELOGI("Nodes of graph: %s have different stream id or batch_label, node num: %zu, different stream num: %zu.", |
|
|
|
comp_graph->GetName().c_str(), comp_graph->GetDirectNodesSize(), stream_set.size()); |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (!label_set.empty()) { |
|
|
|
(void)AttrUtils::SetStr(comp_graph, ATTR_NAME_BATCH_LABEL, *label_set.begin()); |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
@@ -99,8 +115,8 @@ Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &com |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if (!IsSameStreamId(subgraph)) { |
|
|
|
GELOGI("There are more than one stream in subgraph %s", subgraph->GetName().c_str()); |
|
|
|
if (!IsSameStreamIdOrBatchLabel(subgraph)) { |
|
|
|
GELOGI("There are more than one stream or batch_label in subgraph %s", subgraph->GetName().c_str()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
OpDescPtr op_desc = nodes.at(0)->GetOpDesc(); |
|
|
@@ -112,9 +128,11 @@ Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &com |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
run_context.stream = run_context.graphStreamList[stream_id]; |
|
|
|
GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu.", |
|
|
|
subgraph->GetName().c_str(), engine_name.c_str(), stream_id, |
|
|
|
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(run_context.stream))); |
|
|
|
std::string batch_label; |
|
|
|
(void)AttrUtils::GetStr(subgraph, ATTR_NAME_BATCH_LABEL, batch_label); |
|
|
|
GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu, " |
|
|
|
"batch_label: %s", subgraph->GetName().c_str(), engine_name.c_str(), stream_id, |
|
|
|
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(run_context.stream)), batch_label.c_str()); |
|
|
|
for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) { |
|
|
|
GE_CHECK_NOTNULL(*iter); |
|
|
|
Status ret = (*iter)->OptimizeStreamGraph(*subgraph, run_context); |
|
|
|