You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_var_manager.cc 36 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/manager/graph_var_manager.h"
  17. #include <utility>
  18. #include "common/l2_cache_optimize.h"
  19. #include "common/types.h"
  20. #include "framework/common/debug/ge_log.h"
  21. #include "framework/common/debug/log.h"
  22. #include "ge/ge_api_types.h"
  23. #include "graph/debug/ge_attr_define.h"
  24. #include "graph/manager/graph_mem_allocator.h"
  25. #include "graph/manager/trans_var_data_utils.h"
  26. #include "graph/utils/attr_utils.h"
  27. #include "graph/utils/type_utils.h"
  28. using std::map;
  29. using std::string;
  30. using std::vector;
  31. namespace ge {
  32. VarResource::VarResource(uint64_t session_id) : session_id_(session_id) {}
  33. VarResource::~VarResource() {
  34. var_offset_set_.clear();
  35. var_addr_mgr_map_.clear();
  36. cur_var_tensor_desc_map_.clear();
  37. var_broad_cast_info_.clear();
  38. }
  39. ge::Status VarResource::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr,
  40. rtMemType_t &memory_type) {
  41. if (dev_ptr == nullptr) {
  42. GELOGE(FAILED, "[GetVarAddr] dev_ptr is null!");
  43. return FAILED;
  44. }
  45. std::string var_key = VarKey(var_name, tensor_desc);
  46. GELOGD("VarResource::GetVarAddr , var_key = %s", var_key.c_str());
  47. auto iter = var_addr_mgr_map_.find(var_key);
  48. if (iter == var_addr_mgr_map_.end()) {
  49. GELOGE(FAILED, "VarResource::GetVarAddr failed, var_key %s", var_key.c_str());
  50. return FAILED;
  51. }
  52. *dev_ptr = iter->second.address;
  53. memory_type = iter->second.memory_type;
  54. return SUCCESS;
  55. }
  56. void VarResource::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) {
  57. var_addr_mgr_map = var_addr_mgr_map_;
  58. }
  59. void VarResource::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr,
  60. rtMemType_t memory_type) {
  61. std::string var_key = VarKey(var_name, tensor_desc);
  62. GELOGI("VarResource::SetVarAddr , var_key = %s, mem_type:%u", var_key.c_str(), memory_type);
  63. if (var_addr_mgr_map_.count(var_key) == 0) {
  64. GELOGI("SetVarAddr node_name %s, tensor_desc type %s, format %s", var_name.c_str(),
  65. TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  66. TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  67. VarAddrMgr var_addr_mgr;
  68. var_addr_mgr.address = dev_ptr;
  69. var_addr_mgr.tensor_desc = tensor_desc;
  70. var_addr_mgr_map_[var_key] = var_addr_mgr;
  71. }
  72. cur_var_tensor_desc_map_[var_name] = tensor_desc;
  73. }
  74. ge::Status VarResource::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address,
  75. rtMemType_t memory_type) {
  76. std::string var_key = VarKey(var_name, tensor_desc);
  77. GELOGD("VarResource::SaveVarAddr, var_key = %s", var_key.c_str());
  78. if (var_addr_mgr_map_.count(var_key) == 0) {
  79. uint64_t logic_address = VarManager::Instance(session_id_)->GetVarMemLogicBase() +
  80. static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
  81. GELOGI("SaveVarAddr node_name %s, tensor_desc format %s, type %s.", var_name.c_str(),
  82. TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  83. TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
  84. VarAddrMgr var_addr_mgr;
  85. var_addr_mgr.address = reinterpret_cast<uint8_t *>(static_cast<std::uintptr_t>(logic_address));
  86. var_addr_mgr.offset = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
  87. var_addr_mgr.tensor_desc = tensor_desc;
  88. var_addr_mgr.memory_type = memory_type;
  89. var_addr_mgr_map_[var_key] = var_addr_mgr;
  90. var_offset_set_.insert(logic_address);
  91. return SUCCESS;
  92. }
  93. GELOGE(FAILED, "VarResource::SaveVarAddr, var_key %s save addr conflict", var_key.c_str());
  94. return FAILED;
  95. }
  96. bool VarResource::IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  97. std::string var_key = VarKey(var_name, tensor_desc);
  98. return var_addr_mgr_map_.count(var_key) != 0;
  99. }
  100. bool VarResource::IsVarExist(const std::string &var_name) { return cur_var_tensor_desc_map_.count(var_name) != 0; }
  101. std::string VarResource::VarKey(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  102. std::string var_key(var_name);
  103. var_key.append(std::to_string(static_cast<int32_t>(tensor_desc.GetFormat())))
  104. .append("_")
  105. .append(std::to_string(static_cast<int32_t>(tensor_desc.GetDataType())));
  106. return var_key;
  107. }
  108. ge::Status VarResource::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
  109. if (cur_var_tensor_desc_map_.count(var_name) == 0) {
  110. return FAILED;
  111. }
  112. tensor_desc = cur_var_tensor_desc_map_[var_name];
  113. return SUCCESS;
  114. }
  115. ge::Status VarResource::RenewCurVarDesc(const std::string &var_name, const ge::OpDescPtr &op_desc) {
  116. if (cur_var_tensor_desc_map_.count(var_name) == 0) {
  117. GELOGI("There is no this node[%s] in var tensor_desc map. so no need renew!", var_name.c_str());
  118. return SUCCESS;
  119. }
  120. if (op_desc == nullptr) {
  121. GELOGE(FAILED, "[RenewCurVarDesc] renew var desc fail! input opdesc is null!");
  122. return FAILED;
  123. }
  124. ge::GeTensorDesc curr_desc;
  125. ge::Status ret = GetCurVarDesc(var_name, curr_desc);
  126. if (ret != SUCCESS) {
  127. GELOGE(FAILED, "[RenewCurVarDesc] Get var desc fail!");
  128. return FAILED;
  129. }
  130. std::string key = VarKey(var_name, curr_desc);
  131. curr_desc.SetOriginFormat((op_desc->GetOutputDesc(0)).GetOriginFormat());
  132. curr_desc.SetFormat((op_desc->GetOutputDesc(0)).GetFormat());
  133. cur_var_tensor_desc_map_[var_name] = curr_desc;
  134. auto iter = var_addr_mgr_map_.find(key);
  135. if (iter == var_addr_mgr_map_.end()) {
  136. GELOGE(FAILED, "[RenewCurVarDesc] can't find ele with key [%s]", key.c_str());
  137. return FAILED;
  138. }
  139. auto val = iter->second;
  140. val.tensor_desc.SetOriginFormat((op_desc->GetOutputDesc(0)).GetOriginFormat());
  141. val.tensor_desc.SetFormat((op_desc->GetOutputDesc(0)).GetFormat());
  142. var_addr_mgr_map_.erase(iter);
  143. key = VarKey(var_name, curr_desc);
  144. var_addr_mgr_map_[key] = val;
  145. return SUCCESS;
  146. }
  147. void VarResource::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info) {
  148. var_broad_cast_info_[graph_id][broad_cast_info.var_name] = broad_cast_info;
  149. }
  150. ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) {
  151. if (var_broad_cast_info_.count(graph_id) == 0 || var_broad_cast_info_[graph_id].count(var_name) == 0) {
  152. return FAILED;
  153. }
  154. broad_cast_info = var_broad_cast_info_[graph_id][var_name];
  155. return SUCCESS;
  156. }
  157. ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name,
  158. const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr) {
  159. if (var_op_desc == nullptr) {
  160. GELOGE(FAILED, "[SyncVarData2BroadCast] var opdesc is null!");
  161. return FAILED;
  162. }
  163. GE_CHECK_NOTNULL(base_ptr);
  164. GELOGI("SyncVarData2BroadCast graph_id: %u, var_name: %s.", graph_id, var_name.c_str());
  165. VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
  166. uint8_t *dst_addr = base_ptr + var_broadcast_info.input_offset;
  167. ge::GeTensorDesc var_tensor_desc = var_op_desc->GetOutputDesc(0);
  168. return ge::TransVarDataUtils::SyncVarData2BroadCast(var_name, var_tensor_desc, dst_addr,
  169. var_broadcast_info.input_size, session_id_);
  170. }
  171. ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
  172. const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr) {
  173. GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str());
  174. GE_CHECK_NOTNULL(var_op_desc);
  175. string var_is_broadcast;
  176. bool is_broadcast = AttrUtils::GetStr(var_op_desc, VAR_ATTR_VAR_IS_BROADCAST, var_is_broadcast);
  177. if (!is_broadcast) {
  178. return SUCCESS;
  179. }
  180. VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
  181. // subgraph base_ptr could be nullptr, task it as base 0
  182. uint8_t *dst_addr = base_ptr + var_broadcast_info.output_offset;
  183. ge::GeTensorDesc var_tensor_desc = var_op_desc->GetOutputDesc(0);
  184. return ge::TransVarDataUtils::SyncBroadCastData2Var(dst_addr, var_broadcast_info.output_size, var_name,
  185. var_tensor_desc, session_id_);
  186. }
  187. ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_name,
  188. const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr) {
  189. GE_CHECK_NOTNULL(var_op_desc);
  190. string var_is_broadcast;
  191. bool is_broadcast = AttrUtils::GetStr(var_op_desc, VAR_ATTR_VAR_IS_BROADCAST, var_is_broadcast);
  192. if (!is_broadcast) {
  193. return SUCCESS;
  194. }
  195. return SyncVarData2BroadCast(graph_id, var_name, var_op_desc, base_ptr);
  196. }
  197. bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_set_.count(offset) > 0; }
  198. VarTransRoad *VarResource::GetTransRoad(const std::string &var_name) {
  199. auto iter = var_to_trans_road_.find(var_name);
  200. if (iter == var_to_trans_road_.end()) {
  201. return nullptr;
  202. } else {
  203. return &(iter->second);
  204. }
  205. }
  206. Status VarResource::GetChangedGraphId(const std::string &var_name, uint32_t &graph_id) {
  207. auto iter = var_names_to_changed_graph_id_.find(var_name);
  208. if (iter == var_names_to_changed_graph_id_.end()) {
  209. return FAILED;
  210. } else {
  211. graph_id = iter->second;
  212. return SUCCESS;
  213. }
  214. }
  215. Status VarResource::GetAllocatedGraphId(const std::string &var_name, uint32_t &graph_id) {
  216. auto iter = var_names_to_allocated_graph_id_.find(var_name);
  217. if (iter == var_names_to_allocated_graph_id_.end()) {
  218. return FAILED;
  219. } else {
  220. graph_id = iter->second;
  221. return SUCCESS;
  222. }
  223. }
  224. Status VarResource::SetAllocatedGraphId(const std::string &var_name, uint32_t graph_id) {
  225. if (GetAllocatedGraphId(var_name, graph_id) == SUCCESS) {
  226. GELOGW("VarManager var[%s] has been allocated in graph[%d]", var_name.c_str(), graph_id);
  227. return SUCCESS;
  228. }
  229. var_names_to_allocated_graph_id_[var_name] = graph_id;
  230. return SUCCESS;
  231. }
  232. MemResource::MemResource() : total_size_(0), var_mem_size_(0) {}
  233. Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &mem_offset) {
  234. size = (size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
  235. uint64_t real_size = size;
  236. total_size_ = VarManager::Instance(session_id)->GetVarMemMaxSize();
  237. if (total_size_ < var_mem_size_) {
  238. GELOGE(PARAM_INVALID, "total_size_: %lu is smaller than var_mem_size_: %lu", total_size_, var_mem_size_);
  239. return PARAM_INVALID;
  240. }
  241. uint64_t free_size = total_size_ - var_mem_size_;
  242. if (free_size < (size + kSessionMemAlignSize * 2)) {
  243. GELOGE(PARAM_INVALID, "Out of memory : current var size[%lu] exceeds total var size[%lu]",
  244. size + kSessionMemAlignSize * 2 + var_mem_size_, total_size_);
  245. return PARAM_INVALID;
  246. }
  247. mem_offset = var_mem_size_;
  248. // offset for next, align 512 BYTE
  249. size = size + kSessionMemAlignSize;
  250. var_mem_size_ = var_mem_size_ + size;
  251. // align 512 BYTE
  252. var_mem_size_ = var_mem_size_ + kSessionMemAlignSize;
  253. GELOGI(
  254. "[IMAS]AssignVarMem Set session_%lu name[%s] output[%d]"
  255. "offset to [%zu] size[%lu] realsize[%lu].",
  256. session_id, var_name.c_str(), 0, mem_offset, (var_mem_size_ - mem_offset), real_size);
  257. return SUCCESS;
  258. }
  259. uint64_t MemResource::GetVarMemSize() const { return var_mem_size_; }
  260. void MemResource::UpdateVarMemSize(int64_t mem_size) { var_mem_size_ = mem_size; };
  261. VarManager::VarManager(uint64_t session_id)
  262. : version_(SessionVersion::OTHER_VERSION),
  263. session_id_(session_id),
  264. device_id_(0),
  265. job_id_(0),
  266. graph_mem_max_size_(kGraphMemoryManagerMallocMaxSize),
  267. var_mem_max_size_(kMemoryVarManagerMallocSize),
  268. var_mem_logic_base_(kMemoryVarLogicBase),
  269. use_max_mem_size_(kUseMaxMemorySize) {}
  270. VarManager *VarManager::Instance(uint64_t session_id) {
  271. GELOGD("VarManager::Instance, session id = %lu", session_id);
  272. return VarManagerPool::Instance().GetVarManager(session_id);
  273. }
  274. void VarManager::Destory() {
  275. std::lock_guard<std::recursive_mutex> lock(mutex_);
  276. GELOGI("VarManager::Destory, session id = %lu.", session_id_);
  277. version_ = SessionVersion::OTHER_VERSION;
  278. device_id_ = 0;
  279. session_id_ = 0;
  280. for (auto &memory_resource : mem_resource_map_) {
  281. if (memory_resource.second != nullptr) {
  282. delete memory_resource.second;
  283. memory_resource.second = nullptr;
  284. }
  285. }
  286. mem_resource_map_.clear();
  287. }
  288. ge::Status VarManager::Init(const uint32_t &version, const uint64_t &session_id, const uint32_t &device_id,
  289. const uint64_t &job_id) {
  290. std::lock_guard<std::recursive_mutex> lock(mutex_);
  291. GELOGI("VarManager::Init, session id = %lu.", session_id);
  292. version_ = version;
  293. device_id_ = device_id;
  294. session_id_ = session_id;
  295. job_id_ = job_id;
  296. var_resource_ = std::unique_ptr<VarResource>(new (std::nothrow) VarResource(session_id_));
  297. if (var_resource_ == nullptr) {
  298. GELOGW("VarManager has not been init.");
  299. return ge::INTERNAL_ERROR;
  300. }
  301. return SUCCESS;
  302. }
  303. const uint64_t &VarManager::SessionId() const {
  304. std::lock_guard<std::recursive_mutex> lock(mutex_);
  305. return session_id_;
  306. }
  307. const uint32_t &VarManager::DeviceId() const {
  308. std::lock_guard<std::recursive_mutex> lock(mutex_);
  309. return device_id_;
  310. }
  311. const uint64_t &VarManager::JobId() const {
  312. std::lock_guard<std::recursive_mutex> lock(mutex_);
  313. return job_id_;
  314. }
  315. ge::Status VarManager::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr,
  316. rtMemType_t memory_type) {
  317. GELOGI("VarManager::SetVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  318. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  319. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  320. std::lock_guard<std::recursive_mutex> lock(mutex_);
  321. if (var_resource_ == nullptr) {
  322. GELOGW("VarManager has not been init.");
  323. return ge::INTERNAL_ERROR;
  324. }
  325. var_resource_->SetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  326. return ge::SUCCESS;
  327. }
  328. ge::Status VarManager::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address,
  329. rtMemType_t memory_type) {
  330. GELOGI("VarManager::SaveVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  331. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  332. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  333. std::lock_guard<std::recursive_mutex> lock(mutex_);
  334. if (var_resource_ == nullptr) {
  335. GELOGW("VarManager has not been init.");
  336. return ge::INTERNAL_ERROR;
  337. }
  338. var_resource_->SaveVarAddr(var_name, tensor_desc, address, memory_type);
  339. return ge::SUCCESS;
  340. }
  341. ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr,
  342. rtMemType_t &memory_type) {
  343. std::lock_guard<std::recursive_mutex> lock(mutex_);
  344. GELOGD("VarManager::GetVarAddr var_name = %s, data_type = %s, data_format = %s", var_name.c_str(),
  345. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  346. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  347. if (var_resource_ == nullptr) {
  348. GELOGW("VarManager has not been init.");
  349. return ge::INTERNAL_ERROR;
  350. }
  351. auto ret = var_resource_->GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  352. if (ret != SUCCESS) {
  353. GELOGW("GetVarAddr fail.");
  354. return ge::INTERNAL_ERROR;
  355. }
  356. return SUCCESS;
  357. }
  358. ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr) {
  359. std::lock_guard<std::recursive_mutex> lock(mutex_);
  360. rtMemType_t memory_type = RT_MEMORY_HBM;
  361. return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  362. }
  363. void VarManager::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) {
  364. var_resource_->GetAllVarAddrMgr(var_addr_mgr_map);
  365. }
  366. int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) {
  367. std::lock_guard<std::recursive_mutex> lock(mutex_);
  368. MemResource *mem_resource = nullptr;
  369. auto iter = mem_resource_map_.find(memory_type);
  370. if (iter == mem_resource_map_.end()) {
  371. return 0;
  372. } else {
  373. mem_resource = iter->second;
  374. }
  375. if (mem_resource == nullptr) {
  376. GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid.");
  377. return 0;
  378. }
  379. return mem_resource->GetVarMemSize();
  380. }
  381. Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) {
  382. std::lock_guard<std::recursive_mutex> lock(mutex_);
  383. MemResource *mem_resource = nullptr;
  384. auto iter = mem_resource_map_.find(memory_type);
  385. if (iter == mem_resource_map_.end()) {
  386. mem_resource = new (std::nothrow) MemResource();
  387. if (mem_resource == nullptr) {
  388. GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type);
  389. return ge::INTERNAL_ERROR;
  390. } else {
  391. mem_resource_map_[memory_type] = mem_resource;
  392. }
  393. } else {
  394. mem_resource = iter->second;
  395. }
  396. if (mem_resource == nullptr) {
  397. GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid.");
  398. return FAILED;
  399. }
  400. mem_resource->UpdateVarMemSize(mem_size);
  401. return SUCCESS;
  402. }
  403. ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc,
  404. rtMemType_t memory_type) {
  405. std::lock_guard<std::recursive_mutex> lock(mutex_);
  406. GELOGI("VarManager::AssignVarMem var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  407. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  408. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  409. int64_t tensor_desc_size = 0;
  410. size_t mem_offset = 0;
  411. ge::Status result = TensorUtils::GetSize(tensor_desc, tensor_desc_size);
  412. if (result != ge::SUCCESS) {
  413. GELOGE(result, "get size from TensorDesc failed");
  414. return result;
  415. }
  416. MemResource *mem_resource = nullptr;
  417. auto it = mem_resource_map_.find(memory_type);
  418. if (it == mem_resource_map_.end()) {
  419. mem_resource = new (std::nothrow) MemResource();
  420. if (mem_resource == nullptr) {
  421. GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type);
  422. return ge::INTERNAL_ERROR;
  423. } else {
  424. mem_resource_map_[memory_type] = mem_resource;
  425. }
  426. } else {
  427. mem_resource = it->second;
  428. }
  429. if (mem_resource == nullptr) {
  430. GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid, memory_type = %u.", memory_type);
  431. return ge::INTERNAL_ERROR;
  432. }
  433. result = mem_resource->AssignVarMem(var_name, tensor_desc_size, session_id_, mem_offset);
  434. if (result != SUCCESS) {
  435. GELOGE(ge::INTERNAL_ERROR, "AssignVarMem by offset failed.");
  436. return ge::INTERNAL_ERROR;
  437. }
  438. if (var_resource_ == nullptr) {
  439. GELOGW("VarManager has not been init.");
  440. return ge::INTERNAL_ERROR;
  441. }
  442. result = var_resource_->SaveVarAddr(
  443. var_name, tensor_desc, reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  444. if (result != SUCCESS) {
  445. GELOGE(ge::INTERNAL_ERROR, "AssignVarMem by offset failed.");
  446. return ge::INTERNAL_ERROR;
  447. }
  448. result = var_resource_->GetVarAddr(
  449. var_name, tensor_desc, reinterpret_cast<uint8_t **>(reinterpret_cast<uintptr_t>(&mem_offset)), memory_type);
  450. if (result != SUCCESS) {
  451. GELOGE(ge::INTERNAL_ERROR, "GetVarAddr by offset failed.");
  452. return ge::INTERNAL_ERROR;
  453. }
  454. ge::GeTensorDesc cur_tensor_desc;
  455. result = var_resource_->GetCurVarDesc(var_name, cur_tensor_desc);
  456. if (result != SUCCESS) {
  457. var_resource_->SetVarAddr(var_name, tensor_desc,
  458. reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  459. return SUCCESS;
  460. }
  461. if (cur_tensor_desc.GetFormat() != tensor_desc.GetFormat() ||
  462. cur_tensor_desc.GetDataType() != tensor_desc.GetDataType() ||
  463. cur_tensor_desc.GetShape().GetDims() != tensor_desc.GetShape().GetDims()) {
  464. GELOGI("var %s assigned new memory (format, data type, shape) (%s, %s, %zu) from (%s, %s, %zu)", var_name.c_str(),
  465. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  466. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  467. tensor_desc.GetShape().GetDims().size(),
  468. ge::TypeUtils::DataTypeToSerialString(cur_tensor_desc.GetDataType()).c_str(),
  469. ge::TypeUtils::FormatToSerialString(cur_tensor_desc.GetFormat()).c_str(),
  470. cur_tensor_desc.GetShape().GetDims().size());
  471. var_resource_->SetVarAddr(var_name, tensor_desc,
  472. reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  473. }
  474. return SUCCESS;
  475. }
  476. bool VarManager::IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  477. std::lock_guard<std::recursive_mutex> lock(mutex_);
  478. GELOGD("VarManager::IsVarExist var_name = %s, data_type = %s, data_format = %s", var_name.c_str(),
  479. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  480. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
  481. if (var_resource_ == nullptr) {
  482. GELOGW("VarManager has not been init.");
  483. return false;
  484. }
  485. return var_resource_->IsVarExist(var_name, tensor_desc);
  486. }
  487. bool VarManager::IsVarExist(const std::string &var_name) {
  488. std::lock_guard<std::recursive_mutex> lock(mutex_);
  489. if (var_resource_ == nullptr) {
  490. GELOGW("VarManager has not been init.");
  491. return false;
  492. }
  493. return var_resource_->IsVarExist(var_name);
  494. }
  495. ge::Status VarManager::SyncVarData(uint32_t graph_id, const std::string &var_name, ge::ConstOpDescPtr var_op_desc,
  496. uint8_t *base_ptr) {
  497. std::lock_guard<std::recursive_mutex> lock(mutex_);
  498. if (var_resource_ == nullptr) {
  499. GELOGW("VarManager has not been init.");
  500. return ge::INTERNAL_ERROR;
  501. }
  502. return var_resource_->SyncVarData(graph_id, var_name, std::move(var_op_desc), base_ptr);
  503. }
  504. ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
  505. std::lock_guard<std::recursive_mutex> lock(mutex_);
  506. GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str());
  507. if (var_resource_ == nullptr) {
  508. GELOGW("VarManager has not been init.");
  509. return ge::INTERNAL_ERROR;
  510. }
  511. return var_resource_->GetCurVarDesc(var_name, tensor_desc);
  512. }
  513. ge::Status VarManager::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info) {
  514. std::lock_guard<std::recursive_mutex> lock(mutex_);
  515. GELOGI(
  516. "VarManager::SaveBroadCastInfo var_name = %s, broadcast name = %s, "
  517. "idx = %d, input_offset = %ld, input_size = %lu, output_offset = %ld, "
  518. "output_size = %lu",
  519. broad_cast_info.var_name.c_str(), broad_cast_info.broadcast_name.c_str(), broad_cast_info.idx,
  520. broad_cast_info.input_offset, broad_cast_info.input_size, broad_cast_info.output_offset,
  521. broad_cast_info.output_size);
  522. if (var_resource_ == nullptr) {
  523. GELOGW("VarManager has not been init.");
  524. return ge::INTERNAL_ERROR;
  525. }
  526. var_resource_->SaveBroadCastInfo(graph_id, broad_cast_info);
  527. return SUCCESS;
  528. }
  529. ge::Status VarManager::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) {
  530. std::lock_guard<std::recursive_mutex> lock(mutex_);
  531. if (var_resource_ == nullptr) {
  532. GELOGW("VarManager has not been init.");
  533. return ge::INTERNAL_ERROR;
  534. }
  535. return var_resource_->GetBroadCastInfo(graph_id, var_name, broad_cast_info);
  536. }
  537. ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) {
  538. std::lock_guard<std::recursive_mutex> lock(mutex_);
  539. GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str());
  540. if (var_resource_ == nullptr) {
  541. GELOGE(ge::INTERNAL_ERROR, "VarManager has not been init.");
  542. return ge::INTERNAL_ERROR;
  543. }
  544. return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc));
  545. }
  546. ge::Status VarManager::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
  547. ge::ConstOpDescPtr var_op_desc, uint8_t *base_ptr) {
  548. std::lock_guard<std::recursive_mutex> lock(mutex_);
  549. if (var_resource_ == nullptr) {
  550. GELOGW("VarManager has not been init.");
  551. return ge::INTERNAL_ERROR;
  552. }
  553. return var_resource_->SyncBroadCastData2Var(graph_id, var_name, std::move(var_op_desc), base_ptr);
  554. }
  555. bool VarManager::IsVarAddr(const int64_t &offset) {
  556. std::lock_guard<std::recursive_mutex> lock(mutex_);
  557. if (var_resource_ == nullptr) {
  558. GELOGW("VarManager has not been init.");
  559. return false;
  560. }
  561. return var_resource_->IsVarAddr(offset);
  562. }
  563. ge::Status VarManager::MallocVarMemory(size_t memory_size) {
  564. std::lock_guard<std::recursive_mutex> lock(mutex_);
  565. uint8_t *var_mem_base = nullptr;
  566. string memory_key = std::to_string(session_id_);
  567. // malloc variable memory
  568. size_t var_memory_size = memory_size;
  569. // align 512 BYTE
  570. var_memory_size = (var_memory_size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
  571. const string purpose("variables and constant op memory in training network.");
  572. var_mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, memory_key, var_memory_size);
  573. if (var_mem_base == nullptr) {
  574. GELOGE(ge::INTERNAL_ERROR,
  575. "VarManager::MallocVarMemory failed "
  576. "session_id = %s",
  577. memory_key.c_str());
  578. return ge::INTERNAL_ERROR;
  579. }
  580. return SUCCESS;
  581. }
  582. uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type) {
  583. std::lock_guard<std::recursive_mutex> lock(mutex_);
  584. string memory_key = std::to_string(session_id_);
  585. return MemManager::Instance(memory_type)->GetMemoryAddr(memory_key);
  586. }
  587. uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) {
  588. std::lock_guard<std::recursive_mutex> lock(mutex_);
  589. string mem_key = std::to_string(session_id_);
  590. uint8_t *mem_base = MemManager::Instance(memory_type)->GetMemoryAddr(mem_key);
  591. if (mem_base == nullptr) {
  592. return nullptr;
  593. }
  594. uint8_t *mem_addr =
  595. logic_addr + reinterpret_cast<intptr_t>(mem_base) - VarManager::Instance(session_id_)->GetVarMemLogicBase();
  596. return mem_addr;
  597. }
  598. ge::Status VarManager::FreeVarMemory() {
  599. std::lock_guard<std::recursive_mutex> lock(mutex_);
  600. string memory_key = std::to_string(SessionId());
  601. return MemManager::Instance(RT_MEMORY_HBM)->FreeMemory(memory_key);
  602. }
  603. ge::Status VarManager::SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) {
  604. std::lock_guard<std::recursive_mutex> lock(mutex_);
  605. if (var_resource_ == nullptr) {
  606. GELOGW("VarManager has not been init.");
  607. return ge::INTERNAL_ERROR;
  608. }
  609. return var_resource_->SetTransRoad(var_name, trans_road);
  610. }
  611. VarTransRoad *VarManager::GetTransRoad(const std::string &var_name) {
  612. std::lock_guard<std::recursive_mutex> lock(mutex_);
  613. if (var_resource_ == nullptr) {
  614. GELOGW("VarManager has not been init.");
  615. return nullptr;
  616. }
  617. return var_resource_->GetTransRoad(var_name);
  618. }
  619. Status VarManager::SetChangedGraphId(const std::string &var_name, uint32_t graph_id) {
  620. std::lock_guard<std::recursive_mutex> lock(mutex_);
  621. if (var_resource_ == nullptr) {
  622. GELOGW("VarManager has not been init.");
  623. return INTERNAL_ERROR;
  624. }
  625. return var_resource_->SetChangedGraphId(var_name, graph_id);
  626. }
  627. Status VarManager::GetChangedGraphId(const std::string &var_name, uint32_t &graph_id) {
  628. std::lock_guard<std::recursive_mutex> lock(mutex_);
  629. if (var_resource_ == nullptr) {
  630. GELOGW("VarManager has not been init.");
  631. return INTERNAL_ERROR;
  632. }
  633. return var_resource_->GetChangedGraphId(var_name, graph_id);
  634. }
  635. Status VarManager::SetMemoryMallocSize(const map<string, string> &options) {
  636. auto it = options.find(GRAPH_MEMORY_MAX_SIZE);
  637. if (it == options.end()) {
  638. graph_mem_max_size_ = kGraphMemoryManagerMallocMaxSize;
  639. } else {
  640. string graph_memory_manager_malloc_max_size = it->second;
  641. ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_);
  642. if (ret != SUCCESS) {
  643. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse graph memory manager malloc max size failed.");
  644. return ge::GE_GRAPH_OPTIONS_INVALID;
  645. }
  646. GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_);
  647. }
  648. it = options.find(VARIABLE_MEMORY_MAX_SIZE);
  649. if (it == options.end()) {
  650. var_mem_max_size_ = kMemoryVarManagerMallocSize;
  651. } else {
  652. string memory_var_manager_malloc_size = it->second;
  653. ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_);
  654. if (ret != SUCCESS) {
  655. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse memory var manager malloc size failed.");
  656. return ge::GE_GRAPH_OPTIONS_INVALID;
  657. }
  658. }
  659. var_mem_logic_base_ = graph_mem_max_size_ + kGraphMemoryBuffer;
  660. if (var_mem_logic_base_ > kMaxMemorySize) {
  661. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "kMemoryVarLogicBase : %zu can not exceed max memory size : %zu.",
  662. var_mem_logic_base_, kMaxMemorySize);
  663. return ge::GE_GRAPH_OPTIONS_INVALID;
  664. }
  665. use_max_mem_size_ = graph_mem_max_size_ + var_mem_max_size_;
  666. if (use_max_mem_size_ > kMaxMemorySize) {
  667. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "kUseMaxMemorySize : %zu can not exceed max memory size : %zu.",
  668. use_max_mem_size_, kMaxMemorySize);
  669. return ge::GE_GRAPH_OPTIONS_INVALID;
  670. }
  671. GELOGI("Set memory malloc size successfully");
  672. return SUCCESS;
  673. }
  674. Status VarManager::ParseMemoryMallocSize(string &memory_size, size_t &result) {
  675. if (memory_size.empty()) {
  676. GELOGE(GE_GRAPH_OPTIONS_INVALID, "Memory malloc size input is empty.");
  677. return GE_GRAPH_OPTIONS_INVALID;
  678. }
  679. // split string by '*'
  680. vector<string> splits;
  681. std::istringstream str(memory_size);
  682. string str_split;
  683. while (getline(str, str_split, '*')) {
  684. splits.emplace_back(str_split);
  685. }
  686. result = 1;
  687. for (string split : splits) {
  688. // Trim
  689. auto it = split.find_first_not_of(" ");
  690. if (it != string::npos) {
  691. split.erase(0, it);
  692. }
  693. it = split.find_last_not_of(" ");
  694. if (it != string::npos) {
  695. split.erase(it + 1);
  696. }
  697. for (char c : split) {
  698. if (!isdigit(c)) {
  699. GELOGE(GE_GRAPH_OPTIONS_INVALID, "Memory malloc size input contains non digit.");
  700. return GE_GRAPH_OPTIONS_INVALID;
  701. }
  702. }
  703. uint64_t num = std::strtoul(split.c_str(), nullptr, 0);
  704. GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(result, static_cast<uint32_t>(num)),
  705. GELOGE(FAILED, "Input memory size is out of range.");
  706. return FAILED);
  707. if ((num > kMaxMemorySize) || (result * static_cast<size_t>(num) > kMaxMemorySize)) {
  708. GELOGE(FAILED, "Input memory size can not exceed max memory size : %zu.", kMaxMemorySize);
  709. return FAILED;
  710. }
  711. result *= static_cast<size_t>(num);
  712. }
  713. return SUCCESS;
  714. }
  715. void VarManager::RemoveChangedGraphId(const std::string &var_name) {
  716. std::lock_guard<std::recursive_mutex> lock(mutex_);
  717. if (var_resource_ == nullptr) {
  718. GELOGW("VarManager has not been init.");
  719. return;
  720. }
  721. var_resource_->RemoveChangedGraphId(var_name);
  722. }
  723. Status VarManager::SetAllocatedGraphId(const std::string &var_name, uint32_t graph_id) {
  724. std::lock_guard<std::recursive_mutex> lock(mutex_);
  725. if (var_resource_ == nullptr) {
  726. GELOGW("VarManager has not been init.");
  727. return INTERNAL_ERROR;
  728. }
  729. return var_resource_->SetAllocatedGraphId(var_name, graph_id);
  730. }
  731. Status VarManager::GetAllocatedGraphId(const std::string &var_name, uint32_t &graph_id) {
  732. std::lock_guard<std::recursive_mutex> lock(mutex_);
  733. if (var_resource_ == nullptr) {
  734. GELOGW("VarManager has not been init.");
  735. return INTERNAL_ERROR;
  736. }
  737. return var_resource_->GetAllocatedGraphId(var_name, graph_id);
  738. }
  739. void VarManager::RemoveAllocatedGraphId(const std::string &var_name) {
  740. std::lock_guard<std::recursive_mutex> lock(mutex_);
  741. if (var_resource_ == nullptr) {
  742. GELOGW("VarManager has not been init.");
  743. return;
  744. }
  745. var_resource_->RemoveAllocatedGraphId(var_name);
  746. }
  747. Status VarManager::GetAllVariables(std::map<std::string, GeTensorDesc> &all_variables) {
  748. std::lock_guard<std::recursive_mutex> lock(mutex_);
  749. if (var_resource_ == nullptr) {
  750. GELOGW("VarManager has not been inited.");
  751. return INTERNAL_ERROR;
  752. }
  753. auto new_variable_desc = var_resource_->GetAllVarDesc();
  754. if (new_variable_desc.size() == 0) {
  755. GELOGW("VarManager don't have variables.");
  756. return INTERNAL_ERROR;
  757. }
  758. for (auto iter = new_variable_desc.begin(); iter != new_variable_desc.end(); ++iter) {
  759. auto trans_road = var_resource_->GetTransRoad(iter->first);
  760. if (trans_road == nullptr || trans_road->empty()) {
  761. GELOGI("The variable %s does not have any trans road", iter->first.c_str());
  762. all_variables[iter->first] = iter->second;
  763. continue;
  764. }
  765. // get origin trans info : the first trans node info
  766. auto origin_trans_node_info = trans_road->at(0);
  767. all_variables[iter->first] = origin_trans_node_info.input;
  768. }
  769. return SUCCESS;
  770. }
  771. VarManagerPool::~VarManagerPool() { Destory(); }
  772. VarManagerPool &VarManagerPool::Instance() {
  773. static VarManagerPool var_manager_pool;
  774. return var_manager_pool;
  775. }
  776. void VarManagerPool::Destory() noexcept {
  777. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  778. for (auto &it : var_manager_map_) {
  779. VarManager *var_manager = it.second;
  780. if (var_manager != nullptr) {
  781. var_manager->Destory();
  782. delete var_manager;
  783. var_manager = nullptr;
  784. }
  785. }
  786. var_manager_map_.clear();
  787. }
  788. ge::Status VarManagerPool::Init() const { return SUCCESS; }
  789. VarManager *VarManagerPool::GetVarManager(uint64_t session_id) {
  790. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  791. auto it = var_manager_map_.find(session_id);
  792. if (it != var_manager_map_.end()) {
  793. GELOGD("VarManagerPool::GetVarManager");
  794. return it->second;
  795. }
  796. VarManager *var_manager = new (std::nothrow) VarManager(session_id);
  797. if (var_manager == nullptr) {
  798. GELOGE(INTERNAL_ERROR,
  799. "VarManager::Instance find session by "
  800. "session_id[%lu] failed.",
  801. session_id);
  802. static VarManager new_var_manager(0);
  803. return &new_var_manager;
  804. }
  805. var_manager_map_[session_id] = var_manager;
  806. return var_manager;
  807. }
  808. void VarManagerPool::RemoveVarManager(uint64_t session_id) {
  809. VarManager *var_manager = nullptr;
  810. {
  811. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  812. auto it = var_manager_map_.find(session_id);
  813. if (it != var_manager_map_.end()) {
  814. var_manager = it->second;
  815. var_manager_map_.erase(it);
  816. }
  817. }
  818. if (var_manager != nullptr) {
  819. var_manager->Destory();
  820. delete var_manager;
  821. var_manager = nullptr;
  822. }
  823. }
  824. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示