You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

buffer_pool_graph_builder.cc 42 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include "buffer_pool_graph_builder.h"
  18. #include "common/ge_inner_error_codes.h"
  19. #include "common/types.h"
  20. #include "graph/debug/ge_attr_define.h"
  21. #include "graph/utils/attr_utils.h"
  22. #include "graph/utils/graph_utils.h"
  23. #include "graph/utils/tensor_utils.h"
  24. #include "graph/utils/graph_utils.h"
  25. namespace ge {
  26. namespace ut {
  27. BufferPoolGraphBuilder::BufferPoolGraphBuilder(const std::string &name) {
  28. graph_name_ = name;
  29. }
  30. BufferPoolGraphBuilder::InnerGraphBuilder::InnerGraphBuilder(const std::string &name) {
  31. graph_ = std::make_shared<ComputeGraph>(name);
  32. EXPECT_NE(graph_, nullptr);
  33. }
  34. NodePtr BufferPoolGraphBuilder::InnerGraphBuilder::AddNode(const std::string &name, const std::string &type,
  35. int in_cnt, int out_cnt,
  36. Format format, DataType data_type,
  37. std::vector<int64_t> shape) {
  38. auto tensor_desc = std::make_shared<GeTensorDesc>();
  39. EXPECT_NE(tensor_desc, nullptr);
  40. tensor_desc->SetShape(GeShape(std::move(shape)));
  41. tensor_desc->SetFormat(format);
  42. tensor_desc->SetDataType(data_type);
  43. auto op_desc = std::make_shared<OpDesc>(name, type);
  44. EXPECT_NE(op_desc, nullptr);
  45. for (int i = 0; i < in_cnt; ++i) {
  46. op_desc->AddInputDesc(tensor_desc->Clone());
  47. }
  48. for (int i = 0; i < out_cnt; ++i) {
  49. op_desc->AddOutputDesc(tensor_desc->Clone());
  50. }
  51. return graph_->AddNode(op_desc);
  52. }
  53. void BufferPoolGraphBuilder::InnerGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx,
  54. NodePtr &dst_node, int dst_idx) {
  55. EXPECT_NE(src_node, nullptr);
  56. EXPECT_NE(dst_node, nullptr);
  57. GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx));
  58. }
  59. void BufferPoolGraphBuilder::InnerGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) {
  60. EXPECT_NE(src_node, nullptr);
  61. EXPECT_NE(dst_node, nullptr);
  62. GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor());
  63. }
  64. void BufferPoolGraphBuilder::SetBufferPool(NodePtr &node, int64_t pool_id, int64_t pool_size,
  65. const std::string &batch_label) {
  66. EXPECT_NE(node, nullptr);
  67. (void) AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_ID, pool_id);
  68. (void) AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_SIZE, pool_size);
  69. if (!batch_label.empty()) {
  70. (void) AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label);
  71. }
  72. }
  73. void BufferPoolGraphBuilder::SetBatchLabel(NodePtr &node, const std::string &batch_label) {
  74. EXPECT_NE(node, nullptr);
  75. (void) AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label);
  76. }
  77. void BufferPoolGraphBuilder::SetOutputMemSize(NodePtr &node, const std::vector<int64_t> &mem_size) {
  78. EXPECT_NE(node, nullptr);
  79. EXPECT_NE(node->GetOpDesc(), nullptr);
  80. size_t output_size = node->GetOpDesc()->GetOutputsSize();
  81. EXPECT_EQ(output_size, mem_size.size());
  82. for (size_t i = 0; i < output_size; ++i) {
  83. auto output_op_desc = node->GetOpDesc()->MutableOutputDesc(i);
  84. ge::TensorUtils::SetSize(*output_op_desc, mem_size[i]);
  85. }
  86. }
  87. void BufferPoolGraphBuilder::SetWorkSpaceMemSize(NodePtr &node, const std::vector<int64_t> &ws_bytes) {
  88. EXPECT_NE(node, nullptr);
  89. EXPECT_NE(node->GetOpDesc(), nullptr);
  90. node->GetOpDesc()->SetWorkspaceBytes(ws_bytes);
  91. }
  92. void BufferPoolGraphBuilder::SetPrefetchNodeInfo(NodePtr &node, int64_t pool_id, int64_t pool_size,
  93. const std::vector<int64_t> &mem_size,
  94. const std::vector<int64_t> &ws_bytes,
  95. const std::string &batch_label) {
  96. SetBufferPool(node, pool_id, pool_size, batch_label);
  97. SetOutputMemSize(node, mem_size);
  98. SetWorkSpaceMemSize(node, ws_bytes);
  99. }
  100. ///
  101. /// Normal graph
  102. ///
  103. /// w1 w2 w3 w4 w5
  104. /// \ \ \ \ \.
  105. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
  106. /// \ \ \ \ \.
  107. /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
  108. ///
  109. ///
  110. /// Memory distribution:
  111. ///
  112. /// |___w1__|__w2__|__w3__|__|
  113. ///
  114. /// |_____w4_____|_____w5____|
  115. ///
  116. ComputeGraphPtr BufferPoolGraphBuilder::BuildNormalGraph() {
  117. auto builder = InnerGraphBuilder(graph_name_);
  118. auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
  119. auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
  120. auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
  121. auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
  122. auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);
  123. const int64_t buffer_pool_id = 0;
  124. const int64_t buffer_pool_size = 5600;
  125. auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
  126. SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
  127. auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
  128. SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
  129. auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
  130. SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
  131. auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
  132. SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
  133. auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
  134. SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});
  135. auto add1 = builder.AddNode("add1", ADD, 2, 1);
  136. auto add2 = builder.AddNode("add2", ADD, 2, 1);
  137. auto add3 = builder.AddNode("add3", ADD, 2, 1);
  138. auto add4 = builder.AddNode("add4", ADD, 2, 1);
  139. auto add5 = builder.AddNode("add5", ADD, 2, 1);
  140. auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
  141. auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0);
  142. builder.AddDataEdge(w1, 0, prefetch1, 0);
  143. builder.AddDataEdge(w2, 0, prefetch2, 0);
  144. builder.AddDataEdge(w3, 0, prefetch3, 0);
  145. builder.AddDataEdge(w4, 0, prefetch4, 0);
  146. builder.AddDataEdge(w5, 0, prefetch5, 0);
  147. builder.AddDataEdge(const1, 0, add1, 0);
  148. builder.AddDataEdge(prefetch1, 0, add1, 1);
  149. builder.AddDataEdge(add1, 0, add2, 0);
  150. builder.AddDataEdge(prefetch2, 0, add2, 1);
  151. builder.AddDataEdge(add2, 0, add3, 0);
  152. builder.AddDataEdge(prefetch3, 0, add3, 1);
  153. builder.AddDataEdge(add3, 0, add4, 0);
  154. builder.AddDataEdge(prefetch4, 0, add4, 1);
  155. builder.AddDataEdge(add4, 0, add5, 0);
  156. builder.AddDataEdge(prefetch5, 0, add5, 1);
  157. builder.AddDataEdge(add5, 0, net_output, 0);
  158. auto compute_graph = builder.GetGraph();
  159. return compute_graph;
  160. }
  161. ///
  162. /// Normal graph with multi buffer pool
  163. ///
  164. /// w1 w2 w3 w4 w5
  165. /// \ \ \ \ \.
  166. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
  167. /// (pool0) (pool1) (pool0) (pool0) (pool1)
  168. /// \ \ \ \ \.
  169. /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
  170. ///
  171. ///
  172. /// Memory distribution:
  173. ///
  174. /// |___w1__|__w3__|_________|
  175. /// |_____w4_____|___________|
  176. ///
  177. /// |___w2__|_____w5___|_____|
  178. ///
  179. ComputeGraphPtr BufferPoolGraphBuilder::BuildNormalGraphWithMultiBufferPool() {
  180. auto builder = InnerGraphBuilder(graph_name_);
  181. auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
  182. auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
  183. auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
  184. auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
  185. auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);
  186. const int64_t buffer_pool_id_0 = 0;
  187. const int64_t buffer_pool_id_1 = 1;
  188. const int64_t buffer_pool_size = 5000;
  189. auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
  190. SetPrefetchNodeInfo(prefetch1, buffer_pool_id_0, buffer_pool_size, {500});
  191. auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
  192. SetPrefetchNodeInfo(prefetch2, buffer_pool_id_1, buffer_pool_size, {500});
  193. auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
  194. SetPrefetchNodeInfo(prefetch3, buffer_pool_id_0, buffer_pool_size, {500});
  195. auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
  196. SetPrefetchNodeInfo(prefetch4, buffer_pool_id_0, buffer_pool_size, {1024});
  197. auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
  198. SetPrefetchNodeInfo(prefetch5, buffer_pool_id_1, buffer_pool_size, {1024});
  199. auto add1 = builder.AddNode("add1", ADD, 2, 1);
  200. auto add2 = builder.AddNode("add2", ADD, 2, 1);
  201. auto add3 = builder.AddNode("add3", ADD, 2, 1);
  202. auto add4 = builder.AddNode("add4", ADD, 2, 1);
  203. auto add5 = builder.AddNode("add5", ADD, 2, 1);
  204. auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
  205. auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0);
  206. builder.AddDataEdge(w1, 0, prefetch1, 0);
  207. builder.AddDataEdge(w2, 0, prefetch2, 0);
  208. builder.AddDataEdge(w3, 0, prefetch3, 0);
  209. builder.AddDataEdge(w4, 0, prefetch4, 0);
  210. builder.AddDataEdge(w5, 0, prefetch5, 0);
  211. builder.AddDataEdge(const1, 0, add1, 0);
  212. builder.AddDataEdge(prefetch1, 0, add1, 1);
  213. builder.AddDataEdge(add1, 0, add2, 0);
  214. builder.AddDataEdge(prefetch2, 0, add2, 1);
  215. builder.AddDataEdge(add2, 0, add3, 0);
  216. builder.AddDataEdge(prefetch3, 0, add3, 1);
  217. builder.AddDataEdge(add3, 0, add4, 0);
  218. builder.AddDataEdge(prefetch4, 0, add4, 1);
  219. builder.AddDataEdge(add4, 0, add5, 0);
  220. builder.AddDataEdge(prefetch5, 0, add5, 1);
  221. builder.AddDataEdge(add5, 0, net_output, 0);
  222. auto compute_graph = builder.GetGraph();
  223. return compute_graph;
  224. }
  225. ///
  226. /// SerialGraph: Buffer pool size only can contain one prefetch node
  227. ///
  228. /// w1 w2 w3 w4 w5
  229. /// \ \ \ \ \.
  230. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
  231. /// \ \ \ \ \.
  232. /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
  233. ///
  234. ///
  235. /// Memory distribution:
  236. ///
  237. /// |____w1_____|__|
  238. ///
  239. /// |____w2_____|__|
  240. ///
  241. /// |____w3_____|__|
  242. ///
  243. /// |______w4______|
  244. ///
  245. /// |______w5______|
  246. ///
  247. ComputeGraphPtr BufferPoolGraphBuilder::BuildSerialGraph() {
  248. auto builder = InnerGraphBuilder(graph_name_);
  249. auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
  250. auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
  251. auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
  252. auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
  253. auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);
  254. const int64_t buffer_pool_id = 0;
  255. const int64_t buffer_pool_size = 2048;
  256. auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
  257. SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
  258. auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
  259. SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
  260. auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
  261. SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
  262. auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
  263. SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
  264. auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
  265. SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});
  266. auto add1 = builder.AddNode("add1", ADD, 2, 1);
  267. auto add2 = builder.AddNode("add2", ADD, 2, 1);
  268. auto add3 = builder.AddNode("add3", ADD, 2, 1);
  269. auto add4 = builder.AddNode("add4", ADD, 2, 1);
  270. auto add5 = builder.AddNode("add5", ADD, 2, 1);
  271. auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
  272. auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0);
  273. builder.AddDataEdge(w1, 0, prefetch1, 0);
  274. builder.AddDataEdge(w2, 0, prefetch2, 0);
  275. builder.AddDataEdge(w3, 0, prefetch3, 0);
  276. builder.AddDataEdge(w4, 0, prefetch4, 0);
  277. builder.AddDataEdge(w5, 0, prefetch5, 0);
  278. builder.AddDataEdge(const1, 0, add1, 0);
  279. builder.AddDataEdge(prefetch1, 0, add1, 1);
  280. builder.AddDataEdge(add1, 0, add2, 0);
  281. builder.AddDataEdge(prefetch2, 0, add2, 1);
  282. builder.AddDataEdge(add2, 0, add3, 0);
  283. builder.AddDataEdge(prefetch3, 0, add3, 1);
  284. builder.AddDataEdge(add3, 0, add4, 0);
  285. builder.AddDataEdge(prefetch4, 0, add4, 1);
  286. builder.AddDataEdge(add4, 0, add5, 0);
  287. builder.AddDataEdge(prefetch5, 0, add5, 1);
  288. builder.AddDataEdge(add5, 0, net_output, 0);
  289. auto compute_graph = builder.GetGraph();
  290. return compute_graph;
  291. }
  292. ///
  293. /// GraphWithMultiPrefetch: Calc node with more prefetch node
  294. ///
  295. /// w1 w2 w3 w4 w5
  296. /// \ \ \ \ \.
  297. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 const1
  298. /// \ / \ / \ /
  299. /// \ / \ / \ /
  300. /// \ / \ / \ /
  301. /// add1 ------ c ------- add2 ----- c ----- add3
  302. /// | | |
  303. /// | | |
  304. /// --------------- net_output ------------
  305. ///
  306. /// Memory distribution:
  307. ///
  308. /// |___w1__|__w2__|__w3__|__|
  309. ///
  310. /// |_____w4_____|_____w5____|
  311. ///
  312. ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiPrefetch() {
  313. auto builder = InnerGraphBuilder(graph_name_);
  314. auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
  315. auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
  316. auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
  317. auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
  318. auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);
  319. const int64_t buffer_pool_id = 0;
  320. const int64_t buffer_pool_size = 5600;
  321. auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
  322. SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
  323. auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
  324. SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
  325. auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
  326. SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
  327. auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
  328. SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
  329. auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
  330. SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});
  331. auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
  332. auto add1 = builder.AddNode("add1", ADD, 2, 1);
  333. auto add2 = builder.AddNode("add2", ADD, 2, 1);
  334. auto add3 = builder.AddNode("add3", ADD, 2, 1);
  335. auto net_output = builder.AddNode("net_output", NETOUTPUT, 3, 0);
  336. builder.AddDataEdge(w1, 0, prefetch1, 0);
  337. builder.AddDataEdge(w2, 0, prefetch2, 0);
  338. builder.AddDataEdge(w3, 0, prefetch3, 0);
  339. builder.AddDataEdge(w4, 0, prefetch4, 0);
  340. builder.AddDataEdge(w5, 0, prefetch5, 0);
  341. builder.AddDataEdge(prefetch1, 0, add1, 0);
  342. builder.AddDataEdge(prefetch2, 0, add1, 1);
  343. builder.AddDataEdge(prefetch3, 0, add2, 0);
  344. builder.AddDataEdge(prefetch4, 0, add2, 1);
  345. builder.AddDataEdge(const1, 0, add3, 0);
  346. builder.AddDataEdge(prefetch5, 0, add3, 1);
  347. builder.AddDataEdge(add1, 0, net_output, 0);
  348. builder.AddDataEdge(add2, 0, net_output, 1);
  349. builder.AddDataEdge(add3, 0, net_output, 2);
  350. builder.AddControlEdge(add1, add2);
  351. builder.AddControlEdge(add2, add3);
  352. auto compute_graph = builder.GetGraph();
  353. return compute_graph;
  354. }
  355. ///
  356. /// GraphWithSubgraph: Calc node in different subgraph
  357. ///
  358. ///
  359. /// call_node1(with Subgraph1) --------------- call_node2 (with Subgraph2) --------------- net_output
  360. ///
  361. ///
  362. /// Subgraph1: Subgraph2:
  363. ///
  364. /// w1 w2 w3 w4 w5
  365. /// \ \ \ \ \.
  366. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
  367. /// \ \ \ \ \.
  368. /// const1 ----- add1 ----- add2 ----- add3 ---- subgraph1_out data1 ---- add4 ----- add5 ---- subgraph2_out
  369. ///
  370. ///
  371. /// Memory distribution:
  372. ///
  373. /// |___w1__|__w2__|__w3__|__|
  374. ///
  375. /// |_____w4_____|_____w5____|
  376. ///
  377. ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithSubgraph() {
  378. auto builder = InnerGraphBuilder(graph_name_);
  379. const int64_t buffer_pool_id = 0;
  380. const int64_t buffer_pool_size = 5600;
  381. // Subgraph1
  382. auto subgraph_builder1 = InnerGraphBuilder("Subgraph1");
  383. auto w1 = subgraph_builder1.AddNode("w1", VARIABLE, 0, 1);
  384. auto w2 = subgraph_builder1.AddNode("w2", VARIABLE, 0, 1);
  385. auto w3 = subgraph_builder1.AddNode("w3", VARIABLE, 0, 1);
  386. auto prefetch1 = subgraph_builder1.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
  387. SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
  388. auto prefetch2 = subgraph_builder1.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
  389. SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
  390. auto prefetch3 = subgraph_builder1.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
  391. SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
  392. auto subgraph1_out = subgraph_builder1.AddNode("subgraph1_out", NETOUTPUT, 1, 0);
  393. auto const1 = subgraph_builder1.AddNode("const1", CONSTANTOP, 0, 1);
  394. auto add1 = subgraph_builder1.AddNode("add1", ADD, 2, 1);
  395. auto add2 = subgraph_builder1.AddNode("add2", ADD, 2, 1);
  396. auto add3 = subgraph_builder1.AddNode("add3", ADD, 2, 1);
  397. subgraph_builder1.AddDataEdge(w1, 0, prefetch1, 0);
  398. subgraph_builder1.AddDataEdge(w2, 0, prefetch2, 0);
  399. subgraph_builder1.AddDataEdge(w3, 0, prefetch3, 0);
  400. subgraph_builder1.AddDataEdge(const1, 0, add1, 0);
  401. subgraph_builder1.AddDataEdge(prefetch1, 0, add1, 1);
  402. subgraph_builder1.AddDataEdge(add1, 0, add2, 0);
  403. subgraph_builder1.AddDataEdge(prefetch2, 0, add2, 1);
  404. subgraph_builder1.AddDataEdge(add2, 0, add3, 0);
  405. subgraph_builder1.AddDataEdge(prefetch3, 0, add3, 1);
  406. subgraph_builder1.AddDataEdge(add3, 0, subgraph1_out, 0);
  407. auto subgraph1 = subgraph_builder1.GetGraph();
  408. for (auto &node : subgraph1->GetDirectNode()) {
  409. node->SetOwnerComputeGraph(subgraph1);
  410. }
  411. // Subgraph2
  412. auto subgraph_builder2 = InnerGraphBuilder("Subgraph2");
  413. auto w4 = subgraph_builder2.AddNode("w4", VARIABLE, 0, 1);
  414. auto w5 = subgraph_builder2.AddNode("w5", VARIABLE, 0, 1);
  415. auto prefetch4 = subgraph_builder2.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
  416. SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
  417. auto prefetch5 = subgraph_builder2.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
  418. SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});
  419. auto add4 = subgraph_builder2.AddNode("add4", ADD, 2, 1);
  420. auto add5 = subgraph_builder2.AddNode("add5", ADD, 2, 1);
  421. auto data1 = subgraph_builder2.AddNode("data1", DATA, 0, 1);
  422. auto subgraph2_out = subgraph_builder2.AddNode("subgraph2_out", NETOUTPUT, 1, 1);
  423. subgraph_builder2.AddDataEdge(w4, 0, prefetch4, 0);
  424. subgraph_builder2.AddDataEdge(w5, 0, prefetch5, 0);
  425. subgraph_builder2.AddDataEdge(data1, 0, add4, 0);
  426. subgraph_builder2.AddDataEdge(prefetch4, 0, add4, 1);
  427. subgraph_builder2.AddDataEdge(add4, 0, add5, 0);
  428. subgraph_builder2.AddDataEdge(prefetch5, 0, add5, 1);
  429. subgraph_builder2.AddDataEdge(add5, 0, subgraph2_out, 0);
  430. auto subgraph2 = subgraph_builder2.GetGraph();
  431. for (auto &node : subgraph2->GetDirectNode()) {
  432. node->SetOwnerComputeGraph(subgraph2);
  433. }
  434. // root graph
  435. auto call_node1 = builder.AddNode("call_node1", PARTITIONEDCALL, 0, 1);
  436. auto call_node2 = builder.AddNode("call_node2", PARTITIONEDCALL, 1, 0);
  437. auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0);
  438. builder.AddDataEdge(call_node1, 0, call_node2, 0);
  439. builder.AddDataEdge(call_node2, 0, net_output, 0);
  440. auto compute_graph = builder.GetGraph();
  441. call_node1->SetOwnerComputeGraph(compute_graph);
  442. call_node1->GetOpDesc()->AddSubgraphName(subgraph1->GetName());
  443. call_node1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph1->GetName());
  444. call_node2->SetOwnerComputeGraph(compute_graph);
  445. call_node2->GetOpDesc()->AddSubgraphName(subgraph2->GetName());
  446. call_node2->GetOpDesc()->SetSubgraphInstanceName(0, subgraph2->GetName());
  447. subgraph1->SetParentNode(call_node1);
  448. subgraph1->SetParentGraph(compute_graph);
  449. subgraph2->SetParentNode(call_node2);
  450. subgraph2->SetParentGraph(compute_graph);
  451. compute_graph->AddSubGraph(subgraph1);
  452. compute_graph->AddSubGraph(subgraph2);
  453. return compute_graph;
  454. }
  455. ///
  456. /// SubgraphWithInnerDependency: Calc node in different subgraph with inner dependency
  457. ///
  458. ///
  459. /// call_node1(with Subgraph1) --------------------- call_node2 (with Subgraph2) ---------- net_output
  460. ///
  461. ///
  462. /// Subgraph1: Subgraph2:
  463. ///
  464. /// w1 w2 w3 w4 w5
  465. /// \ \ \ \ \.
  466. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
  467. /// \ \ \ \ \.
  468. /// const1 ----- add1 ----- add2 ----- subgraph1_out data1 ---- add3 ---- add4 ----- add5 ---- subgraph2_out
  469. ///
  470. ///
  471. /// Memory distribution:
  472. ///
  473. /// |___w1__|__w2__|__w3__|__|
  474. ///
  475. /// |_____w4_____|_____w5____|
  476. ///
  477. ComputeGraphPtr BufferPoolGraphBuilder::BuildSubgraphWithInnerDependency() {
  478. auto builder = InnerGraphBuilder(graph_name_);
  479. const int64_t buffer_pool_id = 0;
  480. const int64_t buffer_pool_size = 5600;
  481. // Subgraph1
  482. auto subgraph_builder1 = InnerGraphBuilder("Subgraph1");
  483. auto w1 = subgraph_builder1.AddNode("w1", VARIABLE, 0, 1);
  484. auto w2 = subgraph_builder1.AddNode("w2", VARIABLE, 0, 1);
  485. auto prefetch1 = subgraph_builder1.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
  486. SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
  487. auto prefetch2 = subgraph_builder1.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
  488. SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
  489. auto subgraph1_out = subgraph_builder1.AddNode("subgraph1_out", NETOUTPUT, 1, 0);
  490. auto const1 = subgraph_builder1.AddNode("const1", CONSTANTOP, 0, 1);
  491. auto add1 = subgraph_builder1.AddNode("add1", ADD, 2, 1);
  492. auto add2 = subgraph_builder1.AddNode("add2", ADD, 2, 1);
  493. subgraph_builder1.AddDataEdge(w1, 0, prefetch1, 0);
  494. subgraph_builder1.AddDataEdge(w2, 0, prefetch2, 0);
  495. subgraph_builder1.AddDataEdge(const1, 0, add1, 0);
  496. subgraph_builder1.AddDataEdge(prefetch1, 0, add1, 1);
  497. subgraph_builder1.AddDataEdge(add1, 0, add2, 0);
  498. subgraph_builder1.AddDataEdge(prefetch2, 0, add2, 1);
  499. subgraph_builder1.AddDataEdge(add2, 0, subgraph1_out, 0);
  500. auto subgraph1 = subgraph_builder1.GetGraph();
  501. for (auto &node : subgraph1->GetDirectNode()) {
  502. node->SetOwnerComputeGraph(subgraph1);
  503. }
  504. // Subgraph2
  505. auto subgraph_builder2 = InnerGraphBuilder("Subgraph2");
  506. auto w3 = subgraph_builder2.AddNode("w3", VARIABLE, 0, 1);
  507. auto w4 = subgraph_builder2.AddNode("w4", VARIABLE, 0, 1);
  508. auto w5 = subgraph_builder2.AddNode("w5", VARIABLE, 0, 1);
  509. auto prefetch3 = subgraph_builder2.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
  510. SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
  511. auto prefetch4 = subgraph_builder2.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
  512. SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
  513. auto prefetch5 = subgraph_builder2.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
  514. SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});
  515. auto add3 = subgraph_builder2.AddNode("add3", ADD, 2, 1);
  516. auto add4 = subgraph_builder2.AddNode("add4", ADD, 2, 1);
  517. auto add5 = subgraph_builder2.AddNode("add5", ADD, 2, 1);
  518. auto data1 = subgraph_builder2.AddNode("data1", DATA, 0, 1);
  519. auto subgraph2_out = subgraph_builder2.AddNode("subgraph2_out", NETOUTPUT, 1, 1);
  520. subgraph_builder2.AddDataEdge(w3, 0, prefetch3, 0);
  521. subgraph_builder2.AddDataEdge(w4, 0, prefetch4, 0);
  522. subgraph_builder2.AddDataEdge(w5, 0, prefetch5, 0);
  523. subgraph_builder2.AddDataEdge(data1, 0, add3, 0);
  524. subgraph_builder2.AddDataEdge(prefetch3, 0, add3, 1);
  525. subgraph_builder2.AddDataEdge(add3, 0, add4, 0);
  526. subgraph_builder2.AddDataEdge(prefetch4, 0, add4, 1);
  527. subgraph_builder2.AddDataEdge(add4, 0, add5, 0);
  528. subgraph_builder2.AddDataEdge(prefetch5, 0, add5, 1);
  529. subgraph_builder2.AddDataEdge(add5, 0, subgraph2_out, 0);
  530. auto subgraph2 = subgraph_builder2.GetGraph();
  531. for (auto &node : subgraph2->GetDirectNode()) {
  532. node->SetOwnerComputeGraph(subgraph2);
  533. }
  534. // root graph
  535. auto call_node1 = builder.AddNode("call_node1", PARTITIONEDCALL, 0, 1);
  536. auto call_node2 = builder.AddNode("call_node2", PARTITIONEDCALL, 1, 0);
  537. auto net_output = subgraph_builder2.AddNode("net_output", NETOUTPUT, 1, 0);
  538. builder.AddDataEdge(call_node1, 0, call_node2, 0);
  539. builder.AddDataEdge(call_node2, 0, net_output, 0);
  540. auto compute_graph = builder.GetGraph();
  541. call_node1->SetOwnerComputeGraph(compute_graph);
  542. call_node1->GetOpDesc()->AddSubgraphName(subgraph1->GetName());
  543. call_node1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph1->GetName());
  544. call_node2->SetOwnerComputeGraph(compute_graph);
  545. call_node2->GetOpDesc()->AddSubgraphName(subgraph2->GetName());
  546. call_node2->GetOpDesc()->SetSubgraphInstanceName(0, subgraph2->GetName());
  547. subgraph1->SetParentNode(call_node1);
  548. subgraph1->SetParentGraph(compute_graph);
  549. subgraph2->SetParentNode(call_node2);
  550. subgraph2->SetParentGraph(compute_graph);
  551. compute_graph->AddSubGraph(subgraph1);
  552. compute_graph->AddSubGraph(subgraph2);
  553. return compute_graph;
  554. }
  555. ///
  556. /// BuildGraphWithMultiBatch: Different batch label
  557. ///
  558. ///
  559. /// batch_label_128
  560. ///
  561. /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ---
  562. /// / / / / / / \.
  563. /// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \.
  564. /// const1 switch_false / / / / / \.
  565. /// \ / / / / / / \.
  566. /// switch1 w1 w2 w3 w4 w5 merge1 -- net_output
  567. /// / \ \ \ \ \ \ /
  568. /// const2 switch_true \ \ \ \ \ /
  569. /// \c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 /
  570. /// \ \ \ \ \ \ /
  571. /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ---
  572. ///
  573. /// batch_label_256
  574. ///
  575. ///
  576. /// Memory distribution:
  577. ///
  578. /// |___w1__|__w2__|__w3__|__|
  579. ///
  580. /// |_____w4_____|_____w5____|
  581. ///
  582. ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiBatch() {
  583. auto builder = InnerGraphBuilder(graph_name_);
  584. auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
  585. auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
  586. auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
  587. auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
  588. auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);
  589. auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
  590. auto const2 = builder.AddNode("const2", CONSTANTOP, 0, 1);
  591. auto switch1 = builder.AddNode("switch1", SWITCH, 2, 2);
  592. auto switch_false = builder.AddNode("switch_false", IDENTITY, 1, 1);
  593. auto switch_true = builder.AddNode("switch_true", IDENTITY, 1, 1);
  594. auto merge1 = builder.AddNode("merge1", MERGE, 2, 2);
  595. auto net_output = builder.AddNode("net_output", NETOUTPUT, 1, 0);
  596. builder.AddDataEdge(const1, 0, switch1, 0);
  597. builder.AddDataEdge(const2, 0, switch1, 1);
  598. builder.AddDataEdge(switch1, 0, switch_false, 0);
  599. builder.AddDataEdge(switch1, 1, switch_true, 0);
  600. builder.AddDataEdge(merge1, 0, net_output, 0);
  601. std::string batch_label_128 = "batch_128";
  602. std::string batch_label_256 = "batch_256";
  603. const int64_t buffer_pool_id = 0;
  604. const int64_t buffer_pool_size = 5600;
  605. {
  606. auto prefetch1 = builder.AddNode("batch_label_128/prefetch1", HCOMALLGATHER, 1, 1);
  607. SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128);
  608. auto prefetch2 = builder.AddNode("batch_label_128/prefetch2", HCOMALLGATHER, 1, 1);
  609. SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128);
  610. auto prefetch3 = builder.AddNode("batch_label_128/prefetch3", HCOMALLGATHER, 1, 1);
  611. SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_128);
  612. auto prefetch4 = builder.AddNode("batch_label_128/prefetch4", HCOMALLGATHER, 1, 1);
  613. SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_128);
  614. auto prefetch5 = builder.AddNode("batch_label_128/prefetch5", HCOMALLGATHER, 1, 1);
  615. SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_128);
  616. auto add1 = builder.AddNode("batch_label_128/add1", ADD, 2, 1);
  617. SetBatchLabel(add1, batch_label_128);
  618. auto add2 = builder.AddNode("batch_label_128/add2", ADD, 2, 1);
  619. SetBatchLabel(add2, batch_label_128);
  620. auto add3 = builder.AddNode("batch_label_128/add3", ADD, 2, 1);
  621. SetBatchLabel(add3, batch_label_128);
  622. auto add4 = builder.AddNode("batch_label_128/add4", ADD, 2, 1);
  623. SetBatchLabel(add4, batch_label_128);
  624. auto add5 = builder.AddNode("batch_label_128/add5", ADD, 2, 1);
  625. SetBatchLabel(add5, batch_label_128);
  626. auto const1 = builder.AddNode("batch_label_128/const1", CONSTANTOP, 0, 1);
  627. SetBatchLabel(const1, batch_label_128);
  628. builder.AddDataEdge(w1, 0, prefetch1, 0);
  629. builder.AddDataEdge(w2, 0, prefetch2, 0);
  630. builder.AddDataEdge(w3, 0, prefetch3, 0);
  631. builder.AddDataEdge(w4, 0, prefetch4, 0);
  632. builder.AddDataEdge(w5, 0, prefetch5, 0);
  633. builder.AddDataEdge(const1, 0, add1, 0);
  634. builder.AddDataEdge(prefetch1, 0, add1, 1);
  635. builder.AddDataEdge(add1, 0, add2, 0);
  636. builder.AddDataEdge(prefetch2, 0, add2, 1);
  637. builder.AddDataEdge(add2, 0, add3, 0);
  638. builder.AddDataEdge(prefetch3, 0, add3, 1);
  639. builder.AddDataEdge(add3, 0, add4, 0);
  640. builder.AddDataEdge(prefetch4, 0, add4, 1);
  641. builder.AddDataEdge(add4, 0, add5, 0);
  642. builder.AddDataEdge(prefetch5, 0, add5, 1);
  643. builder.AddDataEdge(add5, 0, merge1, 0);
  644. builder.AddControlEdge(switch_false, const1);
  645. }
  646. {
  647. auto prefetch1 = builder.AddNode("batch_label_256/prefetch1", HCOMALLGATHER, 1, 1);
  648. SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256);
  649. auto prefetch2 = builder.AddNode("batch_label_256/prefetch2", HCOMALLGATHER, 1, 1);
  650. SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256);
  651. auto prefetch3 = builder.AddNode("batch_label_256/prefetch3", HCOMALLGATHER, 1, 1);
  652. SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500}, {500}, batch_label_256);
  653. auto prefetch4 = builder.AddNode("batch_label_256/prefetch4", HCOMALLGATHER, 1, 1);
  654. SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_256);
  655. auto prefetch5 = builder.AddNode("batch_label_256/prefetch5", HCOMALLGATHER, 1, 1);
  656. SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024}, {1024}, batch_label_256);
  657. auto add1 = builder.AddNode("batch_label_256/add1", ADD, 2, 1);
  658. SetBatchLabel(add1, batch_label_256);
  659. auto add2 = builder.AddNode("batch_label_256/add2", ADD, 2, 1);
  660. SetBatchLabel(add2, batch_label_256);
  661. auto add3 = builder.AddNode("batch_label_256/add3", ADD, 2, 1);
  662. SetBatchLabel(add3, batch_label_256);
  663. auto add4 = builder.AddNode("batch_label_256/add4", ADD, 2, 1);
  664. SetBatchLabel(add4, batch_label_256);
  665. auto add5 = builder.AddNode("batch_label_256/add5", ADD, 2, 1);
  666. SetBatchLabel(add5, batch_label_256);
  667. auto const1 = builder.AddNode("batch_label_256/const1", CONSTANTOP, 0, 1);
  668. SetBatchLabel(const1, batch_label_128);
  669. builder.AddDataEdge(w1, 0, prefetch1, 0);
  670. builder.AddDataEdge(w2, 0, prefetch2, 0);
  671. builder.AddDataEdge(w3, 0, prefetch3, 0);
  672. builder.AddDataEdge(w4, 0, prefetch4, 0);
  673. builder.AddDataEdge(w5, 0, prefetch5, 0);
  674. builder.AddDataEdge(const1, 0, add1, 0);
  675. builder.AddDataEdge(prefetch1, 0, add1, 1);
  676. builder.AddDataEdge(add1, 0, add2, 0);
  677. builder.AddDataEdge(prefetch2, 0, add2, 1);
  678. builder.AddDataEdge(add2, 0, add3, 0);
  679. builder.AddDataEdge(prefetch3, 0, add3, 1);
  680. builder.AddDataEdge(add3, 0, add4, 0);
  681. builder.AddDataEdge(prefetch4, 0, add4, 1);
  682. builder.AddDataEdge(add4, 0, add5, 0);
  683. builder.AddDataEdge(prefetch5, 0, add5, 1);
  684. builder.AddDataEdge(add5, 0, merge1, 1);
  685. builder.AddControlEdge(switch_true, const1);
  686. }
  687. auto compute_graph = builder.GetGraph();
  688. return compute_graph;
  689. }
  690. ///
  691. /// GraphWithMultiOutputPrefetch: Prefetch has more than one output
  692. ///
  693. /// w1 w2 w3 w4 w5
  694. /// \ \ \ \ \.
  695. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
  696. /// / \ / \ / \ / \ /
  697. /// / \ / \ / \ / \ /
  698. /// const1 ----- add1 add2 add3 add4 add5
  699. /// | \ | / |
  700. /// | \ | / |
  701. /// | \ | / |
  702. /// | \ | / |
  703. /// -------------- net_output ---------------
  704. ///
  705. /// Memory distribution:
  706. ///
  707. /// |___w1__|__w2__|__w3__|__|
  708. ///
  709. /// |_____w4_____|_____w5____|
  710. ///
  711. ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiOutputPrefetch() {
  712. auto builder = InnerGraphBuilder(graph_name_);
  713. auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
  714. auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
  715. auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
  716. auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
  717. auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);
  718. const int64_t buffer_pool_id = 0;
  719. const int64_t buffer_pool_size = 5600;
  720. auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 1, 1);
  721. SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500});
  722. auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 1, 1);
  723. SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500});
  724. auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 1, 1);
  725. SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500});
  726. auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 1, 1);
  727. SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024});
  728. auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
  729. SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});
  730. auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
  731. auto add1 = builder.AddNode("add1", ADD, 2, 1);
  732. auto add2 = builder.AddNode("add2", ADD, 2, 1);
  733. auto add3 = builder.AddNode("add3", ADD, 2, 1);
  734. auto add4 = builder.AddNode("add4", ADD, 2, 1);
  735. auto add5 = builder.AddNode("add5", ADD, 2, 1);
  736. auto net_output = builder.AddNode("net_output", NETOUTPUT, 5, 0);
  737. builder.AddDataEdge(w1, 0, prefetch1, 0);
  738. builder.AddDataEdge(w2, 0, prefetch2, 0);
  739. builder.AddDataEdge(w3, 0, prefetch3, 0);
  740. builder.AddDataEdge(w4, 0, prefetch4, 0);
  741. builder.AddDataEdge(w5, 0, prefetch5, 0);
  742. builder.AddDataEdge(const1, 0, add1, 0);
  743. builder.AddDataEdge(prefetch1, 0, add1, 1);
  744. builder.AddDataEdge(prefetch1, 0, add2, 0);
  745. builder.AddDataEdge(prefetch2, 0, add2, 1);
  746. builder.AddDataEdge(prefetch2, 0, add3, 0);
  747. builder.AddDataEdge(prefetch3, 0, add3, 1);
  748. builder.AddDataEdge(prefetch3, 0, add4, 0);
  749. builder.AddDataEdge(prefetch4, 0, add4, 1);
  750. builder.AddDataEdge(prefetch4, 0, add5, 0);
  751. builder.AddDataEdge(prefetch5, 0, add5, 1);
  752. builder.AddDataEdge(add1, 0, net_output, 0);
  753. builder.AddDataEdge(add2, 0, net_output, 1);
  754. builder.AddDataEdge(add3, 0, net_output, 2);
  755. builder.AddDataEdge(add4, 0, net_output, 3);
  756. builder.AddDataEdge(add5, 0, net_output, 4);
  757. auto compute_graph = builder.GetGraph();
  758. return compute_graph;
  759. }
  760. ///
  761. /// GraphWithMultiOutputPrefetch: Prefetch has more than one output
  762. ///
  763. /// w1 w2 w3 w4 w5
  764. /// \ / \ / \ / \ / \.
  765. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
  766. /// / \ / \ / \ / \ /
  767. /// / \ / \ / \ / \ /
  768. /// const1 ----- add1 add2 add3 add4 add5
  769. /// | \ | / |
  770. /// | \ | / |
  771. /// | \ | / |
  772. /// | \ | / |
  773. /// -------------- net_output ---------------
  774. ///
  775. /// Memory distribution:
  776. ///
  777. /// |___w1__|__w2__|__w3__|__|
  778. ///
  779. /// |_____w4_____|_____w5____|
  780. ///
  781. ComputeGraphPtr BufferPoolGraphBuilder::BuildGraphWithMultiInputOutputPrefetch() {
  782. auto builder = InnerGraphBuilder(graph_name_);
  783. auto w1 = builder.AddNode("w1", VARIABLE, 0, 1);
  784. auto w2 = builder.AddNode("w2", VARIABLE, 0, 1);
  785. auto w3 = builder.AddNode("w3", VARIABLE, 0, 1);
  786. auto w4 = builder.AddNode("w4", VARIABLE, 0, 1);
  787. auto w5 = builder.AddNode("w5", VARIABLE, 0, 1);
  788. const int64_t buffer_pool_id = 0;
  789. const int64_t buffer_pool_size = 5600;
  790. auto prefetch1 = builder.AddNode("prefetch1", HCOMALLGATHER, 2, 2);
  791. SetPrefetchNodeInfo(prefetch1, buffer_pool_id, buffer_pool_size, {500, 500});
  792. auto prefetch2 = builder.AddNode("prefetch2", HCOMALLGATHER, 2, 2);
  793. SetPrefetchNodeInfo(prefetch2, buffer_pool_id, buffer_pool_size, {500, 500});
  794. auto prefetch3 = builder.AddNode("prefetch3", HCOMALLGATHER, 2, 2);
  795. SetPrefetchNodeInfo(prefetch3, buffer_pool_id, buffer_pool_size, {500, 1024});
  796. auto prefetch4 = builder.AddNode("prefetch4", HCOMALLGATHER, 2, 2);
  797. SetPrefetchNodeInfo(prefetch4, buffer_pool_id, buffer_pool_size, {1024, 1024});
  798. auto prefetch5 = builder.AddNode("prefetch5", HCOMALLGATHER, 1, 1);
  799. SetPrefetchNodeInfo(prefetch5, buffer_pool_id, buffer_pool_size, {1024});
  800. auto const1 = builder.AddNode("const1", CONSTANTOP, 0, 1);
  801. auto add1 = builder.AddNode("add1", ADD, 2, 1);
  802. auto add2 = builder.AddNode("add2", ADD, 2, 1);
  803. auto add3 = builder.AddNode("add3", ADD, 2, 1);
  804. auto add4 = builder.AddNode("add4", ADD, 2, 1);
  805. auto add5 = builder.AddNode("add5", ADD, 2, 1);
  806. auto net_output = builder.AddNode("net_output", NETOUTPUT, 5, 0);
  807. builder.AddDataEdge(w1, 0, prefetch1, 0);
  808. builder.AddDataEdge(w2, 0, prefetch1, 1);
  809. builder.AddDataEdge(w2, 0, prefetch2, 0);
  810. builder.AddDataEdge(w3, 0, prefetch2, 1);
  811. builder.AddDataEdge(w3, 0, prefetch3, 0);
  812. builder.AddDataEdge(w4, 0, prefetch3, 1);
  813. builder.AddDataEdge(w4, 0, prefetch4, 0);
  814. builder.AddDataEdge(w5, 0, prefetch4, 1);
  815. builder.AddDataEdge(w5, 0, prefetch5, 0);
  816. builder.AddDataEdge(const1, 0, add1, 0);
  817. builder.AddDataEdge(prefetch1, 0, add1, 1);
  818. builder.AddDataEdge(prefetch1, 1, add2, 0);
  819. builder.AddDataEdge(prefetch2, 0, add2, 1);
  820. builder.AddDataEdge(prefetch2, 1, add3, 0);
  821. builder.AddDataEdge(prefetch3, 0, add3, 1);
  822. builder.AddDataEdge(prefetch3, 1, add4, 0);
  823. builder.AddDataEdge(prefetch4, 0, add4, 1);
  824. builder.AddDataEdge(prefetch4, 1, add5, 0);
  825. builder.AddDataEdge(prefetch5, 0, add5, 1);
  826. builder.AddDataEdge(add1, 0, net_output, 0);
  827. builder.AddDataEdge(add2, 0, net_output, 1);
  828. builder.AddDataEdge(add3, 0, net_output, 2);
  829. builder.AddDataEdge(add4, 0, net_output, 3);
  830. builder.AddDataEdge(add5, 0, net_output, 4);
  831. auto compute_graph = builder.GetGraph();
  832. return compute_graph;
  833. }
  834. } // namespace ut
  835. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示