|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978 |
- /**
- * Copyright 2019-2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include <gtest/gtest.h>
- #include "buffer_pool_graph_builder.h"
- #include "common/ge_inner_error_codes.h"
- #include "common/types.h"
- #include "graph/debug/ge_attr_define.h"
- #include "graph/utils/attr_utils.h"
- #include "graph/utils/graph_utils.h"
- #include "graph/utils/tensor_utils.h"
- #include "graph/utils/graph_utils.h"
-
- namespace ge {
- namespace ut {
- BufferPoolGraphBuilder::BufferPoolGraphBuilder(const std::string &name) {
- graph_name_ = name;
- }
-
- BufferPoolGraphBuilder::InnerGraphBuilder::InnerGraphBuilder(const std::string &name) {
- graph_ = std::make_shared<ComputeGraph>(name);
- EXPECT_NE(graph_, nullptr);
- }
-
- NodePtr BufferPoolGraphBuilder::InnerGraphBuilder::AddNode(const std::string &name, const std::string &type,
- int in_cnt, int out_cnt,
- Format format, DataType data_type,
- std::vector<int64_t> shape) {
- auto tensor_desc = std::make_shared<GeTensorDesc>();
- EXPECT_NE(tensor_desc, nullptr);
- tensor_desc->SetShape(GeShape(std::move(shape)));
- tensor_desc->SetFormat(format);
- tensor_desc->SetDataType(data_type);
- auto op_desc = std::make_shared<OpDesc>(name, type);
- EXPECT_NE(op_desc, nullptr);
- for (int i = 0; i < in_cnt; ++i) {
- op_desc->AddInputDesc(tensor_desc->Clone());
- }
- for (int i = 0; i < out_cnt; ++i) {
- op_desc->AddOutputDesc(tensor_desc->Clone());
- }
- return graph_->AddNode(op_desc);
- }
-
- void BufferPoolGraphBuilder::InnerGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx,
- NodePtr &dst_node, int dst_idx) {
- EXPECT_NE(src_node, nullptr);
- EXPECT_NE(dst_node, nullptr);
- GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx));
- }
-
- void BufferPoolGraphBuilder::InnerGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) {
- EXPECT_NE(src_node, nullptr);
- EXPECT_NE(dst_node, nullptr);
- GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor());
- }
-
- void BufferPoolGraphBuilder::SetBufferPool(NodePtr &node, int64_t pool_id, int64_t pool_size,
- const std::string &batch_label) {
- EXPECT_NE(node, nullptr);
- (void) AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_ID, pool_id);
- (void) AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_SIZE, pool_size);
- if (!batch_label.empty()) {
- (void) AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label);
- }
- }
-
- void BufferPoolGraphBuilder::SetBatchLabel(NodePtr &node, const std::string &batch_label) {
- EXPECT_NE(node, nullptr);
- (void) AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label);
-
- }
-
- void BufferPoolGraphBuilder::SetOutputMemSize(NodePtr &node, const std::vector<int64_t> &mem_size) {
- EXPECT_NE(node, nullptr);
- EXPECT_NE(node->GetOpDesc(), nullptr);
- size_t output_size = node->GetOpDesc()->GetOutputsSize();
- EXPECT_EQ(output_size, mem_size.size());
- for (size_t i = 0; i < output_size; ++i) {
- auto output_op_desc = node->GetOpDesc()->MutableOutputDesc(i);
- ge::TensorUtils::SetSize(*output_op_desc, mem_size[i]);
- }
- }
-
- void BufferPoolGraphBuilder::SetWorkSpaceMemSize(NodePtr &node, const std::vector<int64_t> &ws_bytes) {
- EXPECT_NE(node, nullptr);
- EXPECT_NE(node->GetOpDesc(), nullptr);
- node->GetOpDesc()->SetWorkspaceBytes(ws_bytes);
- }
-
- void BufferPoolGraphBuilder::SetPrefetchNodeInfo(NodePtr &node, int64_t pool_id, int64_t pool_size,
- const std::vector<int64_t> &mem_size,
- const std::vector<int64_t> &ws_bytes,
- const std::string &batch_label) {
- SetBufferPool(node, pool_id, pool_size, batch_label);
- SetOutputMemSize(node, mem_size);
- SetWorkSpaceMemSize(node, ws_bytes);
- }
-
- ///
- /// Normal graph
- ///
- /// w1 w2 w3 w4 w5
- /// \ \ \ \ \.
- /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
- /// \ \ \ \ \.
- /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
- ///
- ///
- /// Memory distribution:
- ///
- /// |___w1__|__w2__|__w3__|__|
- ///
- /// |_____w4_____|_____w5____|
- ///
- ComputeGraphPtr BufferPoolGraphBuilder::BuildNormalGraph() {
- auto builder = InnerGraphBuilder(graph_name_);
- auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
- auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
- auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
- auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
- auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);
-
- const int64_t buffer_pool_id = 0;
- const int64_t buffer_pool_size = 5600;
-
- auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
- auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
- auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
- auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
- auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});
-
- auto add1 = builder.AddNode("add1", ADD, 2, 1);
- auto add2 = builder.AddNode("add2", ADD, 2, 1);
- auto add3 = builder.AddNode("add3", ADD, 2, 1);
- auto add4 = builder.AddNode("add4", ADD, 2, 1);
- auto add5 = builder.AddNode("add5", ADD, 2, 1);
- auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
- auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0);
-
- builder.AddDataEdge(w1, 0, prefetch1, 0);
- builder.AddDataEdge(w2, 0, prefetch2, 0);
- builder.AddDataEdge(w3, 0, prefetch3, 0);
- builder.AddDataEdge(w4, 0, prefetch4, 0);
- builder.AddDataEdge(w5, 0, prefetch5, 0);
-
- builder.AddDataEdge(const1, 0, add1, 0);
- builder.AddDataEdge(prefetch1, 0, add1, 1);
-
- builder.AddDataEdge(add1, 0, add2, 0);
- builder.AddDataEdge(prefetch2, 0, add2, 1);
-
- builder.AddDataEdge(add2, 0, add3, 0);
- builder.AddDataEdge(prefetch3, 0, add3, 1);
-
- builder.AddDataEdge(add3, 0, add4, 0);
- builder.AddDataEdge(prefetch4, 0, add4, 1);
-
- builder.AddDataEdge(add4, 0, add5, 0);
- builder.AddDataEdge(prefetch5, 0, add5, 1);
-
- builder.AddDataEdge(add5, 0, net_output, 0);
-
- auto compute_graph = builder.GetGraph();
-
- return compute_graph;
- }
-
- ///
- /// Normal graph with multi buffer pool
- ///
- /// w1 w2 w3 w4 w5
- /// \ \ \ \ \.
- /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
- /// (pool0) (pool1) (pool0) (pool0) (pool1)
- /// \ \ \ \ \.
- /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
- ///
- ///
- /// Memory distribution:
- ///
- /// |___w1__|__w3__|_________|
- /// |_____w4_____|___________|
- ///
- /// |___w2__|_____w5___|_____|
- ///
- ComputeGraphPtr BufferPoolGraphBuilder::BuildNormalGraphWithMultiBufferPool() {
- auto builder = InnerGraphBuilder(graph_name_);
- auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
- auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
- auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
- auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
- auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);
-
- const int64_t buffer_pool_id_0 = 0;
- const int64_t buffer_pool_id_1 = 1;
- const int64_t buffer_pool_size = 5000;
-
- auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch1, buffer_pool_id_0, buffer_pool_size, {500});
- auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch2, buffer_pool_id_1, buffer_pool_size, {500});
- auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch3, buffer_pool_id_0, buffer_pool_size, {500});
- auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch4, buffer_pool_id_0, buffer_pool_size, {1024});
- auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch5, buffer_pool_id_1, buffer_pool_size, {1024});
-
- auto add1 = builder.AddNode("add1", ADD, 2, 1);
- auto add2 = builder.AddNode("add2", ADD, 2, 1);
- auto add3 = builder.AddNode("add3", ADD, 2, 1);
- auto add4 = builder.AddNode("add4", ADD, 2, 1);
- auto add5 = builder.AddNode("add5", ADD, 2, 1);
- auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
- auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0);
-
- builder.AddDataEdge(w1, 0, prefetch1, 0);
- builder.AddDataEdge(w2, 0, prefetch2, 0);
- builder.AddDataEdge(w3, 0, prefetch3, 0);
- builder.AddDataEdge(w4, 0, prefetch4, 0);
- builder.AddDataEdge(w5, 0, prefetch5, 0);
-
- builder.AddDataEdge(const1, 0, add1, 0);
- builder.AddDataEdge(prefetch1, 0, add1, 1);
-
- builder.AddDataEdge(add1, 0, add2, 0);
- builder.AddDataEdge(prefetch2, 0, add2, 1);
-
- builder.AddDataEdge(add2, 0, add3, 0);
- builder.AddDataEdge(prefetch3, 0, add3, 1);
-
- builder.AddDataEdge(add3, 0, add4, 0);
- builder.AddDataEdge(prefetch4, 0, add4, 1);
-
- builder.AddDataEdge(add4, 0, add5, 0);
- builder.AddDataEdge(prefetch5, 0, add5, 1);
-
- builder.AddDataEdge(add5, 0, net_output, 0);
-
- auto compute_graph = builder.GetGraph();
-
- return compute_graph;
- }
-
- ///
- /// SerialGraph: Buffer pool size only can contain one prefetch node
- ///
- /// w1 w2 w3 w4 w5
- /// \ \ \ \ \.
- /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
- /// \ \ \ \ \.
- /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
- ///
- ///
- /// Memory distribution:
- ///
- /// |____w1_____|__|
- ///
- /// |____w2_____|__|
- ///
- /// |____w3_____|__|
- ///
- /// |______w4______|
- ///
- /// |______w5______|
- ///
- ComputeGraphPtr BufferPoolGraphBuilder::BuildSerialGraph() {
- auto builder = InnerGraphBuilder(graph_name_);
- auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
- auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
- auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
- auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
- auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);
-
- const int64_t buffer_pool_id = 0;
- const int64_t buffer_pool_size = 2048;
-
- auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
- auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
- auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
- auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
- auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});
-
- auto add1 = builder.AddNode("add1", ADD, 2, 1);
- auto add2 = builder.AddNode("add2", ADD, 2, 1);
- auto add3 = builder.AddNode("add3", ADD, 2, 1);
- auto add4 = builder.AddNode("add4", ADD, 2, 1);
- auto add5 = builder.AddNode("add5", ADD, 2, 1);
- auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
- auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0);
-
- builder.AddDataEdge(w1, 0, prefetch1, 0);
- builder.AddDataEdge(w2, 0, prefetch2, 0);
- builder.AddDataEdge(w3, 0, prefetch3, 0);
- builder.AddDataEdge(w4, 0, prefetch4, 0);
- builder.AddDataEdge(w5, 0, prefetch5, 0);
-
- builder.AddDataEdge(const1, 0, add1, 0);
- builder.AddDataEdge(prefetch1, 0, add1, 1);
-
- builder.AddDataEdge(add1, 0, add2, 0);
- builder.AddDataEdge(prefetch2, 0, add2, 1);
-
- builder.AddDataEdge(add2, 0, add3, 0);
- builder.AddDataEdge(prefetch3, 0, add3, 1);
-
- builder.AddDataEdge(add3, 0, add4, 0);
- builder.AddDataEdge(prefetch4, 0, add4, 1);
-
- builder.AddDataEdge(add4, 0, add5, 0);
- builder.AddDataEdge(prefetch5, 0, add5, 1);
-
- builder.AddDataEdge(add5, 0, net_output, 0);
-
- auto compute_graph = builder.GetGraph();
-
- return compute_graph;
- }
-
- ///
- /// GraphWithMultiPrefetch: Calc node with more prefetch node
- ///
- /// w1 w2 w3 w4 w5
- /// \ \ \ \ \.
- /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 const1
- /// \ / \ / \ /
- /// \ / \ / \ /
- /// \ / \ / \ /
- /// add1 ------ c ------- add2 ----- c ----- add3
- /// | | |
- /// | | |
- /// --------------- net_output ------------
- ///
- /// Memory distribution:
- ///
- /// |___w1__|__w2__|__w3__|__|
- ///
- /// |_____w4_____|_____w5____|
- ///
- ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiPrefetch() {
- auto builder = InnerGraphBuilder(graph_name_);
- auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
- auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
- auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
- auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
- auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);
-
- const int64_t buffer_pool_id = 0;
- const int64_t buffer_pool_size = 5600;
-
- auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
- auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
- auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
- auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
- auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});
-
- auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
- auto add1 = builder.AddNode("add1", ADD, 2, 1);
- auto add2 = builder.AddNode("add2", ADD, 2, 1);
- auto add3 = builder.AddNode("add3", ADD, 2, 1);
- auto net_output = builder.AddNode("net_output", NETOUTPUT, 3, 0);
-
- builder.AddDataEdge(w1, 0, prefetch1, 0);
- builder.AddDataEdge(w2, 0, prefetch2, 0);
- builder.AddDataEdge(w3, 0, prefetch3, 0);
- builder.AddDataEdge(w4, 0, prefetch4, 0);
- builder.AddDataEdge(w5, 0, prefetch5, 0);
-
- builder.AddDataEdge(prefetch1, 0, add1, 0);
- builder.AddDataEdge(prefetch2, 0, add1, 1);
-
- builder.AddDataEdge(prefetch3, 0, add2, 0);
- builder.AddDataEdge(prefetch4, 0, add2, 1);
-
- builder.AddDataEdge(const1, 0, add3, 0);
- builder.AddDataEdge(prefetch5, 0, add3, 1);
-
- builder.AddDataEdge(add1, 0, net_output, 0);
- builder.AddDataEdge(add2, 0, net_output, 1);
- builder.AddDataEdge(add3, 0, net_output, 2);
-
- builder.AddControlEdge(add1, add2);
- builder.AddControlEdge(add2, add3);
-
- auto compute_graph = builder.GetGraph();
-
- return compute_graph;
- }
-
- ///
- /// GraphWithSubgraph: Calc node in different subgraph
- ///
- ///
- /// call_node1(with Subgraph1) --------------- call_node2 (with Subgraph2) --------------- net_output
- ///
- ///
- /// Subgraph1: Subgraph2:
- ///
- /// w1 w2 w3 w4 w5
- /// \ \ \ \ \.
- /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
- /// \ \ \ \ \.
- /// const1 ----- add1 ----- add2 ----- add3 ---- subgraph1_out data1 ---- add4 ----- add5 ---- subgraph2_out
- ///
- ///
- /// Memory distribution:
- ///
- /// |___w1__|__w2__|__w3__|__|
- ///
- /// |_____w4_____|_____w5____|
- ///
- ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithSubgraph() {
- auto builder = InnerGraphBuilder(graph_name_);
-
- const int64_t buffer_pool_id = 0;
- const int64_t buffer_pool_size = 5600;
-
- // Subgraph1
- auto subgraph_builder1 = InnerGraphBuilder("Subgraph1");
- auto w1 = subgraph_builder1.AddNode("w1", VARIABLE, 0, 1);
- auto w2 = subgraph_builder1.AddNode("w2", VARIABLE, 0, 1);
- auto w3 = subgraph_builder1.AddNode("w3", VARIABLE, 0, 1);
-
- auto prefetch1 = subgraph_builder1.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
- auto prefetch2 = subgraph_builder1.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
- auto prefetch3 = subgraph_builder1.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
- auto subgraph1_out = subgraph_builder1.AddNode("subgraph1_out", NETOUTPUT, 1, 0);
- auto const1 = subgraph_builder1.AddNode("const1", CONSTANTOP, 0, 1);
-
- auto add1 = subgraph_builder1.AddNode("add1", ADD, 2, 1);
- auto add2 = subgraph_builder1.AddNode("add2", ADD, 2, 1);
- auto add3 = subgraph_builder1.AddNode("add3", ADD, 2, 1);
-
- subgraph_builder1.AddDataEdge(w1, 0, prefetch1, 0);
- subgraph_builder1.AddDataEdge(w2, 0, prefetch2, 0);
- subgraph_builder1.AddDataEdge(w3, 0, prefetch3, 0);
- subgraph_builder1.AddDataEdge(const1, 0, add1, 0);
- subgraph_builder1.AddDataEdge(prefetch1, 0, add1, 1);
- subgraph_builder1.AddDataEdge(add1, 0, add2, 0);
- subgraph_builder1.AddDataEdge(prefetch2, 0, add2, 1);
- subgraph_builder1.AddDataEdge(add2, 0, add3, 0);
- subgraph_builder1.AddDataEdge(prefetch3, 0, add3, 1);
- subgraph_builder1.AddDataEdge(add3, 0, subgraph1_out, 0);
- auto subgraph1 = subgraph_builder1.GetGraph();
- for (auto &node : subgraph1->GetDirectNode()) {
- node->SetOwnerComputeGraph(subgraph1);
- }
-
- // Subgraph2
- auto subgraph_builder2 = InnerGraphBuilder("Subgraph2");
- auto w4 = subgraph_builder2.AddNode("w4", VARIABLE, 0, 1);
- auto w5 = subgraph_builder2.AddNode("w5", VARIABLE, 0, 1);
-
- auto prefetch4 = subgraph_builder2.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
- auto prefetch5 = subgraph_builder2.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});
-
- auto add4 = subgraph_builder2.AddNode("add4", ADD, 2, 1);
- auto add5 = subgraph_builder2.AddNode("add5", ADD, 2, 1);
- auto data1 = subgraph_builder2.AddNode("data1", DATA, 0, 1);
- auto subgraph2_out = subgraph_builder2.AddNode("subgraph2_out", NETOUTPUT, 1, 1);
-
- subgraph_builder2.AddDataEdge(w4, 0, prefetch4, 0);
- subgraph_builder2.AddDataEdge(w5, 0, prefetch5, 0);
- subgraph_builder2.AddDataEdge(data1, 0, add4, 0);
- subgraph_builder2.AddDataEdge(prefetch4, 0, add4, 1);
- subgraph_builder2.AddDataEdge(add4, 0, add5, 0);
- subgraph_builder2.AddDataEdge(prefetch5, 0, add5, 1);
- subgraph_builder2.AddDataEdge(add5, 0, subgraph2_out, 0);
-
- auto subgraph2 = subgraph_builder2.GetGraph();
- for (auto &node : subgraph2->GetDirectNode()) {
- node->SetOwnerComputeGraph(subgraph2);
- }
-
- // root graph
- auto call_node1 = builder.AddNode("call_node1", PARTITIONEDCALL, 0, 1);
- auto call_node2 = builder.AddNode("call_node2", PARTITIONEDCALL, 1, 0);
- auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0);
- builder.AddDataEdge(call_node1, 0, call_node2, 0);
- builder.AddDataEdge(call_node2, 0, net_output, 0);
- auto compute_graph = builder.GetGraph();
- call_node1->SetOwnerComputeGraph(compute_graph);
- call_node1->GetOpDesc()->AddSubgraphName(subgraph1->GetName());
- call_node1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph1->GetName());
- call_node2->SetOwnerComputeGraph(compute_graph);
- call_node2->GetOpDesc()->AddSubgraphName(subgraph2->GetName());
- call_node2->GetOpDesc()->SetSubgraphInstanceName(0, subgraph2->GetName());
-
- subgraph1->SetParentNode(call_node1);
- subgraph1->SetParentGraph(compute_graph);
- subgraph2->SetParentNode(call_node2);
- subgraph2->SetParentGraph(compute_graph);
- compute_graph->AddSubGraph(subgraph1);
- compute_graph->AddSubGraph(subgraph2);
-
- return compute_graph;
- }
-
- ///
- /// SubgraphWithInnerDependency: Calc node in different subgraph with inner dependency
- ///
- ///
- /// call_node1(with Subgraph1) --------------------- call_node2 (with Subgraph2) ---------- net_output
- ///
- ///
- /// Subgraph1: Subgraph2:
- ///
- /// w1 w2 w3 w4 w5
- /// \ \ \ \ \.
- /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
- /// \ \ \ \ \.
- /// const1 ----- add1 ----- add2 ----- subgraph1_out data1 ---- add3 ---- add4 ----- add5 ---- subgraph2_out
- ///
- ///
- /// Memory distribution:
- ///
- /// |___w1__|__w2__|__w3__|__|
- ///
- /// |_____w4_____|_____w5____|
- ///
- ComputeGraphPtr BufferPoolGraphBuilder::BuildSubgraphWithInnerDependency() {
- auto builder = InnerGraphBuilder(graph_name_);
-
- const int64_t buffer_pool_id = 0;
- const int64_t buffer_pool_size = 5600;
-
- // Subgraph1
- auto subgraph_builder1 = InnerGraphBuilder("Subgraph1");
- auto w1 = subgraph_builder1.AddNode("w1", VARIABLE, 0, 1);
- auto w2 = subgraph_builder1.AddNode("w2", VARIABLE, 0, 1);
-
- auto prefetch1 = subgraph_builder1.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
- auto prefetch2 = subgraph_builder1.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
- auto subgraph1_out = subgraph_builder1.AddNode("subgraph1_out", NETOUTPUT, 1, 0);
- auto const1 = subgraph_builder1.AddNode("const1", CONSTANTOP, 0, 1);
-
- auto add1 = subgraph_builder1.AddNode("add1", ADD, 2, 1);
- auto add2 = subgraph_builder1.AddNode("add2", ADD, 2, 1);
-
- subgraph_builder1.AddDataEdge(w1, 0, prefetch1, 0);
- subgraph_builder1.AddDataEdge(w2, 0, prefetch2, 0);
- subgraph_builder1.AddDataEdge(const1, 0, add1, 0);
- subgraph_builder1.AddDataEdge(prefetch1, 0, add1, 1);
- subgraph_builder1.AddDataEdge(add1, 0, add2, 0);
- subgraph_builder1.AddDataEdge(prefetch2, 0, add2, 1);
- subgraph_builder1.AddDataEdge(add2, 0, subgraph1_out, 0);
- auto subgraph1 = subgraph_builder1.GetGraph();
- for (auto &node : subgraph1->GetDirectNode()) {
- node->SetOwnerComputeGraph(subgraph1);
- }
-
- // Subgraph2
- auto subgraph_builder2 = InnerGraphBuilder("Subgraph2");
- auto w3 = subgraph_builder2.AddNode("w3", VARIABLE, 0, 1);
- auto w4 = subgraph_builder2.AddNode("w4", VARIABLE, 0, 1);
- auto w5 = subgraph_builder2.AddNode("w5", VARIABLE, 0, 1);
-
- auto prefetch3 = subgraph_builder2.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
- auto prefetch4 = subgraph_builder2.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
- auto prefetch5 = subgraph_builder2.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});
-
- auto add3 = subgraph_builder2.AddNode("add3", ADD, 2, 1);
- auto add4 = subgraph_builder2.AddNode("add4", ADD, 2, 1);
- auto add5 = subgraph_builder2.AddNode("add5", ADD, 2, 1);
- auto data1 = subgraph_builder2.AddNode("data1", DATA, 0, 1);
- auto subgraph2_out = subgraph_builder2.AddNode("subgraph2_out", NETOUTPUT, 1, 1);
-
- subgraph_builder2.AddDataEdge(w3, 0, prefetch3, 0);
- subgraph_builder2.AddDataEdge(w4, 0, prefetch4, 0);
- subgraph_builder2.AddDataEdge(w5, 0, prefetch5, 0);
- subgraph_builder2.AddDataEdge(data1, 0, add3, 0);
- subgraph_builder2.AddDataEdge(prefetch3, 0, add3, 1);
- subgraph_builder2.AddDataEdge(add3, 0, add4, 0);
- subgraph_builder2.AddDataEdge(prefetch4, 0, add4, 1);
- subgraph_builder2.AddDataEdge(add4, 0, add5, 0);
- subgraph_builder2.AddDataEdge(prefetch5, 0, add5, 1);
- subgraph_builder2.AddDataEdge(add5, 0, subgraph2_out, 0);
-
- auto subgraph2 = subgraph_builder2.GetGraph();
- for (auto &node : subgraph2->GetDirectNode()) {
- node->SetOwnerComputeGraph(subgraph2);
- }
-
- // root graph
- auto call_node1 = builder.AddNode("call_node1", PARTITIONEDCALL, 0, 1);
- auto call_node2 = builder.AddNode("call_node2", PARTITIONEDCALL, 1, 0);
- auto net_output = subgraph_builder2.AddNode("net_output", NETOUTPUT, 1, 0);
- builder.AddDataEdge(call_node1, 0, call_node2, 0);
- builder.AddDataEdge(call_node2, 0, net_output, 0);
- auto compute_graph = builder.GetGraph();
- call_node1->SetOwnerComputeGraph(compute_graph);
- call_node1->GetOpDesc()->AddSubgraphName(subgraph1->GetName());
- call_node1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph1->GetName());
- call_node2->SetOwnerComputeGraph(compute_graph);
- call_node2->GetOpDesc()->AddSubgraphName(subgraph2->GetName());
- call_node2->GetOpDesc()->SetSubgraphInstanceName(0, subgraph2->GetName());
-
- subgraph1->SetParentNode(call_node1);
- subgraph1->SetParentGraph(compute_graph);
- subgraph2->SetParentNode(call_node2);
- subgraph2->SetParentGraph(compute_graph);
- compute_graph->AddSubGraph(subgraph1);
- compute_graph->AddSubGraph(subgraph2);
-
- return compute_graph;
- }
-
- ///
- /// BuildGraphWithMultiBatch: Different batch label
- ///
- ///
- /// batch_label_128
- ///
- /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ---
- /// / / / / / / \.
- /// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \.
- /// const1 switch_false / / / / / \.
- /// \ / / / / / / \.
- /// switch1 w1 w2 w3 w4 w5 merge1 -- net_output
- /// / \ \ \ \ \ \ /
- /// const2 switch_true \ \ \ \ \ /
- /// \c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 /
- /// \ \ \ \ \ \ /
- /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ---
- ///
- /// batch_label_256
- ///
- ///
- /// Memory distribution:
- ///
- /// |___w1__|__w2__|__w3__|__|
- ///
- /// |_____w4_____|_____w5____|
- ///
- ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiBatch() {
- auto builder = InnerGraphBuilder(graph_name_);
- auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
- auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
- auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
- auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
- auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);
-
- auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
- auto const2 = builder.AddNode("const2", CONSTANTOP, 0, 1);
- auto switch1 = builder.AddNode("switch1", SWITCH, 2, 2);
- auto switch_false = builder.AddNode("switch_false", IDENTITY, 1, 1);
- auto switch_true = builder.AddNode("switch_true", IDENTITY, 1, 1);
- auto merge1 = builder.AddNode("merge1", MERGE, 2, 2);
- auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0);
-
- builder.AddDataEdge(const1, 0, switch1, 0);
- builder.AddDataEdge(const2, 0, switch1, 1);
- builder.AddDataEdge(switch1, 0, switch_false, 0);
- builder.AddDataEdge(switch1, 1, switch_true, 0);
- builder.AddDataEdge(merge1, 0, net_output, 0);
-
- std::string batch_label_128 = "batch_128";
- std::string batch_label_256 = "batch_256";
-
- const int64_t buffer_pool_id = 0;
- const int64_t buffer_pool_size = 5600;
-
- {
- auto prefetch1 = builder.AddNode("batch_label_128/prefetch1", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128);
- auto prefetch2 = builder.AddNode("batch_label_128/prefetch2", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128);
- auto prefetch3 = builder.AddNode("batch_label_128/prefetch3", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128);
- auto prefetch4 = builder.AddNode("batch_label_128/prefetch4", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_128);
- auto prefetch5 = builder.AddNode("batch_label_128/prefetch5", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_128);
-
- auto add1 = builder.AddNode("batch_label_128/add1", ADD, 2, 1);
- SetBatchLabel(add1, batch_label_128);
- auto add2 = builder.AddNode("batch_label_128/add2", ADD, 2, 1);
- SetBatchLabel(add2, batch_label_128);
- auto add3 = builder.AddNode("batch_label_128/add3", ADD, 2, 1);
- SetBatchLabel(add3, batch_label_128);
- auto add4 = builder.AddNode("batch_label_128/add4", ADD, 2, 1);
- SetBatchLabel(add4, batch_label_128);
- auto add5 = builder.AddNode("batch_label_128/add5", ADD, 2, 1);
- SetBatchLabel(add5, batch_label_128);
- auto const1 = builder.AddNode("batch_label_128/const1", CONSTANTOP, 0, 1);
- SetBatchLabel(const1, batch_label_128);
-
- builder.AddDataEdge(w1, 0, prefetch1, 0);
- builder.AddDataEdge(w2, 0, prefetch2, 0);
- builder.AddDataEdge(w3, 0, prefetch3, 0);
- builder.AddDataEdge(w4, 0, prefetch4, 0);
- builder.AddDataEdge(w5, 0, prefetch5, 0);
-
- builder.AddDataEdge(const1, 0, add1, 0);
- builder.AddDataEdge(prefetch1, 0, add1, 1);
-
- builder.AddDataEdge(add1, 0, add2, 0);
- builder.AddDataEdge(prefetch2, 0, add2, 1);
-
- builder.AddDataEdge(add2, 0, add3, 0);
- builder.AddDataEdge(prefetch3, 0, add3, 1);
-
- builder.AddDataEdge(add3, 0, add4, 0);
- builder.AddDataEdge(prefetch4, 0, add4, 1);
-
- builder.AddDataEdge(add4, 0, add5, 0);
- builder.AddDataEdge(prefetch5, 0, add5, 1);
-
- builder.AddDataEdge(add5, 0, merge1, 0);
- builder.AddControlEdge(switch_false, const1);
- }
-
- {
- auto prefetch1 = builder.AddNode("batch_label_256/prefetch1", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256);
- auto prefetch2 = builder.AddNode("batch_label_256/prefetch2", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256);
- auto prefetch3 = builder.AddNode("batch_label_256/prefetch3", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256);
- auto prefetch4 = builder.AddNode("batch_label_256/prefetch4", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_256);
- auto prefetch5 = builder.AddNode("batch_label_256/prefetch5", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_256);
-
- auto add1 = builder.AddNode("batch_label_256/add1", ADD, 2, 1);
- SetBatchLabel(add1, batch_label_256);
- auto add2 = builder.AddNode("batch_label_256/add2", ADD, 2, 1);
- SetBatchLabel(add2, batch_label_256);
- auto add3 = builder.AddNode("batch_label_256/add3", ADD, 2, 1);
- SetBatchLabel(add3, batch_label_256);
- auto add4 = builder.AddNode("batch_label_256/add4", ADD, 2, 1);
- SetBatchLabel(add4, batch_label_256);
- auto add5 = builder.AddNode("batch_label_256/add5", ADD, 2, 1);
- SetBatchLabel(add5, batch_label_256);
- auto const1 = builder.AddNode("batch_label_256/const1", CONSTANTOP, 0, 1);
- SetBatchLabel(const1, batch_label_128);
-
- builder.AddDataEdge(w1, 0, prefetch1, 0);
- builder.AddDataEdge(w2, 0, prefetch2, 0);
- builder.AddDataEdge(w3, 0, prefetch3, 0);
- builder.AddDataEdge(w4, 0, prefetch4, 0);
- builder.AddDataEdge(w5, 0, prefetch5, 0);
-
- builder.AddDataEdge(const1, 0, add1, 0);
- builder.AddDataEdge(prefetch1, 0, add1, 1);
-
- builder.AddDataEdge(add1, 0, add2, 0);
- builder.AddDataEdge(prefetch2, 0, add2, 1);
-
- builder.AddDataEdge(add2, 0, add3, 0);
- builder.AddDataEdge(prefetch3, 0, add3, 1);
-
- builder.AddDataEdge(add3, 0, add4, 0);
- builder.AddDataEdge(prefetch4, 0, add4, 1);
-
- builder.AddDataEdge(add4, 0, add5, 0);
- builder.AddDataEdge(prefetch5, 0, add5, 1);
-
- builder.AddDataEdge(add5, 0, merge1, 1);
-
- builder.AddControlEdge(switch_true, const1);
- }
-
- auto compute_graph = builder.GetGraph();
-
- return compute_graph;
- }
-
- ///
- /// GraphWithMultiOutputPrefetch: Prefetch has more than one output
- ///
- /// w1 w2 w3 w4 w5
- /// \ \ \ \ \.
- /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
- /// / \ / \ / \ / \ /
- /// / \ / \ / \ / \ /
- /// const1 ----- add1 add2 add3 add4 add5
- /// | \ | / |
- /// | \ | / |
- /// | \ | / |
- /// | \ | / |
- /// -------------- net_output ---------------
- ///
- /// Memory distribution:
- ///
- /// |___w1__|__w2__|__w3__|__|
- ///
- /// |_____w4_____|_____w5____|
- ///
- ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiOutputPrefetch() {
- auto builder = InnerGraphBuilder(graph_name_);
- auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
- auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
- auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
- auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
- auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);
-
- const int64_t buffer_pool_id = 0;
- const int64_t buffer_pool_size = 5600;
-
- auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
- auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
- auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
- auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
- auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});
-
- auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
- auto add1 = builder.AddNode("add1", ADD, 2, 1);
- auto add2 = builder.AddNode("add2", ADD, 2, 1);
- auto add3 = builder.AddNode("add3", ADD, 2, 1);
- auto add4 = builder.AddNode("add4", ADD, 2, 1);
- auto add5 = builder.AddNode("add5", ADD, 2, 1);
- auto net_output = builder.AddNode("net_output", NETOUTPUT, 5, 0);
-
- builder.AddDataEdge(w1, 0, prefetch1, 0);
- builder.AddDataEdge(w2, 0, prefetch2, 0);
- builder.AddDataEdge(w3, 0, prefetch3, 0);
- builder.AddDataEdge(w4, 0, prefetch4, 0);
- builder.AddDataEdge(w5, 0, prefetch5, 0);
-
- builder.AddDataEdge(const1, 0, add1, 0);
- builder.AddDataEdge(prefetch1, 0, add1, 1);
-
- builder.AddDataEdge(prefetch1, 0, add2, 0);
- builder.AddDataEdge(prefetch2, 0, add2, 1);
-
- builder.AddDataEdge(prefetch2, 0, add3, 0);
- builder.AddDataEdge(prefetch3, 0, add3, 1);
-
- builder.AddDataEdge(prefetch3, 0, add4, 0);
- builder.AddDataEdge(prefetch4, 0, add4, 1);
-
- builder.AddDataEdge(prefetch4, 0, add5, 0);
- builder.AddDataEdge(prefetch5, 0, add5, 1);
-
- builder.AddDataEdge(add1, 0, net_output, 0);
- builder.AddDataEdge(add2, 0, net_output, 1);
- builder.AddDataEdge(add3, 0, net_output, 2);
- builder.AddDataEdge(add4, 0, net_output, 3);
- builder.AddDataEdge(add5, 0, net_output, 4);
-
- auto compute_graph = builder.GetGraph();
-
- return compute_graph;
- }
-
- ///
- /// GraphWithMultiOutputPrefetch: Prefetch has more than one output
- ///
- /// w1 w2 w3 w4 w5
- /// \ / \ / \ / \ / \.
- /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
- /// / \ / \ / \ / \ /
- /// / \ / \ / \ / \ /
- /// const1 ----- add1 add2 add3 add4 add5
- /// | \ | / |
- /// | \ | / |
- /// | \ | / |
- /// | \ | / |
- /// -------------- net_output ---------------
- ///
- /// Memory distribution:
- ///
- /// |___w1__|__w2__|__w3__|__|
- ///
- /// |_____w4_____|_____w5____|
- ///
- ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiInputOutputPrefetch() {
- auto builder = InnerGraphBuilder(graph_name_);
- auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
- auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
- auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
- auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
- auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);
-
- const int64_t buffer_pool_id = 0;
- const int64_t buffer_pool_size = 5600;
-
- auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 2, 2);
- SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500, 500});
- auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 2, 2);
- SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500, 500});
- auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 2, 2);
- SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500, 1024});
- auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 2, 2);
- SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024, 1024});
- auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
- SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});
-
- auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
- auto add1 = builder.AddNode("add1", ADD, 2, 1);
- auto add2 = builder.AddNode("add2", ADD, 2, 1);
- auto add3 = builder.AddNode("add3", ADD, 2, 1);
- auto add4 = builder.AddNode("add4", ADD, 2, 1);
- auto add5 = builder.AddNode("add5", ADD, 2, 1);
- auto net_output = builder.AddNode("net_output", NETOUTPUT, 5, 0);
-
- builder.AddDataEdge(w1, 0, prefetch1, 0);
- builder.AddDataEdge(w2, 0, prefetch1, 1);
- builder.AddDataEdge(w2, 0, prefetch2, 0);
- builder.AddDataEdge(w3, 0, prefetch2, 1);
- builder.AddDataEdge(w3, 0, prefetch3, 0);
- builder.AddDataEdge(w4, 0, prefetch3, 1);
- builder.AddDataEdge(w4, 0, prefetch4, 0);
- builder.AddDataEdge(w5, 0, prefetch4, 1);
- builder.AddDataEdge(w5, 0, prefetch5, 0);
-
- builder.AddDataEdge(const1, 0, add1, 0);
- builder.AddDataEdge(prefetch1, 0, add1, 1);
-
- builder.AddDataEdge(prefetch1, 1, add2, 0);
- builder.AddDataEdge(prefetch2, 0, add2, 1);
-
- builder.AddDataEdge(prefetch2, 1, add3, 0);
- builder.AddDataEdge(prefetch3, 0, add3, 1);
-
- builder.AddDataEdge(prefetch3, 1, add4, 0);
- builder.AddDataEdge(prefetch4, 0, add4, 1);
-
- builder.AddDataEdge(prefetch4, 1, add5, 0);
- builder.AddDataEdge(prefetch5, 0, add5, 1);
-
- builder.AddDataEdge(add1, 0, net_output, 0);
- builder.AddDataEdge(add2, 0, net_output, 1);
- builder.AddDataEdge(add3, 0, net_output, 2);
- builder.AddDataEdge(add4, 0, net_output, 3);
- builder.AddDataEdge(add5, 0, net_output, 4);
-
- auto compute_graph = builder.GetGraph();
-
- return compute_graph;
- }
- } // namespace ut
- } // namespace ge
|