You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

variable_op_pass.cc 25 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago

  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/variable_op_pass.h"
  17. #include <string>
  18. #include <vector>
  19. #include "common/formats/formats.h"
  20. #include "common/formats/utils/formats_trans_utils.h"
  21. #include "graph/ge_context.h"
  22. #include "graph/graph.h"
  23. #include "graph/manager/graph_var_manager.h"
  24. #include "graph/utils/graph_utils.h"
  25. #include "graph/utils/tensor_utils.h"
  26. #include "graph/utils/type_utils.h"
  27. namespace ge {
  28. namespace {
  29. const int kTransOpOutIndex = 0;
  30. std::string GetKey(Format format, DataType type, const std::vector<int64_t> &dims) {
  31. std::stringstream key;
  32. key << static_cast<int>(format) << '-';
  33. key << static_cast<int>(type) << '-';
  34. for (auto dim : dims) {
  35. key << dim << '-';
  36. }
  37. return key.str();
  38. }
  39. Status ByPassTransNode(NodePtr &trans_node, NodePtr &ref_node) {
  40. GE_CHECK_NOTNULL(trans_node);
  41. GE_CHECK_NOTNULL(ref_node);
  42. GELOGD("Begin to bypass trans node %s", trans_node->GetName().c_str());
  43. auto ret = GraphUtils::CopyInCtrlEdges(trans_node, ref_node);
  44. if (ret != GRAPH_SUCCESS) {
  45. GELOGE(INTERNAL_ERROR,
  46. "Failed to move control edges from trans "
  47. "node %s to var-ref %s",
  48. trans_node->GetName().c_str(), ref_node->GetName().c_str());
  49. return INTERNAL_ERROR;
  50. }
  51. auto ref_in_anchor = ref_node->GetInDataAnchor(0);
  52. if (ref_in_anchor == nullptr) {
  53. GELOGE(INTERNAL_ERROR,
  54. "The variable ref node %s does not have an "
  55. "input anchor",
  56. ref_node->GetName().c_str());
  57. return INTERNAL_ERROR;
  58. }
  59. ref_in_anchor->UnlinkAll();
  60. auto trans_in_anchor = trans_node->GetInDataAnchor(0);
  61. if (trans_in_anchor == nullptr) {
  62. GELOGE(INTERNAL_ERROR,
  63. "Failed to get the in data anchor from trans"
  64. " node %s type %s",
  65. trans_node->GetName().c_str(), trans_node->GetType().c_str());
  66. return INTERNAL_ERROR;
  67. }
  68. auto prev_trans_node_out_anchor = trans_in_anchor->GetPeerOutAnchor();
  69. if (prev_trans_node_out_anchor == nullptr) {
  70. GELOGW(
  71. "The trans node %s does not have an input, so the ref node %s does"
  72. " not have any inputs after bypass",
  73. trans_node->GetName().c_str(), trans_node->GetName().c_str());
  74. } else {
  75. ret = GraphUtils::AddEdge(prev_trans_node_out_anchor, ref_in_anchor);
  76. if (ret != GRAPH_SUCCESS) {
  77. GELOGE(INTERNAL_ERROR,
  78. "Failed to add edge between ref node %s "
  79. "and the prev node of trans node %s",
  80. ref_node->GetName().c_str(), trans_node->GetName().c_str());
  81. return INTERNAL_ERROR;
  82. }
  83. }
  84. return SUCCESS;
  85. }
  86. bool IsTransSupport(const TransNodeInfo &trans_info) {
  87. if (trans_info.output.GetShape().IsUnknownShape()) {
  88. return false;
  89. }
  90. if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) {
  91. return true;
  92. } else if (trans_info.node_type == TRANSDATA || trans_info.node_type == TRANSPOSED) {
  93. formats::TransArgs args{nullptr,
  94. trans_info.input.GetFormat(),
  95. trans_info.output.GetFormat(),
  96. trans_info.input.GetShape().GetDims(),
  97. trans_info.output.GetShape().GetDims(),
  98. trans_info.input.GetDataType()};
  99. return formats::IsTransFormatSupport(args);
  100. } else if (trans_info.node_type == CAST) {
  101. formats::CastArgs datatype_args{nullptr, static_cast<size_t>(trans_info.input.GetShape().GetShapeSize()),
  102. trans_info.input.GetDataType(), trans_info.output.GetDataType()};
  103. return formats::IsTransDataTypeSupport(datatype_args);
  104. } else {
  105. return false;
  106. }
  107. }
  108. } // namespace
  109. Status VariableOpPass::Run(ge::ComputeGraphPtr graph) {
  110. if (graph == nullptr) {
  111. GELOGE(INTERNAL_ERROR, "Failed to run variable op pass, null graph");
  112. return INTERNAL_ERROR;
  113. }
  114. GELOGD("Begin to run variable op pass on graph %s, session %lu, graph id %u", graph->GetName().c_str(),
  115. GetContext().SessionId(), graph->GetGraphID());
  116. if (var_accelerate_ctrl_ == nullptr) {
  117. GELOGE(INTERNAL_ERROR, "Failed to run var op pass, the variable accelerate control is null");
  118. return INTERNAL_ERROR;
  119. }
  120. GELOGD("Begin to generate ref map for variable and refs, graph name:%s.", graph->GetName().c_str());
  121. if (RenewVarDesc(graph) != SUCCESS) {
  122. GELOGE(INTERNAL_ERROR, "Failed to renew var desc on graph");
  123. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  124. }
  125. if (GenerateVariableVariableRefMap(graph) != SUCCESS) {
  126. GELOGE(INTERNAL_ERROR, "Failed to generate variable map for graph %s", graph->GetName().c_str());
  127. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  128. }
  129. GELOGD("Begin to fusion variables and trans nodes");
  130. for (auto &var_to_refs : var_and_var_ref_map_) {
  131. auto &node = var_to_refs.first;
  132. GE_CHECK_NOTNULL(node);
  133. GE_CHECK_NOTNULL(var_accelerate_ctrl_);
  134. if (!var_accelerate_ctrl_->IsVarPermitToChangeFormats(node->GetName())) {
  135. GELOGD("The var %s does not permit to change formats, skip it", node->GetName().c_str());
  136. continue;
  137. }
  138. VarTransRoad fusion_road;
  139. auto ret = FusionIfNeed(node, fusion_road);
  140. if (ret != SUCCESS) {
  141. return ret;
  142. }
  143. if (fusion_road.empty()) {
  144. GELOGD("No need to fusion variable and trans op for var %s", node->GetName().c_str());
  145. continue;
  146. }
  147. auto start_iter = fusion_road.begin();
  148. auto end_iter = fusion_road.rbegin();
  149. GELOGD(
  150. "Trans variable data for %s from format %s to %s, shape %s to %s "
  151. "data-type %s to %s, path len %zu success",
  152. node->GetName().c_str(), TypeUtils::FormatToSerialString(start_iter->input.GetFormat()).c_str(),
  153. TypeUtils::FormatToSerialString(end_iter->output.GetFormat()).c_str(),
  154. formats::ShapeToString(start_iter->input.GetShape().GetDims()).c_str(),
  155. formats::ShapeToString(end_iter->output.GetShape().GetDims()).c_str(),
  156. TypeUtils::DataTypeToSerialString(start_iter->input.GetDataType()).c_str(),
  157. TypeUtils::DataTypeToSerialString(end_iter->output.GetDataType()).c_str(), fusion_road.size());
  158. ret = VarManager::Instance(graph->GetSessionID())->SetTransRoad(node->GetName(), fusion_road);
  159. if (ret != SUCCESS) {
  160. GELOGE(INTERNAL_ERROR, "Failed to update the format fusion road for var %s", node->GetName().c_str());
  161. return INTERNAL_ERROR;
  162. }
  163. ret = VarManager::Instance(graph->GetSessionID())->SetChangedGraphId(node->GetName(), graph->GetGraphID());
  164. if (ret != SUCCESS) {
  165. GELOGE(INTERNAL_ERROR, "Failed to update the graph id for var %s", node->GetName().c_str());
  166. return INTERNAL_ERROR;
  167. }
  168. var_accelerate_ctrl_->SetVarChanged(node->GetName());
  169. GELOGD("Begin to update format info for var %s.", node->GetName().c_str());
  170. std::set<ge::NodePtr> node_set({node});
  171. if (UpdateIOFormatInfo(end_iter->output, node_set) != SUCCESS) {
  172. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  173. }
  174. // renew var desc if the trans_road is all reshape or reformat
  175. ret = RenewVarDesc(graph->GetSessionID(), node, fusion_road);
  176. if (ret != SUCCESS) {
  177. GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str());
  178. return FAILED;
  179. }
  180. }
  181. return SUCCESS;
  182. }
  183. Status VariableOpPass::DealFusion(const ge::NodePtr &var_node) {
  184. GE_CHECK_NOTNULL(var_node);
  185. GELOGD("Begin to fusion var %s with trans", var_node->GetName().c_str());
  186. auto graph = var_node->GetOwnerComputeGraph();
  187. for (auto &trans_node : var_node->GetOutDataNodes()) {
  188. GELOGD("Remove node %s type %s when fusion with variable %s", trans_node->GetName().c_str(),
  189. trans_node->GetType().c_str(), var_node->GetName().c_str());
  190. if (GraphUtils::IsolateNode(trans_node, {0}) != SUCCESS) {
  191. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  192. }
  193. if (GraphUtils::RemoveNodeWithoutRelink(graph, trans_node) != SUCCESS) {
  194. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  195. }
  196. }
  197. auto iterator = var_and_var_ref_map_.find(var_node);
  198. if (iterator == var_and_var_ref_map_.end()) {
  199. GELOGD("there is no var_ref of node %s", var_node->GetName().c_str());
  200. return SUCCESS;
  201. }
  202. for (auto ref_node : iterator->second) {
  203. GE_CHECK_NOTNULL(ref_node);
  204. for (auto &trans_node : ref_node->GetInDataNodes()) {
  205. GELOGD("Remove node %s type %s when fusion with variable %s", trans_node->GetName().c_str(),
  206. trans_node->GetType().c_str(), var_node->GetName().c_str());
  207. if (trans_node->GetOutDataNodes().size() > 1) {
  208. GELOGD(
  209. "The trans node %s type %s connecting with var-ref %s has more"
  210. " than one output data nodes, unlink the edge between them",
  211. trans_node->GetName().c_str(), trans_node->GetType().c_str(), ref_node->GetName().c_str());
  212. if (ByPassTransNode(trans_node, ref_node) != SUCCESS) {
  213. GELOGE(INTERNAL_ERROR, "Failed to bypass trans node %s to ref %s", trans_node->GetName().c_str(),
  214. ref_node->GetName().c_str());
  215. return INTERNAL_ERROR;
  216. }
  217. } else {
  218. GELOGD(
  219. "The trans node %s type %s connecting with var-ref %s has only"
  220. " one output data nodes, isolate and remove it.",
  221. trans_node->GetName().c_str(), trans_node->GetType().c_str(), ref_node->GetName().c_str());
  222. if (GraphUtils::IsolateNode(trans_node, {0}) != SUCCESS) {
  223. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  224. }
  225. if (GraphUtils::RemoveNodeWithoutRelink(graph, trans_node) != SUCCESS) {
  226. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  227. }
  228. }
  229. }
  230. }
  231. return SUCCESS;
  232. }
  233. Status VariableOpPass::CheckSameAndTransOp(const ge::NodePtr &var_node, bool &is_matched, VarTransRoad &fusion_road) {
  234. std::set<std::string> data_type_and_formats;
  235. std::string trans_op_type;
  236. ge::NodePtr out_node;
  237. ge::GeTensorDesc output_desc;
  238. GE_CHECK_NOTNULL(var_node);
  239. for (auto &out_node_and_anchor : var_node->GetOutDataNodesAndAnchors()) {
  240. auto in_anchor = out_node_and_anchor.second;
  241. GE_CHECK_NOTNULL(in_anchor);
  242. out_node = out_node_and_anchor.first;
  243. GE_CHECK_NOTNULL(out_node);
  244. auto trans_op_desc = out_node->GetOpDesc();
  245. GE_CHECK_NOTNULL(trans_op_desc);
  246. trans_op_type = trans_op_desc->GetType();
  247. GELOGD("current node type is %s.", trans_op_type.c_str());
  248. int data_index = TransOpUtil::GetTransOpDataIndex(trans_op_type);
  249. if (data_index < 0) {
  250. GELOGD("Variables only can be fusion with trans_op, the next op is %s type %s", out_node->GetName().c_str(),
  251. out_node->GetType().c_str());
  252. return SUCCESS;
  253. }
  254. if (data_index != in_anchor->GetIdx()) {
  255. GELOGD(
  256. "Variables only can be fusion with trans nodes, the next node %s"
  257. " type %s index %d does not trans anything(correct index %d)",
  258. out_node->GetName().c_str(), out_node->GetType().c_str(), in_anchor->GetIdx(), data_index);
  259. return SUCCESS;
  260. }
  261. output_desc = trans_op_desc->GetOutputDesc(kTransOpOutIndex);
  262. auto trans_op_format = output_desc.GetFormat();
  263. auto trans_op_data_type = output_desc.GetDataType();
  264. auto shape = output_desc.GetShape().GetDims();
  265. auto datatype_and_format = GetKey(trans_op_format, trans_op_data_type, shape);
  266. data_type_and_formats.insert(datatype_and_format);
  267. }
  268. if (data_type_and_formats.empty()) {
  269. return SUCCESS;
  270. }
  271. if (data_type_and_formats.size() > 1) {
  272. std::stringstream type_and_formats_stream;
  273. bool first_time = true;
  274. for (const auto &data_type_and_format : data_type_and_formats) {
  275. if (first_time) {
  276. first_time = false;
  277. } else {
  278. type_and_formats_stream << "|";
  279. }
  280. type_and_formats_stream << data_type_and_format;
  281. }
  282. GELOGW(
  283. "trans_op type size for var Node(%s) is over 1, Currently not"
  284. " supported, dataTypeAndFormats is %s.",
  285. var_node->GetName().c_str(), type_and_formats_stream.str().c_str());
  286. return SUCCESS;
  287. }
  288. int tran_in_index = TransOpUtil::GetTransOpDataIndex(out_node->GetType());
  289. auto out_op_desc = out_node->GetOpDesc();
  290. GE_CHECK_NOTNULL(out_op_desc);
  291. TransNodeInfo trans_node_info;
  292. trans_node_info.node_type = out_node->GetType();
  293. trans_node_info.input = out_op_desc->GetInputDesc(tran_in_index);
  294. trans_node_info.output = out_op_desc->GetOutputDesc(kTransOpOutIndex);
  295. if (!IsTransSupport(trans_node_info)) {
  296. GELOGD("The trans node %s does not support, skip the variable accelerating", trans_node_info.node_type.c_str());
  297. return SUCCESS;
  298. }
  299. is_matched = true;
  300. fusion_road.emplace_back(trans_node_info);
  301. return SUCCESS;
  302. }
  303. Status VariableOpPass::CheckVariableRefLegally(const ge::NodePtr &var_node, bool &is_var_ref_legally) {
  304. is_var_ref_legally = true;
  305. GE_CHECK_NOTNULL(var_node);
  306. auto iterator = var_and_var_ref_map_.find(var_node);
  307. if (iterator == var_and_var_ref_map_.end()) {
  308. GELOGD("var name %s are not in var var_ref map", var_node->GetName().c_str());
  309. return SUCCESS;
  310. }
  311. GELOGD("var name %s, ref var count %zu.", var_node->GetName().c_str(), iterator->second.size());
  312. for (const auto &var_ref_node : iterator->second) {
  313. if (CheckVarAndVarRefAreAlike(var_node, var_ref_node, is_var_ref_legally) != SUCCESS) {
  314. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  315. }
  316. GELOGD("is_var_ref_legally is %d", is_var_ref_legally);
  317. if (!is_var_ref_legally) {
  318. return SUCCESS;
  319. }
  320. }
  321. return SUCCESS;
  322. }
  323. Status VariableOpPass::UpdateVarAndRefOutputFormatInfo(const GeTensorDesc &final_output, const ge::NodePtr &node) {
  324. if (node == nullptr || node->GetOpDesc() == nullptr) {
  325. GELOGE(FAILED, "node or opdesc is nullptr");
  326. return FAILED;
  327. }
  328. const Format &format = final_output.GetFormat();
  329. const DataType &data_type = final_output.GetDataType();
  330. const GeShape &shape = final_output.GetShape();
  331. GELOGD("last ref is (%s, %s, %lu), var_ref_name is %s.", TypeUtils::DataTypeToSerialString(data_type).c_str(),
  332. TypeUtils::FormatToSerialString(format).c_str(), shape.GetDims().size(), node->GetName().c_str());
  333. auto node_desc = node->GetOpDesc()->GetOutputDesc(0);
  334. CopyVariableFormatDataTypeAndShape(final_output, node_desc);
  335. if (node->GetOpDesc()->UpdateOutputDesc(0, node_desc) != GRAPH_SUCCESS) {
  336. GELOGE(FAILED, "update output desc fail.");
  337. return FAILED;
  338. }
  339. GELOGD("node ref is (%s, %s, %lu), var_ref_name is %s.",
  340. TypeUtils::DataTypeToSerialString(node->GetOpDesc()->GetOutputDesc(0).GetDataType()).c_str(),
  341. TypeUtils::FormatToSerialString(node->GetOpDesc()->GetOutputDesc(0).GetFormat()).c_str(),
  342. node->GetOpDesc()->GetOutputDesc(0).GetShape().GetDims().size(), node->GetName().c_str());
  343. auto iterator = var_and_var_ref_map_.find(node);
  344. if (iterator == var_and_var_ref_map_.end()) {
  345. auto graph = node->GetOwnerComputeGraph();
  346. if (GenerateVariableVariableRefMap(graph) != SUCCESS) {
  347. GELOGE(INTERNAL_ERROR, "Failed to generate variable map for graph %s", graph->GetName().c_str());
  348. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  349. }
  350. }
  351. iterator = var_and_var_ref_map_.find(node);
  352. if (iterator == var_and_var_ref_map_.end()) {
  353. GELOGW("The var node %s which belongs to graph %s can not be found on the graph", node->GetName().c_str(),
  354. node->GetOwnerComputeGraph()->GetName().c_str());
  355. return SUCCESS;
  356. }
  357. for (const auto &var_ref_node : iterator->second) {
  358. auto var_ref_node_description = var_ref_node->GetOpDesc();
  359. GE_CHECK_NOTNULL(var_ref_node_description);
  360. GELOGD("var_ref_node before is (%s, %s, %zu), var_ref_name is %s.",
  361. TypeUtils::DataTypeToSerialString(data_type).c_str(), TypeUtils::FormatToSerialString(format).c_str(),
  362. shape.GetDims().size(), var_ref_node->GetName().c_str());
  363. if (var_ref_node_description->UpdateOutputDesc(0, node_desc) != GRAPH_SUCCESS) {
  364. GELOGW("UpdateOutputDesc fail.");
  365. }
  366. if (var_ref_node_description->UpdateInputDesc(0, node_desc) != GRAPH_SUCCESS) {
  367. GELOGW("UpdateInputDesc fail.");
  368. }
  369. const auto &input_desc = var_ref_node_description->MutableInputDesc(0);
  370. const auto &output_desc = var_ref_node_description->MutableOutputDesc(0);
  371. GE_CHECK_NOTNULL(input_desc);
  372. GE_CHECK_NOTNULL(output_desc);
  373. GELOGD("var_ref_node ref is (%s, %s, %zu), var_ref_name is %s.",
  374. TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str(),
  375. TypeUtils::FormatToSerialString(input_desc->GetFormat()).c_str(), output_desc->GetShape().GetDims().size(),
  376. var_ref_node->GetName().c_str());
  377. }
  378. return SUCCESS;
  379. }
  380. Status VariableOpPass::GenerateVariableVariableRefMap(const ComputeGraphPtr &compute_graph) {
  381. std::map<std::string, NodePtr> names_to_var;
  382. std::map<std::string, std::set<NodePtr>> names_to_refs;
  383. GE_CHECK_NOTNULL(compute_graph);
  384. for (auto &node : compute_graph->GetDirectNode()) {
  385. if (node->GetType() != VARIABLE) {
  386. continue;
  387. }
  388. std::string ref_var_name;
  389. if (!ge::AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_name)) {
  390. names_to_var[node->GetName()] = node;
  391. } else {
  392. names_to_refs[ref_var_name].insert(node);
  393. }
  394. }
  395. for (auto &name_to_var : names_to_var) {
  396. var_and_var_ref_map_[name_to_var.second] = names_to_refs[name_to_var.first];
  397. }
  398. return SUCCESS;
  399. }
  400. Status VariableOpPass::CheckVarAndVarRefAreAlike(const NodePtr &var_node, const NodePtr &var_ref_node,
  401. bool &is_var_and_variable_ref_are_alike) {
  402. GE_CHECK_NOTNULL(var_node);
  403. GE_CHECK_NOTNULL(var_ref_node);
  404. GELOGD("var_node GetOutDataNodes. name is %s.", var_node->GetName().c_str());
  405. const auto &var_node_trans_nodes = var_node->GetOutDataNodes();
  406. GELOGD("var_node_trans_nodes size is %zu.", var_node_trans_nodes.size());
  407. GELOGD("var_ref_node GetOutDataNodes. name is %s.", var_ref_node->GetName().c_str());
  408. const auto &var_ref_node_trans_nodes = var_ref_node->GetInDataNodes();
  409. GELOGD("var_ref_node_trans_nodes size is %zu.", var_ref_node_trans_nodes.size());
  410. if (var_ref_node_trans_nodes.size() > 1) {
  411. GELOGE(GE_GRAPH_VARIABLE_OP_PASS_FAILED, "var_ref_node_trans_nodes.size() > 1.");
  412. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  413. }
  414. const auto &var_node_trans_node = var_node_trans_nodes.at(0);
  415. const auto &var_ref_node_trans_node = var_ref_node_trans_nodes.at(0);
  416. if (CheckTransNodeAreInverse(var_node_trans_node, var_ref_node_trans_node, is_var_and_variable_ref_are_alike) !=
  417. SUCCESS) {
  418. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  419. }
  420. return SUCCESS;
  421. }
  422. Status VariableOpPass::CheckTransNodeAreInverse(const NodePtr &node_a, const NodePtr &node_b, bool &is_same) {
  423. GELOGD("In CheckTransNodeAreInverse.");
  424. GE_CHECK_NOTNULL(node_a);
  425. GE_CHECK_NOTNULL(node_b);
  426. const auto &node_a_op_desc = node_a->GetOpDesc();
  427. const auto &node_b_op_desc = node_b->GetOpDesc();
  428. GE_CHECK_NOTNULL(node_a_op_desc);
  429. GE_CHECK_NOTNULL(node_b_op_desc);
  430. const auto &node_a_out_op_desc = node_a_op_desc->MutableOutputDesc(0);
  431. const auto &node_a_in_op_desc = node_a_op_desc->MutableInputDesc(0);
  432. GE_CHECK_NOTNULL(node_a_out_op_desc);
  433. GE_CHECK_NOTNULL(node_a_in_op_desc);
  434. const auto &node_b_out_op_desc = node_b_op_desc->MutableOutputDesc(0);
  435. const auto &node_b_in_op_desc = node_b_op_desc->MutableInputDesc(0);
  436. GE_CHECK_NOTNULL(node_b_out_op_desc);
  437. GE_CHECK_NOTNULL(node_b_in_op_desc);
  438. is_same = IsOpDescSame(node_a_out_op_desc, node_b_in_op_desc) && IsOpDescSame(node_b_out_op_desc, node_a_in_op_desc);
  439. return SUCCESS;
  440. }
  441. bool VariableOpPass::IsOpDescSame(const GeTensorDescPtr &op_desc_a, const GeTensorDescPtr &op_desc_b) {
  442. const auto &format_a = op_desc_a->GetFormat();
  443. const auto &type_a = op_desc_a->GetDataType();
  444. const auto &shape_a = op_desc_a->GetShape();
  445. const auto &format_b = op_desc_b->GetFormat();
  446. const auto &type_b = op_desc_b->GetDataType();
  447. const auto &shape_b = op_desc_b->GetShape();
  448. const auto &dims_a = shape_a.GetDims();
  449. const auto &dims_b = shape_b.GetDims();
  450. GELOGD("(format, data type, shape) = (%s, %s, %zu) (%s, %s, %zu)", TypeUtils::FormatToSerialString(format_a).c_str(),
  451. TypeUtils::DataTypeToSerialString(type_a).c_str(), dims_a.size(),
  452. TypeUtils::FormatToSerialString(format_b).c_str(), TypeUtils::DataTypeToSerialString(type_b).c_str(),
  453. dims_b.size());
  454. return (format_a == format_b) && (type_a == type_b) && (dims_a == dims_b);
  455. }
  456. void VariableOpPass::CopyVariableFormatDataTypeAndShape(const GeTensorDesc &src_tensor_desc,
  457. GeTensorDesc &dst_tensor_desc) {
  458. dst_tensor_desc.SetShape(src_tensor_desc.GetShape());
  459. dst_tensor_desc.SetFormat(src_tensor_desc.GetFormat());
  460. dst_tensor_desc.SetDataType(src_tensor_desc.GetDataType());
  461. }
  462. Status VariableOpPass::CheckIfCouldBeOptimized(const ge::NodePtr &node, bool &flag, VarTransRoad &fusion_road) {
  463. if (node == nullptr) {
  464. return FAILED;
  465. }
  466. bool is_matched = false;
  467. auto ret = CheckSameAndTransOp(node, is_matched, fusion_road);
  468. if (ret != SUCCESS) {
  469. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  470. }
  471. if (!is_matched) {
  472. flag = false;
  473. return SUCCESS;
  474. }
  475. bool is_var_ref_legally = false;
  476. ret = CheckVariableRefLegally(node, is_var_ref_legally);
  477. if (ret != SUCCESS) {
  478. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  479. }
  480. GELOGD("is_var_ref_legally is %d.", is_var_ref_legally);
  481. if (!is_var_ref_legally) {
  482. GELOGI("variable ref connection are illegally");
  483. flag = false;
  484. fusion_road.clear();
  485. return SUCCESS;
  486. }
  487. flag = true;
  488. GELOGD("node %s, is_matched = %d is_var_ref_legally = %d, flag = %d", node->GetName().c_str(), is_matched,
  489. is_var_ref_legally, flag);
  490. return SUCCESS;
  491. }
  492. Status VariableOpPass::FusionIfNeed(const NodePtr &var, VarTransRoad &fusion_road) {
  493. bool can_fusion = false;
  494. while (true) {
  495. auto ret = CheckIfCouldBeOptimized(var, can_fusion, fusion_road);
  496. if (ret != SUCCESS) {
  497. return ret;
  498. }
  499. if (!can_fusion) {
  500. break;
  501. }
  502. ret = DealFusion(var);
  503. if (ret != SUCCESS) {
  504. return ret;
  505. }
  506. }
  507. return SUCCESS;
  508. }
  509. Status VariableOpPass::UpdateIOFormatInfo(const GeTensorDesc &final_output, std::set<NodePtr> &nodes) {
  510. for (auto &need_set_node : nodes) {
  511. auto ret = UpdateVarAndRefOutputFormatInfo(final_output, need_set_node);
  512. if (ret != SUCCESS) {
  513. return GE_GRAPH_VARIABLE_OP_PASS_FAILED;
  514. }
  515. }
  516. return SUCCESS;
  517. }
  518. Status VariableOpPass::RenewVarDesc(ge::ComputeGraphPtr &graph) {
  519. GE_CHECK_NOTNULL(graph);
  520. // renew var manager desc
  521. Status ret = SUCCESS;
  522. for (auto &node : graph->GetDirectNode()) {
  523. bool is_var_node =
  524. (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == VARHANDLEOP);
  525. if (is_var_node) {
  526. if (!ge::VarManager::Instance(graph->GetSessionID())->IsVarExist(node->GetName())) {
  527. GELOGD("var manager does not exist var node[%s]", node->GetName().c_str());
  528. continue;
  529. }
  530. GELOGD("var manager exist var node[%s], graph name[%s]", node->GetName().c_str(), graph->GetName().c_str());
  531. GE_CHECK_NOTNULL(node->GetOpDesc());
  532. ret = ge::VarManager::Instance(graph->GetSessionID())->RenewCurVarDesc(node->GetName(), node->GetOpDesc());
  533. if (ret != SUCCESS) {
  534. GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str());
  535. return FAILED;
  536. }
  537. }
  538. }
  539. return SUCCESS;
  540. }
  541. Status VariableOpPass::RenewVarDesc(uint64_t session_id, const NodePtr &node, const VarTransRoad &fusion_road) {
  542. // renew var desc if the trans_road is all reshape or reformat
  543. for (auto &road : fusion_road) {
  544. if (road.node_type != RESHAPE && road.node_type != REFORMAT) {
  545. return SUCCESS;
  546. }
  547. }
  548. if (!ge::VarManager::Instance(session_id)->IsVarExist(node->GetName())) {
  549. GELOGD("var manager does not exist var node[%s]", node->GetName().c_str());
  550. return SUCCESS;
  551. }
  552. GELOGD("var manager exist var node[%s]", node->GetName().c_str());
  553. GE_CHECK_NOTNULL(node->GetOpDesc());
  554. Status ret = ge::VarManager::Instance(session_id)->RenewCurVarDesc(node->GetName(), node->GetOpDesc());
  555. if (ret != SUCCESS) {
  556. GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str());
  557. return FAILED;
  558. }
  559. return SUCCESS;
  560. }
  561. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示