You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rts_node_task.cc 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "hybrid/node_executor/rts/rts_node_task.h"
  17. #include "hybrid/node_executor/rts/rts_task_factory.h"
  18. #include "graph/debug/ge_attr_define.h"
  19. #include "graph/utils/tensor_utils.h"
  20. #include "graph/utils/type_utils.h"
  21. #include "graph/utils/node_utils.h"
  22. #include "common/ge/ge_util.h"
  23. #include "common/op/ge_op_utils.h"
  24. namespace {
  25. constexpr uint8_t kSwitchPredIndex = 0;
  26. constexpr uint8_t kSwitchCompIndex = 1;
  27. const static std::map<rtCondition_t, std::function<bool(int64_t, int64_t)>> kCompHandle = {
  28. {RT_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value == comp_value; }},
  29. {RT_NOT_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value != comp_value; }},
  30. {RT_GREATER, [](int64_t pred_value, int64_t comp_value) { return pred_value > comp_value; }},
  31. {RT_GREATER_OR_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value >= comp_value; }},
  32. {RT_LESS, [](int64_t pred_value, int64_t comp_value) { return pred_value < comp_value; }},
  33. {RT_LESS_OR_EQUAL, [](int64_t pred_value, int64_t comp_value) { return pred_value <= comp_value; }},
  34. };
  35. }
  36. namespace ge {
  37. namespace hybrid {
  38. REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask);
  39. REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask);
  40. REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask);
  41. REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask);
  42. REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask);
  43. REGISTER_RTS_TASK_CREATOR(LOOPCOND, PassThroughNodeTask);
  44. REGISTER_RTS_TASK_CREATOR(NEXTITERATION, PassThroughNodeTask);
  45. REGISTER_RTS_TASK_CREATOR(REFNEXTITERATION, PassThroughNodeTask);
  46. REGISTER_RTS_TASK_CREATOR(EXIT, PassThroughNodeTask);
  47. REGISTER_RTS_TASK_CREATOR(REFEXIT, PassThroughNodeTask);
  48. REGISTER_RTS_TASK_CREATOR(LABELSET, LabelSetNodeTask);
  49. REGISTER_RTS_TASK_CREATOR(LABELGOTO, LabelGotoNodeTask);
  50. REGISTER_RTS_TASK_CREATOR(LABELGOTOEX, LabelGotoNodeTask);
  51. REGISTER_RTS_TASK_CREATOR(LABELSWITCH, LabelSwitchNodeTask);
  52. REGISTER_RTS_TASK_CREATOR(LABELSWITCHBYINDEX, LabelSwitchNodeTask);
  53. Status RtsNodeTask::GetScalarIndexValue(TaskContext &task_context, uint32_t index, int64_t &value) {
  54. auto tensor_value = task_context.GetInput(index);
  55. GE_CHECK_NOTNULL(tensor_value);
  56. auto tensor_desc = task_context.MutableInputDesc(index);
  57. GE_CHECK_NOTNULL(tensor_desc);
  58. auto data_type = tensor_desc->GetDataType();
  59. switch (data_type) {
  60. #define CASE_TYPE(DT, VT) \
  61. case (DT): { \
  62. VT data_val{}; \
  63. GE_CHK_STATUS_RET(tensor_value->CopyScalarValueToHost(data_val)); \
  64. value = static_cast<int64_t>(data_val); \
  65. break; \
  66. }
  67. // Just accept index data type.
  68. CASE_TYPE(DT_INT32, int32_t)
  69. CASE_TYPE(DT_INT64, int64_t)
  70. #undef CASE_TYPE
  71. default: {
  72. GELOGE(UNSUPPORTED, "Data type %s not index type.", TypeUtils::DataTypeToSerialString(data_type).c_str());
  73. return UNSUPPORTED;
  74. }
  75. }
  76. return SUCCESS;
  77. }
  78. Status StreamActiveNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  79. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  80. const auto &node_state = task_context.GetNodeState();
  81. node_state->RunStreamActive();
  82. if (done_callback) {
  83. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  84. }
  85. GELOGI("[%s] Done executing successfully.", task_context.GetNodeName());
  86. return SUCCESS;
  87. }
  88. Status StreamSwitchNodeTask::Init(const HybridModel &model, const NodePtr &node) {
  89. uint32_t value = 0;
  90. if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, value)) {
  91. GELOGE(INTERNAL_ERROR, "[%s] Get %s failed.", node->GetName().c_str(), ATTR_NAME_STREAM_SWITCH_COND.c_str());
  92. return INTERNAL_ERROR;
  93. }
  94. rtCondition_t cond = static_cast<rtCondition_t>(value);
  95. const auto it = kCompHandle.find(cond);
  96. if (it == kCompHandle.end()) {
  97. GELOGE(INTERNAL_ERROR, "[%s] Get Condition: %u handle failed.", node->GetName().c_str(), value);
  98. return INTERNAL_ERROR;
  99. }
  100. comp_func_ = it->second;
  101. GELOGD("[%s] Done initialization successfully, condition is %u.", node->GetName().c_str(), value);
  102. return SUCCESS;
  103. }
  104. Status StreamSwitchNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  105. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  106. GE_CHECK_NOTNULL(comp_func_);
  107. int64_t pred_value = 0;
  108. GE_CHK_STATUS_RET(GetScalarIndexValue(task_context, kSwitchPredIndex, pred_value));
  109. int64_t comp_value = 0;
  110. GE_CHK_STATUS_RET(GetScalarIndexValue(task_context, kSwitchCompIndex, comp_value));
  111. bool switch_idx = comp_func_(pred_value, comp_value);
  112. auto node_state = task_context.GetNodeState();
  113. node_state->SetSwitchIndex(static_cast<int>(switch_idx));
  114. if (done_callback) {
  115. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  116. }
  117. GELOGI("[%s] Done executing successfully, pred value: %ld, comp value: %ld, switch index: %d.",
  118. task_context.GetNodeName(), pred_value, comp_value, static_cast<int>(switch_idx));
  119. return SUCCESS;
  120. }
  121. Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  122. int index = task_context.GetNodeState()->GetMergeIndex();
  123. GELOGD("[%s] Start to execute, merge index: %d.", task_context.GetNodeName(), index);
  124. if (index < 0 || index >= task_context.NumInputs()) {
  125. GELOGE(INTERNAL_ERROR, "[%s] Invalid merge param, inputs num: %d, merge index: %d.",
  126. task_context.GetNodeName(), task_context.NumInputs(), index);
  127. return INTERNAL_ERROR;
  128. }
  129. const auto in_x = task_context.MutableInput(index); // x
  130. GE_CHECK_NOTNULL(in_x);
  131. GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(MERGE_DATA_OUTPUT, *in_x)); // y
  132. const auto out_y = task_context.MutableOutput(MERGE_INDEX_OUTPUT); // value_index
  133. GE_CHECK_NOTNULL(out_y);
  134. if (out_y->GetSize() > 0) {
  135. GE_CHK_RT_RET(rtMemcpyAsync(out_y->MutableData(), out_y->GetSize(), &index, sizeof(index),
  136. RT_MEMCPY_HOST_TO_DEVICE_EX, task_context.GetStream()));
  137. }
  138. if (done_callback) {
  139. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  140. }
  141. task_context.GetNodeState()->SetMergeIndex(-1); // Invalidate for loop.
  142. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
  143. return SUCCESS;
  144. }
  145. Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  146. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  147. const auto in_x = task_context.GetInput(0); // x
  148. GE_CHECK_NOTNULL(in_x);
  149. GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(0, *in_x)); // y
  150. const auto &node_state = task_context.GetNodeState();
  151. if (kNextIterationOpTypes.count(node_state->GetType()) > 0) {
  152. node_state->RunNextIteration();
  153. }
  154. if (done_callback) {
  155. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  156. }
  157. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
  158. return SUCCESS;
  159. }
  160. Status LabelSetNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  161. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  162. if (done_callback) {
  163. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  164. }
  165. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
  166. return UNSUPPORTED;
  167. }
  168. Status LabelGotoNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  169. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  170. if (done_callback) {
  171. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  172. }
  173. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
  174. return UNSUPPORTED;
  175. }
  176. Status LabelSwitchNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
  177. GELOGD("[%s] Start to execute.", task_context.GetNodeName());
  178. if (done_callback) {
  179. GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
  180. }
  181. GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
  182. return UNSUPPORTED;
  183. }
  184. } // namespace hybrid
  185. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示