You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_var_manager.cc 41 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/manager/graph_var_manager.h"
  17. #include "graph/debug/ge_attr_define.h"
  18. #include "graph/manager/graph_mem_manager.h"
  19. #include "graph/manager/trans_var_data_utils.h"
  20. #include "graph/utils/type_utils.h"
  21. using std::map;
  22. using std::string;
  23. using std::vector;
  24. namespace ge {
  25. VarResource::VarResource(uint64_t session_id) : session_id_(session_id) {}
  26. VarResource::~VarResource() {
  27. var_offset_map_.clear();
  28. var_addr_mgr_map_.clear();
  29. cur_var_tensor_desc_map_.clear();
  30. var_broad_cast_info_.clear();
  31. }
  32. ge::Status VarResource::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr,
  33. rtMemType_t &memory_type) {
  34. if (dev_ptr == nullptr) {
  35. REPORT_INNER_ERROR("E19999", "Param dev_ptr is nullptr, var_name:%s, session_id:%lu, "
  36. "check invalid", var_name.c_str(), session_id_);
  37. GELOGE(FAILED, "[GetVarAddr] dev_ptr is null!");
  38. return FAILED;
  39. }
  40. std::string var_key = VarKey(var_name, tensor_desc);
  41. GELOGD("VarResource::GetVarAddr , var_key = %s", var_key.c_str());
  42. auto iter = var_addr_mgr_map_.find(var_key);
  43. if (iter == var_addr_mgr_map_.end()) {
  44. REPORT_INNER_ERROR("E19999", "var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, "
  45. "check invalid", var_key.c_str(), var_name.c_str(),
  46. session_id_);
  47. GELOGE(FAILED, "VarResource::GetVarAddr failed, var_key %s", var_key.c_str());
  48. return FAILED;
  49. }
  50. *dev_ptr = iter->second.address;
  51. memory_type = iter->second.memory_type;
  52. return SUCCESS;
  53. }
  54. void VarResource::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) {
  55. var_addr_mgr_map = var_addr_mgr_map_;
  56. }
  57. void VarResource::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr,
  58. rtMemType_t memory_type) {
  59. std::string var_key = VarKey(var_name, tensor_desc);
  60. GELOGI("VarResource::SetVarAddr , var_key = %s, mem_type:%u", var_key.c_str(), memory_type);
  61. if (var_addr_mgr_map_.count(var_key) == 0) {
  62. GELOGI("SetVarAddr node_name %s, tensor_desc type %s, format %s", var_name.c_str(),
  63. TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  64. TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  65. VarAddrMgr var_addr_mgr;
  66. var_addr_mgr.address = dev_ptr;
  67. var_addr_mgr.tensor_desc = tensor_desc;
  68. var_addr_mgr_map_[var_key] = var_addr_mgr;
  69. }
  70. cur_var_tensor_desc_map_[var_name] = tensor_desc;
  71. }
  72. ge::Status VarResource::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address,
  73. rtMemType_t memory_type) {
  74. std::string var_key = VarKey(var_name, tensor_desc);
  75. GELOGD("VarResource::SaveVarAddr, var_key = %s", var_key.c_str());
  76. if (var_addr_mgr_map_.count(var_key) == 0) {
  77. uint64_t logic_address = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
  78. if (memory_type != RT_MEMORY_RDMA_HBM) {
  79. logic_address += VarManager::Instance(session_id_)->GetVarMemLogicBase();
  80. }
  81. GELOGI("SaveVarAddr node_name %s, tensor_desc format %s, type %s.", var_name.c_str(),
  82. TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  83. TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
  84. VarAddrMgr var_addr_mgr;
  85. var_addr_mgr.address = reinterpret_cast<uint8_t *>(static_cast<std::uintptr_t>(logic_address));
  86. var_addr_mgr.offset = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
  87. var_addr_mgr.tensor_desc = tensor_desc;
  88. var_addr_mgr.memory_type = memory_type;
  89. var_addr_mgr_map_[var_key] = var_addr_mgr;
  90. var_offset_map_[logic_address] = memory_type;
  91. return SUCCESS;
  92. }
  93. REPORT_INNER_ERROR("E19999", "var_key:%s conflict in var_addr_mgr_map_, var_name:%s, session_id:%lu, "
  94. "check invalid", var_key.c_str(), var_name.c_str(),
  95. session_id_);
  96. GELOGE(FAILED, "VarResource::SaveVarAddr, var_key %s save addr conflict", var_key.c_str());
  97. return FAILED;
  98. }
  99. bool VarResource::IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  100. std::string var_key = VarKey(var_name, tensor_desc);
  101. return var_addr_mgr_map_.count(var_key) != 0;
  102. }
  103. bool VarResource::IsVarExist(const std::string &var_name) { return cur_var_tensor_desc_map_.count(var_name) != 0; }
  104. std::string VarResource::VarKey(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  105. std::string var_key(var_name);
  106. var_key.append(std::to_string(static_cast<int32_t>(tensor_desc.GetFormat())))
  107. .append("_")
  108. .append(std::to_string(static_cast<int32_t>(tensor_desc.GetDataType())));
  109. return var_key;
  110. }
  111. ge::Status VarResource::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
  112. if (cur_var_tensor_desc_map_.count(var_name) == 0) {
  113. return FAILED;
  114. }
  115. tensor_desc = cur_var_tensor_desc_map_[var_name];
  116. return SUCCESS;
  117. }
  118. ge::Status VarResource::RenewCurVarDesc(const std::string &var_name, const ge::OpDescPtr &op_desc) {
  119. if (cur_var_tensor_desc_map_.count(var_name) == 0) {
  120. GELOGI("There is no this node[%s] in var tensor_desc map. so no need renew!", var_name.c_str());
  121. return SUCCESS;
  122. }
  123. if (op_desc == nullptr) {
  124. REPORT_INNER_ERROR("E19999", "Param op_desc is nullptr, var_name:%s, session_id:%lu, check invalid",
  125. var_name.c_str(), session_id_);
  126. GELOGE(FAILED, "[RenewCurVarDesc] renew var desc fail! input opdesc is null!");
  127. return FAILED;
  128. }
  129. ge::GeTensorDesc curr_desc;
  130. ge::Status ret = GetCurVarDesc(var_name, curr_desc);
  131. if (ret != SUCCESS) {
  132. GELOGE(FAILED, "[RenewCurVarDesc] Get var desc fail!");
  133. return FAILED;
  134. }
  135. std::string key = VarKey(var_name, curr_desc);
  136. curr_desc.SetOriginFormat((op_desc->GetOutputDesc(0)).GetOriginFormat());
  137. curr_desc.SetFormat((op_desc->GetOutputDesc(0)).GetFormat());
  138. cur_var_tensor_desc_map_[var_name] = curr_desc;
  139. auto iter = var_addr_mgr_map_.find(key);
  140. if (iter == var_addr_mgr_map_.end()) {
  141. REPORT_INNER_ERROR("E19999", "var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, op:%s(%s), "
  142. "check invalid", key.c_str(), var_name.c_str(),
  143. session_id_, op_desc->GetName().c_str(), op_desc->GetType().c_str());
  144. GELOGE(FAILED, "[RenewCurVarDesc] can't find ele with key [%s]", key.c_str());
  145. return FAILED;
  146. }
  147. auto val = iter->second;
  148. val.tensor_desc.SetOriginFormat((op_desc->GetOutputDesc(0)).GetOriginFormat());
  149. val.tensor_desc.SetFormat((op_desc->GetOutputDesc(0)).GetFormat());
  150. var_addr_mgr_map_.erase(iter);
  151. key = VarKey(var_name, curr_desc);
  152. var_addr_mgr_map_[key] = val;
  153. return SUCCESS;
  154. }
  155. void VarResource::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info) {
  156. var_broad_cast_info_[graph_id][broad_cast_info.var_name] = broad_cast_info;
  157. }
  158. ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) {
  159. if (var_broad_cast_info_.count(graph_id) == 0 || var_broad_cast_info_[graph_id].count(var_name) == 0) {
  160. return FAILED;
  161. }
  162. broad_cast_info = var_broad_cast_info_[graph_id][var_name];
  163. return SUCCESS;
  164. }
  165. ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name,
  166. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  167. GE_CHECK_NOTNULL(base_ptr);
  168. GELOGI("SyncVarData2BroadCast graph_id: %u, var_name: %s.", graph_id, var_name.c_str());
  169. VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
  170. uint8_t *dst_addr = base_ptr + var_broadcast_info.input_offset;
  171. return ge::TransVarDataUtils::SyncVarData2BroadCast(var_name, var_tensor_desc, dst_addr,
  172. var_broadcast_info.input_size, session_id_);
  173. }
  174. ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
  175. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  176. GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str());
  177. VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
  178. // subgraph base_ptr could be nullptr, task it as base 0
  179. uint8_t *dst_addr = base_ptr + var_broadcast_info.output_offset;
  180. return ge::TransVarDataUtils::SyncBroadCastData2Var(dst_addr, var_broadcast_info.output_size, var_name,
  181. var_tensor_desc, session_id_);
  182. }
  183. ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_name,
  184. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  185. return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr);
  186. }
  187. bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; }
  188. rtMemType_t VarResource::GetVarMemType(const int64_t &offset) {
  189. if (var_offset_map_.count(offset) > 0) {
  190. return var_offset_map_[offset];
  191. }
  192. return RT_MEMORY_RESERVED;
  193. }
  194. VarTransRoad *VarResource::GetTransRoad(const std::string &var_name) {
  195. auto iter = var_to_trans_road_.find(var_name);
  196. if (iter == var_to_trans_road_.end()) {
  197. return nullptr;
  198. } else {
  199. return &(iter->second);
  200. }
  201. }
  202. Status VarResource::GetChangedGraphId(const std::string &var_name, uint32_t &graph_id) {
  203. auto iter = var_names_to_changed_graph_id_.find(var_name);
  204. if (iter == var_names_to_changed_graph_id_.end()) {
  205. return FAILED;
  206. } else {
  207. graph_id = iter->second;
  208. return SUCCESS;
  209. }
  210. }
  211. Status VarResource::GetAllocatedGraphId(const std::string &var_name, uint32_t &graph_id) {
  212. auto iter = var_names_to_allocated_graph_id_.find(var_name);
  213. if (iter == var_names_to_allocated_graph_id_.end()) {
  214. return FAILED;
  215. } else {
  216. graph_id = iter->second;
  217. return SUCCESS;
  218. }
  219. }
  220. Status VarResource::SetAllocatedGraphId(const std::string &var_name, uint32_t graph_id) {
  221. if (GetAllocatedGraphId(var_name, graph_id) == SUCCESS) {
  222. GELOGW("VarManager var[%s] has been allocated in graph[%d]", var_name.c_str(), graph_id);
  223. return SUCCESS;
  224. }
  225. var_names_to_allocated_graph_id_[var_name] = graph_id;
  226. return SUCCESS;
  227. }
  228. MemResource::MemResource() : total_size_(0), var_mem_size_(0) {}
  229. MemResource *MemResource::BuildMemResourceFromType(rtMemType_t mem_type) {
  230. switch (mem_type) {
  231. case RT_MEMORY_HBM:
  232. return new (std::nothrow) HbmMemResource();
  233. case RT_MEMORY_RDMA_HBM:
  234. return new (std::nothrow) RdmaMemResource();
  235. default:
  236. return nullptr;
  237. }
  238. }
  239. Status HbmMemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id,
  240. size_t &mem_offset) {
  241. size = (size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
  242. uint64_t real_size = size;
  243. total_size_ = VarManager::Instance(session_id)->GetVarMemMaxSize();
  244. if (total_size_ < var_mem_size_) {
  245. REPORT_INNER_ERROR("E19999", "VarMemMaxSize:%lu < var_mem_size_:%lu, var_size:%lu, var_name:%s, check invalid"
  246. "", total_size_, var_mem_size_, size, var_name.c_str());
  247. GELOGE(PARAM_INVALID, "total_size_: %lu is smaller than var_mem_size_: %lu", total_size_, var_mem_size_);
  248. return PARAM_INVALID;
  249. }
  250. uint64_t free_size = total_size_ - var_mem_size_;
  251. if (free_size < (size + kSessionMemAlignSize * kSessionMemAlignUnit)) {
  252. REPORT_INNER_ERROR("E19999", "free_size:%lu not enough, var_align_size:%lu, var_name:%s, check invalid",
  253. free_size, size, var_name.c_str());
  254. GELOGE(PARAM_INVALID, "Out of memory : current var size[%lu] exceeds total var size[%lu]",
  255. size + kSessionMemAlignSize * kSessionMemAlignUnit + var_mem_size_, total_size_);
  256. return PARAM_INVALID;
  257. }
  258. mem_offset = var_mem_size_;
  259. // offset for next, align 512 BYTE
  260. size = size + kSessionMemAlignSize;
  261. var_mem_size_ = var_mem_size_ + size;
  262. // align 512 BYTE
  263. var_mem_size_ = var_mem_size_ + kSessionMemAlignSize;
  264. GELOGI(
  265. "[IMAS]AssignVarMem Set session_%lu name[%s] output[%d]"
  266. "offset to [%zu] size[%lu] realsize[%lu].",
  267. session_id, var_name.c_str(), 0, mem_offset, (var_mem_size_ - mem_offset), real_size);
  268. return SUCCESS;
  269. }
  270. Status RdmaMemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &address) {
  271. uint8_t *buffer = MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Malloc(size);
  272. if (buffer == nullptr) {
  273. REPORT_CALL_ERROR("E19999", "malloc rdma memory fail, var_size:%lu, var_name:%s",
  274. size, var_name.c_str());
  275. GELOGE(MEMALLOC_FAILED, "Failed to malloc rdma memory for node %s, size = %lu", var_name.c_str(), size);
  276. return MEMALLOC_FAILED;
  277. }
  278. address = static_cast<size_t>(reinterpret_cast<uintptr_t>(buffer));
  279. var_mem_size_ += size;
  280. GELOGI("[IMAS]AssignVarMem Set session_%lu name[%s] output[%d] addr to [%p] size[%lu].",
  281. session_id, var_name.c_str(), 0, buffer, size);
  282. return SUCCESS;
  283. }
  284. uint64_t MemResource::GetVarMemSize() const { return var_mem_size_; }
  285. void MemResource::UpdateVarMemSize(int64_t mem_size) { var_mem_size_ = mem_size; };
  286. VarManager::VarManager(uint64_t session_id)
  287. : version_(SessionVersion::OTHER_VERSION),
  288. session_id_(session_id),
  289. device_id_(0),
  290. job_id_(0),
  291. graph_mem_max_size_(kGraphMemoryManagerMallocMaxSize),
  292. var_mem_max_size_(kMemoryVarManagerMallocSize),
  293. var_mem_logic_base_(kMemoryVarLogicBase),
  294. use_max_mem_size_(kUseMaxMemorySize) {}
  295. VarManager *VarManager::Instance(uint64_t session_id) {
  296. GELOGD("VarManager::Instance, session id = %lu", session_id);
  297. return VarManagerPool::Instance().GetVarManager(session_id);
  298. }
  299. void VarManager::Destory() {
  300. std::lock_guard<std::recursive_mutex> lock(mutex_);
  301. GELOGI("VarManager::Destory, session id = %lu.", session_id_);
  302. version_ = SessionVersion::OTHER_VERSION;
  303. device_id_ = 0;
  304. session_id_ = 0;
  305. for (auto &memory_resource : mem_resource_map_) {
  306. if (memory_resource.second != nullptr) {
  307. delete memory_resource.second;
  308. memory_resource.second = nullptr;
  309. }
  310. }
  311. mem_resource_map_.clear();
  312. }
  313. ge::Status VarManager::Init(const uint32_t &version, const uint64_t &session_id, const uint32_t &device_id,
  314. const uint64_t &job_id) {
  315. std::lock_guard<std::recursive_mutex> lock(mutex_);
  316. GELOGI("VarManager::Init, session id = %lu.", session_id);
  317. if (var_resource_ == nullptr) {
  318. version_ = version;
  319. device_id_ = device_id;
  320. session_id_ = session_id;
  321. job_id_ = job_id;
  322. var_resource_ = std::unique_ptr<VarResource>(new (std::nothrow) VarResource(session_id_));
  323. if (var_resource_ == nullptr) {
  324. GELOGW("VarManager init failed session id = %lu.", session_id);
  325. return ge::INTERNAL_ERROR;
  326. }
  327. } else {
  328. GELOGW("VarManager::has been inited, session id = %lu.", session_id);
  329. }
  330. return SUCCESS;
  331. }
  332. const uint64_t &VarManager::SessionId() const {
  333. std::lock_guard<std::recursive_mutex> lock(mutex_);
  334. return session_id_;
  335. }
  336. const uint32_t &VarManager::DeviceId() const {
  337. std::lock_guard<std::recursive_mutex> lock(mutex_);
  338. return device_id_;
  339. }
  340. const uint64_t &VarManager::JobId() const {
  341. std::lock_guard<std::recursive_mutex> lock(mutex_);
  342. return job_id_;
  343. }
  344. ge::Status VarManager::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr,
  345. rtMemType_t memory_type) {
  346. GELOGI("VarManager::SetVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  347. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  348. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  349. std::lock_guard<std::recursive_mutex> lock(mutex_);
  350. if (var_resource_ == nullptr) {
  351. GELOGW("VarManager has not been init.");
  352. return ge::INTERNAL_ERROR;
  353. }
  354. var_resource_->SetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  355. return ge::SUCCESS;
  356. }
  357. ge::Status VarManager::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address,
  358. rtMemType_t memory_type) {
  359. GELOGI("VarManager::SaveVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  360. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  361. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  362. std::lock_guard<std::recursive_mutex> lock(mutex_);
  363. if (var_resource_ == nullptr) {
  364. GELOGW("VarManager has not been init.");
  365. return ge::INTERNAL_ERROR;
  366. }
  367. var_resource_->SaveVarAddr(var_name, tensor_desc, address, memory_type);
  368. return ge::SUCCESS;
  369. }
  370. ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr,
  371. rtMemType_t &memory_type) {
  372. std::lock_guard<std::recursive_mutex> lock(mutex_);
  373. GELOGD("VarManager::GetVarAddr var_name = %s, data_type = %s, data_format = %s", var_name.c_str(),
  374. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  375. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  376. if (var_resource_ == nullptr) {
  377. GELOGW("VarManager has not been init.");
  378. return ge::INTERNAL_ERROR;
  379. }
  380. auto ret = var_resource_->GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  381. if (ret != SUCCESS) {
  382. GELOGW("GetVarAddr fail.");
  383. return ge::INTERNAL_ERROR;
  384. }
  385. return SUCCESS;
  386. }
  387. ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr) {
  388. std::lock_guard<std::recursive_mutex> lock(mutex_);
  389. rtMemType_t memory_type = RT_MEMORY_HBM;
  390. return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  391. }
  392. void VarManager::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) {
  393. var_resource_->GetAllVarAddrMgr(var_addr_mgr_map);
  394. }
  395. int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) {
  396. std::lock_guard<std::recursive_mutex> lock(mutex_);
  397. MemResource *mem_resource = nullptr;
  398. auto iter = mem_resource_map_.find(memory_type);
  399. if (iter == mem_resource_map_.end()) {
  400. return 0;
  401. } else {
  402. mem_resource = iter->second;
  403. }
  404. if (mem_resource == nullptr) {
  405. REPORT_INNER_ERROR("E19999", "Find no mem_resource in map, memory_type:%d, session_id:%lu",
  406. memory_type, session_id_);
  407. GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid.");
  408. return 0;
  409. }
  410. return mem_resource->GetVarMemSize();
  411. }
  412. Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) {
  413. std::lock_guard<std::recursive_mutex> lock(mutex_);
  414. MemResource *mem_resource = nullptr;
  415. auto iter = mem_resource_map_.find(memory_type);
  416. if (iter == mem_resource_map_.end()) {
  417. mem_resource = MemResource::BuildMemResourceFromType(memory_type);
  418. if (mem_resource == nullptr) {
  419. REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu",
  420. memory_type, session_id_);
  421. GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type);
  422. return ge::INTERNAL_ERROR;
  423. } else {
  424. mem_resource_map_[memory_type] = mem_resource;
  425. }
  426. } else {
  427. mem_resource = iter->second;
  428. }
  429. if (mem_resource == nullptr) {
  430. REPORT_INNER_ERROR("E19999", "MemResource is invalid, memory_type:%d, session_id:%lu",
  431. memory_type, session_id_);
  432. GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid.");
  433. return FAILED;
  434. }
  435. mem_resource->UpdateVarMemSize(mem_size);
  436. return SUCCESS;
  437. }
  438. ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc,
  439. rtMemType_t memory_type) {
  440. std::lock_guard<std::recursive_mutex> lock(mutex_);
  441. GELOGI("VarManager::AssignVarMem var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  442. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  443. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  444. int64_t tensor_desc_size = 0;
  445. size_t mem_offset = 0;
  446. ge::Status result = TensorUtils::GetSize(tensor_desc, tensor_desc_size);
  447. if (result != ge::SUCCESS) {
  448. REPORT_CALL_ERROR("E19999", "Get size from tensor fail, var_name:%s, memory_type:%d, session_id:%lu",
  449. var_name.c_str(), memory_type, session_id_);
  450. GELOGE(result, "get size from TensorDesc failed");
  451. return result;
  452. }
  453. MemResource *mem_resource = nullptr;
  454. auto it = mem_resource_map_.find(memory_type);
  455. if (it == mem_resource_map_.end()) {
  456. mem_resource = MemResource::BuildMemResourceFromType(memory_type);
  457. if (mem_resource == nullptr) {
  458. REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu",
  459. memory_type, session_id_);
  460. GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type);
  461. return ge::INTERNAL_ERROR;
  462. } else {
  463. mem_resource_map_[memory_type] = mem_resource;
  464. }
  465. } else {
  466. mem_resource = it->second;
  467. }
  468. if (mem_resource == nullptr) {
  469. REPORT_INNER_ERROR("E19999", "MemResource is invalid, memory_type:%d, session_id:%lu",
  470. memory_type, session_id_);
  471. GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid, memory_type = %u.", memory_type);
  472. return ge::INTERNAL_ERROR;
  473. }
  474. if (var_resource_ == nullptr) {
  475. REPORT_INNER_ERROR("E19999", "VarManager has not been init, memory_type:%d, session_id:%lu, "
  476. "check invalid", memory_type, session_id_);
  477. GELOGW("VarManager has not been init.");
  478. return ge::INTERNAL_ERROR;
  479. }
  480. ge::GeTensorDesc cur_tensor_desc;
  481. int64_t cur_tensor_desc_size = 0;
  482. result = var_resource_->GetCurVarDesc(var_name, cur_tensor_desc);
  483. // reuse old format variable memory
  484. if (result == SUCCESS) {
  485. result = var_resource_->GetVarAddr(
  486. var_name, cur_tensor_desc, reinterpret_cast<uint8_t **>(reinterpret_cast<uintptr_t>(&mem_offset)), memory_type);
  487. if (result == SUCCESS) {
  488. result = TensorUtils::GetSize(cur_tensor_desc, cur_tensor_desc_size);
  489. GELOGD("tensor_desc_size is %ld, cur_tensor_desc_size is %ld, memoffset is %zu", tensor_desc_size,
  490. cur_tensor_desc_size, mem_offset);
  491. }
  492. }
  493. bool can_not_reuse_old_memory = (result != SUCCESS) || (tensor_desc_size > cur_tensor_desc_size);
  494. if (can_not_reuse_old_memory) {
  495. result = mem_resource->AssignVarMem(var_name, tensor_desc_size, session_id_, mem_offset);
  496. if (result != SUCCESS) {
  497. GELOGE(ge::INTERNAL_ERROR, "AssignVarMem by offset failed.");
  498. return ge::INTERNAL_ERROR;
  499. }
  500. result = var_resource_->SaveVarAddr(
  501. var_name, tensor_desc, reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  502. if (result != SUCCESS) {
  503. GELOGE(ge::INTERNAL_ERROR, "AssignVarMem by offset failed.");
  504. return ge::INTERNAL_ERROR;
  505. }
  506. }
  507. // old not exist only save new tensor
  508. result = var_resource_->GetCurVarDesc(var_name, cur_tensor_desc);
  509. if (result != SUCCESS) {
  510. var_resource_->SetVarAddr(var_name, tensor_desc,
  511. reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  512. return SUCCESS;
  513. }
  514. bool format_changed = cur_tensor_desc.GetFormat() != tensor_desc.GetFormat() ||
  515. cur_tensor_desc.GetDataType() != tensor_desc.GetDataType() ||
  516. cur_tensor_desc.GetShape().GetDims() != tensor_desc.GetShape().GetDims();
  517. if (format_changed) {
  518. GELOGI("var %s assigned new memory (format, data type, shape) (%s, %s, %zu) from (%s, %s, %zu)", var_name.c_str(),
  519. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  520. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  521. tensor_desc.GetShape().GetDims().size(),
  522. ge::TypeUtils::DataTypeToSerialString(cur_tensor_desc.GetDataType()).c_str(),
  523. ge::TypeUtils::FormatToSerialString(cur_tensor_desc.GetFormat()).c_str(),
  524. cur_tensor_desc.GetShape().GetDims().size());
  525. var_resource_->SetVarAddr(var_name, tensor_desc,
  526. reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  527. }
  528. return SUCCESS;
  529. }
  530. bool VarManager::IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  531. std::lock_guard<std::recursive_mutex> lock(mutex_);
  532. GELOGD("VarManager::IsVarExist var_name = %s, data_type = %s, data_format = %s", var_name.c_str(),
  533. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  534. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
  535. if (var_resource_ == nullptr) {
  536. GELOGW("VarManager has not been init.");
  537. return false;
  538. }
  539. return var_resource_->IsVarExist(var_name, tensor_desc);
  540. }
  541. bool VarManager::IsVarExist(const std::string &var_name) {
  542. std::lock_guard<std::recursive_mutex> lock(mutex_);
  543. if (var_resource_ == nullptr) {
  544. GELOGW("VarManager has not been init.");
  545. return false;
  546. }
  547. return var_resource_->IsVarExist(var_name);
  548. }
  549. ge::Status VarManager::SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc,
  550. uint8_t *base_ptr) {
  551. std::lock_guard<std::recursive_mutex> lock(mutex_);
  552. if (var_resource_ == nullptr) {
  553. GELOGW("VarManager has not been init.");
  554. return ge::INTERNAL_ERROR;
  555. }
  556. return var_resource_->SyncVarData(graph_id, var_name, var_tensor_desc, base_ptr);
  557. }
  558. ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
  559. std::lock_guard<std::recursive_mutex> lock(mutex_);
  560. GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str());
  561. if (var_resource_ == nullptr) {
  562. GELOGW("VarManager has not been init.");
  563. return ge::INTERNAL_ERROR;
  564. }
  565. return var_resource_->GetCurVarDesc(var_name, tensor_desc);
  566. }
  567. ge::Status VarManager::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info) {
  568. std::lock_guard<std::recursive_mutex> lock(mutex_);
  569. GELOGI(
  570. "VarManager::SaveBroadCastInfo var_name = %s, broadcast name = %s, "
  571. "idx = %d, input_offset = %ld, input_size = %lu, output_offset = %ld, "
  572. "output_size = %lu",
  573. broad_cast_info.var_name.c_str(), broad_cast_info.broadcast_name.c_str(), broad_cast_info.idx,
  574. broad_cast_info.input_offset, broad_cast_info.input_size, broad_cast_info.output_offset,
  575. broad_cast_info.output_size);
  576. if (var_resource_ == nullptr) {
  577. GELOGW("VarManager has not been init.");
  578. return ge::INTERNAL_ERROR;
  579. }
  580. var_resource_->SaveBroadCastInfo(graph_id, broad_cast_info);
  581. return SUCCESS;
  582. }
  583. ge::Status VarManager::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) {
  584. std::lock_guard<std::recursive_mutex> lock(mutex_);
  585. if (var_resource_ == nullptr) {
  586. GELOGW("VarManager has not been init.");
  587. return ge::INTERNAL_ERROR;
  588. }
  589. return var_resource_->GetBroadCastInfo(graph_id, var_name, broad_cast_info);
  590. }
  591. ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) {
  592. std::lock_guard<std::recursive_mutex> lock(mutex_);
  593. GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str());
  594. if (var_resource_ == nullptr) {
  595. REPORT_INNER_ERROR("E19999", "VarManager has not been init, op:%s(%s), session_id:%lu, check invalid",
  596. op_desc->GetName().c_str(), op_desc->GetType().c_str(),
  597. session_id_);
  598. GELOGE(ge::INTERNAL_ERROR, "VarManager has not been init.");
  599. return ge::INTERNAL_ERROR;
  600. }
  601. return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc));
  602. }
  603. ge::Status VarManager::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
  604. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  605. std::lock_guard<std::recursive_mutex> lock(mutex_);
  606. if (var_resource_ == nullptr) {
  607. GELOGW("VarManager has not been init.");
  608. return ge::INTERNAL_ERROR;
  609. }
  610. return var_resource_->SyncBroadCastData2Var(graph_id, var_name, var_tensor_desc, base_ptr);
  611. }
  612. bool VarManager::IsVarAddr(const int64_t &offset) {
  613. std::lock_guard<std::recursive_mutex> lock(mutex_);
  614. if (var_resource_ == nullptr) {
  615. GELOGD("VarManager has not been init.");
  616. return false;
  617. }
  618. return var_resource_->IsVarAddr(offset);
  619. }
  620. rtMemType_t VarManager::GetVarMemType(const int64_t &offset) {
  621. std::lock_guard<std::recursive_mutex> lock(mutex_);
  622. if (var_resource_ == nullptr) {
  623. GELOGW("VarManager has not been init.");
  624. return RT_MEMORY_RESERVED;
  625. }
  626. return var_resource_->GetVarMemType(offset);
  627. }
  628. ge::Status VarManager::MallocVarMemory(size_t memory_size) {
  629. std::lock_guard<std::recursive_mutex> lock(mutex_);
  630. uint8_t *var_mem_base = nullptr;
  631. string memory_key = std::to_string(session_id_);
  632. // malloc variable memory
  633. size_t var_memory_size = memory_size;
  634. // align 512 BYTE
  635. var_memory_size = (var_memory_size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
  636. const string purpose("variables and constant op memory in training network.");
  637. var_mem_base = MemManager::Instance().MemInstance(RT_MEMORY_HBM).MallocMemory(purpose, memory_key, var_memory_size);
  638. if (var_mem_base == nullptr) {
  639. GELOGE(ge::INTERNAL_ERROR,
  640. "VarManager::MallocVarMemory failed "
  641. "session_id = %s",
  642. memory_key.c_str());
  643. return ge::INTERNAL_ERROR;
  644. }
  645. return SUCCESS;
  646. }
  647. uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type) {
  648. std::lock_guard<std::recursive_mutex> lock(mutex_);
  649. if (memory_type == RT_MEMORY_RDMA_HBM) {
  650. return MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).GetRdmaBaseAddr();
  651. }
  652. string memory_key = std::to_string(session_id_);
  653. return MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(memory_key);
  654. }
  655. uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) {
  656. std::lock_guard<std::recursive_mutex> lock(mutex_);
  657. if (memory_type == RT_MEMORY_RDMA_HBM) {
  658. return logic_addr;
  659. }
  660. string mem_key = std::to_string(session_id_);
  661. uint8_t *mem_base = MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(mem_key);
  662. if (mem_base == nullptr) {
  663. return nullptr;
  664. }
  665. uint8_t *mem_addr =
  666. logic_addr + reinterpret_cast<intptr_t>(mem_base) - VarManager::Instance(session_id_)->GetVarMemLogicBase();
  667. return mem_addr;
  668. }
  669. ge::Status VarManager::FreeVarMemory() {
  670. std::lock_guard<std::recursive_mutex> lock(mutex_);
  671. string memory_key = std::to_string(SessionId());
  672. return MemManager::Instance().MemInstance(RT_MEMORY_HBM).FreeMemory(memory_key);
  673. }
  674. ge::Status VarManager::SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) {
  675. std::lock_guard<std::recursive_mutex> lock(mutex_);
  676. if (var_resource_ == nullptr) {
  677. GELOGW("VarManager has not been init.");
  678. return ge::INTERNAL_ERROR;
  679. }
  680. return var_resource_->SetTransRoad(var_name, trans_road);
  681. }
  682. VarTransRoad *VarManager::GetTransRoad(const std::string &var_name) {
  683. std::lock_guard<std::recursive_mutex> lock(mutex_);
  684. if (var_resource_ == nullptr) {
  685. GELOGW("VarManager has not been init.");
  686. return nullptr;
  687. }
  688. return var_resource_->GetTransRoad(var_name);
  689. }
  690. Status VarManager::SetChangedGraphId(const std::string &var_name, uint32_t graph_id) {
  691. std::lock_guard<std::recursive_mutex> lock(mutex_);
  692. if (var_resource_ == nullptr) {
  693. GELOGW("VarManager has not been init.");
  694. return INTERNAL_ERROR;
  695. }
  696. return var_resource_->SetChangedGraphId(var_name, graph_id);
  697. }
  698. Status VarManager::GetChangedGraphId(const std::string &var_name, uint32_t &graph_id) {
  699. std::lock_guard<std::recursive_mutex> lock(mutex_);
  700. if (var_resource_ == nullptr) {
  701. GELOGW("VarManager has not been init.");
  702. return INTERNAL_ERROR;
  703. }
  704. return var_resource_->GetChangedGraphId(var_name, graph_id);
  705. }
  706. Status VarManager::SetMemoryMallocSize(const map<string, string> &options) {
  707. auto it = options.find(GRAPH_MEMORY_MAX_SIZE);
  708. if (it == options.end()) {
  709. graph_mem_max_size_ = kGraphMemoryManagerMallocMaxSize;
  710. } else {
  711. string graph_memory_manager_malloc_max_size = it->second;
  712. ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_);
  713. if (ret != SUCCESS) {
  714. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse graph memory manager malloc max size failed.");
  715. return ge::GE_GRAPH_OPTIONS_INVALID;
  716. }
  717. GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_);
  718. }
  719. it = options.find(VARIABLE_MEMORY_MAX_SIZE);
  720. if (it == options.end()) {
  721. var_mem_max_size_ = kMemoryVarManagerMallocSize;
  722. } else {
  723. string memory_var_manager_malloc_size = it->second;
  724. ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_);
  725. if (ret != SUCCESS) {
  726. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse memory var manager malloc size failed.");
  727. return ge::GE_GRAPH_OPTIONS_INVALID;
  728. }
  729. }
  730. var_mem_logic_base_ = graph_mem_max_size_ + kGraphMemoryBuffer;
  731. if (var_mem_logic_base_ > kMaxMemorySize) {
  732. REPORT_INNER_ERROR("E19999", "var_login_base:%zu can not exeed limit:%zu, session_id:%lu, check invalid",
  733. var_mem_logic_base_, kMaxMemorySize, session_id_);
  734. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "kMemoryVarLogicBase : %zu can not exceed max memory size : %zu.",
  735. var_mem_logic_base_, kMaxMemorySize);
  736. return ge::GE_GRAPH_OPTIONS_INVALID;
  737. }
  738. use_max_mem_size_ = graph_mem_max_size_ + var_mem_max_size_;
  739. if (use_max_mem_size_ > kMaxMemorySize) {
  740. REPORT_INNER_ERROR("E19999", "all mem_use size:%zu can not exeed limit:%zu, session_id:%lu, check invalid",
  741. use_max_mem_size_, kMaxMemorySize, session_id_);
  742. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "kUseMaxMemorySize : %zu can not exceed max memory size : %zu.",
  743. use_max_mem_size_, kMaxMemorySize);
  744. return ge::GE_GRAPH_OPTIONS_INVALID;
  745. }
  746. GELOGI("Set memory malloc size successfully");
  747. return SUCCESS;
  748. }
  749. Status VarManager::ParseMemoryMallocSize(string &memory_size, size_t &result) {
  750. if (memory_size.empty()) {
  751. REPORT_INNER_ERROR("E19999", "Param memory_size is empty, session_id:%lu, check invalid",
  752. session_id_);
  753. GELOGE(GE_GRAPH_OPTIONS_INVALID, "Memory malloc size input is empty.");
  754. return GE_GRAPH_OPTIONS_INVALID;
  755. }
  756. // split string by '*'
  757. vector<string> splits;
  758. std::istringstream str(memory_size);
  759. string str_split;
  760. while (getline(str, str_split, '*')) {
  761. splits.emplace_back(str_split);
  762. }
  763. result = 1;
  764. for (string split : splits) {
  765. // Trim
  766. auto it = split.find_first_not_of(" ");
  767. if (it != string::npos) {
  768. split.erase(0, it);
  769. }
  770. it = split.find_last_not_of(" ");
  771. if (it != string::npos) {
  772. split.erase(it + 1);
  773. }
  774. for (char c : split) {
  775. if (!isdigit(c)) {
  776. REPORT_INNER_ERROR("E19999", "Param memory_size:%s contains non digit, session_id:%lu, check invalid",
  777. memory_size.c_str(), session_id_);
  778. GELOGE(GE_GRAPH_OPTIONS_INVALID, "Memory malloc size input contains non digit.");
  779. return GE_GRAPH_OPTIONS_INVALID;
  780. }
  781. }
  782. uint64_t num = std::strtoul(split.c_str(), nullptr, 0);
  783. GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(result, static_cast<uint32_t>(num)),
  784. REPORT_INNER_ERROR("E19999", "Param memory_size:%s will overflow after multi all, session_id:%lu, "
  785. "check invalid", memory_size.c_str(),
  786. session_id_);
  787. GELOGE(FAILED, "Input memory size is out of range.");
  788. return FAILED);
  789. if ((num > kMaxMemorySize) || (result * static_cast<size_t>(num) > kMaxMemorySize)) {
  790. REPORT_INNER_ERROR("E19999", "Param memory_size:%s after multi will exceed limit:%lu, session_id:%lu, "
  791. "check invalid", memory_size.c_str(), kMaxMemorySize,
  792. session_id_);
  793. GELOGE(FAILED, "Input memory size can not exceed max memory size : %zu.", kMaxMemorySize);
  794. return FAILED;
  795. }
  796. result *= static_cast<size_t>(num);
  797. }
  798. return SUCCESS;
  799. }
  800. void VarManager::RemoveChangedGraphId(const std::string &var_name) {
  801. std::lock_guard<std::recursive_mutex> lock(mutex_);
  802. if (var_resource_ == nullptr) {
  803. GELOGW("VarManager has not been init.");
  804. return;
  805. }
  806. var_resource_->RemoveChangedGraphId(var_name);
  807. }
  808. Status VarManager::SetAllocatedGraphId(const std::string &var_name, uint32_t graph_id) {
  809. std::lock_guard<std::recursive_mutex> lock(mutex_);
  810. if (var_resource_ == nullptr) {
  811. GELOGW("VarManager has not been init.");
  812. return INTERNAL_ERROR;
  813. }
  814. return var_resource_->SetAllocatedGraphId(var_name, graph_id);
  815. }
  816. Status VarManager::GetAllocatedGraphId(const std::string &var_name, uint32_t &graph_id) {
  817. std::lock_guard<std::recursive_mutex> lock(mutex_);
  818. if (var_resource_ == nullptr) {
  819. GELOGW("VarManager has not been init.");
  820. return INTERNAL_ERROR;
  821. }
  822. return var_resource_->GetAllocatedGraphId(var_name, graph_id);
  823. }
  824. void VarManager::RemoveAllocatedGraphId(const std::string &var_name) {
  825. std::lock_guard<std::recursive_mutex> lock(mutex_);
  826. if (var_resource_ == nullptr) {
  827. GELOGW("VarManager has not been init.");
  828. return;
  829. }
  830. var_resource_->RemoveAllocatedGraphId(var_name);
  831. }
  832. Status VarManager::GetAllVariables(std::map<std::string, GeTensorDesc> &all_variables) {
  833. std::lock_guard<std::recursive_mutex> lock(mutex_);
  834. if (var_resource_ == nullptr) {
  835. GELOGW("VarManager has not been inited.");
  836. return INTERNAL_ERROR;
  837. }
  838. auto new_variable_desc = var_resource_->GetAllVarDesc();
  839. if (new_variable_desc.size() == 0) {
  840. GELOGW("VarManager don't have variables.");
  841. return INTERNAL_ERROR;
  842. }
  843. for (auto iter = new_variable_desc.begin(); iter != new_variable_desc.end(); ++iter) {
  844. auto trans_road = var_resource_->GetTransRoad(iter->first);
  845. if (trans_road == nullptr || trans_road->empty()) {
  846. GELOGI("The variable %s does not have any trans road", iter->first.c_str());
  847. all_variables[iter->first] = iter->second;
  848. continue;
  849. }
  850. // get origin trans info : the first trans node info
  851. auto origin_trans_node_info = trans_road->at(0);
  852. all_variables[iter->first] = origin_trans_node_info.input;
  853. }
  854. return SUCCESS;
  855. }
  856. VarManagerPool::~VarManagerPool() { Destory(); }
  857. VarManagerPool &VarManagerPool::Instance() {
  858. static VarManagerPool var_manager_pool;
  859. return var_manager_pool;
  860. }
  861. void VarManagerPool::Destory() noexcept {
  862. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  863. for (auto &it : var_manager_map_) {
  864. VarManager *var_manager = it.second;
  865. if (var_manager != nullptr) {
  866. var_manager->Destory();
  867. delete var_manager;
  868. var_manager = nullptr;
  869. }
  870. }
  871. var_manager_map_.clear();
  872. }
  873. ge::Status VarManagerPool::Init() const { return SUCCESS; }
  874. VarManager *VarManagerPool::GetVarManager(uint64_t session_id) {
  875. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  876. auto it = var_manager_map_.find(session_id);
  877. if (it != var_manager_map_.end()) {
  878. GELOGD("VarManagerPool::GetVarManager");
  879. return it->second;
  880. }
  881. VarManager *var_manager = new (std::nothrow) VarManager(session_id);
  882. if (var_manager == nullptr) {
  883. REPORT_INNER_ERROR("E19999", "New VarManager fail, session_id:%lu", session_id);
  884. GELOGE(INTERNAL_ERROR,
  885. "VarManager::Instance find session by "
  886. "session_id[%lu] failed.",
  887. session_id);
  888. static VarManager new_var_manager(0);
  889. return &new_var_manager;
  890. }
  891. var_manager_map_[session_id] = var_manager;
  892. return var_manager;
  893. }
  894. void VarManagerPool::RemoveVarManager(uint64_t session_id) {
  895. VarManager *var_manager = nullptr;
  896. {
  897. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  898. auto it = var_manager_map_.find(session_id);
  899. if (it != var_manager_map_.end()) {
  900. var_manager = it->second;
  901. var_manager_map_.erase(it);
  902. }
  903. }
  904. if (var_manager != nullptr) {
  905. var_manager->Destory();
  906. delete var_manager;
  907. var_manager = nullptr;
  908. }
  909. }
  910. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示