You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

block_mem_assigner.h 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_BUILD_MEMORY_BLOCK_MEM_ASSIGNER_H_
  17. #define GE_GRAPH_BUILD_MEMORY_BLOCK_MEM_ASSIGNER_H_
  18. #include <map>
  19. #include <string>
  20. #include <unordered_map>
  21. #include <unordered_set>
  22. #include <utility>
  23. #include <vector>
  24. #include <list>
  25. #include "framework/common/ge_inner_error_codes.h"
  26. #include "framework/common/types.h"
  27. #include "framework/common/util.h"
  28. #include "graph/build/memory/mem_assigner.h"
  29. #include "graph/compute_graph.h"
  30. #include "graph/utils/graph_utils.h"
  31. namespace ge {
  32. const size_t kMaxLifeTime = 0xffffffff;
  33. const int32_t kInvalidThreadScopeId = -1;
  34. const uint64_t kSessionScopeMemory = 0x100000000;
  35. const uint64_t kMemoryTypeMask = 0xffffffff;
  36. enum MemoryNoReuseScope { kReuse, kSessionNoReuse, kGraphNoReuse };
  37. using DependStreamLife = std::map<int64_t, std::map<int64_t, size_t>>;
  38. enum OpMemoryType { kOutput, kWorkspace };
  39. struct NodeTypeIndex {
  40. NodeTypeIndex(ge::NodePtr node, OpMemoryType mem_type, uint32_t index, bool ref_input = false, size_t begin = 0,
  41. int32_t thread_scope_id = kInvalidThreadScopeId)
  42. : node(std::move(node)), mem_type(mem_type), index(index), ref_input(ref_input), life_time_begin(begin),
  43. thread_scope_id(thread_scope_id) {}
  44. ge::NodePtr node = nullptr;
  45. OpMemoryType mem_type = kOutput;
  46. uint32_t index = 0;
  47. bool ref_input = false;
  48. size_t life_time_begin = 0;
  49. size_t life_time_end = kMaxLifeTime;
  50. int32_t thread_scope_id = kInvalidThreadScopeId;
  51. const string GetMemType() const {
  52. if (mem_type == kOutput) {
  53. return "output";
  54. } else if (mem_type == kWorkspace) {
  55. return "workspace";
  56. }
  57. return "unknown";
  58. }
  59. size_t GetLifeBegin() const {
  60. if ((node == nullptr) || (node->GetOpDesc() == nullptr)) {
  61. return 0;
  62. }
  63. if ((life_time_begin > 0) && (life_time_begin < static_cast<size_t>(node->GetOpDesc()->GetId()))) {
  64. return life_time_begin;
  65. } else {
  66. return node->GetOpDesc()->GetId();
  67. }
  68. }
  69. std::string GetLifeBeginDesc() const {
  70. if (node == nullptr) {
  71. return "";
  72. }
  73. auto node_op_desc = node->GetOpDesc();
  74. if (node_op_desc != nullptr) {
  75. auto life_begin = GetLifeBegin();
  76. if (life_begin != static_cast<size_t>(node_op_desc->GetId())) {
  77. return std::to_string(life_begin) + "-" + std::to_string(node_op_desc->GetId());
  78. } else {
  79. return std::to_string(node_op_desc->GetId());
  80. }
  81. }
  82. return "";
  83. }
  84. };
  85. class MemoryBlock {
  86. public:
  87. explicit MemoryBlock(size_t block_size, int64_t stream_id = 0, bool reuse_mem = true,
  88. int64_t memory_type = RT_MEMORY_HBM)
  89. : ref_count_(0),
  90. stream_id_(stream_id),
  91. deleted_block_(false),
  92. reuse_mem_(reuse_mem),
  93. same_stream_(true),
  94. input_index_(0),
  95. continuous_block_(false),
  96. first_continuous_block_(false),
  97. last_continuous_block_(false),
  98. is_zero_copy_(false),
  99. memory_type_(memory_type),
  100. block_size_(block_size),
  101. head_offset_(0),
  102. tail_offset_(0),
  103. child_offset_(0) {}
  104. MemoryBlock(const MemoryBlock &) = delete;
  105. MemoryBlock &operator=(const MemoryBlock &) = delete;
  106. ~MemoryBlock() {
  107. node_type_index_list_.clear();
  108. symbol_list_.clear();
  109. }
  110. size_t Size() const { return block_size_; }
  111. void SetSize(size_t size) {
  112. if (size > block_size_) {
  113. block_size_ = size;
  114. }
  115. }
  116. size_t AlignSize() const;
  117. void SetHeadOffset(size_t offset);
  118. void SetTailOffset(size_t offset);
  119. size_t HeadOffset() const { return head_offset_; }
  120. size_t TailOffset() const { return tail_offset_; }
  121. void AddNodeTypeIndex(const NodeTypeIndex &node_type_index, size_t real_size, size_t no_align_size) {
  122. node_type_index_list_.emplace_back(node_type_index);
  123. real_size_list_.emplace_back(real_size);
  124. no_align_size_list_.emplace_back(no_align_size);
  125. if ((node_type_index.node != nullptr) && (node_type_index.node->GetOpDesc() != nullptr)) {
  126. auto stream_id = node_type_index.node->GetOpDesc()->GetStreamId();
  127. if (stream_id != stream_id_) {
  128. same_stream_ = false;
  129. }
  130. }
  131. if (node_type_index.thread_scope_id != kInvalidThreadScopeId) {
  132. thread_scope_id_.insert(node_type_index.thread_scope_id);
  133. }
  134. }
  135. void AddSymbol(const std::string &symbol) {
  136. symbol_list_.emplace_back(symbol);
  137. }
  138. const std::vector<NodeTypeIndex> &NodeTypeIndexList() const { return node_type_index_list_; }
  139. const std::vector<std::string> &SymbolList() const { return symbol_list_; }
  140. const std::vector<size_t> &RealSizeList() const { return real_size_list_; }
  141. const std::vector<MemoryBlock *> &ChildBlockList() const { return child_blocks_; }
  142. const std::vector<size_t> &NoAlignSizeList() const { return no_align_size_list_; }
  143. const std::set<int32_t> &ThreadScopeId() const { return thread_scope_id_; }
  144. void Resize();
  145. std::string String();
  146. bool IsSameBatchLabel();
  147. void AddContinuousLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life);
  148. void AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &node_depend_stream_life);
  149. void SetLifeTimeEnd(size_t time);
  150. size_t GetLifeBegin();
  151. size_t GetLifeEnd() const;
  152. void AddDependLifeBegin(DependStreamLife &node_depend_stream_life);
  153. size_t GetDependLifeBegin(int64_t stream_id, DependStreamLife &node_depend_stream_life);
  154. bool CanReuse(int32_t thread_scope_id) const;
  155. int ref_count_;
  156. int64_t stream_id_;
  157. bool deleted_block_;
  158. bool reuse_mem_;
  159. bool same_stream_;
  160. uint32_t input_index_;
  161. bool continuous_block_;
  162. bool first_continuous_block_;
  163. bool last_continuous_block_;
  164. bool is_zero_copy_;
  165. std::map<int64_t, size_t> depend_stream_life_;
  166. int64_t memory_type_;
  167. std::string batch_label_;
  168. private:
  169. size_t block_size_;
  170. std::vector<size_t> real_size_list_;
  171. std::vector<size_t> no_align_size_list_;
  172. size_t head_offset_;
  173. size_t tail_offset_;
  174. size_t child_offset_;
  175. std::vector<NodeTypeIndex> node_type_index_list_;
  176. std::vector<std::string> symbol_list_;
  177. std::vector<MemoryBlock *> child_blocks_;
  178. std::set<int32_t> thread_scope_id_;
  179. };
  180. class BlockMemAssigner : public MemAssigner {
  181. public:
  182. BlockMemAssigner(ComputeGraphPtr compute_graph, const std::map<std::string, std::string> &anchor_to_symbol,
  183. const std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors);
  184. BlockMemAssigner(const BlockMemAssigner &) = delete;
  185. BlockMemAssigner &operator=(const BlockMemAssigner &) = delete;
  186. ~BlockMemAssigner() override;
  187. Status Assign() override;
  188. const std::map<uint64_t, size_t> &GetMemOffsets() const { return mem_offsets_; }
  189. int64_t GetAtomicAddrCleanId() const { return atomic_addr_clean_id_; }
  190. std::vector<MemoryBlock *> GetMemoryBlocks() const { return memory_blocks_; }
  191. ///
  192. /// @ingroup domi
  193. /// @brief memory size fixed for reuse. get memory range
  194. /// @param [out] ranges return memory range
  195. /// @return Status result
  196. ///
  197. virtual Status GetMemoryRanges(std::vector<int64_t> &ranges) = 0;
  198. ///
  199. /// @ingroup domi
  200. /// @brief traverse all nodes' outputs and needed workspace mem, apply memory, consider reuse memory
  201. /// @param [in] ranges memory range provided
  202. /// @author
  203. ///
  204. void AssignMemoryWithReuse(std::vector<int64_t> &ranges);
  205. void SetOpMemOffset(bool is_zero_copy);
  206. std::string GetMaxBatchLabel() const { return max_batch_label_; }
  207. protected:
  208. ///
  209. /// @ingroup domi
  210. /// @brief traverse all memory size, resize, and calculate offset
  211. /// @param [in&out] memory_blocks memory size, resize and calculate memory address after offset
  212. ///
  213. void ResizeMemoryBlocks();
  214. void GetOutAndWorkSpaceMem(std::vector<int64_t> &all_memory_size);
  215. void GetNodeWorkSpaceSize(const ge::NodePtr &node, std::vector<int64_t> &workspace_memory, int64_t &total_size);
  216. ///
  217. /// @ingroup GE
  218. /// @brief Determine whether it is the type of zero memory node.
  219. /// @param [in] node type.
  220. /// @return bool true: is zero memory node; false: is not zero memory node
  221. /// @author
  222. ///
  223. bool CheckIsZeroMemNodeType(const std::string &node_type) const;
  224. ///
  225. /// @ingroup GE
  226. /// @brief Check pre_reuse flag & post_reuse glag for each symbol
  227. /// @return void
  228. ///
  229. void InitReuseFlag();
  230. ///
  231. /// @ingroup GE
  232. /// @brief get pre_reuse flag
  233. /// @param [in] node
  234. /// @param [in] out_index
  235. /// @return bool
  236. ///
  237. bool IsPreReuse(const NodePtr &node, uint32_t out_index) const;
  238. ///
  239. /// @ingroup GE
  240. /// @brief get post_reuse flag
  241. /// @param [in] mem_block
  242. /// @return bool
  243. ///
  244. bool IsPostReuse(const MemoryBlock *mem_block) const;
  245. ///
  246. /// @ingroup GE
  247. /// @brief check if symbol of cur node_index_io has block
  248. /// @param [in] node_index_io
  249. /// @param [out] symbol
  250. /// @return bool
  251. ///
  252. bool IsSymbolExist(const NodeIndexIO &node_index_io, std::string &symbol);
  253. ///
  254. /// @ingroup GE
  255. /// @brief Print symbol
  256. /// @return void
  257. ///
  258. void PrintSymbolMap();
  259. ///
  260. /// @ingroup GE
  261. /// @brief Get the memory type corresponding to the current symbol.
  262. /// @param [in] node_index_io_list
  263. /// @param [out] memory_type
  264. /// @return void
  265. ///
  266. void GetSymbolMemType(std::list<NodeIndexIO> node_index_io_list, int64_t &memory_type);
  267. ///
  268. /// @ingroup GE
  269. /// @brief Update input tensor or output tensor of op to new memory type attr.
  270. /// @param [in] node_index_io_list
  271. /// @param [in] memory_type
  272. /// @return void
  273. ///
  274. void UpdateOpTensorMemType(std::list<NodeIndexIO> node_index_io_list, int64_t memory_type);
  275. std::map<uint64_t, size_t> mem_offsets_;
  276. ge::ComputeGraphPtr compute_graph_;
  277. std::vector<MemoryBlock *> memory_blocks_;
  278. std::vector<MemoryBlock *> blocks_store_;
  279. std::vector<NodeTypeIndex> zero_memory_list_;
  280. // ref mapping
  281. const std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors_;
  282. const std::map<std::string, std::string> &anchor_to_symbol_;
  283. std::map<std::string, bool> pre_reuse_flag_;
  284. std::map<std::string, bool> post_reuse_flag_;
  285. std::map<std::string, size_t> symbol_size_;
  286. std::map<std::string, int64_t> symbol_to_mem_type_;
  287. private:
  288. ///
  289. /// @ingroup GE
  290. /// @brief Traversing the compute_graph_ to apply for output memory while considering reuse
  291. /// @param [in] n: node in compute_graph_
  292. /// @param [in] index: output node index
  293. /// @param [in] ranges: available memory specifications
  294. /// @param [in] is_op_reuse_mem: Whether the op reuses the memory, true: reuse; false: not reuse
  295. /// @param [in] continuous: Whether the op uses continuous memory
  296. /// @return MemoryBlock*
  297. /// @author
  298. ///
  299. MemoryBlock *ApplyOutMemory(const ge::NodePtr &n, uint32_t index, const std::vector<int64_t> &ranges,
  300. const bool is_op_reuse_mem, const bool continuous);
  301. Status AssignOutputMemoryWithReuse(const NodePtr &node, vector<int64_t> &ranges);
  302. ///
  303. /// @ingroup GE
  304. /// @brief Traversing the compute_graph_ to apply for memory while considering reuse
  305. /// @param [in] block_size applied memory block size
  306. /// @param [in] real_size actual memory size required
  307. /// @param [in] type output or workspace
  308. /// @param [in] n node in compute_graph_
  309. /// @param [in] out_index output node index
  310. /// @param [in] workspace_reuse_flag reuse flag for workspace
  311. /// @param [in] is_op_reuse_mem whether the op reuses memory
  312. /// @param [in] continuous whether the memory of op is continuous
  313. /// @param [in] memory_type device memory type
  314. /// @return MemoryBlock*
  315. /// @author
  316. ///
  317. MemoryBlock *ApplyMemory(size_t block_size, size_t real_size, size_t no_align_size, OpMemoryType mem_type,
  318. const ge::NodePtr &n, uint32_t out_index, const std::vector<bool> &workspace_reuse_flag,
  319. const bool is_op_reuse_mem, const bool continuous, uint64_t memory_type);
  320. ///
  321. /// @ingroup GE
  322. /// @brief check workspace_reuse_flag to judge if add workspace block wait reuse
  323. /// @param [in] workspace_reuse_flag mark out index if support resue
  324. /// @param [in] index out index
  325. /// @param [in] stream_id which stream op in
  326. /// @param [in] mem_block node workspace mem_block
  327. /// @param [in] memory_type workspace memory type
  328. /// @return void
  329. /// @author
  330. ///
  331. void CheckWorkspaceReuse(const vector<bool> &workspace_reuse_flag, uint32_t index, int64_t stream_id,
  332. MemoryBlock *mem_block, uint64_t memory_type);
  333. ///
  334. /// @ingroup GE
  335. /// @brief Release memory block to reusable list
  336. /// @param [in] to_release memory block to be released
  337. /// @param [in] reusable_memory reusable list
  338. /// @return void
  339. /// @author
  340. ///
  341. void ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory, bool same_stream = true);
  342. ///
  343. /// @ingroup GE
  344. /// @brief Release memory blocks to reusable list
  345. /// @param [in] to_releases memory blocks to be released
  346. /// @param [in] reusable_memory reusable list
  347. /// @return void
  348. /// @author
  349. ///
  350. void ReleaseMemorys(const vector<MemoryBlock *> &to_releases, vector<MemoryBlock *> &reusable_memory);
  351. ///
  352. /// @ingroup GE
  353. /// @brief Release memory block to reusable list
  354. /// @param [in] n node in compute_graph_
  355. /// @param [in] node_out_blocks output memory blocks for ops
  356. /// @param [in] reusable_memory reusable list
  357. /// @return void
  358. /// @author
  359. ///
  360. void ReleaseInputNodeOutMemory(const std::unordered_map<string, vector<MemoryBlock *>> &node_out_blocks,
  361. vector<MemoryBlock *> &reusable_memory, ge::NodePtr &n);
  362. ///
  363. /// @ingroup GE
  364. /// @brief Resize memory blocks for each batchs
  365. /// @return merge or not
  366. /// @author
  367. ///
  368. void ResizeDynamicBatchBlocks();
  369. void AssignContinuousBlocks();
  370. bool IsZeroCopyBlock(const NodePtr &node, bool continuous);
  371. bool IsOutNodeSetContinuousInput(const NodePtr &n, uint32_t out_index, std::string &peer_name,
  372. uint32_t &peer_input_index, bool &no_need_assign_memory, bool &reset_zero_copy_flag);
  373. bool IsContinuousMemoryReuse(const NodePtr &n, const NodePtr &peer_node, uint32_t out_index);
  374. ///
  375. /// @ingroup GE
  376. /// @|+++++++++block1++++++++| |+++++++++block1++++++++|
  377. /// @|+++++++++block1++++++++||++block2++| |+++++++++block1++++++++||++block2++|
  378. /// @ |++block2++||++block3++| ==> |++block3++| |++block2++|
  379. /// @ |++block3++| |++block3++|
  380. /// @return void
  381. /// @author
  382. ///
  383. void ReuseBlocksByLifeTime(size_t range_size);
  384. bool IsContinuousOutput(const NodePtr &n);
  385. bool GetWorkSpaceMemoryType(const NodePtr &node, size_t index, uint64_t &memory_type,
  386. vector<bool> &workspace_reuse_flag);
  387. void ContinuousOutRefCheck(bool &isAllOutputRef, bool &isOutputHasRef, const NodePtr &n);
  388. Status ApplyContinuousMemory(const NodePtr &n, const vector<int64_t> &ranges, const bool is_op_reuse_mem);
  389. void MarkContinuousAllocedForOneInputFromVariable(const NodePtr &node);
  390. void CheckAndReleaseSuspendedBlock(const NodePtr &node, uint32_t idx, MemoryBlock *block);
  391. std::unordered_map<int64_t, std::unordered_map<int64_t, std::vector<MemoryBlock *>>> reusable_blocks_;
  392. std::unordered_map<int64_t, std::unordered_map<int64_t, std::vector<MemoryBlock *>>> stream_workspace_blocks_;
  393. std::unordered_map<std::string, std::vector<MemoryBlock *>> node_out_blocks_;
  394. std::unordered_map<std::string, MemoryBlock *> symbol_blocks_;
  395. std::unordered_map<std::string, std::unordered_map<uint32_t, MemoryBlock *>> node_continuous_input_blocks_;
  396. std::map<std::string, uint32_t> node_continuous_input_counts_;
  397. // reuse memory
  398. vector<string> op_no_reuse_mem_vec_;
  399. bool op_reuse_env_valid_ = false;
  400. std::string ge_disable_reuse_mem_env_ = "0";
  401. bool is_op_reuse_mem_ = true;
  402. size_t life_time_;
  403. int64_t atomic_addr_clean_id_ = 0;
  404. size_t theory_min_memory_size_ = 0;
  405. size_t theory_memory_size_ = 0;
  406. std::string max_batch_label_;
  407. size_t continuous_life_begin_ = 0;
  408. ///
  409. /// @ [stream1][nodeid]
  410. /// @[nodeid] [stream2][nodeid]
  411. /// @ [stream2][nodeid]
  412. ///
  413. DependStreamLife total_node_depend_stream_life_;
  414. bool root_unknown_shape_flag_ = false;
  415. };
  416. } // namespace ge
  417. #endif // GE_GRAPH_BUILD_MEMORY_BLOCK_MEM_ASSIGNER_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示