You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_execute.cc 25 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/execute/graph_execute.h"
  17. #include <memory>
  18. #include <string>
  19. #include "graph/load/model_manager/model_manager.h"
  20. #include "graph/load/model_manager/davinci_model.h"
  21. #include "omm/csa_interact.h"
  22. namespace ge {
  23. using Uint32Pair = pair<uint32_t, uint32_t>;
  24. const uint32_t kInvalidModelId = UINT32_MAX;
  25. GraphExecutor::GraphExecutor()
  26. : init_flag_(false),
  27. train_graph_flag_(false),
  28. sync_run_mutex_(nullptr),
  29. condition_(nullptr),
  30. graph_run_listener_(nullptr),
  31. graph_context_(nullptr),
  32. last_graph_id_(UINT32_MAX),
  33. malloc_flag_(false) {}
  34. GraphExecutor::~GraphExecutor() {
  35. outputs_desc_.clear();
  36. if (malloc_flag_) {
  37. for (auto &buffer_addr : buffer_addr_) {
  38. rtError_t rt_ret;
  39. rt_ret = rtFreeHost(buffer_addr);
  40. if (rt_ret != RT_ERROR_NONE) {
  41. GELOGE(RT_FAILED, "[GraphManager] subgraph free buffer failed, ret: 0x%X", rt_ret);
  42. }
  43. }
  44. }
  45. malloc_flag_ = false;
  46. buffer_addr_.clear();
  47. }
  48. Status GraphExecutor::SetCondition(std::mutex *mutex, std::condition_variable *cond,
  49. std::shared_ptr<GraphModelListener> listener) {
  50. if (mutex == nullptr) {
  51. GELOGE(GE_GRAPH_PARAM_NULLPTR, "[SetCondition] input param mutex is nullptr.");
  52. return GE_GRAPH_PARAM_NULLPTR;
  53. }
  54. if (cond == nullptr) {
  55. GELOGE(GE_GRAPH_PARAM_NULLPTR, "[SetCondition] input param cond is nullptr.");
  56. return GE_GRAPH_PARAM_NULLPTR;
  57. }
  58. if (listener == nullptr) {
  59. GELOGE(GE_GRAPH_PARAM_NULLPTR, "[SetCondition] input param listener is nullptr.");
  60. return GE_GRAPH_PARAM_NULLPTR;
  61. }
  62. sync_run_mutex_ = mutex;
  63. condition_ = cond;
  64. graph_run_listener_ = listener;
  65. init_flag_ = true;
  66. return SUCCESS;
  67. }
  68. Status GraphExecutor::SetGraphContext(GraphContextPtr graph_context_ptr) {
  69. if (graph_context_ptr == nullptr) {
  70. GELOGE(GE_GRAPH_PARAM_NULLPTR, "[SetGraphContext] input param graph_context_ptr is nullptr");
  71. return GE_GRAPH_PARAM_NULLPTR;
  72. }
  73. graph_context_ = graph_context_ptr;
  74. return SUCCESS;
  75. }
  76. Status GraphExecutor::SetDynamicSize(uint32_t model_id, const std::vector<uint64_t> &batch_num, int32_t dynamic_type) {
  77. auto model_manager = ge::ModelManager::GetInstance();
  78. GE_CHECK_NOTNULL(model_manager);
  79. Status ret = model_manager->SetDynamicSize(model_id, batch_num, dynamic_type);
  80. if (ret != SUCCESS) {
  81. GELOGE(ret, "SetDynamicSize failed");
  82. return ret;
  83. }
  84. return SUCCESS;
  85. }
  86. void GraphExecutor::SetTrainFlag(bool is_train_graph) { train_graph_flag_ = is_train_graph; }
  87. Status GraphExecutor::FreeInOutBuffer() {
  88. if (malloc_flag_) {
  89. for (auto iter = buffer_addr_.begin(); iter != buffer_addr_.end(); ++iter) {
  90. rtError_t rt_ret;
  91. rt_ret = rtFreeHost(*iter);
  92. if (rt_ret != RT_ERROR_NONE) {
  93. GELOGE(RT_FAILED, "[GraphManager] subgraph free buffer failed, ret: 0x%X", rt_ret);
  94. (void)buffer_addr_.erase(buffer_addr_.begin(), iter);
  95. return GE_GRAPH_FREE_FAILED;
  96. }
  97. }
  98. buffer_addr_.clear();
  99. malloc_flag_ = false;
  100. return SUCCESS;
  101. } else {
  102. GELOGD("[GraphManager] not malloc buffer.");
  103. return SUCCESS;
  104. }
  105. }
  106. Status GraphExecutor::MallocInOutBuffer(const std::vector<uint64_t> &buffer_size, std::vector<void *> &data_addr) {
  107. if (malloc_flag_) {
  108. auto all_size_same = true;
  109. if (buffer_size.size() == buffer_size_.size()) {
  110. for (size_t i = 0; i < buffer_size.size(); i++) {
  111. if (buffer_size[i] != buffer_size_[i]) {
  112. all_size_same = false;
  113. break;
  114. }
  115. }
  116. } else {
  117. all_size_same = false;
  118. }
  119. if (all_size_same) {
  120. data_addr = buffer_addr_;
  121. return SUCCESS;
  122. }
  123. buffer_size_.clear();
  124. auto rt_ret = FreeInOutBuffer();
  125. if (rt_ret != SUCCESS) {
  126. GELOGE(RT_FAILED, "[SubGraphInfo] MallocInOutBuffer free buffer failed, ret: 0x%X", rt_ret);
  127. return RT_FAILED;
  128. }
  129. }
  130. rtError_t rt_ret;
  131. for (size_t i = 0; i < buffer_size.size(); ++i) {
  132. void *tmp_buf = nullptr;
  133. rt_ret = rtMallocHost(&tmp_buf, buffer_size[i]);
  134. if (rt_ret != RT_ERROR_NONE) {
  135. GELOGE(RT_FAILED, "[GraphManager] subgraph malloc buffer failed, ret: 0x%X", rt_ret);
  136. return GE_GRAPH_MALLOC_FAILED;
  137. }
  138. malloc_flag_ = true;
  139. data_addr.push_back(tmp_buf);
  140. buffer_addr_.push_back(tmp_buf);
  141. }
  142. buffer_size_ = buffer_size;
  143. return SUCCESS;
  144. }
  145. Status GraphExecutor::PrepareInputData(const std::vector<GeTensor> &input_tensor, InputData &graph_input_data,
  146. OutputData &graph_output_data, std::vector<InputOutputDescInfo> &output_desc) {
  147. // Preprocessing input data
  148. graph_input_data.index = 0;
  149. graph_input_data.timeout = 0;
  150. graph_input_data.timestamp = 0;
  151. std::size_t inputSize = input_tensor.size();
  152. std::size_t output_size = output_desc.size();
  153. std::vector<uint64_t> bufferSizeVec;
  154. std::vector<void *> addrVec;
  155. for (std::size_t i = 0; i < inputSize; ++i) {
  156. const GeTensor *InTensor = &input_tensor[i];
  157. GE_CHECK_NOTNULL(InTensor);
  158. bufferSizeVec.push_back(InTensor->GetData().size());
  159. }
  160. for (const auto &desc : output_desc) {
  161. bufferSizeVec.push_back(desc.size);
  162. }
  163. Status ret = MallocInOutBuffer(bufferSizeVec, addrVec);
  164. if (ret != SUCCESS) {
  165. GELOGE(GE_GRAPH_MALLOC_FAILED, "[GraphExecutor] Malloc mem failed");
  166. return GE_GRAPH_MALLOC_FAILED;
  167. }
  168. for (std::size_t i = 0; i < input_tensor.size() && i < addrVec.size(); ++i) {
  169. const GeTensor *in_tensor = &input_tensor[i];
  170. GE_CHECK_NOTNULL(in_tensor);
  171. if ((addrVec[i] != nullptr) && (in_tensor->GetData().data() != nullptr)) {
  172. rtError_t rt_ret = rtMemcpy(addrVec[i], bufferSizeVec[i], in_tensor->GetData().data(),
  173. in_tensor->GetData().size(), RT_MEMCPY_HOST_TO_HOST);
  174. if (rt_ret != RT_ERROR_NONE) {
  175. GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
  176. return RT_FAILED;
  177. }
  178. }
  179. DataBuffer in_data_buf;
  180. in_data_buf.data = reinterpret_cast<uint8_t *>(addrVec[i]);
  181. in_data_buf.length = in_tensor->GetData().size();
  182. in_data_buf.isDataSupportMemShare = false;
  183. graph_input_data.blobs.push_back(in_data_buf);
  184. }
  185. graph_output_data.index = 0;
  186. for (std::size_t j = 0; j < output_size; j++) {
  187. auto desc = output_desc[j];
  188. uint64_t buffer_size = desc.size;
  189. DataBuffer out_data_buf;
  190. out_data_buf.data = reinterpret_cast<uint8_t *>(addrVec[inputSize + j]);
  191. out_data_buf.length = buffer_size;
  192. out_data_buf.isDataSupportMemShare = false;
  193. graph_output_data.blobs.push_back(out_data_buf);
  194. }
  195. return SUCCESS;
  196. }
  197. Status GraphExecutor::SyncExecuteModel(uint32_t model_id, const std::vector<GeTensor> &input_tensor,
  198. std::vector<GeTensor> &output_tensor) {
  199. auto model_manager = ge::ModelManager::GetInstance();
  200. GE_CHECK_NOTNULL(model_manager);
  201. if (model_manager->IsDynamicShape(model_id)) {
  202. GELOGI("[ExecuteGraph] GetInputOutputDescInfo via dynamic shape model executor, modelId=%u", model_id);
  203. return model_manager->SyncExecuteModel(model_id, input_tensor, output_tensor);
  204. }
  205. // Prepare input and output
  206. std::vector<InputOutputDescInfo> inputs_desc;
  207. std::vector<InputOutputDescInfo> output_desc;
  208. GELOGI("[ExecuteGraph] GetInputOutputDescInfo via new ome begin.");
  209. Status ret = GetInputOutputDescInfo(model_id, inputs_desc, output_desc);
  210. if (ret != SUCCESS) {
  211. GELOGE(GE_GRAPH_GET_IN_OUT_FAILED, "[GraphExecutor] GetInputOutputDescInfo failed, modelId=%u.", model_id);
  212. return GE_GRAPH_GET_IN_OUT_FAILED;
  213. }
  214. outputs_desc_.assign(output_desc.begin(), output_desc.end());
  215. InputData input_data;
  216. OutputData output_data;
  217. input_data.model_id = model_id;
  218. ret = PrepareInputData(input_tensor, input_data, output_data, output_desc);
  219. if (ret != SUCCESS) {
  220. GELOGE(GE_GRAPH_PREPARE_FAILED, "[GraphExecutor] PrepareInputData failed, modelId=%u.", model_id);
  221. return GE_GRAPH_PREPARE_FAILED;
  222. }
  223. if (graph_run_listener_->ResetResult() != SUCCESS) {
  224. GELOGE(GE_GRAPH_EXECUTE_FAILED, "Reset result failed");
  225. return GE_GRAPH_EXECUTE_FAILED;
  226. }
  227. // Run mode async
  228. GELOGI("[ExecuteGraph] DataInput via new ome begin.");
  229. ret = DataInput(input_data, output_data);
  230. if (ret != SUCCESS) {
  231. GELOGE(GE_GRAPH_DATA_INPUT_FAILED, "[GraphExecutor] push data failed, modelId=%u.", model_id);
  232. return GE_GRAPH_DATA_INPUT_FAILED;
  233. }
  234. GELOGI("[GraphExecutor] input data push to wrapper finish, waiting for result...");
  235. // Pending until async execute graph complete
  236. {
  237. std::unique_lock<std::mutex> ulock(*sync_run_mutex_);
  238. if (!graph_run_listener_->IsFinished()) {
  239. (*condition_).wait(ulock);
  240. }
  241. // Run graph return
  242. uint32_t result_code = graph_run_listener_->GetResultCode();
  243. if (result_code != SUCCESS && result_code != END_OF_SEQUENCE) {
  244. GELOGE(GE_GRAPH_EXECUTE_FAILED, "[GraphExecutor] execute model failed, ret=%u, modelId=%u.", result_code,
  245. model_id);
  246. return GE_GRAPH_EXECUTE_FAILED;
  247. }
  248. }
  249. for (size_t i = 0; i < output_data.blobs.size(); ++i) {
  250. DataBuffer outputDataTmp = output_data.blobs[i];
  251. CHECK_FALSE_EXEC(outputDataTmp.length != 0,
  252. GELOGE(GE_GRAPH_EXECUTE_FAILED, "Failed to allocate memory, length is 0.");
  253. return GE_GRAPH_EXECUTE_FAILED);
  254. std::unique_ptr<uint8_t> outBufTmp(new (std::nothrow) uint8_t[outputDataTmp.length]);
  255. if (outBufTmp == nullptr) {
  256. GELOGE(FAILED, "Failed to allocate memory.");
  257. return FAILED;
  258. }
  259. GE_PRINT_DYNAMIC_MEMORY(new, "the output memory of data on training.", sizeof(uint8_t) * outputDataTmp.length)
  260. rtError_t ret_value = rtMemcpy(outBufTmp.get(), outputDataTmp.length, outputDataTmp.data, outputDataTmp.length,
  261. RT_MEMCPY_HOST_TO_HOST);
  262. CHECK_FALSE_EXEC(ret_value == RT_ERROR_NONE,
  263. GELOGE(GE_GRAPH_EXECUTE_FAILED, "Call rt api rtMemcpy failed, ret: 0x%X", ret);
  264. return GE_GRAPH_EXECUTE_FAILED);
  265. GeTensor outTensor;
  266. std::vector<int64_t> shapeDims;
  267. for (const auto &dim : output_desc[i].shape_info.dims) {
  268. shapeDims.push_back(dim);
  269. }
  270. GeShape outShape(shapeDims);
  271. outTensor.MutableTensorDesc().SetShape(outShape);
  272. outTensor.MutableTensorDesc().SetDataType((DataType)output_desc[i].data_type);
  273. (void)outTensor.SetData(outBufTmp.get(), outputDataTmp.length);
  274. output_tensor.push_back(outTensor);
  275. }
  276. GELOGI("[GraphExecutor] execute model success, modelId=%u.", model_id);
  277. return SUCCESS;
  278. }
  279. void GraphExecutor::InitModelIdInfo(std::vector<uint32_t> &out_model_id_info,
  280. std::vector<SubGraphInfoPtr> &sub_graph_vec, uint32_t output_size) {
  281. for (uint32_t i = 0; i < output_size; i++) {
  282. for (size_t j = 0; j < sub_graph_vec.size(); j++) {
  283. if (sub_graph_vec[j]->GetOutputFlag().size() == output_size && sub_graph_vec[j]->GetOutputFlag().at(i)) {
  284. out_model_id_info.push_back(sub_graph_vec[j]->GetModelIdInfo().model_id);
  285. }
  286. }
  287. }
  288. }
  289. Status GraphExecutor::FreeExecuteMemory() {
  290. auto ret = FreeInOutBuffer();
  291. if (ret != SUCCESS) {
  292. GELOGE(ret, "[FreeExecuteMemory] FreeInOutBuffer Error!");
  293. return ret;
  294. }
  295. return SUCCESS;
  296. }
  297. Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeRootModelPtr &ge_root_model,
  298. const std::vector<GeTensor> &input_tensor, std::vector<GeTensor> &output_tensor) {
  299. if (graph_id != last_graph_id_) {
  300. auto ret = FreeExecuteMemory();
  301. if (ret != SUCCESS) {
  302. return ret;
  303. }
  304. }
  305. last_graph_id_ = graph_id;
  306. if (!init_flag_) {
  307. GELOGE(GE_GRAPH_EXECUTE_NOT_INIT, "[GraphExecutor] AI Core Engine without calling SetCondition!");
  308. return GE_GRAPH_EXECUTE_NOT_INIT;
  309. }
  310. GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED);
  311. Status ret = SyncExecuteModel(ge_root_model->GetModelId(), input_tensor, output_tensor);
  312. if (ret != SUCCESS) {
  313. GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] SyncExecuteModel Error!");
  314. return GE_GRAPH_SYNC_MODEL_FAILED;
  315. }
  316. return SUCCESS;
  317. }
  318. Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model,
  319. const std::vector<InputTensorInfo> &input_tensor,
  320. const RunAsyncCallback& callback) {
  321. GELOGI("[GraphExecutor] Start to async execute graph, graph_id=%u", graph_id);
  322. if (graph_id != last_graph_id_) {
  323. auto ret = FreeExecuteMemory();
  324. if (ret != SUCCESS) {
  325. return ret;
  326. }
  327. }
  328. last_graph_id_ = graph_id;
  329. GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED);
  330. Status ret = AsyncExecuteModel(ge_root_model, input_tensor, callback);
  331. if (ret != SUCCESS) {
  332. GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] AsyncExecuteModel Error!");
  333. return GE_GRAPH_SYNC_MODEL_FAILED;
  334. }
  335. GELOGI("[GraphExecutor] Async execute graph success, graph_id=%u", graph_id);
  336. return SUCCESS;
  337. }
  338. bool CompareByLoad(const Uint32Pair &lhs, const Uint32Pair &rhs) {
  339. return lhs.second < rhs.second;
  340. }
  341. uint32_t GraphExecutor::GetExecuteModelId(const GeRootModelPtr &ge_root_model) {
  342. std::vector<uint32_t> model_ids = ge_root_model->GetAllModelId();
  343. if (model_ids.empty()) {
  344. return kInvalidModelId;
  345. }
  346. if (model_ids.size() == 1) {
  347. return ge_root_model->GetModelId();
  348. }
  349. std::vector<Uint32Pair> model_id_to_loads;
  350. auto model_manager = ModelManager::GetInstance();
  351. GE_CHECK_NOTNULL(model_manager);
  352. for (auto model_id : model_ids) {
  353. auto davinci_model = model_manager->GetModel(model_id);
  354. auto hybrid_model = model_manager->GetHybridModel(model_id);
  355. if (hybrid_model == nullptr) {
  356. GE_CHECK_NOTNULL(davinci_model);
  357. }
  358. uint32_t input_load = hybrid_model != nullptr ? hybrid_model->GetDataInputerSize() :
  359. davinci_model->GetDataInputerSize();
  360. uint32_t running_load = hybrid_model != nullptr ? static_cast<uint32_t>(hybrid_model->GetRunningFlag()) :
  361. static_cast<uint32_t>(davinci_model->GetRunningFlag());
  362. uint32_t load = input_load + running_load;
  363. if (load == 0) {
  364. return model_id;
  365. }
  366. model_id_to_loads.emplace_back(model_id, load);
  367. }
  368. sort(model_id_to_loads.begin(), model_id_to_loads.end(), CompareByLoad);
  369. if (model_id_to_loads.empty()) {
  370. return kInvalidModelId;
  371. }
  372. return model_id_to_loads.begin()->first;
  373. }
  374. Status GraphExecutor::SetCallback(uint32_t model_id, const GeRootModelPtr &ge_root_model,
  375. const RunAsyncCallback &callback) {
  376. auto model_manager = ge::ModelManager::GetInstance();
  377. GE_CHECK_NOTNULL(model_manager);
  378. if (model_manager->IsNeedHybridLoad(*ge_root_model)) {
  379. auto model = model_manager->GetHybridModel(model_id);
  380. GE_CHECK_NOTNULL(model);
  381. if (model->SetRunAsyncListenerCallback(callback) != SUCCESS) {
  382. GELOGE(FAILED, "SetRunAsyncListenerCallback failed.");
  383. return FAILED;
  384. }
  385. } else {
  386. auto model = model_manager->GetModel(model_id);
  387. GE_CHECK_NOTNULL(model);
  388. if (model->SetRunAsyncListenerCallback(callback) != SUCCESS) {
  389. GELOGE(FAILED, "SetRunAsyncListenerCallback failed.");
  390. return FAILED;
  391. }
  392. }
  393. return SUCCESS;
  394. }
  395. Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<InputTensorInfo> &inputs,
  396. const RunAsyncCallback &callback) {
  397. uint32_t model_id = GetExecuteModelId(ge_root_model);
  398. if (model_id == kInvalidModelId) {
  399. GELOGE(INTERNAL_ERROR, "No valid model id.");
  400. return INTERNAL_ERROR;
  401. }
  402. try {
  403. auto model_manager = ge::ModelManager::GetInstance();
  404. GE_CHECK_NOTNULL(model_manager);
  405. GELOGI("RunAsync begin.model_id %u", model_id);
  406. if (SetCallback(model_id, ge_root_model, callback) != SUCCESS) {
  407. GELOGE(FAILED, "RunAsync: SetCallBack for model fail");
  408. return FAILED;
  409. }
  410. Status ret = model_manager->DataInputTensor(model_id, inputs);
  411. if (ret != SUCCESS) {
  412. GELOGE(ret, "RunAsync: DataInput fail");
  413. return ret;
  414. }
  415. GELOGI("RunAsync success.");
  416. } catch (std::bad_alloc &) {
  417. GELOGE(MEMALLOC_FAILED, "RunAsync failed, bad memory allocation occur !");
  418. CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  419. return MEMALLOC_FAILED;
  420. } catch (...) {
  421. GELOGE(FAILED, "RunAsync failed, some exceptions occur !");
  422. CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  423. return FAILED;
  424. }
  425. return SUCCESS;
  426. }
  427. Status GraphExecutor::DataInput(const InputData &input_data, OutputData &output_data) {
  428. try {
  429. auto model_manager = ge::ModelManager::GetInstance();
  430. GE_CHECK_NOTNULL(model_manager);
  431. Status ret = model_manager->DataInput(input_data, output_data);
  432. if (ret != SUCCESS) {
  433. GELOGE(ret, "DataInput: DataInput failed.");
  434. CsaInteract::GetInstance().WriteErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  435. return ret;
  436. }
  437. } catch (std::bad_alloc &) {
  438. GELOGE(MEMALLOC_FAILED, "DataInput failed, bad memory allocation occur !");
  439. CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  440. return MEMALLOC_FAILED;
  441. } catch (...) {
  442. GELOGE(FAILED, "DataInput failed, some exceptions occur !");
  443. CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  444. return FAILED;
  445. }
  446. return SUCCESS;
  447. }
  448. Status GraphExecutor::GetInputOutputDescInfo(const uint32_t model_id, vector<InputOutputDescInfo> &input_desc,
  449. vector<InputOutputDescInfo> &output_desc) {
  450. try {
  451. auto model_manager = ge::ModelManager::GetInstance();
  452. GE_CHECK_NOTNULL(model_manager);
  453. Status ret = model_manager->GetInputOutputDescInfo(model_id, input_desc, output_desc);
  454. if (ret != SUCCESS) {
  455. GELOGE(ret, "GetInputOutputDescInfo failed.");
  456. CsaInteract::GetInstance().WriteErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  457. return ret;
  458. }
  459. } catch (std::bad_alloc &) {
  460. GELOGE(MEMALLOC_FAILED, "GetInputOutputDescInfo failed, bad memory allocation occur !");
  461. CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  462. return MEMALLOC_FAILED;
  463. } catch (...) {
  464. GELOGE(FAILED, "GetInputOutputDescInfo failed, some exceptions occur !");
  465. CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  466. return FAILED;
  467. }
  468. return SUCCESS;
  469. }
  470. Status GraphExecutor::GetInputOutputDescInfo(const uint32_t model_id, vector<InputOutputDescInfo> &input_desc,
  471. vector<InputOutputDescInfo> &output_desc,
  472. std::vector<uint32_t> &input_formats, std::vector<uint32_t> &out_formats,
  473. bool new_model_desc) {
  474. try {
  475. auto model_manager = ge::ModelManager::GetInstance();
  476. GE_CHECK_NOTNULL(model_manager);
  477. Status ret = model_manager->GetInputOutputDescInfo(model_id, input_desc, output_desc, input_formats, out_formats,
  478. new_model_desc);
  479. if (ret != SUCCESS) {
  480. GELOGE(ret, "GetInputOutputDescInfo failed.");
  481. CsaInteract::GetInstance().WriteErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  482. return ret;
  483. }
  484. } catch (std::bad_alloc &) {
  485. GELOGE(MEMALLOC_FAILED, "GetInputOutputDescInfo failed, bad memory allocation occur !");
  486. CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  487. return MEMALLOC_FAILED;
  488. } catch (...) {
  489. GELOGE(FAILED, "GetInputOutputDescInfo failed, some exceptions occur !");
  490. CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  491. return FAILED;
  492. }
  493. return SUCCESS;
  494. }
  495. ///
  496. /// @ingroup ge
  497. /// @brief Get dynamic batch_info
  498. /// @param [in] model_id
  499. /// @param [out] batch_info
  500. /// @param [out] dynamic_type
  501. /// @return execute result
  502. ///
  503. Status GraphExecutor::GetDynamicBatchInfo(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info,
  504. int32_t &dynamic_type) {
  505. auto model_manager = ge::ModelManager::GetInstance();
  506. GE_CHECK_NOTNULL(model_manager);
  507. Status ret = model_manager->GetDynamicBatchInfo(model_id, batch_info, dynamic_type);
  508. if (ret != SUCCESS) {
  509. GELOGE(ret, "GetDynamicBatchInfo failed.");
  510. return ret;
  511. }
  512. return SUCCESS;
  513. }
  514. ///
  515. /// @ingroup ge
  516. /// @brief Get combined dynamic dims info
  517. /// @param [in] model_id
  518. /// @param [out] batch_info
  519. /// @return execute result
  520. ///
  521. Status GraphExecutor::GetCombinedDynamicDims(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info) {
  522. auto model_manager = ge::ModelManager::GetInstance();
  523. GE_CHECK_NOTNULL(model_manager);
  524. Status ret = model_manager->GetCombinedDynamicDims(model_id, batch_info);
  525. if (ret != SUCCESS) {
  526. GELOGE(ret, "GetCombinedDynamicDims failed.");
  527. return ret;
  528. }
  529. return SUCCESS;
  530. }
  531. ///
  532. /// @ingroup ge
  533. /// @brief Get user designate shape order
  534. /// @param [in] model_id
  535. /// @param [out] user_input_shape_order
  536. /// @return execute result
  537. ///
  538. ge::Status GraphExecutor::GetUserDesignateShapeOrder(uint32_t model_id,
  539. std::vector<std::string> &user_input_shape_order) {
  540. auto model_manager = ge::ModelManager::GetInstance();
  541. GE_CHECK_NOTNULL(model_manager);
  542. Status ret = model_manager->GetUserDesignateShapeOrder(model_id, user_input_shape_order);
  543. if (ret != SUCCESS) {
  544. GELOGE(ret, "GetUserDesignateShapeOrder failed.");
  545. return ret;
  546. }
  547. return SUCCESS;
  548. }
  549. Status GraphExecutor::GetCurShape(const uint32_t model_id, std::vector<int64_t> &batch_info, int32_t &dynamic_type) {
  550. auto model_manager = ge::ModelManager::GetInstance();
  551. GE_CHECK_NOTNULL(model_manager);
  552. Status ret = model_manager->GetCurShape(model_id, batch_info, dynamic_type);
  553. if (ret != SUCCESS) {
  554. GELOGE(ret, "GetCurShape failed");
  555. return ret;
  556. }
  557. return SUCCESS;
  558. }
  559. Status GraphExecutor::GetModelAttr(uint32_t model_id, std::vector<string> &dynamic_output_shape_info) {
  560. auto model_manager = ge::ModelManager::GetInstance();
  561. GE_CHECK_NOTNULL(model_manager);
  562. Status ret = model_manager->GetModelAttr(model_id, dynamic_output_shape_info);
  563. if (ret != SUCCESS) {
  564. GELOGE(FAILED, "GetModelAttr failed");
  565. return ret;
  566. }
  567. return SUCCESS;
  568. }
  569. Status GraphExecutor::GetAippInfo(uint32_t model_id, uint32_t index, AippConfigInfo &aipp_info) {
  570. auto model_manager = ge::ModelManager::GetInstance();
  571. GE_CHECK_NOTNULL(model_manager);
  572. Status ret = model_manager->GetAippInfo(model_id, index, aipp_info);
  573. if (ret != SUCCESS) {
  574. GELOGW("GetAIPPInfo is not success.");
  575. return ret;
  576. }
  577. return SUCCESS;
  578. }
  579. Status GraphExecutor::GetAippType(uint32_t model_id, uint32_t index, InputAippType &type, size_t &aipp_index) {
  580. auto model_manager = ge::ModelManager::GetInstance();
  581. GE_CHECK_NOTNULL(model_manager);
  582. Status ret = model_manager->GetAippType(model_id, index, type, aipp_index);
  583. if (ret != SUCCESS) {
  584. GELOGW("Get aipp type is not success.");
  585. return ret;
  586. }
  587. return SUCCESS;
  588. }
  589. Status GraphExecutor::GetOrigInputInfo(uint32_t model_id, uint32_t index, OriginInputInfo &orig_input_info) {
  590. auto model_manager = ge::ModelManager::GetInstance();
  591. GE_CHECK_NOTNULL(model_manager);
  592. Status ret = model_manager->GetOrigInputInfo(model_id, index, orig_input_info);
  593. if (ret != SUCCESS) {
  594. GELOGE(ret, "GetOrigInputInfo failed.");
  595. return ret;
  596. }
  597. return SUCCESS;
  598. }
  599. Status GraphExecutor::GetAllAippInputOutputDims(uint32_t model_id, uint32_t index,
  600. std::vector<InputOutputDims> &input_dims,
  601. std::vector<InputOutputDims> &output_dims) {
  602. auto model_manager = ge::ModelManager::GetInstance();
  603. GE_CHECK_NOTNULL(model_manager);
  604. Status ret = model_manager->GetAllAippInputOutputDims(model_id, index, input_dims, output_dims);
  605. if (ret != SUCCESS) {
  606. GELOGE(ret, "GetAllAippInputOutputDims failed.");
  607. return ret;
  608. }
  609. return SUCCESS;
  610. }
  611. Status GraphExecutor::GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id,
  612. OpDescInfo &op_desc_info) {
  613. auto model_manager = ge::ModelManager::GetInstance();
  614. GE_CHECK_NOTNULL(model_manager);
  615. Status ret = model_manager->GetOpDescInfo(device_id, stream_id, task_id, op_desc_info);
  616. if (ret != SUCCESS) {
  617. GELOGE(ret, "GetOpDescInfo failed.");
  618. return ret;
  619. }
  620. return SUCCESS;
  621. }
  622. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示