You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_v2_executor.h 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. /**
  2. * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef AIR_CXX_RUNTIME_V2_CORE_MODEL_V_2_EXECUTOR_H_
  17. #define AIR_CXX_RUNTIME_V2_CORE_MODEL_V_2_EXECUTOR_H_
  18. #include <memory>
  19. #include "graph/compute_graph.h"
  20. #include "graph/ge_error_codes.h"
  21. #include "model_desc.h"
  22. #include "runtime/stream.h"
  23. #include "exe_graph/runtime/tensor.h"
  24. namespace gert {
  25. enum SubExeGraphType { kInitExeGraph, kMainExeGraph, kDeInitExeGraph, kSubExeGraphTypeEnd };
  26. static constexpr char *kSubExeGraphTypeStrs[kSubExeGraphTypeEnd] = {(char *)"Init", (char *)"Main", (char *)"DeInit"};
  27. inline const char *GetSubExeGraphTypeStr(SubExeGraphType type) {
  28. return kSubExeGraphTypeStrs[type];
  29. }
  30. class ResourceGuard {
  31. public:
  32. void *ResetExecutionData(std::unique_ptr<uint8_t[]> execution_data);
  33. void ResetAnyValue(std::unique_ptr<uint8_t[]> any_values, size_t count);
  34. void PushNode(void *node);
  35. void PushWatcher(void *watcher);
  36. void *ResetNodesArray(std::unique_ptr<uint8_t[]> nodes_array);
  37. void *ResetStartNodesArray(std::unique_ptr<uint8_t[]> start_nodes_array);
  38. void *ResetNodesIndgreeArray(std::unique_ptr<uint8_t[]> nodes_indgree_array);
  39. void *ResetNodesWaitIndgreeArray(std::unique_ptr<uint8_t[]> nodes_indgree_array);
  40. void *ResetInputsArray(std::unique_ptr<uint8_t[]> inputs_array);
  41. void *ResetOutputsArray(std::unique_ptr<uint8_t[]> outputs_array);
  42. void *ResetWatchersArray(std::unique_ptr<uint8_t[]> watchers_array);
  43. void *ResetReadyQueue(void *ready_queue);
  44. void *ResetBuffer(std::unique_ptr<uint8_t[]> buffer);
  45. void *ResetComputeNodeInfo(std::unique_ptr<uint8_t[]> compute_node_info);
  46. void *ResetKernelExtendInfo(std::unique_ptr<uint8_t[]> kernel_extend_info);
  47. void *ResetModelDesc(std::unique_ptr<uint8_t[]> model_desc);
  48. ~ResourceGuard();
  49. private:
  50. std::unique_ptr<uint8_t[]> execution_data_holder_;
  51. size_t any_values_num_;
  52. std::unique_ptr<uint8_t[]> any_values_guard_;
  53. std::vector<std::unique_ptr<void, decltype(&free)>> nodes_guarder_;
  54. std::vector<std::unique_ptr<void, decltype(&free)>> watchers_guarder_;
  55. std::unique_ptr<uint8_t[]> continuous_buffer_guarder_;
  56. std::unique_ptr<uint8_t[]> buffer_guarder_;
  57. std::unique_ptr<uint8_t[]> compute_node_info_guarder_;
  58. std::unique_ptr<uint8_t[]> kernel_extend_info_guarder_;
  59. std::unique_ptr<uint8_t[]> model_desc_guarder_;
  60. std::unique_ptr<uint8_t[]> nodes_array_guarder_;
  61. std::unique_ptr<uint8_t[]> start_nodes_array_guarder_;
  62. std::unique_ptr<uint8_t[]> nodes_indgree_array_guarder_;
  63. std::unique_ptr<uint8_t[]> nodes_wait_indgree_array_guarder_;
  64. std::unique_ptr<uint8_t[]> inputs_array_guarder_;
  65. std::unique_ptr<uint8_t[]> outputs_array_guarder_;
  66. std::unique_ptr<uint8_t[]> watchers_array_guarder_;
  67. std::unique_ptr<void, decltype(&free)> ready_queue_guarder_{nullptr, nullptr};
  68. };
  69. struct ModelExecuteArg {
  70. rtStream_t stream;
  71. };
  72. static_assert(std::is_standard_layout<ModelExecuteArg>::value, "The class ModelExecuteArg must be a POD");
  73. class ExeGraphExecutor {
  74. public:
  75. // todo unload时释放anyvalue资源
  76. ge::graphStatus Load() {
  77. return ge::GRAPH_SUCCESS;
  78. }
  79. ge::graphStatus UnLoad() {
  80. return ge::GRAPH_SUCCESS;
  81. }
  82. /**
  83. * 设置图执行的输入/输出,需要注意的是,使用者需要自己保证inputs/outputs刷新完全!!!
  84. */
  85. ge::graphStatus SpecifyInputs(void **inputs, size_t start, size_t num);
  86. ge::graphStatus SpecifyOutputs(void **outputs, size_t num);
  87. ge::graphStatus Execute();
  88. const void *GetExecutionData() const {
  89. return execution_data_;
  90. }
  91. ResourceGuard &GetResourceGuard();
  92. void *SetExecutionData(std::unique_ptr<uint8_t[]> execution_data);
  93. private:
  94. friend class ModelV2ExecutorTestHelper;
  95. void *execution_data_;
  96. ResourceGuard resource_guard_;
  97. };
  98. class ModelV2Executor {
  99. public:
  100. static std::unique_ptr<ModelV2Executor> Create(const ge::ComputeGraphPtr &root_graph);
  101. ge::graphStatus Load();
  102. ge::graphStatus Execute(const ModelExecuteArg &arg, Tensor **inputs, size_t input_num, Tensor **outputs,
  103. size_t output_num);
  104. ge::graphStatus ExecuteSync(Tensor **inputs, size_t input_num, Tensor **outputs, size_t output_num);
  105. ge::graphStatus UnLoad();
  106. const ModelDesc &GetModelDesc() const;
  107. void SetModelDesc(ModelDesc *model_desc);
  108. ModelV2Executor(const ModelV2Executor &) = delete;
  109. ModelV2Executor(ModelV2Executor &&) = delete;
  110. ModelV2Executor &operator=(const ModelV2Executor &) = delete;
  111. ModelV2Executor &operator=(ModelV2Executor &&) = delete;
  112. private:
  113. friend class ModelV2ExecutorBuilder;
  114. friend class ModelV2ExecutorTestHelper;
  115. ModelV2Executor() = default;
  116. private:
  117. std::array<ExeGraphExecutor, kSubExeGraphTypeEnd> graphs_;
  118. ResourceGuard resource_guard_;
  119. ModelDesc *model_desc_ = nullptr;
  120. rtStream_t default_stream_ = nullptr;
  121. };
  122. } // namespace gert
  123. #endif // AIR_CXX_RUNTIME_V2_CORE_MODEL_V_2_EXECUTOR_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示