You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_ir_utils.cc 51 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/utils/ge_ir_utils.h"
  17. #include <utility>
  18. #include "framework/common/debug/ge_log.h"
  19. namespace {
  20. const char *const kControlAnchorIndex = ":-1";
  21. const char *const kNodeTypeForSubgraph = "subgraph";
  22. const char *const kPrefixForInputDesc = "input_desc_attr_";
  23. const char *const kPrefixForOutputDesc = "output_desc_attr_";
  24. const char *const kDumpGEGraph = "DUMP_GE_GRAPH";
  25. const int8_t kMaxRecursionDepth = 10;
  26. const char *const kDumpGeGraph = std::getenv(kDumpGEGraph);
  27. const int64_t kDumpLevel = (kDumpGeGraph != nullptr) ? std::strtol(kDumpGeGraph, nullptr, 10) : ge::OnnxUtils::NO_DUMP;
  28. const int64_t kInputPrefixLength = 5;
  29. const int64_t kOutputPrefixLength = 6;
  30. using AttrDefPair = ::google::protobuf::MapPair<std::string, ge::proto::AttrDef>;
  31. } // namespace
  32. namespace ge {
  33. // Part 1: from IR convert to ONNX Protobuf
  34. static const std::map<ge::DataType, onnx::TensorProto_DataType> kGeDataTypeToOnnxMap = {
  35. {DT_INT64, onnx::TensorProto_DataType_INT64}, {DT_UINT64, onnx::TensorProto_DataType_UINT64},
  36. {DT_FLOAT, onnx::TensorProto_DataType_FLOAT}, {DT_INT32, onnx::TensorProto_DataType_INT32},
  37. {DT_UINT32, onnx::TensorProto_DataType_UINT32}, {DT_INT8, onnx::TensorProto_DataType_INT8},
  38. {DT_UINT8, onnx::TensorProto_DataType_UINT8}, {DT_INT16, onnx::TensorProto_DataType_INT16},
  39. {DT_UINT16, onnx::TensorProto_DataType_UINT16}, {DT_FLOAT16, onnx::TensorProto_DataType_FLOAT16},
  40. {DT_DOUBLE, onnx::TensorProto_DataType_DOUBLE}, {DT_BOOL, onnx::TensorProto_DataType_BOOL},
  41. };
  42. onnx::TensorProto_DataType OnnxUtils::EncodeDataType(DataType data_type) {
  43. auto it = kGeDataTypeToOnnxMap.find(data_type);
  44. if (it != kGeDataTypeToOnnxMap.end()) {
  45. return it->second;
  46. } else {
  47. GELOGW("EncodeDataType: datatype not support %u", data_type);
  48. return onnx::TensorProto_DataType_UNDEFINED;
  49. }
  50. }
  51. void OnnxUtils::AddAttrProtoFromAttribute(const std::pair<const std::string, ge::GeAttrValue> &string_attr_value,
  52. onnx::NodeProto *node_proto) {
  53. if (node_proto == nullptr) {
  54. GELOGE(FAILED, "Node proto is nullptr.");
  55. return;
  56. }
  57. auto attr = node_proto->add_attribute();
  58. if (attr == nullptr) {
  59. GELOGE(GRAPH_FAILED, "attr is nullptr.");
  60. return;
  61. }
  62. auto attr_name = string_attr_value.first;
  63. attr->set_name(attr_name);
  64. auto attr_value = string_attr_value.second;
  65. auto value_type = attr_value.GetValueType();
  66. switch (value_type) {
  67. case GeAttrValue::VT_FLOAT: {
  68. GeAttrValue::FLOAT data_f = 0;
  69. (void)attr_value.GetValue(data_f);
  70. attr->set_f(data_f);
  71. attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
  72. break;
  73. }
  74. case GeAttrValue::VT_LIST_FLOAT: {
  75. GeAttrValue::LIST_FLOAT data_fs = {};
  76. (void)attr_value.GetValue(data_fs);
  77. attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
  78. for (auto &v : data_fs) {
  79. attr->add_floats(v);
  80. }
  81. break;
  82. }
  83. case GeAttrValue::VT_INT: {
  84. GeAttrValue::INT data_i = 0;
  85. (void)attr_value.GetValue(data_i);
  86. attr->set_type(onnx::AttributeProto_AttributeType_INT);
  87. attr->set_i(data_i);
  88. break;
  89. }
  90. case GeAttrValue::VT_LIST_INT: {
  91. GeAttrValue::LIST_INT data_is = {};
  92. (void)attr_value.GetValue(data_is);
  93. attr->set_type(onnx::AttributeProto_AttributeType_INTS);
  94. for (auto &v : data_is) {
  95. attr->add_ints(v);
  96. }
  97. break;
  98. }
  99. case GeAttrValue::VT_STRING: {
  100. GeAttrValue::STR data_s;
  101. (void)attr_value.GetValue(data_s);
  102. attr->set_type(onnx::AttributeProto_AttributeType_STRING);
  103. attr->set_s(data_s);
  104. break;
  105. }
  106. case GeAttrValue::VT_LIST_STRING: {
  107. GeAttrValue::LIST_STR data_ss = {};
  108. (void)attr_value.GetValue(data_ss);
  109. attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
  110. for (auto &v : data_ss) {
  111. attr->add_strings(v);
  112. }
  113. break;
  114. }
  115. default:
  116. GELOGW("GeAttrValue ValueType: %u is not supported for now", value_type);
  117. break;
  118. }
  119. }
  120. void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
  121. void *data) {
  122. if (node_proto == nullptr) {
  123. GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str());
  124. return;
  125. }
  126. auto attr = node_proto->add_attribute();
  127. if (attr == nullptr) {
  128. GELOGE(GRAPH_FAILED, "attr is nullptr.");
  129. return;
  130. }
  131. attr->set_name(name);
  132. switch (type) {
  133. case onnx::AttributeProto_AttributeType_FLOAT:
  134. attr->set_f((*(static_cast<float *>(data))));
  135. attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
  136. break;
  137. case onnx::AttributeProto_AttributeType_FLOATS:
  138. attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
  139. for (auto &v : (*(static_cast<std::vector<float> *>(data)))) {
  140. attr->add_floats(v);
  141. }
  142. break;
  143. case onnx::AttributeProto_AttributeType_INT:
  144. attr->set_type(onnx::AttributeProto_AttributeType_INT);
  145. attr->set_i((*(static_cast<int64_t *>(data))));
  146. break;
  147. case onnx::AttributeProto_AttributeType_INTS:
  148. attr->set_type(onnx::AttributeProto_AttributeType_INTS);
  149. for (auto &v : *(static_cast<std::vector<int64_t> *>(data))) {
  150. attr->add_ints(v);
  151. }
  152. break;
  153. case onnx::AttributeProto_AttributeType_STRING:
  154. attr->set_type(onnx::AttributeProto_AttributeType_STRING);
  155. attr->set_s((*(static_cast<std::string *>(data))));
  156. break;
  157. case onnx::AttributeProto_AttributeType_STRINGS:
  158. attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
  159. for (auto &v : *(static_cast<std::vector<std::string> *>(data))) {
  160. attr->add_strings(v);
  161. }
  162. break;
  163. default:
  164. GELOGW("AttributeProto AttributeType: %u is not supported for now", type);
  165. break;
  166. }
  167. }
  168. void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
  169. ::google::protobuf::RepeatedField<::google::protobuf::int64> data) {
  170. if (node_proto == nullptr) {
  171. GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str());
  172. return;
  173. }
  174. if (!data.empty()) {
  175. auto attr = node_proto->add_attribute();
  176. if (attr == nullptr) {
  177. GELOGE(GRAPH_FAILED, "attr is nullptr.");
  178. return;
  179. }
  180. attr->set_name(name);
  181. for (auto &v : data) {
  182. attr->add_ints(v);
  183. }
  184. attr->set_type(type);
  185. }
  186. }
  187. void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
  188. ::google::protobuf::RepeatedField<bool> data) {
  189. if (node_proto == nullptr) {
  190. GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str());
  191. return;
  192. }
  193. if (!data.empty()) {
  194. auto attr = node_proto->add_attribute();
  195. if (attr == nullptr) {
  196. GELOGE(GRAPH_FAILED, "attr is nullptr.");
  197. return;
  198. }
  199. attr->set_name(name);
  200. for (auto &v : data) {
  201. attr->add_ints(static_cast<int64_t>(v));
  202. }
  203. attr->set_type(type);
  204. }
  205. }
  206. void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
  207. ::google::protobuf::RepeatedField<float> data) {
  208. if (node_proto == nullptr) {
  209. GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str());
  210. return;
  211. }
  212. if (!data.empty()) {
  213. auto attr = node_proto->add_attribute();
  214. if (attr == nullptr) {
  215. GELOGE(GRAPH_FAILED, "attr is nullptr.");
  216. return;
  217. }
  218. attr->set_name(name);
  219. for (auto &v : data) {
  220. attr->add_floats(v);
  221. }
  222. attr->set_type(type);
  223. }
  224. }
  225. void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
  226. ::google::protobuf::RepeatedPtrField<::std::string> data) {
  227. if (node_proto == nullptr) {
  228. GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str());
  229. return;
  230. }
  231. if (!data.empty()) {
  232. auto attr = node_proto->add_attribute();
  233. if (attr == nullptr) {
  234. GELOGE(GRAPH_FAILED, "attr is nullptr.");
  235. return;
  236. }
  237. attr->set_name(name);
  238. for (auto &v : data) {
  239. attr->add_strings(v);
  240. }
  241. attr->set_type(type);
  242. }
  243. }
  244. void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc) {
  245. if (node_proto == nullptr || op_desc == nullptr) {
  246. GELOGE(GRAPH_FAILED, "node_proto or op_desc is nullptr");
  247. return;
  248. }
  249. // Input describes
  250. auto size_in = op_desc->GetAllInputsSize();
  251. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_nums", &size_in);
  252. if (size_in > 0) {
  253. for (uint32_t i = 0; i < size_in; i++) {
  254. auto input_desc = op_desc->GetInputDescPtrDfault(i);
  255. if (input_desc != nullptr) {
  256. auto data_type = TypeUtils::DataTypeToSerialString(input_desc->GetDataType());
  257. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_dtype:" + std::to_string(i),
  258. &data_type);
  259. auto data_type_origin = TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType());
  260. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  261. "input_desc_origin_dtype:" + std::to_string(i), &data_type_origin);
  262. auto dims = input_desc->GetShape().GetDims();
  263. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_desc_shape:" + std::to_string(i),
  264. &dims);
  265. auto dims_origin = input_desc->GetOriginShape().GetDims();
  266. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS,
  267. "input_desc_origin_shape:" + std::to_string(i), &dims_origin);
  268. auto layout = TypeUtils::FormatToSerialString(input_desc->GetFormat());
  269. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_layout:" + std::to_string(i),
  270. &layout);
  271. auto layout_origin = TypeUtils::FormatToSerialString(input_desc->GetOriginFormat());
  272. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  273. "input_desc_origin_layout:" + std::to_string(i), &layout_origin);
  274. auto tensor_descriptor = input_desc->tensor_descriptor_.GetProtoMsg();
  275. if (tensor_descriptor != nullptr) {
  276. auto size = tensor_descriptor->size();
  277. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_size:" + std::to_string(i),
  278. &size);
  279. auto weight_size = tensor_descriptor->weight_size();
  280. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  281. "input_desc_weight_size:" + std::to_string(i), &weight_size);
  282. auto reuse_input = tensor_descriptor->reuse_input();
  283. auto reuse_input_int = static_cast<int64_t>(reuse_input);
  284. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  285. "input_desc_reuse_input:" + std::to_string(i), &reuse_input_int);
  286. auto output_tensor = tensor_descriptor->output_tensor();
  287. auto output_tensor_int = static_cast<int64_t>(output_tensor);
  288. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  289. "input_desc_output_tensor:" + std::to_string(i), &output_tensor_int);
  290. auto device_type = tensor_descriptor->device_type();
  291. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  292. "input_desc_device_type:" + std::to_string(i), &device_type);
  293. auto input_tensor = tensor_descriptor->input_tensor();
  294. auto input_tensor_int = static_cast<int64_t>(input_tensor);
  295. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  296. "input_desc_input_tensor:" + std::to_string(i), &input_tensor_int);
  297. auto real_dim_cnt = tensor_descriptor->real_dim_cnt();
  298. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  299. "input_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt);
  300. auto data_offset = tensor_descriptor->data_offset();
  301. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  302. "input_desc_data_offset:" + std::to_string(i), &data_offset);
  303. auto cmps_size = tensor_descriptor->cmps_size();
  304. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_cmps_size:" + std::to_string(i),
  305. &cmps_size);
  306. auto cmps_tab = tensor_descriptor->cmps_tab();
  307. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  308. "input_desc_cmps_tab:" + std::to_string(i), &cmps_tab);
  309. auto cmps_tab_offset = tensor_descriptor->cmps_tab_offset();
  310. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  311. "input_desc_cmps_tab_offset:" + std::to_string(i), &cmps_tab_offset);
  312. const auto &tensor_desc_map = tensor_descriptor->attr();
  313. std::string suffix = ":" + std::to_string(i);
  314. AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForInputDesc, suffix);
  315. } else {
  316. GELOGW("Tensor descriptor is nullptr");
  317. continue;
  318. }
  319. } else {
  320. GELOGW("Input desc is nullptr");
  321. continue;
  322. }
  323. }
  324. }
  325. // Output describes
  326. auto size_out = op_desc->GetOutputsSize();
  327. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_nums", &size_out);
  328. if (size_out > 0) {
  329. for (uint32_t i = 0; i < size_out; i++) {
  330. auto output_desc = op_desc->GetOutputDescPtr(i);
  331. if (output_desc != nullptr) {
  332. auto data_type = TypeUtils::DataTypeToSerialString(output_desc->GetDataType());
  333. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_dtype:" + std::to_string(i),
  334. &data_type);
  335. auto origin_data_type = TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType());
  336. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  337. "output_desc_origin_dtype:" + std::to_string(i), &origin_data_type);
  338. auto dims = output_desc->GetShape().GetDims();
  339. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_desc_shape:" + std::to_string(i),
  340. &dims);
  341. auto dims_origin = output_desc->GetOriginShape().GetDims();
  342. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS,
  343. "output_desc_origin_shape:" + std::to_string(i), &dims_origin);
  344. auto layout = TypeUtils::FormatToSerialString(output_desc->GetFormat());
  345. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_layout:" + std::to_string(i),
  346. &layout);
  347. auto layout_origin = TypeUtils::FormatToSerialString(output_desc->GetOriginFormat());
  348. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  349. "output_desc_origin_layout:" + std::to_string(i), &layout_origin);
  350. auto tensor_descriptor = output_desc->tensor_descriptor_.GetProtoMsg();
  351. if (tensor_descriptor != nullptr) {
  352. auto size = tensor_descriptor->size();
  353. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_size:" + std::to_string(i),
  354. &size);
  355. auto weight_size = tensor_descriptor->weight_size();
  356. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  357. "output_desc_weight_size:" + std::to_string(i), &weight_size);
  358. auto device_type = tensor_descriptor->device_type();
  359. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  360. "output_desc_device_type:" + std::to_string(i), &device_type);
  361. auto real_dim_cnt = tensor_descriptor->real_dim_cnt();
  362. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
  363. "output_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt);
  364. const auto &tensor_desc_map = tensor_descriptor->attr();
  365. std::string suffix = ":" + std::to_string(i);
  366. AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForOutputDesc, suffix);
  367. } else {
  368. GELOGW("Tensor descriptor is nullptr");
  369. continue;
  370. }
  371. } else {
  372. GELOGW("Output desc is nullptr");
  373. continue;
  374. }
  375. }
  376. }
  377. }
  378. void OnnxUtils::AddAttrProtoForAttrsFromAttrMap(
  379. const ::google::protobuf::Map<std::string, ::ge::proto::AttrDef> &attr_map, onnx::NodeProto *node_proto,
  380. const std::string &prefix, const std::string &suffix) {
  381. for (const auto &item : attr_map) {
  382. auto attr_name = item.first;
  383. auto attr_def = item.second;
  384. auto attr_type = attr_def.value_case();
  385. if (attr_type == ge::proto::AttrDef::kT) {
  386. const auto &tensor_def = attr_def.t();
  387. const auto &tensor_desc = tensor_def.desc();
  388. auto data_type = ge::proto::DataType_Name(tensor_desc.dtype());
  389. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_dtype" + suffix,
  390. &data_type);
  391. auto dims = tensor_desc.shape().dim();
  392. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + "_desc_shape" + suffix,
  393. dims);
  394. auto layout = tensor_desc.layout();
  395. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_layout" + suffix,
  396. &layout);
  397. auto device_type = tensor_desc.device_type();
  398. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
  399. prefix + attr_name + "_desc_device_type" + suffix, &device_type);
  400. if (kDumpLevel == DUMP_ALL) {
  401. auto data = tensor_def.data();
  402. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_data" + suffix,
  403. &data);
  404. }
  405. }
  406. if (attr_type == ge::proto::AttrDef::kS) {
  407. if (kDumpLevel == DUMP_ALL) {
  408. auto str_value = attr_def.s();
  409. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + suffix, &str_value);
  410. }
  411. }
  412. if (attr_type == ge::proto::AttrDef::kI) {
  413. auto int_value = attr_def.i();
  414. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value);
  415. }
  416. if (attr_type == ge::proto::AttrDef::kF) {
  417. auto float_value = attr_def.f();
  418. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, prefix + attr_name + suffix, &float_value);
  419. }
  420. if (attr_type == ge::proto::AttrDef::kB) {
  421. auto int_value = static_cast<int64_t>(attr_def.b());
  422. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value);
  423. }
  424. if (attr_type == ge::proto::AttrDef::kList) {
  425. const auto &list_value = attr_def.list();
  426. auto list_value_type = list_value.val_type();
  427. if (list_value_type ==
  428. ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) {
  429. if (kDumpLevel == DUMP_ALL) {
  430. const auto &strings = list_value.s();
  431. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, prefix + attr_name + suffix, strings);
  432. }
  433. }
  434. if (list_value_type ==
  435. ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) {
  436. const auto &floats = list_value.f();
  437. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, prefix + attr_name + suffix, floats);
  438. }
  439. if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) {
  440. const auto &ints = list_value.i();
  441. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, ints);
  442. }
  443. if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) {
  444. const auto &bools = list_value.b();
  445. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, bools);
  446. }
  447. }
  448. }
  449. }
  450. void OnnxUtils::AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto) {
  451. if (node == nullptr) {
  452. GELOGE(GRAPH_FAILED, "node is nullptr");
  453. return;
  454. }
  455. // 1.Attributes added from node's methods
  456. auto send_list = node->send_event_id_list_;
  457. if (!send_list.empty()) {
  458. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "send_event_id_list", &send_list);
  459. }
  460. auto recv_list = node->recv_event_id_list_;
  461. if (!recv_list.empty()) {
  462. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "recv_event_id_list", &recv_list);
  463. }
  464. auto op_desc = node->op_;
  465. if (op_desc != nullptr) {
  466. // for input_name_idx_ in opdesc
  467. auto input_name_2_indexs = op_desc->GetAllInputName();
  468. ::google::protobuf::RepeatedPtrField<::std::string> input_names;
  469. ::google::protobuf::RepeatedField<::google::protobuf::int64> input_indexes;
  470. for (const auto &input_name_2_index : input_name_2_indexs) {
  471. std::string input_name = input_name_2_index.first;
  472. input_names.Add(std::move(input_name));
  473. input_indexes.Add(input_name_2_index.second);
  474. }
  475. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "_input_name_key", input_names);
  476. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "_input_name_value", input_indexes);
  477. // 2.Attributes added from node's op_(message OpDef)
  478. // Input and out describes
  479. AddAttrProtoForOpInAndOutDesc(node_proto, op_desc);
  480. // Others
  481. auto op_def = op_desc->op_def_.GetProtoMsg();
  482. if (op_def != nullptr) {
  483. auto id = op_def->id();
  484. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "id", &id);
  485. auto stream_id = op_def->stream_id();
  486. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "stream_id", &stream_id);
  487. const auto &input_name = op_def->input_name();
  488. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "input_name", input_name);
  489. const auto &src_name = op_def->src_name();
  490. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "src_name", src_name);
  491. const auto &src_index = op_def->src_index();
  492. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "src_index", src_index);
  493. const auto &dst_name = op_def->dst_name();
  494. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "dst_name", dst_name);
  495. const auto &dst_index = op_def->dst_index();
  496. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "dst_index", dst_index);
  497. const auto &input_i = op_def->input_i();
  498. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_i", input_i);
  499. const auto &output_i = op_def->output_i();
  500. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_i", output_i);
  501. const auto &workspace = op_def->workspace();
  502. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace", workspace);
  503. const auto &workspace_bytes = op_def->workspace_bytes();
  504. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", workspace_bytes);
  505. const auto &is_input_const = op_def->is_input_const();
  506. AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", is_input_const);
  507. const auto &op_def_attr_map = op_def->attr();
  508. AddAttrProtoForAttrsFromAttrMap(op_def_attr_map, node_proto);
  509. } else {
  510. GELOGE(FAILED, "Opdef is nullptr");
  511. return;
  512. }
  513. } else {
  514. GELOGE(FAILED, "Opdesc is nullptr");
  515. return;
  516. }
  517. }
  518. bool OnnxUtils::EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto) {
  519. if ((node == nullptr) || (node_proto == nullptr)) {
  520. GELOGE(GRAPH_FAILED, "EncodeOpDesc: Input Para Node Invalid");
  521. return false;
  522. }
  523. // 2.Encode map<string, GeAttrValue> attrs_ to AttributeProto
  524. for (auto &node_attr : node->attrs_) {
  525. AddAttrProtoFromAttribute(node_attr, node_proto);
  526. }
  527. // 3.Encode ge::Node members to AttributeProto
  528. AddAttrProtoFromNodeMembers(node, node_proto);
  529. return true;
  530. }
  531. void OnnxUtils::EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto) {
  532. if ((node == nullptr) || (node_proto == nullptr)) {
  533. GELOGE(GRAPH_FAILED, "EncodeNodeLinkForNetronVisual: Input Para Node Invalid");
  534. return;
  535. }
  536. const auto &node_name = node->GetName();
  537. for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  538. if ((out_data_anchor != nullptr) && (!out_data_anchor->GetPeerInDataAnchors().empty())) {
  539. node_proto->add_output(node_name + ":" + std::to_string(out_data_anchor->GetIdx()));
  540. }
  541. }
  542. auto out_control_anchor = node->GetOutControlAnchor();
  543. if ((out_control_anchor != nullptr) && (!out_control_anchor->GetPeerInControlAnchors().empty())) {
  544. node_proto->add_output(node_name + kControlAnchorIndex);
  545. }
  546. }
  547. bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) {
  548. if ((node == nullptr) || (node_proto == nullptr)) {
  549. GELOGE(GRAPH_FAILED, "EncodeNodeLink: Input Para Node Invalid");
  550. return false;
  551. }
  552. node_proto->clear_input();
  553. // 1. Add input by in data edge
  554. for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
  555. auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  556. if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) {
  557. node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" +
  558. std::to_string(peer_out_anchor->GetIdx()));
  559. } else {
  560. // Add "" input
  561. node_proto->add_input("");
  562. }
  563. }
  564. // 2. Add input by in control edge
  565. auto in_control_anchor = node->GetInControlAnchor();
  566. if (in_control_anchor != nullptr) {
  567. auto peer_out_anchors = in_control_anchor->GetPeerOutControlAnchors();
  568. for (const auto &peer_out_anchor : peer_out_anchors) {
  569. if (peer_out_anchor->GetOwnerNode()) {
  570. node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex);
  571. }
  572. }
  573. } else {
  574. GELOGE(FAILED, "Incontrol anchor is nullptr");
  575. return false;
  576. }
  577. // 3. Add output for Netron visual support
  578. EncodeNodeLinkForNetronVisual(node, node_proto);
  579. return true;
  580. }
  581. bool OnnxUtils::EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto) {
  582. if ((node == nullptr) || (node_proto == nullptr)) {
  583. GELOGE(GRAPH_FAILED, "EncodeNode: Input Para Node Invalid");
  584. return false;
  585. }
  586. // 1. Encode name and type
  587. node_proto->set_name(node->GetName());
  588. /// Netron believes that some operators, such as the activation operator of softplus, only have one input,
  589. /// while the link relation of control anchor may exist in ge, resulting in two inputs. Therefore, "ge:" prefix
  590. /// is added to correctly display the link relation at the expense of some color features
  591. node_proto->set_op_type("ge:" + node->GetType());
  592. if (kDumpLevel != DUMP_WITH_OUT_DESC) {
  593. // 2.for attr
  594. if (!EncodeNodeDesc(node, node_proto)) {
  595. GELOGE(GRAPH_FAILED, "Encode NodeDesc: %s failed", node->GetName().c_str());
  596. return false;
  597. }
  598. }
  599. // 3.for link info
  600. return EncodeNodeLink(node, node_proto);
  601. }
  602. void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type) {
  603. if ((node == nullptr) || (tensor_type == nullptr)) {
  604. GELOGE(GRAPH_FAILED, "EncodeTypeProtoTensorType: Input Para Node or tensor_type Invalid");
  605. return;
  606. }
  607. const auto &op_desc = node->GetOpDesc();
  608. if (op_desc != nullptr) {
  609. uint32_t size_out = static_cast<uint32_t>(op_desc->GetOutputsSize());
  610. if (size_out > 0) {
  611. for (uint32_t i = 0; i < size_out; i++) {
  612. const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(i);
  613. if (ge_tensor != nullptr) {
  614. auto ge_data_type = ge_tensor->GetDataType();
  615. auto onnx_data_type = EncodeDataType(ge_data_type);
  616. tensor_type->set_elem_type(onnx_data_type);
  617. onnx::TensorShapeProto *shape = tensor_type->mutable_shape();
  618. if (shape != nullptr) {
  619. for (auto d : ge_tensor->GetShape().GetDims()) {
  620. auto dim = shape->add_dim();
  621. dim->set_dim_value(d);
  622. }
  623. } else {
  624. GELOGW("Shape is nullptr");
  625. continue;
  626. }
  627. } else {
  628. GELOGW("Ge tensor is nullptr");
  629. continue;
  630. }
  631. }
  632. }
  633. } else {
  634. GELOGW("OpDesc Is Empty, nodeName %s nodeType %s", node->GetName().c_str(), node->GetType().c_str());
  635. return;
  636. }
  637. }
  638. void OnnxUtils::EncodeValueInfo(const NodePtr &node, onnx::ValueInfoProto *value_info_proto) {
  639. if ((node == nullptr) || (value_info_proto == nullptr)) {
  640. GELOGE(GRAPH_FAILED, "EncodeValueInfo: Input Para Node or value_info_proto Invalid");
  641. return;
  642. }
  643. value_info_proto->set_name(node->GetName());
  644. onnx::TypeProto *t = value_info_proto->mutable_type();
  645. onnx::TypeProto_Tensor *tensor_type = t->mutable_tensor_type();
  646. EncodeTypeProtoTensorType(node, tensor_type);
  647. }
  648. bool OnnxUtils::EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto) {
  649. if ((graph == nullptr) || (graph_proto == nullptr)) {
  650. GELOGE(GRAPH_FAILED, "EncodeGraph: Input para Invalid");
  651. return false;
  652. }
  653. graph_proto->set_name(graph->GetName());
  654. // 1. Add graph inputs
  655. for (const auto &input : graph->GetInputNodes()) {
  656. auto value_info_proto = graph_proto->add_input();
  657. EncodeValueInfo(input, value_info_proto);
  658. }
  659. // 2. Add graph outputs
  660. for (const auto &output : graph->GetOutputNodes()) {
  661. auto value_info_proto = graph_proto->add_output();
  662. EncodeValueInfo(output, value_info_proto);
  663. }
  664. // 3. Add nodes
  665. for (const auto &node : graph->GetDirectNode()) {
  666. if (!EncodeNode(node, graph_proto->add_node())) {
  667. GELOGW("EncodeNode failed");
  668. continue;
  669. }
  670. }
  671. return true;
  672. }
  673. bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto) {
  674. model_proto.set_model_version(model.GetVersion());
  675. model_proto.set_ir_version(onnx::IR_VERSION);
  676. model_proto.set_producer_name(model.GetName());
  677. auto &graph = model.graph_;
  678. auto compute_graph = GraphUtils::GetComputeGraph(graph);
  679. if (compute_graph == nullptr) {
  680. GELOGE(GRAPH_FAILED, "GetComputeGraph: return nullptr");
  681. return false;
  682. }
  683. auto graph_proto = model_proto.mutable_graph();
  684. if (graph_proto == nullptr) {
  685. GELOGE(GRAPH_FAILED, "mutable_graph: %s return nullptr", compute_graph->GetName().c_str());
  686. return false;
  687. }
  688. if (!EncodeGraph(compute_graph, graph_proto)) {
  689. GELOGE(GRAPH_FAILED, "EncodeGraph: %s fail", compute_graph->GetName().c_str());
  690. return false;
  691. }
  692. // For subgraphs: a subgraph is represented by a node
  693. for (const auto &sub_compute_graph : compute_graph->GetAllSubgraphs()) {
  694. if (sub_compute_graph != nullptr) {
  695. auto node_proto = graph_proto->add_node();
  696. if (node_proto == nullptr) {
  697. GELOGW("Node proto is nullptr");
  698. continue;
  699. }
  700. node_proto->set_name(sub_compute_graph->GetName());
  701. node_proto->set_op_type(kNodeTypeForSubgraph);
  702. auto attr = node_proto->add_attribute();
  703. attr->set_name("graph");
  704. attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
  705. auto sub_graph_proto = attr->mutable_g();
  706. if (sub_graph_proto == nullptr) {
  707. GELOGW("Sub graph proto is nullptr");
  708. continue;
  709. }
  710. if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) {
  711. GELOGW("Encode sub graph: %s fail", sub_compute_graph->GetName().c_str());
  712. continue;
  713. }
  714. } else {
  715. GELOGW("Graph: %s subgraph is nullptr, skip EncodeGraph", compute_graph->GetName().c_str());
  716. continue;
  717. }
  718. }
  719. return true;
  720. }
  721. // Part 2: from ONNX Protobuf convert to IR
  722. static std::map<onnx::TensorProto_DataType, ge::DataType> onnxDataTypeToGeMap = {
  723. {onnx::TensorProto_DataType_INT64, DT_INT64}, {onnx::TensorProto_DataType_UINT64, DT_UINT64},
  724. {onnx::TensorProto_DataType_FLOAT, DT_FLOAT}, {onnx::TensorProto_DataType_INT32, DT_INT32},
  725. {onnx::TensorProto_DataType_UINT32, DT_UINT32}, {onnx::TensorProto_DataType_INT8, DT_INT8},
  726. {onnx::TensorProto_DataType_UINT8, DT_UINT8}, {onnx::TensorProto_DataType_INT16, DT_INT16},
  727. {onnx::TensorProto_DataType_UINT16, DT_UINT16}, {onnx::TensorProto_DataType_FLOAT16, DT_FLOAT16},
  728. {onnx::TensorProto_DataType_DOUBLE, DT_DOUBLE}, {onnx::TensorProto_DataType_BOOL, DT_BOOL},
  729. };
  730. ge::DataType OnnxUtils::DecodeDataType(onnx::TensorProto_DataType data_type) {
  731. auto it = onnxDataTypeToGeMap.find(data_type);
  732. if (it != onnxDataTypeToGeMap.end()) {
  733. return it->second;
  734. } else {
  735. GELOGW("DecodeDataType: datatype not support %u", data_type);
  736. return ge::DT_UNDEFINED;
  737. }
  738. }
  739. bool OnnxUtils::ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index) {
  740. auto sep = node_name_index.rfind(':');
  741. if (sep == std::string::npos) {
  742. return false;
  743. }
  744. node_name = node_name_index.substr(0, sep);
  745. auto index_str = node_name_index.substr(sep + 1);
  746. index = static_cast<int32_t>(std::strtol(index_str.c_str(), nullptr, 10));
  747. return true;
  748. }
  749. bool OnnxUtils::DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr) {
  750. if (node_ptr == nullptr) {
  751. GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp: node_ptr is nullptr");
  752. return false;
  753. }
  754. // Data edge
  755. if (item.src_out_index >= 0) {
  756. auto src_anchor = node_ptr->GetOutDataAnchor(item.src_out_index);
  757. auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index);
  758. if ((src_anchor == nullptr) || (dst_anchor == nullptr)) {
  759. GELOGE(GRAPH_FAILED, "Get data anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index,
  760. item.dst_node_name.c_str(), item.dst_in_index);
  761. return false;
  762. }
  763. if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) {
  764. GELOGE(GRAPH_FAILED, "Data Anchor: src_anchor->LinkTo(dst_anchor) failed");
  765. return false;
  766. }
  767. // Control edge
  768. } else {
  769. auto src_anchor = node_ptr->GetOutControlAnchor();
  770. auto dst_anchor = item.dst_node->GetInControlAnchor();
  771. if ((src_anchor == nullptr) || (dst_anchor == nullptr)) {
  772. GELOGE(GRAPH_FAILED, "Get control anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index,
  773. item.dst_node_name.c_str(), item.dst_in_index);
  774. return false;
  775. }
  776. if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) {
  777. GELOGE(GRAPH_FAILED, "Control Anchor: src_anchor->LinkTo(dst_anchor) failed");
  778. return false;
  779. }
  780. }
  781. return true;
  782. }
  783. bool OnnxUtils::DecodeNodeLink(const std::vector<onnx::NodeProto> &node_proto_vector,
  784. const std::map<std::string, NodePtr> &node_map) {
  785. for (const auto &node_proto : node_proto_vector) {
  786. const auto &node_name = node_proto.name();
  787. auto dst_node = node_map.find(node_name);
  788. if ((dst_node == node_map.end()) || (dst_node->second == nullptr)) {
  789. GELOGE(GRAPH_FAILED, "destination node: %s find failed or is nullptr", node_name.c_str());
  790. return false;
  791. }
  792. int32_t dst_index = 0;
  793. for (const auto &input : node_proto.input()) {
  794. std::string input_node_name;
  795. int32_t index = 0;
  796. if (ParseNameIndex(input, input_node_name, index)) {
  797. auto item = NodeLinkInfo{input_node_name, index, dst_node->second, dst_index, node_proto.name()};
  798. auto src_node = node_map.find(input_node_name);
  799. if (src_node == node_map.end()) {
  800. GELOGE(GRAPH_FAILED, "find src node: %s failed", input_node_name.c_str());
  801. return false;
  802. }
  803. auto node_ptr = src_node->second;
  804. if (node_ptr == nullptr) {
  805. GELOGE(GRAPH_FAILED, "src node: %s is nullptr", input_node_name.c_str());
  806. return false;
  807. }
  808. if (!DecodeNodeLinkImp(item, node_ptr)) {
  809. GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp node: %s failed", input_node_name.c_str());
  810. return false;
  811. }
  812. }
  813. if (index >= 0) {
  814. dst_index++;
  815. }
  816. }
  817. }
  818. return true;
  819. }
  820. void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector<std::string> &strings) {
  821. if (attr_proto.type() != onnx::AttributeProto_AttributeType_STRINGS) {
  822. GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str());
  823. return;
  824. }
  825. for (int i = 0; i < attr_proto.strings_size(); i++) {
  826. strings.push_back(attr_proto.strings(i));
  827. }
  828. }
  829. void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value) {
  830. if (attr_proto.type() != onnx::AttributeProto_AttributeType_STRING) {
  831. GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str());
  832. return;
  833. }
  834. value = attr_proto.s();
  835. }
  836. void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector<int64_t> &ints) {
  837. if (attr_proto.type() != onnx::AttributeProto_AttributeType_INTS) {
  838. GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str());
  839. return;
  840. }
  841. for (int i = 0; i < attr_proto.ints_size(); i++) {
  842. ints.push_back(attr_proto.ints(i));
  843. }
  844. }
  845. void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t &value) {
  846. if (attr_proto.type() != onnx::AttributeProto_AttributeType_INT) {
  847. GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str());
  848. return;
  849. }
  850. value = attr_proto.i();
  851. }
  852. void OnnxUtils::DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto,
  853. const std::string &attr_name_for_input_desc, int32_t index,
  854. OpDescPtr &op_desc) {
  855. if (op_desc->MutableInputDesc(static_cast<uint32_t>(index)) == nullptr) {
  856. GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableInputDesc(static_cast<uint32_t>(index)) is nullptr",
  857. op_desc->GetName().c_str(), attr_name_for_input_desc.c_str());
  858. return;
  859. }
  860. if (attr_name_for_input_desc == "input_desc_dtype") {
  861. auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s());
  862. op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetDataType(data_type);
  863. } else if (attr_name_for_input_desc == "input_desc_shape") {
  864. std::vector<std::int64_t> ints;
  865. DecodeAttribute(attr_proto, ints);
  866. GeShape ge_shape(ints);
  867. op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape);
  868. } else if (attr_name_for_input_desc == "input_desc_layout") {
  869. auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
  870. op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetFormat(data_format);
  871. } else if (attr_name_for_input_desc == "input_desc_origin_shape") {
  872. std::vector<std::int64_t> ints;
  873. DecodeAttribute(attr_proto, ints);
  874. GeShape ge_shape(ints);
  875. op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape);
  876. } else if (attr_name_for_input_desc == "input_desc_origin_layout") {
  877. auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
  878. op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format);
  879. } else if (attr_name_for_input_desc == "input_desc_size") {
  880. int64_t input_size = 0;
  881. auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
  882. DecodeAttribute(attr_proto, input_size);
  883. tensor_descriptor->set_size(input_size);
  884. } else if (attr_name_for_input_desc == "input_desc_data_offset") {
  885. auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
  886. int64_t offset = 0;
  887. DecodeAttribute(attr_proto, offset);
  888. tensor_descriptor->set_data_offset(offset);
  889. } else {
  890. return;
  891. }
  892. }
  893. void OnnxUtils::DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto,
  894. const std::string &attr_name_for_output_desc, int32_t index,
  895. OpDescPtr &op_desc) {
  896. if (op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) == nullptr) {
  897. GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) is nullptr",
  898. op_desc->GetName().c_str(), attr_name_for_output_desc.c_str());
  899. return;
  900. }
  901. if (attr_name_for_output_desc == "output_desc_dtype") {
  902. auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s());
  903. op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetDataType(data_type);
  904. } else if (attr_name_for_output_desc == "output_desc_shape") {
  905. std::vector<std::int64_t> ints;
  906. DecodeAttribute(attr_proto, ints);
  907. GeShape ge_shape(ints);
  908. op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape);
  909. } else if (attr_name_for_output_desc == "output_desc_layout") {
  910. auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
  911. op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetFormat(data_format);
  912. } else if (attr_name_for_output_desc == "output_desc_origin_shape") {
  913. std::vector<std::int64_t> ints;
  914. DecodeAttribute(attr_proto, ints);
  915. GeShape ge_shape(ints);
  916. op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape);
  917. } else if (attr_name_for_output_desc == "output_desc_origin_layout") {
  918. auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
  919. op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format);
  920. } else if (attr_name_for_output_desc == "output_desc_size") {
  921. int64_t output_size = 0;
  922. auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
  923. DecodeAttribute(attr_proto, output_size);
  924. tensor_descriptor->set_size(output_size);
  925. } else if (attr_name_for_output_desc == "output_desc_data_offset") {
  926. auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
  927. int64_t offset = 0;
  928. DecodeAttribute(attr_proto, offset);
  929. tensor_descriptor->set_data_offset(offset);
  930. } else {
  931. return;
  932. }
  933. }
  934. void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto,
  935. const std::string &attr_name_for_input_output_desc, int32_t index,
  936. OpDescPtr &op_desc) {
  937. if (op_desc == nullptr) {
  938. GELOGE(GRAPH_FAILED, "op_desc is nullptr");
  939. return;
  940. }
  941. if (attr_name_for_input_output_desc.substr(0, kInputPrefixLength) == "input") {
  942. DecodeNodeAttributeForOpInDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc);
  943. } else if (attr_name_for_input_output_desc.substr(0, kOutputPrefixLength) == "output") {
  944. DecodeNodeAttributeForOpOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc);
  945. } else {
  946. return;
  947. }
  948. }
  949. void OnnxUtils::DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def) {
  950. auto attr_map = op_def.mutable_attr();
  951. const auto &attr_name = attr_proto.name();
  952. ge::proto::AttrDef op_attr;
  953. int64_t value = 0;
  954. DecodeAttribute(attr_proto, value);
  955. op_attr.set_i(value);
  956. attr_map->insert(AttrDefPair(attr_name, op_attr));
  957. }
  958. void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) {
  959. if (op_desc == nullptr) {
  960. GELOGE(GRAPH_FAILED, "DecodeNodeAttributeForOpDesc: op_desc is nullptr");
  961. return;
  962. }
  963. const auto &attr_name = attr_proto.name();
  964. std::string attr_name_for_input_output_desc;
  965. int32_t index = 0;
  966. if (!ParseNameIndex(attr_name, attr_name_for_input_output_desc, index)) {
  967. if (attr_name == "id") {
  968. op_desc->SetId(attr_proto.i());
  969. } else if (attr_name == "stream_id") {
  970. op_desc->SetStreamId(attr_proto.i());
  971. } else if (attr_name == "src_name") {
  972. std::vector<std::string> strings;
  973. DecodeAttribute(attr_proto, strings);
  974. op_desc->SetSrcName(strings);
  975. } else if (attr_name == "dst_name") {
  976. std::vector<std::string> strings;
  977. DecodeAttribute(attr_proto, strings);
  978. op_desc->SetDstName(strings);
  979. } else if (attr_name == "src_index") {
  980. std::vector<std::int64_t> ints;
  981. DecodeAttribute(attr_proto, ints);
  982. op_desc->SetSrcIndex(ints);
  983. } else if (attr_name == "dst_index") {
  984. std::vector<std::int64_t> ints;
  985. DecodeAttribute(attr_proto, ints);
  986. op_desc->SetDstIndex(ints);
  987. } else if (attr_name == "fusion_scope") {
  988. DecodeNodeAttributeForOpDef(attr_proto, *op_desc->op_def_.GetProtoMsg());
  989. } else if (attr_name == "input_i") {
  990. std::vector<std::int64_t> ints;
  991. DecodeAttribute(attr_proto, ints);
  992. op_desc->SetInputOffset(ints);
  993. } else if (attr_name == "output_i") {
  994. std::vector<std::int64_t> ints;
  995. DecodeAttribute(attr_proto, ints);
  996. op_desc->SetOutputOffset(ints);
  997. } else {
  998. return;
  999. }
  1000. // Update input and output desc
  1001. } else {
  1002. DecodeNodeAttributeForOpInAndOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc);
  1003. }
  1004. }
  1005. bool OnnxUtils::DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &op_desc) {
  1006. if (op_desc == nullptr || node_proto == nullptr) {
  1007. GELOGE(GRAPH_FAILED, " Op_desc is nullptr or node_proto is nullptr");
  1008. return false;
  1009. }
  1010. // 1. Decode node_proto name and type
  1011. op_desc->SetName(node_proto->name());
  1012. const auto &node_type_with_ge_prefix = node_proto->op_type();
  1013. auto sep = node_type_with_ge_prefix.find(':');
  1014. if (sep == std::string::npos) {
  1015. return false;
  1016. }
  1017. auto node_type = node_type_with_ge_prefix.substr(sep + 1);
  1018. op_desc->SetType(node_type);
  1019. // 2. Add empty input and output desc
  1020. for (const auto &attr : node_proto->attribute()) {
  1021. if (attr.name() == "input_desc_nums") {
  1022. auto size_in = attr.i();
  1023. for (int64_t i = 0; i < size_in; i++) {
  1024. GeTensorDesc ge_tensor_desc;
  1025. GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add inputdesc failed.");
  1026. }
  1027. }
  1028. if (attr.name() == "output_desc_nums") {
  1029. auto size_out = attr.i();
  1030. for (int64_t i = 0; i < size_out; i++) {
  1031. GeTensorDesc ge_tensor_desc;
  1032. GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add outputdesc failed.");
  1033. }
  1034. }
  1035. }
  1036. // 3.Decode node_proto attributes
  1037. for (int i = 0; i < node_proto->attribute_size(); i++) {
  1038. DecodeNodeAttributeForOpDesc(node_proto->attribute(i), op_desc);
  1039. }
  1040. return true;
  1041. }
  1042. bool OnnxUtils::DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph) {
  1043. if (recursion_depth > kMaxRecursionDepth) {
  1044. GELOGE(GRAPH_FAILED, "DecodeGraph: recursion depth is too large, abort");
  1045. return false;
  1046. }
  1047. graph = ComGraphMakeShared<ge::ComputeGraph>(graph_proto.name());
  1048. GE_CHK_BOOL_EXEC(graph != nullptr, return false, "ComputeGraph make shared failed");
  1049. /// 1. Decode all nodes first, node should include input
  1050. /// and output nodes and nodes which represent sub graphs
  1051. std::map<std::string, NodePtr> node_map;
  1052. std::vector<onnx::NodeProto> node_proto_vector;
  1053. for (const auto &node_proto : graph_proto.node()) {
  1054. // a. nodes represent sub graphs
  1055. if (node_proto.op_type() == kNodeTypeForSubgraph) {
  1056. ComputeGraphPtr compute_graph;
  1057. // in this case, node only have one attr, whose type is AttributeProto_AttributeType_GRAPH
  1058. const auto &node_attr = node_proto.attribute(0);
  1059. if ((node_attr.type() == onnx::AttributeProto_AttributeType_GRAPH) &&
  1060. DecodeGraph(recursion_depth + 1, node_attr.g(), compute_graph)) {
  1061. (void)graph->AddSubGraph(compute_graph);
  1062. } else {
  1063. GELOGE(GRAPH_FAILED, "Decode sub graph %s failed with node type:%d", node_proto.name().c_str(),
  1064. node_attr.type());
  1065. return false;
  1066. }
  1067. // b. direct nodes in graph
  1068. } else {
  1069. node_proto_vector.push_back(node_proto);
  1070. OpDescPtr op_desc = ComGraphMakeShared<OpDesc>();
  1071. // b.1 For node desc
  1072. if (!DecodeNodeDesc(&node_proto, op_desc)) {
  1073. GELOGE(GRAPH_FAILED, "Decode node desc %s failed ", node_proto.name().c_str());
  1074. return false;
  1075. }
  1076. auto node = graph->AddNode(op_desc);
  1077. node_map.insert(std::make_pair(node_proto.name(), node));
  1078. }
  1079. }
  1080. /// We get all nodes in graph here
  1081. /// b.2 For node link
  1082. if (!DecodeNodeLink(node_proto_vector, node_map)) {
  1083. GELOGE(GRAPH_FAILED, "Decode node link failed");
  1084. return false;
  1085. }
  1086. // 2. Add inputs nodes for graph
  1087. for (const auto &input : graph_proto.input()) {
  1088. const auto &input_node_name = input.name();
  1089. auto input_node_item = node_map.find(input_node_name);
  1090. if (input_node_item == node_map.end()) {
  1091. GELOGE(GRAPH_FAILED, "cannot find graph's input node %s in node_", input_node_name.c_str());
  1092. return false;
  1093. }
  1094. auto ret = graph->AddInputNode(input_node_item->second);
  1095. GE_CHK_BOOL_EXEC(ret != nullptr, continue, "Add inputnode failed");
  1096. }
  1097. // 3. Add outputs nodes for graph
  1098. for (const auto &output : graph_proto.output()) {
  1099. const auto &output_node_name = output.name();
  1100. auto output_node_item = node_map.find(output_node_name);
  1101. if (output_node_item == node_map.end()) {
  1102. GELOGE(GRAPH_FAILED, "cannot find graph's output node %s in node_", output_node_name.c_str());
  1103. return false;
  1104. }
  1105. auto ret = graph->AddOutputNode(output_node_item->second);
  1106. if (ret == nullptr) {
  1107. GELOGW("Add outputnode failed,out put node is %s", output_node_name.c_str());
  1108. continue;
  1109. }
  1110. }
  1111. return true;
  1112. }
  1113. bool OnnxUtils::ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model) {
  1114. model.name_ = model_proto.producer_name();
  1115. model.version_ = static_cast<uint32_t>(model_proto.model_version());
  1116. auto &graph_proto = model_proto.graph();
  1117. ComputeGraphPtr compute_graph;
  1118. // 0 means recursion depth, father call
  1119. if (!DecodeGraph(0, graph_proto, compute_graph)) {
  1120. GELOGE(GRAPH_FAILED, "Decode compute graph from graph_proto failed");
  1121. return false;
  1122. }
  1123. model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  1124. return true;
  1125. }
  1126. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示