You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mem_rw_conflict_optimize.cc 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <string>
  17. #include <vector>
  18. #include "common/ge/ge_util.h"
  19. #include "graph/common/omg_util.h"
  20. #include "graph/debug/ge_attr_define.h"
  21. #include "graph/optimize/graph_optimize.h"
  22. #include "graph/utils/graph_utils.h"
  23. #include "graph/utils/node_utils.h"
  24. namespace {
  25. using namespace ge;
  26. const int kIdentityAnchorIndex = 0;
  27. // rw type of input.
  28. enum class InputRWType {
  29. kReadOnly, // Normal op input only read
  30. kWriteable, // Op like Assign/ApplyMomentum
  31. kScopeWriteable, // Op like hcom_allreduce, it will modify input ,but not expect take effect on pre ouput
  32. kInvalidRWType
  33. };
  34. // rw type of output
  35. enum class OutputRWType {
  36. kReadOnly, // 1.const output 2.not ref output but has several peer output
  37. kSoftRead, // not ref output but only has one output node
  38. kWriteable, // ref output. Like Assign/ApplyMomentum
  39. kInvalidRWType
  40. };
  41. // input and output rw_type of one node. key is anchor_idx, value is rw_type
  42. struct NodeInputOutputRWType {
  43. map<uint32_t, InputRWType> input_rw_type_map;
  44. map<uint32_t, OutputRWType> output_rw_type_map;
  45. };
  46. // input and output rw_type of node in current graph
  47. thread_local map<string, NodeInputOutputRWType> node_rwtype_map_;
  48. ///
  49. /// @brief Convert input rw_type enum to string. For log print.
  50. /// @param rw_type
  51. /// @return rw_type_name
  52. ///
  53. static std::string InputRWTypeToSerialString(InputRWType rw_type) {
  54. const static char *names[4] = {"ReadOnly", "Writeable", "ScopeWriteable", "InvalidRWType"};
  55. return names[static_cast<int>(rw_type)];
  56. }
  57. ///
  58. /// @brief Convert output rw_type enum to string. For log print.
  59. /// @param rw_type
  60. /// @return rw_type_name
  61. ///
  62. static std::string OutputRWTypeToSerialString(OutputRWType rw_type) {
  63. const static char *names[4] = {"ReadOnly", "SoftRead", "Writeable", "InvalidRWType"};
  64. return names[static_cast<int>(rw_type)];
  65. }
  66. OutputRWType GetSingleNodeOutputRWTypeByIndex(const Node &node, uint32_t index) {
  67. auto op_desc = node.GetOpDesc();
  68. if (op_desc == nullptr) {
  69. return OutputRWType::kInvalidRWType;
  70. }
  71. if (op_desc->GetType() == VARIABLE) {
  72. return OutputRWType::kWriteable;
  73. }
  74. // check if it is ref output
  75. auto input_names = op_desc->GetAllInputName();
  76. for (auto &input_name_2_idx : input_names) {
  77. if (op_desc->GetOutputNameByIndex(index) == input_name_2_idx.first) {
  78. return OutputRWType::kWriteable;
  79. }
  80. }
  81. // check if it is ref switch
  82. std::string type;
  83. if ((node.GetType() == FRAMEWORK_OP_TYPE) && AttrUtils::GetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type) &&
  84. (type == REFSWITCH)) {
  85. return OutputRWType::kWriteable;
  86. }
  87. if (op_desc->GetType() == CONSTANT || op_desc->GetType() == CONSTANTOP) {
  88. return OutputRWType::kReadOnly;
  89. }
  90. auto out_data_anchor = node.GetOutDataAnchor(index);
  91. if (out_data_anchor == nullptr) {
  92. return OutputRWType::kInvalidRWType;
  93. }
  94. if (out_data_anchor->GetPeerInDataNodesSize() > 1) {
  95. return OutputRWType::kReadOnly;
  96. } else {
  97. return OutputRWType::kSoftRead;
  98. }
  99. }
  100. ///
  101. /// @brief Get input rw_type of one node with sub graph. It will return rw_type after solve conflict scene.
  102. /// @param rw_type_set
  103. /// @return
  104. ///
  105. InputRWType GetInputRwTypeInConflict(const std::set<int> &rw_type_set) {
  106. // for input rw type calc
  107. int total_rw_type = 0;
  108. for (const auto rw : rw_type_set) {
  109. total_rw_type += rw;
  110. }
  111. switch (total_rw_type) {
  112. case 0:
  113. return InputRWType::kReadOnly; // all input rw type is readonly
  114. case 2:
  115. return InputRWType::kScopeWriteable; // readonly 2 scope_writeable
  116. case 3:
  117. return InputRWType::kWriteable; // all input rw type is writeable or readonly 2 writeable
  118. case 5:
  119. return InputRWType::kInvalidRWType; // writeable 2 scope_writeable
  120. default:
  121. return InputRWType::kInvalidRWType;
  122. }
  123. }
  124. NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) {
  125. if (src_node.GetOpDesc() == nullptr) {
  126. return nullptr;
  127. }
  128. static std::atomic_long identity_num(0);
  129. auto next_num = identity_num.fetch_add(1);
  130. // 1. create new identity op desc
  131. string identity_name = src_node.GetName() + "_" + IDENTITY + std::to_string(next_num);
  132. auto identity_opdesc = MakeShared<OpDesc>(identity_name, IDENTITY);
  133. if (identity_opdesc == nullptr) {
  134. GELOGE(OUT_OF_MEMORY, "Failed to insert identity node, name %s", identity_name.c_str());
  135. return nullptr;
  136. }
  137. auto data_desc = src_node.GetOpDesc()->GetOutputDesc(out_anchor_idx);
  138. // 2. add input_desc & output_desc for new identity
  139. Status ret = identity_opdesc->AddInputDesc("x", data_desc);
  140. if (ret != SUCCESS) {
  141. GELOGE(ret, "Add Input desc failed for new identity %s.", identity_name.c_str());
  142. return nullptr;
  143. }
  144. ret = identity_opdesc->AddOutputDesc("y", data_desc);
  145. if (ret != SUCCESS) {
  146. GELOGE(ret, "Add Output desc failed for new Identity %s.", identity_name.c_str());
  147. return nullptr;
  148. }
  149. GELOGI("Insert new Identity node %s.", identity_name.c_str());
  150. auto graph = src_node.GetOwnerComputeGraph();
  151. if (graph == nullptr) {
  152. GELOGE(GRAPH_PARAM_INVALID, "Node %s owner compute graph is null.", src_node.GetName().c_str());
  153. return nullptr;
  154. }
  155. return graph->AddNode(identity_opdesc);
  156. }
  157. OutputRWType GetOutputRWTypeByIndex(const Node &node, uint32_t index) {
  158. auto op_desc = node.GetOpDesc();
  159. if (op_desc == nullptr) {
  160. return OutputRWType::kInvalidRWType;
  161. }
  162. if (op_desc->GetType() == WHILE) {
  163. return OutputRWType::kSoftRead;
  164. }
  165. vector<string> subgraph_names = op_desc->GetSubgraphInstanceNames();
  166. if (subgraph_names.empty()) {
  167. // single node without sub graph
  168. return GetSingleNodeOutputRWTypeByIndex(node, index);
  169. } else {
  170. // node with sub graph
  171. auto output_node_vec = NodeUtils::GetSubgraphOutputNodes(node);
  172. auto output_rw_type = OutputRWType::kInvalidRWType;
  173. if (output_node_vec.size() == 1) {
  174. // find rw type from map.
  175. auto iter = node_rwtype_map_.find(output_node_vec.at(0)->GetName());
  176. if (iter == node_rwtype_map_.end()) {
  177. GELOGW("Can not find rw type of node %s from map.It could take some effect on following preprocess.",
  178. output_node_vec.at(0)->GetName().c_str());
  179. return OutputRWType::kInvalidRWType;
  180. }
  181. auto index_2_output_rw_type = iter->second.output_rw_type_map.find(index);
  182. if (index_2_output_rw_type == iter->second.output_rw_type_map.end()) {
  183. GELOGW("Can not find rw type of node %s from map.It could take some effect on following preprocess.",
  184. output_node_vec.at(0)->GetName().c_str());
  185. return OutputRWType::kInvalidRWType;
  186. }
  187. output_rw_type = index_2_output_rw_type->second;
  188. } else {
  189. output_rw_type = OutputRWType::kSoftRead;
  190. }
  191. // check peer input
  192. auto out_data_anchor = node.GetOutDataAnchor(index);
  193. if (out_data_anchor == nullptr) {
  194. return OutputRWType::kInvalidRWType;
  195. }
  196. if (out_data_anchor->GetPeerInDataNodesSize() > 1) {
  197. return OutputRWType::kReadOnly;
  198. } else {
  199. return output_rw_type;
  200. }
  201. }
  202. }
  203. InputRWType GetSingleNodeInputRWTypeByIndex(const Node &node, uint32_t index) {
  204. auto op_desc = node.GetOpDesc();
  205. if (op_desc == nullptr) {
  206. return InputRWType::kInvalidRWType;
  207. }
  208. if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMALLGATHER ||
  209. op_desc->GetType() == HCOMREDUCESCATTER) {
  210. return InputRWType::kScopeWriteable;
  211. }
  212. // check if it is ref input
  213. auto output_names = op_desc->GetAllOutputName();
  214. for (auto &output_name_2_idx : output_names) {
  215. if (op_desc->GetInputNameByIndex(index) == output_name_2_idx.first) {
  216. return InputRWType::kWriteable;
  217. }
  218. }
  219. // check if it is ref switch
  220. std::string type;
  221. if ((node.GetType() == FRAMEWORK_OP_TYPE) && (AttrUtils::GetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) &&
  222. (type == REFSWITCH) && (index == 0)) {
  223. return InputRWType::kWriteable;
  224. }
  225. return InputRWType::kReadOnly;
  226. }
  227. InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) {
  228. auto op_desc = node.GetOpDesc();
  229. if (op_desc == nullptr) {
  230. return InputRWType::kInvalidRWType;
  231. }
  232. if (op_desc->GetType() == WHILE) {
  233. return InputRWType::kScopeWriteable;
  234. }
  235. vector<string> subgraph_names = op_desc->GetSubgraphInstanceNames();
  236. if (subgraph_names.empty()) {
  237. // single node without sub graph
  238. return GetSingleNodeInputRWTypeByIndex(node, index);
  239. } else {
  240. // node with sub graph
  241. std::set<int> node_rw_type_set;
  242. auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index);
  243. // get all input data node in subgraph
  244. std::set<int> anchor_rw_type_set;
  245. for (const auto &data_node : data_node_vec) {
  246. // Data only has 1 out data anchor. Here just take first out data anchor. And index 0 is valid.
  247. auto out_data_anchor = data_node->GetOutDataAnchor(0);
  248. if (out_data_anchor == nullptr) {
  249. continue;
  250. }
  251. auto data_op_desc = data_node->GetOpDesc();
  252. if (data_op_desc == nullptr) {
  253. continue;
  254. }
  255. // find rw type from map.
  256. auto iter = node_rwtype_map_.find(data_op_desc->GetName());
  257. if (iter == node_rwtype_map_.end()) {
  258. GELOGW("Can not find rw type of node %s from map.It could take some effect on following preprocess.",
  259. data_op_desc->GetName().c_str());
  260. return InputRWType::kInvalidRWType;
  261. }
  262. auto input_rw_type = iter->second.input_rw_type_map.find(out_data_anchor->GetIdx());
  263. if (input_rw_type == iter->second.input_rw_type_map.end()) {
  264. GELOGW("Can not find rw type of node %s from map.It could take some effect on following preprocess.",
  265. data_op_desc->GetName().c_str());
  266. return InputRWType::kInvalidRWType;
  267. }
  268. anchor_rw_type_set.emplace(static_cast<int>(input_rw_type->second));
  269. }
  270. return GetInputRwTypeInConflict(anchor_rw_type_set);
  271. }
  272. }
  273. Status MarkRWTypeForSubgraph(const ComputeGraphPtr &sub_graph) {
  274. for (const auto &node : sub_graph->GetDirectNode()) {
  275. GE_CHECK_NOTNULL(node);
  276. GE_CHECK_NOTNULL(node->GetOpDesc());
  277. std::set<int> anchor_rw_type_set;
  278. if (node->GetType() == DATA) {
  279. // calc all input_rw_type of peer output , as input_rw_type of DATA. Index 0 is valid.
  280. auto anchor_2_node_vec = NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, 0);
  281. for (const auto anchor_2_node_pair : anchor_2_node_vec) {
  282. auto input_rw_type = GetInputRWTypeByIndex(*anchor_2_node_pair.second, anchor_2_node_pair.first->GetIdx());
  283. GELOGD("Input rw type of Node %s %dth input anchor is %s", anchor_2_node_pair.second->GetName().c_str(),
  284. anchor_2_node_pair.first->GetIdx(), InputRWTypeToSerialString(input_rw_type).c_str());
  285. anchor_rw_type_set.emplace(static_cast<int>(input_rw_type));
  286. }
  287. auto anchor_rw_type = GetInputRwTypeInConflict(anchor_rw_type_set);
  288. GELOGD("Input rw type of Node %s is %s", node->GetName().c_str(),
  289. InputRWTypeToSerialString(anchor_rw_type).c_str());
  290. map<uint32_t, InputRWType> input_rw_type_map{std::make_pair(0, anchor_rw_type)};
  291. NodeInputOutputRWType data_rw_type{input_rw_type_map};
  292. node_rwtype_map_.emplace(std::make_pair(node->GetName(), data_rw_type));
  293. }
  294. if (node->GetType() == NETOUTPUT) {
  295. // calc all output_rw_type of peer input , as output_rw_type of DATA
  296. map<uint32_t, OutputRWType> output_rw_type_map;
  297. for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
  298. GE_CHECK_NOTNULL(in_data_anchor);
  299. auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor();
  300. GE_CHECK_NOTNULL(pre_out_anchor);
  301. auto pre_node = pre_out_anchor->GetOwnerNode();
  302. GE_CHECK_NOTNULL(pre_node);
  303. auto pre_output_rw_type = GetOutputRWTypeByIndex(*pre_node, pre_out_anchor->GetIdx());
  304. GELOGD("Output rw type of Node %s %dth output anchor is %s", pre_node->GetName().c_str(),
  305. pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str());
  306. if (pre_output_rw_type == OutputRWType::kWriteable) {
  307. // insert identity
  308. auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx());
  309. GE_CHECK_NOTNULL(identity_node);
  310. auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node);
  311. if (ret != SUCCESS) {
  312. GELOGE(ret, "Fail to insert identity");
  313. return ret;
  314. }
  315. GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(),
  316. pre_node->GetName().c_str(), node->GetName().c_str());
  317. }
  318. output_rw_type_map.emplace(std::make_pair(in_data_anchor->GetIdx(), OutputRWType::kSoftRead));
  319. }
  320. NodeInputOutputRWType output_rw_type{{}, output_rw_type_map};
  321. node_rwtype_map_.emplace(std::make_pair(node->GetName(), output_rw_type));
  322. }
  323. }
  324. return SUCCESS;
  325. }
  326. ///
  327. /// @brief Reverse traversal all subgraph and mark rw_type for Data/Netoutput.
  328. /// @param sub_graph_vecgs
  329. ///
  330. Status MarkRWTypeForAllSubgraph(const vector<ComputeGraphPtr> &sub_graph_vec) {
  331. for (auto iter = sub_graph_vec.rbegin(); iter != sub_graph_vec.rend(); ++iter) {
  332. auto parent_node = (*iter)->GetParentNode();
  333. if (parent_node == nullptr) {
  334. GELOGD("Current sub graph has no parent node. Ignore it.");
  335. continue;
  336. }
  337. if (parent_node->GetType() == WHILE) {
  338. continue;
  339. }
  340. auto ret = MarkRWTypeForSubgraph(*iter);
  341. if (ret != SUCCESS) {
  342. return ret;
  343. }
  344. }
  345. return SUCCESS;
  346. }
  347. ///
  348. /// @brief Check identity is near subgraph.
  349. /// Eg. As output of Data node in subgraph
  350. /// or as input of Netoutput of subgraph
  351. /// or as input of one node with subgraph
  352. /// or as output of one node with subgraph
  353. /// @param node
  354. /// @return is_near_subgraph
  355. ///
  356. bool CheckIdentityIsNearSubgraph(const Node &node) {
  357. for (const auto &in_node : node.GetInDataNodes()) {
  358. auto in_node_opdesc = in_node->GetOpDesc();
  359. if (in_node_opdesc == nullptr) {
  360. continue;
  361. }
  362. // near entrance of subgraph
  363. if (in_node->GetType() == DATA && NodeUtils::IsSubgraphInput(in_node)) {
  364. return true;
  365. }
  366. // near subgraph
  367. if (!in_node_opdesc->GetSubgraphInstanceNames().empty()) {
  368. return true;
  369. }
  370. }
  371. for (const auto &out_node : node.GetOutDataNodes()) {
  372. auto out_node_opdesc = out_node->GetOpDesc();
  373. if (out_node_opdesc == nullptr) {
  374. continue;
  375. }
  376. // near output of subgraph
  377. if (out_node->GetType() == NETOUTPUT && NodeUtils::IsSubgraphOutput(out_node)) {
  378. return true;
  379. }
  380. // near subgraph
  381. if (!out_node_opdesc->GetSubgraphInstanceNames().empty()) {
  382. return true;
  383. }
  384. }
  385. return false;
  386. }
  387. enum ConflictResult { DO_NOTHING, WRONG_GRAPH, INSERT_IDENTITY };
  388. vector<vector<ConflictResult>> output_2_input_rwtype = {{DO_NOTHING, WRONG_GRAPH, INSERT_IDENTITY},
  389. {DO_NOTHING, WRONG_GRAPH, DO_NOTHING},
  390. {DO_NOTHING, DO_NOTHING, INSERT_IDENTITY}};
  391. ConflictResult GetConflictResultBetweenNode(const OutputRWType output_rw_type, const InputRWType input_rw_type) {
  392. if (output_rw_type == OutputRWType::kInvalidRWType || input_rw_type == InputRWType::kInvalidRWType) {
  393. return WRONG_GRAPH;
  394. }
  395. auto n = static_cast<int>(output_rw_type);
  396. auto m = static_cast<int>(input_rw_type);
  397. // no need to check index or container, because container and index is all defined.
  398. return output_2_input_rwtype[n][m];
  399. }
  400. ///
  401. /// @brief Keep identity_node which near subgraph or has multi output
  402. /// @param node
  403. /// @return
  404. ///
  405. Status RemoveNoUseIdentity(const NodePtr &node) {
  406. if (node->GetInDataNodes().empty() || node->GetOutDataNodesSize() > 1) {
  407. return SUCCESS;
  408. }
  409. if (node->GetOutDataNodesSize() == 1 && node->GetOutDataNodes().at(0)->GetType() == STREAMMERGE) {
  410. return SUCCESS;
  411. }
  412. if (CheckIdentityIsNearSubgraph(*node)) {
  413. return SUCCESS;
  414. }
  415. GE_CHECK_NOTNULL(node->GetInDataAnchor(kIdentityAnchorIndex));
  416. auto pre_out_anchor = node->GetInDataAnchor(kIdentityAnchorIndex)->GetPeerOutAnchor();
  417. GE_CHECK_NOTNULL(pre_out_anchor);
  418. auto pre_node = pre_out_anchor->GetOwnerNode();
  419. auto pre_output_rw_type = GetOutputRWTypeByIndex(*pre_node, pre_out_anchor->GetIdx());
  420. auto anchor_2_outnode_vec = NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, kIdentityAnchorIndex);
  421. ConflictResult conflict_result = WRONG_GRAPH;
  422. if (!anchor_2_outnode_vec.empty()) {
  423. auto anchor_2_outnode = anchor_2_outnode_vec.at(0);
  424. auto peer_input_rw_type = GetInputRWTypeByIndex(*anchor_2_outnode.second, anchor_2_outnode.first->GetIdx());
  425. GELOGD("Pre Node %s %dth output rw type is %s, peer node %s %dth input rw type is %s.", pre_node->GetName().c_str(),
  426. pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str(),
  427. anchor_2_outnode.second->GetName().c_str(), anchor_2_outnode.first->GetIdx(),
  428. InputRWTypeToSerialString(peer_input_rw_type).c_str());
  429. conflict_result = GetConflictResultBetweenNode(pre_output_rw_type, peer_input_rw_type);
  430. } else {
  431. // identity node has no out data node, it can be removed
  432. conflict_result = DO_NOTHING;
  433. }
  434. if (conflict_result != DO_NOTHING) {
  435. return SUCCESS;
  436. }
  437. GELOGI("No need insert Identity. Node %s need to remove.", node->GetName().c_str());
  438. auto ret = GraphUtils::IsolateNode(node, {0});
  439. if (ret != SUCCESS) {
  440. GELOGE(ret, "Fail to isolate node %s.", node->GetName().c_str());
  441. return ret;
  442. }
  443. ret = GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node);
  444. if (ret != SUCCESS) {
  445. GELOGE(ret, "Fail to isolate node %s.", node->GetName().c_str());
  446. return ret;
  447. }
  448. GELOGI("Pre node is %s and %dth output rw type is %s. Isolate and remove Identity node %s.",
  449. pre_node->GetName().c_str(), pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str(),
  450. node->GetName().c_str());
  451. return SUCCESS;
  452. }
  453. Status SplitIdentityAlongAnchor(const OutDataAnchorPtr &out_data_anchor, const InDataAnchorPtr &peer_in_data_anchor,
  454. const OutDataAnchorPtr &pre_out_data_anchor, NodePtr &pre_node) {
  455. // 1.check peer in node RW type.
  456. GE_CHECK_NOTNULL(peer_in_data_anchor);
  457. auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode();
  458. GE_CHECK_NOTNULL(peer_in_data_node);
  459. auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx());
  460. auto ret = out_data_anchor->Unlink(peer_in_data_anchor);
  461. auto old_identity = out_data_anchor->GetOwnerNode();
  462. if (ret != SUCCESS) {
  463. GELOGE(ret, "Failed to unlink from %s %dth out to %s.", old_identity->GetName().c_str(), out_data_anchor->GetIdx(),
  464. peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
  465. return ret;
  466. }
  467. if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) {
  468. auto new_identity = CreateIdentityAfterSrcNode(*pre_node, pre_out_data_anchor->GetIdx());
  469. GE_CHECK_NOTNULL(new_identity);
  470. if (GraphUtils::AddEdge(pre_out_data_anchor, new_identity->GetInDataAnchor(kIdentityAnchorIndex)) != SUCCESS ||
  471. GraphUtils::AddEdge(new_identity->GetOutDataAnchor(kIdentityAnchorIndex), peer_in_data_anchor) != SUCCESS) {
  472. GELOGE(INTERNAL_ERROR, "Failed to insert Identity between node %s and %s",
  473. pre_out_data_anchor->GetOwnerNode()->GetName().c_str(),
  474. peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
  475. return INTERNAL_ERROR;
  476. }
  477. // 2. copy in-control-edge from dst to Identity
  478. if (GraphUtils::CopyInCtrlEdges(peer_in_data_node, new_identity) != SUCCESS) {
  479. GELOGE(INTERNAL_ERROR, "Failed to copy in_control edges from node %s to %s", peer_in_data_node->GetName().c_str(),
  480. new_identity->GetName().c_str());
  481. return INTERNAL_ERROR;
  482. }
  483. GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(),
  484. InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(),
  485. peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
  486. } else {
  487. // copy control edge to pre and peer node
  488. if (GraphUtils::CopyInCtrlEdges(old_identity, peer_in_data_node) != SUCCESS ||
  489. GraphUtils::CopyOutCtrlEdges(old_identity, pre_node) != SUCCESS) {
  490. GELOGW("Fail to copy control edge from node %s.", old_identity->GetName().c_str());
  491. return FAILED;
  492. }
  493. // link identity pre node to next node directly
  494. if (GraphUtils::AddEdge(pre_out_data_anchor, peer_in_data_anchor) != SUCCESS) {
  495. GELOGW("Fail to link data edge from node %s to %s.", pre_out_data_anchor->GetOwnerNode()->GetName().c_str(),
  496. peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
  497. return FAILED;
  498. }
  499. GELOGI("Node %s input rw type is %s, link data edge from Identity input node %s to out node %s directly.",
  500. peer_in_data_node->GetName().c_str(), InputRWTypeToSerialString(input_rw_type).c_str(),
  501. pre_node->GetName().c_str(), peer_in_data_node->GetName().c_str());
  502. }
  503. return SUCCESS;
  504. }
  505. Status SplitIdentity(const NodePtr &node) {
  506. GE_CHECK_NOTNULL(node);
  507. auto out_data_anchor = node->GetOutDataAnchor(kIdentityAnchorIndex);
  508. GE_CHECK_NOTNULL(out_data_anchor);
  509. if (out_data_anchor->GetPeerInDataNodesSize() <= 1) {
  510. return SUCCESS;
  511. }
  512. // get pre node and next node of identity
  513. GE_CHECK_NOTNULL(node->GetInDataAnchor(kIdentityAnchorIndex));
  514. auto pre_out_data_anchor = node->GetInDataAnchor(kIdentityAnchorIndex)->GetPeerOutAnchor();
  515. GE_CHECK_NOTNULL(pre_out_data_anchor);
  516. auto pre_node = pre_out_data_anchor->GetOwnerNode();
  517. GE_CHECK_NOTNULL(pre_node);
  518. for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  519. Status ret = SplitIdentityAlongAnchor(out_data_anchor, peer_in_data_anchor, pre_out_data_anchor, pre_node);
  520. if (ret != SUCCESS) {
  521. GELOGE(ret, "Split identity node along anchor failed.");
  522. return ret;
  523. }
  524. }
  525. // 2.isolate Identity node with no data output
  526. if (node->GetOutDataNodesSize() == 0) {
  527. Status ret = GraphUtils::IsolateNode(node, {});
  528. if (ret != SUCCESS) {
  529. GELOGE(FAILED, "IsolateAndDelete identity node %s.", node->GetName().c_str());
  530. return FAILED;
  531. }
  532. ret = GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node);
  533. if (ret != SUCCESS) {
  534. GELOGE(FAILED, "IsolateAndDelete identity node %s.", node->GetName().c_str());
  535. return FAILED;
  536. }
  537. GELOGI("IsolateAndDelete identity node %s.", node->GetName().c_str());
  538. }
  539. return SUCCESS;
  540. }
  541. Status InsertIdentityAsNeeded(const NodePtr &node) {
  542. auto op_desc = node->GetOpDesc();
  543. GE_CHECK_NOTNULL(op_desc);
  544. if (node->GetOutDataNodesSize() == 0) {
  545. return SUCCESS;
  546. }
  547. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  548. GE_CHECK_NOTNULL(out_data_anchor);
  549. auto output_rw_type = GetOutputRWTypeByIndex(*node, out_data_anchor->GetIdx());
  550. for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  551. GE_CHECK_NOTNULL(peer_in_data_anchor);
  552. auto peer_in_node = peer_in_data_anchor->GetOwnerNode();
  553. GE_CHECK_NOTNULL(peer_in_node);
  554. auto input_rw_type = GetInputRWTypeByIndex(*peer_in_node, peer_in_data_anchor->GetIdx());
  555. GELOGD("Node %s output rw type is %s, Node %s input rw type is %s", node->GetName().c_str(),
  556. OutputRWTypeToSerialString(output_rw_type).c_str(), peer_in_node->GetName().c_str(),
  557. InputRWTypeToSerialString(input_rw_type).c_str());
  558. auto conflict_result = GetConflictResultBetweenNode(output_rw_type, input_rw_type);
  559. switch (conflict_result) {
  560. case DO_NOTHING:
  561. case WRONG_GRAPH:
  562. GELOGD("No need insert Identity.");
  563. continue;
  564. case INSERT_IDENTITY:
  565. auto identity_node = CreateIdentityAfterSrcNode(*node, out_data_anchor->GetIdx());
  566. if (identity_node == nullptr) {
  567. GELOGE(FAILED, "Create identity node failed.");
  568. return FAILED;
  569. }
  570. auto ret = GraphUtils::InsertNodeBetweenDataAnchors(out_data_anchor, peer_in_data_anchor, identity_node);
  571. if (ret != GRAPH_SUCCESS) {
  572. GELOGE(INTERNAL_ERROR, "Failed to insert reshape between node %s and %s", node->GetName().c_str(),
  573. peer_in_node->GetName().c_str());
  574. return INTERNAL_ERROR;
  575. }
  576. GELOGI("Insert Identity between %s and %s to handle memory conflict.", node->GetName().c_str(),
  577. peer_in_node->GetName().c_str());
  578. continue;
  579. }
  580. }
  581. }
  582. return SUCCESS;
  583. }
  584. } // namespace
  585. namespace ge {
  586. Status GraphOptimize::CheckRWConflict(ComputeGraphPtr &compute_graph, bool &has_conflict) {
  587. node_rwtype_map_.clear();
  588. auto sub_graph_vec = compute_graph->GetAllSubgraphs();
  589. if (sub_graph_vec.empty()) {
  590. GELOGD("No sub graph here. Ignore memory conflict handle.");
  591. return SUCCESS;
  592. }
  593. // 1.loop all subgraph, mark rw type from inside to outside
  594. Status ret = MarkRWTypeForAllSubgraph(sub_graph_vec);
  595. if (ret != SUCCESS) {
  596. GELOGE(ret, "Fail to mark rw type for subgraph.");
  597. return ret;
  598. }
  599. has_conflict = false;
  600. for (const auto &node : compute_graph->GetAllNodes()) {
  601. auto op_desc = node->GetOpDesc();
  602. GE_CHECK_NOTNULL(op_desc);
  603. if (node->GetOutDataNodesSize() == 0) {
  604. return SUCCESS;
  605. }
  606. if (node->GetType() == WHILE) {
  607. return SUCCESS;
  608. }
  609. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  610. GE_CHECK_NOTNULL(out_data_anchor);
  611. auto output_rw_type = GetOutputRWTypeByIndex(*node, out_data_anchor->GetIdx());
  612. for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  613. GE_CHECK_NOTNULL(peer_in_data_anchor);
  614. auto peer_in_node = peer_in_data_anchor->GetOwnerNode();
  615. GE_CHECK_NOTNULL(peer_in_node);
  616. if (peer_in_node->GetType() == WHILE) {
  617. return SUCCESS;
  618. }
  619. auto input_rw_type = GetInputRWTypeByIndex(*peer_in_node, peer_in_data_anchor->GetIdx());
  620. auto conflict_result = GetConflictResultBetweenNode(output_rw_type, input_rw_type);
  621. switch (conflict_result) {
  622. case DO_NOTHING:
  623. GELOGD("No rw conflict.");
  624. continue;
  625. case WRONG_GRAPH:
  626. has_conflict = true;
  627. GELOGI("Node %s output rw type is %s, next node %s input_rw_type is %s.It is wrong graph.",
  628. node->GetName().c_str(), OutputRWTypeToSerialString(output_rw_type).c_str(),
  629. peer_in_node->GetName().c_str(), InputRWTypeToSerialString(input_rw_type).c_str());
  630. return SUCCESS;
  631. case INSERT_IDENTITY:
  632. GELOGD("There is rw conflict. It will handle later.");
  633. continue;
  634. }
  635. }
  636. }
  637. }
  638. return SUCCESS;
  639. }
  640. Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) {
  641. node_rwtype_map_.clear();
  642. auto sub_graph_vec = compute_graph->GetAllSubgraphs();
  643. if (sub_graph_vec.empty()) {
  644. GELOGD("No sub graph here. Ignore memory conflict handle.");
  645. return SUCCESS;
  646. }
  647. GE_DUMP(compute_graph, "BeforeHandleMemConflict");
  648. // 1.loop all subgraph, mark rw type from inside to outside
  649. Status ret = MarkRWTypeForAllSubgraph(sub_graph_vec);
  650. if (ret != SUCCESS) {
  651. GELOGE(ret, "Fail to mark rw type for subgraph.");
  652. return ret;
  653. }
  654. // 2.loop all node, including node in subgraph and handle memory rw conflict
  655. for (auto &node : compute_graph->GetAllNodes()) {
  656. // ignore while subgraph node
  657. const auto parent_node = node->GetOwnerComputeGraph()->GetParentNode();
  658. if ((parent_node != nullptr) && (kWhileOpTypes.count(parent_node->GetType()) > 0)) {
  659. continue;
  660. }
  661. // ignore data / netoutput of subgraph
  662. if (node->GetType() == DATA && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) {
  663. continue;
  664. }
  665. if (node->GetType() == NETOUTPUT && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) {
  666. continue;
  667. }
  668. if (node->GetType() == IDENTITY || node->GetType() == READVARIABLEOP) {
  669. // split identity
  670. ret = SplitIdentity(node);
  671. if (ret != SUCCESS) {
  672. GELOGE(ret, "Fail to split identity node %s.", node->GetName().c_str());
  673. return ret;
  674. }
  675. // remove no use identity
  676. ret = RemoveNoUseIdentity(node);
  677. if (ret != SUCCESS) {
  678. GELOGE(ret, "Fail to remove useless identity node %s.", node->GetName().c_str());
  679. return ret;
  680. }
  681. }
  682. // insert Identity
  683. ret = InsertIdentityAsNeeded(node);
  684. if (ret != SUCCESS) {
  685. GELOGE(ret, "Fail to insert Identity node.");
  686. return ret;
  687. }
  688. }
  689. GE_DUMP(compute_graph, "AfterHandleMemConflict");
  690. return SUCCESS;
  691. }
  692. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示