You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_manager.cc 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/session_manager.h"
  17. #include <memory>
  18. #include <utility>
  19. #include "common/ge/ge_util.h"
  20. #include "framework/common/debug/ge_log.h"
  21. #include "graph/ge_context.h"
  22. #include "graph/load/new_model_manager/model_manager.h"
  23. #include "graph/manager/util/rt_context_util.h"
  24. using std::map;
  25. using std::string;
  26. using std::vector;
  27. namespace ge {
  28. Status SessionManager::Initialize(const std::map<std::string, std::string> &options) {
  29. if (init_flag_) {
  30. GELOGW("Session Manager has been initialized.");
  31. return SUCCESS;
  32. }
  33. init_flag_ = true;
  34. return SUCCESS;
  35. }
  36. Status SessionManager::Finalize() {
  37. if (!init_flag_) {
  38. GELOGW("Session Manager has not been initialized.");
  39. return SUCCESS;
  40. }
  41. std::lock_guard<std::mutex> lock(mutex_);
  42. for (auto iter = session_manager_map_.begin(); iter != session_manager_map_.end(); ++iter) {
  43. (void)iter->second->Finalize();
  44. }
  45. session_manager_map_.clear();
  46. init_flag_ = false;
  47. return SUCCESS;
  48. }
  49. Status SessionManager::SetRtContext(SessionId session_id, rtContext_t rt_context) {
  50. GELOGI("set rt_context RT_CTX_NORMAL_MODE, device id:%u.", GetContext().DeviceId());
  51. GE_CHK_RT_RET(rtCtxCreate(&rt_context, RT_CTX_NORMAL_MODE, static_cast<int32_t>(GetContext().DeviceId())));
  52. GE_CHK_RT_RET(rtCtxSetCurrent(rt_context));
  53. RtContextUtil::GetInstance().AddRtContext(session_id, rt_context);
  54. return SUCCESS;
  55. }
  56. Status SessionManager::CreateSession(const std::map<std::string, std::string> &options, SessionId &session_id) {
  57. if (!init_flag_) {
  58. GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized.");
  59. return GE_SESSION_MANAGER_NOT_INIT;
  60. }
  61. SessionId next_session_id = 0;
  62. std::lock_guard<std::mutex> lock(mutex_);
  63. Status nextSessionIdRet = GetNextSessionId(next_session_id);
  64. if (nextSessionIdRet != SUCCESS) {
  65. return nextSessionIdRet;
  66. }
  67. SessionPtr sessionPtr = MakeShared<InnerSession>(next_session_id, options);
  68. if (sessionPtr == nullptr) {
  69. return MEMALLOC_FAILED;
  70. }
  71. Status ret = sessionPtr->Initialize();
  72. if (ret != SUCCESS) {
  73. return ret;
  74. }
  75. (void)session_manager_map_.emplace(std::pair<SessionId, SessionPtr>(next_session_id, sessionPtr));
  76. session_id = next_session_id;
  77. // create a context
  78. ret = SetRtContext(session_id, rtContext_t());
  79. return ret;
  80. }
  81. Status SessionManager::DestroySession(SessionId session_id) {
  82. if (!init_flag_) {
  83. GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized.");
  84. return GE_SESSION_MANAGER_NOT_INIT;
  85. }
  86. std::lock_guard<std::mutex> lock(mutex_);
  87. std::map<SessionId, SessionPtr>::iterator it = session_manager_map_.find(session_id);
  88. if (it == session_manager_map_.end()) {
  89. return GE_SESSION_NOT_EXIST;
  90. }
  91. if (ModelManager::GetInstance() != nullptr) {
  92. ModelManager::GetInstance()->DestroyAicpuSession(session_id);
  93. }
  94. // Unified destruct rt_context
  95. RtContextUtil::GetInstance().DestroyRtContexts(session_id);
  96. SessionPtr innerSession = it->second;
  97. Status ret = innerSession->Finalize();
  98. if (ret != SUCCESS) {
  99. return ret;
  100. }
  101. (void)session_manager_map_.erase(session_id);
  102. return ret;
  103. }
  104. Status SessionManager::GetVariable(SessionId session_id, const std::string &name, Tensor &val) {
  105. if (!init_flag_) {
  106. GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized.");
  107. return GE_SESSION_MANAGER_NOT_INIT;
  108. }
  109. SessionPtr innerSession = nullptr;
  110. {
  111. std::lock_guard<std::mutex> lock(mutex_);
  112. std::map<SessionId, SessionPtr>::iterator it = session_manager_map_.find(session_id);
  113. if (it == session_manager_map_.end()) {
  114. return GE_SESSION_NOT_EXIST;
  115. } else {
  116. innerSession = it->second;
  117. }
  118. }
  119. return innerSession->GetVariable(name, val);
  120. }
  121. Status SessionManager::AddGraph(SessionId session_id, uint32_t graph_id, const Graph &graph) {
  122. std::map<std::string, std::string> options;
  123. return AddGraph(session_id, graph_id, graph, options);
  124. }
  125. Status SessionManager::AddGraph(SessionId session_id, uint32_t graph_id, const Graph &graph,
  126. const std::map<std::string, std::string> &options) {
  127. if (!init_flag_) {
  128. GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized.");
  129. return GE_SESSION_MANAGER_NOT_INIT;
  130. }
  131. SessionPtr innerSession = nullptr;
  132. {
  133. std::lock_guard<std::mutex> lock(mutex_);
  134. std::map<SessionId, SessionPtr>::iterator it = session_manager_map_.find(session_id);
  135. if (it == session_manager_map_.end()) {
  136. return GE_SESSION_NOT_EXIST;
  137. } else {
  138. innerSession = it->second;
  139. }
  140. auto compute_graph = GraphUtils::GetComputeGraph(graph);
  141. GE_CHECK_NOTNULL(compute_graph);
  142. std::string session_graph_id = std::to_string(session_id) + "_" + std::to_string(graph_id);
  143. if (!AttrUtils::SetStr(*compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) {
  144. GELOGW("Set graph session_graph_id attr failed.");
  145. } else {
  146. GELOGD("Set graph session_graph_id attr to [%s]", session_graph_id.c_str());
  147. }
  148. for (auto graph : compute_graph->GetAllSubgraphs()) {
  149. AttrUtils::SetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id);
  150. }
  151. }
  152. return innerSession->AddGraph(graph_id, graph, options);
  153. }
  154. Status SessionManager::AddGraphWithCopy(SessionId session_id, uint32_t graph_id, const Graph &graph,
  155. const std::map<std::string, std::string> &options) {
  156. if (!init_flag_) {
  157. GELOGE(GE_SESSION_MANAGER_NOT_INIT);
  158. return GE_SESSION_MANAGER_NOT_INIT;
  159. }
  160. SessionPtr innerSession = nullptr;
  161. {
  162. std::lock_guard<std::mutex> lock(mutex_);
  163. std::map<SessionId, SessionPtr>::iterator it = session_manager_map_.find(session_id);
  164. if (it == session_manager_map_.end()) {
  165. return GE_SESSION_NOT_EXIST;
  166. } else {
  167. innerSession = it->second;
  168. }
  169. auto compute_graph = GraphUtils::GetComputeGraph(graph);
  170. GE_CHECK_NOTNULL(compute_graph);
  171. std::string session_graph_id = std::to_string(session_id) + "_" + std::to_string(graph_id);
  172. if (!AttrUtils::SetStr(*compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) {
  173. GELOGW("Set graph session_graph_id attr failed.");
  174. } else {
  175. GELOGD("Set graph session_graph_id attr to [%s]", session_graph_id.c_str());
  176. }
  177. for (auto graph : compute_graph->GetAllSubgraphs()) {
  178. AttrUtils::SetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id);
  179. }
  180. }
  181. return innerSession->AddGraphWithCopy(graph_id, graph, options);
  182. }
  183. Status SessionManager::RunGraph(SessionId session_id, uint32_t graph_id, const std::vector<Tensor> &inputs,
  184. std::vector<Tensor> &outputs) {
  185. if (!init_flag_) {
  186. GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized.");
  187. return GE_SESSION_MANAGER_NOT_INIT;
  188. }
  189. SessionPtr innerSession = nullptr;
  190. {
  191. std::lock_guard<std::mutex> lock(mutex_);
  192. std::map<SessionId, SessionPtr>::iterator it = session_manager_map_.find(session_id);
  193. if (it == session_manager_map_.end()) {
  194. return GE_SESSION_NOT_EXIST;
  195. } else {
  196. innerSession = it->second;
  197. }
  198. }
  199. return innerSession->RunGraph(graph_id, inputs, outputs);
  200. }
  201. Status SessionManager::RemoveGraph(SessionId session_id, uint32_t graph_id) {
  202. if (!init_flag_) {
  203. GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized.");
  204. return GE_SESSION_MANAGER_NOT_INIT;
  205. }
  206. SessionPtr innerSession = nullptr;
  207. {
  208. std::lock_guard<std::mutex> lock(mutex_);
  209. std::map<SessionId, SessionPtr>::iterator it = session_manager_map_.find(session_id);
  210. if (it == session_manager_map_.end()) {
  211. return GE_SESSION_NOT_EXIST;
  212. } else {
  213. innerSession = it->second;
  214. }
  215. }
  216. return innerSession->RemoveGraph(graph_id);
  217. }
  218. bool SessionManager::HasSession(SessionId session_id) {
  219. if (!init_flag_) {
  220. GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized.");
  221. return false;
  222. }
  223. return session_manager_map_.find(session_id) != session_manager_map_.end();
  224. }
  225. Status SessionManager::GetNextSessionId(SessionId &next_session_id) {
  226. if (!init_flag_) {
  227. GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized.");
  228. return GE_SESSION_MANAGER_NOT_INIT;
  229. }
  230. static SessionId session_id = 0;
  231. next_session_id = session_id++;
  232. return SUCCESS;
  233. }
  234. Status SessionManager::RegisterCallBackFunc(
  235. SessionId session_id, const std::string &key,
  236. const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback) {
  237. if (!init_flag_) {
  238. GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized.");
  239. return GE_SESSION_MANAGER_NOT_INIT;
  240. }
  241. SessionPtr innerSession = nullptr;
  242. {
  243. std::lock_guard<std::mutex> lock(mutex_);
  244. std::map<SessionId, SessionPtr>::iterator it = session_manager_map_.find(session_id);
  245. if (it == session_manager_map_.end()) {
  246. return GE_SESSION_NOT_EXIST;
  247. } else {
  248. innerSession = it->second;
  249. }
  250. }
  251. return innerSession->RegisterCallBackFunc(key, callback);
  252. }
  253. Status SessionManager::BuildGraph(SessionId session_id, uint32_t graph_id, const std::vector<InputTensorInfo> &inputs) {
  254. if (!init_flag_) {
  255. GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized.");
  256. return GE_SESSION_MANAGER_NOT_INIT;
  257. }
  258. SessionPtr innerSession = nullptr;
  259. {
  260. std::lock_guard<std::mutex> lock(mutex_);
  261. std::map<SessionId, SessionPtr>::iterator it = session_manager_map_.find(session_id);
  262. if (it == session_manager_map_.end()) {
  263. return GE_SESSION_NOT_EXIST;
  264. } else {
  265. innerSession = it->second;
  266. }
  267. }
  268. return innerSession->BuildGraph(graph_id, inputs);
  269. }
  270. Status SessionManager::RunGraphAsync(SessionId session_id, uint32_t graph_id,
  271. const std::vector<InputTensorInfo> &inputs, RunAsyncCallback callback) {
  272. if (!init_flag_) {
  273. GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized.");
  274. return GE_SESSION_MANAGER_NOT_INIT;
  275. }
  276. SessionPtr innerSession = nullptr;
  277. {
  278. std::lock_guard<std::mutex> lock(mutex_);
  279. std::map<SessionId, SessionPtr>::iterator it = session_manager_map_.find(session_id);
  280. if (it == session_manager_map_.end()) {
  281. return GE_SESSION_NOT_EXIST;
  282. } else {
  283. innerSession = it->second;
  284. }
  285. }
  286. return innerSession->RunGraphAsync(graph_id, inputs, callback);
  287. }
  288. Status SessionManager::GetVariables(SessionId session_id, const std::vector<std::string> &var_names,
  289. std::vector<Tensor> &var_values) {
  290. // step 0: init session manager
  291. if (!init_flag_) {
  292. GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized.");
  293. return GE_SESSION_MANAGER_NOT_INIT;
  294. }
  295. SessionPtr innerSession = nullptr;
  296. {
  297. std::lock_guard<std::mutex> lock(mutex_);
  298. std::map<SessionId, SessionPtr>::iterator it = session_manager_map_.find(session_id);
  299. if (it == session_manager_map_.end()) {
  300. return GE_SESSION_NOT_EXIST;
  301. } else {
  302. innerSession = it->second;
  303. }
  304. }
  305. // step 1: get all variable
  306. std::map<std::string, GeTensorDesc> all_variables;
  307. Status ret = innerSession->GetAllVariables(all_variables);
  308. if (ret != SUCCESS) {
  309. GELOGE(FAILED, "Get all variables failed.");
  310. return FAILED;
  311. }
  312. // srep 2: create check point graph
  313. Graph graph = Graph("checkpoint");
  314. ret = innerSession->GenCheckPointGraph(all_variables, graph);
  315. if (ret != SUCCESS) {
  316. GELOGE(FAILED, "Build check point graph failed.");
  317. return FAILED;
  318. }
  319. // step 3: run check point graph
  320. uint32_t graph_id = GetCurrentSecondTimestap();
  321. ret = AddGraph(session_id, graph_id, graph);
  322. if (ret != SUCCESS) {
  323. GELOGE(FAILED, "Add check point graph failed.");
  324. return FAILED;
  325. }
  326. vector<Tensor> inputs;
  327. vector<Tensor> outputs;
  328. ret = RunGraph(session_id, graph_id, inputs, outputs);
  329. if (ret != SUCCESS) {
  330. GELOGE(FAILED, "Run check point graph failed.");
  331. return FAILED;
  332. }
  333. // step 4: save variables
  334. ret = innerSession->SaveVariables(graph, var_names, outputs, var_values);
  335. GELOGD("[SessionManager] outputs size is [%zu], var values size is [%zu].", outputs.size(), var_values.size());
  336. if (ret != SUCCESS) {
  337. GELOGE(FAILED, "Save variables failed.");
  338. return FAILED;
  339. }
  340. // step 5: remove graph
  341. ret = innerSession->RemoveGraph(graph_id);
  342. if (ret != SUCCESS) {
  343. GELOGE(FAILED, "Remove graph failed.");
  344. return FAILED;
  345. }
  346. return ret;
  347. }
  348. bool SessionManager::IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id) {
  349. if (!init_flag_) {
  350. GELOGE(GE_SESSION_MANAGER_NOT_INIT, "Session manager is not initialized.");
  351. return true;
  352. }
  353. SessionPtr innerSession = nullptr;
  354. {
  355. std::lock_guard<std::mutex> lock(mutex_);
  356. auto it = session_manager_map_.find(session_id);
  357. if (it == session_manager_map_.end()) {
  358. GELOGE(GE_SESSION_NOT_EXIST, "The session %lu does not exists", session_id);
  359. return true;
  360. } else {
  361. innerSession = it->second;
  362. }
  363. }
  364. return innerSession->IsGraphNeedRebuild(graph_id);
  365. }
  366. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示