You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rts_node_task.cc 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "hybrid/node_executor/rts/rts_node_task.h"
  17. #include "hybrid/node_executor/rts/rts_task_factory.h"
  18. #include "graph/debug/ge_attr_define.h"
  19. #include "graph/utils/tensor_utils.h"
  20. #include "graph/utils/type_utils.h"
  21. #include "graph/utils/node_utils.h"
  22. #include "common/ge/ge_util.h"
  23. #include "common/op/ge_op_utils.h"
  24. namespace {
  25. constexpr uint8_t kSwitchPredIndex = 0;
  26. constexpr uint8_t kSwitchCompIndex = 1;
  27. const static std::map<rtCondition_t, std::function<bool(int64_t, int64_t)>> kCompHandle = {
  28. {RT_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value == comp_value; }},
  29. {RT_NOT_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value != comp_value; }},
  30. {RT_GREATER, [](int64_t pred_value, int64_t comp_value) { return pred_value > comp_value; }},
  31. {RT_GREATER_OR_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value >= comp_value; }},
  32. {RT_LESS, [](int64_t pred_value, int64_t comp_value) { return pred_value < comp_value; }},
  33. {RT_LESS_OR_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value <= comp_value; }},
  34. };
  35. }
  36. namespace ge {
  37. namespace hybrid {
  38. REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask);
  39. REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask);
  40. REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask);
  41. REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, MemcpyAsyncNodeTask);
  42. REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask);
  43. REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask);
  44. REGISTER_RTS_TASK_CREATOR(LOOPCOND, PassThroughNodeTask);
  45. REGISTER_RTS_TASK_CREATOR(NEXTITERATION, PassThroughNodeTask);
  46. REGISTER_RTS_TASK_CREATOR(REFNEXTITERATION, PassThroughNodeTask);
  47. REGISTER_RTS_TASK_CREATOR(EXIT, PassThroughNodeTask);
  48. REGISTER_RTS_TASK_CREATOR(REFEXIT, PassThroughNodeTask);
  49. REGISTER_RTS_TASK_CREATOR(LABELSET, LabelSetNodeTask);
  50. REGISTER_RTS_TASK_CREATOR(LABELGOTO, LabelGotoNodeTask);
  51. REGISTER_RTS_TASK_CREATOR(LABELGOTOEX, LabelGotoNodeTask);
  52. REGISTER_RTS_TASK_CREATOR(LABELSWITCH, LabelSwitchNodeTask);
  53. REGISTER_RTS_TASK_CREATOR(LABELSWITCHBYINDEX, LabelSwitchNodeTask);
  54. Status RtsNodeTask::GetScalarIndexValue(TaskContext &task_context, uint32_t index, int64_t &value) {
  55. auto tensor_value = task_context.GetInput(index);
  56. GE_CHECK_NOTNULL(tensor_value);
  57. auto tensor_desc = task_context.MutableInputDesc(index);
  58. GE_CHECK_NOTNULL(tensor_desc);
  59. auto data_type = tensor_desc->GetDataType();
  60. switch (data_type) {
  61. #define CASE_TYPE(DT, VT) \
  62. case (DT): { \
  63. VT data_val{}; \
  64. GE_CHK_STATUS_RET(tensor_value->CopyScalarValueToHost(data_val)); \
  65. value = static_cast<int64_t>(data_val); \
  66. break; \
  67. }
  68. // Just accept index data type.
  69. CASE_TYPE(DT_INT32, int32_t)
  70. CASE_TYPE(DT_INT64, int64_t)
  71. #undef CASE_TYPE
  72. default: {
  73. GELOGE(UNSUPPORTED, "Data type %s not index type.", TypeUtils::DataTypeToSerialString(data_type).c_str());
  74. return UNSUPPORTED;
  75. }
  76. }
  77. return SUCCESS;
  78. }
  79. Status StreamActiveNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  80. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  81. const auto &node_state = task_context.GetNodeState();
  82. node_state->SetSwitchIndex(0);
  83. if (done_callback) {
  84. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  85. }
  86. GELOGI("[%s] Done executing successfully.", task_context.GetNodeName());
  87. return SUCCESS;
  88. }
  89. Status StreamSwitchNodeTask::Init(const HybridModel &model, const NodePtr &node) {
  90. uint32_t value = 0;
  91. if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, value)) {
  92. GELOGE(INTERNAL_ERROR, "[%s] Get %s failed.", node->GetName().c_str(), ATTR_NAME_STREAM_SWITCH_COND.c_str());
  93. return INTERNAL_ERROR;
  94. }
  95. rtCondition_t cond = static_cast<rtCondition_t>(value);
  96. const auto it = kCompHandle.find(cond);
  97. if (it == kCompHandle.end()) {
  98. GELOGE(INTERNAL_ERROR, "[%s] Get Condition: %u handle failed.", node->GetName().c_str(), value);
  99. return INTERNAL_ERROR;
  100. }
  101. comp_func_ = it->second;
  102. GELOGD("[%s] Done initialization successfully, condition is %u.", node->GetName().c_str(), value);
  103. return SUCCESS;
  104. }
  105. Status StreamSwitchNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  106. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  107. GE_CHECK_NOTNULL(comp_func_);
  108. int64_t pred_value = 0;
  109. GE_CHK_STATUS_RET(GetScalarIndexValue(task_context, kSwitchPredIndex, pred_value));
  110. int64_t comp_value = 0;
  111. GE_CHK_STATUS_RET(GetScalarIndexValue(task_context, kSwitchCompIndex, comp_value));
  112. bool switch_idx = comp_func_(pred_value, comp_value);
  113. auto node_state = task_context.GetNodeState();
  114. node_state->SetSwitchIndex(static_cast<int>(switch_idx));
  115. if (done_callback) {
  116. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  117. }
  118. GELOGI("[%s] Done executing successfully, pred value: %ld, comp value: %ld, switch index: %d.",
  119. task_context.GetNodeName(), pred_value, comp_value, static_cast<int>(switch_idx));
  120. return SUCCESS;
  121. }
  122. Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  123. int index = task_context.GetNodeState()->GetMergeIndex();
  124. GELOGD("[%s] Start to execute, merge index: %d.", task_context.GetNodeName(), index);
  125. if (index < 0 || index >= task_context.NumInputs()) {
  126. GELOGE(INTERNAL_ERROR, "[%s] Invalid merge param, inputs num: %d, merge index: %d.",
  127. task_context.GetNodeName(), task_context.NumInputs(), index);
  128. return INTERNAL_ERROR;
  129. }
  130. const auto in_x = task_context.MutableInput(index); // x
  131. GE_CHECK_NOTNULL(in_x);
  132. GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(MERGE_DATA_OUTPUT, *in_x)); // y
  133. const auto out_y = task_context.MutableOutput(MERGE_INDEX_OUTPUT); // value_index
  134. GE_CHECK_NOTNULL(out_y);
  135. if (out_y->GetSize() > 0) {
  136. GE_CHK_RT_RET(rtMemcpyAsync(out_y->MutableData(), out_y->GetSize(), &index, sizeof(index),
  137. RT_MEMCPY_HOST_TO_DEVICE_EX, task_context.GetStream()));
  138. }
  139. if (done_callback) {
  140. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  141. }
  142. task_context.GetNodeState()->SetMergeIndex(-1); // Invalidate for loop.
  143. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
  144. return SUCCESS;
  145. }
  146. Status MemcpyAsyncNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  147. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  148. auto input_desc = task_context.MutableInputDesc(0);
  149. GE_CHECK_NOTNULL(input_desc);
  150. int64_t copy_size = 0;
  151. GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*input_desc, copy_size));
  152. // copy_size would not be negative since GetTensorSizeInBytes returned successfully.
  153. if (copy_size > 0) {
  154. const auto in_v = task_context.MutableInput(0);
  155. const auto out_v = task_context.MutableOutput(0);
  156. GE_CHECK_NOTNULL(in_v);
  157. GE_CHECK_NOTNULL(out_v);
  158. GELOGD("[%s] input size: %zu, output size: %zu, copy size: %ld", task_context.GetNodeName(),
  159. in_v->GetSize(), out_v->GetSize(), copy_size);
  160. GE_CHK_RT_RET(rtMemcpyAsync(out_v->MutableData(), out_v->GetSize(), in_v->GetData(), copy_size,
  161. RT_MEMCPY_DEVICE_TO_DEVICE, task_context.GetStream()));
  162. } else {
  163. GELOGW("[%s] invalid copy size: %ld", task_context.GetNodeName(), copy_size);
  164. }
  165. if (done_callback) {
  166. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  167. }
  168. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
  169. return SUCCESS;
  170. }
  171. Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  172. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  173. const auto in_x = task_context.GetInput(0); // x
  174. GE_CHECK_NOTNULL(in_x);
  175. GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(0, *in_x)); // y
  176. const auto &node_state = task_context.GetNodeState();
  177. if (kNextIterationOpTypes.count(node_state->GetType()) > 0) {
  178. node_state->RunLoopNext();
  179. } else if (kExitOpTypes.count(node_state->GetType()) > 0) {
  180. node_state->RunLoopExit();
  181. }
  182. if (done_callback) {
  183. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  184. }
  185. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
  186. return SUCCESS;
  187. }
  188. Status LabelSetNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  189. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  190. if (done_callback) {
  191. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  192. }
  193. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
  194. return UNSUPPORTED;
  195. }
  196. Status LabelGotoNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  197. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  198. if (done_callback) {
  199. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  200. }
  201. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
  202. return UNSUPPORTED;
  203. }
  204. Status LabelSwitchNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  205. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  206. if (done_callback) {
  207. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  208. }
  209. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
  210. return UNSUPPORTED;
  211. }
  212. } // namespace hybrid
  213. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示