You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_parser_util.cc 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph_parser_util.h"
  17. #include <memory>
  18. #include "common/auth/file_saver.h"
  19. #include "common/convert/pb2json.h"
  20. #include "common/debug/log.h"
  21. #include "common/debug/memory_dumper.h"
  22. #include "common/model_parser/base.h"
  23. #include "common/model_saver.h"
  24. #include "common/properties_manager.h"
  25. #include "common/string_util.h"
  26. #include "common/types.h"
  27. #include "common/util.h"
  28. #include "common/util/error_manager/error_manager.h"
  29. #include "external/register/register_types.h"
  30. #include "framework/common/debug/ge_log.h"
  31. #include "framework/omg/parser/parser_inner_ctx.h"
  32. #include "graph/compute_graph.h"
  33. #include "graph/debug/ge_attr_define.h"
  34. #include "graph/debug/ge_attr_define.h"
  35. #include "graph/optimize/common/params.h"
  36. #include "graph/utils/type_utils.h"
  37. #include "omg/omg_inner_types.h"
  38. #include "omg/parser/model_parser.h"
  39. #include "omg/parser/parser_factory.h"
  40. #include "omg/parser/weights_parser.h"
  41. #include "parser/common/pre_checker.h"
  42. #include "proto/ge_ir.pb.h"
  43. #include "register/op_registry.h"
  44. namespace ge {
  45. namespace {
  46. // The function is incomplete. Currently, only l2_optimize, off_optimize is supported.
  47. const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\"";
  48. const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\"";
  49. const char *const kSplitError1 = "size not equal to 2 split by \":\"";
  50. const char *const kEmptyError = "can not be empty";
  51. const char *const kFloatNumError = "exist float number";
  52. const char *const kDigitError = "is not digit";
  53. const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\"";
  54. const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8";
  55. const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes.";
  56. vector<string> SplitInputShape(const std::string &input_shape) {
  57. vector<string> shape_pair_vec;
  58. size_t pos = input_shape.rfind(":");
  59. if (pos != std::string::npos) {
  60. shape_pair_vec.emplace_back(input_shape.substr(0, pos));
  61. shape_pair_vec.emplace_back(input_shape.substr(pos + 1, input_shape.size() - pos));
  62. }
  63. return shape_pair_vec;
  64. }
  65. static std::map<std::string, ge::DataType> output_type_str_to_datatype = {
  66. {"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}};
  67. static bool CheckInputTrueOrFalse(const std::string &s, const std::string &atc_param) {
  68. if ((s == "true") || (s == "false")) {
  69. return true;
  70. } else {
  71. ErrorManager::GetInstance().ATCReportErrMessage("E10033", {"parameter", "value"}, {atc_param, s});
  72. GELOGE(PARAM_INVALID, "Input parameter[--%s]'s value[%s] must be true or false.", atc_param.c_str(), s.c_str());
  73. return false;
  74. }
  75. }
  76. bool CheckDigitStr(std::string &str) {
  77. for (char c : str) {
  78. if (!isdigit(c)) {
  79. GELOGE(domi::FAILED, "value[%s] is not positive integer", str.c_str());
  80. return false;
  81. }
  82. }
  83. return true;
  84. }
  85. Status StringToInt(std::string &str, int32_t &value) {
  86. try {
  87. if (!CheckDigitStr(str)) {
  88. GELOGE(PARAM_INVALID, "Invalid of digit string: %s ", str.c_str());
  89. ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
  90. {"--output_type", str, "is not positive integer"});
  91. return PARAM_INVALID;
  92. }
  93. value = stoi(str);
  94. } catch (std::invalid_argument &) {
  95. GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch invalid_argument.", str.c_str());
  96. ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"output_type", str});
  97. return PARAM_INVALID;
  98. } catch (std::out_of_range &) {
  99. GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch out_of_range.", str.c_str());
  100. ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"output_type", str});
  101. return PARAM_INVALID;
  102. }
  103. return SUCCESS;
  104. }
  105. Status VerifyOutputTypeAndOutNodes(std::vector<std::string> &out_type_vec) {
  106. std::vector<std::pair<std::string, int32_t>> user_out_nodes = domi::GetContext().user_out_nodes;
  107. std::set<std::string> out_nodes_info;
  108. for (uint32_t i = 0; i < user_out_nodes.size(); ++i) {
  109. // out_nodes set should include output_type and output_format
  110. std::string tmp = user_out_nodes[i].first + ":" + to_string(user_out_nodes[i].second);
  111. out_nodes_info.emplace(tmp);
  112. }
  113. for (uint32_t i = 0; i < out_type_vec.size(); ++i) {
  114. if (out_nodes_info.find(out_type_vec[i]) == out_nodes_info.end()) {
  115. ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
  116. {"--output_type", out_type_vec[i], kOutputTypeError});
  117. GELOGE(domi::FAILED, "Invalid value for --output_type[%s], %s.", out_type_vec[i].c_str(), kOutputTypeError);
  118. return domi::FAILED;
  119. }
  120. }
  121. return domi::SUCCESS;
  122. }
  123. Status ParseOutputType(const std::string &output_type, std::map<std::string, vector<uint32_t>> &out_type_index_map,
  124. std::map<std::string, vector<ge::DataType>> &out_type_dt_map) {
  125. if (output_type.find(':') == std::string::npos) {
  126. GELOGI("output_type is not multiple nodes, means all out nodes");
  127. auto it = output_type_str_to_datatype.find(output_type);
  128. if (it == output_type_str_to_datatype.end()) {
  129. ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
  130. {"--output_type", output_type, kOutputTypeSupport});
  131. GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport);
  132. return domi::FAILED;
  133. }
  134. return domi::SUCCESS;
  135. }
  136. std::vector<std::string> out_type_vec;
  137. vector<string> nodes_v = StringUtils::Split(output_type, ';');
  138. for (const string &node : nodes_v) {
  139. vector<string> node_index_type_v = StringUtils::Split(node, ':');
  140. if (node_index_type_v.size() != 3) { // The size must be 3.
  141. ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
  142. {"--output_type", node, kOutputTypeSample});
  143. GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", node.c_str(), kOutputTypeSample);
  144. return domi::FAILED;
  145. }
  146. ge::DataType tmp_dt;
  147. std::string node_name = StringUtils::Trim(node_index_type_v[0]);
  148. std::string index_str = StringUtils::Trim(node_index_type_v[1]);
  149. int32_t index;
  150. if (StringToInt(index_str, index) != SUCCESS) {
  151. GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s.", index_str.c_str());
  152. return domi::FAILED;
  153. }
  154. std::string dt_value = StringUtils::Trim(node_index_type_v[2]);
  155. auto it = output_type_str_to_datatype.find(dt_value);
  156. if (it == output_type_str_to_datatype.end()) {
  157. ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
  158. {"--output_type", dt_value, kOutputTypeSupport});
  159. GELOGE(ge::PARAM_INVALID, "Invalid value for --output_type[%s], %s.", dt_value.c_str(), kOutputTypeSupport);
  160. return domi::FAILED;
  161. } else {
  162. tmp_dt = it->second;
  163. }
  164. out_type_vec.push_back(node_name + ":" + index_str);
  165. auto it_index = out_type_index_map.find(node_name);
  166. if (it_index == out_type_index_map.end()) {
  167. vector<uint32_t> tmp_vec;
  168. tmp_vec.push_back(index);
  169. out_type_index_map.emplace(node_name, tmp_vec);
  170. } else {
  171. it_index->second.push_back(index);
  172. }
  173. auto it_dt = out_type_dt_map.find(node_name);
  174. if (it_dt == out_type_dt_map.end()) {
  175. vector<ge::DataType> tmp_vec;
  176. tmp_vec.push_back(tmp_dt);
  177. out_type_dt_map.emplace(node_name, tmp_vec);
  178. } else {
  179. it_dt->second.push_back(tmp_dt);
  180. }
  181. }
  182. return VerifyOutputTypeAndOutNodes(out_type_vec);
  183. }
  184. Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) {
  185. int32_t out_size = op_desc->GetOutputsSize();
  186. if (index < 0 || index >= out_size) {
  187. GELOGE(domi::FAILED,
  188. "out_node [%s] output index:%d must be smaller "
  189. "than node output size:%d and can not be negative!",
  190. op_desc->GetName().c_str(), index, out_size);
  191. std::string fail_reason = "output index:" + to_string(index) +
  192. " must be smaller than output size:" + to_string(out_size) + " and can not be negative!";
  193. ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"},
  194. {"out_nodes", op_desc->GetName(), fail_reason});
  195. return domi::FAILED;
  196. }
  197. return domi::SUCCESS;
  198. }
  199. Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) {
  200. ge::OpDescPtr tmpDescPtr = node->GetOpDesc();
  201. if (tmpDescPtr == nullptr) {
  202. GELOGE(domi::FAILED, "Get outnode op desc fail.");
  203. return domi::FAILED;
  204. }
  205. size_t size = tmpDescPtr->GetOutputsSize();
  206. if (node->GetType() != NETOUTPUT) {
  207. for (size_t index = 0; index < size; ++index) {
  208. output_nodes_info.push_back(std::make_pair(node, index));
  209. }
  210. } else {
  211. const auto in_anchors = node->GetAllInDataAnchors();
  212. for (auto in_anchor : in_anchors) {
  213. auto out_anchor = in_anchor->GetPeerOutAnchor();
  214. if (out_anchor == nullptr) {
  215. GELOGE(domi::FAILED, "Get leaf node op desc fail.");
  216. return domi::FAILED;
  217. }
  218. auto out_node = out_anchor->GetOwnerNode();
  219. output_nodes_info.push_back(std::make_pair(out_node, out_anchor->GetIdx()));
  220. }
  221. }
  222. return SUCCESS;
  223. }
  224. void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
  225. std::vector<std::string> &output_nodes_name) {
  226. output_nodes_name.clear();
  227. if (domi::GetContext().out_top_names.empty()) {
  228. // tf process, no top name.
  229. for (const auto output_node_info : output_nodes_info) {
  230. std::string node_name = output_node_info.first->GetName();
  231. int32_t index = output_node_info.second;
  232. output_nodes_name.push_back(node_name + ":" + std::to_string(index));
  233. }
  234. return;
  235. }
  236. // caffe process, need add top name after node_name:index
  237. for (size_t i = 0; i < output_nodes_info.size(); ++i) {
  238. std::string node_name = output_nodes_info[i].first->GetName();
  239. int32_t index = output_nodes_info[i].second;
  240. if (i < domi::GetContext().out_top_names.size()) {
  241. output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" + domi::GetContext().out_top_names[i]);
  242. } else {
  243. GELOGW("Get top name of node [%s] fail.", node_name.c_str());
  244. output_nodes_name.push_back(node_name + ":" + std::to_string(index));
  245. }
  246. }
  247. }
  248. } // namespace
  249. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ParseOutputFp16NodesFormat(const string &is_output_fp16) {
  250. if (is_output_fp16.empty()) {
  251. return SUCCESS;
  252. }
  253. vector<domiTensorFormat_t> &output_formats = domi::GetContext().output_formats;
  254. output_formats.clear();
  255. vector<string> node_format_vec = StringUtils::Split(is_output_fp16, ',');
  256. for (auto &is_fp16 : node_format_vec) {
  257. StringUtils::Trim(is_fp16);
  258. if (!CheckInputTrueOrFalse(is_fp16, "is_output_adjust_hw_layout")) {
  259. GELOGE(PARAM_INVALID, "Invalid Param, is_output_adjust_hw_layout only support true/false: but is [%s]",
  260. is_output_fp16.c_str());
  261. return PARAM_INVALID;
  262. }
  263. if (is_fp16 == "false") {
  264. output_formats.push_back(DOMI_TENSOR_ND);
  265. } else if (is_fp16 == "true") {
  266. output_formats.push_back(domi::DOMI_TENSOR_NC1HWC0);
  267. }
  268. }
  269. return SUCCESS;
  270. }
  271. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph,
  272. const std::string &output_type,
  273. const std::string &output) {
  274. ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
  275. GE_CHECK_NOTNULL(compute_graph);
  276. std::vector<std::pair<std::string, int32_t>> user_out_nodes = domi::GetContext().user_out_nodes;
  277. std::vector<domiTensorFormat_t> output_formats = domi::GetContext().output_formats;
  278. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_info;
  279. std::vector<std::string> output_nodes_name;
  280. std::map<std::string, vector<uint32_t>> out_type_index_map;
  281. std::map<std::string, vector<ge::DataType>> out_type_dt_map;
  282. if (!output_type.empty()) {
  283. if (ParseOutputType(output_type, out_type_index_map, out_type_dt_map) != SUCCESS) {
  284. GELOGE(domi::FAILED, "Parse output_type failed.");
  285. return domi::FAILED;
  286. }
  287. }
  288. // User declared outputs
  289. for (uint32_t i = 0; i < user_out_nodes.size(); ++i) {
  290. ge::NodePtr out_node = compute_graph->FindNode(user_out_nodes[i].first);
  291. if (out_node == nullptr) {
  292. GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", user_out_nodes[i].first.c_str());
  293. return domi::FAILED;
  294. }
  295. auto op_desc = out_node->GetOpDesc();
  296. GE_CHECK_NOTNULL(op_desc);
  297. if (CheckOutNode(op_desc, user_out_nodes[i].second) != SUCCESS) {
  298. GELOGE(domi::FAILED, "Check out node (%s) fail.", user_out_nodes[i].first.c_str());
  299. return domi::FAILED;
  300. }
  301. if (i < output_formats.size()) {
  302. if (output_formats[i] == domi::DOMI_TENSOR_NC1HWC0) {
  303. GELOGI("The output node [%s] should be set NC1HWC0", user_out_nodes[i].first.c_str());
  304. if (!ge::AttrUtils::SetBool(op_desc, "output_set_fp16_nc1hwc0", true)) {
  305. GELOGW("The output node [%s] set NC1HWC0 failed", user_out_nodes[i].first.c_str());
  306. }
  307. }
  308. }
  309. auto it_index = out_type_index_map.find(user_out_nodes[i].first);
  310. auto it_dt = out_type_dt_map.find(user_out_nodes[i].first);
  311. if ((it_index != out_type_index_map.end()) && (it_dt != out_type_dt_map.end())) {
  312. GELOGI("The output node [%s] need to be set output_type", user_out_nodes[i].first.c_str());
  313. (void)ge::AttrUtils::SetListDataType(op_desc, "_output_dt_list", it_dt->second);
  314. (void)ge::AttrUtils::SetListInt(op_desc, "_output_dt_index", it_index->second);
  315. }
  316. output_nodes_info.push_back(std::make_pair(out_node, user_out_nodes[i].second));
  317. }
  318. // default output node (leaf)
  319. if (user_out_nodes.empty()) {
  320. for (ge::NodePtr node : compute_graph->GetDirectNode()) {
  321. if (!node->GetInDataNodes().empty() && node->GetOutDataNodes().empty()) {
  322. Status ret = GetOutputLeaf(node, output_nodes_info);
  323. GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "find leaf fail.");
  324. }
  325. }
  326. }
  327. GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name);
  328. compute_graph->SetGraphOutNodesInfo(output_nodes_info);
  329. domi::GetContext().net_out_nodes = output_nodes_name;
  330. return domi::SUCCESS;
  331. }
  332. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ParseInputShape(
  333. const string &input_shape, unordered_map<string, vector<int64_t>> &shape_map,
  334. vector<pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input) {
  335. vector<string> shape_vec = StringUtils::Split(input_shape, ';');
  336. const int DEFAULT_SHAPE_PAIR_SIZE = 2;
  337. for (const auto &shape : shape_vec) {
  338. vector<string> shape_pair_vec = SplitInputShape(shape);
  339. if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) {
  340. ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
  341. {shape, kSplitError1, kInputShapeSample1});
  342. GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
  343. shape.c_str(), kSplitError1, kInputShapeSample1);
  344. return false;
  345. }
  346. if (shape_pair_vec[1].empty()) {
  347. ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
  348. {shape, kEmptyError, kInputShapeSample1});
  349. GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
  350. shape.c_str(), kEmptyError, kInputShapeSample1);
  351. return false;
  352. }
  353. vector<string> shape_value_strs = StringUtils::Split(shape_pair_vec[1], ',');
  354. vector<int64_t> shape_values;
  355. for (auto &shape_value_str : shape_value_strs) {
  356. // stoul: The method may throw an exception: invalid_argument/out_of_range
  357. if (std::string::npos != shape_value_str.find('.')) {
  358. ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
  359. {shape, kFloatNumError, kInputShapeSample2});
  360. GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.",
  361. shape.c_str(), kFloatNumError, kInputShapeSample2);
  362. return false;
  363. }
  364. long left_result = 0;
  365. try {
  366. left_result = stol(StringUtils::Trim(shape_value_str));
  367. if (!shape_value_str.empty() && (shape_value_str.front() == '-')) {
  368. // The value maybe dynamic shape [-1], need substr it and verify isdigit.
  369. shape_value_str = shape_value_str.substr(1);
  370. }
  371. for (char c : shape_value_str) {
  372. if (!isdigit(c)) {
  373. ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"},
  374. {shape, kDigitError, kInputShapeSample2});
  375. GELOGE(PARAM_INVALID, "--input_shape's shape value[%s] is not digit", shape_value_str.c_str());
  376. return false;
  377. }
  378. }
  379. } catch (const std::out_of_range &) {
  380. ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"},
  381. {"input_shape", shape_value_str});
  382. GELOGW("Input parameter[--input_shape]’s value[%s] cause out of range execption!", shape_value_str.c_str());
  383. return false;
  384. } catch (const std::invalid_argument &) {
  385. ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"},
  386. {"input_shape", shape_value_str});
  387. GELOGW("Input parameter[--input_shape]’s value[%s] cause invalid argument!", shape_value_str.c_str());
  388. return false;
  389. } catch (...) {
  390. ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"},
  391. {"input_shape", shape_value_str});
  392. GELOGW("Input parameter[--input_shape]’s value[%s] cause unkown execption!", shape_value_str.c_str());
  393. return false;
  394. }
  395. int64_t result = left_result;
  396. // - 1 is not currently supported
  397. if (!is_dynamic_input && result <= 0) {
  398. ErrorManager::GetInstance().ATCReportErrMessage("E10011", {"shape", "result"}, {shape, std::to_string(result)});
  399. GELOGW(
  400. "Input parameter[--input_shape]’s shape value[%s] is invalid, "
  401. "expect positive integer, but value is %ld.",
  402. shape.c_str(), result);
  403. return false;
  404. }
  405. shape_values.push_back(result);
  406. }
  407. shape_map.emplace(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values));
  408. user_shape_map.push_back(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values));
  409. }
  410. return true;
  411. }
  412. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ParseOutputNodes(const string &out_nodes) {
  413. try {
  414. // parse output node
  415. if (!out_nodes.empty()) {
  416. domi::GetContext().out_nodes_map.clear();
  417. domi::GetContext().user_out_nodes.clear();
  418. vector<string> nodes_v = StringUtils::Split(out_nodes, ';');
  419. for (const string &node : nodes_v) {
  420. vector<string> key_value_v = StringUtils::Split(node, ':');
  421. if (key_value_v.size() != 2) { // The size must be 2.
  422. ErrorManager::GetInstance().ATCReportErrMessage(
  423. "E10001", {"parameter", "value", "reason"},
  424. {"--out_nodes", node, "the correct format is \"node_name1:0;node_name1:1;node_name2:0\""});
  425. GELOGE(PARAM_INVALID,
  426. "The input format of --out_nodes is invalid, the correct format is "
  427. "\"node_name1:0;node_name1:1;node_name2:0\", while the actual input is %s.",
  428. node.c_str());
  429. return PARAM_INVALID;
  430. }
  431. auto iter = domi::GetContext().out_nodes_map.find(key_value_v[0]);
  432. // stoi: The method may throw an exception: invalid_argument/out_of_range
  433. if (!CheckDigitStr(key_value_v[1])) {
  434. ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
  435. {"--out_nodes", out_nodes, "is not positive integer"});
  436. GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s", out_nodes.c_str());
  437. return PARAM_INVALID;
  438. }
  439. int32_t index = stoi(StringUtils::Trim(key_value_v[1]));
  440. if (iter != domi::GetContext().out_nodes_map.end()) {
  441. iter->second.emplace_back(index);
  442. } else {
  443. std::vector<int32_t> index_v;
  444. index_v.emplace_back(index);
  445. domi::GetContext().out_nodes_map.emplace(key_value_v[0], index_v);
  446. }
  447. domi::GetContext().user_out_nodes.push_back(std::make_pair(key_value_v[0], index));
  448. }
  449. }
  450. } catch (std::invalid_argument &) {
  451. GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str());
  452. ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"out_nodes", out_nodes});
  453. return PARAM_INVALID;
  454. } catch (std::out_of_range &) {
  455. GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str());
  456. ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"out_nodes", out_nodes});
  457. return PARAM_INVALID;
  458. }
  459. return SUCCESS;
  460. }
  461. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ParseOpConf(const char *op_conf) {
  462. if (op_conf != nullptr && *op_conf != '\0') {
  463. // divided by ":"
  464. PropertiesManager::Instance().SetPropertyDelimiter(OP_CONF_DELIMITER);
  465. // Parsing the op_conf configuration item file
  466. if (!PropertiesManager::Instance().Init(op_conf)) {
  467. GELOGE(FAILED, "op_name_map init failed!");
  468. return FAILED;
  469. }
  470. // Return map and put it into ATC global variable
  471. domi::GetContext().op_conf_map = PropertiesManager::Instance().GetPropertyMap();
  472. }
  473. return SUCCESS;
  474. }
  475. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示