You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_execute.cc 22 kB

5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/execute/graph_execute.h"
  17. #include <memory>
  18. #include <string>
  19. #include "graph/load/model_manager/model_manager.h"
  20. #include "omm/csa_interact.h"
  21. namespace ge {
  22. GraphExecutor::GraphExecutor()
  23. : init_flag_(false),
  24. train_graph_flag_(false),
  25. sync_run_mutex_(nullptr),
  26. condition_(nullptr),
  27. graph_run_listener_(nullptr),
  28. graph_context_(nullptr),
  29. last_graph_id_(UINT32_MAX),
  30. malloc_flag_(false) {}
  31. GraphExecutor::~GraphExecutor() {
  32. outputs_desc_.clear();
  33. if (malloc_flag_) {
  34. for (auto &buffer_addr : buffer_addr_) {
  35. rtError_t rt_ret;
  36. rt_ret = rtFreeHost(buffer_addr);
  37. if (rt_ret != RT_ERROR_NONE) {
  38. GELOGE(RT_FAILED, "[GraphManager] subgraph free buffer failed, ret: 0x%X", rt_ret);
  39. }
  40. }
  41. }
  42. malloc_flag_ = false;
  43. buffer_addr_.clear();
  44. }
  45. Status GraphExecutor::SetCondition(std::mutex *mutex, std::condition_variable *cond,
  46. std::shared_ptr<GraphModelListener> listener) {
  47. if (mutex == nullptr) {
  48. GELOGE(GE_GRAPH_PARAM_NULLPTR, "[SetCondition] input param mutex is nullptr.");
  49. return GE_GRAPH_PARAM_NULLPTR;
  50. }
  51. if (cond == nullptr) {
  52. GELOGE(GE_GRAPH_PARAM_NULLPTR, "[SetCondition] input param cond is nullptr.");
  53. return GE_GRAPH_PARAM_NULLPTR;
  54. }
  55. if (listener == nullptr) {
  56. GELOGE(GE_GRAPH_PARAM_NULLPTR, "[SetCondition] input param listener is nullptr.");
  57. return GE_GRAPH_PARAM_NULLPTR;
  58. }
  59. sync_run_mutex_ = mutex;
  60. condition_ = cond;
  61. graph_run_listener_ = listener;
  62. init_flag_ = true;
  63. return SUCCESS;
  64. }
  65. Status GraphExecutor::SetGraphContext(GraphContextPtr graph_context_ptr) {
  66. if (graph_context_ptr == nullptr) {
  67. GELOGE(GE_GRAPH_PARAM_NULLPTR, "[SetGraphContext] input param graph_context_ptr is nullptr");
  68. return GE_GRAPH_PARAM_NULLPTR;
  69. }
  70. graph_context_ = graph_context_ptr;
  71. return SUCCESS;
  72. }
  73. Status GraphExecutor::SetDynamicSize(uint32_t model_id, const std::vector<uint64_t> &batch_num, int32_t dynamic_type) {
  74. auto model_manager = ge::ModelManager::GetInstance();
  75. GE_CHECK_NOTNULL(model_manager);
  76. Status ret = model_manager->SetDynamicSize(model_id, batch_num, dynamic_type);
  77. if (ret != SUCCESS) {
  78. GELOGE(ret, "SetDynamicSize failed");
  79. return ret;
  80. }
  81. return SUCCESS;
  82. }
  83. void GraphExecutor::SetTrainFlag(bool is_train_graph) { train_graph_flag_ = is_train_graph; }
  84. Status GraphExecutor::FreeInOutBuffer() {
  85. if (malloc_flag_) {
  86. for (auto iter = buffer_addr_.begin(); iter != buffer_addr_.end(); ++iter) {
  87. rtError_t rt_ret;
  88. rt_ret = rtFreeHost(*iter);
  89. if (rt_ret != RT_ERROR_NONE) {
  90. GELOGE(RT_FAILED, "[GraphManager] subgraph free buffer failed, ret: 0x%X", rt_ret);
  91. (void)buffer_addr_.erase(buffer_addr_.begin(), iter);
  92. return GE_GRAPH_FREE_FAILED;
  93. }
  94. }
  95. buffer_addr_.clear();
  96. malloc_flag_ = false;
  97. return SUCCESS;
  98. } else {
  99. GELOGD("[GraphManager] not malloc buffer.");
  100. return SUCCESS;
  101. }
  102. }
  103. Status GraphExecutor::MallocInOutBuffer(const std::vector<uint64_t> &buffer_size, std::vector<void *> &data_addr) {
  104. if (malloc_flag_) {
  105. auto all_size_same = true;
  106. if (buffer_size.size() == buffer_size_.size()) {
  107. for (size_t i = 0; i < buffer_size.size(); i++) {
  108. if (buffer_size[i] != buffer_size_[i]) {
  109. all_size_same = false;
  110. break;
  111. }
  112. }
  113. } else {
  114. all_size_same = false;
  115. }
  116. if (all_size_same) {
  117. data_addr = buffer_addr_;
  118. return SUCCESS;
  119. }
  120. buffer_size_.clear();
  121. auto rt_ret = FreeInOutBuffer();
  122. if (rt_ret != SUCCESS) {
  123. GELOGE(RT_FAILED, "[SubGraphInfo] MallocInOutBuffer free buffer failed, ret: 0x%X", rt_ret);
  124. return RT_FAILED;
  125. }
  126. }
  127. rtError_t rt_ret;
  128. for (size_t i = 0; i < buffer_size.size(); ++i) {
  129. void *tmp_buf = nullptr;
  130. rt_ret = rtMallocHost(&tmp_buf, buffer_size[i]);
  131. if (rt_ret != RT_ERROR_NONE) {
  132. GELOGE(RT_FAILED, "[GraphManager] subgraph malloc buffer failed, ret: 0x%X", rt_ret);
  133. return GE_GRAPH_MALLOC_FAILED;
  134. }
  135. malloc_flag_ = true;
  136. data_addr.push_back(tmp_buf);
  137. buffer_addr_.push_back(tmp_buf);
  138. }
  139. buffer_size_ = buffer_size;
  140. return SUCCESS;
  141. }
  142. Status GraphExecutor::PrepareInputData(const std::vector<GeTensor> &input_tensor, InputData &graph_input_data,
  143. OutputData &graph_output_data, std::vector<InputOutputDescInfo> &output_desc) {
  144. // Preprocessing input data
  145. graph_input_data.index = 0;
  146. graph_input_data.timeout = 0;
  147. graph_input_data.timestamp = 0;
  148. std::size_t inputSize = input_tensor.size();
  149. std::size_t output_size = output_desc.size();
  150. std::vector<uint64_t> bufferSizeVec;
  151. std::vector<void *> addrVec;
  152. for (std::size_t i = 0; i < inputSize; ++i) {
  153. const GeTensor *InTensor = &input_tensor[i];
  154. GE_CHECK_NOTNULL(InTensor);
  155. bufferSizeVec.push_back(InTensor->GetData().size());
  156. }
  157. for (const auto &desc : output_desc) {
  158. bufferSizeVec.push_back(desc.size);
  159. }
  160. Status ret = MallocInOutBuffer(bufferSizeVec, addrVec);
  161. if (ret != SUCCESS) {
  162. GELOGE(GE_GRAPH_MALLOC_FAILED, "[GraphExecutor] Malloc mem failed");
  163. return GE_GRAPH_MALLOC_FAILED;
  164. }
  165. for (std::size_t i = 0; i < input_tensor.size() && i < addrVec.size(); ++i) {
  166. const GeTensor *in_tensor = &input_tensor[i];
  167. GE_CHECK_NOTNULL(in_tensor);
  168. if ((addrVec[i] != nullptr) && (in_tensor->GetData().data() != nullptr)) {
  169. rtError_t rt_ret = rtMemcpy(addrVec[i], bufferSizeVec[i], in_tensor->GetData().data(),
  170. in_tensor->GetData().size(), RT_MEMCPY_HOST_TO_HOST);
  171. if (rt_ret != RT_ERROR_NONE) {
  172. GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
  173. return RT_FAILED;
  174. }
  175. }
  176. DataBuffer in_data_buf;
  177. in_data_buf.data = reinterpret_cast<uint8_t *>(addrVec[i]);
  178. in_data_buf.length = in_tensor->GetData().size();
  179. in_data_buf.isDataSupportMemShare = false;
  180. graph_input_data.blobs.push_back(in_data_buf);
  181. }
  182. graph_output_data.index = 0;
  183. for (std::size_t j = 0; j < output_size; j++) {
  184. auto desc = output_desc[j];
  185. uint64_t buffer_size = desc.size;
  186. DataBuffer out_data_buf;
  187. out_data_buf.data = reinterpret_cast<uint8_t *>(addrVec[inputSize + j]);
  188. out_data_buf.length = buffer_size;
  189. out_data_buf.isDataSupportMemShare = false;
  190. graph_output_data.blobs.push_back(out_data_buf);
  191. }
  192. return SUCCESS;
  193. }
  194. Status GraphExecutor::SyncExecuteModel(uint32_t model_id, const std::vector<GeTensor> &input_tensor,
  195. std::vector<GeTensor> &output_tensor) {
  196. auto model_manager = ge::ModelManager::GetInstance();
  197. GE_CHECK_NOTNULL(model_manager);
  198. if (model_manager->IsDynamicShape(model_id)) {
  199. GELOGI("[ExecuteGraph] GetInputOutputDescInfo via dynamic shape model executor, modelId=%u", model_id);
  200. return model_manager->SyncExecuteModel(model_id, input_tensor, output_tensor);
  201. }
  202. // Prepare input and output
  203. std::vector<InputOutputDescInfo> inputs_desc;
  204. std::vector<InputOutputDescInfo> output_desc;
  205. GELOGI("[ExecuteGraph] GetInputOutputDescInfo via new ome begin.");
  206. Status ret = GetInputOutputDescInfo(model_id, inputs_desc, output_desc);
  207. if (ret != SUCCESS) {
  208. GELOGE(GE_GRAPH_GET_IN_OUT_FAILED, "[GraphExecutor] GetInputOutputDescInfo failed, modelId=%u.", model_id);
  209. return GE_GRAPH_GET_IN_OUT_FAILED;
  210. }
  211. outputs_desc_.assign(output_desc.begin(), output_desc.end());
  212. InputData input_data;
  213. OutputData output_data;
  214. input_data.model_id = model_id;
  215. ret = PrepareInputData(input_tensor, input_data, output_data, output_desc);
  216. if (ret != SUCCESS) {
  217. GELOGE(GE_GRAPH_PREPARE_FAILED, "[GraphExecutor] PrepareInputData failed, modelId=%u.", model_id);
  218. return GE_GRAPH_PREPARE_FAILED;
  219. }
  220. if (graph_run_listener_->ResetResult() != SUCCESS) {
  221. GELOGE(GE_GRAPH_EXECUTE_FAILED, "Reset result failed");
  222. return GE_GRAPH_EXECUTE_FAILED;
  223. }
  224. // Run mode async
  225. GELOGI("[ExecuteGraph] DataInput via new ome begin.");
  226. ret = DataInput(input_data, output_data);
  227. if (ret != SUCCESS) {
  228. GELOGE(GE_GRAPH_DATA_INPUT_FAILED, "[GraphExecutor] push data failed, modelId=%u.", model_id);
  229. return GE_GRAPH_DATA_INPUT_FAILED;
  230. }
  231. GELOGI("[GraphExecutor] input data push to wrapper finish, waiting for result...");
  232. // Pending until async execute graph complete
  233. {
  234. std::unique_lock<std::mutex> ulock(*sync_run_mutex_);
  235. if (!graph_run_listener_->IsFinished()) {
  236. (*condition_).wait(ulock);
  237. }
  238. // Run graph return
  239. uint32_t result_code = graph_run_listener_->GetResultCode();
  240. if (result_code != SUCCESS && result_code != END_OF_SEQUENCE) {
  241. GELOGE(GE_GRAPH_EXECUTE_FAILED, "[GraphExecutor] execute model failed, ret=%u, modelId=%u.", result_code,
  242. model_id);
  243. return GE_GRAPH_EXECUTE_FAILED;
  244. }
  245. }
  246. for (size_t i = 0; i < output_data.blobs.size(); ++i) {
  247. DataBuffer outputDataTmp = output_data.blobs[i];
  248. CHECK_FALSE_EXEC(outputDataTmp.length != 0,
  249. GELOGE(GE_GRAPH_EXECUTE_FAILED, "Failed to allocate memory, length is 0.");
  250. return GE_GRAPH_EXECUTE_FAILED);
  251. std::unique_ptr<uint8_t> outBufTmp(new (std::nothrow) uint8_t[outputDataTmp.length]);
  252. if (outBufTmp == nullptr) {
  253. GELOGE(FAILED, "Failed to allocate memory.");
  254. return FAILED;
  255. }
  256. GE_PRINT_DYNAMIC_MEMORY(new, "the output memory of data on training.", sizeof(uint8_t) * outputDataTmp.length)
  257. rtError_t ret_value = rtMemcpy(outBufTmp.get(), outputDataTmp.length, outputDataTmp.data, outputDataTmp.length,
  258. RT_MEMCPY_HOST_TO_HOST);
  259. CHECK_FALSE_EXEC(ret_value == RT_ERROR_NONE,
  260. GELOGE(GE_GRAPH_EXECUTE_FAILED, "Call rt api rtMemcpy failed, ret: 0x%X", ret);
  261. return GE_GRAPH_EXECUTE_FAILED);
  262. GeTensor outTensor;
  263. std::vector<int64_t> shapeDims;
  264. for (const auto &dim : output_desc[i].shape_info.dims) {
  265. shapeDims.push_back(dim);
  266. }
  267. GeShape outShape(shapeDims);
  268. outTensor.MutableTensorDesc().SetShape(outShape);
  269. outTensor.MutableTensorDesc().SetDataType((DataType)output_desc[i].data_type);
  270. (void)outTensor.SetData(outBufTmp.get(), outputDataTmp.length);
  271. output_tensor.push_back(outTensor);
  272. }
  273. GELOGI("[GraphExecutor] execute model success, modelId=%u.", model_id);
  274. return SUCCESS;
  275. }
  276. void GraphExecutor::InitModelIdInfo(std::vector<uint32_t> &out_model_id_info,
  277. std::vector<SubGraphInfoPtr> &sub_graph_vec, uint32_t output_size) {
  278. for (uint32_t i = 0; i < output_size; i++) {
  279. for (size_t j = 0; j < sub_graph_vec.size(); j++) {
  280. if (sub_graph_vec[j]->GetOutputFlag().size() == output_size && sub_graph_vec[j]->GetOutputFlag().at(i)) {
  281. out_model_id_info.push_back(sub_graph_vec[j]->GetModelIdInfo().model_id);
  282. }
  283. }
  284. }
  285. }
  286. Status GraphExecutor::FreeExecuteMemory() {
  287. auto ret = FreeInOutBuffer();
  288. if (ret != SUCCESS) {
  289. GELOGE(ret, "[FreeExecuteMemory] FreeInOutBuffer Error!");
  290. return ret;
  291. }
  292. return SUCCESS;
  293. }
  294. Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeRootModelPtr &ge_root_model,
  295. const std::vector<GeTensor> &input_tensor, std::vector<GeTensor> &output_tensor) {
  296. if (graph_id != last_graph_id_) {
  297. auto ret = FreeExecuteMemory();
  298. if (ret != SUCCESS) {
  299. return ret;
  300. }
  301. }
  302. last_graph_id_ = graph_id;
  303. if (!init_flag_) {
  304. GELOGE(GE_GRAPH_EXECUTE_NOT_INIT, "[GraphExecutor] AI Core Engine without calling SetCondition!");
  305. return GE_GRAPH_EXECUTE_NOT_INIT;
  306. }
  307. GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED);
  308. Status ret = SyncExecuteModel(ge_root_model->GetModelId(), input_tensor, output_tensor);
  309. if (ret != SUCCESS) {
  310. GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] SyncExecuteModel Error!");
  311. return GE_GRAPH_SYNC_MODEL_FAILED;
  312. }
  313. return SUCCESS;
  314. }
  315. Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model,
  316. const std::vector<InputTensorInfo> &input_tensor) {
  317. GELOGI("[GraphExecutor] Start to async execute graph, graph_id=%u", graph_id);
  318. if (graph_id != last_graph_id_) {
  319. auto ret = FreeExecuteMemory();
  320. if (ret != SUCCESS) {
  321. return ret;
  322. }
  323. }
  324. last_graph_id_ = graph_id;
  325. GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED);
  326. Status ret = AsyncExecuteModel(ge_root_model->GetModelId(), input_tensor);
  327. if (ret != SUCCESS) {
  328. GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] AsyncExecuteModel Error!");
  329. return GE_GRAPH_SYNC_MODEL_FAILED;
  330. }
  331. GELOGI("[GraphExecutor] Async execute graph success, graph_id=%u", graph_id);
  332. return SUCCESS;
  333. }
  334. Status GraphExecutor::AsyncExecuteModel(uint32_t model_id, const std::vector<InputTensorInfo> &inputs) {
  335. try {
  336. auto model_manager = ge::ModelManager::GetInstance();
  337. GE_CHECK_NOTNULL(model_manager);
  338. GELOGI("RunAsync begin.model_id %u", model_id);
  339. Status ret = model_manager->DataInputTensor(model_id, inputs);
  340. if (ret != SUCCESS) {
  341. GELOGE(ret, "RunAsync: DataInput fail");
  342. return ret;
  343. }
  344. GELOGI("RunAsync success.");
  345. } catch (std::bad_alloc &) {
  346. GELOGE(MEMALLOC_FAILED, "RunAsync failed, bad memory allocation occur !");
  347. CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  348. return MEMALLOC_FAILED;
  349. } catch (...) {
  350. GELOGE(FAILED, "RunAsync failed, some exceptions occur !");
  351. CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  352. return FAILED;
  353. }
  354. return SUCCESS;
  355. }
  356. Status GraphExecutor::DataInput(const InputData &input_data, OutputData &output_data) {
  357. try {
  358. auto model_manager = ge::ModelManager::GetInstance();
  359. GE_CHECK_NOTNULL(model_manager);
  360. Status ret = model_manager->DataInput(input_data, output_data);
  361. if (ret != SUCCESS) {
  362. GELOGE(ret, "DataInput: DataInput failed.");
  363. CsaInteract::GetInstance().WriteErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  364. return ret;
  365. }
  366. } catch (std::bad_alloc &) {
  367. GELOGE(MEMALLOC_FAILED, "DataInput failed, bad memory allocation occur !");
  368. CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  369. return MEMALLOC_FAILED;
  370. } catch (...) {
  371. GELOGE(FAILED, "DataInput failed, some exceptions occur !");
  372. CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  373. return FAILED;
  374. }
  375. return SUCCESS;
  376. }
  377. Status GraphExecutor::GetInputOutputDescInfo(const uint32_t model_id, vector<InputOutputDescInfo> &input_desc,
  378. vector<InputOutputDescInfo> &output_desc) {
  379. try {
  380. auto model_manager = ge::ModelManager::GetInstance();
  381. GE_CHECK_NOTNULL(model_manager);
  382. Status ret = model_manager->GetInputOutputDescInfo(model_id, input_desc, output_desc);
  383. if (ret != SUCCESS) {
  384. GELOGE(ret, "GetInputOutputDescInfo failed.");
  385. CsaInteract::GetInstance().WriteErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  386. return ret;
  387. }
  388. } catch (std::bad_alloc &) {
  389. GELOGE(MEMALLOC_FAILED, "GetInputOutputDescInfo failed, bad memory allocation occur !");
  390. CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  391. return MEMALLOC_FAILED;
  392. } catch (...) {
  393. GELOGE(FAILED, "GetInputOutputDescInfo failed, some exceptions occur !");
  394. CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  395. return FAILED;
  396. }
  397. return SUCCESS;
  398. }
  399. Status GraphExecutor::GetInputOutputDescInfo(const uint32_t model_id, vector<InputOutputDescInfo> &input_desc,
  400. vector<InputOutputDescInfo> &output_desc,
  401. std::vector<uint32_t> &input_formats, std::vector<uint32_t> &out_formats,
  402. bool new_model_desc) {
  403. try {
  404. auto model_manager = ge::ModelManager::GetInstance();
  405. GE_CHECK_NOTNULL(model_manager);
  406. Status ret = model_manager->GetInputOutputDescInfo(model_id, input_desc, output_desc, input_formats, out_formats,
  407. new_model_desc);
  408. if (ret != SUCCESS) {
  409. GELOGE(ret, "GetInputOutputDescInfo failed.");
  410. CsaInteract::GetInstance().WriteErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  411. return ret;
  412. }
  413. } catch (std::bad_alloc &) {
  414. GELOGE(MEMALLOC_FAILED, "GetInputOutputDescInfo failed, bad memory allocation occur !");
  415. CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  416. return MEMALLOC_FAILED;
  417. } catch (...) {
  418. GELOGE(FAILED, "GetInputOutputDescInfo failed, some exceptions occur !");
  419. CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  420. return FAILED;
  421. }
  422. return SUCCESS;
  423. }
  424. ///
  425. /// @ingroup ge
  426. /// @brief Get dynamic batch_info
  427. /// @param [in] model_id
  428. /// @param [out] batch_info
  429. /// @param [out] dynamic_type
  430. /// @return execute result
  431. ///
  432. Status GraphExecutor::GetDynamicBatchInfo(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info,
  433. int32_t &dynamic_type) {
  434. auto model_manager = ge::ModelManager::GetInstance();
  435. GE_CHECK_NOTNULL(model_manager);
  436. Status ret = model_manager->GetDynamicBatchInfo(model_id, batch_info, dynamic_type);
  437. if (ret != SUCCESS) {
  438. GELOGE(ret, "GetDynamicBatchInfo failed.");
  439. return ret;
  440. }
  441. return SUCCESS;
  442. }
  443. ///
  444. /// @ingroup ge
  445. /// @brief Get combined dynamic dims info
  446. /// @param [in] model_id
  447. /// @param [out] batch_info
  448. /// @return execute result
  449. ///
  450. Status GraphExecutor::GetCombinedDynamicDims(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info) {
  451. auto model_manager = ge::ModelManager::GetInstance();
  452. GE_CHECK_NOTNULL(model_manager);
  453. Status ret = model_manager->GetCombinedDynamicDims(model_id, batch_info);
  454. if (ret != SUCCESS) {
  455. GELOGE(ret, "GetCombinedDynamicDims failed.");
  456. return ret;
  457. }
  458. return SUCCESS;
  459. }
  460. ///
  461. /// @ingroup ge
  462. /// @brief Get user designate shape order
  463. /// @param [in] model_id
  464. /// @param [out] user_input_shape_order
  465. /// @return execute result
  466. ///
  467. ge::Status GraphExecutor::GetUserDesignateShapeOrder(uint32_t model_id,
  468. std::vector<std::string> &user_input_shape_order) {
  469. auto model_manager = ge::ModelManager::GetInstance();
  470. GE_CHECK_NOTNULL(model_manager);
  471. Status ret = model_manager->GetUserDesignateShapeOrder(model_id, user_input_shape_order);
  472. if (ret != SUCCESS) {
  473. GELOGE(ret, "GetUserDesignateShapeOrder failed.");
  474. return ret;
  475. }
  476. return SUCCESS;
  477. }
  478. Status GraphExecutor::GetCurShape(const uint32_t model_id, std::vector<int64_t> &batch_info, int32_t &dynamic_type) {
  479. auto model_manager = ge::ModelManager::GetInstance();
  480. GE_CHECK_NOTNULL(model_manager);
  481. Status ret = model_manager->GetCurShape(model_id, batch_info, dynamic_type);
  482. if (ret != SUCCESS) {
  483. GELOGE(ret, "GetCurShape failed");
  484. return ret;
  485. }
  486. return SUCCESS;
  487. }
  488. Status GraphExecutor::GetModelAttr(uint32_t model_id, std::vector<string> &dynamic_output_shape_info) {
  489. auto model_manager = ge::ModelManager::GetInstance();
  490. GE_CHECK_NOTNULL(model_manager);
  491. Status ret = model_manager->GetModelAttr(model_id, dynamic_output_shape_info);
  492. if (ret != SUCCESS) {
  493. GELOGE(FAILED, "GetModelAttr failed");
  494. return ret;
  495. }
  496. return SUCCESS;
  497. }
  498. Status GraphExecutor::GetAippInfo(uint32_t model_id, uint32_t index, AippConfigInfo &aipp_info) {
  499. auto model_manager = ge::ModelManager::GetInstance();
  500. GE_CHECK_NOTNULL(model_manager);
  501. Status ret = model_manager->GetAippInfo(model_id, index, aipp_info);
  502. if (ret != SUCCESS) {
  503. GELOGW("GetAIPPInfo is not success.");
  504. return ret;
  505. }
  506. return SUCCESS;
  507. }
  508. Status GraphExecutor::GetAippType(uint32_t model_id, uint32_t index, InputAippType &type, size_t &aipp_index) {
  509. auto model_manager = ge::ModelManager::GetInstance();
  510. GE_CHECK_NOTNULL(model_manager);
  511. Status ret = model_manager->GetAippType(model_id, index, type, aipp_index);
  512. if (ret != SUCCESS) {
  513. GELOGW("Get aipp type is not success.");
  514. return ret;
  515. }
  516. return SUCCESS;
  517. }
  518. Status GraphExecutor::GetOrigInputInfo(uint32_t model_id, uint32_t index, OriginInputInfo &orig_input_info) {
  519. auto model_manager = ge::ModelManager::GetInstance();
  520. GE_CHECK_NOTNULL(model_manager);
  521. Status ret = model_manager->GetOrigInputInfo(model_id, index, orig_input_info);
  522. if (ret != SUCCESS) {
  523. GELOGE(ret, "GetOrigInputInfo failed.");
  524. return ret;
  525. }
  526. return SUCCESS;
  527. }
  528. Status GraphExecutor::GetAllAippInputOutputDims(uint32_t model_id, uint32_t index,
  529. std::vector<InputOutputDims> &input_dims,
  530. std::vector<InputOutputDims> &output_dims) {
  531. auto model_manager = ge::ModelManager::GetInstance();
  532. GE_CHECK_NOTNULL(model_manager);
  533. Status ret = model_manager->GetAllAippInputOutputDims(model_id, index, input_dims, output_dims);
  534. if (ret != SUCCESS) {
  535. GELOGE(ret, "GetAllAippInputOutputDims failed.");
  536. return ret;
  537. }
  538. return SUCCESS;
  539. }
  540. Status GraphExecutor::GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id,
  541. OpDescInfo &op_desc_info) {
  542. auto model_manager = ge::ModelManager::GetInstance();
  543. GE_CHECK_NOTNULL(model_manager);
  544. Status ret = model_manager->GetOpDescInfo(device_id, stream_id, task_id, op_desc_info);
  545. if (ret != SUCCESS) {
  546. GELOGE(ret, "GetOpDescInfo failed.");
  547. return ret;
  548. }
  549. return SUCCESS;
  550. }
  551. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示