You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

block_mem_assigner.cc 60 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/build/memory/block_mem_assigner.h"
  17. #include <algorithm>
  18. #include <sstream>
  19. #include "framework/common/debug/ge_log.h"
  20. #include "graph/anchor.h"
  21. #include "graph/buffer.h"
  22. #include "graph/ge_attr_value.h"
  23. #include "graph/ge_context.h"
  24. #include "graph/node.h"
  25. #include "graph/utils/graph_utils.h"
  26. #include "graph/utils/node_utils.h"
  27. #include "graph/utils/op_desc_utils.h"
  28. #include "graph/utils/tensor_utils.h"
  29. #include "graph/debug/ge_attr_define.h"
  30. #include "graph/optimize/common/params.h"
  31. #include "omg/omg_inner_types.h"
  32. #include "runtime/mem.h"
  33. namespace {
  34. const char *const kAttrNameWorkspaceReuseFlag = "workspace_reuse_flag";
  35. const char *const kL2FusionDynamicConvergeOp = "l2fusion_dynamic_converge_op";
  36. const char *const kOpNoReuseMem = "no_reuse_mem_flag";
  37. const char *const kDisableReuseMemory = "ge.exec.disableReuseMemory";
  38. const char *const OP_NO_REUSE_MEM = "OP_NO_REUSE_MEM";
  39. const int kReuseMaxCount = 10;
  40. const int kReuseMaxOpNum = 10;
  41. const int kReuseMaxCharNum = 2000;
  42. } // namespace
  43. namespace ge {
  44. using std::map;
  45. using std::pair;
  46. using std::string;
  47. using std::stringstream;
  48. using std::unordered_map;
  49. using std::unordered_set;
  50. using std::vector;
  51. void MemoryBlock::SetHeadOffset(size_t offset) {
  52. head_offset_ = offset;
  53. size_t child_offset = head_offset_;
  54. for (auto block : child_blocks_) {
  55. if (block != nullptr) {
  56. block->SetHeadOffset(child_offset);
  57. child_offset += block->Size();
  58. }
  59. }
  60. }
  61. void MemoryBlock::SetTailOffset(size_t offset) {
  62. tail_offset_ = offset;
  63. size_t child_offset = head_offset_;
  64. for (auto block : child_blocks_) {
  65. if (block != nullptr) {
  66. child_offset += block->Size();
  67. block->SetTailOffset(child_offset - 1);
  68. }
  69. }
  70. }
  71. void MemoryBlock::Resize() {
  72. size_t child_block_size = 0;
  73. for (auto block : child_blocks_) {
  74. if (block != nullptr) {
  75. block->Resize();
  76. child_block_size += block->Size();
  77. }
  78. }
  79. auto iter = std::max_element(real_size_list_.begin(), real_size_list_.end());
  80. if (iter == real_size_list_.end()) {
  81. GELOGW("real_size_list_ is empty");
  82. return;
  83. } else {
  84. size_t block_size = (child_block_size > *iter) ? child_block_size : *iter;
  85. if ((block_size > 0) && (block_size % MEM_ALIGN_SIZE != 0)) {
  86. block_size = (block_size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE;
  87. }
  88. block_size_ = block_size;
  89. if (last_continuous_block_) {
  90. block_size_ += MEM_ALIGN_SIZE;
  91. }
  92. }
  93. }
  94. bool MemoryBlock::IsSameLabel(std::string &first_batch_label) {
  95. if (node_type_index_list_.empty()) {
  96. return false;
  97. }
  98. auto node_op_desc = node_type_index_list_[0].node->GetOpDesc();
  99. if (node_op_desc == nullptr) {
  100. return false;
  101. }
  102. // not all op has ATTR_NAME_BATCH_LABEL, no need check return value, only check out parameter
  103. (void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, first_batch_label);
  104. if (first_batch_label.empty()) {
  105. return false;
  106. }
  107. bool all_same_label = true;
  108. for (size_t index = 1; index < node_type_index_list_.size(); ++index) {
  109. if (node_type_index_list_[index].node == nullptr) {
  110. continue;
  111. }
  112. std::string batch_label;
  113. auto index_op_desc = node_type_index_list_[index].node->GetOpDesc();
  114. GE_IF_BOOL_EXEC(index_op_desc == nullptr, continue);
  115. (void)ge::AttrUtils::GetStr(index_op_desc, ATTR_NAME_BATCH_LABEL, batch_label);
  116. if (first_batch_label != batch_label) {
  117. all_same_label = false;
  118. break;
  119. }
  120. }
  121. return all_same_label;
  122. }
  123. bool CanNotLifeReuse(MemoryBlock *block) {
  124. if (block == nullptr || !block->reuse_mem_ || block->deleted_block_ || block->continuous_block_ ||
  125. block->GetLifeEnd() == kMaxLifeTime) {
  126. return true;
  127. }
  128. return false;
  129. }
  130. void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block) {
  131. if (CanNotLifeReuse(this) || CanNotLifeReuse(block)) {
  132. return;
  133. }
  134. MemoryBlock *parent = nullptr;
  135. MemoryBlock *child = nullptr;
  136. // merge small block to large block
  137. if ((block->GetLifeBegin() > GetLifeEnd()) && (block->stream_id_ == stream_id_)) {
  138. if ((child_offset_ + block->block_size_) <= block_size_) {
  139. parent = this;
  140. child = block;
  141. } else if ((block->child_offset_ + block_size_) <= block->block_size_) {
  142. parent = block;
  143. child = this;
  144. }
  145. }
  146. if ((parent != nullptr) && (child != nullptr) && child->child_blocks_.empty()) {
  147. parent->child_blocks_.emplace_back(child);
  148. parent->child_offset_ += child->block_size_;
  149. child->deleted_block_ = true;
  150. GELOGI(
  151. "Add block[%p size:%zu, stream id:%ld life time[begin:%zu, end:%zu]] to"
  152. " block[%p size:%zu, stream id:%ld, life time[begin:%zu, end:%zu]]",
  153. child, child->block_size_, child->stream_id_, child->GetLifeBegin(), child->GetLifeEnd(), parent,
  154. parent->block_size_, parent->stream_id_, parent->GetLifeBegin(), parent->GetLifeEnd());
  155. }
  156. }
  157. size_t MemoryBlock::GetLifeBegin() {
  158. size_t life_time = 0;
  159. if (!node_type_index_list_.empty()) {
  160. if (node_type_index_list_.front().node != nullptr) {
  161. auto node_op_desc = node_type_index_list_.front().node->GetOpDesc();
  162. if (node_op_desc != nullptr) {
  163. life_time = node_op_desc->GetId();
  164. }
  165. }
  166. }
  167. return life_time;
  168. }
  169. size_t MemoryBlock::GetLifeEnd() {
  170. if (!node_type_index_list_.empty()) {
  171. return node_type_index_list_.back().life_time_end;
  172. }
  173. return kMaxLifeTime;
  174. }
  175. void MemoryBlock::SetLifeTimeEnd(size_t time) {
  176. if (!node_type_index_list_.empty()) {
  177. node_type_index_list_.back().life_time_end = time;
  178. }
  179. }
  180. void SetLastUsedInputMemAttr(NodePtr &node, int input_index) {
  181. if (node == nullptr) {
  182. return;
  183. }
  184. auto node_op_desc = node->GetOpDesc();
  185. if (node_op_desc != nullptr) {
  186. auto input_desc = node_op_desc->GetInputDesc(input_index);
  187. if (!ge::AttrUtils::SetInt(input_desc, ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE, true)) {
  188. GELOGW("Set %s input[%d] ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE to true failed.", node_op_desc->GetName().c_str(),
  189. input_index);
  190. return;
  191. }
  192. GELOGD("Set %s input[%d] ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE to true success.", node_op_desc->GetName().c_str(),
  193. input_index);
  194. if (node_op_desc->UpdateInputDesc(input_index, input_desc) != GRAPH_SUCCESS) {
  195. GELOGW("Update %s input[%d] desc failed.", node_op_desc->GetName().c_str(), input_index);
  196. }
  197. }
  198. }
  199. Status GetNoAlignSize(const ge::OpDesc &desc, uint32_t index, size_t &size) {
  200. // calculate tensor real size
  201. auto output_op_desc = desc.GetOutputDescPtr(index);
  202. if (output_op_desc == nullptr) {
  203. GELOGI("GetNoAlignSize failed. OpName: %s, OpType: %s, index: %d", desc.GetName().c_str(), desc.GetType().c_str(),
  204. index);
  205. return FAILED;
  206. }
  207. int64_t tensor_size = 0;
  208. GeShape shape = output_op_desc->GetShape();
  209. Format format = output_op_desc->GetFormat();
  210. DataType data_type = output_op_desc->GetDataType();
  211. graphStatus graph_status = TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size);
  212. if (graph_status != GRAPH_SUCCESS) {
  213. GELOGE(graph_status, "CalcTensorMemSize failed!");
  214. return FAILED;
  215. }
  216. size = static_cast<size_t>(tensor_size);
  217. return SUCCESS;
  218. }
  219. string ToString(ge::NodeTypeIndex &x) {
  220. stringstream ss;
  221. ss << "[" << x.node->GetName() << "(" << x.node->GetType() << "), ";
  222. if (x.mem_type == kOutput) {
  223. ss << "Output, ";
  224. } else {
  225. ss << "Workspace, ";
  226. }
  227. ss << x.index << "]";
  228. return ss.str();
  229. }
  230. string MemoryBlock::String() {
  231. stringstream ss;
  232. ss << "Block size: " << Size() << " from " << HeadOffset() << " to " << TailOffset() << "";
  233. ss << "real_size_list: " << ToString(real_size_list_) << "";
  234. ss << "ref_count: " << ref_count_ << "";
  235. ss << "members: ";
  236. for (auto x : NodeTypeIndexList()) {
  237. ss << "__node: " << ToString(x) << "";
  238. }
  239. for (const auto &symbol : SymbolList()) {
  240. ss << "__symbol: " << symbol << "";
  241. }
  242. return ss.str();
  243. }
  244. BlockMemAssigner::BlockMemAssigner(ge::ComputeGraphPtr compute_graph)
  245. : mem_offset_(0), compute_graph_(std::move(compute_graph)), life_time_(0) {}
  246. BlockMemAssigner::~BlockMemAssigner() {
  247. for (MemoryBlock *memory_block : memory_blocks_) {
  248. GE_DELETE_NEW_SINGLE(memory_block);
  249. }
  250. }
  251. void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) {
  252. if (GraphUtils::GetRefMapping(compute_graph_, symbol_to_anchors_, anchor_to_symbol_) != GRAPH_SUCCESS) {
  253. GELOGE(FAILED, "Get ref-mapping for graph %s failed.", compute_graph_->GetName().c_str());
  254. return;
  255. }
  256. vector<int64_t> temp;
  257. for (const NodePtr &n : compute_graph_->GetAllNodes()) {
  258. auto node_op_desc = n->GetOpDesc();
  259. GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue);
  260. if (node_op_desc->GetType() == ATOMICADDRCLEAN) {
  261. atomic_addr_clean_id_ = node_op_desc->GetId();
  262. }
  263. for (auto &out_anchor : n->GetAllOutDataAnchors()) {
  264. GeTensorDesc output_desc = node_op_desc->GetOutputDesc(out_anchor->GetIdx());
  265. bool reuse_input = false;
  266. GE_IF_BOOL_EXEC(ge::TensorUtils::GetReuseInput(output_desc, reuse_input) != SUCCESS,
  267. GELOGI("Get reuse_input failed"));
  268. if (!reuse_input) {
  269. int64_t size = 0;
  270. GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(output_desc, size) != SUCCESS, GELOGI("Get size failed"));
  271. if (anchor_to_symbol_.empty()) {
  272. all_memory_size.emplace_back(size);
  273. } else {
  274. auto iter1 = anchor_to_symbol_.find(NodeIndexIO(n, out_anchor->GetIdx(), kOut).ToString());
  275. if (iter1 == anchor_to_symbol_.end()) {
  276. continue;
  277. }
  278. std::string symbol = iter1->second;
  279. auto iter2 = symbol_size_.find(symbol);
  280. if (iter2 == symbol_size_.end()) {
  281. symbol_size_[symbol] = size;
  282. } else if (size > static_cast<int64_t>(iter2->second)) {
  283. iter2->second = size;
  284. }
  285. }
  286. }
  287. }
  288. temp.clear();
  289. GetNodeWorkSpaceSize(n, temp);
  290. all_memory_size.insert(all_memory_size.end(), temp.begin(), temp.end());
  291. }
  292. GELOGI("The last atomic_addr_clean node id: %ld", atomic_addr_clean_id_);
  293. for (auto &pair : symbol_size_) {
  294. all_memory_size.emplace_back(pair.second);
  295. }
  296. sort(all_memory_size.begin(), all_memory_size.end());
  297. GELOGI("All memory size: %s", ToString(all_memory_size).c_str());
  298. for (auto iter = all_memory_size.begin(); iter != all_memory_size.end();) {
  299. if (*iter == 0) {
  300. iter = all_memory_size.erase(iter);
  301. } else {
  302. ++iter;
  303. }
  304. }
  305. InitReuseFlag();
  306. PrintSymbolMap();
  307. }
  308. ///
  309. /// @ingroup domi
  310. /// @brief decide memory size based on actual input memory size
  311. /// @param [in] size actual memory size in need
  312. /// @param [in] ranges memory size provided
  313. /// @return size_t memory size to apply
  314. ///
  315. size_t GetBlockSize(size_t size, const vector<int64_t> &ranges) {
  316. for (int64_t x : ranges) {
  317. auto x_temp = static_cast<size_t>(x);
  318. if (size <= x_temp) {
  319. return x_temp;
  320. }
  321. }
  322. GELOGW("Memory needed size:%zu is beyond the biggest block in memory ranges.", size);
  323. return 0;
  324. }
  325. bool IsDirectOutputNode(const NodePtr &node, int idx) {
  326. if ((node != nullptr) && (node->GetOpDesc() != nullptr) && (node->GetOpDesc()->GetType() == NETOUTPUT)) {
  327. GELOGI("This is netoutput node, the input node mem can not be reused");
  328. return true;
  329. }
  330. return false;
  331. }
  332. void AddReusableBlockCount(const MemoryBlock &mem_block, map<string, uint64_t> &reusable_block_counts) {
  333. string key = std::to_string(mem_block.Size());
  334. key += "_" + std::to_string(mem_block.stream_id_);
  335. auto it = reusable_block_counts.find(key);
  336. if (it != reusable_block_counts.end()) {
  337. it->second++;
  338. } else {
  339. reusable_block_counts[key] = 1;
  340. }
  341. }
  342. void ReduceReusableBlockCount(const MemoryBlock &mem_block, map<string, uint64_t> &reusable_block_counts) {
  343. string key = std::to_string(mem_block.Size());
  344. key += "_" + std::to_string(mem_block.stream_id_);
  345. auto it = reusable_block_counts.find(key);
  346. if (it != reusable_block_counts.end()) {
  347. if (it->second > 0) {
  348. it->second--;
  349. }
  350. }
  351. }
  352. bool CanReuseBySize(const map<string, uint64_t> &reusable_block_counts, const MemoryBlock &reusable_block,
  353. size_t block_size, size_t real_size, bool continuous, int64_t atomic_addr_clean_id) {
  354. bool can_reuse = false;
  355. // If node is before atomic_addr_clean node, the continus memory can't be reused.
  356. if (!reusable_block.NodeTypeIndexList().empty()) {
  357. auto node = reusable_block.NodeTypeIndexList()[0].node;
  358. if (node != nullptr) {
  359. auto op_desc = node->GetOpDesc();
  360. if (op_desc != nullptr) {
  361. if ((op_desc->GetId() < atomic_addr_clean_id) && continuous) {
  362. return false;
  363. }
  364. }
  365. }
  366. }
  367. // continuous memory case:only real_size is maximum can be reused and only one continuous memory in one block
  368. if (continuous || reusable_block.continuous_block_) {
  369. auto it =
  370. std::max_element(std::begin(reusable_block.NoAlignSizeList()), std::end(reusable_block.NoAlignSizeList()));
  371. if (it != std::end(reusable_block.NoAlignSizeList())) {
  372. GE_IF_BOOL_EXEC((continuous && reusable_block.continuous_block_) || (continuous && (real_size < *it)) ||
  373. (reusable_block.continuous_block_ && (real_size > *it)),
  374. GELOGD("Conflict current block size:%zu continuous:%d, reuse block max size:%zu continuous:%d",
  375. real_size, continuous, *it, reusable_block.continuous_block_);
  376. return false;);
  377. }
  378. }
  379. if (reusable_block.Size() == block_size) {
  380. can_reuse = true;
  381. } else {
  382. string key = std::to_string(reusable_block.Size());
  383. key += "_" + std::to_string(reusable_block.stream_id_);
  384. auto it = reusable_block_counts.find(key);
  385. GE_IF_BOOL_EXEC(
  386. (it != reusable_block_counts.end() && (it->second > kReuseMaxCount)) && (reusable_block.Size() > block_size),
  387. can_reuse = true;
  388. GELOGD("Less size mem reuse, reuse block size:%zu, current block size:%zu", reusable_block.Size(), block_size););
  389. }
  390. return can_reuse;
  391. }
  392. bool CanReuseByStream(const std::unordered_set<int64_t> &reuse_stream, MemoryBlock &reusable_block) {
  393. bool can_reuse = false;
  394. if (reuse_stream.find(reusable_block.stream_id_) != reuse_stream.cend()) {
  395. can_reuse = true;
  396. }
  397. return can_reuse;
  398. }
  399. bool BlockMemAssigner::IsOutNodeSetContinuousInput(const NodePtr &n, uint32_t out_index, std::string &peer_name,
  400. uint32_t &peer_input_index) {
  401. if (n == nullptr || n->GetAllOutDataAnchors().size() <= 0) {
  402. return false;
  403. }
  404. if (static_cast<size_t>(out_index) < n->GetAllOutDataAnchors().size()) {
  405. auto out_anchor = n->GetOutDataAnchor(out_index);
  406. GE_IF_BOOL_EXEC(out_anchor == nullptr,
  407. GELOGE(FAILED, "Node[%s] output[%u] anchor is null.", n->GetName().c_str(), out_index);
  408. return false;);
  409. for (auto const &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
  410. GE_IF_BOOL_EXEC(peer_in_anchor == nullptr,
  411. GELOGE(FAILED, "Node[%s] output[%u] peer_in_anchor 0 is null.", n->GetName().c_str(), out_index);
  412. return false;);
  413. auto peer_node = peer_in_anchor->GetOwnerNode();
  414. GE_IF_BOOL_EXEC(peer_node == nullptr,
  415. GELOGE(FAILED, "Node[%s] output[%u] node is null.", n->GetName().c_str(), out_index);
  416. return false;);
  417. // Get the continuous input type of the node, default is false
  418. bool is_input_continuous = false;
  419. auto peer_in_node_desc = peer_node->GetOpDesc();
  420. GE_IF_BOOL_EXEC(peer_in_node_desc == nullptr,
  421. GELOGE(FAILED, "Node[%s] output[%u] nodedesc is null.", n->GetName().c_str(), out_index);
  422. return false;);
  423. // If GetBool fail, is_input_continuous is false.
  424. (void)ge::AttrUtils::GetBool(peer_in_node_desc, ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous);
  425. if (is_input_continuous) {
  426. if (n->GetOwnerComputeGraph() != nullptr) {
  427. string graph_name = n->GetOwnerComputeGraph()->GetName();
  428. GELOGI("%s name[%s] output[%u] node[%s] set input[%d] continuous, input size[%u].", graph_name.c_str(),
  429. n->GetName().c_str(), out_index, peer_in_node_desc->GetName().c_str(), peer_in_anchor->GetIdx(),
  430. peer_node->GetAllInDataAnchorsSize());
  431. // Only set attr one times.
  432. if (node_continuous_input_blocks_[peer_in_node_desc->GetName()].size() == 0) {
  433. (void)ge::AttrUtils::SetBool(peer_in_node_desc, ATTR_NAME_CONTINUOUS_INPUT_ALLOC, true);
  434. node_continuous_input_counts_[peer_in_node_desc->GetName()] = peer_node->GetAllInDataAnchorsSize();
  435. }
  436. peer_input_index = peer_in_anchor->GetIdx();
  437. peer_name = peer_in_node_desc->GetName();
  438. return true;
  439. }
  440. }
  441. }
  442. }
  443. return false;
  444. }
  445. ///
  446. /// @ingroup GE
  447. /// @brief Check pre_reuse flag & post_reuse glag for each symbol
  448. /// @return void
  449. ///
  450. void BlockMemAssigner::InitReuseFlag() {
  451. static const std::set<std::string> kPreReuseTypes = {ge::DATA_TYPE, ge::AIPP_DATA_TYPE, ge::ANN_DATA_TYPE,
  452. ge::NETOUTPUT, ge::PROPOSAL, ge::ZEROSLIKE,
  453. ge::CONSTANT, ge::CONSTANTOP};
  454. static const std::set<std::string> kPostReuseTypes = {ge::DATA_TYPE, ge::AIPP_DATA_TYPE, ge::ENTER,
  455. ge::REFENTER, ge::NEXTITERATION, ge::REFNEXTITERATION};
  456. for (auto &pair : symbol_to_anchors_) {
  457. std::string symbol = pair.first;
  458. bool pre_reuse_flag = true;
  459. bool post_reuse_flag = true;
  460. for (auto &node_index_io : pair.second) {
  461. if (node_index_io.io_type_ == kIn) {
  462. continue;
  463. }
  464. OutDataAnchorPtr out_anchor = node_index_io.node_->GetOutDataAnchor(node_index_io.index_);
  465. if (out_anchor == nullptr) {
  466. continue;
  467. }
  468. bool out_flg = false;
  469. if (node_index_io.node_->GetOutDataNodes().empty()) {
  470. out_flg = true;
  471. }
  472. for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) {
  473. if (IsDirectOutputNode(in_anchor->GetOwnerNode(), in_anchor->GetIdx())) {
  474. out_flg = true;
  475. break;
  476. }
  477. }
  478. std::string type = out_anchor->GetOwnerNode()->GetType();
  479. pre_reuse_flag = pre_reuse_flag && !out_flg && (kPreReuseTypes.count(type) == 0);
  480. post_reuse_flag = post_reuse_flag && (kPostReuseTypes.count(type) == 0);
  481. if (!pre_reuse_flag && !post_reuse_flag) {
  482. break;
  483. }
  484. }
  485. pre_reuse_flag_[symbol] = pre_reuse_flag;
  486. post_reuse_flag_[symbol] = post_reuse_flag;
  487. }
  488. }
  489. ///
  490. /// @ingroup GE
  491. /// @brief get pre_reuse flag
  492. /// @param [in] node
  493. /// @param [in] out_index
  494. /// @return bool
  495. ///
  496. bool BlockMemAssigner::IsPreReuse(const NodePtr &node, uint32_t out_index) const {
  497. OutDataAnchorPtr out_data_anchor = nullptr;
  498. if (static_cast<size_t>(out_index) < node->GetAllOutDataAnchors().size()) {
  499. out_data_anchor = node->GetOutDataAnchor(out_index);
  500. }
  501. if (out_data_anchor == nullptr) {
  502. return false;
  503. }
  504. NodeIndexIO cur_node_index_io(out_data_anchor->GetOwnerNode(), out_data_anchor->GetIdx(), kOut);
  505. auto iter1 = anchor_to_symbol_.find(cur_node_index_io.ToString());
  506. if (iter1 == anchor_to_symbol_.end()) {
  507. return false;
  508. }
  509. std::string symbol = iter1->second;
  510. auto iter2 = pre_reuse_flag_.find(symbol);
  511. if (iter2 == pre_reuse_flag_.end()) {
  512. return false;
  513. }
  514. return iter2->second;
  515. }
  516. ///
  517. /// @ingroup GE
  518. /// @brief get post_reuse flag
  519. /// @param [in] mem_block
  520. /// @return bool
  521. ///
  522. bool BlockMemAssigner::IsPostReuse(const MemoryBlock *mem_block) const {
  523. if (mem_block == nullptr) {
  524. return false;
  525. }
  526. for (auto &symbol : mem_block->SymbolList()) {
  527. auto iter = post_reuse_flag_.find(symbol);
  528. if (iter == post_reuse_flag_.end()) {
  529. continue;
  530. }
  531. if (!iter->second) {
  532. return false;
  533. }
  534. }
  535. return true;
  536. }
  537. ///
  538. /// @ingroup GE
  539. /// @brief check if symbol of cur node_index_io has block
  540. /// @param [in] node_index_io
  541. /// @return bool
  542. ///
  543. bool BlockMemAssigner::IsSymbolExist(const NodeIndexIO &node_index_io) {
  544. auto iter = anchor_to_symbol_.find(node_index_io.ToString());
  545. if (iter == anchor_to_symbol_.end()) {
  546. return false;
  547. }
  548. std::string symbol = iter->second;
  549. return symbol_blocks_.find(symbol) != symbol_blocks_.end();
  550. }
  551. ///
  552. /// @ingroup GE
  553. /// @brief Print symbol
  554. /// @return void
  555. ///
  556. void BlockMemAssigner::PrintSymbolMap() {
  557. for (auto &pair : symbol_to_anchors_) {
  558. GELOGD("symbol=%s, max_size=%zu, pre_reuse=%s, post_reuse=%s", pair.first.c_str(), symbol_size_[pair.first],
  559. pre_reuse_flag_[pair.first] ? "true" : "false", post_reuse_flag_[pair.first] ? "true" : "false");
  560. for (auto &node_index_io : pair.second) {
  561. GELOGD("anchor:%s", node_index_io.ToString().c_str());
  562. }
  563. }
  564. }
  565. MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, size_t no_align_size,
  566. MemoryType mem_type, const NodePtr &n, uint32_t out_index,
  567. const vector<bool> &workspace_reuse_flag, const bool is_op_reuse_mem,
  568. const bool continuous) {
  569. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "Input parameter n is null.");
  570. auto node_op_desc = n->GetOpDesc();
  571. GE_IF_BOOL_EXEC(node_op_desc == nullptr, return nullptr);
  572. bool is_reuse_memory = false;
  573. string ge_disable_reuse_mem_env = "0";
  574. (void)ge::GetContext().GetOption(kDisableReuseMemory, ge_disable_reuse_mem_env);
  575. if (ge_disable_reuse_mem_env != "1") {
  576. bool reuse_mem_flag = !((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]);
  577. is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) && !node_op_desc->HasAttr(kOpNoReuseMem) &&
  578. reuse_mem_flag && is_op_reuse_mem && (IsPreReuse(n, out_index));
  579. auto stream_id = node_op_desc->GetStreamId();
  580. auto map_iter = reusable_streams_map_.find(stream_id);
  581. if (is_reuse_memory && map_iter != reusable_streams_map_.end()) {
  582. for (auto it = reusable_blocks_.begin(); it != reusable_blocks_.end(); ++it) {
  583. MemoryBlock *reusable_block = *it;
  584. if (!IsPostReuse(reusable_block)) {
  585. reusable_block->reuse_mem_ = false;
  586. GELOGI("Unreusable block.");
  587. continue;
  588. }
  589. // A node can reuse blocks of the same stream and preorder streams
  590. auto id = GetAtomicAddrCleanId();
  591. if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size, real_size, continuous, id) &&
  592. CanReuseByStream(map_iter->second, *reusable_block)) {
  593. GELOGD("Cross stream mem reuse, target stream:%ld, current stream:%ld", reusable_block->stream_id_,
  594. stream_id);
  595. reusable_block->AddNodeTypeIndex({n, mem_type, out_index, false}, real_size, no_align_size);
  596. if (mem_type == kOutput) {
  597. auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString());
  598. if (iter != anchor_to_symbol_.end()) {
  599. reusable_block->AddSymbol(iter->second);
  600. }
  601. }
  602. reusable_block->continuous_block_ = continuous;
  603. reusable_block->ref_count_++;
  604. ReduceReusableBlockCount(*reusable_block, reusable_block_counts_);
  605. reusable_blocks_.erase(it);
  606. return reusable_block;
  607. }
  608. }
  609. }
  610. }
  611. auto block = new (std::nothrow) MemoryBlock(block_size, node_op_desc->GetStreamId(), is_reuse_memory);
  612. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(block == nullptr, return nullptr, "new an object failed.");
  613. // Data and netoutput need zero copy block
  614. if ((node_op_desc->GetType() == DATA_TYPE && !continuous) || (node_op_desc->GetType() == NETOUTPUT)) {
  615. block->is_zero_copy_ = true;
  616. }
  617. block->Init(real_size, mem_type, n, out_index, no_align_size);
  618. block->stream_id_ = node_op_desc->GetStreamId();
  619. block->ref_count_++;
  620. block->continuous_block_ = continuous;
  621. if (mem_type == kOutput) {
  622. auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString());
  623. if (iter != anchor_to_symbol_.end()) {
  624. block->AddSymbol(iter->second);
  625. }
  626. }
  627. memory_blocks_.emplace_back(block);
  628. return block;
  629. }
  630. MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, const vector<int64_t> &ranges,
  631. const bool is_op_reuse_mem, const bool continuous) {
  632. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "input node is null.");
  633. auto node_op_desc = n->GetOpDesc();
  634. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(node_op_desc == nullptr, return nullptr, "node_op_desc is null.");
  635. MemoryBlock *block = nullptr;
  636. NodeIndexIO node_index_io(n, index, kOut);
  637. int64_t size = 0;
  638. auto output_op_desc = node_op_desc->GetOutputDescPtr(index);
  639. if (output_op_desc != nullptr) {
  640. GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(*output_op_desc, size) != SUCCESS, GELOGI("Get size failed"));
  641. }
  642. size_t no_align_size = 0;
  643. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetNoAlignSize(*node_op_desc, index, no_align_size) != SUCCESS, return nullptr,
  644. "Get no align size failed");
  645. if (IsSymbolExist(node_index_io)) {
  646. std::string symbol = anchor_to_symbol_[node_index_io.ToString()];
  647. block = symbol_blocks_[symbol];
  648. block->AddNodeTypeIndex({n, kOutput, index, true}, size, no_align_size);
  649. block->ref_count_++;
  650. } else {
  651. int64_t max_size = size;
  652. auto iter1 = anchor_to_symbol_.find(node_index_io.ToString());
  653. if (iter1 != anchor_to_symbol_.end()) {
  654. auto iter2 = symbol_size_.find(iter1->second);
  655. if (iter2 != symbol_size_.end()) {
  656. max_size = iter2->second;
  657. }
  658. }
  659. auto block_size = GetBlockSize(max_size, ranges);
  660. vector<bool> workspace_reuse_flag;
  661. block = ApplyMemory(block_size, size, no_align_size, kOutput, n, index, workspace_reuse_flag, is_op_reuse_mem,
  662. continuous);
  663. }
  664. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(block == nullptr, return nullptr, "Block is nullptr.");
  665. int out_count_reuse_input = block->ref_count_;
  666. int out_count = 0;
  667. GE_IF_BOOL_EXEC(index >= n->GetAllOutDataAnchors().size(), GELOGE(FAILED, "index is out of range."); return nullptr);
  668. auto out_data_anchor = n->GetOutDataAnchor(index);
  669. GE_IF_BOOL_EXEC(out_data_anchor == nullptr, GELOGE(FAILED, "Out data anchor is nullptr."); return nullptr);
  670. for (const auto &in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  671. auto owner_node = in_anchor->GetOwnerNode();
  672. auto op_desc = owner_node->GetOpDesc();
  673. GE_IF_BOOL_EXEC(op_desc == nullptr, continue);
  674. Params *instance = Params::Instance();
  675. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(instance == nullptr, return nullptr, "Params instance is nullptr.");
  676. if (!((instance->GetTarget() == TARGET_TYPE_TINY) && (op_desc->GetType() == NETOUTPUT))) {
  677. out_count++;
  678. }
  679. }
  680. bool reuse_input = false;
  681. for (const auto &in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  682. auto owner_node = in_anchor->GetOwnerNode();
  683. GE_IF_BOOL_EXEC(owner_node == nullptr, continue);
  684. auto op_desc = owner_node->GetOpDesc();
  685. GE_IF_BOOL_EXEC(op_desc == nullptr, continue);
  686. for (uint32_t i = 0; i < static_cast<uint32_t>(op_desc->GetOutputsSize()); i++) {
  687. bool dst_reuse_input = false;
  688. uint32_t dst_reuse_input_index = 0;
  689. auto owner_node_op_desc = op_desc->GetOutputDescPtr(i);
  690. GE_IF_BOOL_EXEC(owner_node_op_desc == nullptr, continue);
  691. GE_IF_BOOL_EXEC(ge::TensorUtils::GetReuseInput(*owner_node_op_desc, dst_reuse_input) != SUCCESS,
  692. GELOGI("Get dst_reuse_input failed"));
  693. GE_IF_BOOL_EXEC(ge::TensorUtils::GetReuseInputIndex(*owner_node_op_desc, dst_reuse_input_index) != SUCCESS,
  694. GELOGI("Get dst_reuse_input_index failed"));
  695. if (dst_reuse_input && (dst_reuse_input_index == static_cast<uint32_t>(in_anchor->GetIdx()))) {
  696. block->AddNodeTypeIndex({owner_node, kOutput, i, true}, block->Size(), block->Size());
  697. out_count_reuse_input += 1;
  698. reuse_input = true;
  699. }
  700. }
  701. }
  702. block->ref_count_ = reuse_input ? out_count_reuse_input + out_count - 1 : out_count;
  703. return block;
  704. }
  705. bool IsOutputBlock(const ge::InDataAnchorPtr &in_data_anchor) {
  706. auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  707. GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, GELOGE(FAILED, "Peer out anchor is nullptr."); return false);
  708. auto src = peer_out_anchor->GetOwnerNode();
  709. int32_t index = peer_out_anchor->GetIdx();
  710. auto iter = domi::GetContext().out_nodes_map.find(src->GetName());
  711. if (iter != domi::GetContext().out_nodes_map.end()) {
  712. for (auto id : iter->second) {
  713. if (index == id) {
  714. return true;
  715. }
  716. }
  717. }
  718. return false;
  719. }
  720. // atomic out memory will be reassigned
  721. bool IsAtomicOutputMemory(const ge::NodePtr &node, uint32_t output_index, bool is_atomic,
  722. bool out_node_set_continuous_input) {
  723. auto op_desc = node->GetOpDesc();
  724. if (op_desc == nullptr) {
  725. return false;
  726. }
  727. vector<int64_t> atomic_output_index;
  728. // If GetListInt fail, atomic_output_index is empty.
  729. (void)ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index);
  730. if (!out_node_set_continuous_input && is_atomic) {
  731. for (auto &index : atomic_output_index) {
  732. if (static_cast<uint32_t>(index) == output_index) {
  733. if (node->GetOwnerComputeGraph() != nullptr) {
  734. string graph_name = node->GetOwnerComputeGraph()->GetName();
  735. GELOGD("[IMAS]Atomic no assign %s name[%s] output[%d] streamid[%ld].", graph_name.c_str(),
  736. op_desc->GetName().c_str(), index, op_desc->GetStreamId());
  737. }
  738. return true;
  739. }
  740. }
  741. }
  742. return false;
  743. }
  744. void BlockMemAssigner::ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock *> &reusable_memory) {
  745. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(to_release == nullptr, return, "Input parameter to_release is null.");
  746. GE_CHK_TRUE_EXEC_INFO(to_release->ref_count_ <= 0, return, "Release memory");
  747. GE_CHK_TRUE_EXEC_INFO(!to_release->reuse_mem_, return, "doesn't reuse memory");
  748. --to_release->ref_count_;
  749. if (to_release->ref_count_ == 0) {
  750. to_release->SetLifeTimeEnd(life_time_);
  751. reusable_memory.emplace_back(to_release);
  752. AddReusableBlockCount(*to_release, reusable_block_counts_);
  753. }
  754. }
  755. void BlockMemAssigner::ReleaseMemorys(const vector<MemoryBlock *> &to_releases,
  756. vector<MemoryBlock *> &reusable_memory) {
  757. for (auto mem_block : to_releases) {
  758. ReleaseMemory(mem_block, reusable_memory);
  759. }
  760. }
  761. void BlockMemAssigner::ReleaseInputNodeOutMemory(const unordered_map<string, vector<MemoryBlock *>> &node_out_blocks,
  762. vector<MemoryBlock *> &reusable_memory, NodePtr &node) {
  763. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  764. if ((in_anchor->GetPeerOutAnchor() == nullptr) ||
  765. (in_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetOpDesc() == nullptr) || (node->GetOpDesc() == nullptr)) {
  766. return;
  767. }
  768. GE_IF_BOOL_EXEC(IsOutputBlock(in_anchor), continue);
  769. auto node_name = in_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName();
  770. GE_IF_BOOL_EXEC((in_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetType() == CONSTANT) ||
  771. (in_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetType() == FASTRCNNPREDICTIONS) ||
  772. (in_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetType() == CONSTANTOP),
  773. continue);
  774. auto it = node_out_blocks.find(node_name);
  775. if (it == node_out_blocks.end()) {
  776. continue;
  777. }
  778. for (auto block : it->second) {
  779. const vector<NodeTypeIndex> &node_type_indexs = block->NodeTypeIndexList();
  780. if (node_type_indexs.empty()) {
  781. continue;
  782. }
  783. GELOGD("node_type_indexs: %d, %s", node_type_indexs.back().index,
  784. node_type_indexs.back().node->GetName().c_str());
  785. if ((node_type_indexs.back().node == in_anchor->GetPeerOutAnchor()->GetOwnerNode()) &&
  786. (node_type_indexs.back().index == static_cast<uint32_t>(in_anchor->GetPeerOutAnchor()->GetIdx())) &&
  787. (node->GetOpDesc()->GetStreamId() == block->stream_id_)) {
  788. ReleaseMemory(block, reusable_memory);
  789. if (block->ref_count_ == 0) {
  790. SetLastUsedInputMemAttr(node, in_anchor->GetIdx());
  791. }
  792. }
  793. }
  794. }
  795. }
  796. void SplitStringByComma(const string &str, vector<string> &sub_str_vec) {
  797. std::string tmp_string = str + ",";
  798. std::string::size_type start_pos = 0;
  799. std::string::size_type cur_pos = tmp_string.find(',', 0);
  800. while (cur_pos != std::string::npos) {
  801. std::string sub_str = tmp_string.substr(start_pos, cur_pos - start_pos);
  802. if (!sub_str.empty()) {
  803. vector<string>::iterator ret = std::find(sub_str_vec.begin(), sub_str_vec.end(), sub_str);
  804. if (ret == sub_str_vec.end()) {
  805. sub_str_vec.push_back(sub_str);
  806. }
  807. }
  808. start_pos = cur_pos + 1;
  809. cur_pos = tmp_string.find(',', start_pos);
  810. }
  811. }
  812. void CheckAndGetOpReuseEnv(const string &env, vector<string> &env_vec, bool &op_reuse_env_valid) {
  813. string env_str;
  814. env_str = string(env);
  815. if (env_str.size() > kReuseMaxCharNum) {
  816. GELOGE(FAILED, "The OP_NO_REUSE_MEM has more than %d characters.", kReuseMaxCharNum);
  817. return;
  818. }
  819. SplitStringByComma(env_str, env_vec);
  820. if (env_vec.size() > kReuseMaxOpNum) {
  821. GELOGE(FAILED, "The OP_NO_REUSE_MEM has more than %d nodes.", kReuseMaxOpNum);
  822. return;
  823. }
  824. op_reuse_env_valid = true;
  825. return;
  826. }
  827. Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector<int64_t> &ranges) {
  828. auto op_desc = node->GetOpDesc();
  829. int64_t stream_id = op_desc->GetStreamId();
  830. vector<int64_t> memorys_type;
  831. bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, memorys_type);
  832. GELOGI("Assign memory node[%s], output size[%d], output memory type size[%d]", op_desc->GetName().c_str(),
  833. op_desc->GetOutputsSize(), memorys_type.size());
  834. if (has_mem_type_attr && (memorys_type.size() != op_desc->GetOutputsSize())) {
  835. GELOGE(INTERNAL_ERROR, "fusion: node[%s], output memory size err[outputsize:%zu, memorysize:%zu]",
  836. op_desc->GetName().c_str(), op_desc->GetOutputsSize(), memorys_type.size());
  837. return INTERNAL_ERROR;
  838. }
  839. is_op_reuse_mem_ = true;
  840. if (op_reuse_env_valid_ == true) {
  841. vector<string>::iterator it_name =
  842. std::find(op_no_reuse_mem_vec_.begin(), op_no_reuse_mem_vec_.end(), op_desc->GetName());
  843. vector<string>::iterator it_type =
  844. std::find(op_no_reuse_mem_vec_.begin(), op_no_reuse_mem_vec_.end(), op_desc->GetType());
  845. GE_IF_BOOL_EXEC(it_name != op_no_reuse_mem_vec_.end() || it_type != op_no_reuse_mem_vec_.end(),
  846. is_op_reuse_mem_ = false;);
  847. }
  848. bool is_atomic = false;
  849. // If GetBool fail, is_atomic is false.
  850. (void)ge::AttrUtils::GetBool(op_desc, ATOMIC_ATTR_IS_ATOMIC_NODE, is_atomic);
  851. // Allocate memory for the current node and release node memory of the same size in the workspace
  852. GE_IF_BOOL_EXEC(ge_disable_reuse_mem_env_ != "1",
  853. ReleaseMemorys(stream_workspace_blocks_[stream_id], reusable_blocks_);)
  854. for (uint32_t i = 0; i < static_cast<uint32_t>(op_desc->GetOutputsSize()); i++) {
  855. int64_t size = 0;
  856. auto output_op_desc = op_desc->GetOutputDescPtr(i);
  857. if (output_op_desc != nullptr) {
  858. GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(*output_op_desc, size) != SUCCESS, GELOGI("Get size failed"));
  859. }
  860. // fusion: other type's size not means malloc HBM memory
  861. bool l1_flag = has_mem_type_attr && memorys_type[i] == RT_MEMORY_L1;
  862. if (l1_flag) {
  863. GELOGI("fusion: node[%s], output[%s], output memory type [%d]", op_desc->GetName().c_str(),
  864. op_desc->GetOutputNameByIndex(i).c_str(), memorys_type[i]);
  865. size = 0;
  866. }
  867. std::string peer_name;
  868. uint32_t peer_input_index = 0;
  869. bool out_node_set_continuous_input = false;
  870. bool no_need_assign_memory = ((size == 0) || CheckIsZeroMemNodeType(node->GetType()));
  871. if (!no_need_assign_memory) {
  872. out_node_set_continuous_input = IsOutNodeSetContinuousInput(node, i, peer_name, peer_input_index);
  873. no_need_assign_memory = IsAtomicOutputMemory(node, i, is_atomic, out_node_set_continuous_input);
  874. }
  875. if (no_need_assign_memory) {
  876. zero_memory_list_.emplace_back(node, kOutput, i, false);
  877. continue;
  878. }
  879. // atomic can't be reused
  880. if (is_op_reuse_mem_ && out_node_set_continuous_input && is_atomic) {
  881. is_op_reuse_mem_ = false;
  882. }
  883. MemoryBlock *mem_block = ApplyOutMemory(node, i, ranges, is_op_reuse_mem_, out_node_set_continuous_input);
  884. if (mem_block != nullptr) {
  885. node_out_blocks_[node->GetName()].emplace_back(mem_block);
  886. if (out_node_set_continuous_input) {
  887. node_continuous_input_blocks_[peer_name][peer_input_index] = mem_block;
  888. }
  889. NodeIndexIO node_index_io(node, i, kOut);
  890. auto iter = anchor_to_symbol_.find(node_index_io.ToString());
  891. if (iter == anchor_to_symbol_.end()) {
  892. continue;
  893. }
  894. symbol_blocks_[iter->second] = mem_block;
  895. }
  896. }
  897. return SUCCESS;
  898. }
  899. ///
  900. /// @ingroup domi
  901. /// @brief traverse all nodes outputs and workspace in need, apply memory block considering memory reuse
  902. /// @param [in/out] ranges memory size provided
  903. /// @return Status result
  904. ///
  905. void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) {
  906. // Init reusable streams map
  907. InitReusableStreamMap();
  908. (void)ge::GetContext().GetOption(kDisableReuseMemory, ge_disable_reuse_mem_env_);
  909. GEEVENT("Reuse memory %s", ge_disable_reuse_mem_env_ == "1" ? "close" : "open");
  910. string op_no_reuse_mem_str;
  911. const char *op_no_reuse_mem = std::getenv(OP_NO_REUSE_MEM);
  912. GE_IF_BOOL_EXEC(op_no_reuse_mem != nullptr, op_no_reuse_mem_str = string(op_no_reuse_mem);
  913. CheckAndGetOpReuseEnv(op_no_reuse_mem_str, op_no_reuse_mem_vec_, op_reuse_env_valid_););
  914. for (NodePtr &n : compute_graph_->GetAllNodes()) {
  915. auto node_op_desc = n->GetOpDesc();
  916. GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue);
  917. life_time_ = node_op_desc->GetId();
  918. int64_t stream_id = node_op_desc->GetStreamId();
  919. if (AssignOutputMemoryWithReuse(n, ranges) != SUCCESS) {
  920. return;
  921. }
  922. stream_workspace_blocks_[stream_id].clear();
  923. vector<int64_t> temp;
  924. GetNodeWorkSpaceSize(n, temp);
  925. vector<int64_t> workspace_bytes;
  926. vector<int64_t> workspace_memory_type;
  927. bool has_workspace_mem_type_attr =
  928. ge::AttrUtils::GetListInt(node_op_desc, TVM_ATTR_NAME_WORKSPACE_TYPE, workspace_memory_type);
  929. vector<bool> workspace_reuse_flag;
  930. GE_IF_BOOL_EXEC(!ge::AttrUtils::GetListBool(node_op_desc, kAttrNameWorkspaceReuseFlag, workspace_reuse_flag),
  931. GELOGD("OP %s get workspace_reuse_flag attr failed", node_op_desc->GetName().c_str()));
  932. GELOGI("Assign memory node[%s], size [temp:%zu, memory type size:%zu]", node_op_desc->GetName().c_str(),
  933. temp.size(), workspace_memory_type.size());
  934. if (has_workspace_mem_type_attr && (temp.size() != workspace_memory_type.size())) {
  935. GELOGE(INTERNAL_ERROR, "fusion: node[%s], workspace_memory size err![v_temp:%zu, workspace:%zu]",
  936. n->GetName().c_str(), temp.size(), workspace_memory_type.size());
  937. return;
  938. }
  939. for (size_t i = 0; i < temp.size(); i++) {
  940. // fusion: other type's size not means malloc HBM memory
  941. bool workspace_skip_flag = false;
  942. if (has_workspace_mem_type_attr && workspace_memory_type[i] == RT_MEMORY_L1) {
  943. GELOGI(
  944. "fusion: node[%s]workspace index[%d] is not hbm type, add to zero_memory_list, workspace memory type [%ld]",
  945. node_op_desc->GetName().c_str(), i, workspace_memory_type[i]);
  946. workspace_skip_flag = true;
  947. }
  948. if (temp[i] == 0 || workspace_skip_flag) {
  949. zero_memory_list_.emplace_back(n, kWorkspace, static_cast<uint32_t>(i), false);
  950. continue;
  951. }
  952. MemoryBlock *mem_block = ApplyMemory(GetBlockSize(static_cast<size_t>(temp[i]), ranges),
  953. static_cast<size_t>(temp[i]), static_cast<size_t>(temp[i]), kWorkspace, n,
  954. static_cast<uint32_t>(i), workspace_reuse_flag, is_op_reuse_mem_, false);
  955. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mem_block == nullptr, continue, "failed to apply memory block.");
  956. CheckWorkspaceReuse(workspace_reuse_flag, i, stream_id, mem_block);
  957. }
  958. ReleaseInputNodeOutMemory(node_out_blocks_, reusable_blocks_, n);
  959. }
  960. GELOGD("Assigned memory blocks:");
  961. for (auto mem_block : memory_blocks_) {
  962. GELOGD("%s", mem_block->String().c_str());
  963. (void)mem_block; // Fix warning
  964. }
  965. bool merge_dynamic_batch = false;
  966. GE_IF_BOOL_EXEC(!(ge_disable_reuse_mem_env_ == "1"), merge_dynamic_batch = MergeDynamicBatchBlocks();)
  967. GE_IF_BOOL_EXEC(!merge_dynamic_batch, ReuseBlocksByLifeTime();)
  968. AssignContinuousBlocks();
  969. ResizeMemoryBlocks();
  970. GELOGD("Memory blocks after resize:");
  971. for (auto mem_block : memory_blocks_) {
  972. GELOGD("%s", mem_block->String().c_str());
  973. (void)mem_block; // Fix warning
  974. }
  975. }
  976. void BlockMemAssigner::CheckWorkspaceReuse(const vector<bool> &workspace_reuse_flag, uint32_t index, int64_t stream_id,
  977. MemoryBlock *mem_block) {
  978. bool reuse_mem_flag =
  979. ((workspace_reuse_flag.size() > index) && (workspace_reuse_flag[index] == false)) ? false : true;
  980. if (reuse_mem_flag) {
  981. stream_workspace_blocks_[stream_id].emplace_back(mem_block);
  982. }
  983. }
  984. void BlockMemAssigner::GetNodeWorkSpaceSize(const NodePtr &node, vector<int64_t> &workspace_memory) {
  985. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(node->GetOpDesc() == nullptr, return, "Op desc is null.");
  986. vector<int64_t> workspace_byte_nums = node->GetOpDesc()->GetWorkspaceBytes();
  987. GELOGD("GetNodeWorkSpaceSize: node[%s] size:%zu", node->GetOpDesc()->GetName().c_str(), workspace_byte_nums.size());
  988. for (int64_t byte_size : workspace_byte_nums) {
  989. workspace_memory.emplace_back(byte_size);
  990. GELOGD("GetNodeWorkSpaceSize: push back size:%ld", byte_size);
  991. }
  992. }
  993. // descending order
  994. static bool CompareBlockMaxSize(MemoryBlock *left, MemoryBlock *right) {
  995. if (left == nullptr || right == nullptr) {
  996. return false;
  997. }
  998. auto left_max_size = std::max_element(left->RealSizeList().begin(), left->RealSizeList().end());
  999. if (left_max_size != left->RealSizeList().end()) {
  1000. auto right_max_size = std::max_element(right->RealSizeList().begin(), right->RealSizeList().end());
  1001. if (right_max_size == right->RealSizeList().end() || (*left_max_size > *right_max_size)) {
  1002. return true;
  1003. }
  1004. }
  1005. return false;
  1006. }
  1007. void MergeBlocks(std::vector<MemoryBlock *> &dest, std::vector<MemoryBlock *> &src) {
  1008. for (size_t i = 0; i < dest.size(); ++i) {
  1009. if (i >= src.size()) {
  1010. return;
  1011. }
  1012. if (dest[i] != nullptr && src[i] != nullptr) {
  1013. for (auto &symbol : src[i]->SymbolList()) {
  1014. dest[i]->AddSymbol(symbol);
  1015. }
  1016. for (size_t j = 0; j < src[i]->NodeTypeIndexList().size(); ++j) {
  1017. dest[i]->AddNodeTypeIndex(src[i]->NodeTypeIndexList()[j], src[i]->RealSizeList()[j],
  1018. src[i]->NoAlignSizeList()[j]);
  1019. src[i]->deleted_block_ = true;
  1020. }
  1021. }
  1022. }
  1023. }
  1024. bool BlockMemAssigner::MergeDynamicBatchBlocks() {
  1025. bool merged = false;
  1026. std::map<std::string, std::vector<MemoryBlock *>> dynamic_batch_blocks;
  1027. for (auto block : memory_blocks_) {
  1028. if (block == nullptr) {
  1029. continue;
  1030. }
  1031. std::string batch_label;
  1032. if (block->IsSameLabel(batch_label)) {
  1033. dynamic_batch_blocks[batch_label].emplace_back(block);
  1034. }
  1035. }
  1036. auto it = dynamic_batch_blocks.begin();
  1037. auto it_max = it;
  1038. // find max block counts
  1039. for (; it != dynamic_batch_blocks.end(); ++it) {
  1040. if (it->second.size() > it_max->second.size()) {
  1041. it_max = it;
  1042. }
  1043. std::sort(it->second.begin(), it->second.end(), CompareBlockMaxSize);
  1044. }
  1045. if (it_max != dynamic_batch_blocks.end()) {
  1046. GELOGD("MergeDynamicBatch %s block counts %zu", it_max->first.c_str(), it_max->second.size());
  1047. }
  1048. for (it = dynamic_batch_blocks.begin(); it != dynamic_batch_blocks.end(); ++it) {
  1049. if (it != it_max) {
  1050. GELOGD("MergeDynamicBatch from %s to %s", it->first.c_str(), it_max->first.c_str());
  1051. MergeBlocks(it_max->second, it->second);
  1052. merged = true;
  1053. }
  1054. }
  1055. return merged;
  1056. }
  1057. // asending order
  1058. static bool CompareBlockIndex(MemoryBlock *left, MemoryBlock *right) {
  1059. if (left == nullptr || right == nullptr) {
  1060. return false;
  1061. }
  1062. if (left->input_index_ < right->input_index_) {
  1063. return true;
  1064. }
  1065. return false;
  1066. }
  1067. ///
  1068. /// @ingroup domi
  1069. /// @brief order blocks by continuous input index
  1070. /// @param [in] blocks need be processed
  1071. /// @param [in] input blocks need continuous
  1072. /// @param [out] blocks after continuous order
  1073. /// @param [in/out] blocks ordered
  1074. ///
  1075. void ReAssignContinuousBlocks(const std::vector<MemoryBlock *> &org_blocks,
  1076. const std::map<MemoryBlock *, uint32_t> block_map,
  1077. std::vector<MemoryBlock *> &dest_blocks, std::vector<MemoryBlock *> &continuous_blocks) {
  1078. for (auto &memory_block : org_blocks) {
  1079. if (memory_block == nullptr || memory_block->deleted_block_) {
  1080. continue;
  1081. }
  1082. if (block_map.find(memory_block) != block_map.end()) {
  1083. continue;
  1084. }
  1085. dest_blocks.emplace_back(memory_block);
  1086. }
  1087. // add continuous block
  1088. std::sort(continuous_blocks.begin(), continuous_blocks.end(), CompareBlockIndex);
  1089. size_t count = 0;
  1090. for (auto &memory_block : continuous_blocks) {
  1091. GE_IF_BOOL_EXEC(memory_block == nullptr, continue);
  1092. GELOGI("Block continuous input index:%d", memory_block->input_index_);
  1093. count++;
  1094. if (count == 1) {
  1095. memory_block->first_continuous_block_ = true;
  1096. }
  1097. if (count == continuous_blocks.size()) {
  1098. memory_block->last_continuous_block_ = true;
  1099. }
  1100. dest_blocks.emplace_back(memory_block);
  1101. }
  1102. }
  1103. void BlockMemAssigner::AssignContinuousBlocks() {
  1104. for (auto &block_map : node_continuous_input_blocks_) {
  1105. std::vector<MemoryBlock *> dest_memory_blocks;
  1106. std::map<MemoryBlock *, uint32_t> continuous_block_map;
  1107. std::vector<MemoryBlock *> continuous_blocks;
  1108. auto it = node_continuous_input_counts_.find(block_map.first);
  1109. GE_IF_BOOL_EXEC(it == node_continuous_input_counts_.end(), continue);
  1110. GELOGI("Node:%s continuous input block count:%zu input count:%u", block_map.first.c_str(), block_map.second.size(),
  1111. it->second);
  1112. GE_IF_BOOL_EXEC(it->second != block_map.second.size(), continue);
  1113. for (auto &it : block_map.second) {
  1114. if (it.second != nullptr) {
  1115. continuous_block_map[it.second] = it.first;
  1116. it.second->input_index_ = it.first;
  1117. continuous_blocks.emplace_back(it.second);
  1118. }
  1119. }
  1120. if (continuous_block_map.size() != continuous_blocks.size()) {
  1121. GELOGW("Node:%s continuous input map size:%zu vector size:%zu", block_map.first.c_str(),
  1122. continuous_block_map.size(), continuous_blocks.size());
  1123. continue;
  1124. }
  1125. ReAssignContinuousBlocks(memory_blocks_, continuous_block_map, dest_memory_blocks, continuous_blocks);
  1126. memory_blocks_.swap(dest_memory_blocks);
  1127. }
  1128. }
  1129. void BlockMemAssigner::ReuseBlocksByLifeTime() {
  1130. for (size_t i = 0; i < memory_blocks_.size(); ++i) {
  1131. auto parent = memory_blocks_[i];
  1132. if (parent == nullptr || parent->deleted_block_) {
  1133. continue;
  1134. }
  1135. if (parent->reuse_mem_ && !IsPostReuse(parent)) {
  1136. parent->reuse_mem_ = false;
  1137. }
  1138. for (size_t j = i + 1; j < memory_blocks_.size(); ++j) {
  1139. parent->AddLifeReuseBlock(memory_blocks_[j]);
  1140. }
  1141. }
  1142. }
  1143. ///
  1144. /// @ingroup domi_omg
  1145. /// @brief traverse memory size, resize, calculate offset
  1146. /// @param [in&out] memory_blocks_ memory block, after calculating offset
  1147. ///
  1148. void BlockMemAssigner::ResizeMemoryBlocks() {
  1149. for (auto &memory_block : memory_blocks_) {
  1150. if (memory_block == nullptr || memory_block->deleted_block_ || memory_block->is_zero_copy_) {
  1151. continue;
  1152. }
  1153. if (memory_block->first_continuous_block_) {
  1154. mem_offset_ += MEM_ALIGN_SIZE;
  1155. }
  1156. memory_block->Resize();
  1157. memory_block->SetHeadOffset(mem_offset_);
  1158. mem_offset_ += memory_block->Size();
  1159. memory_block->SetTailOffset(mem_offset_ - 1);
  1160. }
  1161. GELOGI("mem_offset_ exclude zero_copy_memory is %zu.", mem_offset_);
  1162. }
  1163. ///
  1164. /// @ingroup domi
  1165. /// @brief given NodeTypeIndex, set offset in Op's OpDef
  1166. /// @param [in&out] node_type_index <node, memory type, id>
  1167. /// @param [in] offset offset to be set
  1168. /// @param [in] size memory size
  1169. /// @param [in] real_size memory size in need
  1170. /// @return Status result
  1171. ///
  1172. void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block, size_t real_size, size_t no_align_size,
  1173. bool child_block) {
  1174. ge::OpDescPtr op_desc = node_type.node->GetOpDesc();
  1175. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_desc == nullptr, return, "op_desc is null.");
  1176. string graph_name = node_type.node->GetOwnerComputeGraph()->GetName();
  1177. vector<int64_t> memorys_type;
  1178. int64_t offset = block->HeadOffset();
  1179. size_t end = node_type.life_time_end;
  1180. bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, memorys_type);
  1181. if (node_type.mem_type == kOutput) {
  1182. vector<int64_t> output_list = op_desc->GetOutputOffset();
  1183. for (auto i = static_cast<uint32_t>(output_list.size()); i < node_type.index + 1; i++) {
  1184. output_list.emplace_back(kInvalidOffset);
  1185. }
  1186. if (output_list.empty()) {
  1187. GELOGW("Empty output");
  1188. return;
  1189. }
  1190. if ((op_desc->GetType() == DATA) || (op_desc->GetType() == AIPP_DATA_TYPE) || (op_desc->GetType() == MULTISHAPE) ||
  1191. (op_desc->GetType() == NETOUTPUT)) {
  1192. if ((output_list[node_type.index] == kInvalidOffset) || (output_list[node_type.index] < offset)) {
  1193. output_list.at(node_type.index) = offset;
  1194. }
  1195. } else {
  1196. // fusion: keep the original other type offset value from op_desc
  1197. bool set_out_offset = (!has_mem_type_attr) ||
  1198. (memorys_type.size() > node_type.index && memorys_type[node_type.index] != RT_MEMORY_L1);
  1199. if (set_out_offset) {
  1200. output_list.at(node_type.index) = offset;
  1201. }
  1202. }
  1203. op_desc->SetOutputOffset(output_list);
  1204. } else if (node_type.mem_type == kWorkspace) {
  1205. vector<int64_t> workspace_list;
  1206. workspace_list = op_desc->GetWorkspace();
  1207. for (auto i = static_cast<uint32_t>(workspace_list.size()); i < node_type.index + 1; i++) {
  1208. workspace_list.emplace_back(kInvalidOffset);
  1209. }
  1210. vector<int64_t> workspace_mem_type;
  1211. bool has_workspace_mem_type = ge::AttrUtils::GetListInt(op_desc, TVM_ATTR_NAME_WORKSPACE_TYPE, workspace_mem_type);
  1212. // fusion: keep the original other type offset value from op_desc
  1213. bool set_workspace_offset = (!has_workspace_mem_type) || (workspace_mem_type.size() > node_type.index &&
  1214. workspace_mem_type[node_type.index] != RT_MEMORY_L1);
  1215. if (set_workspace_offset) {
  1216. workspace_list.at(node_type.index) = offset;
  1217. }
  1218. op_desc->SetWorkspace(workspace_list);
  1219. }
  1220. GELOGI(
  1221. "[IMAS]Set %s name[%s] %s[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu]"
  1222. " noalignsize[%zu] life time begin[%zu] life time end[%zu] child[%d] isref[%d].",
  1223. graph_name.c_str(), op_desc->GetName().c_str(), node_type.GetMemType().c_str(), node_type.index, offset,
  1224. op_desc->GetStreamId(), block->Size(), real_size, no_align_size, op_desc->GetId(), end, child_block,
  1225. node_type.ref_input);
  1226. }
  1227. void SetBlockOpMemOffset(MemoryBlock *block, bool child_block) {
  1228. if (block == nullptr) {
  1229. return;
  1230. }
  1231. size_t index = 0;
  1232. size_t real_size = 0;
  1233. size_t no_align_size = 0;
  1234. auto real_size_list_size = block->RealSizeList().size();
  1235. for (const NodeTypeIndex &node_type_index : block->NodeTypeIndexList()) {
  1236. if (index < real_size_list_size) {
  1237. real_size = block->RealSizeList()[index];
  1238. no_align_size = block->NoAlignSizeList()[index];
  1239. }
  1240. SetOffsetSize(node_type_index, block, real_size, no_align_size, child_block);
  1241. index++;
  1242. }
  1243. }
  1244. void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) {
  1245. for (MemoryBlock *memory_block : memory_blocks_) {
  1246. if (memory_block == nullptr || memory_block->deleted_block_) {
  1247. continue;
  1248. }
  1249. if ((is_zero_copy && !memory_block->is_zero_copy_) || (!is_zero_copy && memory_block->is_zero_copy_)) {
  1250. continue;
  1251. }
  1252. SetBlockOpMemOffset(memory_block, false);
  1253. for (MemoryBlock *child_block : memory_block->ChildBlockList()) {
  1254. SetBlockOpMemOffset(child_block, true);
  1255. }
  1256. }
  1257. if (!is_zero_copy) {
  1258. for (const NodeTypeIndex &node_type_index : zero_memory_list_) {
  1259. MemoryBlock block(0, 0);
  1260. SetOffsetSize(node_type_index, &block, 0, 0, false);
  1261. }
  1262. }
  1263. }
  1264. Status BlockMemAssigner::Assign() {
  1265. vector<int64_t> ranges;
  1266. if (GetMemoryRanges(ranges) != SUCCESS) {
  1267. GELOGE(FAILED, "GetMemoryRanges Fail!");
  1268. return FAILED;
  1269. }
  1270. GE_IF_BOOL_EXEC(ranges.empty(), return SUCCESS);
  1271. AssignMemoryWithReuse(ranges);
  1272. SetOpMemOffset(false);
  1273. return SUCCESS;
  1274. }
  1275. void BlockMemAssigner::InitReusableStreamMap() {
  1276. // save a stream's id and its first Node and last node.
  1277. map<int64_t, pair<NodePtr, NodePtr>> stream_head_tail_node_map;
  1278. // save a stream's id and its directly child stream.
  1279. map<int64_t, unordered_set<int64_t>> stream_dependency_map;
  1280. // save a stream's id and its occupied memory.
  1281. unordered_map<int64_t, int64_t> stream_mem_map;
  1282. // Find streams's first and last node.
  1283. FindHeadAndTailNodesForStream(stream_head_tail_node_map, stream_mem_map);
  1284. // If streamB's first node is the output of streamA's last node, then B depends on A.
  1285. FindDependentStream(stream_head_tail_node_map, stream_dependency_map);
  1286. // If a stream has more than one child stream, select the one that occupies the closest memory
  1287. for (const auto &iter : stream_dependency_map) {
  1288. if (iter.second.empty()) {
  1289. continue;
  1290. }
  1291. int64_t target_size = stream_mem_map[iter.first];
  1292. int64_t min_size_gap = LONG_MAX;
  1293. int64_t target_reuse_stream_id = 0;
  1294. for (auto id : iter.second) {
  1295. if (labs(stream_mem_map[id] - target_size) < min_size_gap) {
  1296. target_reuse_stream_id = id;
  1297. min_size_gap = labs(stream_mem_map[id] - target_size);
  1298. }
  1299. }
  1300. // If b can reuse a, then b should also be able to reuse all blocks that a can reuse.
  1301. reusable_streams_map_[target_reuse_stream_id].insert(reusable_streams_map_[iter.first].begin(),
  1302. reusable_streams_map_[iter.first].end());
  1303. }
  1304. }
  1305. void BlockMemAssigner::FindHeadAndTailNodesForStream(map<int64_t, pair<NodePtr, NodePtr>> &stream_head_tail_node_map,
  1306. unordered_map<int64_t, int64_t> &stream_mem_map) {
  1307. for (const auto &n : compute_graph_->GetAllNodes()) {
  1308. GE_IF_BOOL_EXEC(n->GetOpDesc() == nullptr, GELOGW("Op desc is nullptr"); continue);
  1309. auto stream_id = n->GetOpDesc()->GetStreamId();
  1310. // traverse to find streams's first and last node.
  1311. if (stream_head_tail_node_map.find(stream_id) == stream_head_tail_node_map.end()) {
  1312. stream_head_tail_node_map[stream_id] = std::make_pair(n, n);
  1313. reusable_streams_map_[stream_id].insert(stream_id); // a node can reuse blocks from same stream.
  1314. } else {
  1315. stream_head_tail_node_map[stream_id].second = n;
  1316. }
  1317. // Accumulate the output size of the node in the stream.
  1318. for (size_t i = 0; i < n->GetOpDesc()->GetOutputsSize(); i++) {
  1319. int64_t size = 0;
  1320. if (ge::TensorUtils::GetSize(*n->GetOpDesc()->GetOutputDescPtr(static_cast<uint32_t>(i)), size) != SUCCESS) {
  1321. GELOGW("Get output size failed!");
  1322. continue;
  1323. }
  1324. stream_mem_map[stream_id] += size;
  1325. }
  1326. // Accumulate the workspace size of the node in the stream.
  1327. for (auto size : n->GetOpDesc()->GetWorkspaceBytes()) {
  1328. stream_mem_map[stream_id] += size;
  1329. }
  1330. }
  1331. }
  1332. void BlockMemAssigner::FindDependentStream(map<int64_t, pair<NodePtr, NodePtr>> &stream_head_tail_node_map,
  1333. map<int64_t, unordered_set<int64_t>> &stream_dependency_map) {
  1334. for (const auto &it1 : stream_head_tail_node_map) {
  1335. for (const auto &it2 : stream_head_tail_node_map) {
  1336. if (it1 == it2) {
  1337. continue;
  1338. }
  1339. NodePtr pre_node = it1.second.second;
  1340. NodePtr post_node = it2.second.first;
  1341. std::vector<NodePtr> out_nodes;
  1342. // Direct link out_node
  1343. for (const auto &out_node : pre_node->GetOutNodes()) {
  1344. if ((out_node->GetOpDesc() == nullptr) || (post_node->GetOpDesc() == nullptr) ||
  1345. (pre_node->GetOpDesc() == nullptr)) {
  1346. continue;
  1347. }
  1348. out_nodes.emplace_back(out_node);
  1349. }
  1350. FindDependentStreamBetweenGraphs(pre_node, out_nodes);
  1351. for (auto &out_node : out_nodes) {
  1352. if (out_node->GetOpDesc()->GetId() == post_node->GetOpDesc()->GetId()) {
  1353. stream_dependency_map[pre_node->GetOpDesc()->GetStreamId()].insert(post_node->GetOpDesc()->GetStreamId());
  1354. }
  1355. }
  1356. }
  1357. }
  1358. }
  1359. ///
  1360. /// @ingroup GE
  1361. /// @brief Find dependent link between parent_graph and sub_graph
  1362. /// @param [in] pre_node
  1363. /// @param [out] out_nodes
  1364. /// @return void
  1365. /// @author
  1366. ///
  1367. void BlockMemAssigner::FindDependentStreamBetweenGraphs(const NodePtr &pre_node, std::vector<NodePtr> &out_nodes) {
  1368. if ((pre_node == nullptr) || (pre_node->GetOpDesc() == nullptr)) {
  1369. return;
  1370. }
  1371. // FunctionOp & subgraph input
  1372. std::vector<std::string> subgraph_names = pre_node->GetOpDesc()->GetSubgraphInstanceNames();
  1373. for (auto &subgraph_name : subgraph_names) {
  1374. ComputeGraphPtr subgraph = compute_graph_->GetSubgraph(subgraph_name);
  1375. if (subgraph == nullptr) {
  1376. continue;
  1377. }
  1378. for (auto &node : subgraph->GetDirectNode()) {
  1379. OpDescPtr op_desc = node->GetOpDesc();
  1380. if (op_desc == nullptr) {
  1381. continue;
  1382. }
  1383. if (op_desc->HasAttr(ATTR_NAME_PARENT_NODE_INDEX)) {
  1384. out_nodes.emplace_back(node);
  1385. }
  1386. }
  1387. }
  1388. // subgraph output & parent_node output
  1389. if (NodeUtils::IsSubgraphOutput(pre_node)) {
  1390. NodePtr parent_node = pre_node->GetOwnerComputeGraph()->GetParentNode();
  1391. for (const auto &out_node : parent_node->GetOutNodes()) {
  1392. out_nodes.emplace_back(out_node);
  1393. }
  1394. }
  1395. }
  1396. bool BlockMemAssigner::CheckIsZeroMemNodeType(const string &node_type) const {
  1397. return (node_type == VARIABLE) || (node_type == CONSTANT) || (node_type == MULTISHAPE) ||
  1398. (node_type == HCOMBROADCAST) || (node_type == HCOMALLREDUCE) || (node_type == CONSTANTOP) ||
  1399. (node_type == ASSIGNADD) || (node_type == ASSIGNSUB) || (node_type == ASSIGN) || (node_type == HVDWAIT) ||
  1400. (node_type == HVDCALLBACKBROADCAST) || (node_type == HVDCALLBACKALLREDUCE);
  1401. }
  1402. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示