You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

buffer_pool_graph_builder.h 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GRAPH_UTILS_BUFFER_POOL_GRAPH_BUILDER_H_
  17. #define GRAPH_UTILS_BUFFER_POOL_GRAPH_BUILDER_H_
  18. #include <string>
  19. #include <vector>
  20. #include "graph/compute_graph.h"
  21. #include "graph/graph.h"
  22. #include "graph/node.h"
  23. namespace ge {
  24. namespace ut {
  25. class BufferPoolGraphBuilder {
  26. public:
  27. explicit BufferPoolGraphBuilder(const std::string &name = "BufferPoolGraph");
  28. ~BufferPoolGraphBuilder() {}
  29. class InnerGraphBuilder {
  30. public:
  31. explicit InnerGraphBuilder(const std::string &name);
  32. ~InnerGraphBuilder() {}
  33. NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt,
  34. Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT,
  35. std::vector<int64_t> shape = {1, 1, 224, 224});
  36. void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx);
  37. void AddControlEdge(NodePtr &src_node, NodePtr &dst_node);
  38. ComputeGraphPtr GetGraph() {
  39. graph_->TopologicalSorting();
  40. return graph_;
  41. }
  42. private:
  43. ComputeGraphPtr graph_;
  44. };
  45. ///
  46. /// Normal graph
  47. ///
  48. /// w1 w2 w3 w4 w5
  49. /// \ \ \ \ \.
  50. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
  51. /// \ \ \ \ \.
  52. /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
  53. ///
  54. ///
  55. /// Memory distribution:
  56. ///
  57. /// |___w1__|__w2__|__w3__|__|
  58. ///
  59. /// |_____w4_____|_____w5____|
  60. ///
  61. ComputeGraphPtr BuildNormalGraph();
  62. ///
  63. /// Normal graph with multi buffer pool
  64. ///
  65. /// w1 w2 w3 w4 w5
  66. /// \ \ \ \ \.
  67. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
  68. /// (pool0) (pool1) (pool0) (pool0) (pool1)
  69. /// \ \ \ \ \.
  70. /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
  71. ///
  72. ///
  73. /// Memory distribution:
  74. ///
  75. /// |___w1__|__w3__|_________|
  76. /// |_____w4_____|___________|
  77. ///
  78. /// |___w2__|_____w5___|_____|
  79. ///
  80. ComputeGraphPtr BuildNormalGraphWithMultiBufferPool();
  81. ///
  82. /// SerialGraph: Buffer pool size only can contain one prefetch node
  83. ///
  84. /// w1 w2 w3 w4 w5
  85. /// \ \ \ \ \.
  86. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
  87. /// \ \ \ \ \.
  88. /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ----- net_output
  89. ///
  90. ///
  91. /// Memory distribution:
  92. ///
  93. /// |____w1_____|__|
  94. ///
  95. /// |____w2_____|__|
  96. ///
  97. /// |____w3_____|__|
  98. ///
  99. /// |______w4______|
  100. ///
  101. /// |______w5______|
  102. ///
  103. ComputeGraphPtr BuildSerialGraph();
  104. ///
  105. /// GraphWithMultiPrefetch: Calc node with more prefetch node
  106. ///
  107. /// w1 w2 w3 w4 w5
  108. /// \ \ \ \ \.
  109. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 const1
  110. /// \ / \ / \ /
  111. /// \ / \ / \ /
  112. /// \ / \ / \ /
  113. /// add1 ------ c ------- add2 ----- c ----- add3
  114. /// | | |
  115. /// | | |
  116. /// --------------- net_output ------------
  117. ///
  118. /// Memory distribution:
  119. ///
  120. /// |___w1__|__w2__|__w3__|__|
  121. ///
  122. /// |_____w4_____|_____w5____|
  123. ///
  124. ComputeGraphPtr BuildGraphWithMultiPrefetch();
  125. ///
  126. /// GraphWithSubgraph: Calc node in different subgraph
  127. ///
  128. ///
  129. /// call_node1(with Subgraph1) --------------- call_node2 (with Subgraph2) --------------- net_output
  130. ///
  131. ///
  132. /// Subgraph1: Subgraph2:
  133. ///
  134. /// w1 w2 w3 w4 w5
  135. /// \ \ \ \ \.
  136. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
  137. /// \ \ \ \ \.
  138. /// const1 ----- add1 ----- add2 ----- add3 ---- subgraph1_out data1 ---- add4 ----- add5 ---- subgraph2_out
  139. ///
  140. ///
  141. /// Memory distribution:
  142. ///
  143. /// |___w1__|__w2__|__w3__|__|
  144. ///
  145. /// |_____w4_____|_____w5____|
  146. ///
  147. ComputeGraphPtr BuildGraphWithSubgraph();
  148. ///
  149. /// SubgraphWithInnerDependency: Calc node in different subgraph with inner dependency
  150. ///
  151. ///
  152. /// call_node1(with Subgraph1) --------------------- call_node2 (with Subgraph2) ---------- net_output
  153. ///
  154. ///
  155. /// Subgraph1: Subgraph2:
  156. ///
  157. /// w1 w2 w3 w4 w5
  158. /// \ \ \ \ \.
  159. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
  160. /// \ \ \ \ \.
  161. /// const1 ----- add1 ----- add2 ----- subgraph1_out data1 ---- add3 ---- add4 ----- add5 ---- subgraph2_out
  162. ///
  163. ///
  164. /// Memory distribution:
  165. ///
  166. /// |___w1__|__w2__|__w3__|__|
  167. ///
  168. /// |_____w4_____|_____w5____|
  169. ///
  170. ComputeGraphPtr BuildSubgraphWithInnerDependency();
  171. ///
  172. /// BuildGraphWithMultiBatch: Different batch label
  173. ///
  174. ///
  175. /// batch_label_128
  176. ///
  177. /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ---
  178. /// / / / / / / \.
  179. /// /c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 \.
  180. /// const1 switch_false / / / / / \.
  181. /// \ / / / / / / \.
  182. /// switch1 w1 w2 w3 w4 w5 merge1 -- net_output
  183. /// / \ \ \ \ \ \ /
  184. /// const2 switch_true \ \ \ \ \ /
  185. /// \c prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 /
  186. /// \ \ \ \ \ \ /
  187. /// const1 ----- add1 ----- add2 ----- add3 ----- add4 ----- add5 ---
  188. ///
  189. /// batch_label_256
  190. ///
  191. ///
  192. /// Memory distribution:
  193. ///
  194. /// |___w1__|__w2__|__w3__|__|
  195. ///
  196. /// |_____w4_____|_____w5____|
  197. ///
  198. ComputeGraphPtr BuildGraphWithMultiBatch();
  199. ///
  200. /// GraphWithMultiOutputPrefetch: Prefetch has more than one output
  201. ///
  202. /// w1 w2 w3 w4 w5
  203. /// \ \ \ \ \.
  204. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
  205. /// / \ / \ / \ / \ /
  206. /// / \ / \ / \ / \ /
  207. /// const1 ----- add1 add2 add3 add4 add5
  208. /// | \ | / |
  209. /// | \ | / |
  210. /// | \ | / |
  211. /// | \ | / |
  212. /// -------------- net_output ---------------
  213. ///
  214. /// Memory distribution:
  215. ///
  216. /// |___w1__|__w2__|__w3__|__|
  217. ///
  218. /// |_____w4_____|_____w5____|
  219. ///
  220. ComputeGraphPtr BuildGraphWithMultiOutputPrefetch();
  221. ///
  222. /// GraphWithMultiOutputPrefetch: Prefetch has more than one output
  223. ///
  224. /// w1 w2 w3 w4 w5
  225. /// \ / \ / \ / \ / \.
  226. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
  227. /// / \ / \ / \ / \ /
  228. /// / \ / \ / \ / \ /
  229. /// const1 ----- add1 add2 add3 add4 add5
  230. /// | \ | / |
  231. /// | \ | / |
  232. /// | \ | / |
  233. /// | \ | / |
  234. /// -------------- net_output ---------------
  235. ///
  236. /// Memory distribution:
  237. ///
  238. /// |___w1__|__w2__|__w3__|__|
  239. ///
  240. /// |_____w4_____|_____w5____|
  241. ///
  242. ComputeGraphPtr BuildGraphWithMultiInputOutputPrefetch();
  243. void SetBufferPool(NodePtr &node, int64_t pool_id, int64_t pool_size, const std::string &batch_label = "");
  244. void SetBatchLabel(NodePtr &node, const std::string &batch_label = "");
  245. void SetOutputMemSize(NodePtr &node, const std::vector<int64_t> &mem_size = {1024});
  246. void SetWorkSpaceMemSize(NodePtr &node, const std::vector<int64_t> &ws_bytes = {1024});
  247. void SetPrefetchNodeInfo(NodePtr &node, int64_t pool_id, int64_t pool_size,
  248. const std::vector<int64_t> &mem_size = {1024},
  249. const std::vector<int64_t> &ws_bytes = {1024},
  250. const std::string &batch_label = "");
  251. private:
  252. std::string graph_name_;
  253. };
  254. } // namespace ut
  255. } // namespace ge
  256. #endif // GRAPH_UTILS_BUFFER_POOL_GRAPH_BUILDER_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示