You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

case_label_maker.cc 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "case_label_maker.h"
  17. #include "common/util.h"
  18. #include "common/ge_inner_error_codes.h"
  19. #include "framework/common/types.h"
  20. #include "framework/common/op/ge_op_utils.h"
  21. #include "graph/debug/ge_attr_define.h"
  22. #include "graph/utils/graph_utils.h"
  23. namespace ge {
  24. constexpr uint32_t kCasePredIndex = 0;
  25. constexpr uint32_t kMinCaseBranch = 1;
  26. constexpr uint32_t kMaxCaseBranch = 0x7fffffff;
  27. /**
  28. * @ingroup ge
  29. * @brief Make label node to functional call.
  30. * @param [in/out] label_index: serial id for whole graph.
  31. * @return: 0 for success / others for fail
  32. */
  33. Status CaseOpLabelMaker::Run(uint32_t &label_index) {
  34. GE_CHECK_NOTNULL(parent_node_);
  35. GE_CHECK_NOTNULL(parent_graph_);
  36. OpDescPtr case_desc = parent_node_->GetOpDesc();
  37. GE_CHECK_NOTNULL(case_desc);
  38. const auto graph_names = case_desc->GetSubgraphInstanceNames();
  39. if (graph_names.empty() || graph_names.size() > kMaxCaseBranch) {
  40. GELOGE(INTERNAL_ERROR, "Node: %s has invalid subgraph, graph size: %zu.", case_desc->GetName().c_str(),
  41. graph_names.size());
  42. return FAILED;
  43. }
  44. // One branch, no need label.
  45. const uint32_t graph_num = static_cast<uint32_t>(graph_names.size());
  46. if (graph_num == kMinCaseBranch) {
  47. GELOGI("Node: %s just one subgraph.", case_desc->GetName().c_str());
  48. return SUCCESS;
  49. }
  50. NodePtr first_label = nullptr;
  51. ComputeGraphPtr first_graph = nullptr;
  52. std::vector<uint32_t> switch_labels;
  53. uint32_t last_label_index = label_index++;
  54. for (uint32_t index = 0; index < graph_num; ++index) {
  55. ComputeGraphPtr graph = parent_graph_->GetSubgraph(graph_names[index]);
  56. GE_CHECK_NOTNULL(graph);
  57. // all branch, add label and stream active nodes to head.
  58. std::string stream_active_name =
  59. parent_node_->GetName() + "/StreamActive_" + std::to_string(index); // rtStreamActive
  60. NodePtr stream_active = AddStreamActive(graph, stream_active_name);
  61. if (stream_active == nullptr) {
  62. GELOGE(INTERNAL_ERROR, "Subgraph: %s add stream active failed.", graph->GetName().c_str());
  63. return FAILED;
  64. }
  65. uint32_t curr_label_index = label_index++;
  66. std::string label_set_name = parent_node_->GetName() + "/LabelSet_" + std::to_string(index); // rtLabelSet
  67. NodePtr label = AddLabelSetEnter(graph, label_set_name, curr_label_index, stream_active);
  68. if (label == nullptr) {
  69. GELOGE(INTERNAL_ERROR, "Subgraph: %s add label set failed.", graph->GetName().c_str());
  70. return FAILED;
  71. }
  72. switch_labels.emplace_back(curr_label_index);
  73. if (index == 0) { // save first subgraph node for switch.
  74. first_label = label;
  75. first_graph = graph;
  76. }
  77. if (index + 1 < graph_num) {
  78. // middle node, add goto node to tail.
  79. std::string label_goto_name = parent_node_->GetName() + "/LabelGoto_" + std::to_string(index); // rtLabelGoto
  80. if (AddLabelGotoLeave(graph, label_goto_name, last_label_index) == nullptr) {
  81. GELOGE(INTERNAL_ERROR, "Subgraph: %s add label goto failed.", graph->GetName().c_str());
  82. return FAILED;
  83. }
  84. } else {
  85. // last node, add label node to tail.
  86. std::string last_label_name = parent_node_->GetName() + "/LabelSet_Last"; // rtLabelSet
  87. if (AddLabelSetLeave(graph, last_label_name, last_label_index) == nullptr) {
  88. GELOGE(INTERNAL_ERROR, "Subgraph: %s add label set failed.", graph->GetName().c_str());
  89. return FAILED;
  90. }
  91. }
  92. }
  93. // Add Switch node for first branch.
  94. GE_CHECK_NOTNULL(first_label);
  95. GE_CHECK_NOTNULL(first_graph);
  96. // first case, add switch node to head.
  97. const std::string label_switch_name = parent_node_->GetName() + "/LabelSwitch"; // rtLabelSwitchByIndex
  98. const GeTensorDesc &pred_desc = case_desc->GetInputDesc(kCasePredIndex);
  99. NodePtr switch_node = AddLabelSwitchEnter(first_graph, label_switch_name, pred_desc, switch_labels);
  100. if (switch_node == nullptr) {
  101. GELOGE(INTERNAL_ERROR, "Subgraph: %s add label switch failed.", first_graph->GetName().c_str());
  102. return FAILED;
  103. }
  104. // Link control edge to then branch head.
  105. if (GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), first_label->GetInControlAnchor()) != SUCCESS) {
  106. GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add ctrl edge to %s failed.", first_label->GetName().c_str());
  107. return FAILED;
  108. }
  109. uint32_t parent_index = 0; // Case cond input is first.
  110. const std::string data_name = parent_node_->GetName() + "/SwitchIndexData";
  111. if (AddLabelSwitchIndex(first_graph, data_name, pred_desc, switch_node, parent_index) == nullptr) {
  112. GELOGE(INTERNAL_ERROR, "Subgraph: %s add switch input failed.", first_graph->GetName().c_str());
  113. return FAILED;
  114. }
  115. GELOGI("Node: %s assign label success.", case_desc->GetName().c_str());
  116. return SUCCESS;
  117. }
  118. REGISTER_LABEL_MAKER(CASE, CaseOpLabelMaker);
  119. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示