You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_context.cc 9.9 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/build/run_context.h"
  17. #include "framework/common/util.h"
  18. #include "framework/common/debug/ge_log.h"
  19. #include "graph/debug/ge_attr_define.h"
  20. #include "common/omg_util.h"
  21. namespace ge {
  22. RunContextUtil::~RunContextUtil() { DestroyRtModelResources(); }
  23. Status RunContextUtil::InitMemInfo(uint8_t *data_mem_base, uint64_t data_mem_size,
  24. std::map<int64_t, uint8_t *> mem_type_to_data_mem_base,
  25. std::map<int64_t, uint64_t> mem_type_to_data_mem_size, uint8_t *weight_mem_base,
  26. uint64_t weight_mem_size) {
  27. if ((data_mem_size > 0) && (data_mem_base == nullptr)) {
  28. REPORT_INNER_ERROR("E19999", "InitMemInfo param data_mem_base is null but data_mem_size = %lu", data_mem_size);
  29. GELOGE(PARAM_INVALID, "[Check][Param] InitMemInfo param data_mem_base is null but data_mem_size = %lu.",
  30. data_mem_size);
  31. return PARAM_INVALID;
  32. }
  33. if ((weight_mem_size > 0) && (weight_mem_base == nullptr)) {
  34. REPORT_INNER_ERROR("E19999", "InitMemInfo param weight_mem_base is null but weight_mem_size = %lu",
  35. weight_mem_size);
  36. GELOGE(PARAM_INVALID, "[Check][Param] InitMemInfo param weight_mem_base is null but weight_mem_size = %lu.",
  37. weight_mem_size);
  38. return PARAM_INVALID;
  39. }
  40. if (mem_type_to_data_mem_base.empty() || mem_type_to_data_mem_size.empty() ||
  41. mem_type_to_data_mem_base.size() != mem_type_to_data_mem_size.size()) {
  42. REPORT_INNER_ERROR("E19999", "InitMemInfo param mem_type_to_data_mem_base size[%zu] "
  43. "is not equal to the size of mem_type_to_data_mem_size[%zu].",
  44. mem_type_to_data_mem_base.size(), mem_type_to_data_mem_size.size());
  45. GELOGE(PARAM_INVALID,
  46. "[Check][Param] InitMemInfo param mem_type_to_data_mem_base size[%zu] is not equal to the size of "
  47. "mem_type_to_data_mem_size[%zu].", mem_type_to_data_mem_base.size(), mem_type_to_data_mem_size.size());
  48. return PARAM_INVALID;
  49. }
  50. data_mem_base_ = data_mem_base;
  51. data_mem_size_ = data_mem_size;
  52. weight_mem_base_ = weight_mem_base;
  53. weight_mem_size_ = weight_mem_size;
  54. mem_type_to_data_mem_base_ = mem_type_to_data_mem_base;
  55. mem_type_to_data_mem_size_ = mem_type_to_data_mem_size;
  56. return SUCCESS;
  57. }
  58. Status RunContextUtil::CreateRtModelResources(uint32_t stream_num, uint32_t event_num, uint32_t label_num) {
  59. // Create rt model
  60. rtError_t rt_ret = rtModelCreate(&rt_model_, 0);
  61. if (rt_ret != RT_ERROR_NONE) {
  62. REPORT_CALL_ERROR("E19999", "call rtModelCreate failed, ret:%d,", static_cast<int>(rt_ret));
  63. GELOGE(RT_FAILED, "[Call][RtModelCreate] failed. rt_ret = %d", static_cast<int>(rt_ret));
  64. return RT_FAILED;
  65. }
  66. // Create rt Stream and bind with model
  67. for (uint32_t i = 0; i < stream_num; ++i) {
  68. rtStream_t stream = nullptr;
  69. rt_ret = rtStreamCreate(&stream, 0);
  70. if (rt_ret != RT_ERROR_NONE) {
  71. REPORT_CALL_ERROR("E19999", "call rtStreamCreate failed, ret:%d, index:%u,",
  72. static_cast<int>(rt_ret), i);
  73. GELOGE(RT_FAILED, "[Call][RtStreamCreate] failed. rt_ret = %d, index = %u", static_cast<int>(rt_ret), i);
  74. return RT_FAILED;
  75. }
  76. stream_list_.emplace_back(stream);
  77. rt_ret = rtModelBindStream(rt_model_, stream, 0);
  78. if (rt_ret != RT_ERROR_NONE) {
  79. REPORT_CALL_ERROR("E19999", "call rtModelBindStream failed, ret:%d, index:%u,",
  80. static_cast<int>(rt_ret), i);
  81. GELOGE(RT_FAILED, "[Bind][StreamAndModel] failed. rt_ret = %d, index = %u", static_cast<int>(rt_ret), i);
  82. return RT_FAILED;
  83. }
  84. }
  85. // Create rt event
  86. uint32_t create_flag = static_cast<uint32_t>((event_num > kEventReuseThreshold) ? RT_EVENT_WITH_FLAG :
  87. RT_EVENT_DEFAULT);
  88. for (uint32_t i = 0; i < event_num; ++i) {
  89. rtEvent_t event = nullptr;
  90. rt_ret = rtEventCreateWithFlag(&event, create_flag);
  91. if (rt_ret != RT_ERROR_NONE) {
  92. REPORT_CALL_ERROR("E19999", "call rtEventCreate failed, ret:%d, index:%u,",
  93. static_cast<int>(rt_ret), i);
  94. GELOGE(RT_FAILED, "[Call][RtEventCreate] failed. rt_ret = %d, index = %u", static_cast<int>(rt_ret), i);
  95. return RT_FAILED;
  96. }
  97. event_list_.emplace_back(event);
  98. }
  99. // Create rt label
  100. for (uint32_t i = 0; i < label_num; ++i) {
  101. rtLabel_t label = nullptr;
  102. rt_ret = rtLabelCreateV2(&label, rt_model_);
  103. if (rt_ret != RT_ERROR_NONE) {
  104. REPORT_CALL_ERROR("E19999", "call rtLabelCreateV2 failed, ret:%d, index:%u,",
  105. static_cast<int>(rt_ret), i);
  106. GELOGE(RT_FAILED, "[Call][RtLabelCreate] failed. rt_ret = %d, index = %u", static_cast<int>(rt_ret), i);
  107. return RT_FAILED;
  108. }
  109. label_list_.emplace_back(label);
  110. }
  111. return SUCCESS;
  112. }
  113. void RunContextUtil::DestroyRtModelResources() noexcept {
  114. rtError_t rt_ret;
  115. for (size_t i = 0; i < stream_list_.size(); i++) {
  116. // Unbind stream to model first
  117. (void)rtModelUnbindStream(rt_model_, stream_list_[i]);
  118. rt_ret = rtStreamDestroy(stream_list_[i]);
  119. if (rt_ret != RT_ERROR_NONE) {
  120. GELOGW("Destroy stream failed. rt_ret = %d, index = %zu.", static_cast<int>(rt_ret), i);
  121. }
  122. }
  123. stream_list_.clear();
  124. for (size_t i = 0; i < event_list_.size(); i++) {
  125. rt_ret = rtEventDestroy(event_list_[i]);
  126. if (rt_ret != RT_ERROR_NONE) {
  127. GELOGW("Destroy event failed. rt_ret = %d, index = %zu.", static_cast<int>(rt_ret), i);
  128. }
  129. }
  130. event_list_.clear();
  131. for (size_t i = 0; i < label_list_.size(); ++i) {
  132. rt_ret = rtLabelDestroy(label_list_[i]);
  133. if (rt_ret != RT_ERROR_NONE) {
  134. GELOGW("Destroy label failed. rt_ret = %d, index = %zu.", static_cast<int>(rt_ret), i);
  135. }
  136. }
  137. label_list_.clear();
  138. if (rt_model_ != nullptr) {
  139. rt_ret = rtModelDestroy(rt_model_);
  140. if (rt_ret != RT_ERROR_NONE) {
  141. GELOGW("Destroy rt model failed. rt_ret = %d.", static_cast<int>(rt_ret));
  142. }
  143. rt_model_ = nullptr;
  144. }
  145. }
  146. Status RunContextUtil::CreateRunContext(Model &model, const ComputeGraphPtr &graph, Buffer &buffer,
  147. const uint64_t session_id) {
  148. GELOGD("Begin to Create RunContext, session_id = %lu", session_id);
  149. // check params
  150. if (graph == nullptr) {
  151. REPORT_INNER_ERROR("E19999", "Check param graph nullptr, session_id:%lu,", session_id);
  152. GELOGE(PARAM_INVALID, "[Check][Param] CreateRunContext param graph is null. session_id=%lu", session_id);
  153. return PARAM_INVALID;
  154. }
  155. uint32_t stream_num = 0;
  156. if (!AttrUtils::GetInt(&model, ATTR_MODEL_STREAM_NUM, stream_num)) {
  157. REPORT_INNER_ERROR("E19999", "Get Attr:%s failed from model, session_id:%lu,",
  158. ATTR_MODEL_STREAM_NUM.c_str(), session_id);
  159. GELOGE(INTERNAL_ERROR, "[Get][Attr] %s failed from model. session_id=%lu",
  160. ATTR_MODEL_STREAM_NUM.c_str(), session_id);
  161. return INTERNAL_ERROR;
  162. }
  163. GELOGD("Stream_num = %u", stream_num);
  164. uint32_t event_num = 0;
  165. if (!AttrUtils::GetInt(&model, ATTR_MODEL_EVENT_NUM, event_num)) {
  166. REPORT_INNER_ERROR("E19999", "Get Attr:%s failed from model, session_id:%lu,",
  167. ATTR_MODEL_EVENT_NUM.c_str(), session_id);
  168. GELOGE(INTERNAL_ERROR, "[Get][Attr] %s failed from model, session_id:%lu,",
  169. ATTR_MODEL_EVENT_NUM.c_str(), session_id);
  170. return INTERNAL_ERROR;
  171. }
  172. GELOGD("Event_num = %u", event_num);
  173. uint32_t label_num = 0;
  174. if (!AttrUtils::GetInt(&model, ATTR_MODEL_LABEL_NUM, label_num)) {
  175. REPORT_INNER_ERROR("E19999", "Get Attr:%s failed from model, session_id:%lu,",
  176. ATTR_MODEL_LABEL_NUM.c_str(), session_id);
  177. GELOGE(INTERNAL_ERROR, "[Get][Attr] %s failed from model, session_id:%lu,",
  178. ATTR_MODEL_LABEL_NUM.c_str(), session_id);
  179. return INTERNAL_ERROR;
  180. }
  181. GELOGD("Label_num = %u", label_num);
  182. Status ret = CreateRtModelResources(stream_num, event_num, label_num);
  183. if (ret != SUCCESS) {
  184. GELOGE(ret, "[Create][RtModelResources] failed. session_id=%lu", session_id);
  185. DestroyRtModelResources();
  186. return ret;
  187. }
  188. GELOGI("CreateRunContext: data_mem_base_ = %p, weight_mem_base_ = %p, memory_size = %lu, weight_size = %lu",
  189. data_mem_base_, weight_mem_base_, data_mem_size_, weight_mem_size_);
  190. PrintMemInfo();
  191. run_context_ = {rt_model_,
  192. nullptr,
  193. session_id,
  194. data_mem_size_,
  195. data_mem_base_,
  196. mem_type_to_data_mem_size_,
  197. mem_type_to_data_mem_base_,
  198. weight_mem_size_,
  199. weight_mem_base_,
  200. buffer,
  201. stream_list_,
  202. event_list_,
  203. label_list_};
  204. return SUCCESS;
  205. }
  206. void RunContextUtil::PrintMemInfo() {
  207. for (auto iter : mem_type_to_data_mem_base_) {
  208. GELOGD("CreateRunContext: memory type = %ld, data memory base = %p", iter.first, iter.second);
  209. }
  210. for (auto iter : mem_type_to_data_mem_size_) {
  211. GELOGD("CreateRunContext: memory type = %ld, data memory size = %lu", iter.first, iter.second);
  212. }
  213. }
  214. RunContext &RunContextUtil::GetRunContext() { return run_context_; }
  215. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示