/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "buffer_pool_graph_builder.h" #include "common/ge_inner_error_codes.h" #include "common/types.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/attr_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/graph_utils.h" namespace ge { namespace ut { BufferPoolGraphBuilder::BufferPoolGraphBuilder(const std::string &name) { graph_name_ = name; } BufferPoolGraphBuilder::InnerGraphBuilder::InnerGraphBuilder(const std::string &name) { graph_ = std::make_shared(name); EXPECT_NE(graph_, nullptr); } NodePtr BufferPoolGraphBuilder::InnerGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, DataType data_type, std::vector shape) { auto tensor_desc = std::make_shared(); EXPECT_NE(tensor_desc, nullptr); tensor_desc->SetShape(GeShape(std::move(shape))); tensor_desc->SetFormat(format); tensor_desc->SetDataType(data_type); auto op_desc = std::make_shared(name, type); EXPECT_NE(op_desc, nullptr); for (int i = 0; i < in_cnt; ++i) { op_desc->AddInputDesc(tensor_desc->Clone()); } for (int i = 0; i < out_cnt; ++i) { op_desc->AddOutputDesc(tensor_desc->Clone()); } return graph_->AddNode(op_desc); } void BufferPoolGraphBuilder::InnerGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) { EXPECT_NE(src_node, nullptr); EXPECT_NE(dst_node, nullptr); GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); } void BufferPoolGraphBuilder::InnerGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { EXPECT_NE(src_node, nullptr); EXPECT_NE(dst_node, nullptr); GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); } void BufferPoolGraphBuilder::SetBufferPool(NodePtr &node, int64_t pool_id, int64_t pool_size, const std::string &batch_label) { EXPECT_NE(node, nullptr); (void) AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_ID, pool_id); (void) AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_SIZE, pool_size); if (!batch_label.empty()) { (void) AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); } } void BufferPoolGraphBuilder::SetBatchLabel(NodePtr &node, const std::string &batch_label) { EXPECT_NE(node, nullptr); (void) AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label); } void BufferPoolGraphBuilder::SetOutputMemSize(NodePtr &node, const std::vector &mem_size) { EXPECT_NE(node, nullptr); EXPECT_NE(node->GetOpDesc(), nullptr); size_t output_size = node->GetOpDesc()->GetOutputsSize(); EXPECT_EQ(output_size, mem_size.size()); for (size_t i = 0; i < output_size; ++i) { auto output_op_desc = node->GetOpDesc()->MutableOutputDesc(i); ge::TensorUtils::SetSize(*output_op_desc, mem_size[i]); } } void BufferPoolGraphBuilder::SetWorkSpaceMemSize(NodePtr &node, const std::vector &ws_bytes) { EXPECT_NE(node, nullptr); EXPECT_NE(node->GetOpDesc(), nullptr); node->GetOpDesc()->SetWorkspaceBytes(ws_bytes); } void BufferPoolGraphBuilder::SetPrefetchNodeInfo(NodePtr &node, int64_t pool_id, int64_t pool_size, const std::vector &mem_size, const std::vector &ws_bytes, const std::string &batch_label) { SetBufferPool(node, pool_id, pool_size, batch_label); SetOutputMemSize(node, mem_size); SetWorkSpaceMemSize(node, ws_bytes); } /// /// Normal graph /// /// w1 w2 w3 w4 w5 /// \ \ \ \ \. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 /// \ \ \ \ \. /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output /// /// /// Memory distribution: /// /// |___w1__|__w2__|__w3__|__| /// /// |_____w4_____|_____w5____| /// ComputeGraphPtr BufferPoolGraphBuilder::BuildNormalGraph() { auto builder = InnerGraphBuilder(graph_name_); auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); const int64_t buffer_pool_id = 0; const int64_t buffer_pool_size = 5600; auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); auto add1 = builder.AddNode("add1", ADD, 2, 1); auto add2 = builder.AddNode("add2", ADD, 2, 1); auto add3 = builder.AddNode("add3", ADD, 2, 1); auto add4 = builder.AddNode("add4", ADD, 2, 1); auto add5 = builder.AddNode("add5", ADD, 2, 1); auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0); builder.AddDataEdge(w1, 0, prefetch1, 0); builder.AddDataEdge(w2, 0, prefetch2, 0); builder.AddDataEdge(w3, 0, prefetch3, 0); builder.AddDataEdge(w4, 0, prefetch4, 0); builder.AddDataEdge(w5, 0, prefetch5, 0); builder.AddDataEdge(const1, 0, add1, 0); builder.AddDataEdge(prefetch1, 0, add1, 1); builder.AddDataEdge(add1, 0, add2, 0); builder.AddDataEdge(prefetch2, 0, add2, 1); builder.AddDataEdge(add2, 0, add3, 0); builder.AddDataEdge(prefetch3, 0, add3, 1); builder.AddDataEdge(add3, 0, add4, 0); builder.AddDataEdge(prefetch4, 0, add4, 1); builder.AddDataEdge(add4, 0, add5, 0); builder.AddDataEdge(prefetch5, 0, add5, 1); builder.AddDataEdge(add5, 0, net_output, 0); auto compute_graph = builder.GetGraph(); return compute_graph; } /// /// Normal graph with multi buffer pool /// /// w1 w2 w3 w4 w5 /// \ \ \ \ \. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 /// (pool0) (pool1) (pool0) (pool0) (pool1) /// \ \ \ \ \. /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output /// /// /// Memory distribution: /// /// |___w1__|__w3__|_________| /// |_____w4_____|___________| /// /// |___w2__|_____w5___|_____| /// ComputeGraphPtr BufferPoolGraphBuilder::BuildNormalGraphWithMultiBufferPool() { auto builder = InnerGraphBuilder(graph_name_); auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); const int64_t buffer_pool_id_0 = 0; const int64_t buffer_pool_id_1 = 1; const int64_t buffer_pool_size = 5000; auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch1, buffer_pool_id_0, buffer_pool_size, {500}); auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch2, buffer_pool_id_1, buffer_pool_size, {500}); auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch3, buffer_pool_id_0, buffer_pool_size, {500}); auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch4, buffer_pool_id_0, buffer_pool_size, {1024}); auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch5, buffer_pool_id_1, buffer_pool_size, {1024}); auto add1 = builder.AddNode("add1", ADD, 2, 1); auto add2 = builder.AddNode("add2", ADD, 2, 1); auto add3 = builder.AddNode("add3", ADD, 2, 1); auto add4 = builder.AddNode("add4", ADD, 2, 1); auto add5 = builder.AddNode("add5", ADD, 2, 1); auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0); builder.AddDataEdge(w1, 0, prefetch1, 0); builder.AddDataEdge(w2, 0, prefetch2, 0); builder.AddDataEdge(w3, 0, prefetch3, 0); builder.AddDataEdge(w4, 0, prefetch4, 0); builder.AddDataEdge(w5, 0, prefetch5, 0); builder.AddDataEdge(const1, 0, add1, 0); builder.AddDataEdge(prefetch1, 0, add1, 1); builder.AddDataEdge(add1, 0, add2, 0); builder.AddDataEdge(prefetch2, 0, add2, 1); builder.AddDataEdge(add2, 0, add3, 0); builder.AddDataEdge(prefetch3, 0, add3, 1); builder.AddDataEdge(add3, 0, add4, 0); builder.AddDataEdge(prefetch4, 0, add4, 1); builder.AddDataEdge(add4, 0, add5, 0); builder.AddDataEdge(prefetch5, 0, add5, 1); builder.AddDataEdge(add5, 0, net_output, 0); auto compute_graph = builder.GetGraph(); return compute_graph; } /// /// SerialGraph: Buffer pool size only can contain one prefetch node /// /// w1 w2 w3 w4 w5 /// \ \ \ \ \. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 /// \ \ \ \ \. /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output /// /// /// Memory distribution: /// /// |____w1_____|__| /// /// |____w2_____|__| /// /// |____w3_____|__| /// /// |______w4______| /// /// |______w5______| /// ComputeGraphPtr BufferPoolGraphBuilder::BuildSerialGraph() { auto builder = InnerGraphBuilder(graph_name_); auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); const int64_t buffer_pool_id = 0; const int64_t buffer_pool_size = 2048; auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); auto add1 = builder.AddNode("add1", ADD, 2, 1); auto add2 = builder.AddNode("add2", ADD, 2, 1); auto add3 = builder.AddNode("add3", ADD, 2, 1); auto add4 = builder.AddNode("add4", ADD, 2, 1); auto add5 = builder.AddNode("add5", ADD, 2, 1); auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0); builder.AddDataEdge(w1, 0, prefetch1, 0); builder.AddDataEdge(w2, 0, prefetch2, 0); builder.AddDataEdge(w3, 0, prefetch3, 0); builder.AddDataEdge(w4, 0, prefetch4, 0); builder.AddDataEdge(w5, 0, prefetch5, 0); builder.AddDataEdge(const1, 0, add1, 0); builder.AddDataEdge(prefetch1, 0, add1, 1); builder.AddDataEdge(add1, 0, add2, 0); builder.AddDataEdge(prefetch2, 0, add2, 1); builder.AddDataEdge(add2, 0, add3, 0); builder.AddDataEdge(prefetch3, 0, add3, 1); builder.AddDataEdge(add3, 0, add4, 0); builder.AddDataEdge(prefetch4, 0, add4, 1); builder.AddDataEdge(add4, 0, add5, 0); builder.AddDataEdge(prefetch5, 0, add5, 1); builder.AddDataEdge(add5, 0, net_output, 0); auto compute_graph = builder.GetGraph(); return compute_graph; } /// /// GraphWithMultiPrefetch: Calc node with more prefetch node /// /// w1 w2 w3 w4 w5 /// \ \ \ \ \. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 const1 /// \ / \ / \ / /// \ / \ / \ / /// \ / \ / \ / /// add1 ------ c ------- add2 ----- c ----- add3 /// | | | /// | | | /// --------------- net_output ------------ /// /// Memory distribution: /// /// |___w1__|__w2__|__w3__|__| /// /// |_____w4_____|_____w5____| /// ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiPrefetch() { auto builder = InnerGraphBuilder(graph_name_); auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); const int64_t buffer_pool_id = 0; const int64_t buffer_pool_size = 5600; auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); auto add1 = builder.AddNode("add1", ADD, 2, 1); auto add2 = builder.AddNode("add2", ADD, 2, 1); auto add3 = builder.AddNode("add3", ADD, 2, 1); auto net_output = builder.AddNode("net_output", NETOUTPUT, 3, 0); builder.AddDataEdge(w1, 0, prefetch1, 0); builder.AddDataEdge(w2, 0, prefetch2, 0); builder.AddDataEdge(w3, 0, prefetch3, 0); builder.AddDataEdge(w4, 0, prefetch4, 0); builder.AddDataEdge(w5, 0, prefetch5, 0); builder.AddDataEdge(prefetch1, 0, add1, 0); builder.AddDataEdge(prefetch2, 0, add1, 1); builder.AddDataEdge(prefetch3, 0, add2, 0); builder.AddDataEdge(prefetch4, 0, add2, 1); builder.AddDataEdge(const1, 0, add3, 0); builder.AddDataEdge(prefetch5, 0, add3, 1); builder.AddDataEdge(add1, 0, net_output, 0); builder.AddDataEdge(add2, 0, net_output, 1); builder.AddDataEdge(add3, 0, net_output, 2); builder.AddControlEdge(add1, add2); builder.AddControlEdge(add2, add3); auto compute_graph = builder.GetGraph(); return compute_graph; } /// /// GraphWithSubgraph: Calc node in different subgraph /// /// /// call_node1(with Subgraph1) --------------- call_node2 (with Subgraph2) --------------- net_output /// /// /// Subgraph1: Subgraph2: /// /// w1 w2 w3 w4 w5 /// \ \ \ \ \. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 /// \ \ \ \ \. /// const1 ----- add1 ----- add2 ----- add3 ---- subgraph1_out data1 ---- add4 ----- add5 ---- subgraph2_out /// /// /// Memory distribution: /// /// |___w1__|__w2__|__w3__|__| /// /// |_____w4_____|_____w5____| /// ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithSubgraph() { auto builder = InnerGraphBuilder(graph_name_); const int64_t buffer_pool_id = 0; const int64_t buffer_pool_size = 5600; // Subgraph1 auto subgraph_builder1 = InnerGraphBuilder("Subgraph1"); auto w1 = subgraph_builder1.AddNode("w1", VARIABLE, 0, 1); auto w2 = subgraph_builder1.AddNode("w2", VARIABLE, 0, 1); auto w3 = subgraph_builder1.AddNode("w3", VARIABLE, 0, 1); auto prefetch1 = subgraph_builder1.AddNode("prefetch1", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); auto prefetch2 = subgraph_builder1.AddNode("prefetch2", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); auto prefetch3 = subgraph_builder1.AddNode("prefetch3", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); auto subgraph1_out = subgraph_builder1.AddNode("subgraph1_out", NETOUTPUT, 1, 0); auto const1 = subgraph_builder1.AddNode("const1", CONSTANTOP, 0, 1); auto add1 = subgraph_builder1.AddNode("add1", ADD, 2, 1); auto add2 = subgraph_builder1.AddNode("add2", ADD, 2, 1); auto add3 = subgraph_builder1.AddNode("add3", ADD, 2, 1); subgraph_builder1.AddDataEdge(w1, 0, prefetch1, 0); subgraph_builder1.AddDataEdge(w2, 0, prefetch2, 0); subgraph_builder1.AddDataEdge(w3, 0, prefetch3, 0); subgraph_builder1.AddDataEdge(const1, 0, add1, 0); subgraph_builder1.AddDataEdge(prefetch1, 0, add1, 1); subgraph_builder1.AddDataEdge(add1, 0, add2, 0); subgraph_builder1.AddDataEdge(prefetch2, 0, add2, 1); subgraph_builder1.AddDataEdge(add2, 0, add3, 0); subgraph_builder1.AddDataEdge(prefetch3, 0, add3, 1); subgraph_builder1.AddDataEdge(add3, 0, subgraph1_out, 0); auto subgraph1 = subgraph_builder1.GetGraph(); for (auto &node : subgraph1->GetDirectNode()) { node->SetOwnerComputeGraph(subgraph1); } // Subgraph2 auto subgraph_builder2 = InnerGraphBuilder("Subgraph2"); auto w4 = subgraph_builder2.AddNode("w4", VARIABLE, 0, 1); auto w5 = subgraph_builder2.AddNode("w5", VARIABLE, 0, 1); auto prefetch4 = subgraph_builder2.AddNode("prefetch4", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); auto prefetch5 = subgraph_builder2.AddNode("prefetch5", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); auto add4 = subgraph_builder2.AddNode("add4", ADD, 2, 1); auto add5 = subgraph_builder2.AddNode("add5", ADD, 2, 1); auto data1 = subgraph_builder2.AddNode("data1", DATA, 0, 1); auto subgraph2_out = subgraph_builder2.AddNode("subgraph2_out", NETOUTPUT, 1, 1); subgraph_builder2.AddDataEdge(w4, 0, prefetch4, 0); subgraph_builder2.AddDataEdge(w5, 0, prefetch5, 0); subgraph_builder2.AddDataEdge(data1, 0, add4, 0); subgraph_builder2.AddDataEdge(prefetch4, 0, add4, 1); subgraph_builder2.AddDataEdge(add4, 0, add5, 0); subgraph_builder2.AddDataEdge(prefetch5, 0, add5, 1); subgraph_builder2.AddDataEdge(add5, 0, subgraph2_out, 0); auto subgraph2 = subgraph_builder2.GetGraph(); for (auto &node : subgraph2->GetDirectNode()) { node->SetOwnerComputeGraph(subgraph2); } // root graph auto call_node1 = builder.AddNode("call_node1", PARTITIONEDCALL, 0, 1); auto call_node2 = builder.AddNode("call_node2", PARTITIONEDCALL, 1, 0); auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0); builder.AddDataEdge(call_node1, 0, call_node2, 0); builder.AddDataEdge(call_node2, 0, net_output, 0); auto compute_graph = builder.GetGraph(); call_node1->SetOwnerComputeGraph(compute_graph); call_node1->GetOpDesc()->AddSubgraphName(subgraph1->GetName()); call_node1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph1->GetName()); call_node2->SetOwnerComputeGraph(compute_graph); call_node2->GetOpDesc()->AddSubgraphName(subgraph2->GetName()); call_node2->GetOpDesc()->SetSubgraphInstanceName(0, subgraph2->GetName()); subgraph1->SetParentNode(call_node1); subgraph1->SetParentGraph(compute_graph); subgraph2->SetParentNode(call_node2); subgraph2->SetParentGraph(compute_graph); compute_graph->AddSubGraph(subgraph1); compute_graph->AddSubGraph(subgraph2); return compute_graph; } /// /// SubgraphWithInnerDependency: Calc node in different subgraph with inner dependency /// /// /// call_node1(with Subgraph1) --------------------- call_node2 (with Subgraph2) ---------- net_output /// /// /// Subgraph1: Subgraph2: /// /// w1 w2 w3 w4 w5 /// \ \ \ \ \. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 /// \ \ \ \ \. /// const1 ----- add1 ----- add2 ----- subgraph1_out data1 ---- add3 ---- add4 ----- add5 ---- subgraph2_out /// /// /// Memory distribution: /// /// |___w1__|__w2__|__w3__|__| /// /// |_____w4_____|_____w5____| /// ComputeGraphPtr BufferPoolGraphBuilder::BuildSubgraphWithInnerDependency() { auto builder = InnerGraphBuilder(graph_name_); const int64_t buffer_pool_id = 0; const int64_t buffer_pool_size = 5600; // Subgraph1 auto subgraph_builder1 = InnerGraphBuilder("Subgraph1"); auto w1 = subgraph_builder1.AddNode("w1", VARIABLE, 0, 1); auto w2 = subgraph_builder1.AddNode("w2", VARIABLE, 0, 1); auto prefetch1 = subgraph_builder1.AddNode("prefetch1", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); auto prefetch2 = subgraph_builder1.AddNode("prefetch2", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); auto subgraph1_out = subgraph_builder1.AddNode("subgraph1_out", NETOUTPUT, 1, 0); auto const1 = subgraph_builder1.AddNode("const1", CONSTANTOP, 0, 1); auto add1 = subgraph_builder1.AddNode("add1", ADD, 2, 1); auto add2 = subgraph_builder1.AddNode("add2", ADD, 2, 1); subgraph_builder1.AddDataEdge(w1, 0, prefetch1, 0); subgraph_builder1.AddDataEdge(w2, 0, prefetch2, 0); subgraph_builder1.AddDataEdge(const1, 0, add1, 0); subgraph_builder1.AddDataEdge(prefetch1, 0, add1, 1); subgraph_builder1.AddDataEdge(add1, 0, add2, 0); subgraph_builder1.AddDataEdge(prefetch2, 0, add2, 1); subgraph_builder1.AddDataEdge(add2, 0, subgraph1_out, 0); auto subgraph1 = subgraph_builder1.GetGraph(); for (auto &node : subgraph1->GetDirectNode()) { node->SetOwnerComputeGraph(subgraph1); } // Subgraph2 auto subgraph_builder2 = InnerGraphBuilder("Subgraph2"); auto w3 = subgraph_builder2.AddNode("w3", VARIABLE, 0, 1); auto w4 = subgraph_builder2.AddNode("w4", VARIABLE, 0, 1); auto w5 = subgraph_builder2.AddNode("w5", VARIABLE, 0, 1); auto prefetch3 = subgraph_builder2.AddNode("prefetch3", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); auto prefetch4 = subgraph_builder2.AddNode("prefetch4", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); auto prefetch5 = subgraph_builder2.AddNode("prefetch5", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); auto add3 = subgraph_builder2.AddNode("add3", ADD, 2, 1); auto add4 = subgraph_builder2.AddNode("add4", ADD, 2, 1); auto add5 = subgraph_builder2.AddNode("add5", ADD, 2, 1); auto data1 = subgraph_builder2.AddNode("data1", DATA, 0, 1); auto subgraph2_out = subgraph_builder2.AddNode("subgraph2_out", NETOUTPUT, 1, 1); subgraph_builder2.AddDataEdge(w3, 0, prefetch3, 0); subgraph_builder2.AddDataEdge(w4, 0, prefetch4, 0); subgraph_builder2.AddDataEdge(w5, 0, prefetch5, 0); subgraph_builder2.AddDataEdge(data1, 0, add3, 0); subgraph_builder2.AddDataEdge(prefetch3, 0, add3, 1); subgraph_builder2.AddDataEdge(add3, 0, add4, 0); subgraph_builder2.AddDataEdge(prefetch4, 0, add4, 1); subgraph_builder2.AddDataEdge(add4, 0, add5, 0); subgraph_builder2.AddDataEdge(prefetch5, 0, add5, 1); subgraph_builder2.AddDataEdge(add5, 0, subgraph2_out, 0); auto subgraph2 = subgraph_builder2.GetGraph(); for (auto &node : subgraph2->GetDirectNode()) { node->SetOwnerComputeGraph(subgraph2); } // root graph auto call_node1 = builder.AddNode("call_node1", PARTITIONEDCALL, 0, 1); auto call_node2 = builder.AddNode("call_node2", PARTITIONEDCALL, 1, 0); auto net_output = subgraph_builder2.AddNode("net_output", NETOUTPUT, 1, 0); builder.AddDataEdge(call_node1, 0, call_node2, 0); builder.AddDataEdge(call_node2, 0, net_output, 0); auto compute_graph = builder.GetGraph(); call_node1->SetOwnerComputeGraph(compute_graph); call_node1->GetOpDesc()->AddSubgraphName(subgraph1->GetName()); call_node1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph1->GetName()); call_node2->SetOwnerComputeGraph(compute_graph); call_node2->GetOpDesc()->AddSubgraphName(subgraph2->GetName()); call_node2->GetOpDesc()->SetSubgraphInstanceName(0, subgraph2->GetName()); subgraph1->SetParentNode(call_node1); subgraph1->SetParentGraph(compute_graph); subgraph2->SetParentNode(call_node2); subgraph2->SetParentGraph(compute_graph); compute_graph->AddSubGraph(subgraph1); compute_graph->AddSubGraph(subgraph2); return compute_graph; } /// /// BuildGraphWithMultiBatch: Different batch label /// /// /// batch_label_128 /// /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 --- /// / / / / / / \. /// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \. /// const1 switch_false / / / / / \. /// \ / / / / / / \. /// switch1 w1 w2 w3 w4 w5 merge1 -- net_output /// / \ \ \ \ \ \ / /// const2 switch_true \ \ \ \ \ / /// \c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 / /// \ \ \ \ \ \ / /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 --- /// /// batch_label_256 /// /// /// Memory distribution: /// /// |___w1__|__w2__|__w3__|__| /// /// |_____w4_____|_____w5____| /// ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiBatch() { auto builder = InnerGraphBuilder(graph_name_); auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); auto const2 = builder.AddNode("const2", CONSTANTOP, 0, 1); auto switch1 = builder.AddNode("switch1", SWITCH, 2, 2); auto switch_false = builder.AddNode("switch_false", IDENTITY, 1, 1); auto switch_true = builder.AddNode("switch_true", IDENTITY, 1, 1); auto merge1 = builder.AddNode("merge1", MERGE, 2, 2); auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0); builder.AddDataEdge(const1, 0, switch1, 0); builder.AddDataEdge(const2, 0, switch1, 1); builder.AddDataEdge(switch1, 0, switch_false, 0); builder.AddDataEdge(switch1, 1, switch_true, 0); builder.AddDataEdge(merge1, 0, net_output, 0); std::string batch_label_128 = "batch_128"; std::string batch_label_256 = "batch_256"; const int64_t buffer_pool_id = 0; const int64_t buffer_pool_size = 5600; { auto prefetch1 = builder.AddNode("batch_label_128/prefetch1", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128); auto prefetch2 = builder.AddNode("batch_label_128/prefetch2", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128); auto prefetch3 = builder.AddNode("batch_label_128/prefetch3", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128); auto prefetch4 = builder.AddNode("batch_label_128/prefetch4", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_128); auto prefetch5 = builder.AddNode("batch_label_128/prefetch5", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_128); auto add1 = builder.AddNode("batch_label_128/add1", ADD, 2, 1); SetBatchLabel(add1, batch_label_128); auto add2 = builder.AddNode("batch_label_128/add2", ADD, 2, 1); SetBatchLabel(add2, batch_label_128); auto add3 = builder.AddNode("batch_label_128/add3", ADD, 2, 1); SetBatchLabel(add3, batch_label_128); auto add4 = builder.AddNode("batch_label_128/add4", ADD, 2, 1); SetBatchLabel(add4, batch_label_128); auto add5 = builder.AddNode("batch_label_128/add5", ADD, 2, 1); SetBatchLabel(add5, batch_label_128); auto const1 = builder.AddNode("batch_label_128/const1", CONSTANTOP, 0, 1); SetBatchLabel(const1, batch_label_128); builder.AddDataEdge(w1, 0, prefetch1, 0); builder.AddDataEdge(w2, 0, prefetch2, 0); builder.AddDataEdge(w3, 0, prefetch3, 0); builder.AddDataEdge(w4, 0, prefetch4, 0); builder.AddDataEdge(w5, 0, prefetch5, 0); builder.AddDataEdge(const1, 0, add1, 0); builder.AddDataEdge(prefetch1, 0, add1, 1); builder.AddDataEdge(add1, 0, add2, 0); builder.AddDataEdge(prefetch2, 0, add2, 1); builder.AddDataEdge(add2, 0, add3, 0); builder.AddDataEdge(prefetch3, 0, add3, 1); builder.AddDataEdge(add3, 0, add4, 0); builder.AddDataEdge(prefetch4, 0, add4, 1); builder.AddDataEdge(add4, 0, add5, 0); builder.AddDataEdge(prefetch5, 0, add5, 1); builder.AddDataEdge(add5, 0, merge1, 0); builder.AddControlEdge(switch_false, const1); } { auto prefetch1 = builder.AddNode("batch_label_256/prefetch1", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256); auto prefetch2 = builder.AddNode("batch_label_256/prefetch2", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256); auto prefetch3 = builder.AddNode("batch_label_256/prefetch3", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256); auto prefetch4 = builder.AddNode("batch_label_256/prefetch4", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_256); auto prefetch5 = builder.AddNode("batch_label_256/prefetch5", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_256); auto add1 = builder.AddNode("batch_label_256/add1", ADD, 2, 1); SetBatchLabel(add1, batch_label_256); auto add2 = builder.AddNode("batch_label_256/add2", ADD, 2, 1); SetBatchLabel(add2, batch_label_256); auto add3 = builder.AddNode("batch_label_256/add3", ADD, 2, 1); SetBatchLabel(add3, batch_label_256); auto add4 = builder.AddNode("batch_label_256/add4", ADD, 2, 1); SetBatchLabel(add4, batch_label_256); auto add5 = builder.AddNode("batch_label_256/add5", ADD, 2, 1); SetBatchLabel(add5, batch_label_256); auto const1 = builder.AddNode("batch_label_256/const1", CONSTANTOP, 0, 1); SetBatchLabel(const1, batch_label_128); builder.AddDataEdge(w1, 0, prefetch1, 0); builder.AddDataEdge(w2, 0, prefetch2, 0); builder.AddDataEdge(w3, 0, prefetch3, 0); builder.AddDataEdge(w4, 0, prefetch4, 0); builder.AddDataEdge(w5, 0, prefetch5, 0); builder.AddDataEdge(const1, 0, add1, 0); builder.AddDataEdge(prefetch1, 0, add1, 1); builder.AddDataEdge(add1, 0, add2, 0); builder.AddDataEdge(prefetch2, 0, add2, 1); builder.AddDataEdge(add2, 0, add3, 0); builder.AddDataEdge(prefetch3, 0, add3, 1); builder.AddDataEdge(add3, 0, add4, 0); builder.AddDataEdge(prefetch4, 0, add4, 1); builder.AddDataEdge(add4, 0, add5, 0); builder.AddDataEdge(prefetch5, 0, add5, 1); builder.AddDataEdge(add5, 0, merge1, 1); builder.AddControlEdge(switch_true, const1); } auto compute_graph = builder.GetGraph(); return compute_graph; } /// /// GraphWithMultiOutputPrefetch: Prefetch has more than one output /// /// w1 w2 w3 w4 w5 /// \ \ \ \ \. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 /// / \ / \ / \ / \ / /// / \ / \ / \ / \ / /// const1 ----- add1 add2 add3 add4 add5 /// | \ | / | /// | \ | / | /// | \ | / | /// | \ | / | /// -------------- net_output --------------- /// /// Memory distribution: /// /// |___w1__|__w2__|__w3__|__| /// /// |_____w4_____|_____w5____| /// ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiOutputPrefetch() { auto builder = InnerGraphBuilder(graph_name_); auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); const int64_t buffer_pool_id = 0; const int64_t buffer_pool_size = 5600; auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}); auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}); auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}); auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}); auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); auto add1 = builder.AddNode("add1", ADD, 2, 1); auto add2 = builder.AddNode("add2", ADD, 2, 1); auto add3 = builder.AddNode("add3", ADD, 2, 1); auto add4 = builder.AddNode("add4", ADD, 2, 1); auto add5 = builder.AddNode("add5", ADD, 2, 1); auto net_output = builder.AddNode("net_output", NETOUTPUT, 5, 0); builder.AddDataEdge(w1, 0, prefetch1, 0); builder.AddDataEdge(w2, 0, prefetch2, 0); builder.AddDataEdge(w3, 0, prefetch3, 0); builder.AddDataEdge(w4, 0, prefetch4, 0); builder.AddDataEdge(w5, 0, prefetch5, 0); builder.AddDataEdge(const1, 0, add1, 0); builder.AddDataEdge(prefetch1, 0, add1, 1); builder.AddDataEdge(prefetch1, 0, add2, 0); builder.AddDataEdge(prefetch2, 0, add2, 1); builder.AddDataEdge(prefetch2, 0, add3, 0); builder.AddDataEdge(prefetch3, 0, add3, 1); builder.AddDataEdge(prefetch3, 0, add4, 0); builder.AddDataEdge(prefetch4, 0, add4, 1); builder.AddDataEdge(prefetch4, 0, add5, 0); builder.AddDataEdge(prefetch5, 0, add5, 1); builder.AddDataEdge(add1, 0, net_output, 0); builder.AddDataEdge(add2, 0, net_output, 1); builder.AddDataEdge(add3, 0, net_output, 2); builder.AddDataEdge(add4, 0, net_output, 3); builder.AddDataEdge(add5, 0, net_output, 4); auto compute_graph = builder.GetGraph(); return compute_graph; } /// /// GraphWithMultiOutputPrefetch: Prefetch has more than one output /// /// w1 w2 w3 w4 w5 /// \ / \ / \ / \ / \. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 /// / \ / \ / \ / \ / /// / \ / \ / \ / \ / /// const1 ----- add1 add2 add3 add4 add5 /// | \ | / | /// | \ | / | /// | \ | / | /// | \ | / | /// -------------- net_output --------------- /// /// Memory distribution: /// /// |___w1__|__w2__|__w3__|__| /// /// |_____w4_____|_____w5____| /// ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiInputOutputPrefetch() { auto builder = InnerGraphBuilder(graph_name_); auto w1 = builder.AddNode("w1", VARIABLE, 0, 1); auto w2 = builder.AddNode("w2", VARIABLE, 0, 1); auto w3 = builder.AddNode("w3", VARIABLE, 0, 1); auto w4 = builder.AddNode("w4", VARIABLE, 0, 1); auto w5 = builder.AddNode("w5", VARIABLE, 0, 1); const int64_t buffer_pool_id = 0; const int64_t buffer_pool_size = 5600; auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 2, 2); SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500, 500}); auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 2, 2); SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500, 500}); auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 2, 2); SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500, 1024}); auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 2, 2); SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024, 1024}); auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1); SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}); auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1); auto add1 = builder.AddNode("add1", ADD, 2, 1); auto add2 = builder.AddNode("add2", ADD, 2, 1); auto add3 = builder.AddNode("add3", ADD, 2, 1); auto add4 = builder.AddNode("add4", ADD, 2, 1); auto add5 = builder.AddNode("add5", ADD, 2, 1); auto net_output = builder.AddNode("net_output", NETOUTPUT, 5, 0); builder.AddDataEdge(w1, 0, prefetch1, 0); builder.AddDataEdge(w2, 0, prefetch1, 1); builder.AddDataEdge(w2, 0, prefetch2, 0); builder.AddDataEdge(w3, 0, prefetch2, 1); builder.AddDataEdge(w3, 0, prefetch3, 0); builder.AddDataEdge(w4, 0, prefetch3, 1); builder.AddDataEdge(w4, 0, prefetch4, 0); builder.AddDataEdge(w5, 0, prefetch4, 1); builder.AddDataEdge(w5, 0, prefetch5, 0); builder.AddDataEdge(const1, 0, add1, 0); builder.AddDataEdge(prefetch1, 0, add1, 1); builder.AddDataEdge(prefetch1, 1, add2, 0); builder.AddDataEdge(prefetch2, 0, add2, 1); builder.AddDataEdge(prefetch2, 1, add3, 0); builder.AddDataEdge(prefetch3, 0, add3, 1); builder.AddDataEdge(prefetch3, 1, add4, 0); builder.AddDataEdge(prefetch4, 0, add4, 1); builder.AddDataEdge(prefetch4, 1, add5, 0); builder.AddDataEdge(prefetch5, 0, add5, 1); builder.AddDataEdge(add1, 0, net_output, 0); builder.AddDataEdge(add2, 0, net_output, 1); builder.AddDataEdge(add3, 0, net_output, 2); builder.AddDataEdge(add4, 0, net_output, 3); builder.AddDataEdge(add5, 0, net_output, 4); auto compute_graph = builder.GetGraph(); return compute_graph; } } // namespace ut } // namespace ge