You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

for_pass.cc 27 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/for_pass.h"
  17. #include "common/ge/ge_util.h"
  18. #include "common/op/ge_op_utils.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "framework/common/debug/log.h"
  21. #include "framework/common/ge_inner_error_codes.h"
  22. #include "framework/common/types.h"
  23. #include "graph/debug/ge_attr_define.h"
  24. #include "graph/utils/graph_utils.h"
  25. #include "graph/utils/type_utils.h"
  26. #include "graph/utils/node_utils.h"
  27. #include "graph/utils/op_desc_utils.h"
  28. namespace {
  29. const uint32_t kWhileIInputIndex = 0;
  30. const uint32_t kWhileAbsDeltaInputIndex = 1;
  31. const uint32_t kWhileRangeInputIndex = 2;
  32. const uint32_t kWhileStartInputIndex = 3;
  33. const uint32_t kWhileDeltaInputIndex = 4;
  34. const uint32_t kWhileDataInputIndex = 5;
  35. const uint32_t kSubgraphLoopVarInputIndex = 0;
  36. const uint32_t kSubgraphInputIndex = 1;
  37. const uint32_t kWhileOutputIndex = 5;
  38. const std::string kAbs = "Abs";
  39. }
  40. namespace ge {
  41. Status ForPass::Run(NodePtr &node) {
  42. if (node->GetType() != FOR) {
  43. GELOGD("no need for_pass for node %s.", node->GetName().c_str());
  44. return SUCCESS;
  45. }
  46. GELOGI("Begin to transfer for_op to while_op, node:%s.", node->GetName().c_str());
  47. ComputeGraphPtr graph = node->GetOwnerComputeGraph();
  48. GE_CHECK_NOTNULL(graph);
  49. ComputeGraphPtr root_graph = GraphUtils::FindRootGraph(graph);
  50. GE_CHECK_NOTNULL(root_graph);
  51. ForInfo for_info;
  52. GE_CHK_STATUS_RET(BuildForInfo(root_graph, node, for_info),
  53. "Build ForInfo failed, node:%s.", node->GetName().c_str());
  54. WhileInfo while_info;
  55. GE_CHK_STATUS_RET(TranWhileInfo(graph, for_info, while_info),
  56. "Transfer WhileInfo from ForInfo failed, node:%s.", node->GetName().c_str());
  57. ComputeGraphPtr cond_graph = BuildCondGraph(while_info);
  58. if ((cond_graph == nullptr) || (root_graph->AddSubgraph(cond_graph) != GRAPH_SUCCESS)) {
  59. GELOGE(FAILED, "Add while_cond_graph failed, node:%s.", node->GetName().c_str());
  60. return FAILED;
  61. }
  62. ComputeGraphPtr body_graph = BuildBodyGraph(while_info);
  63. if ((body_graph == nullptr) || (root_graph->AddSubgraph(body_graph) != GRAPH_SUCCESS)) {
  64. GELOGE(FAILED, "Add while_body_graph failed, node:%s.", node->GetName().c_str());
  65. return FAILED;
  66. }
  67. GE_CHK_STATUS_RET(UpdateForBodyInputMapping(while_info),
  68. "Update InputMapping for for-body-graph failed, node:%s.", node->GetName().c_str());
  69. // for node has and only has one subgraph
  70. GE_CHECK_NOTNULL(node->GetOpDesc());
  71. node->GetOpDesc()->RemoveSubgraphInstanceName(node->GetOpDesc()->GetSubgraphInstanceName(0));
  72. GELOGI("Transfer for_op to while_op succ, node:%s.", node->GetName().c_str());
  73. return IsolateAndDeleteNode(node, std::vector<int>());
  74. }
  75. ///
  76. /// @brief Build for_info
  77. /// @param [in] root_graph
  78. /// @param [in] node
  79. /// @param [out] for_info
  80. /// @return Status
  81. ///
  82. Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &node, ForInfo &for_info) {
  83. GELOGI("Begin to build for_info for node %s.", node->GetName().c_str());
  84. OutDataAnchorPtr start = FindInputWithIndex(node, FOR_START_INPUT);
  85. OutDataAnchorPtr limit = FindInputWithIndex(node, FOR_LIMIT_INPUT);
  86. OutDataAnchorPtr delta = FindInputWithIndex(node, FOR_DELTA_INPUT);
  87. if ((start == nullptr) || (limit == nullptr) || (delta == nullptr)) {
  88. GELOGE(FAILED, "BuildForInfo for %s failed: start/limit/delta is NULL.", node->GetName().c_str());
  89. return FAILED;
  90. }
  91. std::vector<OutDataAnchorPtr> data_inputs;
  92. std::vector<std::vector<InDataAnchorPtr>> data_outputs;
  93. std::vector<OutControlAnchorPtr> ctrl_inputs;
  94. std::vector<InControlAnchorPtr> ctrl_outputs;
  95. if (FindInputsAndOutputs(node, data_inputs, data_outputs, ctrl_inputs, ctrl_outputs) != SUCCESS) {
  96. GELOGE(FAILED, "BuildForInfo for %s failed: find inputs/outputs failed.", node->GetName().c_str());
  97. return FAILED;
  98. }
  99. NodeUtils::UnlinkAll(*node);
  100. OpDescPtr op_desc = node->GetOpDesc();
  101. GE_CHECK_NOTNULL(op_desc);
  102. // For node has and only has one sub_graph
  103. std::string for_body_name = op_desc->GetSubgraphInstanceName(0);
  104. if (for_body_name.empty()) {
  105. GELOGE(FAILED, "BuildForInfo for %s failed: sub_graph_name is empty.", node->GetName().c_str());
  106. return FAILED;
  107. }
  108. ComputeGraphPtr for_body = root_graph->GetSubgraph(for_body_name);
  109. if (for_body == nullptr) {
  110. GELOGE(FAILED, "BuildForInfo for %s failed: for_body_graph is NULL.", node->GetName().c_str());
  111. return FAILED;
  112. }
  113. for_info.for_node = node;
  114. for_info.start = start;
  115. for_info.limit = limit;
  116. for_info.delta = delta;
  117. for_info.body_name = for_body_name;
  118. for_info.for_body = for_body;
  119. for_info.data_inputs = std::move(data_inputs);
  120. for_info.data_outputs = std::move(data_outputs);
  121. for_info.ctrl_inputs = std::move(ctrl_inputs);
  122. for_info.ctrl_outputs = std::move(ctrl_outputs);
  123. GELOGI("Build for_info for node %s succ.", node->GetName().c_str());
  124. return SUCCESS;
  125. }
  126. ///
  127. /// @brief Find input with index for For node
  128. /// @param [in] node
  129. /// @param [in] index
  130. /// @return OutDataAnchorPtr
  131. ///
  132. OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index) {
  133. if (node == nullptr) {
  134. GELOGE(FAILED, "FindInputWithIndex failed: node is NULL.");
  135. return nullptr;
  136. }
  137. InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index);
  138. if (in_data_anchor == nullptr) {
  139. GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index);
  140. return nullptr;
  141. }
  142. OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  143. if (peer_out_anchor == nullptr) {
  144. GELOGE(FAILED, "FindInputWithIndex %s:%u failed: peer_out_anchor is NULL.", node->GetName().c_str(), index);
  145. return nullptr;
  146. }
  147. return peer_out_anchor;
  148. }
  149. ///
  150. /// @brief Find inputs / outputs for for node
  151. /// @param [in] node
  152. /// @param [out] data_inputs
  153. /// @param [out] data_outputs
  154. /// @param [out] ctrl_inputs
  155. /// @param [out] ctrl_outputs
  156. /// @return Status
  157. ///
  158. Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnchorPtr> &data_inputs,
  159. std::vector<std::vector<ge::InDataAnchorPtr>> &data_outputs,
  160. std::vector<ge::OutControlAnchorPtr> &ctrl_inputs,
  161. std::vector<ge::InControlAnchorPtr> &ctrl_outputs) {
  162. GE_CHECK_NOTNULL(node);
  163. uint32_t input_data_num = node->GetAllInDataAnchorsSize();
  164. for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) {
  165. InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index);
  166. if (in_data_anchor == nullptr) {
  167. GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index);
  168. return FAILED;
  169. }
  170. GE_IF_BOOL_EXEC(in_data_anchor->GetPeerOutAnchor() == nullptr,
  171. GELOGW("Get null input by index %d from node %s ",
  172. in_data_anchor->GetIdx(), node->GetName().c_str());
  173. continue);
  174. data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor());
  175. }
  176. for (auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  177. std::vector<ge::InDataAnchorPtr> peer_in_data_anchors;
  178. for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  179. peer_in_data_anchors.emplace_back(peer_in_data_anchor);
  180. }
  181. data_outputs.emplace_back(peer_in_data_anchors);
  182. }
  183. InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor();
  184. GE_CHECK_NOTNULL(in_ctrl_anchor);
  185. for (auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) {
  186. ctrl_inputs.emplace_back(peer_out_ctrl_anchor);
  187. }
  188. OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor();
  189. GE_CHECK_NOTNULL(out_ctrl_anchor);
  190. for (auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
  191. ctrl_outputs.emplace_back(peer_in_ctrl_anchor);
  192. }
  193. return SUCCESS;
  194. }
  195. ///
  196. /// @brief Transfer while_info from for_info
  197. /// @param [in] graph
  198. /// @param [in] for_info
  199. /// @param [out] while_info
  200. /// @return Status
  201. ///
  202. Status ForPass::TranWhileInfo(const ComputeGraphPtr &graph, const ForInfo &for_info, WhileInfo &while_info) {
  203. std::string for_name = for_info.for_node->GetName();
  204. GELOGI("Begin to transfer for_info to while_info, node:%s.", for_name.c_str());
  205. std::string i_name = for_name + "_i";
  206. NodePtr i_node = graph->AddNode(CreateConstDesc(i_name, 0));
  207. if (i_node == nullptr) {
  208. GELOGE(FAILED, "TranWhileInfo failed: create i_node failed.");
  209. return FAILED;
  210. }
  211. AddRePassNode(i_node);
  212. std::string identity_name = i_name + "_Identity";
  213. NodePtr identity_node = graph->AddNode(CreateOpDesc(identity_name, IDENTITY, true));
  214. // Const node has and only has one output, Identity node has and only has one input
  215. if ((identity_node == nullptr) ||
  216. (GraphUtils::AddEdge(i_node->GetOutDataAnchor(0), identity_node->GetInDataAnchor(0)) != GRAPH_SUCCESS)) {
  217. GELOGE(FAILED, "TranWhileInfo failed: Add data-edge %s:0->%s:0 failed.", i_name.c_str(), identity_name.c_str());
  218. return FAILED;
  219. }
  220. AddRePassNode(identity_node);
  221. // Identity node has and only has one output
  222. OutDataAnchorPtr i_input = identity_node->GetOutDataAnchor(0);
  223. if (i_input == nullptr) {
  224. GELOGE(FAILED, "TranWhileInfo failed: i_input is NULL.");
  225. return FAILED;
  226. }
  227. OutDataAnchorPtr range_input = nullptr;
  228. OutDataAnchorPtr abs_delta_input = nullptr;
  229. if (CreateLoopInput(graph, for_info, range_input, abs_delta_input) != SUCCESS) {
  230. GELOGE(FAILED, "TranWhileInfo failed: create loop input failed.");
  231. return FAILED;
  232. }
  233. BuildWhileInfo(for_info, i_input, range_input, abs_delta_input, while_info);
  234. if (InsertWhileNode(graph, for_name + "_While", while_info) != SUCCESS) {
  235. GELOGE(FAILED, "TranWhileInfo failed: insert while node failed.");
  236. return FAILED;
  237. }
  238. GELOGI("Transfer for_info to while_info succ, for_node:%s, while_node:%s.",
  239. for_name.c_str(), while_info.while_node->GetName().c_str());
  240. return SUCCESS;
  241. }
  242. ///
  243. /// @brief Create const op_desc
  244. /// @param [in] name
  245. /// @param [in] value
  246. /// @return OpDescPtr
  247. ///
  248. OpDescPtr ForPass::CreateConstDesc(const std::string &name, int32_t value) {
  249. OpDescPtr const_op_desc = MakeShared<OpDesc>(name, CONSTANT);
  250. if (const_op_desc == nullptr) {
  251. GELOGE(FAILED, "Create op_desc failed, const:%s.", name.c_str());
  252. return nullptr;
  253. }
  254. GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_INT32);
  255. GeTensorPtr const_value = MakeShared<GeTensor>(data_desc, reinterpret_cast<uint8_t *>(&value), sizeof(int32_t));
  256. if (const_value == nullptr) {
  257. GELOGE(FAILED, "Create tensor failed, const:%s.", name.c_str());
  258. return nullptr;
  259. }
  260. if (!AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value)) {
  261. GELOGE(FAILED, "Set ATTR_NAME_WEIGHTS failed, const:%s.", name.c_str());
  262. return nullptr;
  263. }
  264. if (const_op_desc->AddOutputDesc("y", data_desc) != GRAPH_SUCCESS) {
  265. GELOGE(FAILED, "Add output desc failed, const:%s.", name.c_str());
  266. return nullptr;
  267. }
  268. return const_op_desc;
  269. }
  270. ///
  271. /// @brief Create loop node
  272. /// @param [in] graph
  273. /// @param [in] for_info
  274. /// @param [out] range_input
  275. /// @param [out] abs_delta_input
  276. /// @return Status
  277. ///
  278. Status ForPass::CreateLoopInput(const ComputeGraphPtr &graph, const ForInfo &for_info,
  279. OutDataAnchorPtr &range_input, OutDataAnchorPtr &abs_delta_input) {
  280. std::string for_name = for_info.for_node->GetName();
  281. GELOGD("Begin to create loop_count input, node:%s", for_name.c_str());
  282. OutDataAnchorPtr start = for_info.start;
  283. OutDataAnchorPtr limit = for_info.limit;
  284. OutDataAnchorPtr delta = for_info.delta;
  285. std::string sub_name_0 = for_name + "_Sub_0";
  286. std::string abs_name_0 = for_name + "_Abs_0";
  287. std::string abs_name_1 = for_name + "_Abs_1";
  288. // i * |delta| < |limit-start|
  289. PartialGraphBuilder graph_builder;
  290. graph_builder.SetOwnerGraph(graph)
  291. .AddExistNode(for_info.start->GetOwnerNode())
  292. .AddExistNode(for_info.limit->GetOwnerNode())
  293. .AddExistNode(for_info.delta->GetOwnerNode())
  294. .AddNode(CreateOpDesc(sub_name_0, SUB, false))
  295. .AddNode(CreateOpDesc(abs_name_0, kAbs, true))
  296. .AddNode(CreateOpDesc(abs_name_1, kAbs, true))
  297. .AddDataLink(delta->GetOwnerNode()->GetName(), delta->GetIdx(), abs_name_0, 0)
  298. .AddDataLink(limit->GetOwnerNode()->GetName(), limit->GetIdx(), sub_name_0, 0)
  299. .AddDataLink(start->GetOwnerNode()->GetName(), start->GetIdx(), sub_name_0, 1)
  300. .AddDataLink(sub_name_0, 0, abs_name_1, 0);
  301. graphStatus error_code = GRAPH_SUCCESS;
  302. std::string error_msg;
  303. if ((graph_builder.Build(error_code, error_msg) == nullptr) || (error_code != GRAPH_SUCCESS)) {
  304. GELOGE(FAILED, "Create loop_count node failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  305. return FAILED;
  306. }
  307. // Add repass_nodes
  308. for (auto &node : graph_builder.GetAllNodes()) {
  309. AddRePassNode(node);
  310. }
  311. NodePtr abs_delta_node = graph_builder.GetNode(abs_name_0);
  312. NodePtr loop_count_node = graph_builder.GetNode(abs_name_1);
  313. if ((abs_delta_node == nullptr) || (loop_count_node == nullptr)) {
  314. GELOGE(FAILED, "Create loop node failed: node is NULL.");
  315. return FAILED;
  316. }
  317. GELOGD("Create loop_range input succ, node:%s", for_name.c_str());
  318. // abs_node has and only has one output
  319. abs_delta_input = abs_delta_node->GetOutDataAnchor(0);
  320. range_input = loop_count_node->GetOutDataAnchor(0);
  321. return SUCCESS;
  322. }
  323. ///
  324. /// @brief Create op_desc
  325. /// @param [in] name
  326. /// @param [in] type
  327. /// @param [in] io_equal_flag
  328. /// @return OpDescPtr
  329. ///
  330. OpDescPtr ForPass::CreateOpDesc(const std::string &name, const std::string &type, bool io_equal_flag) {
  331. OpDescBuilder op_desc_builder(name, type);
  332. if (io_equal_flag) {
  333. op_desc_builder.AddInput("x")
  334. .AddOutput("y");
  335. } else {
  336. op_desc_builder.AddInput("x1")
  337. .AddInput("x2")
  338. .AddOutput("y");
  339. }
  340. return op_desc_builder.Build();
  341. }
  342. ///
  343. /// @brief Build while-info
  344. /// @param [in] for_info
  345. /// @param [in] i_input
  346. /// @param [in] range_input
  347. /// @param [in] abs_delta_input
  348. /// @param [out] while_info
  349. /// @return void
  350. ///
  351. void ForPass::BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input,
  352. const OutDataAnchorPtr &range_input, const OutDataAnchorPtr &abs_delta_input,
  353. WhileInfo &while_info) {
  354. while_info.i = i_input;
  355. while_info.abs_delta = abs_delta_input;
  356. while_info.range = range_input;
  357. while_info.start = for_info.start;
  358. while_info.delta = for_info.delta;
  359. while_info.for_body_name = for_info.body_name;
  360. while_info.for_body = for_info.for_body;
  361. while_info.data_inputs.emplace_back(while_info.i);
  362. while_info.data_inputs.emplace_back(while_info.abs_delta);
  363. while_info.data_inputs.emplace_back(while_info.range);
  364. while_info.data_inputs.emplace_back(while_info.start);
  365. while_info.data_inputs.emplace_back(while_info.delta);
  366. for (auto &item : for_info.data_inputs) {
  367. while_info.data_inputs.emplace_back(item);
  368. }
  369. for (auto &item : for_info.data_outputs) {
  370. while_info.data_outputs.emplace_back(item);
  371. }
  372. for (auto &item : for_info.ctrl_inputs) {
  373. while_info.ctrl_inputs.emplace_back(item);
  374. }
  375. for (auto &item : for_info.ctrl_outputs) {
  376. while_info.ctrl_outputs.emplace_back(item);
  377. }
  378. }
  379. ///
  380. /// @brief Insert while_node
  381. /// @param [in] graph
  382. /// @param [in] name
  383. /// @param [in&out] while_info
  384. /// @return Status
  385. ///
  386. Status ForPass::InsertWhileNode(const ComputeGraphPtr &graph, const std::string &name, WhileInfo &while_info) {
  387. GELOGD("Begin to create while node, name:%s.", name.c_str());
  388. size_t arg_num = while_info.data_inputs.size();
  389. OpDescBuilder op_desc_builder(name, WHILE);
  390. OpDescPtr op_desc = op_desc_builder.AddDynamicInput("input", arg_num).AddDynamicOutput("output", arg_num).Build();
  391. if (op_desc == nullptr) {
  392. GELOGE(FAILED, "Create while op_desc failed, name:%s.", name.c_str());
  393. return FAILED;
  394. }
  395. NodePtr while_node = graph->AddNode(op_desc);
  396. if (while_node == nullptr) {
  397. GELOGE(FAILED, "Create while node failed, name:%s.", name.c_str());
  398. return FAILED;
  399. }
  400. AddRePassNode(while_node);
  401. while_info.while_node = while_node;
  402. if (BuildWhileLink(while_info) != SUCCESS) {
  403. GELOGE(FAILED, "Build while link-edge failed, name:%s.", name.c_str());
  404. return FAILED;
  405. }
  406. GELOGD("Create while node succ, name:%s.", name.c_str());
  407. return SUCCESS;
  408. }
  409. ///
  410. /// @brief Build while link-edge
  411. /// @param [in] while_info
  412. /// @return Status
  413. ///
  414. Status ForPass::BuildWhileLink(const WhileInfo &while_info) {
  415. NodePtr while_node = while_info.while_node;
  416. GE_CHECK_NOTNULL(while_node);
  417. size_t input_num = while_info.data_inputs.size();
  418. for (size_t i = 0; i < input_num; i++) {
  419. InDataAnchorPtr in_data_anchor = while_node->GetInDataAnchor(i);
  420. GE_CHECK_NOTNULL(in_data_anchor);
  421. OutDataAnchorPtr peer_out_anchor = while_info.data_inputs[i];
  422. if (peer_out_anchor == nullptr) {
  423. continue;
  424. }
  425. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_data_anchor),
  426. "Add data-edge %s:%d->%s:%d failed.",
  427. peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetIdx(),
  428. while_node->GetName().c_str(), i);
  429. }
  430. size_t output_num = while_info.data_outputs.size();
  431. for (size_t i = 0; i < output_num; i++) {
  432. OutDataAnchorPtr out_data_anchor = while_node->GetOutDataAnchor(static_cast<int>(i + kWhileOutputIndex));
  433. GE_CHECK_NOTNULL(out_data_anchor);
  434. for (auto &peer_in_anchor : while_info.data_outputs[i]) {
  435. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_data_anchor, peer_in_anchor),
  436. "Add data-edge %s:%d->%s:%d failed.",
  437. while_node->GetName().c_str(), i + kWhileOutputIndex,
  438. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  439. }
  440. }
  441. InControlAnchorPtr in_ctrl_anchor = while_node->GetInControlAnchor();
  442. GE_CHECK_NOTNULL(in_ctrl_anchor);
  443. for (auto &peer_out_anchor : while_info.ctrl_inputs) {
  444. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_ctrl_anchor),
  445. "Add ctrl-edge %s->%s failed.",
  446. peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  447. in_ctrl_anchor->GetOwnerNode()->GetName().c_str());
  448. }
  449. OutControlAnchorPtr out_ctrl_anchor = while_node->GetOutControlAnchor();
  450. GE_CHECK_NOTNULL(out_ctrl_anchor);
  451. for (auto &peer_in_anchor : while_info.ctrl_outputs) {
  452. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_ctrl_anchor, peer_in_anchor),
  453. "Add ctrl-edge %s->%s failed.",
  454. out_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
  455. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  456. }
  457. return SUCCESS;
  458. }
  459. ///
  460. /// @brief Build cond_graph for while_node
  461. /// @param [in&out] while_info
  462. /// @return ComputeGraphPtr
  463. ///
  464. ComputeGraphPtr ForPass::BuildCondGraph(WhileInfo &while_info) {
  465. std::string cond_name = while_info.for_body_name + "_Cond";
  466. CompleteGraphBuilder graph_builder(cond_name);
  467. // Add parent node
  468. graph_builder.SetParentNode(while_info.while_node);
  469. // Add Node
  470. const std::string mul_name = "Mul";
  471. graph_builder.AddNode(CreateOpDesc(mul_name, MUL, false));
  472. const std::string less_name = "Less";
  473. graph_builder.AddNode(CreateOpDesc(less_name, LESS, false));
  474. // Set Input
  475. graph_builder.SetInput(kWhileIInputIndex, { mul_name }, { 0 })
  476. .SetInput(kWhileAbsDeltaInputIndex, { mul_name }, { 1 })
  477. .SetInput(kWhileRangeInputIndex, { less_name }, { 1 })
  478. .SetUselessInput(kWhileStartInputIndex)
  479. .SetUselessInput(kWhileDeltaInputIndex);
  480. size_t input_num = while_info.data_inputs.size();
  481. for (size_t i = kWhileDataInputIndex; i < input_num; i++) {
  482. graph_builder.SetUselessInput(i);
  483. }
  484. // Add Output
  485. graph_builder.AddOutput(less_name, 0);
  486. // Add Edges
  487. graph_builder.AddDataLink(mul_name, 0, less_name, 0);
  488. // Add Input-Mapping
  489. std::map<uint32_t, uint32_t> input_mapping;
  490. for (size_t i = 0; i < input_num; i++) {
  491. input_mapping[i] = i;
  492. }
  493. graph_builder.SetInputMapping(input_mapping);
  494. graphStatus error_code = GRAPH_SUCCESS;
  495. std::string error_msg;
  496. ComputeGraphPtr cond_graph = graph_builder.Build(error_code, error_msg);
  497. if (cond_graph == nullptr) {
  498. GELOGE(FAILED, "Build cond_graph failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  499. return nullptr;
  500. }
  501. size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size();
  502. while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_COND);
  503. while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, cond_name);
  504. while_info.while_cond = cond_graph;
  505. return cond_graph;
  506. }
  507. ///
  508. /// @brief Build body_graph for while_node
  509. /// @param [in&out] while_info
  510. /// @return ComputeGraphPtr
  511. ///
  512. ComputeGraphPtr ForPass::BuildBodyGraph(WhileInfo &while_info) {
  513. std::string body_name = while_info.for_body_name + "_Body";
  514. CompleteGraphBuilder graph_builder(body_name);
  515. // Add parent node
  516. graph_builder.SetParentNode(while_info.while_node);
  517. // Add calculation nodes
  518. std::string const_name = "Const";
  519. std::string add_name_0 = "Add_0";
  520. std::string mul_name = "Mul";
  521. std::string add_name_1 = "Add_1";
  522. graph_builder.AddNode(CreateConstDesc(const_name, 1))
  523. .AddNode(CreateOpDesc(add_name_0, ADD, false))
  524. .AddNode(CreateOpDesc(mul_name, MUL, false))
  525. .AddNode(CreateOpDesc(add_name_1, ADD, false));
  526. // Add Subgraph node
  527. auto input_num = static_cast<uint32_t>(while_info.data_inputs.size());
  528. std::string sub_graph_node_name = while_info.for_body_name;
  529. uint32_t sub_graph_input_num = input_num - kWhileDataInputIndex + kSubgraphInputIndex;
  530. auto sub_graph_output_num = static_cast<uint32_t>(while_info.data_outputs.size());
  531. graph_builder.AddNode(CreateSubgraphOpDesc(sub_graph_node_name, sub_graph_input_num, sub_graph_output_num));
  532. // Set Input
  533. graph_builder.SetInput(kWhileIInputIndex, { add_name_0, mul_name }, { 0, 0 })
  534. .SetUselessInput(kWhileAbsDeltaInputIndex)
  535. .SetUselessInput(kWhileRangeInputIndex)
  536. .SetInput(kWhileStartInputIndex, { add_name_1 }, { 0 })
  537. .SetInput(kWhileDeltaInputIndex, { mul_name }, { 1 });
  538. for (uint32_t i = 0; i < input_num - kWhileDataInputIndex; i++) {
  539. graph_builder.SetInput(i + kWhileDataInputIndex, { sub_graph_node_name }, { i + kSubgraphInputIndex });
  540. }
  541. // Add Outputs
  542. graph_builder.AddOutput(add_name_0, 0);
  543. for (uint32_t i = kWhileAbsDeltaInputIndex; i < kWhileDataInputIndex; i++) {
  544. graph_builder.AddOutput("Data_" + std::to_string(i), 0);
  545. }
  546. for (uint32_t i = 0; i < sub_graph_output_num; i++) {
  547. graph_builder.AddOutput(sub_graph_node_name, i);
  548. }
  549. // Add Edges
  550. graph_builder.AddDataLink(const_name, 0, add_name_0, 1)
  551. .AddDataLink(mul_name, 0, add_name_1, 1)
  552. .AddDataLink(add_name_1, 0, sub_graph_node_name, kSubgraphLoopVarInputIndex);
  553. // Add Input-Mapping
  554. std::map<uint32_t, uint32_t> input_mapping;
  555. for (size_t i = 0; i < input_num; i++) {
  556. input_mapping[i] = i;
  557. }
  558. graph_builder.SetInputMapping(input_mapping);
  559. // Add outputMapping
  560. std::map<uint32_t, uint32_t> output_mapping;
  561. for (size_t i = 0; i < sub_graph_output_num + kWhileOutputIndex; i++) {
  562. output_mapping[i] = i;
  563. }
  564. graph_builder.SetOutputMapping(output_mapping);
  565. graphStatus error_code = GRAPH_SUCCESS;
  566. std::string error_msg;
  567. ComputeGraphPtr body_graph = graph_builder.Build(error_code, error_msg);
  568. if (body_graph == nullptr) {
  569. GELOGE(FAILED, "Build body_graph failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
  570. return nullptr;
  571. }
  572. NodePtr sub_graph_node = graph_builder.GetNode(sub_graph_node_name);
  573. if (sub_graph_node == nullptr) {
  574. GELOGE(FAILED, "Get sub_graph_node failed: name:%s.", sub_graph_node_name.c_str());
  575. return nullptr;
  576. }
  577. while_info.sub_graph_node = sub_graph_node;
  578. size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size();
  579. while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_BODY);
  580. while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, body_name);
  581. while_info.while_body = body_graph;
  582. return body_graph;
  583. }
  584. ///
  585. /// @brief Create op_desc for subgraph node
  586. /// @param [in] name
  587. /// @param [in] input_num
  588. /// @param [in] output_num
  589. /// @return OpDescPtr
  590. ///
  591. OpDescPtr ForPass::CreateSubgraphOpDesc(const std::string &name, uint32_t input_num, uint32_t output_num) {
  592. OpDescBuilder op_desc_builder(name, PARTITIONEDCALL);
  593. op_desc_builder.AddDynamicInput("args", input_num)
  594. .AddDynamicOutput("output", output_num);
  595. OpDescPtr op_desc = op_desc_builder.Build();
  596. if (op_desc == nullptr) {
  597. GELOGE(FAILED, "Create op_desc for subgraph node failed, name:%s.", name.c_str());
  598. return nullptr;
  599. }
  600. size_t index = op_desc->GetSubgraphInstanceNames().size();
  601. op_desc->AddSubgraphName("f");
  602. op_desc->SetSubgraphInstanceName(index, name);
  603. return op_desc;
  604. }
  605. ///
  606. /// @brief Update InputMapping for for-body-graph
  607. /// @param [in] while_info
  608. /// @return Status
  609. ///
  610. Status ForPass::UpdateForBodyInputMapping(const WhileInfo &while_info) {
  611. ComputeGraphPtr for_body = while_info.for_body;
  612. GE_CHECK_NOTNULL(for_body);
  613. // index_of_cur_graph_node_input -> index_of_new_graph_node_input
  614. std::map<uint32_t, uint32_t> input_mapping;
  615. size_t input_num = while_info.data_inputs.size() - kWhileDataInputIndex + FOR_DATA_INPUT;
  616. for (size_t i = 0; i < input_num; i++) {
  617. if (i == FOR_START_INPUT) {
  618. input_mapping[i] = i;
  619. } else if ((i == FOR_LIMIT_INPUT) || (i == FOR_DELTA_INPUT)) {
  620. continue;
  621. } else {
  622. input_mapping[i] = i - 2;
  623. }
  624. }
  625. for_body->UpdateInputMapping(input_mapping);
  626. for_body->SetParentNode(while_info.sub_graph_node);
  627. for_body->SetParentGraph(while_info.while_body);
  628. return SUCCESS;
  629. }
  630. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示