You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_execute.cc 30 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/execute/graph_execute.h"
  17. #include <memory>
  18. #include <string>
  19. #include "graph/load/model_manager/model_manager.h"
  20. #include "graph/load/model_manager/davinci_model.h"
  21. namespace ge {
  22. using Uint32Pair = pair<uint32_t, uint32_t>;
  23. const uint32_t kInvalidModelId = UINT32_MAX;
  24. GraphExecutor::GraphExecutor()
  25. : init_flag_(false),
  26. train_graph_flag_(false),
  27. sync_run_mutex_(nullptr),
  28. condition_(nullptr),
  29. graph_run_listener_(nullptr),
  30. graph_context_(nullptr),
  31. last_graph_id_(UINT32_MAX),
  32. malloc_flag_(false) {}
  33. GraphExecutor::~GraphExecutor() {
  34. outputs_desc_.clear();
  35. if (malloc_flag_) {
  36. for (auto &buffer_addr : buffer_addr_) {
  37. rtError_t rt_ret;
  38. rt_ret = rtFreeHost(buffer_addr);
  39. if (rt_ret != RT_ERROR_NONE) {
  40. REPORT_CALL_ERROR("E19999", "Call rtFreeHost failed, ret:0x%X", rt_ret);
  41. GELOGE(RT_FAILED, "[GraphManager] subgraph free buffer failed, ret: 0x%X", rt_ret);
  42. }
  43. }
  44. }
  45. malloc_flag_ = false;
  46. buffer_addr_.clear();
  47. }
  48. Status GraphExecutor::SetCondition(std::mutex *mutex, std::condition_variable *cond,
  49. std::shared_ptr<GraphModelListener> listener) {
  50. if (mutex == nullptr) {
  51. REPORT_INNER_ERROR("E19999", "Check param mutex nullptr");
  52. GELOGE(GE_GRAPH_PARAM_NULLPTR, "[SetCondition] input param mutex is nullptr.");
  53. return GE_GRAPH_PARAM_NULLPTR;
  54. }
  55. if (cond == nullptr) {
  56. REPORT_INNER_ERROR("E19999", "Check param cond nullptr");
  57. GELOGE(GE_GRAPH_PARAM_NULLPTR, "[SetCondition] input param cond is nullptr.");
  58. return GE_GRAPH_PARAM_NULLPTR;
  59. }
  60. if (listener == nullptr) {
  61. REPORT_INNER_ERROR("E19999", "Check param listener nullptr");
  62. GELOGE(GE_GRAPH_PARAM_NULLPTR, "[SetCondition] input param listener is nullptr.");
  63. return GE_GRAPH_PARAM_NULLPTR;
  64. }
  65. sync_run_mutex_ = mutex;
  66. condition_ = cond;
  67. graph_run_listener_ = listener;
  68. init_flag_ = true;
  69. return SUCCESS;
  70. }
  71. Status GraphExecutor::SetGraphContext(GraphContextPtr graph_context_ptr) {
  72. if (graph_context_ptr == nullptr) {
  73. REPORT_INNER_ERROR("E19999", "Check param graph_context_ptr nullptr");
  74. GELOGE(GE_GRAPH_PARAM_NULLPTR, "[SetGraphContext] input param graph_context_ptr is nullptr");
  75. return GE_GRAPH_PARAM_NULLPTR;
  76. }
  77. graph_context_ = graph_context_ptr;
  78. return SUCCESS;
  79. }
  80. Status GraphExecutor::SetDynamicSize(uint32_t model_id, const std::vector<uint64_t> &batch_num, int32_t dynamic_type) {
  81. auto model_manager = ge::ModelManager::GetInstance();
  82. GE_CHECK_NOTNULL(model_manager);
  83. Status ret = model_manager->SetDynamicSize(model_id, batch_num, dynamic_type);
  84. if (ret != SUCCESS) {
  85. GELOGE(ret, "SetDynamicSize failed");
  86. return ret;
  87. }
  88. return SUCCESS;
  89. }
  90. void GraphExecutor::SetTrainFlag(bool is_train_graph) { train_graph_flag_ = is_train_graph; }
  91. Status GraphExecutor::FreeInOutBuffer() {
  92. if (malloc_flag_) {
  93. for (auto iter = buffer_addr_.begin(); iter != buffer_addr_.end(); ++iter) {
  94. rtError_t rt_ret;
  95. rt_ret = rtFreeHost(*iter);
  96. if (rt_ret != RT_ERROR_NONE) {
  97. REPORT_CALL_ERROR("E19999", "Call rtFreeHost failed, ret:0x%X", rt_ret);
  98. GELOGE(RT_FAILED, "[GraphManager] subgraph free buffer failed, ret: 0x%X", rt_ret);
  99. (void)buffer_addr_.erase(buffer_addr_.begin(), iter);
  100. return GE_GRAPH_FREE_FAILED;
  101. }
  102. }
  103. buffer_addr_.clear();
  104. malloc_flag_ = false;
  105. return SUCCESS;
  106. } else {
  107. GELOGD("[GraphManager] not malloc buffer.");
  108. return SUCCESS;
  109. }
  110. }
  111. Status GraphExecutor::MallocInOutBuffer(const std::vector<uint64_t> &buffer_size, std::vector<void *> &data_addr) {
  112. if (malloc_flag_) {
  113. auto all_size_same = true;
  114. if (buffer_size.size() == buffer_size_.size()) {
  115. for (size_t i = 0; i < buffer_size.size(); i++) {
  116. if (buffer_size[i] != buffer_size_[i]) {
  117. all_size_same = false;
  118. break;
  119. }
  120. }
  121. } else {
  122. all_size_same = false;
  123. }
  124. if (all_size_same) {
  125. data_addr = buffer_addr_;
  126. return SUCCESS;
  127. }
  128. buffer_size_.clear();
  129. auto rt_ret = FreeInOutBuffer();
  130. if (rt_ret != SUCCESS) {
  131. GELOGE(RT_FAILED, "[SubGraphInfo] MallocInOutBuffer free buffer failed, ret: 0x%X", rt_ret);
  132. return RT_FAILED;
  133. }
  134. }
  135. rtError_t rt_ret;
  136. for (size_t i = 0; i < buffer_size.size(); ++i) {
  137. void *tmp_buf = nullptr;
  138. rt_ret = rtMallocHost(&tmp_buf, buffer_size[i]);
  139. if (rt_ret != RT_ERROR_NONE) {
  140. REPORT_CALL_ERROR("E19999", "Call rtMallocHost failed, size:%lu, ret:0x%X",
  141. buffer_size[i], rt_ret);
  142. GELOGE(RT_FAILED, "[GraphManager] subgraph malloc buffer failed, ret: 0x%X", rt_ret);
  143. return GE_GRAPH_MALLOC_FAILED;
  144. }
  145. malloc_flag_ = true;
  146. data_addr.push_back(tmp_buf);
  147. buffer_addr_.push_back(tmp_buf);
  148. }
  149. buffer_size_ = buffer_size;
  150. return SUCCESS;
  151. }
  152. Status GraphExecutor::PrepareInputData(const std::vector<GeTensor> &input_tensor, InputData &graph_input_data,
  153. OutputData &graph_output_data, std::vector<InputOutputDescInfo> &output_desc) {
  154. // Preprocessing input data
  155. graph_input_data.index = 0;
  156. graph_input_data.timeout = 0;
  157. graph_input_data.timestamp = 0;
  158. std::size_t inputSize = input_tensor.size();
  159. std::size_t output_size = output_desc.size();
  160. std::vector<uint64_t> bufferSizeVec;
  161. std::vector<void *> addrVec;
  162. for (std::size_t i = 0; i < inputSize; ++i) {
  163. const GeTensor *InTensor = &input_tensor[i];
  164. GE_CHECK_NOTNULL(InTensor);
  165. bufferSizeVec.push_back(InTensor->GetData().size());
  166. }
  167. for (const auto &desc : output_desc) {
  168. bufferSizeVec.push_back(desc.size);
  169. }
  170. Status ret = MallocInOutBuffer(bufferSizeVec, addrVec);
  171. if (ret != SUCCESS) {
  172. GELOGE(GE_GRAPH_MALLOC_FAILED, "[GraphExecutor] Malloc mem failed");
  173. return GE_GRAPH_MALLOC_FAILED;
  174. }
  175. for (std::size_t i = 0; i < input_tensor.size() && i < addrVec.size(); ++i) {
  176. const GeTensor *in_tensor = &input_tensor[i];
  177. GE_CHECK_NOTNULL(in_tensor);
  178. if ((addrVec[i] != nullptr) && (in_tensor->GetData().data() != nullptr)) {
  179. rtError_t rt_ret = rtMemcpy(addrVec[i], bufferSizeVec[i], in_tensor->GetData().data(),
  180. in_tensor->GetData().size(), RT_MEMCPY_HOST_TO_HOST);
  181. if (rt_ret != RT_ERROR_NONE) {
  182. REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, dst_size:%lu, src_size:%zu, ret:0x%X",
  183. bufferSizeVec[i], in_tensor->GetData().size(), rt_ret);
  184. GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
  185. return RT_FAILED;
  186. }
  187. }
  188. DataBuffer in_data_buf;
  189. in_data_buf.data = reinterpret_cast<uint8_t *>(addrVec[i]);
  190. in_data_buf.length = in_tensor->GetData().size();
  191. in_data_buf.isDataSupportMemShare = false;
  192. graph_input_data.blobs.push_back(in_data_buf);
  193. }
  194. graph_output_data.index = 0;
  195. for (std::size_t j = 0; j < output_size; j++) {
  196. auto desc = output_desc[j];
  197. uint64_t buffer_size = desc.size;
  198. DataBuffer out_data_buf;
  199. out_data_buf.data = reinterpret_cast<uint8_t *>(addrVec[inputSize + j]);
  200. out_data_buf.length = buffer_size;
  201. out_data_buf.isDataSupportMemShare = false;
  202. graph_output_data.blobs.push_back(out_data_buf);
  203. }
  204. return SUCCESS;
  205. }
  206. Status GraphExecutor::SyncExecuteModel(uint32_t model_id, const std::vector<GeTensor> &input_tensor,
  207. std::vector<GeTensor> &output_tensor) {
  208. auto model_manager = ge::ModelManager::GetInstance();
  209. GE_CHECK_NOTNULL(model_manager);
  210. if (model_manager->IsDynamicShape(model_id)) {
  211. GELOGI("[ExecuteGraph] GetInputOutputDescInfo via dynamic shape model executor, modelId=%u", model_id);
  212. return model_manager->SyncExecuteModel(model_id, input_tensor, output_tensor);
  213. }
  214. // Prepare input and output
  215. std::vector<InputOutputDescInfo> inputs_desc;
  216. std::vector<InputOutputDescInfo> output_desc;
  217. GELOGI("[ExecuteGraph] GetInputOutputDescInfo via new ome begin.");
  218. Status ret = GetInputOutputDescInfo(model_id, inputs_desc, output_desc);
  219. if (ret != SUCCESS) {
  220. GELOGE(GE_GRAPH_GET_IN_OUT_FAILED, "[GraphExecutor] GetInputOutputDescInfo failed, modelId=%u.", model_id);
  221. return GE_GRAPH_GET_IN_OUT_FAILED;
  222. }
  223. outputs_desc_.assign(output_desc.begin(), output_desc.end());
  224. InputData input_data;
  225. OutputData output_data;
  226. input_data.model_id = model_id;
  227. ret = PrepareInputData(input_tensor, input_data, output_data, output_desc);
  228. if (ret != SUCCESS) {
  229. GELOGE(GE_GRAPH_PREPARE_FAILED, "[GraphExecutor] PrepareInputData failed, modelId=%u.", model_id);
  230. return GE_GRAPH_PREPARE_FAILED;
  231. }
  232. if (graph_run_listener_->ResetResult() != SUCCESS) {
  233. REPORT_CALL_ERROR("E19999", "Call graph_run_listener_.ResetResult fail, model_id:%u",
  234. model_id);
  235. GELOGE(GE_GRAPH_EXECUTE_FAILED, "Reset result failed");
  236. return GE_GRAPH_EXECUTE_FAILED;
  237. }
  238. // Run mode async
  239. GELOGI("[ExecuteGraph] DataInput via new ome begin.");
  240. ret = DataInput(input_data, output_data);
  241. if (ret != SUCCESS) {
  242. GELOGE(GE_GRAPH_DATA_INPUT_FAILED, "[GraphExecutor] push data failed, modelId=%u.", model_id);
  243. return GE_GRAPH_DATA_INPUT_FAILED;
  244. }
  245. GELOGI("[GraphExecutor] input data push to wrapper finish, waiting for result...");
  246. // Pending until async execute graph complete
  247. {
  248. std::unique_lock<std::mutex> ulock(*sync_run_mutex_);
  249. if (!graph_run_listener_->IsFinished()) {
  250. (*condition_).wait(ulock);
  251. }
  252. // Run graph return
  253. uint32_t result_code = graph_run_listener_->GetResultCode();
  254. if (result_code != SUCCESS && result_code != END_OF_SEQUENCE) {
  255. REPORT_CALL_ERROR("E19999", "Graph_run_listener_ run fail, result:%u, model_id:%u",
  256. result_code, model_id);
  257. GELOGE(GE_GRAPH_EXECUTE_FAILED, "[GraphExecutor] execute model failed, ret=%u, modelId=%u.", result_code,
  258. model_id);
  259. return GE_GRAPH_EXECUTE_FAILED;
  260. }
  261. }
  262. for (size_t i = 0; i < output_data.blobs.size(); ++i) {
  263. DataBuffer outputDataTmp = output_data.blobs[i];
  264. CHECK_FALSE_EXEC(outputDataTmp.length != 0,
  265. REPORT_INNER_ERROR("E19999", "Param output_data.length is 0 in model:%u, check invalid",
  266. model_id);
  267. GELOGE(GE_GRAPH_EXECUTE_FAILED, "Failed to allocate memory, length is 0.");
  268. return GE_GRAPH_EXECUTE_FAILED);
  269. std::unique_ptr<uint8_t> outBufTmp(new (std::nothrow) uint8_t[outputDataTmp.length]);
  270. if (outBufTmp == nullptr) {
  271. REPORT_CALL_ERROR("E19999", "New output buffer fail, length:%lu, model:%u",
  272. outputDataTmp.length, model_id);
  273. GELOGE(FAILED, "Failed to allocate memory.");
  274. return FAILED;
  275. }
  276. GE_PRINT_DYNAMIC_MEMORY(new, "the output memory of data on training.", sizeof(uint8_t) * outputDataTmp.length)
  277. rtError_t ret_value = rtMemcpy(outBufTmp.get(), outputDataTmp.length, outputDataTmp.data, outputDataTmp.length,
  278. RT_MEMCPY_HOST_TO_HOST);
  279. CHECK_FALSE_EXEC(ret_value == RT_ERROR_NONE,
  280. REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, dst_size:%lu, src_size:%zu, ret:0x%X",
  281. outputDataTmp.length, outputDataTmp.length, ret_value);
  282. GELOGE(GE_GRAPH_EXECUTE_FAILED, "Call rt api rtMemcpy failed, ret: 0x%X", ret);
  283. return GE_GRAPH_EXECUTE_FAILED);
  284. GeTensor outTensor;
  285. std::vector<int64_t> shapeDims;
  286. for (const auto &dim : output_desc[i].shape_info.dims) {
  287. shapeDims.push_back(dim);
  288. }
  289. GeShape outShape(shapeDims);
  290. outTensor.MutableTensorDesc().SetShape(outShape);
  291. outTensor.MutableTensorDesc().SetDataType((DataType)output_desc[i].data_type);
  292. (void)outTensor.SetData(outBufTmp.get(), outputDataTmp.length);
  293. output_tensor.push_back(outTensor);
  294. }
  295. GELOGI("[GraphExecutor] execute model success, modelId=%u.", model_id);
  296. return SUCCESS;
  297. }
  298. void GraphExecutor::InitModelIdInfo(std::vector<uint32_t> &out_model_id_info,
  299. std::vector<SubGraphInfoPtr> &sub_graph_vec, uint32_t output_size) {
  300. for (uint32_t i = 0; i < output_size; i++) {
  301. for (size_t j = 0; j < sub_graph_vec.size(); j++) {
  302. if (sub_graph_vec[j]->GetOutputFlag().size() == output_size && sub_graph_vec[j]->GetOutputFlag().at(i)) {
  303. out_model_id_info.push_back(sub_graph_vec[j]->GetModelIdInfo().model_id);
  304. }
  305. }
  306. }
  307. }
  308. Status GraphExecutor::FreeExecuteMemory() {
  309. auto ret = FreeInOutBuffer();
  310. if (ret != SUCCESS) {
  311. GELOGE(ret, "[FreeExecuteMemory] FreeInOutBuffer Error!");
  312. return ret;
  313. }
  314. return SUCCESS;
  315. }
  316. Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeRootModelPtr &ge_root_model,
  317. const std::vector<GeTensor> &input_tensor, std::vector<GeTensor> &output_tensor) {
  318. if (graph_id != last_graph_id_) {
  319. auto ret = FreeExecuteMemory();
  320. if (ret != SUCCESS) {
  321. return ret;
  322. }
  323. }
  324. last_graph_id_ = graph_id;
  325. if (!init_flag_) {
  326. REPORT_INNER_ERROR("E19999", "No SetCondition called before, graph:%u, check invalid",
  327. graph_id);
  328. GELOGE(GE_GRAPH_EXECUTE_NOT_INIT, "[GraphExecutor] AI Core Engine without calling SetCondition!");
  329. return GE_GRAPH_EXECUTE_NOT_INIT;
  330. }
  331. GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED);
  332. Status ret = SyncExecuteModel(ge_root_model->GetModelId(), input_tensor, output_tensor);
  333. if (ret != SUCCESS) {
  334. GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] SyncExecuteModel Error!");
  335. return GE_GRAPH_SYNC_MODEL_FAILED;
  336. }
  337. return SUCCESS;
  338. }
  339. Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model,
  340. const std::vector<InputTensorInfo> &input_tensor,
  341. const RunAsyncCallback& callback) {
  342. GELOGI("[GraphExecutor] Start to async execute graph, graph_id=%u", graph_id);
  343. if (graph_id != last_graph_id_) {
  344. auto ret = FreeExecuteMemory();
  345. if (ret != SUCCESS) {
  346. return ret;
  347. }
  348. }
  349. last_graph_id_ = graph_id;
  350. GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED);
  351. Status ret = AsyncExecuteModel(ge_root_model, input_tensor, callback);
  352. if (ret != SUCCESS) {
  353. GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] AsyncExecuteModel Error!");
  354. return GE_GRAPH_SYNC_MODEL_FAILED;
  355. }
  356. GELOGI("[GraphExecutor] Async execute graph success, graph_id=%u", graph_id);
  357. return SUCCESS;
  358. }
  359. Status GraphExecutor::GetExecuteData(const std::vector<GeTensor> &input_tensor, std::vector<DataBuffer> &blobs,
  360. std::vector<GeTensorDesc> &tensor_desc) {
  361. for (const auto &tensor : input_tensor) {
  362. DataBuffer in_data_buf;
  363. // check placement
  364. in_data_buf.data = const_cast<uint8_t *>(tensor.GetData().data());
  365. in_data_buf.length = tensor.GetData().size();
  366. in_data_buf.isDataSupportMemShare = false;
  367. blobs.emplace_back(in_data_buf);
  368. tensor_desc.emplace_back(tensor.GetTensorDesc());
  369. }
  370. return SUCCESS;
  371. }
  372. Status GraphExecutor::ExecuteGraphWithStream(GraphId graph_id,
  373. rtStream_t stream,
  374. const GeRootModelPtr &ge_root_model,
  375. const std::vector<GeTensor> &input_tensor,
  376. std::vector<GeTensor> &output_tensor) {
  377. GELOGI("[GraphExecutor] Start to execute graph with stream, graph id = %u, stream = %p.", graph_id, stream);
  378. if (!init_flag_) {
  379. REPORT_INNER_ERROR("E19999", "No SetCondition called before, graph id = %u, stream = %p, check invalid.",
  380. graph_id, stream);
  381. GELOGE(GE_GRAPH_EXECUTE_NOT_INIT, "[GraphExecutor] AI Core Engine without calling SetCondition!");
  382. return GE_GRAPH_EXECUTE_NOT_INIT;
  383. }
  384. if (graph_id != last_graph_id_) {
  385. auto ret = FreeExecuteMemory();
  386. if (ret != SUCCESS) {
  387. return ret;
  388. }
  389. }
  390. last_graph_id_ = graph_id;
  391. GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED);
  392. auto model_id = ge_root_model->GetModelId();
  393. InputData input_data;
  394. input_data.index = 0;
  395. input_data.model_id = model_id;
  396. std::vector<GeTensorDesc> input_desc;
  397. auto ret = GetExecuteData(input_tensor, input_data.blobs, input_desc);
  398. if (ret != SUCCESS) {
  399. return ret;
  400. }
  401. OutputData output_data;
  402. output_data.index = 0;
  403. output_data.model_id = model_id;
  404. std::vector<GeTensorDesc> output_desc;
  405. ret = GetExecuteData(output_tensor, output_data.blobs, output_desc);
  406. if (ret != SUCCESS) {
  407. return ret;
  408. }
  409. auto async_mode = true;
  410. auto model_manager = ge::ModelManager::GetInstance();
  411. GE_CHECK_NOTNULL(model_manager);
  412. ret = model_manager->ExecuteModel(model_id, stream, async_mode, input_data, input_desc, output_data, output_desc);
  413. if (ret != SUCCESS) {
  414. return ret;
  415. }
  416. GELOGI("[GraphExecutor] Async execute graph with stream success graph id = %u, stream = %p.", graph_id, stream);
  417. return SUCCESS;
  418. }
  419. bool CompareByLoad(const Uint32Pair &lhs, const Uint32Pair &rhs) {
  420. return lhs.second < rhs.second;
  421. }
  422. uint32_t GraphExecutor::GetExecuteModelId(const GeRootModelPtr &ge_root_model) {
  423. std::vector<uint32_t> model_ids = ge_root_model->GetAllModelId();
  424. if (model_ids.empty()) {
  425. return kInvalidModelId;
  426. }
  427. if (model_ids.size() == 1) {
  428. return ge_root_model->GetModelId();
  429. }
  430. std::vector<Uint32Pair> model_id_to_loads;
  431. auto model_manager = ModelManager::GetInstance();
  432. GE_CHECK_NOTNULL(model_manager);
  433. for (auto model_id : model_ids) {
  434. auto davinci_model = model_manager->GetModel(model_id);
  435. auto hybrid_model = model_manager->GetHybridModel(model_id);
  436. if (hybrid_model == nullptr) {
  437. GE_CHECK_NOTNULL(davinci_model);
  438. }
  439. uint32_t input_load = hybrid_model != nullptr ? hybrid_model->GetDataInputerSize() :
  440. davinci_model->GetDataInputerSize();
  441. uint32_t running_load = hybrid_model != nullptr ? static_cast<uint32_t>(hybrid_model->GetRunningFlag()) :
  442. static_cast<uint32_t>(davinci_model->GetRunningFlag());
  443. uint32_t load = input_load + running_load;
  444. if (load == 0) {
  445. return model_id;
  446. }
  447. model_id_to_loads.emplace_back(model_id, load);
  448. }
  449. sort(model_id_to_loads.begin(), model_id_to_loads.end(), CompareByLoad);
  450. if (model_id_to_loads.empty()) {
  451. return kInvalidModelId;
  452. }
  453. return model_id_to_loads.begin()->first;
  454. }
  455. Status GraphExecutor::SetCallback(uint32_t model_id, const GeRootModelPtr &ge_root_model,
  456. const RunAsyncCallback &callback) {
  457. auto model_manager = ge::ModelManager::GetInstance();
  458. GE_CHECK_NOTNULL(model_manager);
  459. if (model_manager->IsNeedHybridLoad(*ge_root_model)) {
  460. auto model = model_manager->GetHybridModel(model_id);
  461. GE_CHECK_NOTNULL(model);
  462. if (model->SetRunAsyncListenerCallback(callback) != SUCCESS) {
  463. GELOGE(FAILED, "SetRunAsyncListenerCallback failed.");
  464. return FAILED;
  465. }
  466. } else {
  467. auto model = model_manager->GetModel(model_id);
  468. GE_CHECK_NOTNULL(model);
  469. if (model->SetRunAsyncListenerCallback(callback) != SUCCESS) {
  470. GELOGE(FAILED, "SetRunAsyncListenerCallback failed.");
  471. return FAILED;
  472. }
  473. }
  474. return SUCCESS;
  475. }
  476. Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<InputTensorInfo> &inputs,
  477. const RunAsyncCallback &callback) {
  478. uint32_t model_id = GetExecuteModelId(ge_root_model);
  479. if (model_id == kInvalidModelId) {
  480. GELOGE(INTERNAL_ERROR, "No valid model id.");
  481. return INTERNAL_ERROR;
  482. }
  483. try {
  484. auto model_manager = ge::ModelManager::GetInstance();
  485. GE_CHECK_NOTNULL(model_manager);
  486. GELOGI("RunAsync begin.model_id %u", model_id);
  487. if (SetCallback(model_id, ge_root_model, callback) != SUCCESS) {
  488. GELOGE(FAILED, "RunAsync: SetCallBack for model fail");
  489. return FAILED;
  490. }
  491. Status ret = model_manager->DataInputTensor(model_id, inputs);
  492. if (ret != SUCCESS) {
  493. GELOGE(ret, "RunAsync: DataInput fail");
  494. return ret;
  495. }
  496. GELOGI("RunAsync success.");
  497. } catch (std::bad_alloc &) {
  498. REPORT_INNER_ERROR("E19999", "Bad memory allocation exception occur failed");
  499. GELOGE(MEMALLOC_FAILED, "RunAsync failed, bad memory allocation occur !");
  500. return MEMALLOC_FAILED;
  501. } catch (...) {
  502. REPORT_INNER_ERROR("E19999", "Some exceptions occur failed");
  503. GELOGE(FAILED, "RunAsync failed, some exceptions occur !");
  504. return FAILED;
  505. }
  506. return SUCCESS;
  507. }
  508. Status GraphExecutor::DataInput(const InputData &input_data, OutputData &output_data) {
  509. try {
  510. auto model_manager = ge::ModelManager::GetInstance();
  511. GE_CHECK_NOTNULL(model_manager);
  512. Status ret = model_manager->DataInput(input_data, output_data);
  513. if (ret != SUCCESS) {
  514. GELOGE(ret, "DataInput: DataInput failed.");
  515. return ret;
  516. }
  517. } catch (std::bad_alloc &) {
  518. REPORT_INNER_ERROR("E19999", "Bad memory allocation exception occur failed");
  519. GELOGE(MEMALLOC_FAILED, "DataInput failed, bad memory allocation occur !");
  520. return MEMALLOC_FAILED;
  521. } catch (...) {
  522. REPORT_INNER_ERROR("E19999", "Some exceptions occur failed");
  523. GELOGE(FAILED, "DataInput failed, some exceptions occur !");
  524. return FAILED;
  525. }
  526. return SUCCESS;
  527. }
  528. Status GraphExecutor::GetInputOutputDescInfo(const uint32_t model_id, vector<InputOutputDescInfo> &input_desc,
  529. vector<InputOutputDescInfo> &output_desc) {
  530. try {
  531. auto model_manager = ge::ModelManager::GetInstance();
  532. GE_CHECK_NOTNULL(model_manager);
  533. Status ret = model_manager->GetInputOutputDescInfo(model_id, input_desc, output_desc);
  534. if (ret != SUCCESS) {
  535. GELOGE(ret, "GetInputOutputDescInfo failed.");
  536. return ret;
  537. }
  538. } catch (std::bad_alloc &) {
  539. REPORT_INNER_ERROR("E19999", "Bad memory allocation exception occur failed");
  540. GELOGE(MEMALLOC_FAILED, "GetInputOutputDescInfo failed, bad memory allocation occur !");
  541. return MEMALLOC_FAILED;
  542. } catch (...) {
  543. REPORT_INNER_ERROR("E19999", "Some exceptions occur failed");
  544. GELOGE(FAILED, "GetInputOutputDescInfo failed, some exceptions occur !");
  545. return FAILED;
  546. }
  547. return SUCCESS;
  548. }
  549. Status GraphExecutor::GetInputOutputDescInfo(const uint32_t model_id, vector<InputOutputDescInfo> &input_desc,
  550. vector<InputOutputDescInfo> &output_desc,
  551. std::vector<uint32_t> &input_formats, std::vector<uint32_t> &out_formats,
  552. bool new_model_desc) {
  553. try {
  554. auto model_manager = ge::ModelManager::GetInstance();
  555. GE_CHECK_NOTNULL(model_manager);
  556. Status ret = model_manager->GetInputOutputDescInfo(model_id, input_desc, output_desc, input_formats, out_formats,
  557. new_model_desc);
  558. if (ret != SUCCESS) {
  559. GELOGE(ret, "GetInputOutputDescInfo failed.");
  560. return ret;
  561. }
  562. } catch (std::bad_alloc &) {
  563. REPORT_INNER_ERROR("E19999", "Bad memory allocation exception occur failed");
  564. GELOGE(MEMALLOC_FAILED, "GetInputOutputDescInfo failed, bad memory allocation occur !");
  565. return MEMALLOC_FAILED;
  566. } catch (...) {
  567. REPORT_INNER_ERROR("E19999", "Some exceptions occur failed");
  568. GELOGE(FAILED, "GetInputOutputDescInfo failed, some exceptions occur !");
  569. return FAILED;
  570. }
  571. return SUCCESS;
  572. }
  573. ///
  574. /// @ingroup ge
  575. /// @brief Get dynamic batch_info
  576. /// @param [in] model_id
  577. /// @param [out] batch_info
  578. /// @param [out] dynamic_type
  579. /// @return execute result
  580. ///
  581. Status GraphExecutor::GetDynamicBatchInfo(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info,
  582. int32_t &dynamic_type) {
  583. auto model_manager = ge::ModelManager::GetInstance();
  584. GE_CHECK_NOTNULL(model_manager);
  585. Status ret = model_manager->GetDynamicBatchInfo(model_id, batch_info, dynamic_type);
  586. if (ret != SUCCESS) {
  587. GELOGE(ret, "GetDynamicBatchInfo failed.");
  588. return ret;
  589. }
  590. return SUCCESS;
  591. }
  592. ///
  593. /// @ingroup ge
  594. /// @brief Get combined dynamic dims info
  595. /// @param [in] model_id
  596. /// @param [out] batch_info
  597. /// @return execute result
  598. ///
  599. Status GraphExecutor::GetCombinedDynamicDims(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info) {
  600. auto model_manager = ge::ModelManager::GetInstance();
  601. GE_CHECK_NOTNULL(model_manager);
  602. Status ret = model_manager->GetCombinedDynamicDims(model_id, batch_info);
  603. if (ret != SUCCESS) {
  604. GELOGE(ret, "GetCombinedDynamicDims failed.");
  605. return ret;
  606. }
  607. return SUCCESS;
  608. }
  609. ///
  610. /// @ingroup ge
  611. /// @brief Get user designate shape order
  612. /// @param [in] model_id
  613. /// @param [out] user_input_shape_order
  614. /// @return execute result
  615. ///
  616. ge::Status GraphExecutor::GetUserDesignateShapeOrder(uint32_t model_id,
  617. std::vector<std::string> &user_input_shape_order) {
  618. auto model_manager = ge::ModelManager::GetInstance();
  619. GE_CHECK_NOTNULL(model_manager);
  620. Status ret = model_manager->GetUserDesignateShapeOrder(model_id, user_input_shape_order);
  621. if (ret != SUCCESS) {
  622. GELOGE(ret, "GetUserDesignateShapeOrder failed.");
  623. return ret;
  624. }
  625. return SUCCESS;
  626. }
  627. Status GraphExecutor::GetCurShape(const uint32_t model_id, std::vector<int64_t> &batch_info, int32_t &dynamic_type) {
  628. auto model_manager = ge::ModelManager::GetInstance();
  629. GE_CHECK_NOTNULL(model_manager);
  630. Status ret = model_manager->GetCurShape(model_id, batch_info, dynamic_type);
  631. if (ret != SUCCESS) {
  632. GELOGE(ret, "GetCurShape failed");
  633. return ret;
  634. }
  635. return SUCCESS;
  636. }
  637. Status GraphExecutor::GetOpAttr(uint32_t model_id, const std::string &op_name, const std::string &attr_name,
  638. std::string &attr_value) {
  639. auto model_manager = ge::ModelManager::GetInstance();
  640. GE_CHECK_NOTNULL(model_manager);
  641. Status ret = model_manager->GetOpAttr(model_id, op_name, attr_name, attr_value);
  642. if (ret != SUCCESS) {
  643. GELOGE(ret, "[Get][OpAttr]Get op:%s attr:%s failed.", op_name.c_str(), attr_name.c_str());
  644. REPORT_CALL_ERROR("E19999", "Get op:%s attr:%s failed.", op_name.c_str(), attr_name.c_str());
  645. return ret;
  646. }
  647. return SUCCESS;
  648. }
  649. Status GraphExecutor::GetModelAttr(uint32_t model_id, std::vector<string> &dynamic_output_shape_info) {
  650. auto model_manager = ge::ModelManager::GetInstance();
  651. GE_CHECK_NOTNULL(model_manager);
  652. Status ret = model_manager->GetModelAttr(model_id, dynamic_output_shape_info);
  653. if (ret != SUCCESS) {
  654. GELOGE(FAILED, "GetModelAttr failed");
  655. return ret;
  656. }
  657. return SUCCESS;
  658. }
  659. Status GraphExecutor::GetAippInfo(uint32_t model_id, uint32_t index, AippConfigInfo &aipp_info) {
  660. auto model_manager = ge::ModelManager::GetInstance();
  661. GE_CHECK_NOTNULL(model_manager);
  662. Status ret = model_manager->GetAippInfo(model_id, index, aipp_info);
  663. if (ret != SUCCESS) {
  664. GELOGW("GetAIPPInfo is not success.");
  665. return ret;
  666. }
  667. return SUCCESS;
  668. }
  669. Status GraphExecutor::GetAippType(uint32_t model_id, uint32_t index, InputAippType &type, size_t &aipp_index) {
  670. auto model_manager = ge::ModelManager::GetInstance();
  671. GE_CHECK_NOTNULL(model_manager);
  672. Status ret = model_manager->GetAippType(model_id, index, type, aipp_index);
  673. if (ret != SUCCESS) {
  674. GELOGW("Get aipp type is not success.");
  675. return ret;
  676. }
  677. return SUCCESS;
  678. }
  679. Status GraphExecutor::GetOrigInputInfo(uint32_t model_id, uint32_t index, OriginInputInfo &orig_input_info) {
  680. auto model_manager = ge::ModelManager::GetInstance();
  681. GE_CHECK_NOTNULL(model_manager);
  682. Status ret = model_manager->GetOrigInputInfo(model_id, index, orig_input_info);
  683. if (ret != SUCCESS) {
  684. GELOGE(ret, "GetOrigInputInfo failed.");
  685. return ret;
  686. }
  687. return SUCCESS;
  688. }
  689. Status GraphExecutor::GetAllAippInputOutputDims(uint32_t model_id, uint32_t index,
  690. std::vector<InputOutputDims> &input_dims,
  691. std::vector<InputOutputDims> &output_dims) {
  692. auto model_manager = ge::ModelManager::GetInstance();
  693. GE_CHECK_NOTNULL(model_manager);
  694. Status ret = model_manager->GetAllAippInputOutputDims(model_id, index, input_dims, output_dims);
  695. if (ret != SUCCESS) {
  696. GELOGE(ret, "GetAllAippInputOutputDims failed.");
  697. return ret;
  698. }
  699. return SUCCESS;
  700. }
  701. Status GraphExecutor::GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id,
  702. OpDescInfo &op_desc_info) {
  703. auto model_manager = ge::ModelManager::GetInstance();
  704. GE_CHECK_NOTNULL(model_manager);
  705. Status ret = model_manager->GetOpDescInfo(device_id, stream_id, task_id, op_desc_info);
  706. if (ret != SUCCESS) {
  707. GELOGE(ret, "GetOpDescInfo failed.");
  708. return ret;
  709. }
  710. return SUCCESS;
  711. }
  712. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示