You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

for_pass.cc 26 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/for_pass.h"
  17. #include "common/ge/ge_util.h"
  18. #include "common/op/ge_op_utils.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "framework/common/debug/log.h"
  21. #include "framework/common/ge_inner_error_codes.h"
  22. #include "framework/common/types.h"
  23. #include "graph/debug/ge_attr_define.h"
  24. #include "graph/utils/graph_utils.h"
  25. #include "graph/utils/type_utils.h"
  26. #include "graph/utils/node_utils.h"
  27. #include "graph/utils/op_desc_utils.h"
  28. namespace {
  29. const uint32_t kWhileIInputIndex = 0;
  30. const uint32_t kWhileAbsDeltaInputIndex = 1;
  31. const uint32_t kWhileRangeInputIndex = 2;
  32. const uint32_t kWhileStartInputIndex = 3;
  33. const uint32_t kWhileDeltaInputIndex = 4;
  34. const uint32_t kWhileDataInputIndex = 5;
  35. const uint32_t kSubgraphLoopVarInputIndex = 0;
  36. const uint32_t kSubgraphInputIndex = 1;
  37. const uint32_t kWhileOutputIndex = 5;
  38. const std::string kAbs = "Abs";
  39. }
  40. namespace ge {
  41. Status ForPass::Run(NodePtr &node) {
  42. if (node->GetType() != FOR) {
  43. GELOGD("no need for_pass for node %s.", node->GetName().c_str());
  44. return SUCCESS;
  45. }
  46. GELOGI("Begin to transfer for_op to while_op, node:%s.", node->GetName().c_str());
  47. ComputeGraphPtr graph = node->GetOwnerComputeGraph();
  48. GE_CHECK_NOTNULL(graph);
  49. ComputeGraphPtr root_graph = GraphUtils::FindRootGraph(graph);
  50. GE_CHECK_NOTNULL(root_graph);
  51. ForInfo for_info;
  52. GE_CHK_STATUS_RET(BuildForInfo(root_graph, node, for_info),
  53. "Build ForInfo failed, node:%s.", node->GetName().c_str());
  54. WhileInfo while_info;
  55. GE_CHK_STATUS_RET(TranWhileInfo(graph, for_info, while_info),
  56. "Transfer WhileInfo from ForInfo failed, node:%s.", node->GetName().c_str());
  57. ComputeGraphPtr cond_graph = BuildCondGraph(while_info);
  58. if ((cond_graph == nullptr) || (root_graph->AddSubgraph(cond_graph) != GRAPH_SUCCESS)) {
  59. GELOGE(FAILED, "Add while_cond_graph failed, node:%s.", node->GetName().c_str());
  60. return FAILED;
  61. }
  62. ComputeGraphPtr body_graph = BuildBodyGraph(while_info);
  63. if ((body_graph == nullptr) || (root_graph->AddSubgraph(body_graph) != GRAPH_SUCCESS)) {
  64. GELOGE(FAILED, "Add while_body_graph failed, node:%s.", node->GetName().c_str());
  65. return FAILED;
  66. }
  67. GE_CHK_STATUS_RET(UpdateForBodyInputMapping(while_info),
  68. "Update InputMapping for for-body-graph failed, node:%s.", node->GetName().c_str());
  69. // for node has and only has one subgraph
  70. GE_CHECK_NOTNULL(node->GetOpDesc());
  71. node->GetOpDesc()->RemoveSubgraphInstanceName(node->GetOpDesc()->GetSubgraphInstanceName(0));
  72. GELOGI("Transfer for_op to while_op succ, node:%s.", node->GetName().c_str());
  73. return IsolateAndDeleteNode(node, std::vector<int>());
  74. }
  75. ///
  76. /// @brief Build for_info
  77. /// @param [in] root_graph
  78. /// @param [in] node
  79. /// @param [out] for_info
  80. /// @return Status
  81. ///
  82. Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &node, ForInfo &for_info) {
  83. GELOGI("Begin to build for_info for node %s.", node->GetName().c_str());
  84. OutDataAnchorPtr start = FindInputWithIndex(node, FOR_START_INPUT);
  85. OutDataAnchorPtr limit = FindInputWithIndex(node, FOR_LIMIT_INPUT);
  86. OutDataAnchorPtr delta = FindInputWithIndex(node, FOR_DELTA_INPUT);
  87. if ((start == nullptr) || (limit == nullptr) || (delta == nullptr)) {
  88. GELOGE(FAILED, "BuildForInfo for %s failed: start/limit/delta is NULL.", node->GetName().c_str());
  89. return FAILED;
  90. }
  91. std::vector<OutDataAnchorPtr> data_inputs;
  92. std::vector<std::vector<InDataAnchorPtr>> data_outputs;
  93. std::vector<OutControlAnchorPtr> ctrl_inputs;
  94. std::vector<InControlAnchorPtr> ctrl_outputs;
  95. if (FindInputsAndOutputs(node, data_inputs, data_outputs, ctrl_inputs, ctrl_outputs) != SUCCESS) {
  96. GELOGE(FAILED, "BuildForInfo for %s failed: find inputs/outputs failed.", node->GetName().c_str());
  97. return FAILED;
  98. }
  99. NodeUtils::UnlinkAll(*node);
  100. OpDescPtr op_desc = node->GetOpDesc();
  101. GE_CHECK_NOTNULL(op_desc);
  102. // For node has and only has one sub_graph
  103. std::string for_body_name = op_desc->GetSubgraphInstanceName(0);
  104. if (for_body_name.empty()) {
  105. GELOGE(FAILED, "BuildForInfo for %s failed: sub_graph_name is empty.", node->GetName().c_str());
  106. return FAILED;
  107. }
  108. ComputeGraphPtr for_body = root_graph->GetSubgraph(for_body_name);
  109. if (for_body == nullptr) {
  110. GELOGE(FAILED, "BuildForInfo for %s failed: for_body_graph is NULL.", node->GetName().c_str());
  111. return FAILED;
  112. }
  113. for_info.for_node = node;
  114. for_info.start = start;
  115. for_info.limit = limit;
  116. for_info.delta = delta;
  117. for_info.body_name = for_body_name;
  118. for_info.for_body = for_body;
  119. for_info.data_inputs = std::move(data_inputs);
  120. for_info.data_outputs = std::move(data_outputs);
  121. for_info.ctrl_inputs = std::move(ctrl_inputs);
  122. for_info.ctrl_outputs = std::move(ctrl_outputs);
  123. GELOGI("Build for_info for node %s success.", node->GetName().c_str());
  124. return SUCCESS;
  125. }
  126. ///
  127. /// @brief Find input with index for For node
  128. /// @param [in] node
  129. /// @param [in] index
  130. /// @return OutDataAnchorPtr
  131. ///
  132. OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index) {
  133. if (node == nullptr) {
  134. GELOGE(FAILED, "FindInputWithIndex failed: node is NULL.");
  135. return nullptr;
  136. }
  137. InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index);
  138. if (in_data_anchor == nullptr) {
  139. GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index);
  140. return nullptr;
  141. }
  142. return in_data_anchor->GetPeerOutAnchor();
  143. }
  144. ///
  145. /// @brief Find inputs / outputs for for node
  146. /// @param [in] node
  147. /// @param [out] data_inputs
  148. /// @param [out] data_outputs
  149. /// @param [out] ctrl_inputs
  150. /// @param [out] ctrl_outputs
  151. /// @return Status
  152. ///
  153. Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnchorPtr> &data_inputs,
  154. std::vector<std::vector<ge::InDataAnchorPtr>> &data_outputs,
  155. std::vector<ge::OutControlAnchorPtr> &ctrl_inputs,
  156. std::vector<ge::InControlAnchorPtr> &ctrl_outputs) {
  157. GE_CHECK_NOTNULL(node);
  158. uint32_t input_data_num = node->GetAllInDataAnchorsSize();
  159. for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) {
  160. InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index);
  161. GE_CHECK_NOTNULL(in_data_anchor);
  162. data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor());
  163. }
  164. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  165. std::vector<ge::InDataAnchorPtr> peer_in_data_anchors;
  166. for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  167. peer_in_data_anchors.emplace_back(peer_in_data_anchor);
  168. }
  169. data_outputs.emplace_back(peer_in_data_anchors);
  170. }
  171. InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor();
  172. GE_CHECK_NOTNULL(in_ctrl_anchor);
  173. for (const auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) {
  174. ctrl_inputs.emplace_back(peer_out_ctrl_anchor);
  175. }
  176. OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor();
  177. GE_CHECK_NOTNULL(out_ctrl_anchor);
  178. for (const auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
  179. ctrl_outputs.emplace_back(peer_in_ctrl_anchor);
  180. }
  181. return SUCCESS;
  182. }
  183. ///
  184. /// @brief Transfer while_info from for_info
  185. /// @param [in] graph
  186. /// @param [in] for_info
  187. /// @param [out] while_info
  188. /// @return Status
  189. ///
  190. Status ForPass::TranWhileInfo(const ComputeGraphPtr &graph, const ForInfo &for_info, WhileInfo &while_info) {
  191. std::string for_name = for_info.for_node->GetName();
  192. GELOGI("Begin to transfer for_info to while_info, node:%s.", for_name.c_str());
  193. std::string i_name = for_name + "_i";
  194. NodePtr i_node = graph->AddNode(CreateConstDesc(i_name, 0));
  195. if (i_node == nullptr) {
  196. GELOGE(FAILED, "TranWhileInfo failed: create i_node failed.");
  197. return FAILED;
  198. }
  199. AddRePassNode(i_node);
  200. std::string identity_name = i_name + "_Identity";
  201. NodePtr identity_node = graph->AddNode(CreateOpDesc(identity_name, IDENTITY, true));
  202. // Const node has and only has one output, Identity node has and only has one input
  203. if ((identity_node == nullptr) ||
  204. (GraphUtils::AddEdge(i_node->GetOutDataAnchor(0), identity_node->GetInDataAnchor(0)) != GRAPH_SUCCESS)) {
  205. GELOGE(FAILED, "TranWhileInfo failed: Add data-edge %s:0->%s:0 failed.", i_name.c_str(), identity_name.c_str());
  206. return FAILED;
  207. }
  208. AddRePassNode(identity_node);
  209. // Identity node has and only has one output
  210. OutDataAnchorPtr i_input = identity_node->GetOutDataAnchor(0);
  211. if (i_input == nullptr) {
  212. GELOGE(FAILED, "TranWhileInfo failed: i_input is NULL.");
  213. return FAILED;
  214. }
  215. OutDataAnchorPtr range_input = nullptr;
  216. OutDataAnchorPtr abs_delta_input = nullptr;
  217. if (CreateLoopInput(graph, for_info, range_input, abs_delta_input) != SUCCESS) {
  218. GELOGE(FAILED, "TranWhileInfo failed: create loop input failed.");
  219. return FAILED;
  220. }
  221. BuildWhileInfo(for_info, i_input, range_input, abs_delta_input, while_info);
  222. if (InsertWhileNode(graph, for_name + "_While", while_info) != SUCCESS) {
  223. GELOGE(FAILED, "TranWhileInfo failed: insert while node failed.");
  224. return FAILED;
  225. }
  226. GELOGI("Transfer for_info to while_info succ, for_node:%s, while_node:%s.",
  227. for_name.c_str(), while_info.while_node->GetName().c_str());
  228. return SUCCESS;
  229. }
  230. ///
  231. /// @brief Create const op_desc
  232. /// @param [in] name
  233. /// @param [in] value
  234. /// @return OpDescPtr
  235. ///
  236. OpDescPtr ForPass::CreateConstDesc(const std::string &name, int32_t value) {
  237. OpDescPtr const_op_desc = MakeShared<OpDesc>(name, CONSTANT);
  238. if (const_op_desc == nullptr) {
  239. GELOGE(FAILED, "Create op_desc failed, const:%s.", name.c_str());
  240. return nullptr;
  241. }
  242. GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_INT32);
  243. GeTensorPtr const_value = MakeShared<GeTensor>(data_desc, reinterpret_cast<uint8_t *>(&value), sizeof(int32_t));
  244. if (const_value == nullptr) {
  245. GELOGE(FAILED, "Create tensor failed, const:%s.", name.c_str());
  246. return nullptr;
  247. }
  248. if (!AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value)) {
  249. GELOGE(FAILED, "Set ATTR_NAME_WEIGHTS failed, const:%s.", name.c_str());
  250. return nullptr;
  251. }
  252. if (const_op_desc->AddOutputDesc("y", data_desc) != GRAPH_SUCCESS) {
  253. GELOGE(FAILED, "Add output desc failed, const:%s.", name.c_str());
  254. return nullptr;
  255. }
  256. return const_op_desc;
  257. }
  258. ///
  259. /// @brief Create loop node
  260. /// @param [in] graph
  261. /// @param [in] for_info
  262. /// @param [out] range_input
  263. /// @param [out] abs_delta_input
  264. /// @return Status
  265. ///
  266. Status ForPass::CreateLoopInput(const ComputeGraphPtr &graph, const ForInfo &for_info,
  267. OutDataAnchorPtr &range_input, OutDataAnchorPtr &abs_delta_input) {
  268. std::string for_name = for_info.for_node->GetName();
  269. GELOGD("Begin to create loop_count input, node:%s", for_name.c_str());
  270. OutDataAnchorPtr start = for_info.start;
  271. OutDataAnchorPtr limit = for_info.limit;
  272. OutDataAnchorPtr delta = for_info.delta;
  273. std::string sub_name_0 = for_name + "_Sub_0";
  274. std::string abs_name_0 = for_name + "_Abs_0";
  275. std::string abs_name_1 = for_name + "_Abs_1";
  276. // i * |delta| < |limit-start|
  277. PartialGraphBuilder graph_builder;
  278. graph_builder.SetOwnerGraph(graph)
  279. .AddExistNode(for_info.start->GetOwnerNode())
  280. .AddExistNode(for_info.limit->GetOwnerNode())
  281. .AddExistNode(for_info.delta->GetOwnerNode())
  282. .AddNode(CreateOpDesc(sub_name_0, SUB, false))
  283. .AddNode(CreateOpDesc(abs_name_0, kAbs, true))
  284. .AddNode(CreateOpDesc(abs_name_1, kAbs, true))
  285. .AddDataLink(delta->GetOwnerNode()->GetName(), delta->GetIdx(), abs_name_0, 0)
  286. .AddDataLink(limit->GetOwnerNode()->GetName(), limit->GetIdx(), sub_name_0, 0)
  287. .AddDataLink(start->GetOwnerNode()->GetName(), start->GetIdx(), sub_name_0, 1)
  288. .AddDataLink(sub_name_0, 0, abs_name_1, 0);
  289. graphStatus error_code = GRAPH_SUCCESS;
  290. std::string error_msg;
  291. if ((graph_builder.Build(error_code, error_msg) == nullptr) || (error_code != GRAPH_SUCCESS)) {
  292. GELOGE(FAILED, "Create loop_count node failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  293. return FAILED;
  294. }
  295. // Add repass_nodes
  296. for (auto &node : graph_builder.GetAllNodes()) {
  297. AddRePassNode(node);
  298. }
  299. NodePtr abs_delta_node = graph_builder.GetNode(abs_name_0);
  300. NodePtr loop_count_node = graph_builder.GetNode(abs_name_1);
  301. if ((abs_delta_node == nullptr) || (loop_count_node == nullptr)) {
  302. GELOGE(FAILED, "Create loop node failed: node is NULL.");
  303. return FAILED;
  304. }
  305. GELOGD("Create loop_range input succ, node:%s", for_name.c_str());
  306. // abs_node has and only has one output
  307. abs_delta_input = abs_delta_node->GetOutDataAnchor(0);
  308. range_input = loop_count_node->GetOutDataAnchor(0);
  309. return SUCCESS;
  310. }
  311. ///
  312. /// @brief Create op_desc
  313. /// @param [in] name
  314. /// @param [in] type
  315. /// @param [in] io_equal_flag
  316. /// @return OpDescPtr
  317. ///
  318. OpDescPtr ForPass::CreateOpDesc(const std::string &name, const std::string &type, bool io_equal_flag) {
  319. OpDescBuilder op_desc_builder(name, type);
  320. if (io_equal_flag) {
  321. op_desc_builder.AddInput("x")
  322. .AddOutput("y");
  323. } else {
  324. op_desc_builder.AddInput("x1")
  325. .AddInput("x2")
  326. .AddOutput("y");
  327. }
  328. return op_desc_builder.Build();
  329. }
  330. ///
  331. /// @brief Build while-info
  332. /// @param [in] for_info
  333. /// @param [in] i_input
  334. /// @param [in] range_input
  335. /// @param [in] abs_delta_input
  336. /// @param [out] while_info
  337. /// @return void
  338. ///
  339. void ForPass::BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input,
  340. const OutDataAnchorPtr &range_input, const OutDataAnchorPtr &abs_delta_input,
  341. WhileInfo &while_info) {
  342. while_info.i = i_input;
  343. while_info.abs_delta = abs_delta_input;
  344. while_info.range = range_input;
  345. while_info.start = for_info.start;
  346. while_info.delta = for_info.delta;
  347. while_info.for_body_name = for_info.body_name;
  348. while_info.for_body = for_info.for_body;
  349. while_info.data_inputs.emplace_back(while_info.i);
  350. while_info.data_inputs.emplace_back(while_info.abs_delta);
  351. while_info.data_inputs.emplace_back(while_info.range);
  352. while_info.data_inputs.emplace_back(while_info.start);
  353. while_info.data_inputs.emplace_back(while_info.delta);
  354. for (auto &item : for_info.data_inputs) {
  355. while_info.data_inputs.emplace_back(item);
  356. }
  357. for (auto &item : for_info.data_outputs) {
  358. while_info.data_outputs.emplace_back(item);
  359. }
  360. for (auto &item : for_info.ctrl_inputs) {
  361. while_info.ctrl_inputs.emplace_back(item);
  362. }
  363. for (auto &item : for_info.ctrl_outputs) {
  364. while_info.ctrl_outputs.emplace_back(item);
  365. }
  366. }
  367. ///
  368. /// @brief Insert while_node
  369. /// @param [in] graph
  370. /// @param [in] name
  371. /// @param [in&out] while_info
  372. /// @return Status
  373. ///
  374. Status ForPass::InsertWhileNode(const ComputeGraphPtr &graph, const std::string &name, WhileInfo &while_info) {
  375. GELOGD("Begin to create while node, name:%s.", name.c_str());
  376. size_t arg_num = while_info.data_inputs.size();
  377. OpDescBuilder op_desc_builder(name, WHILE);
  378. OpDescPtr op_desc = op_desc_builder.AddDynamicInput("input", arg_num).AddDynamicOutput("output", arg_num).Build();
  379. if (op_desc == nullptr) {
  380. GELOGE(FAILED, "Create while op_desc failed, name:%s.", name.c_str());
  381. return FAILED;
  382. }
  383. NodePtr while_node = graph->AddNode(op_desc);
  384. if (while_node == nullptr) {
  385. GELOGE(FAILED, "Create while node failed, name:%s.", name.c_str());
  386. return FAILED;
  387. }
  388. AddRePassNode(while_node);
  389. while_info.while_node = while_node;
  390. if (BuildWhileLink(while_info) != SUCCESS) {
  391. GELOGE(FAILED, "Build while link-edge failed, name:%s.", name.c_str());
  392. return FAILED;
  393. }
  394. GELOGD("Create while node succ, name:%s.", name.c_str());
  395. return SUCCESS;
  396. }
  397. ///
  398. /// @brief Build while link-edge
  399. /// @param [in] while_info
  400. /// @return Status
  401. ///
  402. Status ForPass::BuildWhileLink(const WhileInfo &while_info) {
  403. NodePtr while_node = while_info.while_node;
  404. GE_CHECK_NOTNULL(while_node);
  405. size_t input_num = while_info.data_inputs.size();
  406. for (size_t i = 0; i < input_num; i++) {
  407. InDataAnchorPtr in_data_anchor = while_node->GetInDataAnchor(i);
  408. GE_CHECK_NOTNULL(in_data_anchor);
  409. OutDataAnchorPtr peer_out_anchor = while_info.data_inputs[i];
  410. if (peer_out_anchor == nullptr) {
  411. continue;
  412. }
  413. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_data_anchor),
  414. "Add data-edge %s:%d->%s:%d failed.",
  415. peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetIdx(),
  416. while_node->GetName().c_str(), i);
  417. }
  418. size_t output_num = while_info.data_outputs.size();
  419. for (size_t i = 0; i < output_num; i++) {
  420. OutDataAnchorPtr out_data_anchor = while_node->GetOutDataAnchor(static_cast<int>(i + kWhileOutputIndex));
  421. GE_CHECK_NOTNULL(out_data_anchor);
  422. for (auto &peer_in_anchor : while_info.data_outputs[i]) {
  423. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_data_anchor, peer_in_anchor),
  424. "Add data-edge %s:%d->%s:%d failed.",
  425. while_node->GetName().c_str(), i + kWhileOutputIndex,
  426. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  427. }
  428. }
  429. InControlAnchorPtr in_ctrl_anchor = while_node->GetInControlAnchor();
  430. GE_CHECK_NOTNULL(in_ctrl_anchor);
  431. for (auto &peer_out_anchor : while_info.ctrl_inputs) {
  432. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_ctrl_anchor),
  433. "Add ctrl-edge %s->%s failed.",
  434. peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  435. in_ctrl_anchor->GetOwnerNode()->GetName().c_str());
  436. }
  437. OutControlAnchorPtr out_ctrl_anchor = while_node->GetOutControlAnchor();
  438. GE_CHECK_NOTNULL(out_ctrl_anchor);
  439. for (auto &peer_in_anchor : while_info.ctrl_outputs) {
  440. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_ctrl_anchor, peer_in_anchor),
  441. "Add ctrl-edge %s->%s failed.",
  442. out_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
  443. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  444. }
  445. return SUCCESS;
  446. }
  447. ///
  448. /// @brief Build cond_graph for while_node
  449. /// @param [in&out] while_info
  450. /// @return ComputeGraphPtr
  451. ///
  452. ComputeGraphPtr ForPass::BuildCondGraph(WhileInfo &while_info) {
  453. std::string cond_name = while_info.for_body_name + "_Cond";
  454. CompleteGraphBuilder graph_builder(cond_name);
  455. // Add parent node
  456. graph_builder.SetParentNode(while_info.while_node);
  457. // Add Node
  458. const std::string mul_name = "Mul";
  459. graph_builder.AddNode(CreateOpDesc(mul_name, MUL, false));
  460. const std::string less_name = "Less";
  461. graph_builder.AddNode(CreateOpDesc(less_name, LESS, false));
  462. // Set Input
  463. graph_builder.SetInput(kWhileIInputIndex, { mul_name }, { 0 })
  464. .SetInput(kWhileAbsDeltaInputIndex, { mul_name }, { 1 })
  465. .SetInput(kWhileRangeInputIndex, { less_name }, { 1 })
  466. .SetUselessInput(kWhileStartInputIndex)
  467. .SetUselessInput(kWhileDeltaInputIndex);
  468. size_t input_num = while_info.data_inputs.size();
  469. for (size_t i = kWhileDataInputIndex; i < input_num; i++) {
  470. graph_builder.SetUselessInput(i);
  471. }
  472. // Add Output
  473. graph_builder.AddOutput(less_name, 0);
  474. // Add Edges
  475. graph_builder.AddDataLink(mul_name, 0, less_name, 0);
  476. // Add Input-Mapping
  477. std::map<uint32_t, uint32_t> input_mapping;
  478. for (size_t i = 0; i < input_num; i++) {
  479. input_mapping[i] = i;
  480. }
  481. graph_builder.SetInputMapping(input_mapping);
  482. graphStatus error_code = GRAPH_SUCCESS;
  483. std::string error_msg;
  484. ComputeGraphPtr cond_graph = graph_builder.Build(error_code, error_msg);
  485. if (cond_graph == nullptr) {
  486. GELOGE(FAILED, "Build cond_graph failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  487. return nullptr;
  488. }
  489. size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size();
  490. while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_COND);
  491. while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, cond_name);
  492. while_info.while_cond = cond_graph;
  493. return cond_graph;
  494. }
  495. ///
  496. /// @brief Build body_graph for while_node
  497. /// @param [in&out] while_info
  498. /// @return ComputeGraphPtr
  499. ///
  500. ComputeGraphPtr ForPass::BuildBodyGraph(WhileInfo &while_info) {
  501. std::string body_name = while_info.for_body_name + "_Body";
  502. CompleteGraphBuilder graph_builder(body_name);
  503. // Add parent node
  504. graph_builder.SetParentNode(while_info.while_node);
  505. // Add calculation nodes
  506. std::string const_name = "Const";
  507. std::string add_name_0 = "Add_0";
  508. std::string mul_name = "Mul";
  509. std::string add_name_1 = "Add_1";
  510. graph_builder.AddNode(CreateConstDesc(const_name, 1))
  511. .AddNode(CreateOpDesc(add_name_0, ADD, false))
  512. .AddNode(CreateOpDesc(mul_name, MUL, false))
  513. .AddNode(CreateOpDesc(add_name_1, ADD, false));
  514. // Add Subgraph node
  515. auto input_num = static_cast<uint32_t>(while_info.data_inputs.size());
  516. std::string sub_graph_node_name = while_info.for_body_name;
  517. uint32_t sub_graph_input_num = input_num - kWhileDataInputIndex + kSubgraphInputIndex;
  518. auto sub_graph_output_num = static_cast<uint32_t>(while_info.data_outputs.size());
  519. graph_builder.AddNode(CreateSubgraphOpDesc(sub_graph_node_name, sub_graph_input_num, sub_graph_output_num));
  520. // Set Input
  521. graph_builder.SetInput(kWhileIInputIndex, { add_name_0, mul_name }, { 0, 0 })
  522. .SetUselessInput(kWhileAbsDeltaInputIndex)
  523. .SetUselessInput(kWhileRangeInputIndex)
  524. .SetInput(kWhileStartInputIndex, { add_name_1 }, { 0 })
  525. .SetInput(kWhileDeltaInputIndex, { mul_name }, { 1 });
  526. for (uint32_t i = 0; i < input_num - kWhileDataInputIndex; i++) {
  527. graph_builder.SetInput(i + kWhileDataInputIndex, { sub_graph_node_name }, { i + kSubgraphInputIndex });
  528. }
  529. // Add Outputs
  530. graph_builder.AddOutput(add_name_0, 0);
  531. for (uint32_t i = kWhileAbsDeltaInputIndex; i < kWhileDataInputIndex; i++) {
  532. graph_builder.AddOutput("Data_" + std::to_string(i), 0);
  533. }
  534. for (uint32_t i = 0; i < sub_graph_output_num; i++) {
  535. graph_builder.AddOutput(sub_graph_node_name, i);
  536. }
  537. // Add Edges
  538. graph_builder.AddDataLink(const_name, 0, add_name_0, 1)
  539. .AddDataLink(mul_name, 0, add_name_1, 1)
  540. .AddDataLink(add_name_1, 0, sub_graph_node_name, kSubgraphLoopVarInputIndex);
  541. // Add Input-Mapping
  542. std::map<uint32_t, uint32_t> input_mapping;
  543. for (size_t i = 0; i < input_num; i++) {
  544. input_mapping[i] = i;
  545. }
  546. graph_builder.SetInputMapping(input_mapping);
  547. // Add outputMapping
  548. std::map<uint32_t, uint32_t> output_mapping;
  549. for (size_t i = 0; i < sub_graph_output_num + kWhileOutputIndex; i++) {
  550. output_mapping[i] = i;
  551. }
  552. graph_builder.SetOutputMapping(output_mapping);
  553. graphStatus error_code = GRAPH_SUCCESS;
  554. std::string error_msg;
  555. ComputeGraphPtr body_graph = graph_builder.Build(error_code, error_msg);
  556. if (body_graph == nullptr) {
  557. GELOGE(FAILED, "Build body_graph failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  558. return nullptr;
  559. }
  560. NodePtr sub_graph_node = graph_builder.GetNode(sub_graph_node_name);
  561. if (sub_graph_node == nullptr) {
  562. GELOGE(FAILED, "Get sub_graph_node failed: name:%s.", sub_graph_node_name.c_str());
  563. return nullptr;
  564. }
  565. while_info.sub_graph_node = sub_graph_node;
  566. size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size();
  567. while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_BODY);
  568. while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, body_name);
  569. while_info.while_body = body_graph;
  570. return body_graph;
  571. }
  572. ///
  573. /// @brief Create op_desc for subgraph node
  574. /// @param [in] name
  575. /// @param [in] input_num
  576. /// @param [in] output_num
  577. /// @return OpDescPtr
  578. ///
  579. OpDescPtr ForPass::CreateSubgraphOpDesc(const std::string &name, uint32_t input_num, uint32_t output_num) {
  580. OpDescBuilder op_desc_builder(name, PARTITIONEDCALL);
  581. op_desc_builder.AddDynamicInput("args", input_num)
  582. .AddDynamicOutput("output", output_num);
  583. OpDescPtr op_desc = op_desc_builder.Build();
  584. if (op_desc == nullptr) {
  585. GELOGE(FAILED, "Create op_desc for subgraph node failed, name:%s.", name.c_str());
  586. return nullptr;
  587. }
  588. size_t index = op_desc->GetSubgraphInstanceNames().size();
  589. op_desc->AddSubgraphName("f");
  590. op_desc->SetSubgraphInstanceName(index, name);
  591. return op_desc;
  592. }
  593. ///
  594. /// @brief Update InputMapping for for-body-graph
  595. /// @param [in] while_info
  596. /// @return Status
  597. ///
  598. Status ForPass::UpdateForBodyInputMapping(const WhileInfo &while_info) {
  599. ComputeGraphPtr for_body = while_info.for_body;
  600. GE_CHECK_NOTNULL(for_body);
  601. // index_of_cur_graph_node_input -> index_of_new_graph_node_input
  602. std::map<uint32_t, uint32_t> input_mapping;
  603. size_t input_num = while_info.data_inputs.size() - kWhileDataInputIndex + FOR_DATA_INPUT;
  604. for (size_t i = 0; i < input_num; i++) {
  605. if (i == FOR_START_INPUT) {
  606. input_mapping[i] = i;
  607. } else if ((i == FOR_LIMIT_INPUT) || (i == FOR_DELTA_INPUT)) {
  608. continue;
  609. } else {
  610. input_mapping[i] = i - 2;
  611. }
  612. }
  613. for_body->UpdateInputMapping(input_mapping);
  614. for_body->SetParentNode(while_info.sub_graph_node);
  615. for_body->SetParentGraph(while_info.while_body);
  616. return SUCCESS;
  617. }
  618. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示