You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

case_label_maker.cc 6.8 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/label/case_label_maker.h"
  17. #include "framework/common/util.h"
  18. #include "framework/common/ge_inner_error_codes.h"
  19. #include "framework/common/types.h"
  20. #include "framework/common/op/ge_op_utils.h"
  21. #include "graph/debug/ge_attr_define.h"
  22. #include "graph/utils/graph_utils.h"
  23. namespace ge {
  24. constexpr uint32_t kCasePredIndex = 0;
  25. constexpr uint32_t kMinCaseBranch = 1;
  26. constexpr uint32_t kMaxCaseBranch = 0x7fffffff;
  27. /**
  28. * @ingroup ge
  29. * @brief Make label node to functional call.
  30. * @param [in/out] label_index: serial id for whole graph.
  31. * @return: 0 for success / others for fail
  32. */
  33. Status CaseOpLabelMaker::Run(uint32_t &label_index) {
  34. GE_CHECK_NOTNULL(parent_node_);
  35. GE_CHECK_NOTNULL(parent_graph_);
  36. OpDescPtr case_desc = parent_node_->GetOpDesc();
  37. GE_CHECK_NOTNULL(case_desc);
  38. const auto graph_names = case_desc->GetSubgraphInstanceNames();
  39. if (graph_names.empty() || graph_names.size() > kMaxCaseBranch) {
  40. REPORT_INNER_ERROR("E19999", "Node:%s(%s) subgraph size: %zu, check invalid", case_desc->GetName().c_str(),
  41. case_desc->GetType().c_str(), graph_names.size());
  42. GELOGE(INTERNAL_ERROR, "[Check][Param] Node: %s has invalid subgraph, graph size: %zu.",
  43. case_desc->GetName().c_str(), graph_names.size());
  44. return FAILED;
  45. }
  46. // One branch, no need label.
  47. const uint32_t graph_num = static_cast<uint32_t>(graph_names.size());
  48. if (graph_num == kMinCaseBranch) {
  49. GELOGI("Node: %s just one subgraph.", case_desc->GetName().c_str());
  50. return SUCCESS;
  51. }
  52. NodePtr first_label = nullptr;
  53. ComputeGraphPtr first_graph = nullptr;
  54. std::vector<uint32_t> switch_labels;
  55. uint32_t last_label_index = label_index++;
  56. for (uint32_t index = 0; index < graph_num; ++index) {
  57. ComputeGraphPtr graph = parent_graph_->GetSubgraph(graph_names[index]);
  58. GE_CHECK_NOTNULL(graph);
  59. // all branch, add label and stream active nodes to head.
  60. std::string stream_active_name =
  61. parent_node_->GetName() + "/StreamActive_" + std::to_string(index); // rtStreamActive
  62. NodePtr stream_active = AddStreamActive(graph, stream_active_name);
  63. if (stream_active == nullptr) {
  64. REPORT_CALL_ERROR("E19999", "Add StreamActive node in graph:%s fail",
  65. graph->GetName().c_str());
  66. GELOGE(INTERNAL_ERROR, "[Add][StreamActive] in Subgraph: %s failed.", graph->GetName().c_str());
  67. return FAILED;
  68. }
  69. uint32_t curr_label_index = label_index++;
  70. std::string label_set_name = parent_node_->GetName() + "/LabelSet_" + std::to_string(index); // rtLabelSet
  71. NodePtr label = AddLabelSetEnter(graph, label_set_name, curr_label_index, stream_active);
  72. if (label == nullptr) {
  73. REPORT_CALL_ERROR("E19999", "Add LabelSetEnter node in graph:%s fail",
  74. graph->GetName().c_str());
  75. GELOGE(INTERNAL_ERROR, "[Call][AddLabelSetEnter] Subgraph: %s add label set failed.", graph->GetName().c_str());
  76. return FAILED;
  77. }
  78. switch_labels.emplace_back(curr_label_index);
  79. if (index == 0) { // save first subgraph node for switch.
  80. first_label = label;
  81. first_graph = graph;
  82. }
  83. if (index + 1 < graph_num) {
  84. // middle node, add goto node to tail.
  85. std::string label_goto_name = parent_node_->GetName() + "/LabelGoto_" + std::to_string(index); // rtLabelGoto
  86. if (AddLabelGotoLeave(graph, label_goto_name, last_label_index) == nullptr) {
  87. REPORT_CALL_ERROR("E19999", "Add LabelGotoLeave node in graph:%s fail",
  88. graph->GetName().c_str());
  89. GELOGE(INTERNAL_ERROR, "[Call][AddLabelGotoLeave] Subgraph: %s add label goto failed.",
  90. graph->GetName().c_str());
  91. return FAILED;
  92. }
  93. } else {
  94. // last node, add label node to tail.
  95. std::string last_label_name = parent_node_->GetName() + "/LabelSet_Last"; // rtLabelSet
  96. if (AddLabelSetLeave(graph, last_label_name, last_label_index) == nullptr) {
  97. REPORT_CALL_ERROR("E19999", "Add LabelSetLeave node in graph:%s fail",
  98. graph->GetName().c_str());
  99. GELOGE(INTERNAL_ERROR, "[Call][AddLabelSetLeave] Subgraph: %s add label set failed.",
  100. graph->GetName().c_str());
  101. return FAILED;
  102. }
  103. }
  104. }
  105. // Add Switch node for first branch.
  106. GE_CHECK_NOTNULL(first_label);
  107. GE_CHECK_NOTNULL(first_graph);
  108. // first case, add switch node to head.
  109. const std::string label_switch_name = parent_node_->GetName() + "/LabelSwitch"; // rtLabelSwitchByIndex
  110. const GeTensorDesc &pred_desc = case_desc->GetInputDesc(kCasePredIndex);
  111. NodePtr switch_node = AddLabelSwitchEnter(first_graph, label_switch_name, pred_desc, switch_labels);
  112. if (switch_node == nullptr) {
  113. REPORT_CALL_ERROR("E19999", "Add LabelSwitchEnter node in graph:%s fail",
  114. first_graph->GetName().c_str());
  115. GELOGE(INTERNAL_ERROR, "[Call][AddLabelSwitchEnter] Subgraph: %s add label switch failed.",
  116. first_graph->GetName().c_str());
  117. return FAILED;
  118. }
  119. // Link control edge to then branch head.
  120. if (GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), first_label->GetInControlAnchor()) != SUCCESS) {
  121. REPORT_CALL_ERROR("E19999", "Add ctrl edge from %s to %s in graph:%s fail", switch_node->GetName().c_str(),
  122. first_label->GetName().c_str(), first_graph->GetName().c_str());
  123. GELOGE(INTERNAL_ERROR, "[Add][CtrlEdge] to %s failed.", first_label->GetName().c_str());
  124. return FAILED;
  125. }
  126. uint32_t parent_index = 0; // Case cond input is first.
  127. const std::string data_name = parent_node_->GetName() + "/SwitchIndexData";
  128. if (AddLabelSwitchIndex(first_graph, data_name, pred_desc, switch_node, parent_index) == nullptr) {
  129. REPORT_CALL_ERROR("E19999", "Add LabelSwitchIndex node in graph:%s fail",
  130. first_graph->GetName().c_str());
  131. GELOGE(INTERNAL_ERROR, "[Call][AddLabelSwitchIndex] Subgraph: %s add switch input failed.",
  132. first_graph->GetName().c_str());
  133. return FAILED;
  134. }
  135. GELOGI("Node: %s assign label success.", case_desc->GetName().c_str());
  136. return SUCCESS;
  137. }
  138. REGISTER_LABEL_MAKER(CASE, CaseOpLabelMaker);
  139. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示