You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hybrid_model_async_executor.cc 20 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "hybrid/executor/hybrid_model_async_executor.h"
  17. #include "graph/load/model_manager/model_utils.h"
  18. #include "graph/utils/tensor_utils.h"
  19. #include "graph/utils/type_utils.h"
  20. #include "graph/ge_context.h"
  21. #include "omm/csa_interact.h"
  22. namespace ge {
  23. namespace hybrid {
  24. namespace {
  25. const int kDataOutputIndex = 0;
  26. }
  27. HybridModelAsyncExecutor::HybridModelAsyncExecutor(HybridModel *model)
  28. : model_(model), run_flag_(false) {
  29. }
  30. HybridModelAsyncExecutor::~HybridModelAsyncExecutor() {
  31. if (stream_ != nullptr) {
  32. GE_CHK_RT(rtStreamDestroy(stream_));
  33. }
  34. }
  35. void HybridModelAsyncExecutor::SetDeviceId(uint32_t device_id) {
  36. device_id_ = device_id;
  37. }
  38. void HybridModelAsyncExecutor::SetModelId(uint32_t model_id) {
  39. model_id_ = model_id;
  40. }
  41. Status HybridModelAsyncExecutor::EnqueueData(const shared_ptr<InputDataWrapper> &data) {
  42. GE_CHK_STATUS_EXEC(data_inputer_->Push(data), return domi::DATA_QUEUE_ISFULL,
  43. "Data queue is full, please call again later, model_id %u ", model_id_);
  44. GELOGD("EnqueueData successfully. model_id = %u, data_index = %u", data->GetInput().model_id, data->GetInput().index);
  45. return SUCCESS;
  46. }
  47. Status HybridModelAsyncExecutor::Start(const std::shared_ptr<ModelListener> &listener) {
  48. GELOGD("HybridModelExecutor::Start IN, has listener = %d", listener != nullptr);
  49. std::lock_guard<std::mutex> lk(mu_);
  50. GE_CHK_BOOL_RET_STATUS(!run_flag_, INTERNAL_ERROR, "Model already started.");
  51. run_flag_ = true;
  52. listener_ = listener;
  53. future_ = std::async(std::launch::async, [&]() -> Status {
  54. GetThreadLocalContext() = *executor_->GetContext()->ge_context;
  55. GetContext().SetSessionId(executor_->GetContext()->session_id);
  56. return RunInternal();
  57. });
  58. GE_CHK_BOOL_RET_STATUS(future_.valid(), INTERNAL_ERROR, "Failed to start.");
  59. GELOGD("HybridModelExecutor::Start successfully");
  60. return SUCCESS;
  61. }
  62. Status HybridModelAsyncExecutor::Stop() {
  63. std::lock_guard<std::mutex> lk(mu_);
  64. run_flag_ = false;
  65. data_inputer_->Stop();
  66. Status ret = SUCCESS;
  67. if (future_.valid()) {
  68. ret = future_.get();
  69. }
  70. if (stream_ != nullptr) {
  71. GE_CHK_RT(rtStreamDestroy(stream_));
  72. stream_ = nullptr;
  73. }
  74. return ret;
  75. }
  76. Status HybridModelAsyncExecutor::Init() {
  77. data_inputer_ = std::unique_ptr<DataInputer>(new(std::nothrow) DataInputer());
  78. GE_CHECK_NOTNULL(data_inputer_);
  79. GE_CHK_RT_RET(rtStreamCreate(&stream_, RT_STREAM_PRIORITY_DEFAULT));
  80. executor_ = std::unique_ptr<HybridModelExecutor>(new(std::nothrow) HybridModelExecutor(model_, device_id_, stream_));
  81. GE_CHECK_NOTNULL(executor_);
  82. GE_CHK_STATUS_RET(executor_->Init(), "Failed to init hybrid engine");
  83. GE_CHK_STATUS_RET(InitInputDesc(), "Failed to init input tensors");
  84. return SUCCESS;
  85. }
  86. Status HybridModelAsyncExecutor::PreRun(InputData &current_data, HybridModelExecutor::ExecuteArgs &args) {
  87. GE_CHK_STATUS_RET(SyncVarData(), "Failed to sync var data");
  88. RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[SyncVarData] End");
  89. GE_CHK_STATUS_RET(PrepareInputs(current_data, args), "Failed to copy input data to model");
  90. RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[CopyInputData] End");
  91. return SUCCESS;
  92. }
  93. Status HybridModelAsyncExecutor::RunInternal() {
  94. auto device_id = static_cast<int32_t>(device_id_);
  95. GELOGD("Hybrid model start. model_id = %u, device_id = %u", model_id_, device_id_);
  96. GE_CHK_RT_RET(rtSetDevice(device_id));
  97. // DeviceReset before thread run finished!
  98. GE_MAKE_GUARD(not_used_var, [&] { GE_CHK_RT(rtDeviceReset(device_id)); });
  99. while (run_flag_) {
  100. std::shared_ptr<InputDataWrapper> data_wrapper;
  101. Status ret = data_inputer_->Pop(data_wrapper);
  102. if (data_wrapper == nullptr || ret != SUCCESS) {
  103. GELOGI("data_wrapper is null!, ret = %u", ret);
  104. continue;
  105. }
  106. GELOGI("Getting the input data, model_id:%u", model_id_);
  107. GE_IF_BOOL_EXEC(!run_flag_, break);
  108. InputData current_data = data_wrapper->GetInput();
  109. GELOGI("Model thread Run begin, model id:%u, data index:%u.", model_id_, current_data.index);
  110. RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[RunInternal] [iteration = %d] Start", iterator_count_);
  111. HybridModelExecutor::ExecuteArgs args;
  112. ret = PreRun(current_data, args);
  113. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
  114. ret != SUCCESS, (void) HandleResult(ret, current_data.index, args, data_wrapper->GetOutput());
  115. CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC);
  116. continue, "PreRun failed."); // [No need to check value]
  117. ret = executor_->Execute(args);
  118. ret = HandleResult(ret, current_data.index, args, data_wrapper->GetOutput());
  119. if (ret != SUCCESS) {
  120. CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC);
  121. continue;
  122. }
  123. RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[RunInternal] [iteration = %d] End", iterator_count_);
  124. iterator_count_++;
  125. GELOGI("run iterator count is %lu", iterator_count_);
  126. }
  127. CsaInteract::GetInstance().WriteInternalErrorCode();
  128. GELOGI("Model run end, model id:%u", model_id_);
  129. return SUCCESS;
  130. }
  131. Status HybridModelAsyncExecutor::HandleResult(Status exec_ret,
  132. uint32_t data_id,
  133. HybridModelExecutor::ExecuteArgs &args,
  134. OutputData *output_data) {
  135. GELOGD("Start to handle result. model id = %u, data index = %u, execution ret = %u", model_id_, data_id, exec_ret);
  136. std::vector<ge::OutputTensorInfo> output_tensor_info_list;
  137. if (args.is_eos) {
  138. GELOGI("End of sequence, model id = %u", model_id_);
  139. GE_CHK_STATUS_RET_NOLOG(OnComputeDone(data_id, END_OF_SEQUENCE, output_tensor_info_list));
  140. return SUCCESS;
  141. }
  142. if (exec_ret != SUCCESS) {
  143. GELOGE(exec_ret, "Failed to execute graph. model_id = %u", model_id_);
  144. return OnComputeDone(data_id, INTERNAL_ERROR, output_tensor_info_list);
  145. }
  146. GE_CHECK_NOTNULL(output_data);
  147. auto ret = CopyOutputs(args, output_data, output_tensor_info_list);
  148. if (ret != SUCCESS) {
  149. OnComputeDone(data_id, INTERNAL_ERROR, output_tensor_info_list);
  150. return INTERNAL_ERROR;
  151. }
  152. GELOGD("Executed graph successfully, model id = %u, data_index = %u", model_id_, data_id);
  153. return OnComputeDone(data_id, SUCCESS, output_tensor_info_list);
  154. }
  155. Status HybridModelAsyncExecutor::SyncVarData() {
  156. GELOGI("Sync var data, model id:%u", model_id_);
  157. TensorValue *global_step_var = model_->GetVariable(NODE_NAME_GLOBAL_STEP);
  158. if (global_step_var != nullptr) {
  159. std::vector<uint64_t> v_step;
  160. v_step.push_back(iterator_count_);
  161. GE_CHK_RT_RET(rtMemcpy(global_step_var->MutableData(),
  162. global_step_var->GetSize(),
  163. v_step.data(),
  164. v_step.size() * sizeof(uint64_t),
  165. RT_MEMCPY_HOST_TO_DEVICE));
  166. } else {
  167. GELOGD("No GLOBAL_STEP variable was found.");
  168. }
  169. return SUCCESS;
  170. }
  171. Status HybridModelAsyncExecutor::PrepareInputs(const InputData &current_data, HybridModelExecutor::ExecuteArgs &args) {
  172. if (current_data.blobs.size() < input_tensor_desc_.size()) {
  173. GELOGE(PARAM_INVALID, "Blob size mismatches, expect at least %zu, but got %zu",
  174. input_tensor_desc_.size(), current_data.blobs.size());
  175. return PARAM_INVALID;
  176. }
  177. auto allocator = NpuMemoryAllocator::GetAllocator(device_id_);
  178. GE_CHECK_NOTNULL(allocator);
  179. args.input_desc.resize(input_tensor_desc_.size());
  180. const std::vector<DataBuffer> &blobs = current_data.blobs;
  181. for (size_t input_index = 0; input_index < input_tensor_desc_.size(); ++input_index) {
  182. auto tensor_size = input_sizes_[input_index];
  183. if (is_input_dynamic_[input_index]) {
  184. if (input_index >= current_data.shapes.size()) {
  185. GELOGE(PARAM_INVALID, "Shape index out of range, index = %zu, shape size = %zu",
  186. input_index, current_data.shapes.size());
  187. return PARAM_INVALID;
  188. }
  189. auto &tensor_desc = input_tensor_desc_[input_index];
  190. tensor_desc->SetShape(GeShape(current_data.shapes[input_index]));
  191. args.input_desc[input_index] = tensor_desc;
  192. GELOGD("Update shape of input[%zu] to [%s]", input_index, tensor_desc->MutableShape().ToString().c_str());
  193. GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size),
  194. "Failed to calc tensor size, index = %zu, shape = [%s]",
  195. input_index,
  196. tensor_desc->GetShape().ToString().c_str());
  197. GELOGD("Input tensor[%zu] size = %zu", input_index, tensor_size);
  198. }
  199. GE_CHECK_GE(tensor_size, 0);
  200. AllocationAttr attr;
  201. if (GetContext().GetHostExecFlag()) {
  202. attr.SetMemType(HOST_DDR);
  203. }
  204. auto tensor_buffer = TensorBuffer::Create(allocator, tensor_size, &attr);
  205. GE_CHECK_NOTNULL(tensor_buffer);
  206. args.inputs.emplace_back(std::shared_ptr<TensorBuffer>(tensor_buffer.release()));
  207. GELOGD("To copy input data for input[%zu]", input_index);
  208. const DataBuffer &data_buf = blobs[input_index];
  209. auto mem_size = static_cast<uint64_t>(tensor_size);
  210. GE_CHK_BOOL_RET_STATUS(mem_size >= data_buf.length,
  211. PARAM_INVALID,
  212. "input data size(%lu) does not match model required size(%lu), ret failed.",
  213. data_buf.length,
  214. mem_size);
  215. GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] output[%zu] memaddr[%p] mem_size[%zu] datasize[%lu]",
  216. model_->root_runtime_param_.graph_id,
  217. input_index,
  218. args.inputs[input_index].GetData(),
  219. mem_size,
  220. data_buf.length);
  221. GE_CHK_RT_RET(rtMemcpy(args.inputs[input_index].MutableData(),
  222. mem_size,
  223. data_buf.data,
  224. data_buf.length,
  225. RT_MEMCPY_HOST_TO_DEVICE));
  226. }
  227. return SUCCESS;
  228. }
  229. Status HybridModelAsyncExecutor::InitInputDesc() {
  230. int input_index = 0;
  231. for (const auto &input_node : model_->GetRootGraphItem()->GetInputNodes()) {
  232. GELOGD("Init input[%u], node = %s, is_dynamic = %d",
  233. input_index,
  234. input_node->NodeName().c_str(),
  235. input_node->is_dynamic);
  236. auto output_desc = input_node->MutableOutputDesc(kDataOutputIndex);
  237. GE_CHECK_NOTNULL(output_desc);
  238. int64_t tensor_size = -1;
  239. if (!input_node->is_dynamic) {
  240. GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetSize(*output_desc, tensor_size),
  241. "Failed to get size from %s",
  242. input_node->NodeName().c_str());
  243. if (tensor_size == 0) {
  244. GELOGW("[%s] Tensor size == 0", input_node->NodeName().c_str());
  245. GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*output_desc, tensor_size),
  246. "Failed to calc tensor size");
  247. GELOGD("[%s] Tensor size updated to %ld", input_node->NodeName().c_str(), tensor_size);
  248. }
  249. }
  250. input_sizes_.emplace(input_index, tensor_size);
  251. input_tensor_desc_.emplace(input_index, output_desc);
  252. is_input_dynamic_.push_back(input_node->is_dynamic);
  253. input_index += 1;
  254. }
  255. return SUCCESS;
  256. }
  257. Status HybridModelAsyncExecutor::OnComputeDone(uint32_t data_index, uint32_t result_code,
  258. std::vector<ge::OutputTensorInfo> &outputs) {
  259. GELOGD("OnComputeDone. model id = %u, data index = %u, execution ret = %u", model_id_, data_index, result_code);
  260. if (listener_ != nullptr) {
  261. GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_index, result_code, outputs),
  262. "OnComputeDone failed");
  263. }
  264. return result_code;
  265. }
  266. Status HybridModelAsyncExecutor::CopyOutputs(HybridModelExecutor::ExecuteArgs &args,
  267. OutputData *output_data,
  268. std::vector<ge::OutputTensorInfo> &outputs) {
  269. // copy output data from op to designated position
  270. std::vector<ConstGeTensorDescPtr> &output_tensor_desc_list = args.output_desc;
  271. std::vector<TensorValue> &output_tensors = args.outputs;
  272. if (output_tensor_desc_list.size() != output_tensors.size()) {
  273. GELOGE(INTERNAL_ERROR,
  274. "Output sizes mismatch. From op_desc = %zu, and from output tensors = %zu",
  275. output_tensor_desc_list.size(),
  276. output_tensors.size());
  277. return INTERNAL_ERROR;
  278. }
  279. GELOGD("Number of outputs = %zu", output_tensor_desc_list.size());
  280. for (size_t i = 0; i < output_tensors.size(); ++i) {
  281. GELOGD("Start to process output[%zu]", i);
  282. auto &output_tensor = output_tensors[i];
  283. auto &tensor_desc = output_tensor_desc_list.at(i);
  284. GE_CHECK_NOTNULL(tensor_desc);
  285. int64_t output_size = -1;
  286. GE_CHK_GRAPH_STATUS_RET(TensorUtils::CalcTensorMemSize(tensor_desc->GetShape(),
  287. tensor_desc->GetFormat(),
  288. tensor_desc->GetDataType(),
  289. output_size),
  290. "Failed to calc tensor size for output[%zu]. shape = [%s], type = %s, format = %s",
  291. i,
  292. tensor_desc->GetShape().ToString().c_str(),
  293. TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()).c_str(),
  294. TypeUtils::FormatToSerialString(tensor_desc->GetFormat()).c_str());
  295. GELOGD("Got tensor size for output[%zu] successfully. shape = [%s], type = %s, format = %s, size = %ld",
  296. i,
  297. tensor_desc->GetShape().ToString().c_str(),
  298. TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()).c_str(),
  299. TypeUtils::FormatToSerialString(tensor_desc->GetFormat()).c_str(),
  300. output_size);
  301. GE_CHECK_GE(output_size, 0);
  302. GE_CHECK_LE(output_size, UINT32_MAX);
  303. if (output_tensor.GetSize() < static_cast<size_t>(output_size)) {
  304. GELOGE(INTERNAL_ERROR,
  305. "output[%zu] tensor size(%zu) is not enough for output shape [%s]",
  306. i, output_tensor.GetSize(), tensor_desc->GetShape().ToString().c_str());
  307. return INTERNAL_ERROR;
  308. }
  309. ge::OutputTensorInfo output;
  310. output.data_type = static_cast<uint32_t>(tensor_desc->GetDataType());
  311. output.dims = tensor_desc->GetShape().GetDims();
  312. output.length = output_size;
  313. if (output_size > 0) {
  314. std::unique_ptr<uint8_t[]> data_buf(new(std::nothrow) uint8_t[output_size]);
  315. GE_CHECK_NOTNULL(data_buf);
  316. GE_CHK_RT_RET(rtMemcpy(data_buf.get(),
  317. output_size,
  318. output_tensor.GetData(),
  319. output_size,
  320. RT_MEMCPY_DEVICE_TO_HOST));
  321. output.data = std::move(data_buf);
  322. output_data->blobs.emplace_back(data_buf.get(), static_cast<uint32_t>(output_size), false);
  323. } else {
  324. GELOGW("Output[%zu] is empty. shape = [%s]", i, tensor_desc->GetShape().ToString().c_str());
  325. output.data = nullptr;
  326. output_data->blobs.emplace_back(nullptr, 0U, false);
  327. }
  328. outputs.emplace_back(std::move(output));
  329. GELOGD("Output[%zu] added, type = %s, shape = [%s], size = %ld",
  330. i,
  331. TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()).c_str(),
  332. tensor_desc->GetShape().ToString().c_str(),
  333. output_size);
  334. }
  335. return SUCCESS;
  336. }
  337. Status HybridModelAsyncExecutor::Execute(const std::vector<DataBuffer> &inputs,
  338. const std::vector<GeTensorDesc> &input_desc,
  339. std::vector<DataBuffer> &outputs,
  340. std::vector<GeTensorDesc> &output_desc) {
  341. GELOGI("Start to execute model.");
  342. HybridModelExecutor::ExecuteArgs args;
  343. args.inputs.resize(inputs.size());
  344. for (size_t i = 0; i < inputs.size(); ++i) {
  345. TensorValue tensor_value(inputs[i].data, inputs[i].length);
  346. args.inputs[i] = tensor_value;
  347. }
  348. GE_CHK_STATUS_RET(executor_->Execute(args), "Failed to execute model.");
  349. for (const auto &output_tensor_desc : args.output_desc) {
  350. output_desc.emplace_back(*output_tensor_desc);
  351. }
  352. for (size_t i = 0; i < args.outputs.size(); ++i) {
  353. int64_t output_real_size = 0;
  354. ge::graphStatus graph_status = TensorUtils::GetTensorSizeInBytes(output_desc[i], output_real_size);
  355. if (graph_status != GRAPH_SUCCESS) {
  356. GELOGE(FAILED, "Get tensor size in bytes failed.");
  357. return FAILED;
  358. }
  359. if (output_real_size > 0) {
  360. if (outputs[i].length < static_cast<uint64_t>(output_real_size)) {
  361. GELOGE(FAILED, "output idx[%zu], the memory size of output[%lu] given by "
  362. "user should be greater than or equal to the real size of output[%ld]",
  363. i, outputs[i].length, output_real_size);
  364. return FAILED;
  365. }
  366. GE_CHK_RT_RET(rtMemcpy(outputs[i].data, outputs[i].length, args.outputs[i].GetData(), output_real_size,
  367. RT_MEMCPY_DEVICE_TO_DEVICE));
  368. }
  369. outputs[i].length = output_real_size;
  370. }
  371. return SUCCESS;
  372. }
  373. Status HybridModelAsyncExecutor::Execute(const vector<GeTensor> &inputs, vector<GeTensor> &outputs) {
  374. GELOGD("Start to execute model.");
  375. // prepare inputs
  376. InputData input_data;
  377. for (auto &tensor : inputs) {
  378. DataBuffer buffer;
  379. buffer.data = const_cast<uint8_t *>(tensor.GetData().GetData());
  380. buffer.length = tensor.GetData().size();
  381. input_data.blobs.emplace_back(buffer);
  382. input_data.shapes.emplace_back(tensor.GetTensorDesc().GetShape().GetDims());
  383. }
  384. HybridModelExecutor::ExecuteArgs args;
  385. GE_CHK_STATUS_RET(PrepareInputs(input_data, args), "Failed to copy input data to model");
  386. GELOGD("Done copying input data successfully.");
  387. GE_CHK_STATUS_RET(executor_->Execute(args), "Failed to execute model.");
  388. std::vector<ge::OutputTensorInfo> output_tensor_info_list;
  389. OutputData output_data;
  390. GE_CHK_STATUS_RET(CopyOutputs(args, &output_data, output_tensor_info_list), "Failed to copy outputs.");
  391. GELOGD("Done copying output data successfully. output count = %zu", output_tensor_info_list.size());
  392. int out_index = 0;
  393. outputs.resize(output_tensor_info_list.size());
  394. for (auto &out_tensor_info : output_tensor_info_list) {
  395. auto &ge_tensor = outputs[out_index];
  396. if (out_tensor_info.length > 0) {
  397. GE_CHK_GRAPH_STATUS_RET(ge_tensor.SetData(out_tensor_info.data.get(), out_tensor_info.length),
  398. "Failed to set output[%d].", out_index);
  399. }
  400. ge_tensor.MutableTensorDesc() = *args.output_desc[out_index];
  401. GELOGD("Set output[%d], tensor size = %ld, shape = [%s]",
  402. out_index,
  403. out_tensor_info.length,
  404. ge_tensor.MutableTensorDesc().MutableShape().ToString().c_str());
  405. ++out_index;
  406. }
  407. return SUCCESS;
  408. }
  409. } // namespace hybrid
  410. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示