You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

variable_op_pass_bak.cc 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/variable_op_pass.h"
  17. #include <string>
  18. #include <vector>
  19. #include "common/formats/formats.h"
  20. #include "common/formats/utils/formats_trans_utils.h"
  21. #include "graph/ge_context.h"
  22. #include "graph/graph.h"
  23. #include "graph/manager/graph_var_manager.h"
  24. #include "graph/utils/graph_utils.h"
  25. #include "graph/utils/tensor_utils.h"
  26. #include "graph/utils/type_utils.h"
  27. namespace ge {
  28. namespace {
  29. const int kTransOpOutIndex = 0;
  30. Status ByPassTransNode(NodePtr &front_node, NodePtr &back_node) {
  31. GE_CHECK_NOTNULL(front_node);
  32. GE_CHECK_NOTNULL(back_node);
  33. GELOGD("Begin to bypass trans node %s", front_node->GetName().c_str());
  34. auto ret = GraphUtils::CopyInCtrlEdges(front_node, back_node);
  35. if (ret != GRAPH_SUCCESS) {
  36. GELOGE(INTERNAL_ERROR,
  37. "Failed to move control edges from trans "
  38. "node %s to var-ref %s",
  39. front_node->GetName().c_str(), back_node->GetName().c_str());
  40. return INTERNAL_ERROR;
  41. }
  42. auto back_node_in_anchor = back_node->GetInDataAnchor(0);
  43. if (back_node_in_anchor == nullptr) {
  44. GELOGE(INTERNAL_ERROR,
  45. "The back node %s does not have an "
  46. "input anchor",
  47. back_node->GetName().c_str());
  48. return INTERNAL_ERROR;
  49. }
  50. back_node_in_anchor->UnlinkAll();
  51. auto trans_in_anchor = front_node->GetInDataAnchor(0);
  52. if (trans_in_anchor == nullptr) {
  53. GELOGE(INTERNAL_ERROR,
  54. "Failed to get the in data anchor from trans"
  55. " node %s type %s",
  56. front_node->GetName().c_str(), front_node->GetType().c_str());
  57. return INTERNAL_ERROR;
  58. }
  59. auto prev_trans_node_out_anchor = trans_in_anchor->GetPeerOutAnchor();
  60. if (prev_trans_node_out_anchor == nullptr) {
  61. GELOGW(
  62. "The trans node %s does not have an input, so the ref node %s does"
  63. " not have any inputs after bypass",
  64. front_node->GetName().c_str(), front_node->GetName().c_str());
  65. } else {
  66. ret = GraphUtils::AddEdge(prev_trans_node_out_anchor, back_node_in_anchor);
  67. if (ret != GRAPH_SUCCESS) {
  68. GELOGE(INTERNAL_ERROR,
  69. "Failed to add edge between ref node %s "
  70. "and the prev node of trans node %s",
  71. back_node->GetName().c_str(), front_node->GetName().c_str());
  72. return INTERNAL_ERROR;
  73. }
  74. }
  75. return SUCCESS;
  76. }
  77. bool IsTransSupport(const TransNodeInfo &trans_info) {
  78. if (trans_info.output.GetShape().IsUnknownShape()) {
  79. return false;
  80. }
  81. if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) {
  82. return true;
  83. } else if (trans_info.node_type == TRANSDATA || trans_info.node_type == TRANSPOSED) {
  84. formats::TransArgs args{nullptr,
  85. trans_info.input.GetFormat(),
  86. trans_info.output.GetFormat(),
  87. trans_info.input.GetShape().GetDims(),
  88. trans_info.output.GetShape().GetDims(),
  89. trans_info.input.GetDataType()};
  90. return formats::IsTransFormatSupport(args);
  91. } else if (trans_info.node_type == CAST) {
  92. formats::CastArgs datatype_args{nullptr, static_cast<size_t>(trans_info.input.GetShape().GetShapeSize()),
  93. trans_info.input.GetDataType(), trans_info.output.GetDataType()};
  94. return formats::IsTransDataTypeSupport(datatype_args);
  95. } else {
  96. return false;
  97. }
  98. }
  99. std::string GetInAndOutDecsDiff(NodePtr &trans_node, bool reverse = false) {
  100. int tran_in_index = TransOpUtil::GetTransOpDataIndex(trans_node->GetType());
  101. auto op_desc = trans_node->GetOpDesc();
  102. GeTensorDesc input_desc = op_desc->GetInputDesc(tran_in_index);
  103. GeTensorDesc output_desc = op_desc->GetOutputDesc(kTransOpOutIndex);
  104. if (reverse) {
  105. GeTensorDesc tmp_desc = input_desc;
  106. input_desc = output_desc;
  107. output_desc = tmp_desc;
  108. }
  109. auto input_format = input_desc.GetFormat();
  110. auto input_type = input_desc.GetDataType();
  111. auto input_shape = input_desc.GetShape();
  112. auto output_format = output_desc.GetFormat();
  113. auto output_type = output_desc.GetDataType();
  114. auto output_shape = output_desc.GetShape();
  115. std::stringstream diff_key;
  116. diff_key.str("");
  117. if (input_format != output_format) {
  118. diff_key << static_cast<int>(input_format) << '-' << static_cast<int>(output_format) << '-';
  119. } else {
  120. diff_key << "*-";
  121. }
  122. if (input_type != output_type) {
  123. diff_key << static_cast<int>(input_type) << '-' << static_cast<int>(output_type) << '-';
  124. } else {
  125. diff_key << "*-";
  126. }
  127. if (!ge::formats::IsShapeEqual(input_shape, output_shape)) {
  128. for (auto dim : input_shape.GetDims()) {
  129. diff_key << dim << '-';
  130. }
  131. for (auto dim : output_shape.GetDims()) {
  132. diff_key << dim << '-';
  133. }
  134. } else {
  135. diff_key << "*";
  136. }
  137. return diff_key.str();
  138. }
  139. } // namespace
  140. Status VariableOpPass::Run(ge::ComputeGraphPtr graph) {
  141. if (graph == nullptr) {
  142. GELOGE(INTERNAL_ERROR, "Failed to run variable op pass, null graph");
  143. return INTERNAL_ERROR;
  144. }
  145. GELOGD("Begin to run variable op pass on graph %s, session %lu, graph id %u", graph->GetName().c_str(),
  146. GetContext().SessionId(), graph->GetGraphID());
  147. if (var_accelerate_ctrl_ == nullptr) {
  148. GELOGE(INTERNAL_ERROR, "Failed to run var op pass, the variable accelerate control is null");
  149. return INTERNAL_ERROR;
  150. }
  151. GELOGD("Begin to generate ref map for variable and refs, graph name:%s.", graph->GetName().c_str());
  152. if (RenewVarDesc(graph) != SUCCESS) {
  153. GELOGE(INTERNAL_ERROR, "Failed to renew var desc on graph");
  154. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  155. }
  156. if (GenerateVariableVariableRefMap(graph) != SUCCESS) {
  157. GELOGE(INTERNAL_ERROR, "Failed to generate variable map for graph %s", graph->GetName().c_str());
  158. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  159. }
  160. GELOGD("Begin to fusion variables and trans nodes");
  161. for (auto &var_to_refs : var_and_var_ref_map_) {
  162. auto &node = var_to_refs.first;
  163. GE_CHECK_NOTNULL(node);
  164. GE_CHECK_NOTNULL(var_accelerate_ctrl_);
  165. if (!var_accelerate_ctrl_->IsVarPermitToChangeFormats(node->GetName())) {
  166. GELOGD("The var %s does not permit to change formats, skip it", node->GetName().c_str());
  167. continue;
  168. }
  169. VarTransRoad fusion_road;
  170. auto ret = FusionIfNeed(node, fusion_road);
  171. if (ret != SUCCESS) {
  172. return ret;
  173. }
  174. if (fusion_road.empty()) {
  175. GELOGD("No need to fusion variable %s because it's fusion road is empty", node->GetName().c_str());
  176. continue;
  177. }
  178. ret = RenewTransRoadDesc(node, fusion_road);
  179. if (ret != SUCCESS) {
  180. GELOGE(INTERNAL_ERROR, "Failed to renew description fusion road for var %s", node->GetName().c_str());
  181. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  182. }
  183. auto start_iter = fusion_road.begin();
  184. auto end_iter = fusion_road.rbegin();
  185. GELOGD(
  186. "Trans variable data for %s from format %s to %s, shape %s to %s "
  187. "data-type %s to %s, path len %zu success",
  188. node->GetName().c_str(), TypeUtils::FormatToSerialString(start_iter->input.GetFormat()).c_str(),
  189. TypeUtils::FormatToSerialString(end_iter->output.GetFormat()).c_str(),
  190. formats::ShapeToString(start_iter->input.GetShape().GetDims()).c_str(),
  191. formats::ShapeToString(end_iter->output.GetShape().GetDims()).c_str(),
  192. TypeUtils::DataTypeToSerialString(start_iter->input.GetDataType()).c_str(),
  193. TypeUtils::DataTypeToSerialString(end_iter->output.GetDataType()).c_str(), fusion_road.size());
  194. ret = VarManager::Instance(graph->GetSessionID())->SetTransRoad(node->GetName(), fusion_road);
  195. if (ret != SUCCESS) {
  196. GELOGE(INTERNAL_ERROR, "Failed to update the format fusion road for var %s", node->GetName().c_str());
  197. return INTERNAL_ERROR;
  198. }
  199. ret = VarManager::Instance(graph->GetSessionID())->SetChangedGraphId(node->GetName(), graph->GetGraphID());
  200. if (ret != SUCCESS) {
  201. GELOGE(INTERNAL_ERROR, "Failed to update the graph id for var %s", node->GetName().c_str());
  202. return INTERNAL_ERROR;
  203. }
  204. var_accelerate_ctrl_->SetVarChanged(node->GetName());
  205. GELOGD("Begin to update format info for var %s.", node->GetName().c_str());
  206. std::set<ge::NodePtr> node_set({node});
  207. if (UpdateIOFormatInfo(end_iter->output, node_set) != SUCCESS) {
  208. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  209. }
  210. // renew var desc if the trans_road is all reshape or reformat
  211. ret = RenewVarDesc(graph->GetSessionID(), node, fusion_road);
  212. if (ret != SUCCESS) {
  213. GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str());
  214. return FAILED;
  215. }
  216. }
  217. return SUCCESS;
  218. }
  219. Status VariableOpPass::RenewTransRoadDesc(const NodePtr &var, VarTransRoad &fusion_road) {
  220. auto var_desc = var->GetOpDesc();
  221. GE_CHECK_NOTNULL(var_desc);
  222. TransNodeInfo prev_node_info;
  223. prev_node_info.node_type = var->GetType();
  224. prev_node_info.output = var_desc->GetOutputDesc(0);
  225. // two cases
  226. // fisrt Var->cast->transdata which transdata in fusion road
  227. // the input of transdata is not equal with output of var
  228. // case 1 : suppose input dtype of transdata equal with out dtype
  229. // but not equal with var
  230. // so we make input dtype and output dytpe of transroad equal with var
  231. // case 2: suppose input format of transdata not equal with out format
  232. // and input format not equal with var
  233. // so we make input format equal with var
  234. for (auto &cur_trans : fusion_road) {
  235. if (cur_trans.input.GetFormat() == cur_trans.output.GetFormat()) {
  236. cur_trans.output.SetFormat(prev_node_info.output.GetFormat());
  237. }
  238. if (cur_trans.input.GetDataType() == cur_trans.output.GetDataType()) {
  239. cur_trans.output.SetDataType(prev_node_info.output.GetDataType());
  240. }
  241. if (ge::formats::IsShapeEqual(cur_trans.input.GetShape(), cur_trans.output.GetShape())) {
  242. cur_trans.output.SetShape(prev_node_info.output.GetShape());
  243. }
  244. cur_trans.input = prev_node_info.output;
  245. prev_node_info.output = cur_trans.output;
  246. }
  247. return SUCCESS;
  248. }
  249. Status VariableOpPass::FusionIfNeed(const NodePtr &var, VarTransRoad &fusion_road) {
  250. bool can_fusion = false;
  251. while (true) {
  252. map<string, vector<NodePtr>> trans_type_to_trans_ops ;
  253. map<string, pair<string, bool>> trans_type_to_changed_desc;
  254. // record the order of trans op in first path
  255. vector<string> first_path_trans_order;
  256. auto ret = CheckIfCouldBeOptimized(var, first_path_trans_order, trans_type_to_changed_desc,
  257. trans_type_to_trans_ops, can_fusion);
  258. if (ret != SUCCESS) {
  259. GELOGE(FAILED, "Check trans ops after vatiable could be optimized or not failed");
  260. return ret;
  261. }
  262. if (!can_fusion) {
  263. break;
  264. }
  265. vector<pair<NodePtr, NodePtr>> delete_var_ref_trans_nodes;
  266. ret = GetAndCheckTransOpOfVarRef(var, can_fusion, trans_type_to_changed_desc, delete_var_ref_trans_nodes);
  267. if (ret != SUCCESS) {
  268. GELOGE(FAILED, "get and check trans op of varref failed");
  269. return ret;
  270. }
  271. if (!can_fusion) {
  272. break;
  273. }
  274. ret = UpdateTransRoad(fusion_road, first_path_trans_order,
  275. trans_type_to_changed_desc, trans_type_to_trans_ops);
  276. if (ret != SUCCESS) {
  277. GELOGE(FAILED, "Update trans road failed");
  278. return ret;
  279. }
  280. if (fusion_road.empty()) {
  281. return SUCCESS;
  282. }
  283. ret = DealFusion(var, fusion_road, trans_type_to_changed_desc,
  284. trans_type_to_trans_ops, delete_var_ref_trans_nodes);
  285. if (ret != SUCCESS) {
  286. return ret;
  287. }
  288. }
  289. return SUCCESS;
  290. }
  291. Status VariableOpPass::UpdateTransRoad(VarTransRoad &fusion_road, vector<std::string> &first_path_trans_order,
  292. map<std::string,std::pair<std::string, bool>> &trans_type_to_changed_desc,
  293. map<std::string,vector<NodePtr>> &trans_type_to_trans_ops){
  294. vector<std::string> delete_trans_type;
  295. for (auto &trans_type : first_path_trans_order) {
  296. if (trans_type_to_changed_desc.find(trans_type) == trans_type_to_changed_desc.end()) {
  297. continue;
  298. }
  299. bool delete_flag = false;
  300. for (auto &trans_node : trans_type_to_trans_ops[trans_type]) {
  301. int tran_in_index = TransOpUtil::GetTransOpDataIndex(trans_node->GetType());
  302. auto out_op_desc = trans_node->GetOpDesc();
  303. GE_CHECK_NOTNULL(out_op_desc);
  304. TransNodeInfo trans_node_info;
  305. trans_node_info.node_type = trans_node->GetType();
  306. trans_node_info.input = out_op_desc->GetInputDesc(tran_in_index);
  307. trans_node_info.output = out_op_desc->GetOutputDesc(kTransOpOutIndex);
  308. if (!IsTransSupport(trans_node_info)) {
  309. delete_flag = true;
  310. GELOGD("The trans node %s does not support, skip the variable accelerating", trans_node_info.node_type.c_str());
  311. break;
  312. }
  313. }
  314. if (delete_flag) {
  315. delete_trans_type.push_back(trans_type);
  316. } else {
  317. auto &trans_node = *trans_type_to_trans_ops[trans_type].begin();
  318. auto out_op_desc = trans_node->GetOpDesc();
  319. int tran_in_index = TransOpUtil::GetTransOpDataIndex(trans_node->GetType());
  320. TransNodeInfo trans_node_info;
  321. trans_node_info.node_type = trans_node->GetType();
  322. trans_node_info.input = out_op_desc->GetInputDesc(tran_in_index);
  323. trans_node_info.output = out_op_desc->GetOutputDesc(kTransOpOutIndex);
  324. fusion_road.emplace_back(trans_node_info);
  325. }
  326. }
  327. for (auto &trans_type : delete_trans_type) {
  328. trans_type_to_changed_desc.erase(trans_type);
  329. }
  330. return SUCCESS;
  331. }
  332. Status VariableOpPass::DealFusion(const ge::NodePtr &var_node, VarTransRoad &fusion_road,
  333. map<std::string, std::pair<std::string, bool>> trans_type_to_changed_desc,
  334. map<std::string, vector<NodePtr>> trans_type_to_trans_ops,
  335. vector<pair<NodePtr, NodePtr>> &delete_trans_nodes) {
  336. GE_CHECK_NOTNULL(var_node);
  337. GELOGD("Begin to fusion var %s with trans", var_node->GetName().c_str());
  338. auto graph = var_node->GetOwnerComputeGraph();
  339. for (auto &trans_type : trans_type_to_changed_desc) {
  340. for (auto &trans_node : trans_type_to_trans_ops[trans_type.first]) {
  341. GELOGD("Remove node %s type %s when fusion with variable %s", trans_node->GetName().c_str(),
  342. trans_node->GetType().c_str(), var_node->GetName().c_str());
  343. if (RenewTransOpDesc(trans_node, true) != SUCCESS) {
  344. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  345. }
  346. if (GraphUtils::IsolateNode(trans_node, {0}) != SUCCESS) {
  347. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  348. }
  349. if (GraphUtils::RemoveNodeWithoutRelink(graph, trans_node) != SUCCESS) {
  350. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  351. }
  352. }
  353. }
  354. // Iterate delete_trans_nodes backward, eg a->b->c, delete_trans_nodes:{{b,c},{a,b}}
  355. // we should delete {a,b} first , then b->c,then we can delete {b,c}
  356. // if we delete {b,c} first, then a->c, then we can not get b when we delete {a,b}
  357. for (auto iter = delete_trans_nodes.rbegin(); iter != delete_trans_nodes.rend(); ++iter) {
  358. auto front_node = iter->first;
  359. auto back_node = iter->second;
  360. if (RenewTransOpDesc(front_node, false) != SUCCESS) {
  361. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  362. }
  363. if (front_node->GetOutDataNodes().size() > 1) {
  364. GELOGD("The trans node %s type %s connecting with var-ref %s has more"
  365. " than one output data nodes, unlink the edge between them",
  366. front_node->GetName().c_str(), front_node->GetType().c_str(), back_node->GetName().c_str());
  367. if (ByPassTransNode(front_node, back_node) != SUCCESS) {
  368. GELOGE(INTERNAL_ERROR, "Failed to bypass trans node %s to node %s", front_node->GetName().c_str(),
  369. back_node->GetName().c_str());
  370. return INTERNAL_ERROR;
  371. }
  372. } else {
  373. GELOGD("The trans node %s type %s connecting with %s has only"
  374. " one output data nodes, isolate and remove it.",
  375. front_node->GetName().c_str(), front_node->GetType().c_str(), back_node->GetName().c_str());
  376. if (GraphUtils::IsolateNode(front_node, {0}) != SUCCESS) {
  377. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  378. }
  379. if (GraphUtils::RemoveNodeWithoutRelink(graph, front_node) != SUCCESS) {
  380. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  381. }
  382. }
  383. }
  384. return SUCCESS;
  385. }
  386. Status VariableOpPass::RenewTransOpDesc(ge::NodePtr &node, bool is_reverse) {
  387. int tran_in_index = TransOpUtil::GetTransOpDataIndex(node->GetType());
  388. auto op_desc = node->GetOpDesc();
  389. GE_CHECK_NOTNULL(op_desc);
  390. GeTensorDesc input_desc = op_desc->GetInputDesc(tran_in_index);
  391. GeTensorDesc output_desc = op_desc->GetOutputDesc(kTransOpOutIndex);
  392. GeTensorDesc renew_desc = is_reverse ? output_desc : input_desc;
  393. bool format_changed = false;
  394. bool shape_changed = false;
  395. bool dtype_changed = false;
  396. if (input_desc.GetFormat() != output_desc.GetFormat()) {
  397. format_changed = true;
  398. }
  399. if (input_desc.GetDataType() != output_desc.GetDataType()) {
  400. dtype_changed = true;
  401. }
  402. if (!ge::formats::IsShapeEqual(input_desc.GetShape(), output_desc.GetShape())) {
  403. shape_changed = true;
  404. }
  405. auto cur_node = node;
  406. while (TransOpUtil::IsTransOp(cur_node)) {
  407. tran_in_index = TransOpUtil::GetTransOpDataIndex(cur_node->GetType());
  408. auto next_node = is_reverse ? NodeUtils::GetInDataNodeByIndex(*cur_node, tran_in_index) :
  409. cur_node->GetOutDataNodes().at(kTransOpOutIndex);
  410. if (!TransOpUtil::IsTransOp(next_node)) {
  411. break;
  412. }
  413. auto prev_desc = next_node->GetOpDesc();
  414. tran_in_index = TransOpUtil::GetTransOpDataIndex(next_node->GetType());
  415. auto mutable_output_desc = prev_desc->MutableOutputDesc(kTransOpOutIndex);
  416. auto mutable_input_desc = prev_desc->MutableInputDesc(tran_in_index);
  417. GE_CHECK_NOTNULL(prev_desc->MutableOutputDesc(kTransOpOutIndex));
  418. GE_CHECK_NOTNULL(prev_desc->MutableInputDesc(tran_in_index));
  419. if (shape_changed) {
  420. mutable_input_desc->SetShape(renew_desc.GetShape());
  421. mutable_output_desc->SetShape(renew_desc.GetShape());
  422. }
  423. if (dtype_changed) {
  424. mutable_input_desc->SetDataType(renew_desc.GetDataType());
  425. mutable_output_desc->SetDataType(renew_desc.GetDataType());
  426. }
  427. if (format_changed) {
  428. mutable_input_desc->SetFormat(renew_desc.GetFormat());
  429. mutable_output_desc->SetFormat(renew_desc.GetFormat());
  430. }
  431. cur_node = next_node;
  432. }
  433. return SUCCESS;
  434. }
  435. Status VariableOpPass::CheckIfCouldBeOptimized(const NodePtr &var, vector<string> &first_path_trans_order,
  436. map<string, pair<string, bool>> &trans_type_to_changed_desc,
  437. map<string, vector<NodePtr>> &trans_type_to_trans_ops, bool &flag) {
  438. bool is_match = true;
  439. auto ret = GetSameTransOP(var, first_path_trans_order, trans_type_to_changed_desc,
  440. trans_type_to_trans_ops, is_match);
  441. if (ret != SUCCESS) {
  442. GELOGE(FAILED, "Get same trans op of variable node: %s failed", var->GetName().c_str());
  443. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  444. }
  445. if (!is_match) {
  446. flag = false;
  447. GELOGI("trans nodes after variable do not meet the condition");
  448. return SUCCESS;
  449. }
  450. flag = true;
  451. return SUCCESS;
  452. }
  453. Status VariableOpPass::GetSameTransOP(const NodePtr &var, vector<string> &first_path_trans_order,
  454. map<string, pair<string, bool>> &trans_type_to_changed_desc,
  455. map<string, vector<NodePtr>> &trans_type_to_trans_ops, bool &is_match) {
  456. GELOGD("Begin to get Node: %s trans op info of first path", var->GetName().c_str());
  457. auto ret = GetFisrtPathTransInfo(var, first_path_trans_order,
  458. trans_type_to_changed_desc, trans_type_to_trans_ops);
  459. if (ret != SUCCESS) {
  460. GELOGE(FAILED, "Get var: %s first path trans info failed", var->GetName().c_str());
  461. return FAILED;
  462. }
  463. if (first_path_trans_order.empty()) {
  464. GELOGD("var %s first path has no trans op, not need to pass", var->GetName().c_str());
  465. is_match = false;
  466. return SUCCESS;
  467. }
  468. GELOGD("Begin to depth first search Node: %s ", var->GetName().c_str());
  469. VariableDFS(var, trans_type_to_changed_desc, trans_type_to_trans_ops, is_match);
  470. return SUCCESS;
  471. }
  472. void VariableOpPass::VariableDFS(const NodePtr &node, map<string, pair<string, bool>> &trans_type_to_changed_desc,
  473. map<string, vector<NodePtr>> &trans_type_to_trans_ops, bool &is_match) {
  474. std::stack<NodePtr> node_stack;
  475. std::stack<vector<NodePtr>> path_stack;
  476. for (auto &out_node : node->GetOutDataNodes()) {
  477. if (!is_match) {
  478. break;
  479. }
  480. if (out_node->GetOutDataNodesSize() == 0 || !ge::TransOpUtil::IsTransOp(out_node)) {
  481. is_match = false;
  482. break;
  483. }
  484. node_stack.push(out_node);
  485. path_stack.emplace(vector<NodePtr>{out_node});
  486. while (!node_stack.empty() && is_match) {
  487. auto cur_node = node_stack.top();
  488. auto cur_path = path_stack.top();
  489. node_stack.pop();
  490. path_stack.pop();
  491. if (cur_node->GetOutDataNodesSize() == 0 || !ge::TransOpUtil::IsTransOp(cur_node)) {
  492. UpdateTransInfo(cur_path, is_match, trans_type_to_changed_desc, trans_type_to_trans_ops);
  493. continue;
  494. }
  495. for (auto &next_node : cur_node->GetOutDataNodes()) {
  496. node_stack.push(next_node);
  497. auto next_path = cur_path;
  498. next_path.push_back(next_node);
  499. path_stack.emplace(next_path);
  500. }
  501. }
  502. }
  503. }
  504. Status VariableOpPass::UpdateTransInfo(vector<NodePtr> &cur_path, bool& is_match,
  505. map<string, pair<string, bool>> &trans_type_to_changed_desc,
  506. map<string, vector<NodePtr>> &trans_type_to_trans_ops) {
  507. GELOGD("Begin to update trans info by path");
  508. std::set<string> trans_op_occured;
  509. for (auto &trans_node : cur_path) {
  510. auto trans_node_type = trans_node->GetType();
  511. if (trans_op_occured.find(trans_node_type) != trans_op_occured.end() ||
  512. !ge::TransOpUtil::IsTransOp(trans_node_type)) {
  513. continue;
  514. }
  515. trans_op_occured.insert(trans_node_type);
  516. auto desc_diff = GetInAndOutDecsDiff(trans_node);
  517. if (trans_type_to_changed_desc.find(trans_node_type) != trans_type_to_changed_desc.end() &&
  518. desc_diff == trans_type_to_changed_desc[trans_node_type].first) {
  519. trans_type_to_changed_desc[trans_node_type].second = true;
  520. auto iter = find(trans_type_to_trans_ops[trans_node_type].begin(),
  521. trans_type_to_trans_ops[trans_node_type].end(),
  522. trans_node);
  523. if (iter == trans_type_to_trans_ops[trans_node_type].end()) {
  524. trans_type_to_trans_ops[trans_node_type].push_back(trans_node);
  525. }
  526. }
  527. }
  528. std::set<string> delete_trans_types;
  529. for (auto &trans_item : trans_type_to_changed_desc) {
  530. if (!trans_item.second.second) {
  531. delete_trans_types.insert(trans_item.first);
  532. } else {
  533. trans_item.second.second = false;
  534. }
  535. }
  536. for (auto& delete_item : delete_trans_types) {
  537. trans_type_to_changed_desc.erase(delete_item);
  538. }
  539. if (trans_type_to_changed_desc.empty()) {
  540. is_match = false;
  541. }
  542. return SUCCESS;
  543. }
  544. Status VariableOpPass::GetFisrtPathTransInfo(const NodePtr &var, vector<string> &first_path_trans_order,
  545. map<string, pair<string, bool>> &trans_type_to_changed_desc,
  546. map<string, vector<NodePtr>> &trans_type_to_trans_ops) {
  547. auto cur_node = var;
  548. while (cur_node->GetOutDataNodesSize() != 0) {
  549. cur_node = cur_node->GetOutDataNodes().at(0);
  550. GE_CHECK_NOTNULL(cur_node);
  551. if (!ge::TransOpUtil::IsTransOp(cur_node)) {
  552. break;
  553. }
  554. auto cur_node_type = cur_node->GetType();
  555. // only get the the first occurrence operator of same type
  556. if (trans_type_to_changed_desc.find(cur_node_type) == trans_type_to_changed_desc.end()) {
  557. auto desc_diff = GetInAndOutDecsDiff(cur_node);
  558. trans_type_to_changed_desc[cur_node->GetType()] = make_pair(desc_diff, false);
  559. trans_type_to_trans_ops[cur_node->GetType()] = vector<NodePtr>{cur_node};
  560. first_path_trans_order.push_back(cur_node->GetType());
  561. }
  562. }
  563. GELOGD("get var %s first path trans info success", var->GetName().c_str());
  564. return SUCCESS;
  565. }
  566. Status VariableOpPass::GetAndCheckTransOpOfVarRef(const ge::NodePtr &var_node, bool &pass_check,
  567. map<string, pair<string, bool>> &trans_type_to_changed_desc,
  568. vector<pair<NodePtr, NodePtr>> &delete_var_ref_trans_nodes) {
  569. auto iterator = var_and_var_ref_map_.find(var_node);
  570. if (iterator == var_and_var_ref_map_.end()) {
  571. GELOGD("there is no var_ref of node %s", var_node->GetName().c_str());
  572. return SUCCESS;
  573. }
  574. vector<string> delete_trans_type;
  575. for (auto &trans_type : trans_type_to_changed_desc) {
  576. delete_trans_type.push_back(trans_type.first);
  577. }
  578. for (auto &ref_node : iterator->second) {
  579. GE_CHECK_NOTNULL(ref_node);
  580. auto cur_node = *ref_node->GetInDataNodes().begin();
  581. auto behind_node = ref_node;
  582. GE_CHECK_NOTNULL(cur_node);
  583. vector<string> tmp_delete_trans_type = delete_trans_type;
  584. while (TransOpUtil::IsTransOp(cur_node)) {
  585. GE_CHECK_NOTNULL(cur_node);
  586. auto iter = find(tmp_delete_trans_type.begin(), tmp_delete_trans_type.end(), cur_node->GetType());
  587. if (iter != tmp_delete_trans_type.end()) {
  588. CheckTransOpOfVarAndVarRefSymmetry(cur_node, trans_type_to_changed_desc[cur_node->GetType()].first,
  589. pass_check);
  590. if (!pass_check) {
  591. GELOGD("trans op : %s of var ref %s is illegal", cur_node->GetName().c_str(), ref_node->GetName().c_str());
  592. return SUCCESS;
  593. }
  594. tmp_delete_trans_type.erase(iter);
  595. delete_var_ref_trans_nodes.emplace_back(std::make_pair(cur_node, behind_node));
  596. }
  597. int tran_in_index = TransOpUtil::GetTransOpDataIndex(cur_node->GetType());
  598. behind_node = cur_node;
  599. cur_node = cur_node->GetInDataNodes().at(tran_in_index);
  600. }
  601. if (!tmp_delete_trans_type.empty()) {
  602. pass_check = false;
  603. return SUCCESS;
  604. }
  605. }
  606. return SUCCESS;
  607. }
  608. Status VariableOpPass::CheckTransOpOfVarAndVarRefSymmetry(NodePtr &var_ref_trans_op, const string &desc_diff,
  609. bool &is_symmetry){
  610. auto var_ref_trans_op_desc_diff = GetInAndOutDecsDiff(var_ref_trans_op, true);
  611. is_symmetry = (var_ref_trans_op_desc_diff == desc_diff);
  612. return SUCCESS;
  613. }
  614. Status VariableOpPass::UpdateVarAndRefOutputFormatInfo(const GeTensorDesc &final_output, const ge::NodePtr &node) {
  615. if (node == nullptr || node->GetOpDesc() == nullptr) {
  616. GELOGE(FAILED, "node or opdesc is nullptr");
  617. return FAILED;
  618. }
  619. const Format &format = final_output.GetFormat();
  620. const DataType &data_type = final_output.GetDataType();
  621. const GeShape &shape = final_output.GetShape();
  622. GELOGD("last ref is (%s, %s, %lu), var_ref_name is %s.", TypeUtils::DataTypeToSerialString(data_type).c_str(),
  623. TypeUtils::FormatToSerialString(format).c_str(), shape.GetDims().size(), node->GetName().c_str());
  624. auto node_desc = node->GetOpDesc()->GetOutputDesc(0);
  625. CopyVariableFormatDataTypeAndShape(final_output, node_desc);
  626. if (node->GetOpDesc()->UpdateOutputDesc(0, node_desc) != GRAPH_SUCCESS) {
  627. GELOGE(FAILED, "update output desc fail.");
  628. return FAILED;
  629. }
  630. GELOGD("node ref is (%s, %s, %lu), var_ref_name is %s.",
  631. TypeUtils::DataTypeToSerialString(node->GetOpDesc()->GetOutputDesc(0).GetDataType()).c_str(),
  632. TypeUtils::FormatToSerialString(node->GetOpDesc()->GetOutputDesc(0).GetFormat()).c_str(),
  633. node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims().size(), node->GetName().c_str());
  634. auto iterator = var_and_var_ref_map_.find(node);
  635. if (iterator == var_and_var_ref_map_.end()) {
  636. auto graph = node->GetOwnerComputeGraph();
  637. if (GenerateVariableVariableRefMap(graph) != SUCCESS) {
  638. GELOGE(INTERNAL_ERROR, "Failed to generate variable map for graph %s", graph->GetName().c_str());
  639. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  640. }
  641. }
  642. iterator = var_and_var_ref_map_.find(node);
  643. if (iterator == var_and_var_ref_map_.end()) {
  644. GELOGW("The var node %s which belongs to graph %s can not be found on the graph", node->GetName().c_str(),
  645. node->GetOwnerComputeGraph()->GetName().c_str());
  646. return SUCCESS;
  647. }
  648. for (const auto &var_ref_node : iterator->second) {
  649. auto var_ref_node_description = var_ref_node->GetOpDesc();
  650. GE_CHECK_NOTNULL(var_ref_node_description);
  651. GELOGD("var_ref_node before is (%s, %s, %zu), var_ref_name is %s.",
  652. TypeUtils::DataTypeToSerialString(data_type).c_str(), TypeUtils::FormatToSerialString(format).c_str(),
  653. shape.GetDims().size(), var_ref_node->GetName().c_str());
  654. if (var_ref_node_description->UpdateOutputDesc(0, node_desc) != GRAPH_SUCCESS) {
  655. GELOGW("UpdateOutputDesc fail.");
  656. }
  657. if (var_ref_node_description->UpdateInputDesc(0, node_desc) != GRAPH_SUCCESS) {
  658. GELOGW("UpdateInputDesc fail.");
  659. }
  660. const auto &input_desc = var_ref_node_description->MutableInputDesc(0);
  661. const auto &output_desc = var_ref_node_description->MutableOutputDesc(0);
  662. GE_CHECK_NOTNULL(input_desc);
  663. GE_CHECK_NOTNULL(output_desc);
  664. GELOGD("var_ref_node ref is (%s, %s, %zu), var_ref_name is %s.",
  665. TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str(),
  666. TypeUtils::FormatToSerialString(input_desc->GetFormat()).c_str(), output_desc->GetShape().GetDims().size(),
  667. var_ref_node->GetName().c_str());
  668. }
  669. return SUCCESS;
  670. }
  671. Status VariableOpPass::GenerateVariableVariableRefMap(const ComputeGraphPtr &compute_graph) {
  672. std::map<std::string, NodePtr> names_to_var;
  673. std::map<std::string, std::set<NodePtr>> names_to_refs;
  674. GE_CHECK_NOTNULL(compute_graph);
  675. for (auto &node : compute_graph->GetDirectNode()) {
  676. if (node->GetType() != VARIABLE) {
  677. continue;
  678. }
  679. std::string ref_var_name;
  680. if (!ge::AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_name)) {
  681. names_to_var[node->GetName()] = node;
  682. } else {
  683. names_to_refs[ref_var_name].insert(node);
  684. }
  685. }
  686. for (auto &name_to_var : names_to_var) {
  687. var_and_var_ref_map_[name_to_var.second] = names_to_refs[name_to_var.first];
  688. }
  689. return SUCCESS;
  690. }
  691. void VariableOpPass::CopyVariableFormatDataTypeAndShape(const GeTensorDesc &src_tensor_desc,
  692. GeTensorDesc &dst_tensor_desc) {
  693. dst_tensor_desc.SetShape(src_tensor_desc.GetShape());
  694. dst_tensor_desc.SetFormat(src_tensor_desc.GetFormat());
  695. dst_tensor_desc.SetDataType(src_tensor_desc.GetDataType());
  696. }
  697. Status VariableOpPass::UpdateIOFormatInfo(const GeTensorDesc &final_output, std::set<NodePtr> &nodes) {
  698. for (auto &need_set_node : nodes) {
  699. auto ret = UpdateVarAndRefOutputFormatInfo(final_output, need_set_node);
  700. if (ret != SUCCESS) {
  701. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  702. }
  703. }
  704. return SUCCESS;
  705. }
  706. Status VariableOpPass::RenewVarDesc(ge::ComputeGraphPtr &graph) {
  707. GE_CHECK_NOTNULL(graph);
  708. // renew var manager desc
  709. Status ret = SUCCESS;
  710. for (auto &node : graph->GetDirectNode()) {
  711. bool is_var_node =
  712. (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == VARHANDLEOP);
  713. if (is_var_node) {
  714. if (!ge::VarManager::Instance(graph->GetSessionID())->IsVarExist(node->GetName())) {
  715. GELOGD("var manager does not exist var node[%s]", node->GetName().c_str());
  716. continue;
  717. }
  718. GELOGD("var manager exist var node[%s], graph name[%s]", node->GetName().c_str(), graph->GetName().c_str());
  719. GE_CHECK_NOTNULL(node->GetOpDesc());
  720. ret = ge::VarManager::Instance(graph->GetSessionID())->RenewCurVarDesc(node->GetName(), node->GetOpDesc());
  721. if (ret != SUCCESS) {
  722. GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str());
  723. return FAILED;
  724. }
  725. }
  726. }
  727. return SUCCESS;
  728. }
  729. Status VariableOpPass::RenewVarDesc(uint64_t session_id, const NodePtr &node, const VarTransRoad &fusion_road) {
  730. // renew var desc if the trans_road is all reshape or reformat
  731. for (auto &road : fusion_road) {
  732. if (road.node_type != RESHAPE && road.node_type != REFORMAT) {
  733. return SUCCESS;
  734. }
  735. }
  736. if (!ge::VarManager::Instance(session_id)->IsVarExist(node->GetName())) {
  737. GELOGD("var manager does not exist var node[%s]", node->GetName().c_str());
  738. return SUCCESS;
  739. }
  740. GELOGD("var manager exist var node[%s]", node->GetName().c_str());
  741. GE_CHECK_NOTNULL(node->GetOpDesc());
  742. Status ret = ge::VarManager::Instance(session_id)->RenewCurVarDesc(node->GetName(), node->GetOpDesc());
  743. if (ret != SUCCESS) {
  744. GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str());
  745. return FAILED;
  746. }
  747. return SUCCESS;
  748. }
  749. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示