You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

shape_refiner.cc 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/shape_refiner.h"
  17. #include <memory>
  18. #include <string>
  19. #include <unordered_map>
  20. #include <utility>
  21. #include <vector>
  22. #include "graph/debug/ge_attr_define.h"
  23. #include "graph/utils/graph_utils.h"
  24. #include "debug/ge_log.h"
  25. #include "debug/ge_op_types.h"
  26. #include "debug/ge_util.h"
  27. #include "external/graph/operator.h"
  28. #include "external/graph/operator_factory.h"
  29. #include "framework/common/debug/ge_log.h"
  30. #include "graph/compute_graph.h"
  31. #include "graph/operator_factory_impl.h"
  32. #include "utils/node_utils.h"
  33. #include "utils/op_desc_utils.h"
  34. #include "utils/tensor_utils.h"
  35. #include "utils/type_utils.h"
  36. namespace ge {
  37. namespace {
  38. const uint32_t kWhileBodySubGraphIdx = 1;
  39. graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) {
  40. GELOGD("Enter reverse brush while body subgraph process!");
  41. auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx);
  42. if (sub_graph_body == nullptr) {
  43. GELOGE(GRAPH_FAILED, "Get while body graph failed!");
  44. return GRAPH_FAILED;
  45. }
  46. for (const auto &node_sub : sub_graph_body->GetAllNodes()) {
  47. // const/constant/variale etc. no need to reverse brush
  48. if (node_sub->GetInDataNodes().empty() && node_sub->GetType() != DATA) {
  49. continue;
  50. }
  51. for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) {
  52. auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i);
  53. GE_IF_BOOL_EXEC(input_desc == nullptr,
  54. GELOGW("Get null input by index %zu from node %s ",
  55. i, node_sub->GetName().c_str());
  56. continue);
  57. (void)input_desc->SetUnknownDimNumShape();
  58. }
  59. for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) {
  60. auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i);
  61. (void)output_desc->SetUnknownDimNumShape();
  62. }
  63. }
  64. for (const auto &node_sub : sub_graph_body->GetAllNodes()) {
  65. if (!node_sub->GetInDataNodes().empty() || node_sub->GetType() == DATA) {
  66. continue;
  67. }
  68. for (const auto &out_data_anchor : node_sub->GetAllOutDataAnchors()) {
  69. GE_CHECK_NOTNULL(out_data_anchor);
  70. auto out_data_anchor_idx = out_data_anchor->GetIdx();
  71. for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  72. GE_CHECK_NOTNULL(peer_in_data_anchor);
  73. auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode();
  74. GE_CHECK_NOTNULL(peer_in_data_node);
  75. GE_CHECK_NOTNULL(peer_in_data_node->GetOpDesc());
  76. int idx = peer_in_data_anchor->GetIdx();
  77. auto shape = node_sub->GetOpDesc()->MutableOutputDesc(out_data_anchor_idx)->GetShape();
  78. peer_in_data_node->GetOpDesc()->MutableInputDesc(idx)->SetShape(shape);
  79. }
  80. }
  81. }
  82. return GRAPH_SUCCESS;
  83. }
  84. graphStatus UpdataOutputForMultiBatcch(const ConstNodePtr &node,
  85. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) {
  86. // check sub_graph shape. Get max for update.
  87. for (size_t i = 0; i < ref_out_tensors.size(); ++i) {
  88. if (ref_out_tensors[i].empty()) {
  89. continue;
  90. }
  91. int64_t max_size = 0;
  92. size_t max_shape_index = 0;
  93. auto &ref_out_tensor = ref_out_tensors[i].at(0);
  94. for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) {
  95. auto &tensor = ref_out_tensors[i].at(j);
  96. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  97. GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str());
  98. return GRAPH_FAILED;
  99. }
  100. auto shape = tensor.MutableShape();
  101. int64_t size = 1;
  102. for (auto dim : shape.GetDims()) {
  103. if (INT64_MAX / dim < size) {
  104. GELOGE(PARAM_INVALID, "The shape size overflow");
  105. return PARAM_INVALID;
  106. }
  107. size *= dim;
  108. }
  109. if (size > max_size) {
  110. max_size = size;
  111. max_shape_index = j;
  112. }
  113. }
  114. (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index));
  115. }
  116. return GRAPH_SUCCESS;
  117. }
  118. graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node,
  119. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) {
  120. GELOGD("Enter update parent node shape for class branch op process");
  121. if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) {
  122. return UpdataOutputForMultiBatcch(node, ref_out_tensors);
  123. }
  124. // check sub_graph shape.If not same ,do unknown shape process
  125. for (size_t i = 0; i < ref_out_tensors.size(); i++) {
  126. if (ref_out_tensors[i].empty()) {
  127. continue;
  128. }
  129. auto ref_out_tensor = ref_out_tensors[i].at(0);
  130. ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape();
  131. for (auto &tensor : ref_out_tensors[i]) {
  132. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  133. GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str());
  134. return GRAPH_FAILED;
  135. }
  136. auto shape = tensor.MutableShape();
  137. if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
  138. GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu",
  139. node->GetName().c_str(), i, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
  140. ref_out_tensor_shape = GeShape(UNKNOWN_RANK);
  141. break;
  142. }
  143. for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) {
  144. if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) {
  145. continue;
  146. }
  147. GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu",
  148. node->GetName().c_str(), i, j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
  149. (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM);
  150. }
  151. }
  152. (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
  153. }
  154. return GRAPH_SUCCESS;
  155. }
  156. graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node,
  157. std::vector<std::vector<GeTensorDesc>> &ref_data_tensors,
  158. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) {
  159. GELOGD("Enter update parent node shape for class while op process");
  160. if (ref_data_tensors.size() != ref_out_tensors.size()) {
  161. GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!",
  162. node->GetName().c_str(), ref_data_tensors.size(), ref_out_tensors.size());
  163. return GRAPH_FAILED;
  164. }
  165. for (size_t i = 0; i < ref_data_tensors.size(); i++) {
  166. if (ref_out_tensors[i].size() != 1) {
  167. GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!");
  168. return GRAPH_FAILED;
  169. }
  170. }
  171. bool is_need_reverse_brush = false;
  172. // check input and output
  173. for (size_t i = 0; i < ref_out_tensors.size(); i++) {
  174. if (ref_out_tensors[i].empty()) {
  175. continue;
  176. }
  177. auto ref_out_tensor = ref_out_tensors[i].at(0);
  178. auto tmp_shape = ref_out_tensor.MutableShape();
  179. // ref_i's data and output tensor shape should be same
  180. for (auto &tensor : ref_data_tensors[i]) {
  181. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  182. GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str());
  183. return GRAPH_FAILED;
  184. }
  185. auto shape = tensor.MutableShape();
  186. if (shape.GetDims() != tmp_shape.GetDims()) {
  187. ref_out_tensor.SetUnknownDimNumShape();
  188. is_need_reverse_brush = true;
  189. break;
  190. }
  191. }
  192. (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
  193. }
  194. // reverse refresh while body shape
  195. if (is_need_reverse_brush) {
  196. return ReverseBrushWhileBodySubGraph(node);
  197. }
  198. return GRAPH_SUCCESS;
  199. }
  200. graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) {
  201. auto op_desc = node->GetOpDesc();
  202. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  203. if (sub_graph_names.empty()) {
  204. return GRAPH_SUCCESS;
  205. }
  206. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  207. for (const auto &name : sub_graph_names) {
  208. if (name.empty()) {
  209. GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
  210. continue;
  211. }
  212. auto sub_graph = root_graph->GetSubgraph(name);
  213. if (sub_graph == nullptr) {
  214. GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  215. return GRAPH_FAILED;
  216. }
  217. for (const auto &node_sub : sub_graph->GetDirectNode()) {
  218. if (node_sub->GetType() != DATA) {
  219. continue;
  220. }
  221. int ref_i;
  222. auto data_opdesc = node_sub->GetOpDesc();
  223. if (data_opdesc == nullptr) {
  224. GE_LOGE("Invalid data node on the sub graph %s parent node %s, no OpDesc",
  225. name.c_str(), node->GetName().c_str());
  226. return GRAPH_FAILED;
  227. }
  228. if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  229. GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute",
  230. name.c_str(), node->GetName().c_str());
  231. return GRAPH_FAILED;
  232. }
  233. if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) {
  234. continue;
  235. }
  236. auto input_desc = op_desc->MutableInputDesc(ref_i);
  237. if (input_desc == nullptr) {
  238. GE_LOGE("The ref index(%d) on the data %s on the sub graph %s "
  239. "parent node %s are incompatible, inputs num %u",
  240. ref_i, node_sub->GetName().c_str(), name.c_str(),
  241. node->GetName().c_str(), node->GetAllOutDataAnchorsSize());
  242. return GRAPH_FAILED;
  243. }
  244. GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(),
  245. node->GetName().c_str());
  246. auto ret = data_opdesc->UpdateInputDesc(0, *input_desc);
  247. if (ret != GRAPH_SUCCESS) {
  248. GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s",
  249. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  250. return ret;
  251. }
  252. ret = data_opdesc->UpdateOutputDesc(0, *input_desc);
  253. if (ret != GRAPH_SUCCESS) {
  254. GE_LOGE("Failed to update output desc of data %s on the sub graph %s parent node %s",
  255. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  256. return ret;
  257. }
  258. }
  259. }
  260. return GRAPH_SUCCESS;
  261. }
  262. graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr<ComputeGraph> &sub_graph,
  263. NodePtr &netoutput, const ConstNodePtr &node,
  264. std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) {
  265. auto sub_nodes = sub_graph->GetDirectNode();
  266. for (size_t i = sub_nodes.size(); i > 0; --i) {
  267. auto sub_node = sub_nodes.at(i - 1);
  268. if (sub_node->GetType() == NETOUTPUT) {
  269. netoutput = sub_node;
  270. }
  271. if (sub_node->GetType() == DATA) {
  272. if (sub_node->GetOpDesc() == nullptr) {
  273. return GRAPH_FAILED;
  274. }
  275. int ref_i;
  276. if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  277. GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
  278. return GRAPH_FAILED;
  279. }
  280. if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) {
  281. GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!",
  282. sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize());
  283. return GRAPH_FAILED;
  284. }
  285. ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0));
  286. }
  287. }
  288. return GRAPH_SUCCESS;
  289. }
  290. graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
  291. auto op_desc = node->GetOpDesc();
  292. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  293. if (sub_graph_names.empty()) {
  294. return GRAPH_SUCCESS;
  295. }
  296. std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize());
  297. std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize());
  298. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  299. for (const auto &name : sub_graph_names) {
  300. if (name.empty()) {
  301. GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
  302. continue;
  303. }
  304. auto sub_graph = root_graph->GetSubgraph(name);
  305. if (sub_graph == nullptr) {
  306. GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  307. return GRAPH_FAILED;
  308. }
  309. NodePtr netoutput = nullptr;
  310. auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors);
  311. if (ret != GRAPH_SUCCESS) {
  312. return ret;
  313. }
  314. if (netoutput == nullptr) {
  315. GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str());
  316. return GRAPH_FAILED;
  317. }
  318. auto netoutput_opdesc = netoutput->GetOpDesc();
  319. if (netoutput_opdesc == nullptr) {
  320. GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it",
  321. name.c_str(), node->GetName().c_str());
  322. return GRAPH_FAILED;
  323. }
  324. for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) {
  325. auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx());
  326. if (edge_desc == nullptr) {
  327. GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d",
  328. name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx());
  329. return GRAPH_FAILED;
  330. }
  331. GELOGI("Netoutput in anchor index is %zu, input tensor dim is %zu",
  332. edge_anchor->GetIdx(), edge_desc->GetShape().GetDimNum());
  333. int ref_i;
  334. if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  335. // if there is no ref index on the TensorDesc, it means the output data will be ignored outer.
  336. continue;
  337. }
  338. GELOGI("Parent node index of edge desc is %d", ref_i);
  339. if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) {
  340. return GRAPH_FAILED;
  341. }
  342. ref_out_tensors[ref_i].emplace_back(*edge_desc);
  343. }
  344. }
  345. if (node->GetType() == WHILE) {
  346. return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors);
  347. }
  348. return UpdateParentNodeForBranch(node, ref_out_tensors);
  349. }
  350. string Serial(const vector<int64_t> &dims) {
  351. string serial_string;
  352. serial_string += "[";
  353. for (int64_t dim : dims) {
  354. serial_string += std::to_string(dim) + " ";
  355. }
  356. serial_string += "]";
  357. return serial_string;
  358. }
  359. void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
  360. desc_str += "[";
  361. std::vector<std::pair<int64_t, int64_t>> shape_range;
  362. (void)desc->GetShapeRange(shape_range);
  363. for (const auto &pair : shape_range) {
  364. desc_str += "{";
  365. desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
  366. desc_str += "},";
  367. }
  368. desc_str += "] ";
  369. }
  370. void SerialShapeAndDtype(const GeTensorDescPtr &desc, bool is_origin_info, std::string &desc_str) {
  371. desc_str += "[";
  372. if (!is_origin_info) {
  373. for (int64_t dim : desc->GetShape().GetDims()) {
  374. desc_str += std::to_string(dim) + " ";
  375. }
  376. desc_str += "]";
  377. desc_str += ":" + TypeUtils::DataTypeToSerialString(desc->GetDataType()) + ":" +
  378. TypeUtils::FormatToSerialString(desc->GetFormat()) + " ";
  379. } else {
  380. for (int64_t dim : desc->GetOriginShape().GetDims()) {
  381. desc_str += std::to_string(dim) + " ";
  382. }
  383. desc_str += "]";
  384. desc_str += ":" + TypeUtils::DataTypeToSerialString(desc->GetOriginDataType()) + ":" +
  385. TypeUtils::FormatToSerialString(desc->GetOriginFormat()) + " ";
  386. }
  387. }
  388. graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) {
  389. GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED);
  390. GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED);
  391. for (const auto &in_anchor : node_ptr->GetAllInDataAnchors()) {
  392. auto in_idx = in_anchor->GetIdx();
  393. auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor();
  394. if (peer_out_data_anchor == nullptr) {
  395. continue;
  396. }
  397. auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode();
  398. if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) {
  399. continue;
  400. }
  401. int peer_out_idx = peer_out_data_anchor->GetIdx();
  402. auto peer_out_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(peer_out_idx));
  403. // check shape and dtype continuity. do not stop process
  404. auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_idx));
  405. if (in_desc == nullptr) {
  406. continue;
  407. }
  408. auto in_shape = in_desc->MutableShape().GetDims();
  409. auto in_dtype = in_desc->GetDataType();
  410. auto peer_out_shape = peer_out_desc->MutableShape().GetDims();
  411. auto peer_out_dtype = peer_out_desc->GetDataType();
  412. if (peer_out_dtype != in_dtype) {
  413. GELOGW("current node [%s] [%d]\'th out_dtype is [%s].peer output node [%s] [%d]\'th "
  414. "output_dtype is [%s].The two dtype should be same! Please check graph and fix it",
  415. node_ptr->GetName().c_str(), in_idx, TypeUtils::DataTypeToSerialString(in_dtype).c_str(),
  416. peer_out_data_node->GetName().c_str(), peer_out_idx,
  417. TypeUtils::DataTypeToSerialString(peer_out_dtype).c_str());
  418. } else if ((!in_shape.empty()) && (in_shape != peer_out_shape)) {
  419. string in_shape_str = Serial(in_shape);
  420. string peer_out_shape_str = Serial(peer_out_shape);
  421. GELOGW("current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th "
  422. "input_shape is [%s].The two shape should be same! Please check graph and fix it",
  423. node_ptr->GetName().c_str(), in_idx, in_shape_str.c_str(),
  424. peer_out_data_node->GetName().c_str(), peer_out_idx, peer_out_shape_str.c_str());
  425. }
  426. // refresh current node input desc
  427. in_desc->SetOriginShape(peer_out_desc->GetOriginShape());
  428. in_desc->SetShape(peer_out_desc->MutableShape());
  429. in_desc->SetDataType(peer_out_desc->GetDataType());
  430. in_desc->SetOriginDataType(peer_out_desc->GetOriginDataType());
  431. if (peer_out_desc->MutableShape().GetDims() != UNKNOWN_RANK) {
  432. std::vector<std::pair<int64_t, int64_t>> shape_range;
  433. (void) peer_out_desc->GetShapeRange(shape_range);
  434. in_desc->SetShapeRange(shape_range);
  435. }
  436. ge::TensorUtils::SetRealDimCnt(*in_desc,
  437. static_cast<uint32_t>(peer_out_desc->MutableShape().GetDims().size()));
  438. }
  439. return GRAPH_SUCCESS;
  440. }
  441. } // namespace
  442. void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) {
  443. if (!IsLogEnable(GE, DLOG_DEBUG)) {
  444. return;
  445. }
  446. if (node == nullptr) {
  447. GELOGE(GRAPH_FAILED, "node is null");
  448. return;
  449. }
  450. ge::OpDescPtr op_desc = node->GetOpDesc();
  451. GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return);
  452. std::string str;
  453. if (op_desc->GetInputsSize() != 0) {
  454. std::string input_desc_str = "input shape: ";
  455. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  456. SerialShapeAndDtype(input_desc, false, input_desc_str);
  457. }
  458. str += input_desc_str;
  459. input_desc_str = "input origin shape: ";
  460. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  461. SerialShapeAndDtype(input_desc, true, input_desc_str);
  462. }
  463. str += input_desc_str;
  464. input_desc_str = "input shape range: ";
  465. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  466. SerialShapeRange(input_desc, input_desc_str);
  467. }
  468. str += input_desc_str;
  469. }
  470. if (op_desc->GetAllOutputsDescSize() != 0) {
  471. std::string output_desc_str = "output shape: ";
  472. for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
  473. if (output_desc == nullptr) {
  474. continue;
  475. }
  476. SerialShapeAndDtype(output_desc, false, output_desc_str);
  477. }
  478. str += output_desc_str;
  479. output_desc_str = "output origin shape: ";
  480. for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
  481. if (output_desc == nullptr) {
  482. continue;
  483. }
  484. SerialShapeAndDtype(output_desc, true, output_desc_str);
  485. }
  486. str += output_desc_str;
  487. output_desc_str = "output shape range: ";
  488. for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
  489. SerialShapeRange(output_desc, output_desc_str);
  490. }
  491. str += output_desc_str;
  492. }
  493. GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), str.c_str());
  494. }
  495. InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map,
  496. const NodePtr &node) {
  497. if (node == nullptr) {
  498. GELOGE(GRAPH_FAILED, "node is null");
  499. return nullptr;
  500. }
  501. InferenceContextPtr inference_context = std::shared_ptr<InferenceContext>(InferenceContext::Create());
  502. if (inference_context == nullptr) {
  503. GELOGE(GRAPH_FAILED, "Failed to alloc InferenceContext");
  504. return nullptr;
  505. }
  506. auto all_in_data_anchors = node->GetAllInDataAnchors();
  507. std::vector<std::vector<ShapeAndType>> input_shapes_and_types(all_in_data_anchors.size());
  508. std::vector<std::string> marks;
  509. bool has_input_shapes_and_types = false;
  510. for (const auto &in_anchor : all_in_data_anchors) {
  511. const auto &out_anchor = in_anchor->GetPeerOutAnchor();
  512. if (out_anchor == nullptr) {
  513. continue;
  514. }
  515. auto input_node = out_anchor->GetOwnerNode();
  516. if (input_node == nullptr) {
  517. continue;
  518. }
  519. auto iter = context_map.find(input_node);
  520. if (iter != context_map.end()) {
  521. const auto &src_context = iter->second;
  522. GE_IF_BOOL_EXEC(src_context == nullptr, GELOGE(GRAPH_FAILED, "src_context is null."); return nullptr);
  523. GELOGD("node:%s get %ld marks from node:%s",
  524. node->GetName().c_str(), src_context->GetMarks().size(), input_node->GetName().c_str());
  525. for (auto mark : src_context->GetMarks()) {
  526. marks.push_back(mark);
  527. }
  528. auto output_idx = out_anchor->GetIdx();
  529. auto input_idx = in_anchor->GetIdx();
  530. auto output_shape_and_type = src_context->GetOutputHandleShapesAndTypes();
  531. if (output_idx < static_cast<int>(output_shape_and_type.size())) {
  532. GELOGI("Add shape and type from %s:%d to %s:%d", input_node->GetName().c_str(), output_idx,
  533. node->GetName().c_str(), input_idx);
  534. input_shapes_and_types[input_idx] = output_shape_and_type[output_idx];
  535. has_input_shapes_and_types = true;
  536. } else {
  537. GELOGI("[%s] Output out of range. index = %d, size = %zu", node->GetName().c_str(), output_idx,
  538. output_shape_and_type.size());
  539. }
  540. }
  541. }
  542. if (has_input_shapes_and_types) {
  543. inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types));
  544. }
  545. inference_context->SetMarks(marks);
  546. return inference_context;
  547. }
  548. namespace {
  549. thread_local std::unordered_map<NodePtr, InferenceContextPtr> context_map;
  550. }
  551. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  552. void ShapeRefiner::ClearContextMap() {
  553. context_map.clear();
  554. }
  555. graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) {
  556. return InferShapeAndType(node, op, true);
  557. }
  558. graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph) {
  559. auto op_desc = node->GetOpDesc();
  560. const auto &op_type = op_desc->GetType();
  561. graphStatus ret;
  562. if (before_subgraph) {
  563. ret = UpdateSubGraphDataNodes(node);
  564. if (ret != GRAPH_SUCCESS) {
  565. return ret;
  566. }
  567. }
  568. // Get infer func and execute
  569. ret = op_desc->CallInferFunc(op);
  570. if (ret == GRAPH_PARAM_INVALID) {
  571. // Op ir no infer func, try to get infer func from operator factory
  572. auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType());
  573. if (node_op.IsEmpty()) {
  574. GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str());
  575. return ret;
  576. }
  577. GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str());
  578. auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
  579. node_op.BreakConnect();
  580. if (temp_op_desc == nullptr) {
  581. GELOGE(GRAPH_FAILED, "temp op desc is null");
  582. return GRAPH_FAILED;
  583. }
  584. if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) {
  585. GELOGW("InferShapeAndType UpdateInputName failed");
  586. for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) {
  587. if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) {
  588. break;
  589. }
  590. return GRAPH_SUCCESS;
  591. }
  592. }
  593. if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) {
  594. GELOGW("InferShapeAndType UpdateOutputName failed");
  595. }
  596. op_desc->AddInferFunc(temp_op_desc->GetInferFunc());
  597. ret = op_desc->CallInferFunc(op);
  598. GELOGI("op CallInferFunc second. ret: %u", ret);
  599. }
  600. if (ret != GRAPH_SUCCESS) {
  601. return ret;
  602. }
  603. if (!before_subgraph) {
  604. return UpdateParentNodeOutTensor(node);
  605. }
  606. return GRAPH_SUCCESS;
  607. }
  608. graphStatus ShapeRefiner::InferShapeAndTypeForRunning(const ConstNodePtr &node, Operator &op, bool before_subgraph) {
  609. auto op_desc = node->GetOpDesc();
  610. const auto &op_type = op_desc->GetType();
  611. graphStatus ret;
  612. if (before_subgraph) {
  613. ret = UpdateSubGraphDataNodes(node);
  614. if (ret != GRAPH_SUCCESS) {
  615. return ret;
  616. }
  617. }
  618. // Get infer func and execute
  619. ret = op_desc->CallInferFunc(op);
  620. if (ret == GRAPH_PARAM_INVALID) {
  621. GELOGD("NodeUtils::GetNodeType return value is: [%s]", NodeUtils::GetNodeType(*node).c_str());
  622. auto origin_type = NodeUtils::GetNodeType(*node);
  623. auto infer_func = ge::OperatorFactoryImpl::GetInferShapeFunc(origin_type);
  624. if (infer_func == nullptr) {
  625. GELOGE(GRAPH_FAILED, "Failed to Get InferFunc.type is %s", origin_type.c_str());
  626. return GRAPH_FAILED;
  627. }
  628. op_desc->AddInferFunc(infer_func);
  629. ret = op_desc->CallInferFunc(op);
  630. GELOGI("op CallInferFunc second. ret: %u", ret);
  631. }
  632. if (ret != GRAPH_SUCCESS) {
  633. return ret;
  634. }
  635. if (!before_subgraph) {
  636. return UpdateParentNodeOutTensor(node);
  637. }
  638. return GRAPH_SUCCESS;
  639. }
  640. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  641. graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) {
  642. return InferShapeAndType(node, true);
  643. }
  644. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  645. graphStatus ShapeRefiner::InferShapeAndTypeForRunning(const NodePtr &node, bool before_subgraph) {
  646. GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED);
  647. auto opdesc = node->GetOpDesc();
  648. GE_IF_BOOL_EXEC(opdesc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED);
  649. PrintInOutTensorShape(node, "before_infershape when running");
  650. Operator op = OpDescUtils::CreateOperatorFromNode(node);
  651. graphStatus status = InferShapeAndTypeForRunning(node, op, before_subgraph);
  652. if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) {
  653. PrintInOutTensorShape(node, "after_infershape when running");
  654. return GRAPH_SUCCESS;
  655. } else {
  656. GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str());
  657. return GRAPH_FAILED;
  658. }
  659. }
  660. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY
  661. graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node, bool before_subgraph) {
  662. GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED);
  663. bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
  664. auto opdesc = node->GetOpDesc();
  665. GE_IF_BOOL_EXEC(opdesc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED);
  666. // some op can not infershape twice such as aipp
  667. bool need_update_input = !is_unknown_graph && !opdesc->HasAttr("has_infered_verified");
  668. if (need_update_input) {
  669. auto status = UpdateOpInputDesc(node);
  670. if (status != GRAPH_SUCCESS) {
  671. GELOGE(GRAPH_FAILED, "update op input_desc failed!");
  672. return status;
  673. }
  674. }
  675. if (node->Verify() != GRAPH_SUCCESS) {
  676. GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str());
  677. return GRAPH_FAILED;
  678. }
  679. PrintInOutTensorShape(node, "before_infershape");
  680. Operator op = OpDescUtils::CreateOperatorFromNode(node);
  681. if (!is_unknown_graph) {
  682. auto inference_context = CreateInferenceContext(context_map, node);
  683. GE_CHECK_NOTNULL(inference_context);
  684. GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size());
  685. op.SetInferenceContext(inference_context);
  686. }
  687. graphStatus status = InferShapeAndType(node, op, before_subgraph);
  688. if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) {
  689. if (is_unknown_graph) {
  690. PrintInOutTensorShape(node, "after_infershape when running");
  691. return GRAPH_SUCCESS;
  692. }
  693. auto op_desc = node->GetOpDesc();
  694. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  695. auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
  696. if (output_tensor->MutableShape().GetDims().empty()) {
  697. output_tensor->SetOriginShape(output_tensor->GetShape());
  698. }
  699. ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetOriginShape().GetDims()
  700. .size()));
  701. output_tensor->SetOriginDataType(output_tensor->GetDataType());
  702. GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
  703. node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
  704. TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
  705. TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
  706. }
  707. GE_CHK_STATUS_RET_NOLOG(NodeUtils::UpdatePeerNodeInputDesc(node));
  708. } else {
  709. GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str());
  710. return GRAPH_FAILED;
  711. }
  712. if (!is_unknown_graph) {
  713. auto ctx_after_infer = op.GetInferenceContext();
  714. if (ctx_after_infer != nullptr) {
  715. GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
  716. if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) {
  717. GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(),
  718. ctx_after_infer->GetMarks().size());
  719. (void)context_map.emplace(node, ctx_after_infer);
  720. }
  721. }
  722. }
  723. PrintInOutTensorShape(node, "after_infershape");
  724. return GRAPH_SUCCESS;
  725. }
  726. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示