You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

buffer_pool_memory_pass.cc 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/buffer_pool_memory_pass.h"
  17. #include <string>
  18. #include <vector>
  19. #include "common/omg_util.h"
  20. #include "graph/utils/node_utils.h"
  21. #include "graph/utils/tensor_utils.h"
  22. #include "graph/utils/op_desc_utils.h"
  23. #include "common/math/math_util.h"
  24. namespace ge {
  25. namespace {
  26. const size_t kBufferPoolNodeInSize = 1;
  27. const size_t kBufferPoolNodeOutSize = 1;
  28. } // namespace
  29. Status BufferPoolMemoryPass::Run(ComputeGraphPtr graph) {
  30. if (graph == nullptr) {
  31. GELOGE(PARAM_INVALID, "[Check][Graph]Graph is nullptr");
  32. REPORT_INNER_ERROR("E19999", "Input graph is nullptr");
  33. return PARAM_INVALID;
  34. }
  35. // The cache prefetching scheme is developed for very large models, which gets the weight data in advance
  36. // and allocates it to a special memory pool. When the large model is dynamic shape, it need to go through
  37. // the executor flow and is not allocated memory statically. This is another development point, so we will
  38. // skip the dynamic shape model processing here.
  39. if (graph->GetParentGraph() != nullptr || graph->GetGraphUnknownFlag()) {
  40. return SUCCESS;
  41. }
  42. if (!IsBufferPoolMemEnable(graph)) {
  43. GELOGD("[Check][Enable]Buffer pool memory is not enable, graph:%s.", graph->GetName().c_str());
  44. return SUCCESS;
  45. }
  46. Status ret = graph->TopologicalSorting();
  47. if (ret != SUCCESS) {
  48. GELOGE(ret, "[TopologicalSort][Graph]Graph name:%s.", graph->GetName().c_str());
  49. REPORT_CALL_ERROR("E19999", "Failed to topological sort for graph:%s.", graph->GetName().c_str());
  50. return ret;
  51. }
  52. ret = CopyOutForMultiUsedOutput(graph);
  53. if (ret != SUCCESS) {
  54. GELOGE(FAILED, "[Copy][Output]Graph:%s.", graph->GetName().c_str());
  55. return FAILED;
  56. }
  57. ret = GetBufferPoolAndPeerCalcNodes(graph);
  58. if (ret != SUCCESS) {
  59. GELOGE(FAILED, "[Get][BufferPoolNode]Graph:%s.", graph->GetName().c_str());
  60. return FAILED;
  61. }
  62. if (calc_nodes_.empty()) {
  63. GELOGE(FAILED, "[Check][BufferPoolNode]Graph:%s.", graph->GetName().c_str());
  64. REPORT_CALL_ERROR("E19999", "All Buffer pool nodes are isolated nodes in graph:%s.", graph->GetName().c_str());
  65. return FAILED;
  66. }
  67. ret = AllocateAllBufferPoolSpace();
  68. if (ret != SUCCESS) {
  69. GELOGE(FAILED, "[Alloc][BufferPoolMem]Graph:%s.", graph->GetName().c_str());
  70. return FAILED;
  71. }
  72. ret = SetResultOfMemoryAndEvent();
  73. if (ret != SUCCESS) {
  74. GELOGE(FAILED, "[Set][Result]Graph:%s.", graph->GetName().c_str());
  75. return FAILED;
  76. }
  77. ret = graph->TopologicalSorting();
  78. if (ret != SUCCESS) {
  79. GELOGE(ret, "[TopologicalSort][Graph]Graph name:%s.", graph->GetName().c_str());
  80. REPORT_CALL_ERROR("E19999", "Failed to topological sort for graph:%s.", graph->GetName().c_str());
  81. return ret;
  82. }
  83. return SUCCESS;
  84. }
  85. void BufferPoolMemoryPass::ClearQueue(std::queue<std::pair<std::string, uint32_t>> &q) {
  86. while (!q.empty()) {
  87. q.pop();
  88. }
  89. }
  90. Status BufferPoolMemoryPass::IsBufferPoolMemEnable(const ComputeGraphPtr &graph) {
  91. for (NodePtr &node : graph->GetAllNodes()) {
  92. auto op_desc = node->GetOpDesc();
  93. if (op_desc == nullptr) {
  94. continue;
  95. }
  96. if (op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_ID) && op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_SIZE)) {
  97. return true;
  98. }
  99. }
  100. return false;
  101. }
  102. Status BufferPoolMemoryPass::CheckBufferPoolSize(int64_t total_size, int64_t pool_id, int64_t buffer_pool_size,
  103. std::unordered_map<int64_t, int64_t> &calc_total_size) {
  104. auto iter = calc_total_size.find(pool_id);
  105. if (iter == calc_total_size.end()) {
  106. calc_total_size[pool_id] = total_size;
  107. } else {
  108. FMK_INT64_ADDCHECK(calc_total_size[pool_id], total_size);
  109. calc_total_size[pool_id] += total_size;
  110. }
  111. if (calc_total_size[pool_id] > buffer_pool_size) {
  112. GELOGE(INTERNAL_ERROR, "[Check][Size]The memory required at the same is greater than buffer pool size, "
  113. "pool id:%ld, pool size:%ld, required size:%ld.", pool_id, buffer_pool_size, calc_total_size[pool_id]);
  114. REPORT_INNER_ERROR("E19999", "The memory required at the same is greater than buffer pool size, pool id:%ld,"
  115. " pool size:%ld, required size:%ld.", pool_id, buffer_pool_size, calc_total_size[pool_id]);
  116. return INTERNAL_ERROR;
  117. }
  118. return SUCCESS;
  119. }
  120. Status BufferPoolMemoryPass::TryToFixNodeOrder(NodePtr &pre_node, NodePtr &curr_node, bool &not_change) {
  121. auto pre_node_graph = pre_node->GetOwnerComputeGraph();
  122. auto curr_node_graph = curr_node->GetOwnerComputeGraph();
  123. std::string pre_node_stream_label;
  124. (void) AttrUtils::GetStr(pre_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, pre_node_stream_label);
  125. std::string curr_node_stream_label;
  126. (void) AttrUtils::GetStr(curr_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, curr_node_stream_label);
  127. not_change = true;
  128. if ((pre_node_graph == curr_node_graph) && (pre_node_stream_label == pre_node_stream_label)) {
  129. // Same subgraph, including simultaneously in the root graph.
  130. auto ret = ge::GraphUtils::AddEdge(pre_node->GetOutControlAnchor(), curr_node->GetInControlAnchor());
  131. if (ret != GRAPH_SUCCESS) {
  132. GELOGE(INTERNAL_ERROR, "[Add][Edge]Src:%s, dst:%s.", pre_node->GetName().c_str(), curr_node->GetName().c_str());
  133. REPORT_CALL_ERROR("E19999", "Failed to add ctrl edge from %s to %s.",
  134. pre_node->GetName().c_str(), curr_node->GetName().c_str());
  135. return INTERNAL_ERROR;
  136. }
  137. not_change = false;
  138. } else if (pre_node_graph->GetParentGraph() == curr_node_graph->GetParentGraph() &&
  139. pre_node_graph->GetParentNode() != nullptr && curr_node_graph->GetParentNode() != nullptr) {
  140. // Two nodes are located on different child graphs of different parent nodes.
  141. auto pre_node_parent_op_desc = pre_node_graph->GetParentNode()->GetOpDesc();
  142. auto curr_node_parent_op_desc = curr_node_graph->GetParentNode()->GetOpDesc();
  143. GE_CHECK_NOTNULL(pre_node_parent_op_desc);
  144. GE_CHECK_NOTNULL(curr_node_parent_op_desc);
  145. // The parent node dependency is correct to ensure that the child node dependency,
  146. // there is no need to add control edges.
  147. if (pre_node_parent_op_desc->GetId() > curr_node_parent_op_desc->GetId()) {
  148. GELOGE(INTERNAL_ERROR, "[Check][Dependency]Invalid dependency, pre node:%s, curr node:%s.",
  149. pre_node->GetName().c_str(), curr_node->GetName().c_str());
  150. REPORT_INNER_ERROR("E19999", "Invalid dependency, pre node:%s, curr node:%s.",
  151. pre_node->GetName().c_str(), curr_node->GetName().c_str());
  152. return INTERNAL_ERROR;
  153. }
  154. GELOGI("[Check][Dependency]The two nodes are located in sub graphs of different parent nodes and meet the "
  155. "dependency relationship. pre:%s, curr:%s.", pre_node->GetName().c_str(), curr_node->GetName().c_str());
  156. } else {
  157. GELOGE(INTERNAL_ERROR, "[Check][Dependency]Invalid dependency, pre node:%s, curr node:%s.",
  158. pre_node->GetName().c_str(), curr_node->GetName().c_str());
  159. REPORT_INNER_ERROR("E19999", "Invalid dependency, pre node:%s, curr node:%s.",
  160. pre_node->GetName().c_str(), curr_node->GetName().c_str());
  161. return INTERNAL_ERROR;
  162. }
  163. return SUCCESS;
  164. }
  165. Status BufferPoolMemoryPass::InsertMemCpyNodeAfter(ComputeGraphPtr &graph, NodePtr &node) {
  166. auto out_anchor = node->GetOutDataAnchor(kBufferPoolNodeOutIndex);
  167. OpDescBuilder op_desc_builder(node->GetName() + "_memcpy_async", MEMCPYASYNC);
  168. auto mem_copy_op = op_desc_builder.AddInput("x", node->GetOpDesc()->GetOutputDesc(kBufferPoolNodeOutIndex))
  169. .AddOutput("y", node->GetOpDesc()->GetOutputDesc(kBufferPoolNodeOutIndex))
  170. .Build();
  171. std::string batch_label;
  172. bool get_attr = AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, batch_label);
  173. if (get_attr && !batch_label.empty()) {
  174. (void) AttrUtils::SetStr(mem_copy_op, ATTR_NAME_STREAM_LABEL, batch_label);
  175. }
  176. auto peer_in_anchors = out_anchor->GetPeerInDataAnchors();
  177. std::vector<InDataAnchorPtr> in_anchors(peer_in_anchors.begin(), peer_in_anchors.end());
  178. if (GraphUtils::InsertNodeAfter(out_anchor, in_anchors, graph->AddNode(mem_copy_op)) != GRAPH_SUCCESS) {
  179. GELOGE(FAILED, "[Insert][Node] Node:%s.", node->GetName().c_str());
  180. REPORT_CALL_ERROR("E19999", "Failed to insert mem copy node after %s.", node->GetName().c_str());
  181. return FAILED;
  182. }
  183. return SUCCESS;
  184. }
  185. Status BufferPoolMemoryPass::CopyOutForMultiUsedOutput(ComputeGraphPtr &graph) {
  186. bool changed = false;
  187. for (NodePtr &node : graph->GetAllNodes()) {
  188. auto op_desc = node->GetOpDesc();
  189. if (op_desc == nullptr) {
  190. continue;
  191. }
  192. bool use_buffer_pool = op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_ID) && op_desc->HasAttr(ATTR_NAME_BUFFER_POOL_SIZE);
  193. if (use_buffer_pool) {
  194. if ((node->GetInDataNodes().size() == kBufferPoolNodeInSize) &&
  195. (node->GetOutDataNodes().size() == kBufferPoolNodeOutSize)) {
  196. continue;
  197. } else if ((node->GetAllInDataAnchors().size() == kBufferPoolNodeInSize) &&
  198. (node->GetAllOutDataAnchors().size() == kBufferPoolNodeOutSize)) {
  199. // A prefetching output is used in multiple places. Copy one so that the prefetching node remains
  200. // single input and single output.
  201. if (InsertMemCpyNodeAfter(graph, node) != SUCCESS) {
  202. GELOGE(INTERNAL_ERROR, "[Insert][MemCpy]Node:%s.", node->GetName().c_str());
  203. REPORT_INNER_ERROR("E19999", "Failed to insert mem copy node after %s.", node->GetName().c_str());
  204. return INTERNAL_ERROR;
  205. }
  206. changed = true;
  207. GELOGI("[Insert][Node]Insert mem copy node after %s.", node->GetName().c_str());
  208. } else {
  209. GELOGE(PARAM_INVALID, "[Check][InputOutput]Only support single input and single output, "
  210. "node:%s.", node->GetName().c_str());
  211. REPORT_INNER_ERROR("E19999", "Only support single input and single output, node:%s.", node->GetName().c_str());
  212. return PARAM_INVALID;
  213. }
  214. }
  215. }
  216. if (changed) {
  217. Status ret = graph->TopologicalSorting();
  218. if (ret != SUCCESS) {
  219. GELOGE(ret, "[TopologicalSort][Graph]Graph name:%s.", graph->GetName().c_str());
  220. REPORT_CALL_ERROR("E19999", "Failed to topological sort for graph:%s.", graph->GetName().c_str());
  221. return ret;
  222. }
  223. }
  224. return SUCCESS;
  225. }
  226. Status BufferPoolMemoryPass::GetBufferPoolAndPeerCalcNodes(const ComputeGraphPtr &graph) {
  227. std::unordered_map<std::string, std::unordered_map<int64_t, std::set<NodePtr>>> unique_calc_nodes;
  228. for (const NodePtr &node : graph->GetAllNodes()) {
  229. auto in_data_nodes = node->GetInAllNodes();
  230. for (NodePtr &in_node : in_data_nodes) {
  231. int64_t buffer_pool_id = 0;
  232. int64_t buffer_pool_size = 0;
  233. bool get_attr = AttrUtils::GetInt(in_node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_ID, buffer_pool_id);
  234. get_attr = get_attr && (AttrUtils::GetInt(in_node->GetOpDesc(), ATTR_NAME_BUFFER_POOL_SIZE, buffer_pool_size));
  235. if (get_attr) {
  236. std::string batch_label;
  237. (void) AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label);
  238. peer_buffer_node_item_[batch_label][node].emplace_back(BufferPoolNodeItem(in_node, 0, 0));
  239. buffer_node_to_calc_[batch_label][in_node] = node;
  240. if (unique_calc_nodes[batch_label][buffer_pool_id].count(node) == 0) {
  241. calc_nodes_[batch_label][buffer_pool_id].emplace_back(node);
  242. unique_calc_nodes[batch_label][buffer_pool_id].insert(node);
  243. }
  244. GELOGI("[Get][BufferNode]Calc node:%s, pool node:%s.", node->GetName().c_str(), in_node->GetName().c_str());
  245. Status ret = SetBufferPoolSize(batch_label, buffer_pool_id, buffer_pool_size);
  246. if (ret != SUCCESS) {
  247. GELOGE(ret, "[Set][BufferPoolSize]Node:%s", in_node->GetName().c_str());
  248. REPORT_INNER_ERROR("E19999", "Failed to set buffer pool size, something wrong with the info of node:%s",
  249. in_node->GetName().c_str());
  250. return ret;
  251. }
  252. }
  253. }
  254. }
  255. return SUCCESS;
  256. }
  257. Status BufferPoolMemoryPass::SetBufferPoolSize(const std::string &batch_label, int64_t id, int64_t size) {
  258. auto iter = buffer_pool_size_[batch_label].find(id);
  259. if (iter != buffer_pool_size_[batch_label].end() && iter->second != size) {
  260. GELOGE(PARAM_INVALID, "[Check][BufferPoolSize]Get different size with the same id, "
  261. "id:%ld, original size:%ld, this size:%ld.", id, iter->second, size);
  262. REPORT_INNER_ERROR("E19999", "Get different size with the same id, "
  263. "id:%ld, original size:%ld, this size:%ld.", id, iter->second, size);
  264. return PARAM_INVALID;
  265. }
  266. buffer_pool_size_[batch_label][id] = size;
  267. return SUCCESS;
  268. }
  269. Status BufferPoolMemoryPass::AllocateAllBufferPoolSpace() {
  270. for (const auto &iter : calc_nodes_) {
  271. std::string batch_label = iter.first;
  272. Status ret = AllocateSpaceInBatch(calc_nodes_[batch_label],
  273. buffer_pool_size_[batch_label],
  274. buffer_node_to_calc_[batch_label],
  275. peer_buffer_node_item_[batch_label]);
  276. if (ret != SUCCESS) {
  277. GELOGE(ret, "[Alloc][InBatch]Batch_label:%s.", batch_label.c_str());
  278. REPORT_INNER_ERROR("E19999", "Failed to allocate space in batch, batch_label:%s.", batch_label.c_str());
  279. return ret;
  280. }
  281. GELOGI("[Alloc][InBatch]Alloc space in batch successfully, batch label:%s.", batch_label.c_str());
  282. }
  283. return SUCCESS;
  284. }
  285. Status BufferPoolMemoryPass::AllocateSpaceInBatch(
  286. const std::map<int64_t, std::vector<NodePtr>> &calc_nodes,
  287. const std::unordered_map<int64_t, int64_t> &buffer_pool_size_map,
  288. const std::unordered_map<NodePtr, NodePtr> &buffer_node_to_calc,
  289. std::unordered_map<NodePtr, std::vector<BufferPoolNodeItem>> &buffer_pool_nodes_item) {
  290. for (const auto &calc_node_in_pool : calc_nodes) {
  291. int64_t pool_id = calc_node_in_pool.first;
  292. int64_t buffer_pool_size = buffer_pool_size_map.at(pool_id);
  293. ClearQueue(mem_ctrl_event_);
  294. ClearQueue(stream_ctrl_event_);
  295. BufferPool buffer_pool(pool_id, buffer_pool_size, buffer_node_to_calc);
  296. Status ret = AllocateSpaceInBufferPool(buffer_pool,
  297. calc_node_in_pool.second,
  298. buffer_pool_nodes_item);
  299. if (ret != SUCCESS) {
  300. GELOGE(ret, "[Alloc][InBufferPool]Pool id:%ld, pool size:%ld.", pool_id, buffer_pool_size);
  301. REPORT_INNER_ERROR("E19999", "Failed to allocate space in buffer pool, id:%ld, pool size:%ld.",
  302. pool_id, buffer_pool_size);
  303. return ret;
  304. }
  305. GELOGI("[Alloc][InBufferPool]Alloc space in buffer pool successfully, pool id:%ld.", pool_id);
  306. }
  307. return SUCCESS;
  308. }
  309. Status BufferPoolMemoryPass::AllocateSpaceInBufferPool(
  310. const BufferPool &buffer_pool,
  311. const std::vector<NodePtr> &calc_nodes_in_pool,
  312. std::unordered_map<NodePtr, std::vector<BufferPoolNodeItem>> &buffer_pool_nodes_item) {
  313. int64_t pool_id = buffer_pool.pool_id;
  314. int64_t buffer_pool_size = buffer_pool.pool_size;
  315. int64_t next_start = 0;
  316. NodePtr pre_buffer_pool_node = nullptr;
  317. std::queue<BufferPoolNodeItem> node_mem_range_in_pool;
  318. node_mem_range_in_pool.push(BufferPoolMemoryPass::BufferPoolNodeItem(nullptr, 0, buffer_pool_size));
  319. for (auto &calc_node : calc_nodes_in_pool) {
  320. auto &peer_buffer_node_item = buffer_pool_nodes_item[calc_node];
  321. std::unordered_map<int64_t, int64_t> calc_total_size;
  322. size_t input_buffer_node_num = 0;
  323. for (auto &node_item : peer_buffer_node_item) {
  324. auto peer_buffer_node = node_item.node;
  325. GE_CHECK_NOTNULL(peer_buffer_node);
  326. int64_t total_size = 0;
  327. ++input_buffer_node_num;
  328. Status ret = GetMemorySize(peer_buffer_node, total_size);
  329. if (ret != SUCCESS) {
  330. GELOGE(ret, "[Get][MemSize]Node:%s, calc_node:%s.",
  331. peer_buffer_node->GetName().c_str(), calc_node->GetName().c_str());
  332. REPORT_INNER_ERROR("E19999", "Failed to get memory size, node:%s, calc_node:%s.",
  333. peer_buffer_node->GetName().c_str(), calc_node->GetName().c_str());
  334. return ret;
  335. }
  336. ret = CheckBufferPoolSize(total_size, pool_id, buffer_pool_size, calc_total_size);
  337. if (ret != SUCCESS) {
  338. GELOGE(ret, "[Check][BufferPoolSize]Capacity is not enough for all data, calc_node:%s.",
  339. calc_node->GetName().c_str());
  340. REPORT_INNER_ERROR("E19999", "Capacity is not enough for all data, calc_node:%s.",
  341. calc_node->GetName().c_str());
  342. return ret;
  343. }
  344. BufferPoolNodeItem buffer_pool_node_item(peer_buffer_node, calc_node, pre_buffer_pool_node, total_size,
  345. 0, 0, (input_buffer_node_num == peer_buffer_node_item.size()));
  346. ret = AllocateSpaceForBufferPoolNode(next_start, buffer_pool, buffer_pool_node_item, node_mem_range_in_pool);
  347. if (ret != SUCCESS) {
  348. GELOGE(ret, "[Alloc][ForNode]Pool node:%s, calc_node:%s.",
  349. peer_buffer_node->GetName().c_str(), calc_node->GetName().c_str());
  350. REPORT_INNER_ERROR("E19999", "Failed to allocate space for buffer pool node:%s, calc_node:%s.",
  351. peer_buffer_node->GetName().c_str(), calc_node->GetName().c_str());
  352. return ret;
  353. }
  354. pre_buffer_pool_node = peer_buffer_node;
  355. }
  356. }
  357. return SUCCESS;
  358. }
  359. Status BufferPoolMemoryPass::AllocateSpaceForBufferPoolNode(int64_t &next_start,
  360. const BufferPool buffer_pool,
  361. BufferPoolNodeItem &buffer_pool_node_item,
  362. std::queue<BufferPoolNodeItem> &node_mem_range_in_pool) {
  363. // Get event id must be before FixTheTimingOfDependentNodes
  364. uint32_t logic_event = logic_event_num_;
  365. NodePtr buffer_node = buffer_pool_node_item.node;
  366. NodePtr calc_node = buffer_pool_node_item.out_calc_node;
  367. /// In the scenario where there are multiple PREFETCH operators in the inputs of the calculation operator,
  368. /// the addition of events is optimized to only add events after the last PREFETCH operator.
  369. /// w1 w2 w3 w4 w5
  370. /// | | | | |
  371. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5 xxx
  372. /// \ / \ / \ /
  373. /// \ / \ / \ /
  374. /// \ / \ / \ /
  375. /// node1 node2 node3
  376. /// | | |
  377. /// | | |
  378. /// --------------- other nodes ------------
  379. ///
  380. /// The event id of the PREFETCH operator to the calculation operator needs to be generated before
  381. /// FixTheTimingOfDependentNodes, because FixTheTimingOfDependentNodes may add a new id to stream_ctrl_event_,
  382. /// and this id cannot be reused until the next PREFETCH operator in the sequence.
  383. if (buffer_pool_node_item.is_last_input) {
  384. logic_event = GenerateEventId(buffer_node->GetName(), stream_ctrl_event_);
  385. node_event_multiplexing_[buffer_node].push_back(string("SendTo;" + calc_node->GetName() +
  386. ";" + std::to_string(logic_event)));
  387. mem_ctrl_event_.push(std::make_pair(calc_node->GetName(), logic_event));
  388. }
  389. NodePtr dependent_calc_node = GetOffsetAndDependency(next_start, buffer_pool_node_item.total_size,
  390. buffer_pool.pool_size,
  391. buffer_pool.buffer_node_to_calc,
  392. node_mem_range_in_pool);
  393. if (dependent_calc_node != nullptr) {
  394. Status ret = FixTheTimingOfDependentNodes(dependent_calc_node, buffer_node);
  395. if (ret != SUCCESS) {
  396. GELOGE(ret, "[Fix][Timing]Pool_id:%ld, pool node:%s, dependent node:%s.",
  397. buffer_pool.pool_id, buffer_node->GetName().c_str(), dependent_calc_node->GetName().c_str());
  398. REPORT_INNER_ERROR("E19999", "Failed to fix timing, pool_id:%ld, pool node:%s, dependent node:%s.",
  399. buffer_pool.pool_id, buffer_node->GetName().c_str(),
  400. dependent_calc_node->GetName().c_str());
  401. return ret;
  402. }
  403. }
  404. buffer_pool_node_item.offset_start = next_start;
  405. buffer_node_logical_offset_[buffer_node].push_back(buffer_pool_node_item.total_size);
  406. buffer_node_logical_offset_[buffer_node].push_back(next_start);
  407. FMK_INT64_ADDCHECK(next_start, buffer_pool_node_item.total_size);
  408. next_start += buffer_pool_node_item.total_size;
  409. buffer_pool_node_item.offset_end = next_start;
  410. node_mem_range_in_pool.push(buffer_pool_node_item);
  411. if (buffer_pool_node_item.pre_buffer_pool_node != nullptr) {
  412. bool not_change = true;
  413. auto ret = TryToFixNodeOrder(buffer_pool_node_item.pre_buffer_pool_node, buffer_node, not_change);
  414. if (ret != SUCCESS) {
  415. GELOGE(ret, "[Fix][BufferPoolNodeOrder]Pre node:%s, curr node:%s.",
  416. buffer_pool_node_item.pre_buffer_pool_node->GetName().c_str(), buffer_node->GetName().c_str());
  417. return ret;
  418. }
  419. }
  420. GELOGI("[Alloc][ForNode]Buffer pool node %s send to %s, offset start:%ld, send event id:%u.",
  421. buffer_node->GetName().c_str(), calc_node->GetName().c_str(),
  422. buffer_pool_node_item.offset_start, logic_event);
  423. return SUCCESS;
  424. }
  425. /// When generating the event ID, determine whether the name of the queue head node is the same as the name of
  426. /// the operator, in order to handle such scenarios:
  427. /// w1 w2 w3 w4 w5
  428. /// | | | | |
  429. /// prefetch1 prefetch2 prefetch3 prefetch4 prefetch5
  430. /// | | | | |
  431. /// node1 node2 node3 node4 node5
  432. ///
  433. /// Memory distribution:
  434. ///
  435. /// |____w1_____|__|
  436. ///
  437. /// |____w2_____|__|
  438. ///
  439. /// |____w3_____|__|
  440. ///
  441. /// |______w4______|
  442. ///
  443. /// |______w5______|
  444. ///
  445. /// In this scenario, prefetch2 depends on node1. If the dependency is handled by adding an event of node1 to prefetch2,
  446. /// the id sent by prefetch2 will be the same as the id it receives.Although Runtime supports this through WaitReset,
  447. /// we consider this a dangerous operation and avoid it.
  448. uint32_t BufferPoolMemoryPass::GenerateEventId(const std::string &node_name,
  449. std::queue<std::pair<std::string, uint32_t>> &event_queue) {
  450. uint32_t logic_event = logic_event_num_;
  451. if (!event_queue.empty()) {
  452. auto item = event_queue.front();
  453. if (item.first != node_name) {
  454. logic_event = item.second;
  455. event_queue.pop();
  456. return logic_event;
  457. }
  458. }
  459. ++logic_event_num_;
  460. return logic_event;
  461. }
  462. NodePtr BufferPoolMemoryPass::GetOffsetAndDependency(int64_t &next_start,
  463. int64_t total_mem_size,
  464. int64_t buffer_pool_size,
  465. const std::unordered_map<NodePtr, NodePtr> &buffer_node_to_calc,
  466. std::queue<BufferPoolMemoryPass::BufferPoolNodeItem> &nodes_in_buffer) {
  467. // The buffer pool can no longer fit this Tensor and needs to turn back.
  468. if (next_start + total_mem_size > buffer_pool_size) {
  469. next_start = 0;
  470. if (!nodes_in_buffer.empty()) {
  471. // Take up the rest of the space at the end,
  472. nodes_in_buffer.back().offset_end = buffer_pool_size;
  473. // Pop the first tensor memory in the previous round of the previous round.
  474. nodes_in_buffer.pop();
  475. }
  476. while (!nodes_in_buffer.empty()) {
  477. auto node_item = nodes_in_buffer.front();
  478. // Go to the begin of previous round.
  479. if (node_item.offset_start == 0) {
  480. break;
  481. }
  482. nodes_in_buffer.pop();
  483. }
  484. }
  485. while (!nodes_in_buffer.empty()) {
  486. auto node_item = nodes_in_buffer.front();
  487. if (next_start + total_mem_size <= node_item.offset_end) {
  488. auto pool_node = node_item.node;
  489. if (pool_node == nullptr) {
  490. return nullptr;
  491. }
  492. auto output_calc = buffer_node_to_calc.find(pool_node);
  493. if (output_calc != buffer_node_to_calc.end()) {
  494. return output_calc->second;
  495. }
  496. return nullptr;
  497. }
  498. nodes_in_buffer.pop();
  499. }
  500. return nullptr;
  501. }
  502. Status BufferPoolMemoryPass::FixTheTimingOfDependentNodes(NodePtr &dependent_calc_node, NodePtr &curr_pool_node) {
  503. // The previous process ensures that all pointers are not null.
  504. bool not_change = false;
  505. Status ret = TryToFixNodeOrder(dependent_calc_node, curr_pool_node, not_change);
  506. if (ret != SUCCESS) {
  507. GELOGE(ret, "[Fix][NodeOrder]Src:%s, dst:%s.",
  508. dependent_calc_node->GetName().c_str(), curr_pool_node->GetName().c_str());
  509. return ret;
  510. }
  511. if (not_change) {
  512. return SUCCESS;
  513. }
  514. uint32_t logic_event = GenerateEventId(dependent_calc_node->GetName(), mem_ctrl_event_);
  515. node_event_multiplexing_[curr_pool_node].push_back(string("RecvFrom;" + dependent_calc_node->GetName() +
  516. ";" + std::to_string(logic_event)));
  517. stream_ctrl_event_.push(std::make_pair(curr_pool_node->GetName(), logic_event));
  518. GELOGI("[Fix][Timing]Add ctrl edge for buffer pool memory from %s to %s, buffer pool node recv event:%u.",
  519. dependent_calc_node->GetName().c_str(), curr_pool_node->GetName().c_str(), logic_event);
  520. return SUCCESS;
  521. }
  522. Status BufferPoolMemoryPass::SetResultOfMemoryAndEvent() {
  523. for (auto &iter : node_event_multiplexing_) {
  524. auto node = iter.first;
  525. GE_CHECK_NOTNULL(node);
  526. auto op_desc = node->GetOpDesc();
  527. GE_CHECK_NOTNULL(op_desc);
  528. bool ret = AttrUtils::SetListStr(op_desc, ATTR_NAME_EVENT_MULTIPLEXING, iter.second);
  529. if (!ret) {
  530. GELOGE(INTERNAL_ERROR, "[Set][Attr]Node:%s.", node->GetName().c_str());
  531. REPORT_CALL_ERROR("E19999", "Failed to set event reuse info, node:%s.", node->GetName().c_str());
  532. return INTERNAL_ERROR;
  533. }
  534. auto offset_iter = buffer_node_logical_offset_.find(node);
  535. if (offset_iter == buffer_node_logical_offset_.end()) {
  536. GELOGE(INTERNAL_ERROR, "[Get][LogicalOffset]Node:%s.", node->GetName().c_str());
  537. REPORT_INNER_ERROR("E19999", "Failed to get logical offset and size, node:%s.", node->GetName().c_str());
  538. return INTERNAL_ERROR;
  539. }
  540. ret = AttrUtils::SetListInt(op_desc, ATTR_NAME_BUFFER_POOL_NODE_SIZE_AND_OFFSET, offset_iter->second);
  541. if (!ret) {
  542. GELOGE(INTERNAL_ERROR, "[Set][Attr]Node:%s.", node->GetName().c_str());
  543. REPORT_CALL_ERROR("E19999", "Failed to set node memory offset and size, node:%s.", node->GetName().c_str());
  544. return INTERNAL_ERROR;
  545. }
  546. }
  547. return SUCCESS;
  548. }
  549. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示