You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_var_manager.cc 45 kB

5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/manager/graph_var_manager.h"
  17. #include "graph/debug/ge_attr_define.h"
  18. #include "graph/manager/graph_mem_manager.h"
  19. #include "graph/manager/trans_var_data_utils.h"
  20. #include "graph/utils/type_utils.h"
  21. #include "graph/ge_context.h"
  22. using std::map;
  23. using std::string;
  24. using std::vector;
  25. namespace ge {
  26. VarResource::VarResource(uint64_t session_id) : session_id_(session_id) {}
  27. VarResource::~VarResource() {
  28. var_offset_map_.clear();
  29. var_addr_mgr_map_.clear();
  30. cur_var_tensor_desc_map_.clear();
  31. var_broad_cast_info_.clear();
  32. }
  33. ge::Status VarResource::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr,
  34. rtMemType_t &memory_type) {
  35. if (dev_ptr == nullptr) {
  36. REPORT_INNER_ERROR("E19999", "Param dev_ptr is nullptr, var_name:%s, session_id:%lu, "
  37. "check invalid", var_name.c_str(), session_id_);
  38. GELOGE(FAILED, "[Check][Param] Param dev_ptr is nullptr, var_name:%s, session_id:%lu",
  39. var_name.c_str(), session_id_);
  40. return FAILED;
  41. }
  42. std::string var_key = VarKey(var_name, tensor_desc);
  43. GELOGD("VarResource::GetVarAddr , var_key = %s", var_key.c_str());
  44. auto iter = var_addr_mgr_map_.find(var_key);
  45. if (iter == var_addr_mgr_map_.end()) {
  46. REPORT_INNER_ERROR("E19999", "var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, "
  47. "check invalid", var_key.c_str(), var_name.c_str(),
  48. session_id_);
  49. GELOGE(FAILED, "[Check][Param] var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu",
  50. var_key.c_str(), var_name.c_str(), session_id_);
  51. return FAILED;
  52. }
  53. *dev_ptr = iter->second.address;
  54. memory_type = iter->second.memory_type;
  55. return SUCCESS;
  56. }
  57. void VarResource::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) {
  58. var_addr_mgr_map = var_addr_mgr_map_;
  59. }
  60. void VarResource::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr,
  61. rtMemType_t memory_type) {
  62. std::string var_key = VarKey(var_name, tensor_desc);
  63. GELOGI("VarResource::SetVarAddr , var_key = %s, mem_type:%u", var_key.c_str(), memory_type);
  64. if (var_addr_mgr_map_.count(var_key) == 0) {
  65. GELOGI("SetVarAddr node_name %s, tensor_desc type %s, format %s", var_name.c_str(),
  66. TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  67. TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  68. VarAddrMgr var_addr_mgr;
  69. var_addr_mgr.address = dev_ptr;
  70. var_addr_mgr.tensor_desc = tensor_desc;
  71. var_addr_mgr_map_[var_key] = var_addr_mgr;
  72. }
  73. cur_var_tensor_desc_map_[var_name] = tensor_desc;
  74. }
  75. ge::Status VarResource::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address,
  76. rtMemType_t memory_type) {
  77. std::string var_key = VarKey(var_name, tensor_desc);
  78. GELOGD("VarResource::SaveVarAddr, var_key = %s", var_key.c_str());
  79. if (var_addr_mgr_map_.count(var_key) == 0) {
  80. uint64_t logic_address = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
  81. if (memory_type != RT_MEMORY_RDMA_HBM) {
  82. logic_address += VarManager::Instance(session_id_)->GetVarMemLogicBase();
  83. }
  84. GELOGI("SaveVarAddr node_name %s, tensor_desc format %s, type %s.", var_name.c_str(),
  85. TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  86. TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
  87. VarAddrMgr var_addr_mgr;
  88. var_addr_mgr.address = reinterpret_cast<uint8_t *>(static_cast<std::uintptr_t>(logic_address));
  89. var_addr_mgr.offset = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
  90. var_addr_mgr.tensor_desc = tensor_desc;
  91. var_addr_mgr.memory_type = memory_type;
  92. var_addr_mgr_map_[var_key] = var_addr_mgr;
  93. var_offset_map_[logic_address] = memory_type;
  94. return SUCCESS;
  95. }
  96. REPORT_INNER_ERROR("E19999", "var_key:%s conflict in var_addr_mgr_map_, var_name:%s, session_id:%lu, "
  97. "check invalid", var_key.c_str(), var_name.c_str(),
  98. session_id_);
  99. GELOGE(FAILED, "[Check][Param] var_key:%s conflict in var_addr_mgr_map_, var_name:%s, session_id:%lu",
  100. var_key.c_str(), var_name.c_str(), session_id_);
  101. return FAILED;
  102. }
  103. bool VarResource::IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  104. std::string var_key = VarKey(var_name, tensor_desc);
  105. return var_addr_mgr_map_.count(var_key) != 0;
  106. }
  107. bool VarResource::IsVarExist(const std::string &var_name) { return cur_var_tensor_desc_map_.count(var_name) != 0; }
  108. std::string VarResource::VarKey(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  109. std::string var_key(var_name);
  110. var_key.append(std::to_string(static_cast<int32_t>(tensor_desc.GetFormat())))
  111. .append("_")
  112. .append(std::to_string(static_cast<int32_t>(tensor_desc.GetDataType())));
  113. return var_key;
  114. }
  115. ge::Status VarResource::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
  116. if (cur_var_tensor_desc_map_.count(var_name) == 0) {
  117. return FAILED;
  118. }
  119. tensor_desc = cur_var_tensor_desc_map_[var_name];
  120. return SUCCESS;
  121. }
  122. ge::Status VarResource::RenewCurVarDesc(const std::string &var_name, const ge::OpDescPtr &op_desc) {
  123. if (cur_var_tensor_desc_map_.count(var_name) == 0) {
  124. GELOGI("There is no this node[%s] in var tensor_desc map. so no need renew!", var_name.c_str());
  125. return SUCCESS;
  126. }
  127. if (op_desc == nullptr) {
  128. REPORT_INNER_ERROR("E19999", "Param op_desc is nullptr, var_name:%s, session_id:%lu, check invalid",
  129. var_name.c_str(), session_id_);
  130. GELOGE(FAILED, "[Check][Param] input opdesc is nullptr, var_name:%s, session_id:%lu",
  131. var_name.c_str(), session_id_);
  132. return FAILED;
  133. }
  134. ge::GeTensorDesc curr_desc;
  135. ge::Status ret = GetCurVarDesc(var_name, curr_desc);
  136. if (ret != SUCCESS) {
  137. GELOGE(FAILED, "[Get][CurVarDesc] fail, var_name:%s, session_id:%lu", var_name.c_str(), session_id_);
  138. return FAILED;
  139. }
  140. std::string key = VarKey(var_name, curr_desc);
  141. curr_desc.SetOriginFormat((op_desc->GetOutputDesc(0)).GetOriginFormat());
  142. curr_desc.SetFormat((op_desc->GetOutputDesc(0)).GetFormat());
  143. cur_var_tensor_desc_map_[var_name] = curr_desc;
  144. auto iter = var_addr_mgr_map_.find(key);
  145. if (iter == var_addr_mgr_map_.end()) {
  146. REPORT_INNER_ERROR("E19999", "var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, op:%s(%s), "
  147. "check invalid", key.c_str(), var_name.c_str(),
  148. session_id_, op_desc->GetName().c_str(), op_desc->GetType().c_str());
  149. GELOGE(FAILED, "[Check][Param] var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, op:%s(%s)",
  150. key.c_str(), var_name.c_str(), session_id_, op_desc->GetName().c_str(), op_desc->GetType().c_str());
  151. return FAILED;
  152. }
  153. auto val = iter->second;
  154. val.tensor_desc.SetOriginFormat((op_desc->GetOutputDesc(0)).GetOriginFormat());
  155. val.tensor_desc.SetFormat((op_desc->GetOutputDesc(0)).GetFormat());
  156. var_addr_mgr_map_.erase(iter);
  157. key = VarKey(var_name, curr_desc);
  158. var_addr_mgr_map_[key] = val;
  159. return SUCCESS;
  160. }
  161. void VarResource::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info) {
  162. var_broad_cast_info_[graph_id][broad_cast_info.var_name] = broad_cast_info;
  163. }
  164. ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) {
  165. if (var_broad_cast_info_.count(graph_id) == 0 || var_broad_cast_info_[graph_id].count(var_name) == 0) {
  166. return FAILED;
  167. }
  168. broad_cast_info = var_broad_cast_info_[graph_id][var_name];
  169. return SUCCESS;
  170. }
  171. ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name,
  172. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  173. GE_CHECK_NOTNULL(base_ptr);
  174. GELOGI("SyncVarData2BroadCast graph_id: %u, var_name: %s.", graph_id, var_name.c_str());
  175. VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
  176. uint8_t *dst_addr = base_ptr + var_broadcast_info.input_offset;
  177. return ge::TransVarDataUtils::SyncVarData2BroadCast(var_name, var_tensor_desc, dst_addr,
  178. var_broadcast_info.input_size, session_id_);
  179. }
  180. ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
  181. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  182. GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str());
  183. VarBroadCastInfo var_broadcast_info = var_broad_cast_info_[graph_id][var_name];
  184. // subgraph base_ptr could be nullptr, task it as base 0
  185. uint8_t *dst_addr = base_ptr + var_broadcast_info.output_offset;
  186. return ge::TransVarDataUtils::SyncBroadCastData2Var(dst_addr, var_broadcast_info.output_size, var_name,
  187. var_tensor_desc, session_id_);
  188. }
  189. ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_name,
  190. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  191. return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr);
  192. }
  193. bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; }
  194. rtMemType_t VarResource::GetVarMemType(const int64_t &offset) {
  195. if (var_offset_map_.count(offset) > 0) {
  196. return var_offset_map_[offset];
  197. }
  198. return RT_MEMORY_RESERVED;
  199. }
  200. VarTransRoad *VarResource::GetTransRoad(const std::string &var_name) {
  201. auto iter = var_to_trans_road_.find(var_name);
  202. if (iter == var_to_trans_road_.end()) {
  203. return nullptr;
  204. } else {
  205. return &(iter->second);
  206. }
  207. }
  208. Status VarResource::GetChangedGraphId(const std::string &var_name, uint32_t &graph_id) {
  209. auto iter = var_names_to_changed_graph_id_.find(var_name);
  210. if (iter == var_names_to_changed_graph_id_.end()) {
  211. return FAILED;
  212. } else {
  213. graph_id = iter->second;
  214. return SUCCESS;
  215. }
  216. }
  217. Status VarResource::GetAllocatedGraphId(const std::string &var_name, uint32_t &graph_id) {
  218. auto iter = var_names_to_allocated_graph_id_.find(var_name);
  219. if (iter == var_names_to_allocated_graph_id_.end()) {
  220. return FAILED;
  221. } else {
  222. graph_id = iter->second;
  223. return SUCCESS;
  224. }
  225. }
  226. Status VarResource::SetAllocatedGraphId(const std::string &var_name, uint32_t graph_id) {
  227. if (GetAllocatedGraphId(var_name, graph_id) == SUCCESS) {
  228. GELOGW("VarManager var[%s] has been allocated in graph[%d]", var_name.c_str(), graph_id);
  229. return SUCCESS;
  230. }
  231. var_names_to_allocated_graph_id_[var_name] = graph_id;
  232. return SUCCESS;
  233. }
  234. MemResource::MemResource() : total_size_(0), var_mem_size_(0) {}
  235. MemResource *MemResource::BuildMemResourceFromType(rtMemType_t mem_type) {
  236. switch (mem_type) {
  237. case RT_MEMORY_HBM:
  238. return new (std::nothrow) HbmMemResource();
  239. case RT_MEMORY_RDMA_HBM:
  240. return new (std::nothrow) RdmaMemResource();
  241. default:
  242. return nullptr;
  243. }
  244. }
  245. Status HbmMemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id,
  246. size_t &mem_offset) {
  247. size = (size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
  248. uint64_t real_size = size;
  249. total_size_ = VarManager::Instance(session_id)->GetVarMemMaxSize();
  250. if (total_size_ < var_mem_size_) {
  251. REPORT_INNER_ERROR("E19999", "VarMemMaxSize:%lu < var_mem_size_:%lu, var_size:%lu, var_name:%s, check invalid"
  252. "", total_size_, var_mem_size_, size, var_name.c_str());
  253. GELOGE(PARAM_INVALID, "[Check][Param] total_size_:%lu is smaller than var_mem_size_:%lu, var_name:%s",
  254. total_size_, var_mem_size_, var_name.c_str());
  255. return PARAM_INVALID;
  256. }
  257. uint64_t free_size = total_size_ - var_mem_size_;
  258. if (free_size < (size + kSessionMemAlignSize * kSessionMemAlignUnit)) {
  259. REPORT_INNER_ERROR("E19999", "free_size:%lu not enough, var_align_size:%lu, var_name:%s, check invalid",
  260. free_size, size, var_name.c_str());
  261. GELOGE(PARAM_INVALID, "[Check][Param] Out of memory: current var size[%lu] exceeds total var size[%lu]",
  262. size + kSessionMemAlignSize * kSessionMemAlignUnit + var_mem_size_, total_size_);
  263. return PARAM_INVALID;
  264. }
  265. mem_offset = var_mem_size_;
  266. // offset for next, align 512 BYTE
  267. size = size + kSessionMemAlignSize;
  268. var_mem_size_ = var_mem_size_ + size;
  269. // align 512 BYTE
  270. var_mem_size_ = var_mem_size_ + kSessionMemAlignSize;
  271. GELOGI(
  272. "[IMAS]AssignVarMem Set session_%lu name[%s] output[%d]"
  273. "offset to [%zu] size[%lu] realsize[%lu].",
  274. session_id, var_name.c_str(), 0, mem_offset, (var_mem_size_ - mem_offset), real_size);
  275. return SUCCESS;
  276. }
  277. Status RdmaMemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &address) {
  278. uint8_t *buffer = MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Malloc(size);
  279. if (buffer == nullptr) {
  280. REPORT_CALL_ERROR("E19999", "malloc rdma memory fail, var_size:%lu, var_name:%s",
  281. size, var_name.c_str());
  282. GELOGE(MEMALLOC_FAILED, "[Malloc][RdmaMemory] for node %s failed, size = %lu", var_name.c_str(), size);
  283. return MEMALLOC_FAILED;
  284. }
  285. address = static_cast<size_t>(reinterpret_cast<uintptr_t>(buffer));
  286. var_mem_size_ += size;
  287. GELOGI("[IMAS]AssignVarMem Set session_%lu name[%s] output[%d] addr to [%p] size[%lu].",
  288. session_id, var_name.c_str(), 0, buffer, size);
  289. return SUCCESS;
  290. }
  291. uint64_t MemResource::GetVarMemSize() const { return var_mem_size_; }
  292. void MemResource::UpdateVarMemSize(int64_t mem_size) { var_mem_size_ = mem_size; };
  293. VarManager::VarManager(uint64_t session_id)
  294. : version_(SessionVersion::OTHER_VERSION),
  295. session_id_(session_id),
  296. device_id_(0),
  297. job_id_(0),
  298. graph_mem_max_size_(kGraphMemoryManagerMallocMaxSize),
  299. var_mem_max_size_(kMemoryVarManagerMallocSize),
  300. var_mem_logic_base_(kMemoryVarLogicBase),
  301. use_max_mem_size_(kUseMaxMemorySize) {}
  302. VarManager *VarManager::Instance(uint64_t session_id) {
  303. GELOGD("VarManager::Instance, session id = %lu", session_id);
  304. return VarManagerPool::Instance().GetVarManager(session_id);
  305. }
  306. void VarManager::Destory() {
  307. std::lock_guard<std::recursive_mutex> lock(mutex_);
  308. GELOGI("VarManager::Destory, session id = %lu.", session_id_);
  309. version_ = SessionVersion::OTHER_VERSION;
  310. device_id_ = 0;
  311. session_id_ = 0;
  312. for (auto &memory_resource : mem_resource_map_) {
  313. if (memory_resource.second != nullptr) {
  314. delete memory_resource.second;
  315. memory_resource.second = nullptr;
  316. }
  317. }
  318. mem_resource_map_.clear();
  319. }
  320. ge::Status VarManager::Init(const uint32_t &version, const uint64_t &session_id, const uint32_t &device_id,
  321. const uint64_t &job_id) {
  322. std::lock_guard<std::recursive_mutex> lock(mutex_);
  323. GELOGI("VarManager::Init, session id = %lu.", session_id);
  324. if (var_resource_ == nullptr) {
  325. version_ = version;
  326. device_id_ = device_id;
  327. session_id_ = session_id;
  328. job_id_ = job_id;
  329. var_resource_ = std::unique_ptr<VarResource>(new (std::nothrow) VarResource(session_id_));
  330. if (var_resource_ == nullptr) {
  331. GELOGW("VarManager init failed session id = %lu.", session_id);
  332. return ge::INTERNAL_ERROR;
  333. }
  334. } else {
  335. GELOGW("VarManager::has been inited, session id = %lu.", session_id);
  336. }
  337. return SUCCESS;
  338. }
  339. const uint64_t &VarManager::SessionId() const {
  340. std::lock_guard<std::recursive_mutex> lock(mutex_);
  341. return session_id_;
  342. }
  343. const uint32_t &VarManager::DeviceId() const {
  344. std::lock_guard<std::recursive_mutex> lock(mutex_);
  345. return device_id_;
  346. }
  347. const uint64_t &VarManager::JobId() const {
  348. std::lock_guard<std::recursive_mutex> lock(mutex_);
  349. return job_id_;
  350. }
  351. ge::Status VarManager::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr,
  352. rtMemType_t memory_type) {
  353. GELOGI("VarManager::SetVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  354. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  355. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  356. std::lock_guard<std::recursive_mutex> lock(mutex_);
  357. if (var_resource_ == nullptr) {
  358. GELOGW("VarManager has not been init.");
  359. return ge::INTERNAL_ERROR;
  360. }
  361. var_resource_->SetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  362. return ge::SUCCESS;
  363. }
  364. ge::Status VarManager::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address,
  365. rtMemType_t memory_type) {
  366. GELOGI("VarManager::SaveVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  367. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  368. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  369. std::lock_guard<std::recursive_mutex> lock(mutex_);
  370. if (var_resource_ == nullptr) {
  371. GELOGW("VarManager has not been init.");
  372. return ge::INTERNAL_ERROR;
  373. }
  374. var_resource_->SaveVarAddr(var_name, tensor_desc, address, memory_type);
  375. return ge::SUCCESS;
  376. }
  377. ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr,
  378. rtMemType_t &memory_type) {
  379. std::lock_guard<std::recursive_mutex> lock(mutex_);
  380. GELOGD("VarManager::GetVarAddr var_name = %s, data_type = %s, data_format = %s", var_name.c_str(),
  381. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  382. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  383. if (var_resource_ == nullptr) {
  384. GELOGW("VarManager has not been init.");
  385. return ge::INTERNAL_ERROR;
  386. }
  387. auto ret = var_resource_->GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  388. if (ret != SUCCESS) {
  389. GELOGW("GetVarAddr fail.");
  390. return ge::INTERNAL_ERROR;
  391. }
  392. return SUCCESS;
  393. }
  394. ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr) {
  395. std::lock_guard<std::recursive_mutex> lock(mutex_);
  396. rtMemType_t memory_type = RT_MEMORY_HBM;
  397. return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type);
  398. }
  399. void VarManager::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) {
  400. var_resource_->GetAllVarAddrMgr(var_addr_mgr_map);
  401. }
  402. int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) {
  403. std::lock_guard<std::recursive_mutex> lock(mutex_);
  404. MemResource *mem_resource = nullptr;
  405. auto iter = mem_resource_map_.find(memory_type);
  406. if (iter == mem_resource_map_.end()) {
  407. return 0;
  408. } else {
  409. mem_resource = iter->second;
  410. }
  411. if (mem_resource == nullptr) {
  412. REPORT_INNER_ERROR("E19999", "Find no mem_resource in map, memory_type:%d, session_id:%lu",
  413. memory_type, session_id_);
  414. GELOGE(ge::INTERNAL_ERROR, "[Check][Param] MemResource is invalid, memory_type:%d, session_id:%lu",
  415. memory_type, session_id_);
  416. return 0;
  417. }
  418. return mem_resource->GetVarMemSize();
  419. }
  420. Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) {
  421. std::lock_guard<std::recursive_mutex> lock(mutex_);
  422. MemResource *mem_resource = nullptr;
  423. auto iter = mem_resource_map_.find(memory_type);
  424. if (iter == mem_resource_map_.end()) {
  425. mem_resource = MemResource::BuildMemResourceFromType(memory_type);
  426. if (mem_resource == nullptr) {
  427. REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu",
  428. memory_type, session_id_);
  429. GELOGE(ge::INTERNAL_ERROR, "[Alloc][MemResource] failed, memory_type:%u, session_id:%lu",
  430. memory_type, session_id_);
  431. return ge::INTERNAL_ERROR;
  432. } else {
  433. mem_resource_map_[memory_type] = mem_resource;
  434. }
  435. } else {
  436. mem_resource = iter->second;
  437. }
  438. if (mem_resource == nullptr) {
  439. REPORT_INNER_ERROR("E19999", "MemResource is invalid, memory_type:%d, session_id:%lu",
  440. memory_type, session_id_);
  441. GELOGE(ge::INTERNAL_ERROR, "[Check][Param] MemResource is invalid, memory_type:%u, session_id:%lu",
  442. memory_type, session_id_);
  443. return FAILED;
  444. }
  445. mem_resource->UpdateVarMemSize(mem_size);
  446. return SUCCESS;
  447. }
  448. ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc,
  449. rtMemType_t memory_type) {
  450. std::lock_guard<std::recursive_mutex> lock(mutex_);
  451. GELOGI("VarManager::AssignVarMem var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(),
  452. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  453. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str());
  454. int64_t tensor_desc_size = 0;
  455. size_t mem_offset = 0;
  456. ge::Status result = TensorUtils::GetSize(tensor_desc, tensor_desc_size);
  457. if (result != ge::SUCCESS) {
  458. REPORT_CALL_ERROR("E19999", "Get size from tensor fail, var_name:%s, memory_type:%d, session_id:%lu",
  459. var_name.c_str(), memory_type, session_id_);
  460. GELOGE(result, "[Get][Size] from tensor fail, var_name:%s, memory_type:%u, session_id:%lu",
  461. var_name.c_str(), memory_type, session_id_);
  462. return result;
  463. }
  464. MemResource *mem_resource = nullptr;
  465. auto it = mem_resource_map_.find(memory_type);
  466. if (it == mem_resource_map_.end()) {
  467. mem_resource = MemResource::BuildMemResourceFromType(memory_type);
  468. if (mem_resource == nullptr) {
  469. REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu",
  470. memory_type, session_id_);
  471. GELOGE(ge::INTERNAL_ERROR, "[Alloc][MemResource] failed, memory_type:%u, session_id:%lu.",
  472. memory_type, session_id_);
  473. return ge::INTERNAL_ERROR;
  474. } else {
  475. mem_resource_map_[memory_type] = mem_resource;
  476. }
  477. } else {
  478. mem_resource = it->second;
  479. }
  480. if (mem_resource == nullptr) {
  481. REPORT_INNER_ERROR("E19999", "MemResource is invalid, memory_type:%d, session_id:%lu",
  482. memory_type, session_id_);
  483. GELOGE(ge::INTERNAL_ERROR, "[Check][Param] MemResource is invalid, memory_type:%u, session_id:%lu.",
  484. memory_type, session_id_);
  485. return ge::INTERNAL_ERROR;
  486. }
  487. if (var_resource_ == nullptr) {
  488. REPORT_INNER_ERROR("E19999", "VarManager has not been init, memory_type:%d, session_id:%lu, "
  489. "check invalid", memory_type, session_id_);
  490. GELOGW("VarManager has not been init.");
  491. return ge::INTERNAL_ERROR;
  492. }
  493. ge::GeTensorDesc cur_tensor_desc;
  494. int64_t cur_tensor_desc_size = 0;
  495. result = var_resource_->GetCurVarDesc(var_name, cur_tensor_desc);
  496. // reuse old format variable memory
  497. if (result == SUCCESS) {
  498. result = var_resource_->GetVarAddr(
  499. var_name, cur_tensor_desc, reinterpret_cast<uint8_t **>(reinterpret_cast<uintptr_t>(&mem_offset)), memory_type);
  500. if (result == SUCCESS) {
  501. result = TensorUtils::GetSize(cur_tensor_desc, cur_tensor_desc_size);
  502. GELOGD("tensor_desc_size is %ld, cur_tensor_desc_size is %ld, memoffset is %zu", tensor_desc_size,
  503. cur_tensor_desc_size, mem_offset);
  504. }
  505. }
  506. bool can_not_reuse_old_memory = (result != SUCCESS) || (tensor_desc_size > cur_tensor_desc_size);
  507. if (can_not_reuse_old_memory) {
  508. result = mem_resource->AssignVarMem(var_name, tensor_desc_size, session_id_, mem_offset);
  509. if (result != SUCCESS) {
  510. GELOGE(ge::INTERNAL_ERROR, "[Assign][VarMem] by offset failed, session_id:%lu.", session_id_);
  511. return ge::INTERNAL_ERROR;
  512. }
  513. result = var_resource_->SaveVarAddr(
  514. var_name, tensor_desc, reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  515. if (result != SUCCESS) {
  516. GELOGE(ge::INTERNAL_ERROR, "[Save][VarAddr] by offset failed, memory type:%u, session_id:%lu.",
  517. memory_type, session_id_);
  518. return ge::INTERNAL_ERROR;
  519. }
  520. }
  521. // old not exist only save new tensor
  522. result = var_resource_->GetCurVarDesc(var_name, cur_tensor_desc);
  523. if (result != SUCCESS) {
  524. var_resource_->SetVarAddr(var_name, tensor_desc,
  525. reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  526. return SUCCESS;
  527. }
  528. bool format_changed = cur_tensor_desc.GetFormat() != tensor_desc.GetFormat() ||
  529. cur_tensor_desc.GetDataType() != tensor_desc.GetDataType() ||
  530. cur_tensor_desc.GetShape().GetDims() != tensor_desc.GetShape().GetDims();
  531. if (format_changed) {
  532. GELOGI("var %s assigned new memory (format, data type, shape) (%s, %s, %zu) from (%s, %s, %zu)", var_name.c_str(),
  533. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(),
  534. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  535. tensor_desc.GetShape().GetDims().size(),
  536. ge::TypeUtils::DataTypeToSerialString(cur_tensor_desc.GetDataType()).c_str(),
  537. ge::TypeUtils::FormatToSerialString(cur_tensor_desc.GetFormat()).c_str(),
  538. cur_tensor_desc.GetShape().GetDims().size());
  539. var_resource_->SetVarAddr(var_name, tensor_desc,
  540. reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type);
  541. }
  542. return SUCCESS;
  543. }
  544. bool VarManager::IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc) {
  545. std::lock_guard<std::recursive_mutex> lock(mutex_);
  546. GELOGD("VarManager::IsVarExist var_name = %s, data_type = %s, data_format = %s", var_name.c_str(),
  547. ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
  548. ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
  549. if (var_resource_ == nullptr) {
  550. GELOGW("VarManager has not been init.");
  551. return false;
  552. }
  553. return var_resource_->IsVarExist(var_name, tensor_desc);
  554. }
  555. bool VarManager::IsVarExist(const std::string &var_name) {
  556. std::lock_guard<std::recursive_mutex> lock(mutex_);
  557. if (var_resource_ == nullptr) {
  558. GELOGW("VarManager has not been init.");
  559. return false;
  560. }
  561. return var_resource_->IsVarExist(var_name);
  562. }
  563. ge::Status VarManager::SyncVarData(uint32_t graph_id, const std::string &var_name, const GeTensorDesc &var_tensor_desc,
  564. uint8_t *base_ptr) {
  565. std::lock_guard<std::recursive_mutex> lock(mutex_);
  566. if (var_resource_ == nullptr) {
  567. GELOGW("VarManager has not been init.");
  568. return ge::INTERNAL_ERROR;
  569. }
  570. return var_resource_->SyncVarData(graph_id, var_name, var_tensor_desc, base_ptr);
  571. }
  572. ge::Status VarManager::GetCurVarDesc(const std::string &var_name, ge::GeTensorDesc &tensor_desc) {
  573. std::lock_guard<std::recursive_mutex> lock(mutex_);
  574. GELOGI("VarManager::GetCurVarDesc var_name = %s.", var_name.c_str());
  575. if (var_resource_ == nullptr) {
  576. GELOGW("VarManager has not been init.");
  577. return ge::INTERNAL_ERROR;
  578. }
  579. return var_resource_->GetCurVarDesc(var_name, tensor_desc);
  580. }
  581. ge::Status VarManager::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info) {
  582. std::lock_guard<std::recursive_mutex> lock(mutex_);
  583. GELOGI(
  584. "VarManager::SaveBroadCastInfo var_name = %s, broadcast name = %s, "
  585. "idx = %d, input_offset = %ld, input_size = %lu, output_offset = %ld, "
  586. "output_size = %lu",
  587. broad_cast_info.var_name.c_str(), broad_cast_info.broadcast_name.c_str(), broad_cast_info.idx,
  588. broad_cast_info.input_offset, broad_cast_info.input_size, broad_cast_info.output_offset,
  589. broad_cast_info.output_size);
  590. if (var_resource_ == nullptr) {
  591. GELOGW("VarManager has not been init.");
  592. return ge::INTERNAL_ERROR;
  593. }
  594. var_resource_->SaveBroadCastInfo(graph_id, broad_cast_info);
  595. return SUCCESS;
  596. }
  597. ge::Status VarManager::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) {
  598. std::lock_guard<std::recursive_mutex> lock(mutex_);
  599. if (var_resource_ == nullptr) {
  600. GELOGW("VarManager has not been init.");
  601. return ge::INTERNAL_ERROR;
  602. }
  603. return var_resource_->GetBroadCastInfo(graph_id, var_name, broad_cast_info);
  604. }
  605. ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) {
  606. std::lock_guard<std::recursive_mutex> lock(mutex_);
  607. GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str());
  608. if (var_resource_ == nullptr) {
  609. REPORT_INNER_ERROR("E19999", "VarManager has not been init, op:%s(%s), session_id:%lu, check invalid",
  610. op_desc->GetName().c_str(), op_desc->GetType().c_str(),
  611. session_id_);
  612. GELOGE(ge::INTERNAL_ERROR, "[Check][Param] VarManager has not been init, op:%s(%s), session_id:%lu",
  613. op_desc->GetName().c_str(), op_desc->GetType().c_str(), session_id_);
  614. return ge::INTERNAL_ERROR;
  615. }
  616. return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc));
  617. }
  618. ge::Status VarManager::SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name,
  619. const GeTensorDesc &var_tensor_desc, uint8_t *base_ptr) {
  620. std::lock_guard<std::recursive_mutex> lock(mutex_);
  621. if (var_resource_ == nullptr) {
  622. GELOGW("VarManager has not been init.");
  623. return ge::INTERNAL_ERROR;
  624. }
  625. return var_resource_->SyncBroadCastData2Var(graph_id, var_name, var_tensor_desc, base_ptr);
  626. }
  627. bool VarManager::IsVarAddr(const int64_t &offset) {
  628. std::lock_guard<std::recursive_mutex> lock(mutex_);
  629. if (var_resource_ == nullptr) {
  630. GELOGD("VarManager has not been init.");
  631. return false;
  632. }
  633. return var_resource_->IsVarAddr(offset);
  634. }
  635. rtMemType_t VarManager::GetVarMemType(const int64_t &offset) {
  636. std::lock_guard<std::recursive_mutex> lock(mutex_);
  637. if (var_resource_ == nullptr) {
  638. GELOGW("VarManager has not been init.");
  639. return RT_MEMORY_RESERVED;
  640. }
  641. return var_resource_->GetVarMemType(offset);
  642. }
  643. ge::Status VarManager::MallocVarMemory(size_t memory_size) {
  644. std::lock_guard<std::recursive_mutex> lock(mutex_);
  645. uint8_t *var_mem_base = nullptr;
  646. string memory_key = std::to_string(session_id_);
  647. // malloc variable memory
  648. size_t var_memory_size = memory_size;
  649. // align 512 BYTE
  650. var_memory_size = (var_memory_size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
  651. const string purpose("variables and constant op memory in training network.");
  652. var_mem_base = MemManager::Instance().MemInstance(RT_MEMORY_HBM).MallocMemory(purpose, memory_key, var_memory_size);
  653. if (var_mem_base == nullptr) {
  654. GELOGE(ge::INTERNAL_ERROR, "[Malloc][VarMemory] failed, size:%zu, session_id:%s",
  655. var_memory_size, memory_key.c_str());
  656. return ge::INTERNAL_ERROR;
  657. }
  658. return SUCCESS;
  659. }
  660. uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type) {
  661. std::lock_guard<std::recursive_mutex> lock(mutex_);
  662. if (memory_type == RT_MEMORY_RDMA_HBM) {
  663. return MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).GetRdmaBaseAddr();
  664. }
  665. string memory_key = std::to_string(session_id_);
  666. return MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(memory_key);
  667. }
  668. uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) {
  669. std::lock_guard<std::recursive_mutex> lock(mutex_);
  670. if (memory_type == RT_MEMORY_RDMA_HBM) {
  671. return logic_addr;
  672. }
  673. string mem_key = std::to_string(session_id_);
  674. uint8_t *mem_base = MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(mem_key);
  675. if (mem_base == nullptr) {
  676. return nullptr;
  677. }
  678. uint8_t *mem_addr =
  679. logic_addr + reinterpret_cast<intptr_t>(mem_base) - VarManager::Instance(session_id_)->GetVarMemLogicBase();
  680. return mem_addr;
  681. }
  682. ge::Status VarManager::FreeVarMemory() {
  683. std::lock_guard<std::recursive_mutex> lock(mutex_);
  684. string memory_key = std::to_string(SessionId());
  685. return MemManager::Instance().MemInstance(RT_MEMORY_HBM).FreeMemory(memory_key);
  686. }
  687. ge::Status VarManager::SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) {
  688. std::lock_guard<std::recursive_mutex> lock(mutex_);
  689. if (var_resource_ == nullptr) {
  690. GELOGW("VarManager has not been init.");
  691. return ge::INTERNAL_ERROR;
  692. }
  693. return var_resource_->SetTransRoad(var_name, trans_road);
  694. }
  695. VarTransRoad *VarManager::GetTransRoad(const std::string &var_name) {
  696. std::lock_guard<std::recursive_mutex> lock(mutex_);
  697. if (var_resource_ == nullptr) {
  698. GELOGW("VarManager has not been init.");
  699. return nullptr;
  700. }
  701. return var_resource_->GetTransRoad(var_name);
  702. }
  703. Status VarManager::SetChangedGraphId(const std::string &var_name, uint32_t graph_id) {
  704. std::lock_guard<std::recursive_mutex> lock(mutex_);
  705. if (var_resource_ == nullptr) {
  706. GELOGW("VarManager has not been init.");
  707. return INTERNAL_ERROR;
  708. }
  709. return var_resource_->SetChangedGraphId(var_name, graph_id);
  710. }
  711. Status VarManager::GetChangedGraphId(const std::string &var_name, uint32_t &graph_id) {
  712. std::lock_guard<std::recursive_mutex> lock(mutex_);
  713. if (var_resource_ == nullptr) {
  714. GELOGW("VarManager has not been init.");
  715. return INTERNAL_ERROR;
  716. }
  717. return var_resource_->GetChangedGraphId(var_name, graph_id);
  718. }
  719. Status VarManager::GetTotalMemorySize(size_t &total_mem_size) {
  720. rtError_t rt_ret = rtSetDevice(GetContext().DeviceId());
  721. if (rt_ret != RT_ERROR_NONE) {
  722. REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X",
  723. GetContext().DeviceId(), rt_ret);
  724. GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret);
  725. return RT_FAILED;
  726. }
  727. size_t free_mem = 0;
  728. rt_ret = rtMemGetInfoEx(RT_MEMORYINFO_HBM, &free_mem, &total_mem_size);
  729. if (rt_ret != RT_ERROR_NONE) {
  730. REPORT_CALL_ERROR("E19999", "Call rtMemGetInfo failed, ret:0x%X", rt_ret);
  731. GELOGE(RT_FAILED, "[Call][RtMemGetInfo] failed, ret:0x%X", rt_ret);
  732. return RT_FAILED;
  733. }
  734. if (total_mem_size == 0) {
  735. rt_ret = rtMemGetInfoEx(RT_MEMORYINFO_DDR, &free_mem, &total_mem_size);
  736. if (rt_ret != RT_ERROR_NONE) {
  737. REPORT_CALL_ERROR("E19999", "Call rtMemGetInfo failed, ret:0x%X", rt_ret);
  738. GELOGE(RT_FAILED, "[Call][RtMemGetInfo] failed, ret:0x%X", rt_ret);
  739. return RT_FAILED;
  740. }
  741. }
  742. rt_ret = rtDeviceReset(GetContext().DeviceId());
  743. if (rt_ret != RT_ERROR_NONE) {
  744. REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X",
  745. GetContext().DeviceId(), rt_ret);
  746. GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret);
  747. return RT_FAILED;
  748. }
  749. return SUCCESS;
  750. }
  751. Status VarManager::SetMemoryMallocSize(const map<string, string> &options) {
  752. size_t total_mem_size = 0;
  753. GE_CHK_STATUS_RET_NOLOG(VarManager::GetTotalMemorySize(total_mem_size));
  754. GEEVENT("Total memory size is %zu", total_mem_size);
  755. graph_mem_max_size_ = floor(total_mem_size * kGraphMemoryManagerMallocRatio);
  756. var_mem_max_size_ = floor(total_mem_size * kVarMemoryManagerMallocRatio);
  757. auto it1 = options.find(GRAPH_MEMORY_MAX_SIZE);
  758. if (it1 != options.end()) {
  759. string graph_memory_manager_malloc_max_size = it1->second;
  760. ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_);
  761. if (ret != SUCCESS) {
  762. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_);
  763. return ge::GE_GRAPH_OPTIONS_INVALID;
  764. }
  765. }
  766. auto it2 = options.find(VARIABLE_MEMORY_MAX_SIZE);
  767. if (it2 != options.end()) {
  768. string memory_var_manager_malloc_size = it2->second;
  769. ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_);
  770. if (ret != SUCCESS) {
  771. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_);
  772. return ge::GE_GRAPH_OPTIONS_INVALID;
  773. }
  774. }
  775. GEEVENT("The graph_mem_max_size is %zu and the var_mem_max_size is %zu", graph_mem_max_size_, var_mem_max_size_);
  776. var_mem_logic_base_ = graph_mem_max_size_ + kGraphMemoryBuffer;
  777. if (var_mem_logic_base_ > kMaxMemorySize) {
  778. REPORT_INNER_ERROR("E19999", "var_login_base:%zu can not exeed limit:%zu, session_id:%lu, check invalid",
  779. var_mem_logic_base_, kMaxMemorySize, session_id_);
  780. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Check][Param] kMemoryVarLogicBase:%zu can not exceed "
  781. "max memory size:%zu, session_id:%lu.", var_mem_logic_base_, kMaxMemorySize, session_id_);
  782. return ge::GE_GRAPH_OPTIONS_INVALID;
  783. }
  784. use_max_mem_size_ = graph_mem_max_size_ + var_mem_max_size_;
  785. if (use_max_mem_size_ > kMaxMemorySize) {
  786. REPORT_INNER_ERROR("E19999", "all mem_use size:%zu can not exeed limit:%zu, session_id:%lu, check invalid",
  787. use_max_mem_size_, kMaxMemorySize, session_id_);
  788. GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Check][Param] kUseMaxMemorySize:%zu can not exceed "
  789. "max memory size:%zu, session_id:%lu.", use_max_mem_size_, kMaxMemorySize, session_id_);
  790. return ge::GE_GRAPH_OPTIONS_INVALID;
  791. }
  792. GELOGI("Set memory malloc size successfully");
  793. return SUCCESS;
  794. }
  795. Status VarManager::ParseMemoryMallocSize(string &memory_size, size_t &result) {
  796. if (memory_size.empty()) {
  797. REPORT_INNER_ERROR("E19999", "Param memory_size is empty, session_id:%lu, check invalid",
  798. session_id_);
  799. GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Check][Param] Memory malloc size input is empty, session_id:%lu.", session_id_);
  800. return GE_GRAPH_OPTIONS_INVALID;
  801. }
  802. // split string by '*'
  803. vector<string> splits;
  804. std::istringstream str(memory_size);
  805. string str_split;
  806. while (getline(str, str_split, '*')) {
  807. splits.emplace_back(str_split);
  808. }
  809. result = 1;
  810. for (string split : splits) {
  811. // Trim
  812. auto it = split.find_first_not_of(" ");
  813. if (it != string::npos) {
  814. split.erase(0, it);
  815. }
  816. it = split.find_last_not_of(" ");
  817. if (it != string::npos) {
  818. split.erase(it + 1);
  819. }
  820. for (char c : split) {
  821. if (!isdigit(c)) {
  822. REPORT_INNER_ERROR("E19999", "Param memory_size:%s contains non digit, session_id:%lu, check invalid",
  823. memory_size.c_str(), session_id_);
  824. GELOGE(GE_GRAPH_OPTIONS_INVALID,
  825. "[Check][Param] Memory malloc size:%s input contains non digit, session_id:%lu.",
  826. memory_size.c_str(), session_id_);
  827. return GE_GRAPH_OPTIONS_INVALID;
  828. }
  829. }
  830. uint64_t num = std::strtoul(split.c_str(), nullptr, 0);
  831. GE_IF_BOOL_EXEC(TypeUtils::CheckUint64MulOverflow(result, static_cast<uint32_t>(num)),
  832. REPORT_INNER_ERROR("E19999", "Param memory_size:%s will overflow after multi all, session_id:%lu, "
  833. "check invalid", memory_size.c_str(),
  834. session_id_);
  835. GELOGE(FAILED, "[Check][Param] Param memory_size:%s will overflow after multi all, session_id:%lu",
  836. memory_size.c_str(), session_id_);
  837. return FAILED);
  838. if ((num > kMaxMemorySize) || (result * static_cast<size_t>(num) > kMaxMemorySize)) {
  839. REPORT_INNER_ERROR("E19999", "Param memory_size:%s after multi will exceed limit:%lu, session_id:%lu, "
  840. "check invalid", memory_size.c_str(), kMaxMemorySize,
  841. session_id_);
  842. GELOGE(FAILED, "[Check][Param] Input memory size can not exceed max memory size:%zu, session_id:%lu.",
  843. kMaxMemorySize, session_id_);
  844. return FAILED;
  845. }
  846. result *= static_cast<size_t>(num);
  847. }
  848. return SUCCESS;
  849. }
  850. void VarManager::RemoveChangedGraphId(const std::string &var_name) {
  851. std::lock_guard<std::recursive_mutex> lock(mutex_);
  852. if (var_resource_ == nullptr) {
  853. GELOGW("VarManager has not been init.");
  854. return;
  855. }
  856. var_resource_->RemoveChangedGraphId(var_name);
  857. }
  858. Status VarManager::SetAllocatedGraphId(const std::string &var_name, uint32_t graph_id) {
  859. std::lock_guard<std::recursive_mutex> lock(mutex_);
  860. if (var_resource_ == nullptr) {
  861. GELOGW("VarManager has not been init.");
  862. return INTERNAL_ERROR;
  863. }
  864. return var_resource_->SetAllocatedGraphId(var_name, graph_id);
  865. }
  866. Status VarManager::GetAllocatedGraphId(const std::string &var_name, uint32_t &graph_id) {
  867. std::lock_guard<std::recursive_mutex> lock(mutex_);
  868. if (var_resource_ == nullptr) {
  869. GELOGW("VarManager has not been init.");
  870. return INTERNAL_ERROR;
  871. }
  872. return var_resource_->GetAllocatedGraphId(var_name, graph_id);
  873. }
  874. void VarManager::RemoveAllocatedGraphId(const std::string &var_name) {
  875. std::lock_guard<std::recursive_mutex> lock(mutex_);
  876. if (var_resource_ == nullptr) {
  877. GELOGW("VarManager has not been init.");
  878. return;
  879. }
  880. var_resource_->RemoveAllocatedGraphId(var_name);
  881. }
  882. Status VarManager::GetAllVariables(std::map<std::string, GeTensorDesc> &all_variables) {
  883. std::lock_guard<std::recursive_mutex> lock(mutex_);
  884. if (var_resource_ == nullptr) {
  885. GELOGW("VarManager has not been inited.");
  886. return INTERNAL_ERROR;
  887. }
  888. auto new_variable_desc = var_resource_->GetAllVarDesc();
  889. if (new_variable_desc.size() == 0) {
  890. GELOGW("VarManager don't have variables.");
  891. return INTERNAL_ERROR;
  892. }
  893. for (auto iter = new_variable_desc.begin(); iter != new_variable_desc.end(); ++iter) {
  894. auto trans_road = var_resource_->GetTransRoad(iter->first);
  895. if (trans_road == nullptr || trans_road->empty()) {
  896. GELOGI("The variable %s does not have any trans road", iter->first.c_str());
  897. all_variables[iter->first] = iter->second;
  898. continue;
  899. }
  900. // get origin trans info : the first trans node info
  901. auto origin_trans_node_info = trans_road->at(0);
  902. all_variables[iter->first] = origin_trans_node_info.input;
  903. }
  904. return SUCCESS;
  905. }
  906. VarManagerPool::~VarManagerPool() { Destory(); }
  907. VarManagerPool &VarManagerPool::Instance() {
  908. static VarManagerPool var_manager_pool;
  909. return var_manager_pool;
  910. }
  911. void VarManagerPool::Destory() noexcept {
  912. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  913. for (auto &it : var_manager_map_) {
  914. VarManager *var_manager = it.second;
  915. if (var_manager != nullptr) {
  916. var_manager->Destory();
  917. delete var_manager;
  918. var_manager = nullptr;
  919. }
  920. }
  921. var_manager_map_.clear();
  922. }
  923. ge::Status VarManagerPool::Init() const { return SUCCESS; }
  924. VarManager *VarManagerPool::GetVarManager(uint64_t session_id) {
  925. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  926. auto it = var_manager_map_.find(session_id);
  927. if (it != var_manager_map_.end()) {
  928. GELOGD("VarManagerPool::GetVarManager");
  929. return it->second;
  930. }
  931. VarManager *var_manager = new (std::nothrow) VarManager(session_id);
  932. if (var_manager == nullptr) {
  933. REPORT_INNER_ERROR("E19999", "New VarManager fail, session_id:%lu", session_id);
  934. GELOGE(INTERNAL_ERROR, "[New][VarManager] fail, session_id:%lu", session_id);
  935. static VarManager new_var_manager(0);
  936. return &new_var_manager;
  937. }
  938. var_manager_map_[session_id] = var_manager;
  939. return var_manager;
  940. }
  941. void VarManagerPool::RemoveVarManager(uint64_t session_id) {
  942. VarManager *var_manager = nullptr;
  943. {
  944. std::lock_guard<std::mutex> lock(var_manager_mutex_);
  945. auto it = var_manager_map_.find(session_id);
  946. if (it != var_manager_map_.end()) {
  947. var_manager = it->second;
  948. var_manager_map_.erase(it);
  949. }
  950. }
  951. if (var_manager != nullptr) {
  952. var_manager->Destory();
  953. delete var_manager;
  954. var_manager = nullptr;
  955. }
  956. }
  957. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示