|
|
@@ -29,7 +29,21 @@ |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
namespace { |
|
|
|
Status HandleNewOp(const NodePtr &node, const ComputeGraphPtr &compute_graph, const NodePtr &new_node) { |
|
|
|
bool HasOneNonDataNode(const ComputeGraphPtr &graph) { |
|
|
|
GE_CHECK_NOTNULL(graph); |
|
|
|
int32_t non_data_nums = 0; |
|
|
|
for (const auto& n : graph->GetDirectNode()) { |
|
|
|
if (n->GetType() != parser::DATA) { |
|
|
|
non_data_nums++; |
|
|
|
} |
|
|
|
} |
|
|
|
GELOGD("graph has non data node num is %d", non_data_nums); |
|
|
|
return (non_data_nums == 1); |
|
|
|
} |
|
|
|
Status HandleNewOp(const NodePtr &node, |
|
|
|
const ComputeGraphPtr &compute_graph, |
|
|
|
const NodePtr &new_node, |
|
|
|
bool no_need_change_name) { |
|
|
|
GE_CHECK_NOTNULL(node); |
|
|
|
GE_CHECK_NOTNULL(new_node); |
|
|
|
if (new_node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) { |
|
|
@@ -37,8 +51,13 @@ Status HandleNewOp(const NodePtr &node, const ComputeGraphPtr &compute_graph, co |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
auto op_desc = new_node->GetOpDesc(); |
|
|
|
static std::atomic_long new_node_index(0); |
|
|
|
auto new_name = "PartitionedCall_" + new_node->GetName() + "_" + to_string(new_node_index++); |
|
|
|
string new_name; |
|
|
|
if (no_need_change_name) { |
|
|
|
new_name = node->GetName(); |
|
|
|
} else { |
|
|
|
static std::atomic_long new_node_index(0); |
|
|
|
new_name = "PartitionedCall_" + new_node->GetName() + "_" + to_string(new_node_index++); |
|
|
|
} |
|
|
|
op_desc->SetName(new_name); |
|
|
|
bool ret = ge::AttrUtils::SetListStr(op_desc, |
|
|
|
ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, |
|
|
@@ -91,11 +110,12 @@ Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &n |
|
|
|
GE_CHECK_NOTNULL(compute_graph); |
|
|
|
|
|
|
|
// add subgraph node to graph. |
|
|
|
bool no_need_change_name = HasOneNonDataNode(sub_compute_graph); |
|
|
|
std::vector<NodePtr> input_nodes; |
|
|
|
for (const auto &n : sub_compute_graph->GetDirectNode()) { |
|
|
|
auto new_node = compute_graph->AddNode(n); |
|
|
|
GE_CHECK_NOTNULL(new_node); |
|
|
|
if (HandleNewOp(node, compute_graph, new_node) != SUCCESS) { |
|
|
|
if (HandleNewOp(node, compute_graph, new_node, no_need_change_name) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Handle new op[%s] for node[%s] failed.", new_node->GetName().c_str(), node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|