You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

shape_refiner.cc 28 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/shape_refiner.h"
  17. #include <memory>
  18. #include <string>
  19. #include <unordered_map>
  20. #include <utility>
  21. #include <vector>
  22. #include "graph/debug/ge_attr_define.h"
  23. #include "graph/utils/graph_utils.h"
  24. #include "debug/ge_log.h"
  25. #include "debug/ge_op_types.h"
  26. #include "external/graph/operator.h"
  27. #include "external/graph/operator_factory.h"
  28. #include "framework/common/debug/ge_log.h"
  29. #include "graph/compute_graph.h"
  30. #include "utils/node_utils.h"
  31. #include "utils/op_desc_utils.h"
  32. #include "utils/tensor_utils.h"
  33. #include "utils/type_utils.h"
  34. namespace ge {
  35. namespace {
  36. const uint32_t kWhileBodySubGraphIdx = 1;
  37. graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) {
  38. GELOGD("Enter reverse brush while body subgraph process!");
  39. auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx);
  40. if (sub_graph_body == nullptr) {
  41. GELOGE(GRAPH_FAILED, "Get while body graph failed!");
  42. return GRAPH_FAILED;
  43. }
  44. for (const auto &node_sub : sub_graph_body->GetAllNodes()) {
  45. for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) {
  46. auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i);
  47. GE_IF_BOOL_EXEC(input_desc == nullptr,
  48. GELOGW("Get null input by index %zu from node %s ", i, node_sub->GetName().c_str());
  49. continue);
  50. (void)input_desc->SetUnknownDimNumShape();
  51. }
  52. for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) {
  53. auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i);
  54. (void)output_desc->SetUnknownDimNumShape();
  55. }
  56. }
  57. return GRAPH_SUCCESS;
  58. }
  59. graphStatus UpdataOutputForMultiBatcch(const ConstNodePtr &node,
  60. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) {
  61. // check sub_graph shape. Get max for update.
  62. for (size_t i = 0; i < ref_out_tensors.size(); ++i) {
  63. if (ref_out_tensors[i].empty()) {
  64. continue;
  65. }
  66. int64_t max_size = 0;
  67. size_t max_shape_index = 0;
  68. auto &ref_out_tensor = ref_out_tensors[i].at(0);
  69. const auto &ref_out_tensor_shape = ref_out_tensor.MutableShape();
  70. for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) {
  71. auto &tensor = ref_out_tensors[i].at(j);
  72. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  73. GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str());
  74. return GRAPH_FAILED;
  75. }
  76. auto shape = tensor.MutableShape();
  77. if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
  78. GELOGE(GRAPH_FAILED, "node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu",
  79. node->GetName().c_str(), i, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
  80. return GRAPH_FAILED;
  81. }
  82. int64_t size = 1;
  83. for (auto dim : shape.GetDims()) {
  84. if (INT64_MAX / dim < size) {
  85. GELOGE(PARAM_INVALID, "The shape size overflow");
  86. return PARAM_INVALID;
  87. }
  88. size *= dim;
  89. }
  90. if (size > max_size) {
  91. max_size = size;
  92. max_shape_index = j;
  93. }
  94. }
  95. (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index));
  96. }
  97. return GRAPH_SUCCESS;
  98. }
  99. graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node,
  100. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) {
  101. GELOGD("Enter update parent node shape for class branch op process");
  102. if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) {
  103. return UpdataOutputForMultiBatcch(node, ref_out_tensors);
  104. }
  105. // check sub_graph shape.If not same ,do unknown shape process
  106. for (size_t i = 0; i < ref_out_tensors.size(); i++) {
  107. if (ref_out_tensors[i].empty()) {
  108. continue;
  109. }
  110. auto ref_out_tensor = ref_out_tensors[i].at(0);
  111. ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape();
  112. for (auto &tensor : ref_out_tensors[i]) {
  113. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  114. GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str());
  115. return GRAPH_FAILED;
  116. }
  117. auto shape = tensor.MutableShape();
  118. if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) {
  119. GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i,
  120. shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
  121. ref_out_tensor_shape = GeShape(UNKNOWN_RANK);
  122. break;
  123. }
  124. for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) {
  125. if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) {
  126. continue;
  127. }
  128. GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i,
  129. j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize());
  130. (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM);
  131. }
  132. }
  133. (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
  134. }
  135. return GRAPH_SUCCESS;
  136. }
  137. graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_data_tensors,
  138. std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) {
  139. GELOGD("Enter update parent node shape for class while op process");
  140. if (ref_data_tensors.size() != ref_out_tensors.size()) {
  141. GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!", node->GetName().c_str(),
  142. ref_data_tensors.size(), ref_out_tensors.size());
  143. return GRAPH_FAILED;
  144. }
  145. for (size_t i = 0; i < ref_data_tensors.size(); i++) {
  146. if (ref_out_tensors[i].size() != 1) {
  147. GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!");
  148. return GRAPH_FAILED;
  149. }
  150. }
  151. bool is_need_reverse_brush = false;
  152. // check input and output
  153. for (size_t i = 0; i < ref_out_tensors.size(); i++) {
  154. if (ref_out_tensors[i].empty()) {
  155. continue;
  156. }
  157. auto ref_out_tensor = ref_out_tensors[i].at(0);
  158. auto tmp_shape = ref_out_tensor.MutableShape();
  159. // ref_i's data and output tensor shape should be same
  160. for (auto &tensor : ref_data_tensors[i]) {
  161. if (ref_out_tensor.GetDataType() != tensor.GetDataType()) {
  162. GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str());
  163. return GRAPH_FAILED;
  164. }
  165. auto shape = tensor.MutableShape();
  166. if (shape.GetDims() != tmp_shape.GetDims()) {
  167. ref_out_tensor.SetUnknownDimNumShape();
  168. is_need_reverse_brush = true;
  169. break;
  170. }
  171. }
  172. (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor);
  173. }
  174. // reverse refresh while body shape
  175. if (is_need_reverse_brush) {
  176. return ReverseBrushWhileBodySubGraph(node);
  177. }
  178. return GRAPH_SUCCESS;
  179. }
  180. graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) {
  181. auto op_desc = node->GetOpDesc();
  182. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  183. if (sub_graph_names.empty()) {
  184. return GRAPH_SUCCESS;
  185. }
  186. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  187. for (const auto &name : sub_graph_names) {
  188. if (name.empty()) {
  189. GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
  190. continue;
  191. }
  192. auto sub_graph = root_graph->GetSubgraph(name);
  193. if (sub_graph == nullptr) {
  194. GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  195. return GRAPH_FAILED;
  196. }
  197. for (const auto &node_sub : sub_graph->GetDirectNode()) {
  198. if (node_sub->GetType() != DATA) {
  199. continue;
  200. }
  201. int ref_i;
  202. auto data_opdesc = node_sub->GetOpDesc();
  203. if (data_opdesc == nullptr) {
  204. GE_LOGE("Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(),
  205. node->GetName().c_str());
  206. return GRAPH_FAILED;
  207. }
  208. if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  209. GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(),
  210. node->GetName().c_str());
  211. return GRAPH_FAILED;
  212. }
  213. if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) {
  214. continue;
  215. }
  216. auto input_desc = op_desc->MutableInputDesc(ref_i);
  217. if (input_desc == nullptr) {
  218. GE_LOGE(
  219. "The ref index(%d) on the data %s on the sub graph %s "
  220. "parent node %s are incompatible, inputs num %u",
  221. ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize());
  222. return GRAPH_FAILED;
  223. }
  224. GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(),
  225. node->GetName().c_str());
  226. auto ret = data_opdesc->UpdateInputDesc(0, *input_desc);
  227. if (ret != GRAPH_SUCCESS) {
  228. GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s",
  229. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  230. return ret;
  231. }
  232. ret = data_opdesc->UpdateOutputDesc(0, *input_desc);
  233. if (ret != GRAPH_SUCCESS) {
  234. GE_LOGE("Failed to update output desc of data %s on the sub graph %s parent node %s",
  235. node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str());
  236. return ret;
  237. }
  238. }
  239. }
  240. return GRAPH_SUCCESS;
  241. }
  242. graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr<ComputeGraph> &sub_graph, NodePtr &netoutput,
  243. const ConstNodePtr &node,
  244. std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) {
  245. auto sub_nodes = sub_graph->GetDirectNode();
  246. for (size_t i = sub_nodes.size(); i > 0; --i) {
  247. auto sub_node = sub_nodes.at(i - 1);
  248. if (sub_node->GetType() == NETOUTPUT) {
  249. netoutput = sub_node;
  250. }
  251. if (sub_node->GetType() == DATA) {
  252. if (sub_node->GetOpDesc() == nullptr) {
  253. return GRAPH_FAILED;
  254. }
  255. int ref_i;
  256. if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  257. GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str());
  258. return GRAPH_FAILED;
  259. }
  260. if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) {
  261. GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!", sub_node->GetName().c_str(),
  262. ref_i, node->GetAllInDataAnchorsSize());
  263. return GRAPH_FAILED;
  264. }
  265. ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0));
  266. }
  267. }
  268. return GRAPH_SUCCESS;
  269. }
  270. graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) {
  271. auto op_desc = node->GetOpDesc();
  272. auto sub_graph_names = op_desc->GetSubgraphInstanceNames();
  273. if (sub_graph_names.empty()) {
  274. return GRAPH_SUCCESS;
  275. }
  276. std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize());
  277. std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize());
  278. auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph());
  279. for (const auto &name : sub_graph_names) {
  280. if (name.empty()) {
  281. GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str());
  282. continue;
  283. }
  284. auto sub_graph = root_graph->GetSubgraph(name);
  285. if (sub_graph == nullptr) {
  286. GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str());
  287. return GRAPH_FAILED;
  288. }
  289. NodePtr netoutput = nullptr;
  290. auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors);
  291. if (ret != GRAPH_SUCCESS) {
  292. return ret;
  293. }
  294. if (netoutput == nullptr) {
  295. GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str());
  296. return GRAPH_FAILED;
  297. }
  298. auto netoutput_opdesc = netoutput->GetOpDesc();
  299. if (netoutput_opdesc == nullptr) {
  300. GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(),
  301. node->GetName().c_str());
  302. return GRAPH_FAILED;
  303. }
  304. for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) {
  305. auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx());
  306. if (edge_desc == nullptr) {
  307. GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", name.c_str(),
  308. node->GetName().c_str(), edge_anchor->GetIdx());
  309. return GRAPH_FAILED;
  310. }
  311. GELOGI("Netoutput in anchor index is %zu, input tensor dim is %zu", edge_anchor->GetIdx(),
  312. edge_desc->GetShape().GetDimNum());
  313. int ref_i;
  314. if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) {
  315. // if there is no ref index on the TensorDesc, it means the output data will be ignored outer.
  316. continue;
  317. }
  318. GELOGI("Parent node index of edge desc is %d", ref_i);
  319. if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) {
  320. return GRAPH_FAILED;
  321. }
  322. ref_out_tensors[ref_i].emplace_back(*edge_desc);
  323. }
  324. }
  325. if (node->GetType() == WHILE) {
  326. return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors);
  327. }
  328. return UpdateParentNodeForBranch(node, ref_out_tensors);
  329. }
  330. string Serial(const vector<int64_t> &dims) {
  331. string serial_string;
  332. serial_string += "[";
  333. for (int64_t dim : dims) {
  334. serial_string += std::to_string(dim) + " ";
  335. }
  336. serial_string += "]";
  337. return serial_string;
  338. }
  339. graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) {
  340. GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED);
  341. GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED);
  342. for (const auto &in_anchor : node_ptr->GetAllInDataAnchors()) {
  343. auto in_idx = in_anchor->GetIdx();
  344. auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor();
  345. if (peer_out_data_anchor == nullptr) {
  346. continue;
  347. }
  348. auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode();
  349. if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) {
  350. continue;
  351. }
  352. int peer_out_idx = peer_out_data_anchor->GetIdx();
  353. auto peer_out_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(peer_out_idx));
  354. // check shape and dtype continuity. do not stop process
  355. auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_idx));
  356. if (in_desc == nullptr) {
  357. continue;
  358. }
  359. auto in_shape = in_desc->GetShape().GetDims();
  360. auto in_dtype = in_desc->GetDataType();
  361. auto peer_out_shape = peer_out_desc->GetShape().GetDims();
  362. auto peer_out_dtype = peer_out_desc->GetDataType();
  363. if (peer_out_dtype != in_dtype) {
  364. GELOGW(
  365. "current node [%s] [%d]\'th out_dtype is [%s].peer output node [%s] [%d]\'th "
  366. "output_dtype is [%s].The two dtype should be same! Please check graph and fix it",
  367. node_ptr->GetName().c_str(), in_idx, TypeUtils::DataTypeToSerialString(in_dtype).c_str(),
  368. peer_out_data_node->GetName().c_str(), peer_out_idx, TypeUtils::DataTypeToSerialString(peer_out_dtype).c_str());
  369. } else if ((!in_shape.empty()) && (in_shape != peer_out_shape)) {
  370. string in_shape_str = Serial(in_shape);
  371. string peer_out_shape_str = Serial(peer_out_shape);
  372. GELOGW(
  373. "current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th "
  374. "input_shape is [%s].The two shape should be same! Please check graph and fix it",
  375. node_ptr->GetName().c_str(), in_idx, in_shape_str.c_str(), peer_out_data_node->GetName().c_str(), peer_out_idx,
  376. peer_out_shape_str.c_str());
  377. }
  378. // refresh current node input desc
  379. in_desc->SetOriginShape(peer_out_desc->GetOriginShape());
  380. in_desc->SetShape(peer_out_desc->GetShape());
  381. in_desc->SetDataType(peer_out_desc->GetDataType());
  382. in_desc->SetOriginDataType(peer_out_desc->GetOriginDataType());
  383. std::vector<std::pair<int64_t, int64_t>> shape_range;
  384. (void)peer_out_desc->GetShapeRange(shape_range);
  385. in_desc->SetShapeRange(shape_range);
  386. ge::TensorUtils::SetRealDimCnt(*in_desc, static_cast<uint32_t>(peer_out_desc->GetShape().GetDims().size()));
  387. }
  388. return GRAPH_SUCCESS;
  389. }
  390. } // namespace
  391. void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) {
  392. if (!IsLogEnable(GE, DLOG_DEBUG)) {
  393. return;
  394. }
  395. if (node == nullptr) {
  396. GELOGE(GRAPH_FAILED, "node is null");
  397. return;
  398. }
  399. ge::OpDescPtr op_desc = node->GetOpDesc();
  400. GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return );
  401. std::string str;
  402. if (op_desc->GetInputsSize() != 0) {
  403. std::string input_desc_str = "input shape: ";
  404. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  405. input_desc_str += "[";
  406. for (int64_t dim : input_desc->GetShape().GetDims()) {
  407. input_desc_str += std::to_string(dim) + " ";
  408. }
  409. input_desc_str += "]";
  410. input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) + ":" +
  411. TypeUtils::FormatToSerialString(input_desc->GetFormat()) + " ";
  412. }
  413. str += input_desc_str;
  414. input_desc_str = "input origin shape: ";
  415. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  416. input_desc_str += "[";
  417. for (int64_t dim : input_desc->GetOriginShape().GetDims()) {
  418. input_desc_str += std::to_string(dim) + " ";
  419. }
  420. input_desc_str += "]";
  421. input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) + ":" +
  422. TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) + " ";
  423. }
  424. str += input_desc_str;
  425. }
  426. if (op_desc->GetAllOutputsDescSize() != 0) {
  427. std::string output_desc_str = "output shape: ";
  428. for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
  429. if (output_desc == nullptr) {
  430. continue;
  431. }
  432. output_desc_str += "[";
  433. for (int64_t dim : output_desc->GetShape().GetDims()) {
  434. output_desc_str += std::to_string(dim) + " ";
  435. }
  436. output_desc_str += "]";
  437. output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) + ":" +
  438. TypeUtils::FormatToSerialString(output_desc->GetFormat()) + " ";
  439. }
  440. str += output_desc_str;
  441. output_desc_str = "output origin shape: ";
  442. for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
  443. if (output_desc == nullptr) {
  444. continue;
  445. }
  446. output_desc_str += "[";
  447. for (int64_t dim : output_desc->GetOriginShape().GetDims()) {
  448. output_desc_str += std::to_string(dim) + " ";
  449. }
  450. output_desc_str += "]";
  451. output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) + ":" +
  452. TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) + " ";
  453. }
  454. str += output_desc_str;
  455. }
  456. GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), str.c_str());
  457. }
  458. graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) {
  459. return InferShapeAndType(node, op, true);
  460. }
  461. graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph) {
  462. auto op_desc = node->GetOpDesc();
  463. const auto &op_type = op_desc->GetType();
  464. graphStatus ret;
  465. if (before_subgraph) {
  466. ret = UpdateSubGraphDataNodes(node);
  467. if (ret != GRAPH_SUCCESS) {
  468. return ret;
  469. }
  470. }
  471. // Get infer func and execute
  472. ret = op_desc->CallInferFunc(op);
  473. if (ret == GRAPH_PARAM_INVALID) {
  474. // Op ir no infer func, try to get infer func from operator factory
  475. auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType());
  476. if (node_op.IsEmpty()) {
  477. GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str());
  478. return ret;
  479. }
  480. GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str());
  481. auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
  482. node_op.BreakConnect();
  483. if (temp_op_desc == nullptr) {
  484. GELOGE(GRAPH_FAILED, "temp op desc is null");
  485. return GRAPH_FAILED;
  486. }
  487. if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) {
  488. GELOGW("InferShapeAndType UpdateInputName failed");
  489. for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) {
  490. if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) {
  491. break;
  492. }
  493. return GRAPH_SUCCESS;
  494. }
  495. }
  496. if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) {
  497. GELOGW("InferShapeAndType UpdateOutputName failed");
  498. }
  499. op_desc->AddInferFunc(temp_op_desc->GetInferFunc());
  500. ret = op_desc->CallInferFunc(op);
  501. GELOGI("op CallInferFunc second. ret: %u", ret);
  502. }
  503. if (ret != GRAPH_SUCCESS) {
  504. return ret;
  505. }
  506. if (!before_subgraph) {
  507. return UpdateParentNodeOutTensor(node);
  508. }
  509. return GRAPH_SUCCESS;
  510. }
  511. InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map,
  512. const NodePtr &node) {
  513. if (node == nullptr) {
  514. GELOGE(GRAPH_FAILED, "node is null");
  515. return nullptr;
  516. }
  517. InferenceContextPtr inference_context = std::shared_ptr<InferenceContext>(InferenceContext::Create());
  518. if (inference_context == nullptr) {
  519. GELOGE(GRAPH_FAILED, "Failed to alloc InferenceContext");
  520. return nullptr;
  521. }
  522. auto all_in_data_anchors = node->GetAllInDataAnchors();
  523. std::vector<std::vector<ShapeAndType>> input_shapes_and_types(all_in_data_anchors.size());
  524. std::vector<std::string> marks;
  525. bool has_input_shapes_and_types = false;
  526. for (const auto &in_anchor : all_in_data_anchors) {
  527. const auto &out_anchor = in_anchor->GetPeerOutAnchor();
  528. if (out_anchor == nullptr) {
  529. continue;
  530. }
  531. auto input_node = out_anchor->GetOwnerNode();
  532. if (input_node == nullptr) {
  533. continue;
  534. }
  535. auto iter = context_map.find(input_node);
  536. if (iter != context_map.end()) {
  537. const auto &src_context = iter->second;
  538. GE_IF_BOOL_EXEC(src_context == nullptr, GELOGE(GRAPH_FAILED, "src_context is null."); return nullptr);
  539. GELOGD("node:%s get %ld marks from node:%s", node->GetName().c_str(), src_context->GetMarks().size(),
  540. input_node->GetName().c_str());
  541. for (auto mark : src_context->GetMarks()) {
  542. marks.push_back(mark);
  543. }
  544. auto output_idx = out_anchor->GetIdx();
  545. auto input_idx = in_anchor->GetIdx();
  546. auto output_shape_and_type = src_context->GetOutputHandleShapesAndTypes();
  547. if (output_idx < static_cast<int>(output_shape_and_type.size())) {
  548. GELOGI("Add shape and type from %s:%d to %s:%d", input_node->GetName().c_str(), output_idx,
  549. node->GetName().c_str(), input_idx);
  550. input_shapes_and_types[input_idx] = output_shape_and_type[output_idx];
  551. has_input_shapes_and_types = true;
  552. } else {
  553. GELOGI("[%s] Output out of range. index = %d, size = %zu", node->GetName().c_str(), output_idx,
  554. output_shape_and_type.size());
  555. }
  556. }
  557. }
  558. if (has_input_shapes_and_types) {
  559. inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types));
  560. }
  561. inference_context->SetMarks(marks);
  562. return inference_context;
  563. }
  564. namespace {
  565. thread_local std::unordered_map<NodePtr, InferenceContextPtr> context_map;
  566. }
  567. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ShapeRefiner::ClearContextMap() { context_map.clear(); }
  568. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) {
  569. return InferShapeAndType(node, true);
  570. }
  571. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node,
  572. bool before_subgraph) {
  573. GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED);
  574. bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag();
  575. auto opdesc = node->GetOpDesc();
  576. GE_IF_BOOL_EXEC(opdesc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED);
  577. // some op can not infershape twice such as aipp
  578. bool need_update_input = !is_unknown_graph && !opdesc->HasAttr("has_infered_verified");
  579. if (need_update_input) {
  580. auto status = UpdateOpInputDesc(node);
  581. if (status != GRAPH_SUCCESS) {
  582. GELOGE(GRAPH_FAILED, "update op input_desc failed!");
  583. return status;
  584. }
  585. }
  586. if (node->Verify() != GRAPH_SUCCESS) {
  587. GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str());
  588. return GRAPH_FAILED;
  589. }
  590. PrintInOutTensorShape(node, "before_infershape");
  591. Operator op = OpDescUtils::CreateOperatorFromNode(node);
  592. if (!is_unknown_graph) {
  593. auto inference_context = CreateInferenceContext(context_map, node);
  594. if (inference_context == nullptr) {
  595. GELOGE(GRAPH_FAILED, "inference context is null");
  596. return GRAPH_FAILED;
  597. }
  598. GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size());
  599. op.SetInferenceContext(inference_context);
  600. }
  601. graphStatus status = InferShapeAndType(node, op, before_subgraph);
  602. if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) {
  603. if (is_unknown_graph) {
  604. PrintInOutTensorShape(node, "after_infershape when running");
  605. return GRAPH_SUCCESS;
  606. }
  607. auto op_desc = node->GetOpDesc();
  608. for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
  609. auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
  610. ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
  611. output_tensor->SetOriginShape(output_tensor->GetShape());
  612. output_tensor->SetOriginDataType(output_tensor->GetDataType());
  613. GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
  614. node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
  615. TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
  616. TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
  617. }
  618. } else {
  619. GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str());
  620. return GRAPH_FAILED;
  621. }
  622. if (!is_unknown_graph) {
  623. auto ctx_after_infer = op.GetInferenceContext();
  624. if (ctx_after_infer != nullptr) {
  625. GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size());
  626. if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) {
  627. GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(),
  628. ctx_after_infer->GetMarks().size());
  629. (void)context_map.emplace(node, ctx_after_infer);
  630. }
  631. }
  632. }
  633. PrintInOutTensorShape(node, "after_infershape");
  634. return GRAPH_SUCCESS;
  635. }
  636. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示