You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_serialize.cc 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/model_serialize.h"
  17. #include <google/protobuf/text_format.h>
  18. #include <queue>
  19. #include <iostream>
  20. #include "debug/ge_attr_define.h"
  21. #include "debug/ge_log.h"
  22. #include "debug/ge_util.h"
  23. #include "framework/common/debug/ge_log.h"
  24. #include "graph/detail/model_serialize_imp.h"
  25. #include "proto/ge_ir.pb.h"
  26. #include "utils/graph_utils.h"
  27. #include "debug/ge_op_types.h"
  28. using std::map;
  29. using std::string;
  30. namespace ge {
  31. bool ModelSerializeImp::ParseNodeIndex(const string &node_index, string &node_name, int32_t &index) {
  32. auto sep = node_index.rfind(":");
  33. if (sep == string::npos) {
  34. GELOGW("separator is not found in node_index.");
  35. return false;
  36. }
  37. node_name = node_index.substr(0, sep);
  38. auto index_str = node_index.substr(sep + 1);
  39. index = static_cast<int32_t>(std::strtol(index_str.c_str(), nullptr, 10));
  40. return true;
  41. }
  42. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeTensor(const ConstGeTensorPtr &tensor,
  43. proto::TensorDef *tensor_proto) {
  44. GE_CHK_BOOL_EXEC(tensor != nullptr, return false, "tensor is null.");
  45. GE_CHK_BOOL_EXEC(tensor_proto != nullptr, return false, "tensor_proto is null.");
  46. if (tensor->tensor_def_.GetProtoMsg() != nullptr) {
  47. *tensor_proto = *tensor->tensor_def_.GetProtoMsg();
  48. return true;
  49. }
  50. return false;
  51. }
  52. bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_proto) {
  53. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is null.");
  54. GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null.");
  55. op_def_proto->clear_input();
  56. // Inputs
  57. for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
  58. if (in_data_anchor != nullptr) {
  59. auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  60. if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) {
  61. op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" +
  62. std::to_string(peer_out_anchor->GetIdx()));
  63. } else {
  64. op_def_proto->add_input("");
  65. }
  66. }
  67. }
  68. // Control edge
  69. auto control_anchor = node->GetInControlAnchor();
  70. if (control_anchor != nullptr) {
  71. auto peer_out_anchors = control_anchor->GetPeerOutControlAnchors();
  72. for (const auto &peer_out_anchor : peer_out_anchors) {
  73. if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) {
  74. op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":-1");
  75. }
  76. }
  77. }
  78. return true;
  79. }
  80. bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) {
  81. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null.");
  82. GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null.");
  83. if (op_desc->op_def_.GetProtoMsg() != nullptr) {
  84. *op_def_proto = *op_desc->op_def_.GetProtoMsg();
  85. //Delete unnecessary attr
  86. if (is_dump) {
  87. auto attr = op_def_proto->mutable_attr();
  88. attr->erase(ATTR_NAME_FRAMEWORK_NODE_DEF);
  89. attr->erase(ATTR_NAME_FRAMEWORK_OP_DEF);
  90. attr->erase(ATTR_NAME_FRAMEWORK_FUNC_DEF);
  91. GE_IF_BOOL_EXEC((op_def_proto->type() == CONSTANT || op_def_proto->type() == CONSTANTOP),
  92. attr->erase(ATTR_NAME_WEIGHTS));
  93. }
  94. op_def_proto->clear_input_desc();
  95. op_def_proto->clear_output_desc();
  96. // Input descs
  97. if (op_desc->GetAllInputsSize() > 0) {
  98. auto size = static_cast<uint32_t>(op_desc->GetAllInputsSize());
  99. for (uint32_t i = 0; i < size; i++) {
  100. auto tensor_desc = op_desc->GetInputDescPtrDfault(i);
  101. if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) {
  102. *op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg());
  103. }
  104. }
  105. }
  106. // Output descs
  107. if (op_desc->GetOutputsSize() > 0) {
  108. auto size = static_cast<uint32_t>(op_desc->GetOutputsSize());
  109. for (uint32_t i = 0; i < size; i++) {
  110. auto tensor_desc = op_desc->GetOutputDescPtr(i);
  111. if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) {
  112. *op_def_proto->add_output_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg());
  113. }
  114. }
  115. }
  116. op_def_proto->set_id(op_desc->GetId());
  117. for (const std::string &name : op_desc->GetSubgraphInstanceNames()) {
  118. op_def_proto->add_subgraph_name(name);
  119. }
  120. OpDescToAttrDef(op_desc, op_def_proto);
  121. }
  122. return true;
  123. }
  124. void ModelSerializeImp::OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) {
  125. proto::AttrDef key_in;
  126. proto::AttrDef value_in;
  127. auto op_desc_attr = op_def_proto->mutable_attr();
  128. if (!op_desc->input_name_idx_.empty()) {
  129. for (auto &item : op_desc->input_name_idx_) {
  130. key_in.mutable_list()->add_s(item.first);
  131. value_in.mutable_list()->add_i(item.second);
  132. }
  133. op_desc_attr->insert({"_input_name_key", key_in});
  134. op_desc_attr->insert({"_input_name_value", value_in});
  135. }
  136. proto::AttrDef key_out;
  137. proto::AttrDef value_out;
  138. if (!op_desc->output_name_idx_.empty()) {
  139. for (auto &item : op_desc->output_name_idx_) {
  140. key_out.mutable_list()->add_s(item.first);
  141. value_out.mutable_list()->add_i(item.second);
  142. }
  143. op_desc_attr->insert({"_output_name_key", key_out});
  144. op_desc_attr->insert({"_output_name_value", value_out});
  145. }
  146. proto::AttrDef opt_input;
  147. if (!op_desc->optional_input_names_.empty()) {
  148. for (auto &item : op_desc->optional_input_names_) {
  149. opt_input.mutable_list()->add_s(item);
  150. }
  151. op_desc_attr->insert({"_opt_input", opt_input});
  152. }
  153. }
  154. bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) {
  155. if (node == nullptr || op_def_proto == nullptr) {
  156. GELOGE(GRAPH_FAILED, "Input Para Node Invalid");
  157. return false;
  158. }
  159. if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto, is_dump)) {
  160. GELOGE(GRAPH_FAILED, "Serialize OpDesc failed");
  161. return false;
  162. }
  163. if (SerializeEdge(node, op_def_proto)) {
  164. return true;
  165. } else {
  166. return false;
  167. }
  168. }
  169. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph,
  170. proto::GraphDef *graph_proto,
  171. bool is_dump) {
  172. if (graph == nullptr || graph_proto == nullptr) {
  173. GELOGE(GRAPH_FAILED, "Input para Invalid");
  174. return false;
  175. }
  176. graph_proto->set_name(graph->GetName());
  177. // Inputs
  178. for (const auto &input : graph->GetInputNodes()) {
  179. if (input != nullptr) {
  180. graph_proto->add_input(input->GetName() + ":0");
  181. }
  182. }
  183. // Outputs
  184. for (const auto &output : graph->GetGraphOutNodesInfo()) {
  185. if (output.first != nullptr) {
  186. graph_proto->add_output(output.first->GetName() + ":" + std::to_string(output.second));
  187. GELOGI("Add output to graph proto, node name:%s, index:%ld", output.first->GetName().c_str(), output.second);
  188. }
  189. }
  190. if (graph->attrs_.GetProtoMsg() != nullptr) {
  191. *graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg();
  192. }
  193. for (const auto &node : graph->GetDirectNode()) {
  194. if (!SerializeNode(node, graph_proto->add_op(), is_dump)) {
  195. if (node->GetOpDesc() != nullptr) {
  196. GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str());
  197. }
  198. return false;
  199. }
  200. }
  201. return true;
  202. }
  203. bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto, bool is_dump) {
  204. if (model_proto == nullptr) {
  205. GELOGE(GRAPH_FAILED, "model_proto para Invalid");
  206. return false;
  207. }
  208. model_proto->set_name(model.GetName());
  209. model_proto->set_custom_version(model.GetPlatformVersion());
  210. model_proto->set_version(model.GetVersion());
  211. if (model.attrs_.GetProtoMsg()) {
  212. *model_proto->mutable_attr() = *model.attrs_.GetProtoMsg();
  213. }
  214. auto &graph = model.graph_;
  215. auto compute_graph = GraphUtils::GetComputeGraph(graph);
  216. if (compute_graph == nullptr) {
  217. GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr");
  218. return false;
  219. }
  220. if (!SerializeGraph(compute_graph, model_proto->add_graph(), is_dump)) {
  221. GELOGE(GRAPH_FAILED, "SerializeGraph fail");
  222. return false;
  223. }
  224. for (auto subgraph : compute_graph->GetAllSubgraphs()) {
  225. if (!SerializeGraph(subgraph, model_proto->add_graph(), is_dump)) {
  226. GELOGE(GRAPH_FAILED, "Serialize subgraph failed");
  227. return false;
  228. }
  229. }
  230. return true;
  231. }
  232. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeTensor(
  233. GeTensorPtr &tensor, proto::TensorDef &tensor_proto) {
  234. tensor = std::shared_ptr<GeTensor>(new (std::nothrow) GeTensor(protobuf_owner_, &tensor_proto));
  235. if (tensor == nullptr) {
  236. GELOGE(GRAPH_FAILED, "tensor is nullptr");
  237. return false;
  238. } else {
  239. return true;
  240. }
  241. }
  242. void ModelSerializeImp::AttrDefToOpDesc(OpDescPtr &op_desc,
  243. std::vector<string> &key_in,
  244. std::vector<string> &key_out,
  245. std::vector<uint32_t> &value_in,
  246. std::vector<uint32_t> &value_out,
  247. std::vector<string> &opt_input) {
  248. if (!key_in.empty()) {
  249. if (key_in.size() != value_in.size()) {
  250. GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.",
  251. key_out.size(), value_in.size());
  252. } else {
  253. for (uint32_t i = 0; i < key_in.size(); ++i) {
  254. op_desc->input_name_idx_.insert(std::pair<string, uint32_t>(key_in.at(i), value_in.at(i)));
  255. }
  256. }
  257. }
  258. if (!key_out.empty()) {
  259. if (key_out.size() != value_out.size()) {
  260. GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.",
  261. key_out.size(), value_out.size());
  262. } else {
  263. for (uint32_t i = 0; i < key_out.size(); ++i) {
  264. op_desc->output_name_idx_.insert(std::pair<string, uint32_t>(key_out.at(i), value_out.at(i)));
  265. }
  266. }
  267. }
  268. if (!opt_input.empty()) {
  269. for (const auto &i : opt_input) {
  270. op_desc->optional_input_names_.insert(i);
  271. }
  272. }
  273. }
  274. bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_def_proto) {
  275. std::vector<string> opt_input;
  276. std::vector<string> key_in;
  277. std::vector<uint32_t> value_in;
  278. if (op_def_proto.attr().count("_opt_input") > 0) {
  279. auto &name_list = op_def_proto.attr().at("_opt_input").list();
  280. for (const auto &item_s : name_list.s()) {
  281. opt_input.push_back(item_s);
  282. }
  283. auto op_desc_attr = op_def_proto.mutable_attr();
  284. op_desc_attr->erase("_opt_input");
  285. }
  286. if (op_def_proto.attr().count("_input_name_key") > 0) {
  287. auto &output_name_key_list = op_def_proto.attr().at("_input_name_key").list();
  288. for (const auto &item_s : output_name_key_list.s()) {
  289. key_in.push_back(item_s);
  290. }
  291. auto op_desc_attr = op_def_proto.mutable_attr();
  292. op_desc_attr->erase("_input_name_key");
  293. }
  294. if (op_def_proto.attr().count("_input_name_value") > 0) {
  295. auto &input_name_value_list = op_def_proto.attr().at("_input_name_value").list();
  296. for (const auto &item_i : input_name_value_list.i()) {
  297. value_in.push_back(static_cast<uint32_t>(item_i));
  298. }
  299. auto op_desc_attr = op_def_proto.mutable_attr();
  300. op_desc_attr->erase("_input_name_value");
  301. }
  302. std::vector<string> key_out;
  303. std::vector<uint32_t> value_out;
  304. if (op_def_proto.attr().count("_output_name_key") > 0) {
  305. auto &output_name_key_list = op_def_proto.attr().at("_output_name_key").list();
  306. for (const auto &item_s : output_name_key_list.s()) {
  307. key_out.push_back(item_s);
  308. }
  309. auto op_desc_attr = op_def_proto.mutable_attr();
  310. op_desc_attr->erase("_output_name_key");
  311. }
  312. if (op_def_proto.attr().count("_output_name_value") > 0) {
  313. auto &output_name_value_list = op_def_proto.attr().at("_output_name_value").list();
  314. for (const auto &item_i : output_name_value_list.i()) {
  315. value_out.push_back(static_cast<uint32_t>(item_i));
  316. }
  317. auto op_desc_attr = op_def_proto.mutable_attr();
  318. op_desc_attr->erase("_output_name_value");
  319. }
  320. op_desc = std::shared_ptr<OpDesc>(new (std::nothrow) OpDesc(protobuf_owner_, &op_def_proto));
  321. GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr.");
  322. // Input tensor
  323. for (auto &input_desc : *op_def_proto.mutable_input_desc()) {
  324. std::shared_ptr<GeTensorDesc> temp_value =
  325. std::shared_ptr<GeTensorDesc>(new (std::nothrow) GeTensorDesc(protobuf_owner_, &input_desc));
  326. GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr");
  327. op_desc->inputs_desc_.push_back(temp_value);
  328. }
  329. // Output tensor
  330. for (auto &output_desc : *op_def_proto.mutable_output_desc()) {
  331. std::shared_ptr<GeTensorDesc> temp_value =
  332. std::shared_ptr<GeTensorDesc>(new (std::nothrow) GeTensorDesc(protobuf_owner_, &output_desc));
  333. GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr");
  334. op_desc->outputs_desc_.push_back(temp_value);
  335. }
  336. op_desc->SetId(op_def_proto.id());
  337. uint32_t graph_index = 0;
  338. for (const std::string &name : op_def_proto.subgraph_name()) {
  339. op_desc->AddSubgraphName(name);
  340. op_desc->SetSubgraphInstanceName(graph_index++, name);
  341. }
  342. // insert name index by key and value
  343. AttrDefToOpDesc(op_desc, key_in, key_out, value_in, value_out, opt_input);
  344. return true;
  345. }
  346. bool ModelSerializeImp::UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op_def_proto) {
  347. GE_RT_FALSE_CHECK_NOTNULL(graph);
  348. OpDescPtr op_desc = nullptr;
  349. if (!UnserializeOpDesc(op_desc, op_def_proto)) {
  350. GELOGW("UnserializeOpDesc error.");
  351. }
  352. NodePtr node = graph->AddNode(op_desc, op_desc->GetId());
  353. GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr.");
  354. // Inputs
  355. int dst_index = 0;
  356. for (const auto &input : op_def_proto.input()) {
  357. string node_name;
  358. int32_t index = 0;
  359. if (ParseNodeIndex(input, node_name, index)) {
  360. node_input_node_names_.push_back(NodeNameNodeReq{node_name, index, node, dst_index, op_def_proto.name()});
  361. }
  362. if (index >= 0) {
  363. dst_index++;
  364. }
  365. }
  366. node_map_[op_def_proto.name()] = node;
  367. return true;
  368. }
  369. bool ModelSerializeImp::HandleNodeNameRef() {
  370. // Edges
  371. for (auto &item : node_input_node_names_) {
  372. auto src_node_it = node_map_.find(item.src_node_name);
  373. if (src_node_it == node_map_.end()) {
  374. GELOGE(GRAPH_FAILED, "cannot find node %s", item.src_node_name.c_str());
  375. return false;
  376. }
  377. GE_IF_BOOL_EXEC(src_node_it->second == nullptr || item.dst_node == nullptr, continue);
  378. if (item.src_out_index >= 0) {
  379. auto src_anchor = src_node_it->second->GetOutDataAnchor(item.src_out_index);
  380. auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index);
  381. if (src_anchor == nullptr || dst_anchor == nullptr) {
  382. GELOGE(GRAPH_FAILED, "get anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index,
  383. item.dst_node_name.c_str(), item.dst_in_index);
  384. return false;
  385. }
  386. GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737
  387. } else {
  388. // Control edge
  389. auto src_anchor = src_node_it->second->GetOutControlAnchor();
  390. auto dst_anchor = item.dst_node->GetInControlAnchor();
  391. if (src_anchor != nullptr && dst_anchor != nullptr) {
  392. GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737
  393. }
  394. }
  395. }
  396. // Graph input
  397. for (auto &item : graph_input_node_names_) {
  398. auto node_it = node_map_.find(item.node_name);
  399. if (node_it == node_map_.end()) {
  400. GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str());
  401. return false;
  402. }
  403. GE_IF_BOOL_EXEC(item.graph == nullptr, continue);
  404. auto ret = item.graph->AddInputNode(node_it->second);
  405. if (ret == nullptr) {
  406. return false;
  407. }
  408. }
  409. // Graph output
  410. for (auto &item : graph_output_node_names_) {
  411. auto node_it = node_map_.find(item.node_name);
  412. if (node_it == node_map_.end()) {
  413. GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str());
  414. return false;
  415. }
  416. GE_IF_BOOL_EXEC(item.graph == nullptr, continue);
  417. auto ret = item.graph->AddOutputNodeByIndex(node_it->second, item.index);
  418. GELOGI("node name:%s, item.index:%ld", node_it->second->GetName().c_str(), item.index);
  419. if (ret == nullptr) {
  420. GELOGE(GRAPH_FAILED, "AddOutputNode failed.");
  421. return false;
  422. }
  423. }
  424. node_input_node_names_.clear();
  425. graph_input_node_names_.clear();
  426. graph_output_node_names_.clear();
  427. node_map_.clear();
  428. return true;
  429. }
  430. bool ModelSerializeImp::RebuildOwnership(ComputeGraphPtr &compute_graph, map<string, ComputeGraphPtr> &subgraphs) {
  431. std::queue<ComputeGraphPtr> all_graphs;
  432. all_graphs.emplace(compute_graph);
  433. while (!all_graphs.empty()) {
  434. ComputeGraphPtr graph = all_graphs.front();
  435. all_graphs.pop();
  436. for (const NodePtr &node : graph->GetDirectNode()) {
  437. const OpDescPtr op_desc = node->GetOpDesc();
  438. for (const std::string &name : op_desc->GetSubgraphInstanceNames()) {
  439. auto it = subgraphs.find(name);
  440. if (it == subgraphs.end()) {
  441. GELOGE(GRAPH_FAILED, "Node:%s, Subgraph:%s not found, num:%zu.",
  442. op_desc->GetName().c_str(), name.c_str(), subgraphs.size());
  443. return false;
  444. }
  445. ComputeGraphPtr &subgraph = it->second;
  446. subgraph->SetParentGraph(graph);
  447. subgraph->SetParentNode(node);
  448. compute_graph->AddSubgraph(subgraph->GetName(), subgraph);
  449. all_graphs.emplace(subgraph);
  450. }
  451. }
  452. }
  453. return true;
  454. }
  455. bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_proto) {
  456. model.name_ = model_proto.name();
  457. model.version_ = model_proto.version();
  458. model.platform_version_ = model_proto.custom_version();
  459. model.attrs_ = ProtoAttrMapHelper(protobuf_owner_, model_proto.mutable_attr());
  460. auto &graphs_proto = *model_proto.mutable_graph();
  461. if (!graphs_proto.empty()) {
  462. auto &graph_proto = graphs_proto[0];
  463. ComputeGraphPtr compute_graph_ptr;
  464. if (UnserializeGraphWithoutEdge(compute_graph_ptr, graph_proto)) {
  465. model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr);
  466. }
  467. // 0 is main graph, following is subgraph.
  468. map<string, ComputeGraphPtr> subgraphs;
  469. for (int idx = 1; idx < graphs_proto.size(); ++idx) {
  470. ComputeGraphPtr subgraph;
  471. ModelSerializeImp impl;
  472. if (!impl.UnserializeGraphWithoutEdge(subgraph, graphs_proto[idx])) {
  473. GELOGE(GRAPH_FAILED, "UnserializeGraphWithoutEdge failed");
  474. return false;
  475. }
  476. if (!impl.HandleNodeNameRef()) {
  477. GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed");
  478. return false;
  479. }
  480. subgraphs[subgraph->GetName()] = subgraph;
  481. }
  482. if (!RebuildOwnership(compute_graph_ptr, subgraphs)) {
  483. GELOGE(GRAPH_FAILED, "Rebuild graph ownership failed");
  484. return false;
  485. }
  486. }
  487. if (!HandleNodeNameRef()) {
  488. GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed");
  489. return false;
  490. }
  491. return true;
  492. }
  493. bool ModelSerializeImp::UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graph_proto) {
  494. graph = ComGraphMakeShared<ComputeGraph>(graph_proto.name());
  495. if (graph == nullptr) {
  496. GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed");
  497. return false;
  498. }
  499. // Inputs
  500. for (auto input : graph_proto.input()) {
  501. string node_name;
  502. int32_t index;
  503. if (ParseNodeIndex(input, node_name, index)) {
  504. graph_input_node_names_.push_back(NodeNameGraphReq{node_name, index, graph});
  505. }
  506. }
  507. // Outputs
  508. for (auto output : graph_proto.output()) {
  509. string node_name;
  510. int32_t index;
  511. if (ParseNodeIndex(output, node_name, index)) {
  512. graph_output_node_names_.push_back(NodeNameGraphReq{node_name, index, graph});
  513. }
  514. }
  515. graph->attrs_ = ProtoAttrMapHelper(protobuf_owner_, graph_proto.mutable_attr());
  516. for (auto &op_def_proto : *graph_proto.mutable_op()) {
  517. if (!UnserializeNode(graph, op_def_proto)) {
  518. GELOGE(GRAPH_FAILED, "UnserializeNode fail");
  519. return false;
  520. }
  521. }
  522. return true;
  523. }
  524. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeGraph(ComputeGraphPtr &graph,
  525. proto::GraphDef &graph_proto) {
  526. if (!UnserializeGraphWithoutEdge(graph, graph_proto)) {
  527. GELOGW("UnserializeGraphWithoutEdge fail");
  528. }
  529. if (!HandleNodeNameRef()) {
  530. GELOGE(GRAPH_FAILED, "Link Anchor or set graph input or output fail");
  531. return false;
  532. }
  533. return true;
  534. }
  535. bool ReadProtoFromBinaryFile(const uint8_t *data, size_t len, google::protobuf::Message *proto) {
  536. GE_CHK_BOOL_EXEC(data != nullptr, return false, "data is null.");
  537. GE_CHK_BOOL_EXEC(proto != nullptr, return false, "proto is null.");
  538. google::protobuf::io::CodedInputStream coded_stream(data, len);
  539. // 2048M -1
  540. coded_stream.SetTotalBytesLimit(INT32_MAX, -1);
  541. if (!proto->ParseFromCodedStream(&coded_stream)) {
  542. GELOGE(GRAPH_FAILED, "ReadProtoFromBinaryFile failed len %zu", len);
  543. return false;
  544. }
  545. return true;
  546. }
  547. Buffer ModelSerialize::SerializeModel(const Model &model, bool is_dump) {
  548. proto::ModelDef model_def;
  549. ModelSerializeImp imp;
  550. if (!imp.SerializeModel(model, &model_def, is_dump)) {
  551. return Buffer();
  552. }
  553. #if !defined(__ANDROID__) && !defined(ANDROID)
  554. Buffer buffer(model_def.ByteSizeLong());
  555. #else
  556. Buffer buffer(model_def.ByteSize());
  557. #endif
  558. GE_CHK_BOOL_ONLY_LOG(buffer.GetSize() != 0, "get size failed");
  559. GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed");
  560. auto ret = model_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize()));
  561. if (ret != true) {
  562. GELOGW("serialize to array fail.");
  563. }
  564. return buffer;
  565. }
  566. size_t ModelSerialize::GetSerializeModelSize(const Model &model) {
  567. proto::ModelDef model_def;
  568. ModelSerializeImp imp;
  569. if (!imp.SerializeModel(model, &model_def)) {
  570. return 0;
  571. }
  572. #if !defined(__ANDROID__) && !defined(ANDROID)
  573. return model_def.ByteSizeLong();
  574. #else
  575. return model_def.ByteSize();
  576. #endif
  577. }
  578. Model ModelSerialize::UnserializeModel(const uint8_t *data, size_t len) {
  579. if (data == nullptr) {
  580. GELOGE(GRAPH_FAILED, "data is nullptr");
  581. return Model();
  582. }
  583. std::shared_ptr<proto::ModelDef> model_proto_ptr;
  584. model_proto_ptr = ComGraphMakeShared<proto::ModelDef>();
  585. if (model_proto_ptr == nullptr) {
  586. GELOGE(GRAPH_FAILED, "proto::ModelDef make shared failed");
  587. return Model();
  588. }
  589. auto &model_proto = *model_proto_ptr;
  590. if (!ReadProtoFromBinaryFile(data, len, &model_proto)) {
  591. GELOGE(GRAPH_FAILED, "ParseFromArray fail");
  592. return Model();
  593. }
  594. Model model;
  595. ModelSerializeImp imp;
  596. imp.SetProtobufOwner(model_proto_ptr);
  597. if (!imp.UnserializeModel(model, model_proto)) {
  598. GELOGE(GRAPH_FAILED, "Unserialize Model fail");
  599. return Model();
  600. }
  601. return model;
  602. }
  603. Model ModelSerialize::UnserializeModel(ge::proto::ModelDef &model_def) {
  604. std::shared_ptr<proto::ModelDef> model_def_ptr = ComGraphMakeShared<proto::ModelDef>(model_def);
  605. GE_CHK_BOOL_EXEC(model_def_ptr != nullptr, return Model(), "mode_def make shared failed");
  606. ModelSerializeImp imp;
  607. imp.SetProtobufOwner(model_def_ptr);
  608. Model model;
  609. if (!imp.UnserializeModel(model, *model_def_ptr)) {
  610. GELOGE(GRAPH_FAILED, "Unserialize Model fail");
  611. return Model();
  612. }
  613. return model;
  614. }
  615. Buffer ModelSerialize::SerializeGraph(const ComputeGraphPtr &graph) {
  616. proto::GraphDef graph_def;
  617. ModelSerializeImp imp;
  618. if (!imp.SerializeGraph(graph, &graph_def)) {
  619. return Buffer();
  620. }
  621. #if !defined(__ANDROID__) && !defined(ANDROID)
  622. Buffer buffer(graph_def.ByteSizeLong());
  623. #else
  624. Buffer buffer(graph_def.ByteSize());
  625. #endif
  626. GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed");
  627. GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed");
  628. auto ret = graph_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize()));
  629. if (ret != true) {
  630. GE_LOGE("serialize to array fail.");
  631. }
  632. return buffer;
  633. }
  634. ComputeGraphPtr ModelSerialize::UnserializeGraph(const uint8_t *data, size_t len) {
  635. if (data == nullptr) {
  636. GELOGE(GRAPH_FAILED, "data is nullptr");
  637. return nullptr;
  638. }
  639. std::shared_ptr<proto::GraphDef> graph_proto_ptr;
  640. graph_proto_ptr = ComGraphMakeShared<proto::GraphDef>();
  641. if (graph_proto_ptr == nullptr) {
  642. GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed");
  643. return nullptr;
  644. }
  645. proto::GraphDef &graph_proto = *graph_proto_ptr;
  646. if (!ReadProtoFromBinaryFile(data, len, &graph_proto)) {
  647. GELOGE(GRAPH_FAILED, "ParseFromArray fail");
  648. return nullptr;
  649. }
  650. ComputeGraphPtr graph;
  651. ModelSerializeImp imp;
  652. imp.SetProtobufOwner(graph_proto_ptr);
  653. if (!imp.UnserializeGraph(graph, graph_proto)) {
  654. return nullptr;
  655. }
  656. return graph;
  657. }
  658. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer ModelSerialize::SerializeOpDesc(const ConstOpDescPtr &op_desc) {
  659. proto::OpDef op_def;
  660. ModelSerializeImp imp;
  661. if (!imp.SerializeOpDesc(op_desc, &op_def)) {
  662. return Buffer();
  663. }
  664. #if !defined(__ANDROID__) && !defined(ANDROID)
  665. Buffer buffer(op_def.ByteSizeLong());
  666. #else
  667. Buffer buffer(op_def.ByteSize());
  668. #endif
  669. GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed");
  670. GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed");
  671. auto ret = op_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize()));
  672. if (ret != true) {
  673. GE_LOGE("serialize to array fail.");
  674. }
  675. return buffer;
  676. }
  677. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr ModelSerialize::UnserializeOpDesc(const uint8_t *data,
  678. size_t len) {
  679. if (data == nullptr) {
  680. GELOGE(GRAPH_FAILED, "data is nullptr");
  681. return nullptr;
  682. }
  683. std::shared_ptr<proto::OpDef> op_def_ptr;
  684. op_def_ptr = ComGraphMakeShared<proto::OpDef>();
  685. if (op_def_ptr == nullptr) {
  686. GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed");
  687. return nullptr;
  688. }
  689. proto::OpDef &op_def = *op_def_ptr;
  690. if (!ReadProtoFromBinaryFile(data, len, &op_def)) {
  691. GELOGE(GRAPH_FAILED, "ParseFromArray fail");
  692. return nullptr;
  693. }
  694. OpDescPtr op_desc;
  695. ModelSerializeImp imp;
  696. imp.SetProtobufOwner(op_def_ptr);
  697. if (!imp.UnserializeOpDesc(op_desc, op_def)) {
  698. GELOGW("UnserializeOpDesc error.");
  699. }
  700. return op_desc;
  701. }
  702. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示