You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

infer_base_pass.cc 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "infer_base_pass.h"
  17. #include "common/ge/ge_util.h"
  18. #include "common/util/error_manager/error_manager.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "framework/common/util.h"
  21. #include "graph/debug/ge_attr_define.h"
  22. #include "graph/debug/ge_util.h"
  23. #include "graph/utils/graph_utils.h"
  24. #include "graph/utils/node_utils.h"
  25. #include "graph/utils/tensor_utils.h"
  26. #include "graph/utils/type_utils.h"
  27. namespace ge {
  28. namespace {
  29. string Serial(const vector<int64_t> &dims) {
  30. string serial_string;
  31. serial_string += "[";
  32. for (int64_t dim : dims) {
  33. serial_string += std::to_string(dim) + " ";
  34. }
  35. serial_string += "]";
  36. return serial_string;
  37. }
  38. void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
  39. desc_str += "[";
  40. std::vector<std::pair<int64_t, int64_t>> shape_range;
  41. (void)desc->GetShapeRange(shape_range);
  42. for (const auto &pair : shape_range) {
  43. desc_str += "{";
  44. desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
  45. desc_str += "},";
  46. }
  47. desc_str += "]";
  48. shape_range.clear();
  49. (void)desc->GetOriginShapeRange(shape_range);
  50. for (const auto &pair : shape_range) {
  51. desc_str += ",{";
  52. desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
  53. desc_str += "},";
  54. }
  55. }
  56. void SerialValueRange(const GeTensorDescPtr &desc, std::string &desc_str) {
  57. desc_str += "[";
  58. std::vector<std::pair<int64_t, int64_t>> value_range;
  59. (void)desc->GetValueRange(value_range);
  60. for (const auto &pair : value_range) {
  61. desc_str += "{";
  62. desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
  63. desc_str += "},";
  64. }
  65. desc_str += "]";
  66. }
  67. graphStatus FindSubgraphDataAndNetoutput(const ComputeGraphPtr &sub_graph, NodePtr &netoutput, const ConstNodePtr &node,
  68. std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) {
  69. auto sub_nodes = sub_graph->GetDirectNode();
  70. for (size_t i = sub_nodes.size(); i > 0; --i) {
  71. auto sub_node = sub_nodes.at(i - 1);
  72. if (sub_node->GetType() == NETOUTPUT) {
  73. netoutput = sub_node;
  74. }
  75. if (sub_node->GetType() == DATA) {
  76. if (sub_node->GetOpDesc() == nullptr) {
  77. return GRAPH_FAILED;
  78. }
  79. int ref_i;
  80. if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  81. REPORT_INNER_ERROR("E19999", "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
  82. GELOGE(GRAPH_FAILED, "[Get][Int] subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
  83. return GRAPH_FAILED;
  84. }
  85. if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) {
  86. REPORT_INNER_ERROR("E19999", "data node[%s]'s ref index[%d] is not in range [0, %u)!",
  87. sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize());
  88. GELOGE(GRAPH_FAILED, "[Check][Param] data node[%s]'s ref index[%d] is not in range [0, %u)!",
  89. sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize());
  90. return GRAPH_FAILED;
  91. }
  92. ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0));
  93. }
  94. }
  95. return GRAPH_SUCCESS;
  96. }
  97. } // namespace
  98. Status InferBasePass::Run(NodePtr &node) {
  99. GE_CHECK_NOTNULL(node);
  100. GE_CHECK_NOTNULL(node->GetOpDesc());
  101. bool need_infer = NeedInfer(node);
  102. if (!need_infer) {
  103. GELOGD("Node %s does not need to infer.", node->GetName().c_str());
  104. return SUCCESS;
  105. }
  106. std::set<NodePtr> changed_nodes;
  107. auto ret = InferAndUpdate(node, !OptionExists(kOptimizeAfterSubGraph), changed_nodes);
  108. if (ret != GRAPH_SUCCESS) {
  109. (void)AnalyzeFailedInfo(node);
  110. return GE_GRAPH_INFERSHAPE_FAILED;
  111. }
  112. /*
  113. * we will use changed nodes to do repass for control_ops.
  114. * AddChangedNodesImmediateRepass(changed_nodes);
  115. */
  116. auto status = DoRepassForLoopNode(node);
  117. if (status != SUCCESS) {
  118. GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "repass failed. node: %s", node->GetName().c_str());
  119. return GE_GRAPH_INFERSHAPE_FAILED;
  120. }
  121. return SUCCESS;
  122. }
  123. bool InferBasePass::NeedInfer(const NodePtr &node) { return true; }
  124. void InferBasePass::AnalyzeFailedInfo(const NodePtr &node) { /* Analyze and select failed info*/ }
  125. Status InferBasePass::DoRepassForLoopNode(NodePtr &node) { return SUCCESS; }
  126. graphStatus InferBasePass::UpdatePeerInputs(NodePtr &node) { return GRAPH_SUCCESS; }
  127. void InferBasePass::AddChangedNodesImmediateRepass(std::set<NodePtr> &changed_nodes) {
  128. for (const auto &node_ele : changed_nodes) {
  129. AddImmediateRePassNode(node_ele);
  130. }
  131. }
  132. graphStatus InferBasePass::InferAndUpdate(NodePtr &node, bool before_subgraph, std::set<NodePtr> &changed_nodes) {
  133. auto ret = GRAPH_SUCCESS;
  134. bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
  135. auto opdesc = node->GetOpDesc();
  136. // some op can not infershape twice such as aipp
  137. bool need_update_input = !is_unknown_graph && !opdesc->HasAttr("has_infered_verified");
  138. if (need_update_input) {
  139. ret = UpdateCurOpInputDesc(node);
  140. if (ret != GRAPH_SUCCESS) {
  141. REPORT_CALL_ERROR("E19999", "update op input_desc failed! ret:%d, node:%s", ret, node->GetName().c_str());
  142. GELOGE(GRAPH_FAILED, "[Update][OpInputDesc] failed! ret:%d", ret);
  143. return ret;
  144. }
  145. }
  146. bool contain_subgraph = ContainsSubgraph(node);
  147. if (contain_subgraph && before_subgraph) {
  148. ret = UpdateTensorDescToSubgraphData(node, changed_nodes);
  149. if (ret != GRAPH_SUCCESS) {
  150. return ret;
  151. }
  152. }
  153. ret = Infer(node);
  154. if (ret != GRAPH_SUCCESS) {
  155. return ret;
  156. }
  157. if (contain_subgraph && !before_subgraph) {
  158. ret = UpdateTensorDescToParentNode(node, changed_nodes);
  159. if (ret != GRAPH_SUCCESS) {
  160. return ret;
  161. }
  162. }
  163. ret = UpdatePeerInputs(node);
  164. return ret;
  165. }
  166. graphStatus InferBasePass::UpdateCurOpInputDesc(const NodePtr &node_ptr) {
  167. for (const auto &in_anchor : node_ptr->GetAllInDataAnchors()) {
  168. auto in_idx = in_anchor->GetIdx();
  169. auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor();
  170. if (peer_out_data_anchor == nullptr) {
  171. continue;
  172. }
  173. auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode();
  174. if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) {
  175. continue;
  176. }
  177. int peer_out_idx = peer_out_data_anchor->GetIdx();
  178. auto peer_out_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(peer_out_idx));
  179. // check shape and dtype continuity. do not stop process
  180. auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_idx));
  181. if (in_desc == nullptr) {
  182. continue;
  183. }
  184. auto in_shape = in_desc->MutableShape().GetDims();
  185. auto in_dtype = in_desc->GetDataType();
  186. auto peer_out_shape = peer_out_desc->MutableShape().GetDims();
  187. auto peer_out_dtype = peer_out_desc->GetDataType();
  188. if (peer_out_dtype != in_dtype) {
  189. GELOGW(
  190. "current node [%s] [%d]\'th in_dtype is [%s].peer output node [%s] [%d]\'th "
  191. "output_dtype is [%s].The two dtype should be same! Please check graph and fix it",
  192. node_ptr->GetName().c_str(), in_idx, TypeUtils::DataTypeToSerialString(in_dtype).c_str(),
  193. peer_out_data_node->GetName().c_str(), peer_out_idx, TypeUtils::DataTypeToSerialString(peer_out_dtype).c_str());
  194. } else if ((!in_shape.empty()) && (in_shape != peer_out_shape)) {
  195. string in_shape_str = Serial(in_shape);
  196. string peer_out_shape_str = Serial(peer_out_shape);
  197. GELOGW(
  198. "current node [%s] [%d]\'th in_shape is [%s].peer output node [%s] [%d]\'th "
  199. "output_shape is [%s].The two shape should be same! Please check graph and fix it",
  200. node_ptr->GetName().c_str(), in_idx, in_shape_str.c_str(), peer_out_data_node->GetName().c_str(), peer_out_idx,
  201. peer_out_shape_str.c_str());
  202. }
  203. // refresh current node input desc
  204. bool output_changed = false;
  205. (void)UpdateInputDescAttr(peer_out_desc, in_desc, output_changed);
  206. }
  207. return GRAPH_SUCCESS;
  208. }
  209. graphStatus InferBasePass::UpdateInputDescAttr(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
  210. changed = false;
  211. return GRAPH_SUCCESS;
  212. }
  213. bool InferBasePass::ContainsSubgraph(const NodePtr &node) {
  214. auto op_desc = node->GetOpDesc();
  215. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  216. if (sub_graph_names.empty()) {
  217. return false;
  218. }
  219. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  220. if (root_graph == nullptr) {
  221. return false;
  222. }
  223. for (const auto &name : sub_graph_names) {
  224. if (name.empty()) {
  225. continue;
  226. }
  227. auto sub_graph = root_graph->GetSubgraph(name);
  228. if (sub_graph != nullptr) {
  229. return true;
  230. }
  231. }
  232. return false;
  233. }
  234. std::vector<ComputeGraphPtr> InferBasePass::GetCurNodeSubgraphs(const NodePtr &node) {
  235. std::vector<ComputeGraphPtr> cur_node_subgraph;
  236. auto op_desc = node->GetOpDesc();
  237. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  238. if (sub_graph_names.empty()) {
  239. return cur_node_subgraph;
  240. }
  241. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  242. for (const auto &name : sub_graph_names) {
  243. if (name.empty()) {
  244. GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
  245. continue;
  246. }
  247. auto sub_graph = root_graph->GetSubgraph(name);
  248. if (sub_graph == nullptr) {
  249. REPORT_INNER_ERROR("E19999", "Can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  250. GE_LOGE("[Get][Graph] can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  251. continue;
  252. }
  253. cur_node_subgraph.emplace_back(sub_graph);
  254. }
  255. return cur_node_subgraph;
  256. }
  257. graphStatus InferBasePass::UpdateTensorDescToSubgraphData(NodePtr &node, std::set<NodePtr> &changed_nodes) {
  258. // if infer again, update output of while into subgraph data node
  259. auto op_desc = node->GetOpDesc();
  260. for (const auto &sub_graph : GetCurNodeSubgraphs(node)) {
  261. for (const auto &node_sub : sub_graph->GetDirectNode()) {
  262. if (node_sub->GetType() != DATA) {
  263. continue;
  264. }
  265. auto name = sub_graph->GetName();
  266. int ref_i;
  267. auto data_opdesc = node_sub->GetOpDesc();
  268. if (data_opdesc == nullptr) {
  269. REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
  270. node->GetName().c_str());
  271. GE_LOGE("[Get][OpDesc] Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
  272. node->GetName().c_str());
  273. return GRAPH_FAILED;
  274. }
  275. if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  276. REPORT_INNER_ERROR("E19999", "Invalid data node on the sub graph %s parent node %s, no ref-index attribute",
  277. name.c_str(), node->GetName().c_str());
  278. GE_LOGE("[Get][Int] Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(),
  279. node->GetName().c_str());
  280. return GRAPH_FAILED;
  281. }
  282. if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) {
  283. continue;
  284. }
  285. auto input_desc = op_desc->MutableInputDesc(ref_i);
  286. if (input_desc == nullptr) {
  287. REPORT_INNER_ERROR("E19999",
  288. "The ref index(%d) on the data %s on the sub graph %s "
  289. "parent node %s are incompatible, inputs num %u",
  290. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(),
  291. node->GetAllInDataAnchorsSize());
  292. GE_LOGE(
  293. "[Call][MutableInputDesc] The ref index(%d) on the data %s on the sub graph %s "
  294. "parent node %s are incompatible, inputs num %u",
  295. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllInDataAnchorsSize());
  296. return GRAPH_FAILED;
  297. }
  298. GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(),
  299. node->GetName().c_str());
  300. // if need infer again, refresh subgraph input with output
  301. bool is_infer_again = false;
  302. AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, is_infer_again);
  303. if (is_infer_again) {
  304. input_desc = op_desc->MutableOutputDesc(ref_i);
  305. if (input_desc == nullptr) {
  306. REPORT_INNER_ERROR("E19999",
  307. "The ref index(%d) on the data %s on the subgraph %s "
  308. "parent node %s are incompatible, outputs num %u.",
  309. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(),
  310. node->GetAllOutDataAnchorsSize());
  311. GELOGE(PARAM_INVALID,
  312. "[Call][MutableOutputDesc] The ref index(%d) on the data %s on the subgraph %s "
  313. "parent node %s are incompatible, outputs num %u.",
  314. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(),
  315. node->GetAllOutDataAnchorsSize());
  316. }
  317. GELOGD("Update input desc of data %s on the sub graph %s of node %s,output idx: %d from [%s] to [%s]",
  318. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), ref_i,
  319. data_opdesc->GetInputDescPtr(0)->GetShape().ToString().c_str(),
  320. input_desc->GetShape().ToString().c_str());
  321. }
  322. auto data_input_desc = data_opdesc->MutableInputDesc(0);
  323. auto ret = data_opdesc->UpdateInputDesc(0, *input_desc);
  324. if (ret != GRAPH_SUCCESS) {
  325. REPORT_CALL_ERROR("E19999", "Failed to update input desc of data %s on the sub graph %s parent node %s",
  326. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  327. GE_LOGE("[Update][InputDesc] of data %s on the sub graph %s parent node %s failed", node_sub->GetName().c_str(),
  328. name.c_str(), node->GetName().c_str());
  329. return ret;
  330. }
  331. bool input_changed = TensorDescChanged(input_desc, data_input_desc);
  332. auto data_output_desc = data_opdesc->MutableOutputDesc(0);
  333. ret = data_opdesc->UpdateOutputDesc(0, *input_desc);
  334. if (ret != GRAPH_SUCCESS) {
  335. REPORT_CALL_ERROR("E19999", "Failed to update output desc of data %s on the sub graph %s parent node %s",
  336. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  337. GE_LOGE("[Update][OutputDesc] of data %s on the sub graph %s parent node %s failed",
  338. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  339. return ret;
  340. }
  341. bool output_changed = TensorDescChanged(input_desc, data_output_desc);
  342. if (input_changed || output_changed) {
  343. changed_nodes.insert(node_sub);
  344. }
  345. }
  346. }
  347. return GRAPH_SUCCESS;
  348. }
  349. graphStatus InferBasePass::UpdateTensorDescToParentNode(NodePtr &node, std::set<NodePtr> &changed_nodes) {
  350. std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize());
  351. std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize());
  352. for (const auto &sub_graph : GetCurNodeSubgraphs(node)) {
  353. auto name = sub_graph->GetName();
  354. NodePtr netoutput = nullptr;
  355. auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors);
  356. if (ret != GRAPH_SUCCESS) {
  357. return ret;
  358. }
  359. if (netoutput == nullptr) {
  360. REPORT_INNER_ERROR("E19999", "No NetOutput node on sub graph %s, parent node %s", name.c_str(),
  361. node->GetName().c_str());
  362. GE_LOGE("[Check][Param] No NetOutput node on sub graph %s, parent node %s", name.c_str(),
  363. node->GetName().c_str());
  364. return GRAPH_FAILED;
  365. }
  366. auto netoutput_opdesc = netoutput->GetOpDesc();
  367. if (netoutput_opdesc == nullptr) {
  368. REPORT_INNER_ERROR("E19999", "Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it",
  369. name.c_str(), node->GetName().c_str());
  370. GE_LOGE("[Get][OpDesc] Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(),
  371. node->GetName().c_str());
  372. return GRAPH_FAILED;
  373. }
  374. for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) {
  375. auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx());
  376. if (edge_desc == nullptr) {
  377. REPORT_INNER_ERROR("E19999",
  378. "Invalid NetOutput node on sub graph %s, parent node %s, "
  379. "can not find input tensor %d",
  380. name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx());
  381. GE_LOGE("[Get][Tensor] Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d",
  382. name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx());
  383. return GRAPH_FAILED;
  384. }
  385. GELOGI("Netoutput in anchor index is %d, input tensor dim is %zu", edge_anchor->GetIdx(),
  386. edge_desc->GetShape().GetDimNum());
  387. int ref_i;
  388. if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  389. // if there is no ref index on the TensorDesc, it means the output data will be ignored outer.
  390. continue;
  391. }
  392. GELOGI("Parent node index of edge desc is %d", ref_i);
  393. if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) {
  394. return GRAPH_FAILED;
  395. }
  396. ref_out_tensors[ref_i].emplace_back(*edge_desc);
  397. }
  398. }
  399. if (node->GetType() == WHILE) {
  400. return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors, changed_nodes);
  401. }
  402. return UpdateParentNodeForBranch(node, ref_out_tensors, changed_nodes);
  403. }
  404. graphStatus InferBasePass::UpdateParentNodeForWhile(NodePtr &node,
  405. std::vector<std::vector<GeTensorDesc>> &ref_data_tensors,
  406. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
  407. std::set<NodePtr> &changed_nodes) {
  408. GELOGD("Enter update parent node shape for class while op process");
  409. if (ref_data_tensors.size() != ref_out_tensors.size()) {
  410. REPORT_INNER_ERROR("E19999", "op:%s(%s) input number[%zu] and output number[%zu] is not same!",
  411. node->GetName().c_str(), node->GetType().c_str(), ref_data_tensors.size(),
  412. ref_out_tensors.size());
  413. GELOGE(GRAPH_FAILED, "[Check][Param] while op [%s] input number[%zu] and output number[%zu] is not same!",
  414. node->GetName().c_str(), ref_data_tensors.size(), ref_out_tensors.size());
  415. return GRAPH_FAILED;
  416. }
  417. for (size_t i = 0; i < ref_data_tensors.size(); i++) {
  418. if (ref_out_tensors[i].size() != 1) {
  419. REPORT_INNER_ERROR("E19999", "while op, every output should only find one output tensor in all graph!");
  420. GELOGE(GRAPH_FAILED, "[Check][Param] while op, every output should only find one output tensor in all graph!");
  421. return GRAPH_FAILED;
  422. }
  423. }
  424. bool need_infer_again = false;
  425. // check input and output
  426. for (size_t i = 0; i < ref_out_tensors.size(); i++) {
  427. if (ref_out_tensors[i].empty()) {
  428. continue;
  429. }
  430. auto ref_out_tensor = ref_out_tensors[i].at(0);
  431. auto out_shape = ref_out_tensor.MutableShape();
  432. vector<std::pair<int64_t, int64_t>> data_shape_range;
  433. // ref_i's data and output tensor shape should be same
  434. for (auto &tensor : ref_data_tensors[i]) {
  435. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  436. REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype or format among all ref output",
  437. node->GetName().c_str());
  438. GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype or format output.",
  439. node->GetName().c_str());
  440. return GRAPH_FAILED;
  441. }
  442. auto data_shape = tensor.MutableShape();
  443. // input is dynamic, here use dim_num
  444. if (data_shape.GetDims() != out_shape.GetDims()) {
  445. GELOGI("After infer, While %s %zu output shape [%s] is not match with input shape [%s].Need infer again.",
  446. node->GetName().c_str(), i, out_shape.ToString().c_str(), data_shape.ToString().c_str());
  447. if (data_shape.GetDimNum() != out_shape.GetDimNum()) {
  448. ref_out_tensor.SetUnknownDimNumShape();
  449. } else {
  450. for (size_t j = 0; j < data_shape.GetDimNum(); ++j) {
  451. if (data_shape.GetDim(j) != out_shape.GetDim(j)) {
  452. if (data_shape.GetDim(j) != UNKNOWN_DIM) {
  453. // if input data is fix shape, output is different, need_infer_again
  454. need_infer_again = true;
  455. }
  456. data_shape.SetDim(j, UNKNOWN_DIM);
  457. }
  458. // set shape rang of while, if dim is unknown ,set shape range as {1,-1}
  459. if (data_shape.GetDim(j) == UNKNOWN_DIM) {
  460. data_shape_range.emplace_back(std::make_pair(1, UNKNOWN_DIM));
  461. } else {
  462. data_shape_range.emplace_back(std::make_pair(data_shape.GetDim(j), data_shape.GetDim(j)));
  463. }
  464. }
  465. ref_out_tensor.SetShape(data_shape);
  466. ref_out_tensor.SetShapeRange(data_shape_range);
  467. }
  468. }
  469. }
  470. auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
  471. (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
  472. bool output_changed = TensorDescChanged(ComGraphMakeShared<GeTensorDesc>(ref_out_tensor), output_desc);
  473. if (output_changed) {
  474. changed_nodes.insert(node);
  475. }
  476. }
  477. AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_infer_again);
  478. return GRAPH_SUCCESS;
  479. }
  480. graphStatus InferBasePass::UpdateOutputForMultiBatch(NodePtr &node,
  481. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
  482. std::set<NodePtr> &changed_nodes) {
  483. // check sub_graph shape. Get max for update.
  484. for (size_t i = 0; i < ref_out_tensors.size(); ++i) {
  485. if (ref_out_tensors[i].empty()) {
  486. continue;
  487. }
  488. int64_t max_size = 0;
  489. size_t max_shape_index = 0;
  490. auto &ref_out_tensor = ref_out_tensors[i].at(0);
  491. for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) {
  492. auto &tensor = ref_out_tensors[i].at(j);
  493. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  494. REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output",
  495. node->GetName().c_str());
  496. GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype among all ref output",
  497. node->GetName().c_str());
  498. return GRAPH_FAILED;
  499. }
  500. auto shape = tensor.MutableShape();
  501. int64_t size = 1;
  502. for (auto dim : shape.GetDims()) {
  503. if (dim != 0 && INT64_MAX / dim < size) {
  504. REPORT_INNER_ERROR("E19999", "The shape:%s size overflow, node:%s", shape.ToString().c_str(),
  505. node->GetName().c_str());
  506. GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow");
  507. return PARAM_INVALID;
  508. }
  509. size *= dim;
  510. }
  511. if (size > max_size) {
  512. max_size = size;
  513. max_shape_index = j;
  514. }
  515. }
  516. auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
  517. (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index));
  518. bool output_changed =
  519. TensorDescChanged(ComGraphMakeShared<GeTensorDesc>(ref_out_tensors[i].at(max_shape_index)), output_desc);
  520. if (output_changed) {
  521. changed_nodes.insert(node);
  522. }
  523. }
  524. return GRAPH_SUCCESS;
  525. }
  526. graphStatus InferBasePass::UpdateParentNodeForBranch(NodePtr &node,
  527. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors,
  528. std::set<NodePtr> &changed_nodes) {
  529. GELOGD("Enter update parent node shape for class branch op process");
  530. if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) {
  531. return UpdateOutputForMultiBatch(node, ref_out_tensors, changed_nodes);
  532. }
  533. // check sub_graph shape.If not same ,do unknown shape process
  534. for (size_t i = 0; i < ref_out_tensors.size(); i++) {
  535. if (ref_out_tensors[i].empty()) {
  536. continue;
  537. }
  538. auto ref_out_tensor = ref_out_tensors[i].at(0);
  539. ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape();
  540. for (auto &tensor : ref_out_tensors[i]) {
  541. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  542. REPORT_INNER_ERROR("E19999", "node[%s] does not support diff dtype among all ref output, shape:%s",
  543. node->GetName().c_str(), ref_out_tensor_shape.ToString().c_str());
  544. GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype output", node->GetName().c_str());
  545. return GRAPH_FAILED;
  546. }
  547. auto shape = tensor.MutableShape();
  548. if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
  549. GELOGD("node is %s, i : %zu, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i,
  550. shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
  551. ref_out_tensor_shape = GeShape(UNKNOWN_RANK);
  552. break;
  553. }
  554. for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) {
  555. if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) {
  556. continue;
  557. }
  558. GELOGD("node is %s, i : %zu, j: %zu ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(),
  559. i, j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
  560. (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM);
  561. }
  562. }
  563. auto output_desc = node->GetOpDesc()->MutableOutputDesc(i);
  564. (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
  565. bool output_changed =
  566. TensorDescChanged(ComGraphMakeShared<GeTensorDesc>(ref_out_tensor), output_desc);
  567. if (output_changed) {
  568. changed_nodes.insert(node);
  569. }
  570. }
  571. return GRAPH_SUCCESS;
  572. }
  573. void InferBasePass::PrintInOutTensorShape(const NodePtr &node, const std::string &phase) {
  574. if (!IsLogEnable(GE, DLOG_DEBUG)) {
  575. return;
  576. }
  577. if (node == nullptr) {
  578. REPORT_INNER_ERROR("E19999", "param node is nullprt, check invalid");
  579. GELOGE(GRAPH_FAILED, "[Check][Param] node is null");
  580. return;
  581. }
  582. ge::OpDescPtr op_desc = node->GetOpDesc();
  583. GE_IF_BOOL_EXEC(op_desc == nullptr, REPORT_INNER_ERROR("E19999", "node has no opdesc, check invalid");
  584. GELOGE(GRAPH_FAILED, "[Get][OpDesc] op_desc is null."); return );
  585. std::stringstream ss;
  586. ss << "{";
  587. int32_t in_idx = 0;
  588. int32_t out_idx = 0;
  589. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  590. if (input_desc == nullptr) {
  591. in_idx++;
  592. continue;
  593. }
  594. if (in_idx > 0) {
  595. ss << " ";
  596. }
  597. ss << "input_" << in_idx << " "
  598. << "tensor: [";
  599. ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),";
  600. ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),";
  601. ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),";
  602. ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),";
  603. ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),";
  604. ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),";
  605. string range_str;
  606. SerialShapeRange(input_desc, range_str);
  607. ss << "(shape_range:" << range_str << "),";
  608. string value_range_str;
  609. SerialValueRange(input_desc, value_range_str);
  610. ss << "(value_range:" << value_range_str << ")]";
  611. in_idx++;
  612. }
  613. for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
  614. if (output_desc == nullptr) {
  615. out_idx++;
  616. continue;
  617. }
  618. ss << " ";
  619. ss << "output_" << out_idx << " "
  620. << "tensor: [";
  621. ss << "(shape:[" << output_desc->MutableShape().ToString() << "]),";
  622. ss << "(format:" << TypeUtils::FormatToSerialString(output_desc->GetFormat()) << "),";
  623. ss << "(dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) << "),";
  624. ss << "(origin_shape:" << output_desc->GetOriginShape().ToString() << "),";
  625. ss << "(origin_format:" << TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) << "),";
  626. ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) << "),";
  627. string range_str;
  628. SerialShapeRange(output_desc, range_str);
  629. ss << "(shape_range:" << range_str << "),";
  630. string value_range_str;
  631. SerialValueRange(output_desc, value_range_str);
  632. ss << "(value_range:" << value_range_str << ")]";
  633. out_idx++;
  634. }
  635. ss << "}";
  636. GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), ss.str().c_str());
  637. }
  638. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示