|
|
@@ -32,16 +32,17 @@ namespace { |
|
|
|
bool HasOneNonDataNode(const ComputeGraphPtr &graph) { |
|
|
|
GE_CHECK_NOTNULL(graph); |
|
|
|
int32_t non_data_nums = 0; |
|
|
|
for (const auto& n : graph->GetDirectNode()) { |
|
|
|
if (n->GetType() != parser::DATA) { |
|
|
|
for (const auto& node : graph->GetDirectNode()) { |
|
|
|
if (node->GetType() != parser::DATA) { |
|
|
|
non_data_nums++; |
|
|
|
} |
|
|
|
} |
|
|
|
GELOGD("graph has non data node num is %d", non_data_nums); |
|
|
|
GELOGD("Graph has non data node num is %d", non_data_nums); |
|
|
|
return (non_data_nums == 1); |
|
|
|
} |
|
|
|
Status HandleNewOp(const NodePtr &node, |
|
|
|
const ComputeGraphPtr &compute_graph, |
|
|
|
const ComputeGraphPtr &sub_compute_graph, |
|
|
|
const NodePtr &new_node, |
|
|
|
bool no_need_change_name) { |
|
|
|
GE_CHECK_NOTNULL(node); |
|
|
@@ -60,35 +61,43 @@ Status HandleNewOp(const NodePtr &node, |
|
|
|
new_name = "PartitionedCall_" + new_node->GetName() + "_" + to_string(new_node_index++); |
|
|
|
} |
|
|
|
op_desc->SetName(new_name); |
|
|
|
bool ret = ge::AttrUtils::SetListStr(op_desc, |
|
|
|
ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, |
|
|
|
std::move(std::vector<std::string>{node->GetName()})); |
|
|
|
if (!ret) { |
|
|
|
std::vector<std::string> node_name_vec = { node->GetName() }; |
|
|
|
if (!ge::AttrUtils::SetListStr(op_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, |
|
|
|
std::move(node_name_vec))) { |
|
|
|
GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), op_desc->GetName().c_str()); |
|
|
|
} |
|
|
|
GELOGD("Handle new op[%s] for node[%s] success.", new_node->GetName().c_str(), node->GetName().c_str()); |
|
|
|
// handle control op |
|
|
|
const auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); |
|
|
|
for (size_t i = 0UL; i < sub_graph_names.size(); i++) { |
|
|
|
auto branch_graph = sub_compute_graph->GetSubgraph(sub_graph_names[i]); |
|
|
|
GE_CHECK_NOTNULL(branch_graph); |
|
|
|
branch_graph->SetParentNode(new_node); |
|
|
|
branch_graph->SetParentGraph(compute_graph); |
|
|
|
compute_graph->AddSubGraph(branch_graph); |
|
|
|
} |
|
|
|
GELOGD("Handle new node[%s] for node[%s] success.", new_node->GetName().c_str(), node->GetName().c_str()); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
Status ParserUtils::ExpandOneToManyGraph(const Graph &graph, OutputMapping &output_mapping) { |
|
|
|
GELOGD("Begin run ParserUtils::ExpandOneToManyGraph."); |
|
|
|
for (const auto &gn : graph.GetDirectNode()) { |
|
|
|
NodePtr n = NodeAdapter::GNode2Node(gn); |
|
|
|
GE_CHECK_NOTNULL(n); |
|
|
|
GELOGD("Begin to run ParserUtils::ExpandOneToManyGraph."); |
|
|
|
for (const auto &ge_node : graph.GetDirectNode()) { |
|
|
|
NodePtr node = NodeAdapter::GNode2Node(ge_node); |
|
|
|
GE_CHECK_NOTNULL(node); |
|
|
|
std::string ori_type; |
|
|
|
(void)AttrUtils::GetStr(n->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, ori_type); |
|
|
|
(void)AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, ori_type); |
|
|
|
domi::ParseOpToGraphFunc parse_op_to_graph_func = |
|
|
|
domi::OpRegistry::Instance()->GetParseOpToGraphFunc(n->GetType(), ori_type); |
|
|
|
domi::OpRegistry::Instance()->GetParseOpToGraphFunc(node->GetType(), ori_type); |
|
|
|
if (parse_op_to_graph_func == nullptr) { |
|
|
|
GELOGD("node:%s type:%s ori type:%s has no parse_op_to_graph_func.", |
|
|
|
n->GetName().c_str(), n->GetType().c_str(), ori_type.c_str()); |
|
|
|
node->GetName().c_str(), node->GetType().c_str(), ori_type.c_str()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
GELOGI("node:%s type:%s ori type:%s has registered one to many parser func.", |
|
|
|
n->GetName().c_str(), n->GetType().c_str(), ori_type.c_str()); |
|
|
|
node->GetName().c_str(), node->GetType().c_str(), ori_type.c_str()); |
|
|
|
Graph subgraph("one_to_many_graph"); |
|
|
|
Operator op = OpDescUtils::CreateOperatorFromNode(n); |
|
|
|
Operator op = OpDescUtils::CreateOperatorFromNode(node); |
|
|
|
Status ret = parse_op_to_graph_func(op, subgraph); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Get one to many graph failed for op:%s.", GetOperatorName(op).c_str()); |
|
|
@@ -96,14 +105,14 @@ Status ParserUtils::ExpandOneToManyGraph(const Graph &graph, OutputMapping &outp |
|
|
|
GetOperatorName(op).c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
ret = ExpandNodeToSubgraph(subgraph, n, graph, output_mapping); |
|
|
|
ret = ExpandNodeToSubgraph(subgraph, node, graph, output_mapping); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(FAILED, "[Invoke][ExpandNodeToSubgraph]Expand one to many graph failed for op:%s.", |
|
|
|
GetOperatorName(op).c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
GELOGD("run ParserUtils::ExpandOneToManyGraph success."); |
|
|
|
GELOGD("Run ParserUtils::ExpandOneToManyGraph success."); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
@@ -117,10 +126,10 @@ Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &n |
|
|
|
// add subgraph node to graph. |
|
|
|
bool no_need_change_name = HasOneNonDataNode(sub_compute_graph); |
|
|
|
std::vector<NodePtr> input_nodes; |
|
|
|
for (const auto &n : sub_compute_graph->GetDirectNode()) { |
|
|
|
auto new_node = compute_graph->AddNode(n); |
|
|
|
for (const auto &sub_node : sub_compute_graph->GetDirectNode()) { |
|
|
|
auto new_node = compute_graph->AddNode(sub_node); |
|
|
|
GE_CHECK_NOTNULL(new_node); |
|
|
|
if (HandleNewOp(node, compute_graph, new_node, no_need_change_name) != SUCCESS) { |
|
|
|
if (HandleNewOp(node, compute_graph, sub_compute_graph, new_node, no_need_change_name) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "[Handle][NewOp][%s] for node[%s] failed.", new_node->GetName().c_str(), node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
@@ -240,7 +249,7 @@ Status ParserUtils::HandleOutputContext(const NodePtr &node, |
|
|
|
const std::vector<std::pair<NodePtr, int32_t>> &out_node_index, |
|
|
|
OutputMapping &output_mapping) { |
|
|
|
GE_CHECK_NOTNULL(node); |
|
|
|
GELOGD("The size of out node is %zu", out_node_index.size()); |
|
|
|
GELOGD("The size of output node is %zu", out_node_index.size()); |
|
|
|
for (size_t index = 0; index < out_node_index.size(); index++) { |
|
|
|
auto node_out_anchor = node->GetOutDataAnchor(index); |
|
|
|
if (node_out_anchor == nullptr) { |
|
|
@@ -249,7 +258,7 @@ Status ParserUtils::HandleOutputContext(const NodePtr &node, |
|
|
|
|
|
|
|
NodePtr out_node = out_node_index[index].first; |
|
|
|
int32_t out_index = out_node_index[index].second; |
|
|
|
GELOGD("Begin to handle output node:%s[%d] with index:%zu", out_node->GetName().c_str(), out_index, index); |
|
|
|
GELOGD("Begin to handle output node: %s[%d] with index:%zu", out_node->GetName().c_str(), out_index, index); |
|
|
|
std::string key = GenOutputKey({node->GetName(), index}); |
|
|
|
output_mapping[key] = std::make_pair(out_node->GetName(), out_index); |
|
|
|
auto src_out_anchor = out_node->GetOutDataAnchor(out_index); // get out node's out anchor. |
|
|
|