You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transop_without_reshape_fusion_pass.cc 46 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/transop_without_reshape_fusion_pass.h"
  17. #include <algorithm>
  18. #include <memory>
  19. #include <sstream>
  20. #include <string>
  21. #include <atomic>
  22. #include "common/ge/ge_util.h"
  23. #include "common/ge_inner_error_codes.h"
  24. #include "common/types.h"
  25. #include "graph/common/transop_util.h"
  26. #include "graph/compute_graph.h"
  27. #include "graph/debug/ge_attr_define.h"
  28. #include "graph/ge_tensor.h"
  29. #include "graph/op_desc.h"
  30. #include "graph/utils/graph_utils.h"
  31. #include "graph/utils/node_utils.h"
  32. #include "graph/utils/op_desc_utils.h"
  33. #include "graph/utils/type_utils.h"
  34. #include "init/gelib.h"
  35. namespace {
  36. const char *const kRemainNode = "node_remain";
  37. const int kInvalidFusionOpCount = -1;
  38. const char *const kAttrNameSrcFormat = "src_format";
  39. const char *const kAttrNameDstFormat = "dst_format";
  40. } // namespace
  41. namespace ge {
  42. void TransOpWithoutReshapeFusionPass::SetRemainNode(
  43. const vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_anchor) {
  44. auto iter = nodes_anchor.begin();
  45. while (iter != nodes_anchor.end()) {
  46. auto in_anchor = iter->second;
  47. if (in_anchor == nullptr) {
  48. return;
  49. }
  50. auto in_node = in_anchor->GetOwnerNode();
  51. ++iter;
  52. if (in_node == nullptr) {
  53. return;
  54. }
  55. if (!IsTransOp(in_node)) {
  56. continue;
  57. }
  58. auto op_desc = in_node->GetOpDesc();
  59. if (op_desc == nullptr) {
  60. continue;
  61. }
  62. GELOGI("SetRemainNode node is %s", op_desc->GetName().c_str());
  63. GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return );
  64. }
  65. }
  66. bool TransOpWithoutReshapeFusionPass::FormatContinuousCheck(const OutDataAnchorPtr &out_anchor,
  67. const InDataAnchorPtr &in_anchor) {
  68. if (out_anchor == nullptr || in_anchor == nullptr || in_anchor->GetOwnerNode() == nullptr ||
  69. out_anchor->GetOwnerNode() == nullptr) {
  70. return false;
  71. }
  72. auto in_node = in_anchor->GetOwnerNode();
  73. GE_IF_BOOL_EXEC(in_node == nullptr, GELOGE(INTERNAL_ERROR, "in_node is null"); return false);
  74. auto in_op = in_node->GetOpDesc();
  75. auto out_owner_node = out_anchor->GetOwnerNode();
  76. GE_IF_BOOL_EXEC(out_owner_node == nullptr, GELOGE(INTERNAL_ERROR, "out_owner_node is null"); return false);
  77. auto out_op = out_owner_node->GetOpDesc();
  78. GE_IF_BOOL_EXEC(in_op == nullptr, GELOGE(INTERNAL_ERROR, "in_op is null"); return false);
  79. GE_IF_BOOL_EXEC(out_op == nullptr, GELOGE(INTERNAL_ERROR, "out_op is null"); return false);
  80. auto in_op_desc = in_op->GetInputDescPtr(in_anchor->GetIdx());
  81. auto out_op_desc = out_op->GetOutputDescPtr(out_anchor->GetIdx());
  82. GE_IF_BOOL_EXEC(in_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "in_op_desc is null"); return false);
  83. GE_IF_BOOL_EXEC(out_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "out_op_desc is null"); return false);
  84. if (!ShapeEqualCheck(in_op_desc->GetShape(), out_op_desc->GetShape())) {
  85. return false;
  86. }
  87. if (in_op->GetType() == CAST || out_op->GetType() == CAST) {
  88. return TransOpUtil::CheckPrecisionLoss(in_node);
  89. }
  90. if (in_op_desc->GetFormat() == FORMAT_ND) {
  91. return false;
  92. }
  93. if (out_op_desc->GetFormat() == FORMAT_ND) {
  94. return false;
  95. }
  96. if (in_op_desc->GetFormat() != out_op_desc->GetFormat()) {
  97. return false;
  98. }
  99. return FusionFormatSupport(in_op_desc->GetFormat());
  100. }
  101. graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphNodesInfo() {
  102. vector<bool> sub_graph_has_reshape_node(sub_graph_anchors_.size(), false);
  103. vector<int> transop_num_count(sub_graph_anchors_.size(), 0);
  104. vector<vector<NodePtr>> sub_graph_nodes(sub_graph_anchors_.size());
  105. for (size_t i = 0; i < sub_graph_anchors_.size(); ++i) {
  106. auto nodes_anchor = sub_graph_anchors_[i];
  107. vector<NodePtr> nodes_tmp;
  108. auto iter = nodes_anchor.begin();
  109. auto first_out_anchor = iter->first;
  110. if (first_out_anchor == nullptr) {
  111. continue;
  112. }
  113. nodes_tmp.push_back(first_out_anchor->GetOwnerNode());
  114. while (iter != nodes_anchor.end()) {
  115. auto in_anchor = iter->second;
  116. GE_CHECK_NOTNULL(in_anchor);
  117. auto in_node = in_anchor->GetOwnerNode();
  118. GE_CHECK_NOTNULL(in_node);
  119. if (in_node->GetType() == RESHAPE) {
  120. sub_graph_has_reshape_node[i] = true;
  121. break;
  122. }
  123. if (in_node->GetType() == TRANSPOSE || in_node->GetType() == TRANSPOSED) {
  124. auto input_format = in_node->GetOpDesc()->GetInputDescPtr(0)->GetFormat();
  125. auto output_format = in_node->GetOpDesc()->GetOutputDescPtr(0)->GetFormat();
  126. if (input_format == output_format) {
  127. sub_graph_has_reshape_node[i] = true;
  128. break;
  129. }
  130. }
  131. auto out_anchor = iter->first;
  132. GE_CHECK_NOTNULL(out_anchor);
  133. if (!FormatContinuousCheck(out_anchor, in_anchor)) {
  134. sub_graph_has_reshape_node[i] = true;
  135. break;
  136. }
  137. nodes_tmp.push_back(in_node);
  138. if (IsTransOp(in_node)) {
  139. // count transop num
  140. transop_num_count[i]++;
  141. }
  142. ++iter;
  143. }
  144. sub_graph_nodes[i].swap(nodes_tmp);
  145. if (sub_graph_has_reshape_node[i]) {
  146. SetRemainNode(nodes_anchor);
  147. }
  148. }
  149. sub_graph_has_reshape_node_.swap(sub_graph_has_reshape_node);
  150. transop_num_count_.swap(transop_num_count);
  151. sub_graph_nodes_.swap(sub_graph_nodes);
  152. return GRAPH_SUCCESS;
  153. }
  154. void TransOpWithoutReshapeFusionPass::GetOutDataPeerInControlAnchors(
  155. const size_t index, vector<vector<InControlAnchorPtr>> &out_data_peer_in_control_anchors) {
  156. // The caller guarantees that the index is legal.
  157. for (size_t j = 1; j < sub_graph_anchors_[index].size(); ++j) {
  158. auto nodes_anchor = sub_graph_anchors_[index][j];
  159. auto out_data_anchor = nodes_anchor.first;
  160. GE_CHECK_NOTNULL_JUST_RETURN(out_data_anchor);
  161. for (const auto &peer_in_control_anchor : out_data_anchor->GetPeerInControlAnchors()) {
  162. GE_CHECK_NOTNULL_JUST_RETURN(peer_in_control_anchor);
  163. auto peer_node = peer_in_control_anchor->GetOwnerNode();
  164. if (peer_node == nullptr) {
  165. continue;
  166. }
  167. auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node);
  168. if (iter == sub_graph_nodes_[index].end()) {
  169. out_data_peer_in_control_anchors[index].push_back(peer_in_control_anchor);
  170. } else {
  171. sub_graph_has_out_data_peer_in_control_edge_[index] = true;
  172. }
  173. }
  174. }
  175. }
  176. void TransOpWithoutReshapeFusionPass::GetInControlPeerOutControlAnchors(
  177. const size_t index, vector<vector<OutControlAnchorPtr>> &in_control_peer_out_control_anchors) {
  178. // The caller guarantees that the index is legal.
  179. for (size_t j = 1; j < (sub_graph_nodes_[index].size() - 1); ++j) {
  180. auto node = sub_graph_nodes_[index][j];
  181. GE_CHECK_NOTNULL_JUST_RETURN(node);
  182. auto in_control_anchor = node->GetInControlAnchor();
  183. if (in_control_anchor == nullptr) {
  184. continue;
  185. }
  186. for (const auto &peer_out_anchor : in_control_anchor->GetPeerOutControlAnchors()) {
  187. GE_CHECK_NOTNULL_JUST_RETURN(peer_out_anchor);
  188. auto peer_node = peer_out_anchor->GetOwnerNode();
  189. if (peer_node == nullptr) {
  190. continue;
  191. }
  192. auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node);
  193. if (iter == sub_graph_nodes_[index].end()) {
  194. in_control_peer_out_control_anchors[index].push_back(peer_out_anchor);
  195. } else {
  196. sub_graph_has_control_edge_[index] = true;
  197. }
  198. }
  199. }
  200. }
  201. void TransOpWithoutReshapeFusionPass::GetOutControlPeerAnchors(
  202. const size_t index, vector<vector<InControlAnchorPtr>> &out_control_peer_in_control_anchors,
  203. vector<vector<InDataAnchorPtr>> &out_control_peer_in_data_anchors) {
  204. for (size_t j = 0; j < sub_graph_nodes_[index].size() - 1; ++j) {
  205. auto node = sub_graph_nodes_[index][j];
  206. GE_CHECK_NOTNULL_JUST_RETURN(node);
  207. auto out_control_anchor = node->GetOutControlAnchor();
  208. GE_CHECK_NOTNULL_JUST_RETURN(out_control_anchor);
  209. for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) {
  210. GE_CHECK_NOTNULL_JUST_RETURN(peer_in_anchor);
  211. auto peer_node = peer_in_anchor->GetOwnerNode();
  212. if (peer_node == nullptr) {
  213. continue;
  214. }
  215. auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node);
  216. if (iter == sub_graph_nodes_[index].end()) {
  217. if (j > 0) {
  218. out_control_peer_in_control_anchors[index].push_back(peer_in_anchor);
  219. }
  220. } else {
  221. sub_graph_has_control_edge_[index] = true;
  222. }
  223. }
  224. for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) {
  225. GE_CHECK_NOTNULL_JUST_RETURN(peer_in_anchor);
  226. auto peer_node = peer_in_anchor->GetOwnerNode();
  227. if (peer_node == nullptr) {
  228. continue;
  229. }
  230. auto iter = std::find(sub_graph_nodes_[index].begin(), sub_graph_nodes_[index].end(), peer_node);
  231. if (iter == sub_graph_nodes_[index].end()) {
  232. if (j > 0) {
  233. out_control_peer_in_data_anchors[index].push_back(peer_in_anchor);
  234. }
  235. } else {
  236. sub_graph_has_control_edge_[index] = true;
  237. }
  238. }
  239. }
  240. }
  241. void TransOpWithoutReshapeFusionPass::GetControlAnchors() {
  242. vector<vector<OutControlAnchorPtr>> in_control_peer_out_control_anchors(sub_graph_nodes_.size());
  243. vector<vector<InControlAnchorPtr>> out_control_peer_in_control_anchors(sub_graph_nodes_.size());
  244. vector<vector<InDataAnchorPtr>> out_control_peer_in_data_anchors(sub_graph_nodes_.size());
  245. vector<vector<InControlAnchorPtr>> out_data_peer_in_control_anchors(sub_graph_nodes_.size());
  246. vector<bool> sub_graph_has_control_edge(sub_graph_nodes_.size(), false);
  247. sub_graph_has_control_edge_.swap(sub_graph_has_control_edge);
  248. vector<bool> sub_graph_has_out_data_peer_in_control_edge(sub_graph_nodes_.size(), false);
  249. sub_graph_has_out_data_peer_in_control_edge_.swap(sub_graph_has_out_data_peer_in_control_edge);
  250. for (size_t i = 0; i < sub_graph_nodes_.size(); ++i) {
  251. if (sub_graph_has_reshape_node_[i]) {
  252. continue;
  253. }
  254. GetOutDataPeerInControlAnchors(i, out_data_peer_in_control_anchors);
  255. GetInControlPeerOutControlAnchors(i, in_control_peer_out_control_anchors);
  256. GetOutControlPeerAnchors(i, out_control_peer_in_control_anchors, out_control_peer_in_data_anchors);
  257. }
  258. in_control_peer_out_control_anchors_.swap(in_control_peer_out_control_anchors);
  259. out_control_peer_in_control_anchors_.swap(out_control_peer_in_control_anchors);
  260. out_control_peer_in_data_anchors_.swap(out_control_peer_in_data_anchors);
  261. out_data_peer_in_control_anchors_.swap(out_data_peer_in_control_anchors);
  262. }
  263. void TransOpWithoutReshapeFusionPass::EraseInvalidAnchorsPair() {
  264. auto sub_graph_iter = sub_graph_anchors_.begin();
  265. while (sub_graph_iter != sub_graph_anchors_.end()) {
  266. if (sub_graph_iter->size() <= 1) {
  267. sub_graph_iter = sub_graph_anchors_.erase(sub_graph_iter);
  268. } else {
  269. ++sub_graph_iter;
  270. }
  271. }
  272. }
  273. void TransOpWithoutReshapeFusionPass::UpdateOutputName(const OutDataAnchorPtr &out_anchor,
  274. const InDataAnchorPtr &old_peer_in_anchor,
  275. const NodePtr &in_owner_node) {
  276. if (out_anchor == nullptr || old_peer_in_anchor == nullptr || in_owner_node == nullptr) {
  277. GELOGI("out_anchor or old_peer_in_anchor or in_owner_node is nullptr");
  278. return;
  279. }
  280. auto out_owner_node = out_anchor->GetOwnerNode();
  281. GE_CHECK_NOTNULL_JUST_RETURN(out_owner_node);
  282. GE_CHECK_NOTNULL_JUST_RETURN(old_peer_in_anchor->GetOwnerNode());
  283. auto old_peer_in_name = old_peer_in_anchor->GetOwnerNode()->GetName();
  284. auto output_op = out_owner_node->GetOpDesc();
  285. GE_CHECK_NOTNULL_JUST_RETURN(output_op);
  286. auto output_names = output_op->GetAllOutputName();
  287. auto old_peer_in_name_iter = output_names.find(old_peer_in_name);
  288. if (old_peer_in_name_iter != output_names.end()) {
  289. output_names.erase(old_peer_in_name_iter);
  290. }
  291. output_names[in_owner_node->GetName()] = out_anchor->GetIdx();
  292. if (!output_op->UpdateOutputName(output_names)) {
  293. GELOGW("output_op UpdateOutputName failed");
  294. }
  295. }
  296. void TransOpWithoutReshapeFusionPass::UpdateInputName(const OutDataAnchorPtr &old_peer_out_anchor,
  297. const InDataAnchorPtr &in_anchor, const NodePtr &out_owner_node) {
  298. if (old_peer_out_anchor == nullptr || in_anchor == nullptr || out_owner_node == nullptr) {
  299. GELOGI("old_peer_out_anchor or in_anchor or out_owner_node is nullptr");
  300. return;
  301. }
  302. auto old_node = old_peer_out_anchor->GetOwnerNode();
  303. GE_CHECK_NOTNULL_JUST_RETURN(old_node);
  304. auto old_peer_out_name = old_node->GetName();
  305. auto in_owner_node = in_anchor->GetOwnerNode();
  306. GE_CHECK_NOTNULL_JUST_RETURN(in_owner_node);
  307. auto input_op = in_owner_node->GetOpDesc();
  308. GE_CHECK_NOTNULL_JUST_RETURN(input_op);
  309. auto input_names = input_op->GetAllInputName();
  310. auto old_peer_out_name_iter = input_names.find(old_peer_out_name);
  311. if (old_peer_out_name_iter != input_names.end()) {
  312. input_names.erase(old_peer_out_name_iter);
  313. }
  314. input_names[out_owner_node->GetName()] = in_anchor->GetIdx();
  315. input_op->UpdateInputName(input_names);
  316. }
  317. graphStatus TransOpWithoutReshapeFusionPass::RelinkSubGraphControlEdges(
  318. const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
  319. const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
  320. auto out_anchor = begin_anchors_pair.first;
  321. GE_CHECK_NOTNULL(out_anchor);
  322. auto out_owner_node = out_anchor->GetOwnerNode();
  323. GE_CHECK_NOTNULL(out_owner_node);
  324. auto in_anchor = end_anchors_pair.second;
  325. GE_CHECK_NOTNULL(in_anchor);
  326. auto in_owner_node = in_anchor->GetOwnerNode();
  327. GE_CHECK_NOTNULL(in_owner_node);
  328. if (sub_graph_has_control_edge_[index]) {
  329. GELOGI("add control edge.src:%s, dst:%s", out_owner_node->GetName().c_str(), in_owner_node->GetName().c_str());
  330. if (GraphUtils::AddEdge(out_owner_node->GetOutControlAnchor(), in_owner_node->GetInControlAnchor()) !=
  331. GRAPH_SUCCESS) {
  332. return GRAPH_FAILED;
  333. }
  334. }
  335. if (sub_graph_has_out_data_peer_in_control_edge_[index]) {
  336. GELOGI("add out data 2 in contorl edge.src:%s, dst:%s", out_owner_node->GetName().c_str(),
  337. in_owner_node->GetName().c_str());
  338. if (GraphUtils::AddEdge(out_anchor, in_owner_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
  339. return GRAPH_FAILED;
  340. }
  341. }
  342. return GRAPH_SUCCESS;
  343. }
  344. graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdgesWhenDescNotChanged(
  345. const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
  346. const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
  347. if (RelinkSubGraphControlEdges(begin_anchors_pair, end_anchors_pair, index) != GRAPH_SUCCESS) {
  348. return GRAPH_FAILED;
  349. }
  350. auto out_anchor = begin_anchors_pair.first;
  351. GE_CHECK_NOTNULL(out_anchor);
  352. auto out_owner_node = out_anchor->GetOwnerNode();
  353. GE_CHECK_NOTNULL(out_owner_node);
  354. auto in_anchor = end_anchors_pair.second;
  355. GE_CHECK_NOTNULL(in_anchor);
  356. auto in_owner_node = in_anchor->GetOwnerNode();
  357. GE_CHECK_NOTNULL(in_owner_node);
  358. // can not remove old control edge
  359. for (const auto &peer_in_anchor : out_control_peer_in_control_anchors_[index]) {
  360. GE_CHECK_NOTNULL(peer_in_anchor);
  361. GELOGI("add control edge.src:%s, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  362. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  363. if (GraphUtils::AddEdge(out_owner_node->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) {
  364. return GRAPH_FAILED;
  365. }
  366. }
  367. for (const auto &peer_out_anchor : in_control_peer_out_control_anchors_[index]) {
  368. GE_CHECK_NOTNULL(peer_out_anchor);
  369. GELOGI("add control edge.src:%s, src idx:%d, dst:%s", peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  370. peer_out_anchor->GetIdx(), in_owner_node->GetName().c_str());
  371. if (GraphUtils::AddEdge(peer_out_anchor, in_owner_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
  372. return GRAPH_FAILED;
  373. }
  374. }
  375. for (const auto &peer_in_anchor : out_control_peer_in_data_anchors_[index]) {
  376. GE_CHECK_NOTNULL(peer_in_anchor);
  377. GELOGI("add out control 2 in data edge.src:%s, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  378. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  379. if (GraphUtils::AddEdge(out_owner_node->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) {
  380. return GRAPH_FAILED;
  381. }
  382. }
  383. for (const auto &peer_in_anchor : out_data_peer_in_control_anchors_[index]) {
  384. GE_CHECK_NOTNULL(peer_in_anchor);
  385. GELOGI("add out data 2 in control edge.src:%s, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  386. peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
  387. if (GraphUtils::AddEdge(out_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
  388. return GRAPH_FAILED;
  389. }
  390. }
  391. return GRAPH_SUCCESS;
  392. }
  393. graphStatus TransOpWithoutReshapeFusionPass::RelinkNodesWhenDescNotChanged(
  394. const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
  395. const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair, const int index) {
  396. auto out_anchor = begin_anchors_pair.first;
  397. GE_CHECK_NOTNULL(out_anchor);
  398. auto out_owner_node = out_anchor->GetOwnerNode();
  399. GE_CHECK_NOTNULL(out_owner_node);
  400. auto in_anchor = end_anchors_pair.second;
  401. GE_CHECK_NOTNULL(in_anchor);
  402. auto in_owner_node = in_anchor->GetOwnerNode();
  403. GE_CHECK_NOTNULL(in_owner_node);
  404. GELOGI("remove edge.src %s, src idx:%d, dst:%s, dst idx:%d",
  405. end_anchors_pair.first->GetOwnerNode()->GetName().c_str(), end_anchors_pair.first->GetIdx(),
  406. in_owner_node->GetName().c_str(), in_anchor->GetIdx());
  407. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(end_anchors_pair.first, in_anchor), "remove edge failed");
  408. GELOGI("relink node.src node:%s, src idx:%d, dst node:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  409. out_anchor->GetIdx(), in_owner_node->GetName().c_str(), in_anchor->GetIdx());
  410. if (GraphUtils::AddEdge(out_anchor, in_anchor) != GRAPH_SUCCESS) {
  411. GELOGE(GRAPH_FAILED, "add edge failed!src:%s, src idx:%d, dst:%s, dst idx:%d", out_owner_node->GetName().c_str(),
  412. out_anchor->GetIdx(), in_owner_node->GetName().c_str(), in_anchor->GetIdx());
  413. return GRAPH_FAILED;
  414. } else {
  415. auto old_peer_in_anchor = begin_anchors_pair.second;
  416. UpdateOutputName(out_anchor, old_peer_in_anchor, in_owner_node);
  417. auto old_peer_out_anchor = end_anchors_pair.first;
  418. UpdateInputName(old_peer_out_anchor, in_anchor, out_owner_node);
  419. }
  420. return RelinkControlEdgesWhenDescNotChanged(begin_anchors_pair, end_anchors_pair, index);
  421. }
  422. OpDescPtr TransOpWithoutReshapeFusionPass::GetFormatTransferOp(const GeTensorDesc &format_trans_input_desc,
  423. const GeTensorDesc &format_trans_output_desc) {
  424. static std::atomic_long atomic_fusion_format_transfer_op_count(1);
  425. auto fusion_format_transfer_op_count = atomic_fusion_format_transfer_op_count.fetch_add(1);
  426. std::stringstream format_transfer_op_name;
  427. format_transfer_op_name << "fusion_format_transfer_" << fusion_format_transfer_op_count;
  428. OpDescPtr format_transfer_op = MakeShared<OpDesc>(format_transfer_op_name.str().c_str(), TRANSDATA);
  429. if (format_transfer_op == nullptr) {
  430. GELOGE(INTERNAL_ERROR, "new format transfer op failed!");
  431. return nullptr;
  432. }
  433. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(format_transfer_op, ATTR_NAME_INPUT_FORMAT,
  434. static_cast<int64_t>(format_trans_input_desc.GetFormat())),
  435. GELOGE(INTERNAL_ERROR, "set ATTR_NAME_INPUT_FORMAT failed");
  436. return nullptr);
  437. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(format_transfer_op, ATTR_NAME_OUTPUT_FORMAT,
  438. static_cast<int64_t>(format_trans_output_desc.GetFormat())),
  439. GELOGE(INTERNAL_ERROR, "set ATTR_NAME_OUTPUT_FORMAT failed");
  440. return nullptr);
  441. string src_format = TypeUtils::FormatToSerialString(format_trans_input_desc.GetFormat());
  442. string dst_format = TypeUtils::FormatToSerialString(format_trans_output_desc.GetFormat());
  443. GE_IF_BOOL_EXEC(!AttrUtils::SetStr(format_transfer_op, kAttrNameSrcFormat, src_format),
  444. GELOGE(INTERNAL_ERROR, "set kAttrNameSrcFormat failed");
  445. return nullptr);
  446. GE_IF_BOOL_EXEC(!AttrUtils::SetStr(format_transfer_op, kAttrNameDstFormat, dst_format),
  447. GELOGE(INTERNAL_ERROR, "set kAttrNameDstFormat failed");
  448. return nullptr);
  449. GE_IF_BOOL_EXEC(format_transfer_op->AddInputDesc(format_trans_input_desc) != GRAPH_SUCCESS,
  450. GELOGE(INTERNAL_ERROR, "add input desc failed");
  451. return nullptr);
  452. GE_IF_BOOL_EXEC(format_transfer_op->AddOutputDesc(format_trans_output_desc) != GRAPH_SUCCESS,
  453. GELOGE(INTERNAL_ERROR, "add output desc failed");
  454. return nullptr);
  455. GE_IF_BOOL_EXEC(!ge::AttrUtils::SetBool(format_transfer_op, ATTR_NEED_COMPILE, true),
  456. GELOGE(INTERNAL_ERROR, "set ext attr failed");
  457. return nullptr);
  458. return format_transfer_op;
  459. }
  460. OpDescPtr TransOpWithoutReshapeFusionPass::GetCastOp(const GeTensorDesc &cast_input_desc,
  461. const GeTensorDesc &cast_output_desc) {
  462. static std::atomic_long atomic_fusion_cast_op_count(1);
  463. auto fusion_cast_op_count = atomic_fusion_cast_op_count.fetch_add(1);
  464. std::stringstream cast_op_name;
  465. cast_op_name << "fusion_cast_op_" << fusion_cast_op_count;
  466. auto node_op = ge::OperatorFactory::CreateOperator(cast_op_name.str(), CAST);
  467. auto cast_op = ge::OpDescUtils::GetOpDescFromOperator(node_op);
  468. node_op.BreakConnect();
  469. if (cast_op == nullptr) {
  470. GELOGE(INTERNAL_ERROR, "new cast op failed!");
  471. return nullptr;
  472. }
  473. const int default_input_index = 0;
  474. const int default_output_index = 0;
  475. if (cast_op->GetInputsSize() == 0) {
  476. GE_IF_BOOL_EXEC(cast_op->AddInputDesc(cast_input_desc) != GRAPH_SUCCESS,
  477. GELOGE(INTERNAL_ERROR, "add input desc failed");
  478. return nullptr);
  479. } else {
  480. GE_IF_BOOL_EXEC(cast_op->UpdateInputDesc(default_input_index, cast_input_desc) != GRAPH_SUCCESS,
  481. GELOGE(INTERNAL_ERROR, "update input desc failed");
  482. return nullptr);
  483. }
  484. if (cast_op->GetOutputsSize() == 0) {
  485. GE_IF_BOOL_EXEC(cast_op->AddOutputDesc(cast_output_desc) != GRAPH_SUCCESS,
  486. GELOGE(INTERNAL_ERROR, "add output desc failed");
  487. return nullptr);
  488. } else {
  489. GE_IF_BOOL_EXEC(cast_op->UpdateOutputDesc(default_output_index, cast_output_desc) != GRAPH_SUCCESS,
  490. GELOGE(INTERNAL_ERROR, "update output desc failed");
  491. return nullptr);
  492. }
  493. if (!AttrUtils::SetInt(cast_op, CAST_ATTR_DST_TYPE, static_cast<int64_t>(cast_output_desc.GetDataType()))) {
  494. GELOGE(INTERNAL_ERROR, "set dst_type attr failed");
  495. return nullptr;
  496. }
  497. if (!AttrUtils::SetBool(cast_op, ATTR_NEED_COMPILE, true)) {
  498. GELOGE(INTERNAL_ERROR, "set need_compile attr failed");
  499. return nullptr;
  500. }
  501. return cast_op;
  502. }
  503. bool TransOpWithoutReshapeFusionPass::InsertCastFirstCheck(const GeTensorDesc &out_desc,
  504. const GeTensorDesc &in_desc) const {
  505. return out_desc.GetDataType() != in_desc.GetDataType() && out_desc.GetDataType() != DT_FLOAT16 &&
  506. in_desc.GetDataType() == DT_FLOAT16;
  507. }
  508. void TransOpWithoutReshapeFusionPass::GetFormatTransferDesc(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc,
  509. GeTensorDesc &format_transfer_input,
  510. GeTensorDesc &format_transfer_output) {
  511. bool insert_cast_first = InsertCastFirstCheck(out_desc, in_desc);
  512. if (insert_cast_first) {
  513. format_transfer_input = out_desc;
  514. format_transfer_input.SetDataType(in_desc.GetDataType());
  515. format_transfer_output = in_desc;
  516. } else {
  517. format_transfer_input = out_desc;
  518. format_transfer_output = in_desc;
  519. format_transfer_output.SetDataType(out_desc.GetDataType());
  520. }
  521. }
  522. void TransOpWithoutReshapeFusionPass::GetCastOpDesc(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc,
  523. GeTensorDesc &cast_input, GeTensorDesc &cast_output) {
  524. bool insert_cast_first = InsertCastFirstCheck(out_desc, in_desc);
  525. if (insert_cast_first) {
  526. cast_input = out_desc;
  527. cast_output = out_desc;
  528. cast_output.SetDataType(in_desc.GetDataType());
  529. } else {
  530. cast_input = in_desc;
  531. cast_input.SetDataType(out_desc.GetDataType());
  532. cast_output = in_desc;
  533. }
  534. }
  535. void TransOpWithoutReshapeFusionPass::GetBeginOutDescAndEndInDesc(const int index, GeTensorDesc &out_desc,
  536. GeTensorDesc &in_desc) {
  537. auto nodes_anchor = sub_graph_anchors_[index];
  538. auto out_peer_anchor = nodes_anchor.front().second;
  539. GE_CHECK_NOTNULL_JUST_RETURN(out_peer_anchor);
  540. auto out_owner_node = out_peer_anchor->GetOwnerNode();
  541. GE_CHECK_NOTNULL_JUST_RETURN(out_owner_node);
  542. auto out_peer_op_desc = out_owner_node->GetOpDesc();
  543. GE_IF_BOOL_EXEC(out_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "out_peer_op_desc is nullptr"); return );
  544. out_desc = out_peer_op_desc->GetInputDesc(out_peer_anchor->GetIdx());
  545. auto in_peer_anchor = nodes_anchor.back().first;
  546. GE_CHECK_NOTNULL_JUST_RETURN(in_peer_anchor);
  547. auto in_owner_node = in_peer_anchor->GetOwnerNode();
  548. GE_CHECK_NOTNULL_JUST_RETURN(in_owner_node);
  549. auto in_peer_op_desc = in_owner_node->GetOpDesc();
  550. GE_IF_BOOL_EXEC(in_peer_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "in_peer_op_desc is nullptr"); return );
  551. in_desc = in_peer_op_desc->GetOutputDesc(in_peer_anchor->GetIdx());
  552. }
  553. graphStatus TransOpWithoutReshapeFusionPass::FormatFusion(const int index, OpDescPtr &format_transfer_op,
  554. int32_t &fusion_op_count, bool &fusion_continue) {
  555. GeTensorDesc out_desc;
  556. GeTensorDesc in_desc;
  557. GetBeginOutDescAndEndInDesc(index, out_desc, in_desc);
  558. GeTensorDesc format_transfer_input;
  559. GeTensorDesc format_transfer_output;
  560. GetFormatTransferDesc(out_desc, in_desc, format_transfer_input, format_transfer_output);
  561. if (out_desc.GetFormat() == in_desc.GetFormat() &&
  562. (!ShapeEqualCheck(out_desc.GetShape(), in_desc.GetShape()) ||
  563. !ShapeEqualCheck(out_desc.GetOriginShape(), in_desc.GetOriginShape()))) {
  564. SetRemainNode(sub_graph_anchors_[index]);
  565. return GRAPH_SUCCESS;
  566. }
  567. if (out_desc.GetFormat() != in_desc.GetFormat() && FusionFormatSupport(out_desc.GetFormat()) &&
  568. FusionFormatSupport(in_desc.GetFormat())) {
  569. // create format transop
  570. format_transfer_op = GetFormatTransferOp(format_transfer_input, format_transfer_output);
  571. if (format_transfer_op == nullptr) {
  572. return GRAPH_FAILED;
  573. }
  574. if (OpAccuracyAbilityCheck(format_transfer_op)) {
  575. ++fusion_op_count;
  576. GELOGI("support format transfer op %s", format_transfer_op->GetName().c_str());
  577. } else {
  578. GELOGW("ability not support.src format:%d, src datatype:%d, dst format:%d, dst datatype:%d",
  579. format_transfer_input.GetFormat(), format_transfer_input.GetDataType(), format_transfer_output.GetFormat(),
  580. format_transfer_output.GetDataType());
  581. fusion_op_count = kInvalidFusionOpCount;
  582. }
  583. } else if (out_desc.GetFormat() != in_desc.GetFormat()) {
  584. SetRemainNode(sub_graph_anchors_[index]);
  585. return GRAPH_SUCCESS;
  586. }
  587. fusion_continue = true;
  588. return GRAPH_SUCCESS;
  589. }
  590. graphStatus TransOpWithoutReshapeFusionPass::DataTypeFusion(const int index, OpDescPtr &cast_op,
  591. int32_t &fusion_op_count) {
  592. GeTensorDesc out_desc;
  593. GeTensorDesc in_desc;
  594. GetBeginOutDescAndEndInDesc(index, out_desc, in_desc);
  595. GeTensorDesc cast_input;
  596. GeTensorDesc cast_output;
  597. GetCastOpDesc(out_desc, in_desc, cast_input, cast_output);
  598. if (fusion_op_count != kInvalidFusionOpCount && out_desc.GetDataType() != in_desc.GetDataType()) {
  599. // create cast op
  600. cast_op = GetCastOp(cast_input, cast_output);
  601. if (cast_op == nullptr) {
  602. fusion_op_count = kInvalidFusionOpCount;
  603. return GRAPH_FAILED;
  604. }
  605. if (OpAccuracyAbilityCheck(cast_op)) {
  606. ++fusion_op_count;
  607. GELOGI("support cast op %s. src format:%d, src datatype:%d, dst format:%d, dst datatype:%d",
  608. cast_op->GetName().c_str(), cast_input.GetFormat(), cast_input.GetDataType(), cast_output.GetFormat(),
  609. cast_output.GetDataType());
  610. } else {
  611. GELOGW("ability not support.src format:%d, src datatype:%d, dst format:%d, dst datatype:%d",
  612. cast_input.GetFormat(), cast_input.GetDataType(), cast_output.GetFormat(), cast_output.GetDataType());
  613. fusion_op_count = kInvalidFusionOpCount;
  614. }
  615. }
  616. return GRAPH_SUCCESS;
  617. }
  618. graphStatus TransOpWithoutReshapeFusionPass::TransOpFuseHandle(const ComputeGraphPtr &graph, const int index) {
  619. bool fusion_continue = false;
  620. OpDescPtr format_transfer_op = nullptr;
  621. int32_t fusion_op_count = 0;
  622. auto fortmat_fusion_ret = FormatFusion(index, format_transfer_op, fusion_op_count, fusion_continue);
  623. if (fortmat_fusion_ret != GRAPH_SUCCESS || !fusion_continue) {
  624. SetRemainNode(sub_graph_anchors_[index]);
  625. return GRAPH_SUCCESS;
  626. }
  627. OpDescPtr cast_op = nullptr;
  628. if (DataTypeFusion(index, cast_op, fusion_op_count) != GRAPH_SUCCESS) {
  629. SetRemainNode(sub_graph_anchors_[index]);
  630. return GRAPH_SUCCESS;
  631. }
  632. if (fusion_op_count != kInvalidFusionOpCount && fusion_op_count < transop_num_count_[index]) {
  633. GeTensorDesc out_desc;
  634. GeTensorDesc in_desc;
  635. GetBeginOutDescAndEndInDesc(index, out_desc, in_desc);
  636. bool insert_cast_first = InsertCastFirstCheck(out_desc, in_desc);
  637. if (InsertNewTransOp(graph, cast_op, format_transfer_op, index, insert_cast_first) != GRAPH_SUCCESS) {
  638. return GRAPH_FAILED;
  639. }
  640. } else {
  641. // remain all nodes
  642. SetRemainNode(sub_graph_anchors_[index]);
  643. }
  644. return GRAPH_SUCCESS;
  645. }
  646. void TransOpWithoutReshapeFusionPass::RemoveNousedNodes(const ComputeGraphPtr &graph) {
  647. if (graph == nullptr) {
  648. return;
  649. }
  650. for (size_t i = 0; i < sub_graph_nodes_.size(); ++i) {
  651. if (sub_graph_has_reshape_node_[i]) {
  652. continue;
  653. }
  654. for (const auto &node : sub_graph_nodes_[i]) {
  655. GE_CHECK_NOTNULL_JUST_RETURN(node);
  656. // remove nodes
  657. if (!IsTransOp(node)) {
  658. continue;
  659. }
  660. auto op_desc = node->GetOpDesc();
  661. GE_CHECK_NOTNULL_JUST_RETURN(op_desc);
  662. bool node_remain_flag = op_desc->TryGetExtAttr(kRemainNode, false);
  663. if (node_remain_flag) {
  664. continue;
  665. }
  666. GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(kRemainNode, true), GELOGE(INTERNAL_ERROR, "set ext attr failed"); return );
  667. GELOGI("remove node:%s", node->GetName().c_str());
  668. if (graph->RemoveNode(node) != GRAPH_SUCCESS) {
  669. GELOGW("remove node failed!node:%s", node->GetName().c_str());
  670. continue;
  671. }
  672. }
  673. }
  674. }
  675. graphStatus TransOpWithoutReshapeFusionPass::Run(ComputeGraphPtr graph) {
  676. GELOGI("[TransOpWithoutReshapeFusionPass]: optimize begin.");
  677. if (graph == nullptr) {
  678. return GRAPH_SUCCESS;
  679. }
  680. for (const auto &node : graph->GetDirectNode()) {
  681. GE_CHECK_NOTNULL(node);
  682. if (IsTransOp(node)) {
  683. continue;
  684. }
  685. bool is_unknown = false;
  686. auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown);
  687. if (ret != GRAPH_SUCCESS) {
  688. GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(),
  689. node->GetType().c_str());
  690. continue;
  691. }
  692. if (is_unknown) {
  693. GELOGI("Current node %s, type %s is unknown shape which should be skip.", node->GetName().c_str(),
  694. node->GetType().c_str());
  695. continue;
  696. }
  697. GELOGI("Current normal node name: %s, type: %s.", node->GetName().c_str(), node->GetType().c_str());
  698. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  699. GE_CHECK_NOTNULL(out_anchor);
  700. vector<vector<pair<OutDataAnchorPtr, InDataAnchorPtr>>> sub_graph_anchors;
  701. vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> nodes_list;
  702. if (GetSubGraphsBetweenNormalNode(out_anchor, sub_graph_anchors, nodes_list) != GRAPH_SUCCESS) {
  703. GELOGW("get transops failed!");
  704. continue;
  705. }
  706. sub_graph_anchors_.swap(sub_graph_anchors);
  707. EraseInvalidAnchorsPair();
  708. if (sub_graph_anchors_.empty()) {
  709. continue;
  710. }
  711. // check reshape node
  712. if (GetSubGraphNodesInfo() != GRAPH_SUCCESS) {
  713. continue;
  714. }
  715. // save control edge
  716. GetControlAnchors();
  717. if (TransOpFuse(graph) != GRAPH_SUCCESS) {
  718. return GRAPH_FAILED;
  719. }
  720. }
  721. }
  722. GELOGI("[TransOpWithoutReshapeFusionPass]: Optimize end.");
  723. return GRAPH_SUCCESS;
  724. }
  725. bool TransOpWithoutReshapeFusionPass::DescEqualCheck(ConstGeTensorDescPtr &desc_src,
  726. ConstGeTensorDescPtr &desc_dst) const {
  727. if (desc_src == nullptr || desc_dst == nullptr) {
  728. return false;
  729. }
  730. if (desc_src->GetFormat() != desc_dst->GetFormat() || desc_src->GetDataType() != desc_dst->GetDataType()) {
  731. return false;
  732. }
  733. if (!ShapeEqualCheck(desc_src->GetShape(), desc_dst->GetShape())) {
  734. return false;
  735. }
  736. return ShapeEqualCheck(desc_src->GetOriginShape(), desc_dst->GetOriginShape());
  737. }
  738. bool TransOpWithoutReshapeFusionPass::ShapeEqualCheck(const GeShape &src, const GeShape &dst) const {
  739. if (src.GetDims().size() != dst.GetDims().size()) {
  740. return false;
  741. }
  742. for (size_t i = 0; i < src.GetDims().size(); ++i) {
  743. if (src.GetDim(i) != dst.GetDim(i)) {
  744. return false;
  745. }
  746. }
  747. return true;
  748. }
  749. graphStatus TransOpWithoutReshapeFusionPass::TransOpFuse(const ComputeGraphPtr &graph) {
  750. for (size_t i = 0; i < sub_graph_anchors_.size(); ++i) {
  751. if (sub_graph_has_reshape_node_[i]) {
  752. continue;
  753. }
  754. auto nodes_anchor = sub_graph_anchors_[i];
  755. auto out_anchor = nodes_anchor.front().first;
  756. GE_CHECK_NOTNULL(out_anchor);
  757. auto out_op_desc = out_anchor->GetOwnerNode()->GetOpDesc();
  758. GE_CHECK_NOTNULL(out_op_desc);
  759. auto out_desc = out_op_desc->GetOutputDescPtr(out_anchor->GetIdx());
  760. GE_CHECK_NOTNULL(out_desc);
  761. auto in_anchor = nodes_anchor.back().second;
  762. GE_CHECK_NOTNULL(in_anchor);
  763. auto in_op_desc = in_anchor->GetOwnerNode()->GetOpDesc();
  764. GE_CHECK_NOTNULL(in_op_desc);
  765. auto in_desc = in_op_desc->GetInputDescPtr(in_anchor->GetIdx());
  766. GE_CHECK_NOTNULL(in_desc);
  767. if (FusionFormatSupport(out_desc->GetFormat()) && DescEqualCheck(out_desc, in_desc)) {
  768. // relink begin_out to end_in
  769. if (RelinkNodesWhenDescNotChanged(nodes_anchor.front(), nodes_anchor.back(), static_cast<int>(i)) !=
  770. GRAPH_SUCCESS) {
  771. return GRAPH_FAILED;
  772. }
  773. } else {
  774. if (TransOpFuseHandle(graph, static_cast<int>(i)) != GRAPH_SUCCESS) {
  775. return GRAPH_FAILED;
  776. }
  777. }
  778. }
  779. RemoveNousedNodes(graph);
  780. return GRAPH_SUCCESS;
  781. }
  782. graphStatus TransOpWithoutReshapeFusionPass::AddTransNode(const ComputeGraphPtr &graph, const OpDescPtr &transop,
  783. NodePtr &trans_node) {
  784. if (graph == nullptr) {
  785. return GRAPH_SUCCESS;
  786. }
  787. if (transop == nullptr) {
  788. return GRAPH_SUCCESS;
  789. }
  790. trans_node = graph->AddNode(transop);
  791. if (trans_node == nullptr) {
  792. GELOGE(GRAPH_FAILED, "add node failed!");
  793. return GRAPH_FAILED;
  794. }
  795. return GRAPH_SUCCESS;
  796. }
  797. graphStatus TransOpWithoutReshapeFusionPass::GetTransNode(const ComputeGraphPtr &graph, const OpDescPtr &cast_op,
  798. const OpDescPtr &format_transfer_op,
  799. const bool insert_cast_first,
  800. std::vector<NodePtr> &new_trans_nodes) {
  801. NodePtr format_transfer_node;
  802. if (AddTransNode(graph, format_transfer_op, format_transfer_node) != GRAPH_SUCCESS) {
  803. return GRAPH_FAILED;
  804. }
  805. NodePtr cast_node;
  806. if (AddTransNode(graph, cast_op, cast_node) != GRAPH_SUCCESS) {
  807. return GRAPH_FAILED;
  808. }
  809. if (insert_cast_first) {
  810. if (cast_node != nullptr) {
  811. new_trans_nodes.push_back(cast_node);
  812. }
  813. if (format_transfer_node != nullptr) {
  814. new_trans_nodes.push_back(format_transfer_node);
  815. }
  816. } else {
  817. if (format_transfer_node != nullptr) {
  818. new_trans_nodes.push_back(format_transfer_node);
  819. }
  820. if (cast_node != nullptr) {
  821. new_trans_nodes.push_back(cast_node);
  822. }
  823. }
  824. return GRAPH_SUCCESS;
  825. }
  826. graphStatus TransOpWithoutReshapeFusionPass::InsertNewTransOp(const ComputeGraphPtr &graph, const OpDescPtr &cast_op,
  827. const OpDescPtr &format_transfer_op, const int index,
  828. const bool insert_cast_first) {
  829. std::vector<NodePtr> new_trans_nodes;
  830. if (GetTransNode(graph, cast_op, format_transfer_op, insert_cast_first, new_trans_nodes) != GRAPH_SUCCESS) {
  831. return GRAPH_FAILED;
  832. }
  833. if (new_trans_nodes.empty()) {
  834. GELOGI("No new trans node. Do not need insert new transop.");
  835. return GRAPH_SUCCESS;
  836. }
  837. pair<OutDataAnchorPtr, InDataAnchorPtr> begin_out = sub_graph_anchors_[index].front();
  838. pair<OutDataAnchorPtr, InDataAnchorPtr> end_in = sub_graph_anchors_[index].back();
  839. auto out_anchor = begin_out.first;
  840. GE_CHECK_NOTNULL(out_anchor);
  841. auto out_owner_node = out_anchor->GetOwnerNode();
  842. GE_CHECK_NOTNULL(out_owner_node);
  843. auto in_anchor = end_in.second;
  844. GE_CHECK_NOTNULL(in_anchor);
  845. auto in_owner_node = in_anchor->GetOwnerNode();
  846. GE_CHECK_NOTNULL(in_owner_node);
  847. GELOGI("remove edge.src:%s, src idx:%d, dst:%s, dst idx:%d", end_in.first->GetOwnerNode()->GetName().c_str(),
  848. end_in.first->GetIdx(), in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetIdx());
  849. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(end_in.first, in_anchor), "remove edge failed");
  850. GELOGI("add edge.src:%s, src idx:%d, dst:%s", out_anchor->GetOwnerNode()->GetName().c_str(), out_anchor->GetIdx(),
  851. new_trans_nodes.front()->GetName().c_str());
  852. if (GraphUtils::AddEdge(out_anchor, new_trans_nodes.front()->GetInAnchor(0)) != GRAPH_SUCCESS) {
  853. return GRAPH_FAILED;
  854. } else {
  855. auto old_peer_in_anchor = begin_out.second;
  856. GE_CHECK_NOTNULL(old_peer_in_anchor);
  857. UpdateOutputName(out_anchor, old_peer_in_anchor, in_owner_node);
  858. }
  859. if (new_trans_nodes.size() > 1) {
  860. GELOGI("add edge.src:%s, dst:%s", new_trans_nodes.front()->GetName().c_str(),
  861. new_trans_nodes.back()->GetName().c_str());
  862. if (GraphUtils::AddEdge(new_trans_nodes.front()->GetOutAnchor(0), new_trans_nodes.back()->GetInAnchor(0)) !=
  863. GRAPH_SUCCESS) {
  864. return GRAPH_FAILED;
  865. } else {
  866. auto old_peer_out_anchor = end_in.first;
  867. GE_CHECK_NOTNULL(old_peer_out_anchor);
  868. UpdateInputName(old_peer_out_anchor, in_anchor, out_owner_node);
  869. }
  870. }
  871. GELOGI("add edge.src:%s, dst:%s, dst idx:%d", new_trans_nodes.back()->GetName().c_str(),
  872. in_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetIdx());
  873. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutAnchor(0), in_anchor) != GRAPH_SUCCESS) {
  874. return GRAPH_FAILED;
  875. }
  876. return RelinkControlEdge(index, out_anchor, new_trans_nodes);
  877. }
  878. graphStatus TransOpWithoutReshapeFusionPass::RelinkControlEdge(const int index, const OutDataAnchorPtr &out_anchor,
  879. const vector<NodePtr> &new_trans_nodes) {
  880. GE_CHECK_NOTNULL(out_anchor);
  881. if (new_trans_nodes.front() == nullptr || new_trans_nodes.back() == nullptr) {
  882. return GRAPH_FAILED;
  883. }
  884. if (sub_graph_has_control_edge_[index]) {
  885. GELOGI("add control edge.src:%s, dst:%s", out_anchor->GetOwnerNode()->GetName().c_str(),
  886. new_trans_nodes.front()->GetName().c_str());
  887. if (GraphUtils::AddEdge(out_anchor->GetOwnerNode()->GetOutControlAnchor(),
  888. new_trans_nodes.front()->GetInControlAnchor()) != GRAPH_SUCCESS) {
  889. return GRAPH_FAILED;
  890. }
  891. }
  892. for (const auto &peer_in_anchor : out_control_peer_in_control_anchors_[index]) {
  893. GE_CHECK_NOTNULL(peer_in_anchor);
  894. GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(),
  895. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  896. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) {
  897. return GRAPH_FAILED;
  898. }
  899. }
  900. for (const auto &peer_out_anchor : in_control_peer_out_control_anchors_[index]) {
  901. GE_CHECK_NOTNULL(peer_out_anchor);
  902. GELOGI("add control edge.src:%s, dst:%s", peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  903. new_trans_nodes.front()->GetName().c_str());
  904. if (GraphUtils::AddEdge(peer_out_anchor, new_trans_nodes.front()->GetInControlAnchor()) != GRAPH_SUCCESS) {
  905. return GRAPH_FAILED;
  906. }
  907. }
  908. for (const auto &peer_in_anchor : out_control_peer_in_data_anchors_[index]) {
  909. GE_CHECK_NOTNULL(peer_in_anchor);
  910. GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(),
  911. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  912. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutControlAnchor(), peer_in_anchor) != GRAPH_SUCCESS) {
  913. return GRAPH_FAILED;
  914. }
  915. }
  916. for (const auto &peer_in_anchor : out_data_peer_in_control_anchors_[index]) {
  917. GE_CHECK_NOTNULL(peer_in_anchor);
  918. GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(),
  919. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  920. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutDataAnchor(0), peer_in_anchor) != GRAPH_SUCCESS) {
  921. return GRAPH_FAILED;
  922. }
  923. }
  924. if (sub_graph_has_out_data_peer_in_control_edge_[index]) {
  925. auto in_anchor = sub_graph_anchors_[index].back().second;
  926. GELOGI("add control edge.src:%s, dst:%s", new_trans_nodes.back()->GetName().c_str(),
  927. in_anchor->GetOwnerNode()->GetName().c_str());
  928. if (GraphUtils::AddEdge(new_trans_nodes.back()->GetOutDataAnchor(0),
  929. in_anchor->GetOwnerNode()->GetInControlAnchor()) != GRAPH_SUCCESS) {
  930. return GRAPH_FAILED;
  931. }
  932. }
  933. return GRAPH_SUCCESS;
  934. }
  935. bool TransOpWithoutReshapeFusionPass::OpAccuracyAbilityCheck(const OpDescPtr &op_desc) {
  936. auto instance = GELib::GetInstance();
  937. if ((instance == nullptr) || (!instance->InitFlag())) {
  938. GELOGW("GELib is not initialized!");
  939. return false;
  940. }
  941. if (op_desc == nullptr) {
  942. return false;
  943. }
  944. OpsKernelManager &ops_kernel_manager = instance->OpsKernelManagerObj();
  945. vector<OpInfo> op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType());
  946. if (op_infos.empty()) {
  947. GELOGI("Can not get op info by op type:%s", op_desc->GetType().c_str());
  948. return false;
  949. }
  950. std::string unsupported_reason;
  951. for (const auto &it : op_infos) {
  952. auto kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores();
  953. auto &kernel_name = it.opKernelLib;
  954. auto kernel_info_store = kernel_map.find(kernel_name);
  955. if (kernel_info_store != kernel_map.end()) {
  956. if (kernel_info_store->second != nullptr &&
  957. kernel_info_store->second->CheckAccuracySupported(op_desc, unsupported_reason)) {
  958. op_desc->SetOpEngineName(it.engine);
  959. op_desc->SetOpKernelLibName(kernel_name);
  960. GELOGI("Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), it.engine.c_str(),
  961. op_desc->GetName().c_str());
  962. return true;
  963. }
  964. }
  965. }
  966. GELOGI("op %s CheckAccuracySupported failed!reason:%s", op_desc->GetType().c_str(), unsupported_reason.c_str());
  967. return false;
  968. }
  969. bool TransOpWithoutReshapeFusionPass::FusionFormatSupport(Format format) {
  970. return format == FORMAT_NCHW || format == FORMAT_NHWC || format == FORMAT_FRACTAL_Z || format == FORMAT_NC1HWC0;
  971. }
  972. graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphsBetweenNormalNode(
  973. const OutDataAnchorPtr &out_anchor, std::vector<vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>>> &sub_graphs_out,
  974. vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_list) {
  975. graphStatus ret = GRAPH_SUCCESS;
  976. if (out_anchor == nullptr) {
  977. return GRAPH_FAILED;
  978. }
  979. for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
  980. if (peer_in_anchor == nullptr || peer_in_anchor->GetOwnerNode() == nullptr ||
  981. peer_in_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
  982. continue;
  983. }
  984. nodes_list.emplace_back(out_anchor, peer_in_anchor);
  985. auto peer_in_node = peer_in_anchor->GetOwnerNode();
  986. GE_CHECK_NOTNULL(peer_in_node);
  987. if (!IsTransOp(peer_in_node)) {
  988. sub_graphs_out.push_back(nodes_list);
  989. nodes_list.pop_back();
  990. } else {
  991. for (const auto &peer_out_anchor : peer_in_node->GetAllOutDataAnchors()) {
  992. ret = GetSubGraphsBetweenNormalNode(peer_out_anchor, sub_graphs_out, nodes_list);
  993. if (ret != GRAPH_SUCCESS) {
  994. GELOGE(GRAPH_FAILED, "get all transops between normal node failed!node:%s", peer_in_node->GetName().c_str());
  995. return GRAPH_FAILED;
  996. }
  997. }
  998. nodes_list.pop_back();
  999. }
  1000. }
  1001. return GRAPH_SUCCESS;
  1002. }
  1003. bool TransOpWithoutReshapeFusionPass::IsTransOp(const NodePtr &node) {
  1004. // The caller guarantees that the pointer is not null.
  1005. return node->GetType() == CAST || node->GetType() == RESHAPE || node->GetType() == TRANSPOSE ||
  1006. node->GetType() == TRANSPOSED || node->GetType() == TRANSDATA;
  1007. }
  1008. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示