You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

label_maker.cc 16 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/label/label_maker.h"
  17. #include "framework/common/util.h"
  18. #include "framework/common/ge_inner_error_codes.h"
  19. #include "framework/common/types.h"
  20. #include "framework/common/op/ge_op_utils.h"
  21. #include "graph/debug/ge_attr_define.h"
  22. #include "graph/utils/graph_utils.h"
  23. namespace ge {
  24. /**
  25. * @ingroup ge
  26. * @brief Link node to graph head.
  27. * @param [in] graph: graph for add node.
  28. * @param [in] node: Node add to graph head.
  29. * @return: void
  30. */
  31. void LabelMaker::LinkToGraphHead(const ComputeGraphPtr &graph, const NodePtr &node) {
  32. static const std::set<std::string> non_calc_types = { DATA, CONSTANT, CONSTANTOP, VARIABLE };
  33. for (auto &n : graph->GetDirectNode()) {
  34. if (non_calc_types.count(n->GetType()) > 0) {
  35. continue;
  36. }
  37. const auto nodes = n->GetInDataNodes();
  38. if (nodes.empty()) {
  39. continue;
  40. }
  41. bool is_head_node = true;
  42. for (auto &in_node : nodes) {
  43. if (non_calc_types.count(in_node->GetType()) == 0) {
  44. is_head_node = false;
  45. break;
  46. }
  47. }
  48. if (!is_head_node) {
  49. continue;
  50. }
  51. if (GraphUtils::AddEdge(node->GetOutControlAnchor(), n->GetInControlAnchor()) != SUCCESS) {
  52. REPORT_CALL_ERROR("E19999", "Add ctrl edge from %s to %s in graph:%s fail", node->GetName().c_str(),
  53. n->GetName().c_str(), graph->GetName().c_str());
  54. GELOGE(INTERNAL_ERROR, "[Add][CtrlEdge] from %s to %s failed.", node->GetName().c_str(), n->GetName().c_str());
  55. }
  56. }
  57. }
  58. /**
  59. * @ingroup ge
  60. * @brief Link node to graph tail.
  61. * @param [in] graph: graph for add node.
  62. * @param [in] node: Node add to graph tail.
  63. * @return: void
  64. */
  65. void LabelMaker::LinkToGraphTail(const ComputeGraphPtr &graph, const NodePtr &node) {
  66. auto tail = graph->FindFirstNodeMatchType(NETOUTPUT);
  67. while (tail != nullptr) {
  68. auto nodes = tail->GetOutControlNodes();
  69. if (!nodes.empty()) {
  70. tail = nodes.at(0);
  71. continue;
  72. }
  73. if (GraphUtils::AddEdge(tail->GetOutControlAnchor(), node->GetInControlAnchor()) != SUCCESS) {
  74. REPORT_CALL_ERROR("E19999", "Add ctrl edge from %s to %s in graph:%s fail", tail->GetName().c_str(),
  75. node->GetName().c_str(), graph->GetName().c_str());
  76. GELOGE(INTERNAL_ERROR, "[Add][CtrlEdge] from %s to %s failed.", tail->GetName().c_str(), node->GetName().c_str());
  77. }
  78. return;
  79. }
  80. }
  81. /**
  82. * @ingroup ge
  83. * @brief Add StreamActive node at graph front.
  84. * @param [in] graph: graph for add node.
  85. * @param [in] name: stream active node name.
  86. * @return: NodePtr for success / nullptr for fail
  87. */
  88. NodePtr LabelMaker::AddStreamActive(const ComputeGraphPtr &graph, const std::string &name) {
  89. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  90. const auto &node_list = graph->GetDirectNode();
  91. if (node_list.empty()) {
  92. REPORT_INNER_ERROR("E19999", "Check param graph:%s has no node", graph->GetName().c_str());
  93. GELOGE(INTERNAL_ERROR, "[Check][Param] LabelSet: Graph %s node is empty.", graph->GetName().c_str());
  94. return nullptr;
  95. }
  96. OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMACTIVE);
  97. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  98. (void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true);
  99. GELOGI("StreamActive: Create node %s.", op_desc->GetName().c_str());
  100. vector<uint32_t> active_streams;
  101. (void)AttrUtils::SetStr(op_desc, ATTR_NAME_SWITCH_BRANCH_NODE_LABEL, op_desc->GetName());
  102. (void)AttrUtils::SetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams);
  103. (void)AttrUtils::SetBool(op_desc, ATTR_NAME_SUBGRAPH_FIRST_ACTIVE, true);
  104. NodePtr stream_active = graph->AddNodeFront(op_desc);
  105. GE_CHECK_NOTNULL_EXEC(stream_active, return nullptr);
  106. LinkToGraphHead(graph, stream_active);
  107. return stream_active;
  108. }
  109. /**
  110. * @ingroup ge
  111. * @brief Add LabelSet node at graph front.
  112. * @param [in] graph: graph for add node.
  113. * @param [in] name: label set node name.
  114. * @param [in] index: label id for set.
  115. * @return: NodePtr for success / nullptr for fail
  116. */
  117. NodePtr LabelMaker::AddLabelSetEnter(const ComputeGraphPtr &graph, const std::string &name, uint32_t index,
  118. NodePtr &stream_active) {
  119. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  120. GE_CHECK_NOTNULL_EXEC(stream_active, return nullptr);
  121. const auto &node_list = graph->GetDirectNode();
  122. if (node_list.empty()) {
  123. REPORT_INNER_ERROR("E19999", "Check param graph:%s has no node", graph->GetName().c_str());
  124. GELOGE(INTERNAL_ERROR, "[Check][Param] LabelSet: Graph %s node is empty.", graph->GetName().c_str());
  125. return nullptr;
  126. }
  127. OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSET);
  128. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  129. (void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true);
  130. GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str());
  131. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index);
  132. NodePtr label_set = graph->AddNodeFront(op_desc);
  133. GE_CHECK_NOTNULL_EXEC(label_set, return nullptr);
  134. if (GraphUtils::AddEdge(label_set->GetOutControlAnchor(), stream_active->GetInControlAnchor()) != SUCCESS) {
  135. REPORT_CALL_ERROR("E19999", "Add ctrl edge from %s to %s in graph:%s fail", label_set->GetName().c_str(),
  136. stream_active->GetName().c_str(), graph->GetName().c_str());
  137. GELOGE(INTERNAL_ERROR, "[Add][CtrlEdge] from %s to %s failed.", label_set->GetName().c_str(),
  138. stream_active->GetName().c_str());
  139. return nullptr;
  140. }
  141. return label_set;
  142. }
  143. /**
  144. * @ingroup ge
  145. * @brief Add LabelSet node at graph back.
  146. * @param [in] graph: graph for add node.
  147. * @param [in] name: label set node name.
  148. * @param [in] index: label id for set.
  149. * @return: NodePtr for success / nullptr for fail
  150. */
  151. NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) {
  152. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  153. OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSET);
  154. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  155. (void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true);
  156. GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str());
  157. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index);
  158. (void)AttrUtils::SetBool(op_desc, ATTR_NAME_SUBGRAPH_END_NODE, true);
  159. NodePtr label_set = graph->AddNode(op_desc);
  160. GE_CHECK_NOTNULL_EXEC(label_set, return nullptr);
  161. // Link control edge to graph tail.
  162. LinkToGraphTail(graph, label_set);
  163. return label_set;
  164. }
  165. /**
  166. * @ingroup ge
  167. * @brief Add LabelGoto node at graph front.
  168. * @param [in] graph: graph for add node.
  169. * @param [in] name: label goto node name.
  170. * @param [in] index: label id for goto.
  171. * @return: NodePtr for success / nullptr for fail
  172. */
  173. NodePtr LabelMaker::AddLabelGotoEnter(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) {
  174. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  175. const auto &node_list = graph->GetDirectNode();
  176. auto it = node_list.begin();
  177. if (it == node_list.end()) {
  178. REPORT_INNER_ERROR("E19999", "Check param graph:%s has no node", graph->GetName().c_str());
  179. GELOGE(INTERNAL_ERROR, "[Check][Param] LabelGoto: Graph %s node is empty.", graph->GetName().c_str());
  180. return nullptr;
  181. }
  182. OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTOEX);
  183. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  184. (void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true);
  185. GELOGI("LabelGoto: Create node %s.", op_desc->GetName().c_str());
  186. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index);
  187. NodePtr label_goto = graph->AddNodeFront(op_desc);
  188. if (label_goto == nullptr) {
  189. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s fail",
  190. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  191. GELOGE(INTERNAL_ERROR, "[Add][Node] to graph %s failed.", graph->GetName().c_str());
  192. return nullptr;
  193. }
  194. return label_goto;
  195. }
  196. /**
  197. * @ingroup ge
  198. * @brief Add LabelGoto node at graph back.
  199. * @param [in] graph: graph for add node.
  200. * @param [in] name: label goto node name.
  201. * @param [in] index: label id for goto.
  202. * @return: NodePtr for success / nullptr for fail
  203. */
  204. NodePtr LabelMaker::AddLabelGotoLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) {
  205. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  206. OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTOEX);
  207. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  208. (void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true);
  209. GELOGI("LabelGoto: Create node %s.", op_desc->GetName().c_str());
  210. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index);
  211. NodePtr label_goto = graph->AddNode(op_desc);
  212. GE_CHECK_NOTNULL_EXEC(label_goto, return nullptr);
  213. // Link control edge to graph tail.
  214. LinkToGraphTail(graph, label_goto);
  215. return label_goto;
  216. }
  217. /**
  218. * @ingroup ge
  219. * @brief Add LabelSwitch node at graph front.
  220. * @param [in] graph: graph for add node.
  221. * @param [in] name: label switch node name.
  222. * @param [in] desc: label index data desc.
  223. * @param [in] labels: label id for switch.
  224. * @return: NodePtr for success / nullptr for fail
  225. */
  226. NodePtr LabelMaker::AddLabelSwitchEnter(const ComputeGraphPtr &graph, const std::string &name, const GeTensorDesc &desc,
  227. const std::vector<uint32_t> &labels) {
  228. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  229. const auto &node_list = graph->GetDirectNode();
  230. auto it = node_list.begin();
  231. if (it == node_list.end()) {
  232. REPORT_INNER_ERROR("E19999", "Check param graph:%s has no node", graph->GetName().c_str());
  233. GELOGE(INTERNAL_ERROR, "[Check][Param] LabelSwitchByIndex: Graph %s node is empty.", graph->GetName().c_str());
  234. return nullptr;
  235. }
  236. OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSWITCHBYINDEX);
  237. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  238. (void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true);
  239. GELOGI("LabelSwitchByIndex: Create node %s.", op_desc->GetName().c_str());
  240. if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) {
  241. REPORT_CALL_ERROR("E19999", "Add input desc into node:%s(%s) in graph:%s fail",
  242. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  243. GELOGE(INTERNAL_ERROR, "[Add][InputDesc] failed.");
  244. return nullptr;
  245. }
  246. if (!AttrUtils::SetListInt(op_desc, ATTR_NAME_LABEL_SWITCH_LIST, labels)) {
  247. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_LABEL_SWITCH_LIST.c_str(),
  248. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  249. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s failed.", ATTR_NAME_LABEL_SWITCH_INDEX.c_str());
  250. return nullptr;
  251. }
  252. NodePtr label_switch = graph->AddNodeFront(op_desc);
  253. if (label_switch == nullptr) {
  254. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s ahead fail",
  255. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  256. GELOGE(INTERNAL_ERROR, "[Add][Node] to graph %s failed.", graph->GetName().c_str());
  257. return nullptr;
  258. }
  259. return label_switch;
  260. }
  261. /**
  262. * @ingroup ge
  263. * @brief Add LabelSwitch node at graph back.
  264. * @param [in] graph: graph for add node.
  265. * @param [in] name: label switch node name.
  266. * @param [in] desc: label index data desc.
  267. * @param [in] labels: label id for switch.
  268. * @return: NodePtr for success / nullptr for fail
  269. */
  270. NodePtr LabelMaker::AddLabelSwitchLeave(const ComputeGraphPtr &graph, const std::string &name, const GeTensorDesc &desc,
  271. const std::vector<uint32_t> &labels) {
  272. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  273. OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSWITCHBYINDEX);
  274. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  275. (void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true);
  276. GELOGI("LabelSwitchByIndex: Create node %s.", op_desc->GetName().c_str());
  277. if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) {
  278. REPORT_CALL_ERROR("E19999", "Add input desc into node:%s(%s) in graph:%s fail",
  279. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  280. GELOGE(INTERNAL_ERROR, "[Add][InputDesc] into node:%s(%s) in graph:%s fail",
  281. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  282. return nullptr;
  283. }
  284. if (!AttrUtils::SetListInt(op_desc, ATTR_NAME_LABEL_SWITCH_LIST, labels)) {
  285. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_LABEL_SWITCH_LIST.c_str(),
  286. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  287. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s failed.", ATTR_NAME_LABEL_SWITCH_INDEX.c_str());
  288. return nullptr;
  289. }
  290. NodePtr label_switch = graph->AddNode(op_desc);
  291. GE_CHECK_NOTNULL_EXEC(label_switch, return nullptr);
  292. // Link control edge to graph tail.
  293. LinkToGraphTail(graph, label_switch);
  294. return label_switch;
  295. }
  296. /**
  297. * @ingroup ge
  298. * @brief Add Data node at graph front for switch input.
  299. * @param [in] graph: graph for add node.
  300. * @param [in] name: label switch node name.
  301. * @param [in] desc: label index data desc.
  302. * @param [in] sw_node: switch node for add input.
  303. * @param [in] parent_index: index for parent node.
  304. * @return: NodePtr for success / nullptr for fail
  305. */
  306. NodePtr LabelMaker::AddLabelSwitchIndex(const ComputeGraphPtr &graph, const std::string &name, const GeTensorDesc &desc,
  307. const NodePtr &sw_node, uint32_t parent_index) {
  308. GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
  309. OpDescPtr op_desc = MakeShared<OpDesc>(name, DATA);
  310. GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr);
  311. GELOGI("Data: Create node %s.", op_desc->GetName().c_str());
  312. if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) {
  313. REPORT_CALL_ERROR("E19999", "Add input desc into node:%s(%s) in graph:%s fail",
  314. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  315. GELOGE(INTERNAL_ERROR, "[Add][InputDesc] into node:%s(%s) in graph:%s fail",
  316. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  317. return nullptr;
  318. }
  319. if (op_desc->AddOutputDesc(desc) != GRAPH_SUCCESS) {
  320. REPORT_CALL_ERROR("E19999", "Add output desc into node:%s(%s) in graph:%s fail",
  321. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  322. GELOGE(INTERNAL_ERROR, "[Add][OutputDesc] into node:%s(%s) in graph:%s fail",
  323. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  324. return nullptr;
  325. }
  326. if (!AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  327. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_PARENT_NODE_INDEX.c_str(),
  328. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  329. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_PARENT_NODE_INDEX.c_str(),
  330. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  331. return nullptr;
  332. }
  333. NodePtr op_data = graph->AddNodeFront(op_desc);
  334. GE_CHECK_NOTNULL_EXEC(op_data, return nullptr);
  335. GE_CHECK_NOTNULL_EXEC(graph->AddInputNode(op_data), return nullptr); // take as input node for memory assign.
  336. // Link control edge to graph head.
  337. if (GraphUtils::AddEdge(op_data->GetOutDataAnchor(0), sw_node->GetInDataAnchor(0)) != SUCCESS) {
  338. REPORT_CALL_ERROR("E19999", "Add ctrl edge from %s to %s in graph:%s fail", op_data->GetName().c_str(),
  339. sw_node->GetName().c_str(), graph->GetName().c_str());
  340. GELOGE(INTERNAL_ERROR, "[Add][CtrlEdge] to %s failed.", op_data->GetName().c_str());
  341. return nullptr;
  342. }
  343. return op_data;
  344. }
  345. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示