You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

format_refiner.cc 19 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "format_refiner.h"
  17. #include <deque>
  18. #include <iostream>
  19. #include <set>
  20. #include <unordered_map>
  21. #include <unordered_set>
  22. #include "graph/ref_relation.h"
  23. #include "./compute_graph.h"
  24. #include "./ge_error_codes.h"
  25. #include "./graph/ge_tensor.h"
  26. #include "./operator.h"
  27. #include "./operator_factory.h"
  28. #include "debug/ge_log.h"
  29. #include "debug/ge_op_types.h"
  30. #include "debug/ge_util.h"
  31. #include "framework/common/debug/ge_log.h"
  32. #include "utils/node_utils.h"
  33. #include "utils/op_desc_utils.h"
  34. #include "utils/tensor_utils.h"
  35. #include "utils/type_utils.h"
  36. using namespace ge;
  37. using namespace std;
  38. namespace ge {
  39. namespace {
  40. static const std::unordered_set<string> kChangeDimNodes = {RESHAPE, PERMUTE, EXPANDDIMS, SQUEEZE};
  41. static bool net_format_is_nd = true;
  42. static Format g_user_set_format = FORMAT_ND;
  43. static bool is_first_infer = true;
  44. static RefRelations reflection_builder;
  45. } // namespace
  46. graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &reflection,
  47. std::deque<ge::NodePtr> &nodes, ge::Format to_be_set_format) {
  48. for (const auto &cell : reflection) {
  49. auto node = cell.node;
  50. auto in_out_idx = cell.in_out_idx;
  51. GE_CHECK_NOTNULL(node);
  52. GE_CHECK_NOTNULL(node->GetOpDesc());
  53. if (cell.in_out == ge::NODE_IN) {
  54. auto desc = node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(in_out_idx));
  55. desc.SetOriginFormat(to_be_set_format);
  56. desc.SetFormat(to_be_set_format);
  57. (void)node->GetOpDesc()->UpdateInputDesc(static_cast<uint32_t>(in_out_idx), desc);
  58. } else {
  59. auto desc = node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(in_out_idx));
  60. desc.SetOriginFormat(to_be_set_format);
  61. desc.SetFormat(to_be_set_format);
  62. (void)node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(in_out_idx), desc);
  63. }
  64. nodes.push_back(cell.node);
  65. }
  66. return GRAPH_SUCCESS;
  67. }
  68. graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) {
  69. GE_CHECK_NOTNULL(op_desc);
  70. if (op_desc->GetType() == CONSTANTOP && is_first_infer == true) {
  71. ConstGeTensorPtr tensor_value;
  72. if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) {
  73. GELOGE(GRAPH_FAILED, "Get value failed, node name:%s.", op_desc->GetName().c_str());
  74. return GRAPH_FAILED;
  75. }
  76. GE_CHECK_NOTNULL(tensor_value);
  77. (void)op_desc->UpdateOutputDesc(0, tensor_value->GetTensorDesc());
  78. }
  79. return GRAPH_SUCCESS;
  80. }
  81. graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points,
  82. std::vector<ge::NodePtr> &data_nodes,
  83. std::unordered_map<ge::NodePtr, bool> &node_status) {
  84. if (graph == nullptr) {
  85. GELOGE(GRAPH_FAILED, "input graph is null");
  86. return GRAPH_FAILED;
  87. }
  88. anchor_points.clear();
  89. // Get all anchor point nodes and switch nodes
  90. for (const auto &node_ptr : graph->GetAllNodes()) {
  91. if (node_ptr == nullptr) {
  92. return GRAPH_FAILED;
  93. }
  94. auto op_desc = node_ptr->GetOpDesc();
  95. if (op_desc == nullptr) {
  96. return GRAPH_FAILED;
  97. }
  98. graphStatus status = RefreshConstantOutProcess(op_desc);
  99. if (status != GRAPH_SUCCESS) {
  100. GELOGE(GRAPH_FAILED, "refresh constant out process failed!");
  101. return GRAPH_FAILED;
  102. }
  103. // consider special node save process
  104. // get all input desc format
  105. bool node_is_all_nd = false;
  106. auto input_size = static_cast<uint32_t>(op_desc->GetInputsSize());
  107. for (uint32_t i = 0; i < input_size; i++) {
  108. // Operator pre-set format but not origin format
  109. auto input_format = op_desc->MutableInputDesc(i)->GetFormat();
  110. // Pre-save data node (only main graph data) and default infer fail
  111. if (node_ptr->GetType() == DATA) {
  112. data_nodes.push_back(node_ptr);
  113. }
  114. if (input_format != FORMAT_ND && input_format != FORMAT_RESERVED) {
  115. node_is_all_nd = true;
  116. }
  117. }
  118. // Get all output desc format
  119. auto output_size = static_cast<uint32_t>(op_desc->GetOutputsSize());
  120. for (uint32_t i = 0; i < output_size; i++) {
  121. auto output_format = op_desc->MutableOutputDesc(i)->GetFormat();
  122. if (output_format != FORMAT_ND && output_format != FORMAT_RESERVED) {
  123. node_is_all_nd = true;
  124. }
  125. }
  126. // check anchor point valid
  127. if (!node_is_all_nd) {
  128. continue;
  129. }
  130. GELOGD("Node[%s] is anchor point!", node_ptr->GetName().c_str());
  131. anchor_points.push_back(node_ptr);
  132. }
  133. GELOGI("anchor_points number is %zu", anchor_points.size());
  134. return GRAPH_SUCCESS;
  135. }
  136. graphStatus FormatRefiner::AnchorProcess(const ge::NodePtr &anchor_node,
  137. std::unordered_map<ge::NodePtr, bool> &node_status) {
  138. if (anchor_node == nullptr) {
  139. GELOGE(GRAPH_FAILED, "anchor node is null!");
  140. return GRAPH_FAILED;
  141. }
  142. std::deque<ge::NodePtr> nodes;
  143. nodes.push_back(anchor_node);
  144. while (!nodes.empty()) {
  145. ge::NodePtr node = nodes.front();
  146. nodes.pop_front();
  147. graphStatus status = BackInferProcess(nodes, node, node_status);
  148. if (status != GRAPH_SUCCESS && node != nullptr) {
  149. GELOGE(status, "BackInferProcess failed!node name [%s]", node->GetName().c_str());
  150. return status;
  151. }
  152. status = ForwardInferProcess(nodes, node, node_status);
  153. if (status != GRAPH_SUCCESS && node != nullptr) {
  154. GELOGE(status, "ForwardInferProcess failed!node name [%s]", node->GetName().c_str());
  155. return status;
  156. }
  157. }
  158. return GRAPH_SUCCESS;
  159. }
  160. graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge::NodePtr &node,
  161. std::unordered_map<ge::NodePtr, bool> &node_status) {
  162. GE_CHECK_NOTNULL(node);
  163. GE_CHECK_NOTNULL(node->GetOpDesc());
  164. GELOGD("Enter back infer process!Node is [%s]", (node->GetName()).c_str());
  165. for (const auto &in_anchor : node->GetAllInDataAnchors()) {
  166. GELOGD("Node is [%s] [B]", (node->GetName()).c_str());
  167. auto in_data_anchor_idx = in_anchor->GetIdx();
  168. auto to_be_set_format =
  169. node->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_data_anchor_idx))->GetOriginFormat();
  170. if (to_be_set_format == FORMAT_ND) {
  171. GELOGD("Node [%s] [B], format is ND", (node->GetName()).c_str());
  172. continue;
  173. }
  174. auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor();
  175. if (peer_out_data_anchor == nullptr) {
  176. GELOGW("Node[%s] %dth in data anchor's peer_out_anchor is null", (node->GetName()).c_str(), in_data_anchor_idx);
  177. continue;
  178. }
  179. auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode();
  180. if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) {
  181. GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (node->GetName()).c_str());
  182. continue;
  183. }
  184. // Check format whether have been set
  185. int idx = peer_out_data_anchor->GetIdx();
  186. // do peer_out_node name and index as key to lookup reflections
  187. ge::RefCell key(peer_out_data_node->GetName(), peer_out_data_node, ge::NODE_OUT, idx);
  188. std::unordered_set<RefCell, RefCellHash> reflection;
  189. auto status = reflection_builder.LookUpRefRelations(key, reflection);
  190. if (status != GRAPH_SUCCESS) {
  191. GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d out edge",
  192. (peer_out_data_node->GetName()).c_str(), idx);
  193. return GRAPH_FAILED;
  194. }
  195. auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(idx));
  196. if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) {
  197. auto dim_num = ge_tensor_desc.GetShape().GetDimNum();
  198. if (dim_num == 0) {
  199. GELOGD("node name:%s idx:%d out is scalar. stop back infer!", peer_out_data_node->GetName().c_str(), idx);
  200. continue;
  201. }
  202. /// Check whether node to change dims ()
  203. /// Because some node will calculate with 5D, C dim maybe multi meaning
  204. auto peer_out_data_node_type = peer_out_data_node->GetType();
  205. auto iter1 = kChangeDimNodes.find(peer_out_data_node_type);
  206. // 4 means dims num
  207. if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) {
  208. GELOGD("Node[%s] is change dim node and shape is smaller than 4. do not modify format",
  209. (peer_out_data_node->GetName()).c_str());
  210. continue;
  211. }
  212. if (reflection.empty()) {
  213. ge_tensor_desc.SetOriginFormat(to_be_set_format);
  214. ge_tensor_desc.SetFormat(to_be_set_format);
  215. (void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(idx), ge_tensor_desc);
  216. // Call operator infer format api (forward) to get out format
  217. GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str());
  218. status = peer_out_data_node->InferOriginFormat();
  219. if (status != GRAPH_SUCCESS) {
  220. GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_out_data_node->GetName()).c_str());
  221. return GRAPH_FAILED;
  222. }
  223. nodes.push_back(peer_out_data_node);
  224. } else {
  225. auto status = ReflectionProcess(reflection, nodes, to_be_set_format);
  226. if (status != GRAPH_SUCCESS) {
  227. GELOGE(GRAPH_FAILED, "reflection process failed!");
  228. return GRAPH_FAILED;
  229. }
  230. }
  231. }
  232. }
  233. return GRAPH_SUCCESS;
  234. }
  235. graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, ge::NodePtr &node,
  236. std::unordered_map<ge::NodePtr, bool> &node_status) {
  237. GE_CHECK_NOTNULL(node);
  238. GE_CHECK_NOTNULL(node->GetOpDesc());
  239. GELOGD("Enter forward infer process!Node is [%s]", (node->GetName()).c_str());
  240. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  241. GELOGD("Node is [%s] [F]", (node->GetName()).c_str());
  242. GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue);
  243. auto out_data_anchor_idx = out_data_anchor->GetIdx();
  244. auto to_be_set_format =
  245. node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(out_data_anchor_idx))->GetOriginFormat();
  246. if (to_be_set_format == FORMAT_ND) {
  247. GELOGD("Node [%s] format is ND.[F]", (node->GetName()).c_str());
  248. continue;
  249. }
  250. for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  251. GE_IF_BOOL_EXEC(peer_in_data_anchor == nullptr, continue);
  252. auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode();
  253. GE_IF_BOOL_EXEC(peer_in_data_node == nullptr, continue);
  254. GE_IF_BOOL_EXEC(peer_in_data_node->GetOpDesc() == nullptr, continue);
  255. // Check format whether have been set
  256. int idx = peer_in_data_anchor->GetIdx();
  257. // do peer_out_node name and index as key to lookup reflections
  258. ge::RefCell key(peer_in_data_node->GetName(), peer_in_data_node, ge::NODE_IN, idx);
  259. std::unordered_set<RefCell, RefCellHash> reflection;
  260. auto status = reflection_builder.LookUpRefRelations(key, reflection);
  261. if (status != GRAPH_SUCCESS) {
  262. GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d input edge",
  263. (peer_in_data_node->GetName()).c_str(), idx);
  264. return GRAPH_FAILED;
  265. }
  266. auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(idx));
  267. if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) {
  268. auto dim_num = ge_tensor_desc.GetShape().GetDimNum();
  269. if (dim_num == 0) {
  270. GELOGI("node name:%s idx:%d in is scalar. stop forward infer!", peer_in_data_node->GetName().c_str(), idx);
  271. continue;
  272. }
  273. /// Check whether node to change dims ()
  274. /// Because some node will calculate with 5D, C dim maybe multi meaning
  275. auto peer_in_data_node_type = peer_in_data_node->GetType();
  276. auto iter1 = kChangeDimNodes.find(peer_in_data_node_type);
  277. // 4 means dims num
  278. if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) {
  279. GELOGD("Node[%s] is change dim node. do not infer origin format", (peer_in_data_node->GetName()).c_str());
  280. continue;
  281. }
  282. if (reflection.empty()) {
  283. ge_tensor_desc.SetOriginFormat(to_be_set_format);
  284. ge_tensor_desc.SetFormat(to_be_set_format);
  285. (void)peer_in_data_node->GetOpDesc()->UpdateInputDesc(static_cast<uint32_t>(idx), ge_tensor_desc);
  286. /// Because netoutput node added before infer format ,so netoutput is end condition
  287. /// must set netoutput format , because saved result depend on format
  288. if (peer_in_data_node_type == NETOUTPUT) {
  289. continue;
  290. }
  291. // Call operator infer format api (forward) to get out format
  292. GELOGD("call infer format func[Back]!Node is [%s] ", (peer_in_data_node->GetName()).c_str());
  293. status = peer_in_data_node->InferOriginFormat();
  294. if (status != GRAPH_SUCCESS) {
  295. GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_in_data_node->GetName()).c_str());
  296. return GRAPH_FAILED;
  297. }
  298. nodes.push_back(peer_in_data_node);
  299. } else {
  300. auto status = ReflectionProcess(reflection, nodes, to_be_set_format);
  301. if (status != GRAPH_SUCCESS) {
  302. GELOGE(GRAPH_FAILED, "reflection process failed!");
  303. return GRAPH_FAILED;
  304. }
  305. }
  306. }
  307. }
  308. }
  309. return GRAPH_SUCCESS;
  310. }
  311. void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector<ge::NodePtr> &anchor_points) {
  312. for (const auto &node : anchor_points) {
  313. if (node == nullptr || node->GetOpDesc() == nullptr) {
  314. continue;
  315. }
  316. for (const auto &input_desc : node->GetOpDesc()->GetAllInputsDescPtr()) {
  317. if (input_desc != nullptr) {
  318. input_desc->SetOriginFormat(input_desc->GetFormat());
  319. }
  320. }
  321. for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDescPtr()) {
  322. if (output_desc != nullptr) {
  323. output_desc->SetOriginFormat(output_desc->GetFormat());
  324. }
  325. }
  326. }
  327. }
  328. void FormatRefiner::SetInferOrigineFormatFlag(bool is_first) { is_first_infer = is_first; }
  329. graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format,
  330. std::unordered_map<ge::NodePtr, bool> &node_status) {
  331. bool is_internal_format = TypeUtils::IsInternalFormat(data_format);
  332. bool need_process = (!is_first_infer) && (!is_internal_format) && (data_format != FORMAT_ND);
  333. if (!need_process) {
  334. GELOGI("no necessary to do DataNodeFormatProcess.is_first_infer:%d, data_format:%s", is_first_infer,
  335. TypeUtils::FormatToSerialString(data_format).c_str());
  336. return GRAPH_SUCCESS;
  337. }
  338. GELOGD("Enter DataNodeFormatProcess");
  339. std::vector<ge::NodePtr> uninfered_data_nodes;
  340. // Check and renew data nodes format
  341. for (const auto &data_node : data_nodes) {
  342. GE_CHECK_NOTNULL(data_node);
  343. auto op_desc = data_node->GetOpDesc();
  344. GE_CHECK_NOTNULL(op_desc);
  345. GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(0));
  346. auto curr_format = op_desc->GetOutputDescPtr(0)->GetOriginFormat();
  347. if (curr_format != FORMAT_ND) {
  348. // Data format has been infered , continue
  349. continue;
  350. }
  351. // Set format for un-infered data node
  352. auto input_descs = op_desc->GetAllInputsDescPtr();
  353. auto output_descs = op_desc->GetAllOutputsDescPtr();
  354. for (const auto &input_desc : input_descs) {
  355. if (input_desc != nullptr) {
  356. input_desc->SetOriginFormat(data_format);
  357. input_desc->SetFormat(data_format);
  358. }
  359. }
  360. for (const auto &output_desc : output_descs) {
  361. if (output_desc != nullptr) {
  362. output_desc->SetOriginFormat(data_format);
  363. output_desc->SetFormat(data_format);
  364. }
  365. }
  366. uninfered_data_nodes.push_back(data_node);
  367. }
  368. // Reinfer format from uninfered data nodes
  369. for (const auto &node : uninfered_data_nodes) {
  370. if (node == nullptr) {
  371. continue;
  372. }
  373. GELOGD("data node [%s] start infer format process", node->GetName().c_str());
  374. auto status = AnchorProcess(node, node_status);
  375. if (status != GRAPH_SUCCESS) {
  376. GELOGE(GRAPH_FAILED, "data node [%s] infer format process failed!", node->GetName().c_str());
  377. return GRAPH_FAILED;
  378. }
  379. }
  380. GELOGD("DataNodeFormatProcess success");
  381. return GRAPH_SUCCESS;
  382. }
  383. graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) {
  384. GELOGI("Enter InferOrigineFormat process!");
  385. // True: infered false:no-infered
  386. std::unordered_map<ge::NodePtr, bool> node_status;
  387. std::vector<ge::NodePtr> anchor_points;
  388. std::vector<ge::NodePtr> data_nodes;
  389. // global net format
  390. net_format_is_nd = true;
  391. g_user_set_format = FORMAT_ND;
  392. if (graph == nullptr) {
  393. GELOGE(GRAPH_FAILED, "input graph is null");
  394. return GRAPH_FAILED;
  395. }
  396. // build reflection relations of boundary
  397. (void)reflection_builder.Clear();
  398. auto status = reflection_builder.BuildRefRelations(*graph);
  399. if (status != GRAPH_SUCCESS) {
  400. GELOGE(GRAPH_FAILED, "build reflection relations failed for main and subgraph!");
  401. return GRAPH_FAILED;
  402. }
  403. // User set global net format
  404. status = GetAnchorPoints(graph, anchor_points, data_nodes, node_status);
  405. if (status != GRAPH_SUCCESS) {
  406. GELOGE(GRAPH_FAILED, "GetAnchorPoints Process Faild!");
  407. return GRAPH_FAILED;
  408. }
  409. // Refresh origin format of anchor point
  410. RefreshOriginFormatOfAnchor(anchor_points);
  411. // Infer format process
  412. for (const auto &anchor_node : anchor_points) {
  413. if (anchor_node == nullptr) {
  414. continue;
  415. }
  416. status = AnchorProcess(anchor_node, node_status);
  417. if (status != GRAPH_SUCCESS) {
  418. GELOGE(GRAPH_FAILED, "Anchor node [%s] process failed!", anchor_node->GetName().c_str());
  419. return GRAPH_FAILED;
  420. }
  421. }
  422. /// According to discuss with sys-enginer, data node default format is ND.Its format
  423. /// should be set by infered.But if some data-node can not be got by infer, set context's
  424. /// format for these data nodes.
  425. /// Notice: ignore 5D formats
  426. auto data_format = graph->GetDataFormat();
  427. status = DataNodeFormatProcess(data_nodes, data_format, node_status);
  428. // Set infer flag to false
  429. SetInferOrigineFormatFlag(false);
  430. return status;
  431. }
  432. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示