You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node_item.cc 18 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "hybrid/model/node_item.h"
  17. #include "graph/compute_graph.h"
  18. #include "graph/debug/ge_attr_define.h"
  19. #include "hybrid/executor/worker/shape_inference_engine.h"
  20. #include "hybrid/node_executor/node_executor.h"
  21. namespace ge {
  22. namespace hybrid {
  23. namespace {
  24. const uint8_t kMaxTransCount = 3;
  25. const uint32_t kTransOpIoSize = 1;
  26. const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph";
  27. const char *const kNodeTypeRetVal = "_RetVal";
  28. const std::set<std::string> kControlOpTypes{
  29. IF, STATELESSIF, CASE, WHILE, STATELESSWHILE
  30. };
  31. const std::set<std::string> kControlFlowOpTypes{
  32. STREAMACTIVE, STREAMSWITCH, STREAMSWITCHN, ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT,
  33. LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX
  34. };
  35. const std::set<std::string> kMergeOpTypes{
  36. MERGE, REFMERGE, STREAMMERGE
  37. };
  38. bool IsEnterFeedNode(NodePtr node) {
  39. // For: Enter -> node
  40. // For: Enter -> Cast -> node
  41. // For: Enter -> TransData -> Cast -> node
  42. for (uint8_t i = 0; i < kMaxTransCount; ++i) {
  43. if (kEnterOpTypes.count(NodeUtils::GetNodeType(node)) > 0) {
  44. GELOGD("Node[%s] is Enter feed node.", node->GetName().c_str());
  45. return true;
  46. }
  47. const auto all_nodes = node->GetInDataNodes();
  48. if (all_nodes.size() != kTransOpIoSize || node->GetAllInDataAnchorsSize() != kTransOpIoSize) {
  49. return false;
  50. }
  51. node = all_nodes.at(0);
  52. }
  53. return false;
  54. }
  55. Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) {
  56. uint32_t parent_index = 0;
  57. if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  58. GELOGE(FAILED, "[Invoke][GetInt][%s] Failed to get attr [%s]",
  59. op_desc.GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str());
  60. REPORT_CALL_ERROR("E19999", "[%s] Failed to get attr [%s]",
  61. op_desc.GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str());
  62. return FAILED;
  63. }
  64. for (auto &node_and_anchor : node.GetOutDataNodesAndAnchors()) {
  65. auto dst_op_desc = node_and_anchor.first->GetOpDesc();
  66. GE_CHECK_NOTNULL(dst_op_desc);
  67. auto in_idx = node_and_anchor.second->GetIdx();
  68. auto tensor_desc = dst_op_desc->MutableInputDesc(in_idx);
  69. fused_subgraph.input_mapping[static_cast<int>(parent_index)].emplace_back(tensor_desc);
  70. GELOGD("Input[%u] mapped to [%s:%u]", parent_index, dst_op_desc->GetName().c_str(), in_idx);
  71. }
  72. return SUCCESS;
  73. }
  74. Status ParseOutputMapping(const OpDescPtr &op_desc, FusedSubgraph &fused_subgraph) {
  75. uint32_t parent_index = 0;
  76. if (!AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  77. GELOGE(FAILED, "[Invoke][GetInt][%s] Failed to get attr [%s]",
  78. op_desc->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str());
  79. REPORT_CALL_ERROR("E19999", "[%s] Failed to get attr [%s].",
  80. op_desc->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str());
  81. return FAILED;
  82. }
  83. fused_subgraph.output_mapping.emplace(static_cast<int>(parent_index), op_desc);
  84. return SUCCESS;
  85. }
  86. Status ParseFusedSubgraph(NodeItem &node_item) {
  87. if (!node_item.op_desc->HasAttr(kAttrNameOriginalFusionGraph)) {
  88. return SUCCESS;
  89. }
  90. GELOGI("[%s] Start to parse fused subgraph.", node_item.node_name.c_str());
  91. auto fused_subgraph = std::unique_ptr<FusedSubgraph>(new(std::nothrow)FusedSubgraph());
  92. GE_CHECK_NOTNULL(fused_subgraph);
  93. ComputeGraphPtr fused_graph;
  94. (void) AttrUtils::GetGraph(*node_item.op_desc, kAttrNameOriginalFusionGraph, fused_graph);
  95. GE_CHECK_NOTNULL(fused_graph);
  96. fused_graph->SetGraphUnknownFlag(true);
  97. fused_subgraph->graph = fused_graph;
  98. GE_CHK_GRAPH_STATUS_RET(fused_graph->TopologicalSorting());
  99. for (auto &node : fused_graph->GetAllNodes()) {
  100. GE_CHECK_NOTNULL(node);
  101. auto op_desc = node->GetOpDesc();
  102. GE_CHECK_NOTNULL(op_desc);
  103. const std::string node_type = NodeUtils::GetNodeType(node);
  104. if (node_type == DATA) {
  105. GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph));
  106. } else if (node_type == kNodeTypeRetVal) {
  107. GE_CHK_GRAPH_STATUS_RET(ParseOutputMapping(op_desc, *fused_subgraph));
  108. } else {
  109. fused_subgraph->nodes.emplace_back(node);
  110. }
  111. }
  112. node_item.fused_subgraph = std::move(fused_subgraph);
  113. GELOGI("[%s] Done parsing fused subgraph successfully.", node_item.NodeName().c_str());
  114. return SUCCESS;
  115. }
  116. } // namespace
  117. bool IsControlFlowV2Op(const std::string &op_type) {
  118. return kControlOpTypes.count(op_type) > 0;
  119. }
  120. NodeItem::NodeItem(NodePtr node) : node(std::move(node)) {
  121. this->op_desc = this->node->GetOpDesc().get();
  122. this->node_name = this->node->GetName();
  123. this->node_type = this->node->GetType();
  124. }
  125. Status NodeItem::Create(const NodePtr &node, std::unique_ptr<NodeItem> &node_item) {
  126. GE_CHECK_NOTNULL(node);
  127. GE_CHECK_NOTNULL(node->GetOpDesc());
  128. std::unique_ptr<NodeItem> instance(new(std::nothrow)NodeItem(node));
  129. GE_CHECK_NOTNULL(instance);
  130. GE_CHK_STATUS_RET(instance->Init(), "[Invoke][Init]Failed to init NodeItem [%s] .", node->GetName().c_str());
  131. node_item = std::move(instance);
  132. return SUCCESS;
  133. }
  134. void NodeItem::ResolveOptionalInputs() {
  135. if (op_desc->GetAllInputsSize() != op_desc->GetInputsSize()) {
  136. has_optional_inputs = true;
  137. for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) {
  138. const auto &input_desc = op_desc->MutableInputDesc(i);
  139. if (input_desc == nullptr) {
  140. GELOGD("[%s] Input[%zu] is optional and invalid", NodeName().c_str(), i);
  141. } else {
  142. input_desc_indices_.emplace_back(static_cast<uint32_t>(i));
  143. }
  144. }
  145. }
  146. }
  147. Status NodeItem::InitInputsAndOutputs() {
  148. GE_CHECK_LE(op_desc->GetInputsSize(), INT32_MAX);
  149. GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX);
  150. num_inputs = static_cast<int>(op_desc->GetInputsSize());
  151. num_outputs = static_cast<int>(op_desc->GetOutputsSize());
  152. if (AttrUtils::GetInt(op_desc, ::ge::ATTR_STAGE_LEVEL, group)) {
  153. GELOGD("[%s] Got stage level from op_desc = %d", op_desc->GetName().c_str(), group);
  154. } else {
  155. if (node->GetOwnerComputeGraph() != nullptr) {
  156. if (AttrUtils::GetInt(node->GetOwnerComputeGraph(), ::ge::ATTR_STAGE_LEVEL, group)) {
  157. GELOGD("[%s] Got stage level from parent graph = %d", op_desc->GetName().c_str(), group);
  158. } else {
  159. auto parent_node = node->GetOwnerComputeGraph()->GetParentNode();
  160. if ((parent_node != nullptr) && (AttrUtils::GetInt(parent_node->GetOpDesc(), ::ge::ATTR_STAGE_LEVEL, group))) {
  161. GELOGD("[%s] Got stage level from parent node = %d", op_desc->GetName().c_str(), group);
  162. } else {
  163. GELOGD("[%s] Node do not set stage level", op_desc->GetName().c_str());
  164. }
  165. }
  166. }
  167. }
  168. ResolveOptionalInputs();
  169. return SUCCESS;
  170. }
  171. Status NodeItem::ResolveDynamicState() {
  172. (void) AttrUtils::GetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, is_dynamic);
  173. GELOGD("Node name is %s, dynamic state is %d.", this->node_name.c_str(), is_dynamic);
  174. if (!is_dynamic) {
  175. GE_CHK_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_dynamic),
  176. "[Invoke][GetNodeUnknownShapeStatus][%s] Failed to get shape status.",
  177. node->GetName().c_str());
  178. }
  179. return SUCCESS;
  180. }
  181. Status NodeItem::ResolveStaticInputsAndOutputs() {
  182. for (int i = 0; i < num_inputs; ++i) {
  183. // Data has unconnected input but set by framework
  184. if (node_type != DATA) {
  185. int origin_index = i;
  186. if (has_optional_inputs) {
  187. origin_index = input_desc_indices_[i];
  188. }
  189. auto in_data_anchor = node->GetInDataAnchor(origin_index);
  190. GE_CHECK_NOTNULL(in_data_anchor);
  191. // If no node was connected to the current input anchor
  192. // increase num_static_input_shapes in case dead wait in ShapeInferenceState::AwaitShapesReady
  193. if (in_data_anchor->GetPeerOutAnchor() == nullptr ||
  194. in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() == nullptr) {
  195. num_static_input_shapes++;
  196. is_input_shape_static_.push_back(true);
  197. GELOGW("[%s] Peer node of input[%d] is empty", NodeName().c_str(), i);
  198. continue;
  199. }
  200. }
  201. const auto &input_desc = MutableInputDesc(i);
  202. GE_CHECK_NOTNULL(input_desc);
  203. if (input_desc->MutableShape().IsUnknownShape()) {
  204. is_input_shape_static_.push_back(false);
  205. } else {
  206. num_static_input_shapes++;
  207. is_input_shape_static_.push_back(true);
  208. GELOGD("[%s] The shape of input[%d] is static. shape = [%s]",
  209. NodeName().c_str(), i, input_desc->MutableShape().ToString().c_str());
  210. }
  211. }
  212. for (int i = 0; i < num_outputs; ++i) {
  213. const auto &output_desc = op_desc->MutableOutputDesc(i);
  214. GE_CHECK_NOTNULL(output_desc);
  215. if (output_desc->MutableShape().IsUnknownShape()) {
  216. is_output_shape_static = false;
  217. break;
  218. }
  219. }
  220. if (is_output_shape_static) {
  221. GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(*this));
  222. }
  223. return SUCCESS;
  224. }
  225. void NodeItem::ResolveUnknownShapeType() {
  226. if (IsControlFlowV2Op() || (is_dynamic && node_type == PARTITIONEDCALL)) {
  227. shape_inference_type = DEPEND_COMPUTE;
  228. } else {
  229. int32_t unknown_shape_type_val = 0;
  230. (void) AttrUtils::GetInt(op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val);
  231. shape_inference_type = static_cast<UnknowShapeOpType>(unknown_shape_type_val);
  232. }
  233. }
  234. Status NodeItem::Init() {
  235. is_ctrl_flow_v2_op_ = ge::hybrid::IsControlFlowV2Op(node_type);
  236. is_ctrl_flow_op_ = kControlFlowOpTypes.count(node_type) > 0;
  237. is_merge_op_ = kMergeOpTypes.count(node_type) > 0;
  238. is_root_node_ = node->GetInAllNodes().empty();
  239. GE_CHK_STATUS_RET_NOLOG(InitInputsAndOutputs());
  240. GE_CHK_STATUS_RET_NOLOG(ResolveDynamicState());
  241. ResolveUnknownShapeType();
  242. if (is_dynamic) {
  243. GE_CHK_STATUS_RET_NOLOG(ResolveStaticInputsAndOutputs());
  244. GE_CHK_STATUS_RET(ParseFusedSubgraph(*this),
  245. "[Invoke][ParseFusedSubgraph][%s] Failed to parse fused subgraph", node_name.c_str());
  246. }
  247. copy_mu_ = MakeShared<std::mutex>();
  248. GE_CHECK_NOTNULL(copy_mu_);
  249. return SUCCESS;
  250. }
  251. bool NodeItem::IsHcclOp() const {
  252. return NodeExecutorManager::GetInstance().ResolveExecutorType(*node) == NodeExecutorManager::ExecutorType::HCCL;
  253. }
  254. std::string NodeItem::DebugString() const {
  255. std::stringstream ss;
  256. ss << "Node: ";
  257. ss << "id = " << node_id;
  258. ss << ", name = [" << node->GetName();
  259. ss << "], type = " << node->GetType();
  260. ss << ", is_dynamic = " << (is_dynamic ? "True" : "False");
  261. ss << ", is_output_static = " << (is_output_shape_static ? "True" : "False");
  262. ss << ", unknown_shape_op_type = " << shape_inference_type;
  263. ss << ", stage = " << group;
  264. ss << ", input_start = " << input_start;
  265. ss << ", num_inputs = " << num_inputs;
  266. ss << ", output_start = " << output_start;
  267. ss << ", num_outputs = " << num_outputs;
  268. ss << ", dependent_nodes = [";
  269. for (const auto &dep_node : dependents_for_shape_inference) {
  270. ss << dep_node->GetName() << ", ";
  271. }
  272. ss << "]";
  273. int index = 0;
  274. for (auto &items : outputs) {
  275. ss << ", output[" << index++ << "]: ";
  276. for (auto &item : items) {
  277. ss << "(" << item.second->NodeName() << ":" << item.first << "), ";
  278. }
  279. }
  280. return ss.str();
  281. }
  282. void NodeItem::SetToDynamic() {
  283. num_static_input_shapes = 0;
  284. is_dynamic = true;
  285. for (size_t i = 0; i < is_input_shape_static_.size(); ++i) {
  286. is_input_shape_static_[i] = false;
  287. }
  288. if (kernel_task != nullptr && !kernel_task->IsSupportDynamicShape()) {
  289. GELOGD("[%s] Dynamic shape is not supported, clear node task.", node_name.c_str());
  290. kernel_task = nullptr;
  291. }
  292. }
  293. GeTensorDescPtr NodeItem::DoGetInputDesc(int index) const {
  294. if (!has_optional_inputs) {
  295. return op_desc->MutableInputDesc(static_cast<uint32_t>(index));
  296. }
  297. if (index < 0 || index >= num_inputs) {
  298. GELOGE(PARAM_INVALID, "[Check][Param:index][%s] Invalid input index, num inputs = %d, index = %d",
  299. node_name.c_str(), num_inputs, index);
  300. REPORT_INNER_ERROR("E19999", "Invalid input index, node:%s num inputs = %d, index = %d",
  301. node_name.c_str(), num_inputs, index);
  302. return nullptr;
  303. }
  304. return op_desc->MutableInputDesc(input_desc_indices_[index]);
  305. }
  306. GeTensorDescPtr NodeItem::MutableInputDesc(int index) const {
  307. std::lock_guard<std::mutex> lk(mu_);
  308. return DoGetInputDesc(index);
  309. }
  310. Status NodeItem::GetInputDesc(int index, GeTensorDesc &tensor_desc) const {
  311. std::lock_guard<std::mutex> lk(mu_);
  312. auto input_desc = DoGetInputDesc(index);
  313. GE_CHECK_NOTNULL(input_desc);
  314. tensor_desc = *input_desc;
  315. return SUCCESS;
  316. }
  317. Status NodeItem::GetOutputDesc(int index, GeTensorDesc &tensor_desc) const {
  318. std::lock_guard<std::mutex> lk(mu_);
  319. auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(index));
  320. GE_CHECK_NOTNULL(output_desc);
  321. tensor_desc = *output_desc;
  322. return SUCCESS;
  323. }
  324. GeTensorDescPtr NodeItem::MutableOutputDesc(int index) const {
  325. std::lock_guard<std::mutex> lk(mu_);
  326. return op_desc->MutableOutputDesc(static_cast<uint32_t>(index));
  327. }
  328. Status NodeItem::UpdateInputDesc(int index, const GeTensorDesc &tensor_desc) {
  329. std::lock_guard<std::mutex> lk(mu_);
  330. auto input_desc = DoGetInputDesc(index);
  331. GE_CHECK_NOTNULL(input_desc);
  332. *input_desc = tensor_desc;
  333. return SUCCESS;
  334. }
  335. Status NodeItem::GetCanonicalInputIndex(uint32_t index, int &canonical_index) const {
  336. if (!has_optional_inputs) {
  337. canonical_index = index;
  338. return SUCCESS;
  339. }
  340. auto iter = std::find(input_desc_indices_.begin(), input_desc_indices_.end(), index);
  341. if (iter == input_desc_indices_.end()) {
  342. GELOGE(INTERNAL_ERROR,
  343. "[Check][Param:index]input index:%u not in input_desc_indices_, check Invalid, node:%s",
  344. index, node_name.c_str());
  345. REPORT_INNER_ERROR("E19999", "input index:%u not in input_desc_indices_, check Invalid, node:%s",
  346. index, node_name.c_str());
  347. return INTERNAL_ERROR;
  348. }
  349. canonical_index = static_cast<int>(iter - input_desc_indices_.begin());
  350. GELOGD("[%s] Canonicalize input index from [%u] to [%d]", node_name.c_str(), index, canonical_index);
  351. return SUCCESS;
  352. }
  353. bool NodeItem::IsInputShapeStatic(int index) const {
  354. if (!is_dynamic) {
  355. return true;
  356. }
  357. if (static_cast<size_t>(index) >= is_input_shape_static_.size()) {
  358. GELOGE(PARAM_INVALID, "[Check][Param:index]Input index(%d) out of range: [0, %zu)",
  359. index, is_input_shape_static_.size());
  360. REPORT_INNER_ERROR("E19999", "Input index(%d) out of range: [0, %zu).", index, is_input_shape_static_.size());
  361. return false;
  362. }
  363. return is_input_shape_static_[index];
  364. }
  365. void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) {
  366. data_send_.emplace(node_item);
  367. node_item->data_recv_[this] = anchor_index;
  368. if (is_root_node_) {
  369. auto &data_anchors = node_item->root_data_[this];
  370. data_anchors.emplace(anchor_index);
  371. }
  372. // If Enter feed Not Merge, take as root Node.
  373. if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE)) {
  374. auto &data_anchors = node_item->enter_data_[this];
  375. data_anchors.emplace(anchor_index);
  376. }
  377. GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());
  378. }
  379. void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) {
  380. if (switch_index < switch_groups_.size()) {
  381. auto &switch_group = switch_groups_[switch_index];
  382. switch_group.emplace(node_item);
  383. } else {
  384. ctrl_send_.insert(node_item);
  385. }
  386. node_item->ctrl_recv_.emplace(this);
  387. if (is_root_node_) {
  388. node_item->root_ctrl_.emplace(this);
  389. }
  390. // If Enter feed control signal, take as root Node.
  391. if (IsEnterFeedNode(node) && (node_item->node_type != STREAMMERGE && node_item->node_type != STREAMACTIVE)) {
  392. node_item->enter_ctrl_.emplace(this);
  393. }
  394. GELOGI("Node[%s] will control node[%s]", NodeName().c_str(), node_item->NodeName().c_str());
  395. }
  396. void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) {
  397. if (merge_index >= switch_groups_.size()) {
  398. GELOGE(FAILED, "[%s] group size: %zu, merge index: %u", NodeName().c_str(), switch_groups_.size(), merge_index);
  399. return;
  400. }
  401. // this is StreamMerge node, node_item is StreamActive node.
  402. auto &switch_group = switch_groups_[merge_index];
  403. switch_group.emplace(node_item);
  404. node_item->ctrl_send_.emplace(this);
  405. GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str());
  406. }
  407. size_t NodeItem::GetMergeCtrl(uint32_t merge_index) const {
  408. return ((node_type == STREAMMERGE) && (merge_index < switch_groups_.size())) ? switch_groups_[merge_index].size() : 0;
  409. }
  410. OptionalMutexGuard::OptionalMutexGuard(std::mutex *mutex, const string &name) : mu_(mutex), name_(name) {
  411. if (mu_ != nullptr) {
  412. GELOGD("lock for %s", name_.c_str());
  413. mu_->lock();
  414. }
  415. }
  416. OptionalMutexGuard::~OptionalMutexGuard() {
  417. if (mu_ != nullptr) {
  418. GELOGD("unlock for %s", name_.c_str());
  419. mu_->unlock();
  420. mu_ = nullptr;
  421. }
  422. }
  423. } // namespace hybrid
  424. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示