You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_serialize.cc 20 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/model_serialize.h"
  17. #include <google/protobuf/text_format.h>
  18. #include <iostream>
  19. #include "debug/ge_attr_define.h"
  20. #include "debug/ge_log.h"
  21. #include "debug/ge_util.h"
  22. #include "framework/common/debug/ge_log.h"
  23. #include "graph/detail/model_serialize_imp.h"
  24. #include "proto/ge_ir.pb.h"
  25. #include "utils/graph_utils.h"
  26. using std::string;
  27. namespace ge {
  28. bool ModelSerializeImp::ParseNodeIndex(const string &node_index, string &node_name, int32_t &index) {
  29. auto sep = node_index.rfind(":");
  30. if (sep == string::npos) {
  31. GELOGW("separator is not found in node_index.");
  32. return false;
  33. }
  34. node_name = node_index.substr(0, sep);
  35. auto index_str = node_index.substr(sep + 1);
  36. index = static_cast<int32_t>(std::strtol(index_str.c_str(), nullptr, 10));
  37. return true;
  38. }
  39. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeTensor(const ConstGeTensorPtr &tensor,
  40. proto::TensorDef *tensor_proto) {
  41. GE_CHK_BOOL_EXEC(tensor != nullptr, return false, "tensor is null.");
  42. GE_CHK_BOOL_EXEC(tensor_proto != nullptr, return false, "tensor_proto is null.");
  43. if (tensor->tensor_def_.GetProtoMsg() != nullptr) {
  44. *tensor_proto = *tensor->tensor_def_.GetProtoMsg();
  45. return true;
  46. }
  47. return false;
  48. }
  49. bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_proto) {
  50. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is null.");
  51. GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null.");
  52. op_def_proto->clear_input();
  53. // Inputs
  54. for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
  55. if (in_data_anchor != nullptr) {
  56. auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  57. if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) {
  58. op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" +
  59. std::to_string(peer_out_anchor->GetIdx()));
  60. } else {
  61. op_def_proto->add_input("");
  62. }
  63. }
  64. }
  65. // Control edge
  66. auto control_anchor = node->GetInControlAnchor();
  67. if (control_anchor != nullptr) {
  68. auto peer_out_anchors = control_anchor->GetPeerOutControlAnchors();
  69. for (const auto &peer_out_anchor : peer_out_anchors) {
  70. if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) {
  71. op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":-1");
  72. }
  73. }
  74. }
  75. return true;
  76. }
  77. bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) {
  78. if (op_desc == nullptr || op_def_proto == nullptr) {
  79. GELOGE(GRAPH_FAILED, "Input Para Invalid");
  80. return false;
  81. }
  82. if (op_desc->op_def_.GetProtoMsg() != nullptr) {
  83. *op_def_proto = *op_desc->op_def_.GetProtoMsg();
  84. op_def_proto->clear_input_desc();
  85. op_def_proto->clear_output_desc();
  86. // Input descs
  87. if (op_desc->GetInputsSize() > 0) {
  88. auto size = static_cast<uint32_t>(op_desc->GetInputsSize());
  89. for (uint32_t i = 0; i < size; i++) {
  90. auto tensor_desc = op_desc->GetInputDescPtr(i);
  91. if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) {
  92. *op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg());
  93. }
  94. }
  95. }
  96. // Output descs
  97. if (op_desc->GetOutputsSize() > 0) {
  98. auto size = static_cast<uint32_t>(op_desc->GetOutputsSize());
  99. for (uint32_t i = 0; i < size; i++) {
  100. auto tensor_desc = op_desc->GetOutputDescPtr(i);
  101. if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) {
  102. *op_def_proto->add_output_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg());
  103. }
  104. }
  105. }
  106. }
  107. return true;
  108. }
  109. bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto) {
  110. if (node == nullptr || op_def_proto == nullptr) {
  111. GELOGE(GRAPH_FAILED, "Input Para Node Invalid");
  112. return false;
  113. }
  114. if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto)) {
  115. GELOGE(GRAPH_FAILED, "Serialize OpDesc failed");
  116. return false;
  117. }
  118. if (SerializeEdge(node, op_def_proto)) {
  119. return true;
  120. } else {
  121. return false;
  122. }
  123. }
  124. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph,
  125. proto::GraphDef *graph_proto) {
  126. if (graph == nullptr || graph_proto == nullptr) {
  127. GELOGE(GRAPH_FAILED, "Input para Invalid");
  128. return false;
  129. }
  130. graph_proto->set_name(graph->GetName());
  131. // Inputs
  132. for (const auto &input : graph->GetInputNodes()) {
  133. if (input != nullptr) {
  134. graph_proto->add_input(input->GetName() + ":0");
  135. }
  136. }
  137. // Outputs
  138. for (const auto &output : graph->GetOutputNodes()) {
  139. if (output != nullptr) {
  140. graph_proto->add_output(output->GetName() + ":0");
  141. }
  142. }
  143. if (graph->attrs_.GetProtoMsg() != nullptr) {
  144. *graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg();
  145. }
  146. for (const auto &node : graph->GetDirectNode()) {
  147. if (!SerializeNode(node, graph_proto->add_op())) {
  148. if (node->GetOpDesc() != nullptr) {
  149. GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str());
  150. }
  151. return false;
  152. }
  153. }
  154. return true;
  155. }
  156. bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto) {
  157. if (model_proto == nullptr) {
  158. GELOGE(GRAPH_FAILED, "model_proto para Invalid");
  159. return false;
  160. }
  161. model_proto->set_name(model.GetName());
  162. model_proto->set_custom_version(model.GetPlatformVersion());
  163. model_proto->set_version(model.GetVersion());
  164. if (model.attrs_.GetProtoMsg()) {
  165. *model_proto->mutable_attr() = *model.attrs_.GetProtoMsg();
  166. }
  167. auto &graph = model.graph_;
  168. auto compute_graph = GraphUtils::GetComputeGraph(graph);
  169. if (compute_graph == nullptr) {
  170. GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr");
  171. return false;
  172. }
  173. if (!SerializeGraph(compute_graph, model_proto->add_graph())) {
  174. GELOGE(GRAPH_FAILED, "SerializeGraph fail");
  175. return false;
  176. }
  177. return true;
  178. }
  179. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeTensor(
  180. GeTensorPtr &tensor, proto::TensorDef &tensor_proto) {
  181. tensor = std::shared_ptr<GeTensor>(new (std::nothrow) GeTensor(protobuf_owner_, &tensor_proto));
  182. if (tensor == nullptr) {
  183. GELOGE(GRAPH_FAILED, "tensor is nullptr");
  184. return false;
  185. } else {
  186. return true;
  187. }
  188. }
  189. bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_def_proto) {
  190. op_desc = std::shared_ptr<OpDesc>(new (std::nothrow) OpDesc(protobuf_owner_, &op_def_proto));
  191. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr.");
  192. // Input tensor
  193. for (auto &input_desc : *op_def_proto.mutable_input_desc()) {
  194. std::shared_ptr<GeTensorDesc> temp_value =
  195. std::shared_ptr<GeTensorDesc>(new (std::nothrow) GeTensorDesc(protobuf_owner_, &input_desc));
  196. GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr");
  197. op_desc->inputs_desc_.push_back(temp_value);
  198. }
  199. // Output tensor
  200. for (auto &output_desc : *op_def_proto.mutable_output_desc()) {
  201. std::shared_ptr<GeTensorDesc> temp_value =
  202. std::shared_ptr<GeTensorDesc>(new (std::nothrow) GeTensorDesc(protobuf_owner_, &output_desc));
  203. GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr");
  204. op_desc->outputs_desc_.push_back(temp_value);
  205. }
  206. return true;
  207. }
  208. bool ModelSerializeImp::UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op_def_proto) {
  209. GE_RT_FALSE_CHECK_NOTNULL(graph);
  210. OpDescPtr op_desc = nullptr;
  211. if (!UnserializeOpDesc(op_desc, op_def_proto)) {
  212. GELOGW("UnserializeOpDesc error.");
  213. }
  214. NodePtr node = graph->AddNode(op_desc);
  215. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr.");
  216. // Inputs
  217. int dst_index = 0;
  218. for (const auto &input : op_def_proto.input()) {
  219. string node_name;
  220. int32_t index = 0;
  221. if (ParseNodeIndex(input, node_name, index)) {
  222. node_input_node_names_.push_back(NodeNameNodeReq{node_name, index, node, dst_index, op_def_proto.name()});
  223. }
  224. if (index >= 0) {
  225. dst_index++;
  226. }
  227. }
  228. node_map_[op_def_proto.name()] = node;
  229. return true;
  230. }
  231. bool ModelSerializeImp::HandleNodeNameRef() {
  232. // Edges
  233. for (auto &item : node_input_node_names_) {
  234. auto src_node_it = node_map_.find(item.src_node_name);
  235. if (src_node_it == node_map_.end()) {
  236. GELOGE(GRAPH_FAILED, "cannot find node %s", item.src_node_name.c_str());
  237. return false;
  238. }
  239. GE_IF_BOOL_EXEC(src_node_it->second == nullptr || item.dst_node == nullptr, continue);
  240. if (item.src_out_index >= 0) {
  241. auto src_anchor = src_node_it->second->GetOutDataAnchor(item.src_out_index);
  242. auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index);
  243. if (src_anchor == nullptr || dst_anchor == nullptr) {
  244. GELOGE(GRAPH_FAILED, "get anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index,
  245. item.dst_node_name.c_str(), item.dst_in_index);
  246. return false;
  247. }
  248. GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed.");
  249. } else {
  250. // Control edge
  251. auto src_anchor = src_node_it->second->GetOutControlAnchor();
  252. auto dst_anchor = item.dst_node->GetInControlAnchor();
  253. if (src_anchor != nullptr && dst_anchor != nullptr) {
  254. GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed.");
  255. }
  256. }
  257. }
  258. // Graph input
  259. for (auto &item : graph_input_node_names_) {
  260. auto node_it = node_map_.find(item.node_name);
  261. if (node_it == node_map_.end()) {
  262. GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str());
  263. return false;
  264. }
  265. GE_IF_BOOL_EXEC(item.graph == nullptr, continue);
  266. auto ret = item.graph->AddInputNode(node_it->second);
  267. if (ret == nullptr) {
  268. return false;
  269. }
  270. }
  271. // Graph output
  272. for (auto &item : graph_output_node_names_) {
  273. auto node_it = node_map_.find(item.node_name);
  274. if (node_it == node_map_.end()) {
  275. GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str());
  276. return false;
  277. }
  278. GE_IF_BOOL_EXEC(item.graph == nullptr, continue);
  279. auto ret = item.graph->AddOutputNode(node_it->second);
  280. if (ret == nullptr) {
  281. GELOGE(GRAPH_FAILED, "AddOutputNode failed.");
  282. return false;
  283. }
  284. }
  285. node_input_node_names_.clear();
  286. graph_input_node_names_.clear();
  287. graph_output_node_names_.clear();
  288. node_map_.clear();
  289. return true;
  290. }
  291. bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_proto) {
  292. model.name_ = model_proto.name();
  293. model.version_ = model_proto.version();
  294. model.platform_version_ = model_proto.custom_version();
  295. model.attrs_ = ProtoAttrMapHelper(protobuf_owner_, model_proto.mutable_attr());
  296. auto &graphs_proto = *model_proto.mutable_graph();
  297. if (!graphs_proto.empty()) {
  298. auto &graph_proto = graphs_proto[0];
  299. ComputeGraphPtr compute_graph_ptr;
  300. if (UnserializeGraphWithoutEdge(compute_graph_ptr, graph_proto)) {
  301. model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr);
  302. }
  303. }
  304. if (!HandleNodeNameRef()) {
  305. GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed");
  306. return false;
  307. }
  308. return true;
  309. }
  310. bool ModelSerializeImp::UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graph_proto) {
  311. graph = ComGraphMakeShared<ComputeGraph>(graph_proto.name());
  312. if (graph == nullptr) {
  313. GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed");
  314. return false;
  315. }
  316. // Inputs
  317. for (auto input : graph_proto.input()) {
  318. string node_name;
  319. int32_t index;
  320. if (ParseNodeIndex(input, node_name, index)) {
  321. graph_input_node_names_.push_back(NodeNameGraphReq{node_name, index, graph});
  322. }
  323. }
  324. // Outputs
  325. for (auto output : graph_proto.output()) {
  326. string node_name;
  327. int32_t index;
  328. if (ParseNodeIndex(output, node_name, index)) {
  329. graph_output_node_names_.push_back(NodeNameGraphReq{node_name, index, graph});
  330. }
  331. }
  332. graph->attrs_ = ProtoAttrMapHelper(protobuf_owner_, graph_proto.mutable_attr());
  333. for (auto &op_def_proto : *graph_proto.mutable_op()) {
  334. if (!UnserializeNode(graph, op_def_proto)) {
  335. GELOGE(GRAPH_FAILED, "UnserializeNode fail");
  336. return false;
  337. }
  338. }
  339. return true;
  340. }
  341. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeGraph(ComputeGraphPtr &graph,
  342. proto::GraphDef &graph_proto) {
  343. if (!UnserializeGraphWithoutEdge(graph, graph_proto)) {
  344. GELOGW("UnserializeGraphWithoutEdge fail");
  345. }
  346. if (!HandleNodeNameRef()) {
  347. GELOGE(GRAPH_FAILED, "Link Anchor or set graph input or output fail");
  348. return false;
  349. }
  350. return true;
  351. }
  352. bool ReadProtoFromBinaryFile(const uint8_t *data, size_t len, google::protobuf::Message *proto) {
  353. GE_CHK_BOOL_EXEC(data != nullptr, return false, "data is null.");
  354. GE_CHK_BOOL_EXEC(proto != nullptr, return false, "proto is null.");
  355. google::protobuf::io::CodedInputStream coded_stream(data, len);
  356. // 2048M -1
  357. coded_stream.SetTotalBytesLimit(INT32_MAX, -1);
  358. if (!proto->ParseFromCodedStream(&coded_stream)) {
  359. GELOGE(GRAPH_FAILED, "ReadProtoFromBinaryFile failed len %zu", len);
  360. return false;
  361. }
  362. return true;
  363. }
  364. Buffer ModelSerialize::SerializeModel(const Model &model) {
  365. proto::ModelDef model_def;
  366. ModelSerializeImp imp;
  367. if (!imp.SerializeModel(model, &model_def)) {
  368. return Buffer();
  369. }
  370. #if !defined(__ANDROID__) && !defined(ANDROID)
  371. Buffer buffer(model_def.ByteSizeLong());
  372. #else
  373. Buffer buffer(model_def.ByteSize());
  374. #endif
  375. GE_CHK_BOOL_ONLY_LOG(buffer.GetSize() != 0, "get size failed");
  376. GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed");
  377. auto ret = model_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize()));
  378. if (ret != true) {
  379. GELOGW("serialize to array fail.");
  380. }
  381. return buffer;
  382. }
  383. size_t ModelSerialize::GetSerializeModelSize(const Model &model) {
  384. proto::ModelDef model_def;
  385. ModelSerializeImp imp;
  386. if (!imp.SerializeModel(model, &model_def)) {
  387. return 0;
  388. }
  389. #if !defined(__ANDROID__) && !defined(ANDROID)
  390. return model_def.ByteSizeLong();
  391. #else
  392. return model_def.ByteSize();
  393. #endif
  394. }
  395. Model ModelSerialize::UnserializeModel(const uint8_t *data, size_t len) {
  396. if (data == nullptr) {
  397. GELOGE(GRAPH_FAILED, "data is nullptr");
  398. return Model();
  399. }
  400. std::shared_ptr<proto::ModelDef> model_proto_ptr;
  401. model_proto_ptr = ComGraphMakeShared<proto::ModelDef>();
  402. if (model_proto_ptr == nullptr) {
  403. GELOGE(GRAPH_FAILED, "proto::ModelDef make shared failed");
  404. return Model();
  405. }
  406. auto &model_proto = *model_proto_ptr;
  407. if (!ReadProtoFromBinaryFile(data, len, &model_proto)) {
  408. GELOGE(GRAPH_FAILED, "ParseFromArray fail");
  409. return Model();
  410. }
  411. Model model;
  412. ModelSerializeImp imp;
  413. imp.SetProtobufOwner(model_proto_ptr);
  414. if (!imp.UnserializeModel(model, model_proto)) {
  415. GELOGE(GRAPH_FAILED, "Unserialize Model fail");
  416. return Model();
  417. }
  418. return model;
  419. }
  420. Model ModelSerialize::UnserializeModel(ge::proto::ModelDef &model_def) {
  421. std::shared_ptr<proto::ModelDef> model_def_ptr = ComGraphMakeShared<proto::ModelDef>(model_def);
  422. GE_CHK_BOOL_EXEC(model_def_ptr != nullptr, return Model(), "mode_def make shared failed");
  423. ModelSerializeImp imp;
  424. imp.SetProtobufOwner(model_def_ptr);
  425. Model model;
  426. if (!imp.UnserializeModel(model, *model_def_ptr)) {
  427. GELOGE(GRAPH_FAILED, "Unserialize Model fail");
  428. return Model();
  429. }
  430. return model;
  431. }
  432. Buffer ModelSerialize::SerializeGraph(const ComputeGraphPtr &graph) {
  433. proto::GraphDef graph_def;
  434. ModelSerializeImp imp;
  435. if (!imp.SerializeGraph(graph, &graph_def)) {
  436. return Buffer();
  437. }
  438. #if !defined(__ANDROID__) && !defined(ANDROID)
  439. Buffer buffer(graph_def.ByteSizeLong());
  440. #else
  441. Buffer buffer(graph_def.ByteSize());
  442. #endif
  443. GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed");
  444. GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed");
  445. auto ret = graph_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize()));
  446. if (ret != true) {
  447. GE_LOGE("serialize to array fail.");
  448. }
  449. return buffer;
  450. }
  451. ComputeGraphPtr ModelSerialize::UnserializeGraph(const uint8_t *data, size_t len) {
  452. if (data == nullptr) {
  453. GELOGE(GRAPH_FAILED, "data is nullptr");
  454. return nullptr;
  455. }
  456. std::shared_ptr<proto::GraphDef> graph_proto_ptr;
  457. graph_proto_ptr = ComGraphMakeShared<proto::GraphDef>();
  458. if (graph_proto_ptr == nullptr) {
  459. GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed");
  460. return nullptr;
  461. }
  462. proto::GraphDef &graph_proto = *graph_proto_ptr;
  463. if (!ReadProtoFromBinaryFile(data, len, &graph_proto)) {
  464. GELOGE(GRAPH_FAILED, "ParseFromArray fail");
  465. return nullptr;
  466. }
  467. ComputeGraphPtr graph;
  468. ModelSerializeImp imp;
  469. imp.SetProtobufOwner(graph_proto_ptr);
  470. if (!imp.UnserializeGraph(graph, graph_proto)) {
  471. return nullptr;
  472. }
  473. return graph;
  474. }
  475. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer ModelSerialize::SerializeOpDesc(const ConstOpDescPtr &op_desc) {
  476. proto::OpDef op_def;
  477. ModelSerializeImp imp;
  478. if (!imp.SerializeOpDesc(op_desc, &op_def)) {
  479. return Buffer();
  480. }
  481. #if !defined(__ANDROID__) && !defined(ANDROID)
  482. Buffer buffer(op_def.ByteSizeLong());
  483. #else
  484. Buffer buffer(op_def.ByteSize());
  485. #endif
  486. GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed");
  487. GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed");
  488. auto ret = op_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize()));
  489. if (ret != true) {
  490. GE_LOGE("serialize to array fail.");
  491. }
  492. return buffer;
  493. }
  494. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr ModelSerialize::UnserializeOpDesc(const uint8_t *data,
  495. size_t len) {
  496. if (data == nullptr) {
  497. GELOGE(GRAPH_FAILED, "data is nullptr");
  498. return nullptr;
  499. }
  500. std::shared_ptr<proto::OpDef> op_def_ptr;
  501. op_def_ptr = ComGraphMakeShared<proto::OpDef>();
  502. if (op_def_ptr == nullptr) {
  503. GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed");
  504. return nullptr;
  505. }
  506. proto::OpDef &op_def = *op_def_ptr;
  507. if (!ReadProtoFromBinaryFile(data, len, &op_def)) {
  508. GELOGE(GRAPH_FAILED, "ParseFromArray fail");
  509. return nullptr;
  510. }
  511. OpDescPtr op_desc;
  512. ModelSerializeImp imp;
  513. imp.SetProtobufOwner(op_def_ptr);
  514. if (!imp.UnserializeOpDesc(op_desc, op_def)) {
  515. GELOGW("UnserializeOpDesc error.");
  516. }
  517. return op_desc;
  518. }
  519. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示