You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

single_op_parser.cc 26 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "offline/single_op_parser.h"
  17. #include <vector>
  18. #include <algorithm>
  19. #include <fstream>
  20. #include <sstream>
  21. #include <nlohmann/json.hpp>
  22. #include "framework/common/debug/ge_log.h"
  23. #include "common/util/error_manager/error_manager.h"
  24. #include "framework/common/ge_inner_error_codes.h"
  25. #include "framework/common/util.h"
  26. #include "graph/utils/tensor_utils.h"
  27. #include "graph/utils/type_utils.h"
  28. #include "graph/utils/op_desc_utils.h"
  29. #include "graph/operator_factory_impl.h"
  30. using Json = nlohmann::json;
  31. using std::string;
  32. using std::vector;
  33. using std::map;
  34. namespace ge {
  35. namespace {
  36. constexpr char const *kKeyOp = "op";
  37. constexpr char const *kKeyInputDesc = "input_desc";
  38. constexpr char const *kKeyOutputDesc = "output_desc";
  39. constexpr char const *kKeyAttr = "attr";
  40. constexpr char const *kKeyName = "name";
  41. constexpr char const *kKeyType = "type";
  42. constexpr char const *kKeyShape = "shape";
  43. constexpr char const *kKeyOriginShape = "origin_shape";
  44. constexpr char const *kKeyShapeRange = "shape_range";
  45. constexpr char const *kKeyValue = "value";
  46. constexpr char const *kKeyFormat = "format";
  47. constexpr char const *kKeyOriginFormat = "origin_format";
  48. constexpr char const *kFileSuffix = ".om";
  49. constexpr char const *kKeyDynamicInput = "dynamic_input";
  50. constexpr char const *kKeyDynamicOutput = "dynamic_output";
  51. constexpr char const *kKeyCompileFlag = "compile_flag";
  52. constexpr int kDumpJsonIndent = 2;
  53. constexpr int kShapeRangePairSize = 2;
  54. constexpr int kShapeRangeLow = 0;
  55. constexpr int kShapeRangeHigh = 1;
  56. constexpr int kMaxFileNameLen = 128;
  57. map<string, GeAttrValue::ValueType> kAttrTypeDict = {
  58. {"bool", GeAttrValue::VT_BOOL},
  59. {"int", GeAttrValue::VT_INT},
  60. {"float", GeAttrValue::VT_FLOAT},
  61. {"string", GeAttrValue::VT_STRING},
  62. {"list_bool", GeAttrValue::VT_LIST_BOOL},
  63. {"list_int", GeAttrValue::VT_LIST_INT},
  64. {"list_float", GeAttrValue::VT_LIST_FLOAT},
  65. {"list_string", GeAttrValue::VT_LIST_STRING},
  66. {"list_list_int", GeAttrValue::VT_LIST_LIST_INT},
  67. {"data_type", GeAttrValue::VT_DATA_TYPE},
  68. };
  69. map<string, DataType> kDataTypeDict = {
  70. {"bool", DT_BOOL},
  71. {"int8", DT_INT8},
  72. {"uint8", DT_UINT8},
  73. {"int16", DT_INT16},
  74. {"uint16", DT_UINT16},
  75. {"int32", DT_INT32},
  76. {"uint32", DT_UINT32},
  77. {"int64", DT_INT64},
  78. {"uint64", DT_UINT64},
  79. {"float16", DT_FLOAT16},
  80. {"half", DT_FLOAT16},
  81. {"fp16", DT_FLOAT16},
  82. {"float", DT_FLOAT},
  83. {"float32", DT_FLOAT},
  84. {"double", DT_DOUBLE},
  85. {"complex64", DT_COMPLEX64},
  86. {"complex128", DT_COMPLEX128}
  87. };
  88. map<string, Format> kFormatDict = {
  89. {"nchw", FORMAT_NCHW},
  90. {"nhwc", FORMAT_NHWC},
  91. {"nd", FORMAT_ND},
  92. {"nc1hwc0", FORMAT_NC1HWC0},
  93. {"fractal_z", FORMAT_FRACTAL_Z},
  94. {"nc1c0hwpad", FORMAT_NC1C0HWPAD},
  95. {"nhwc1c0", FORMAT_NHWC1C0},
  96. {"fsr_nchw", FORMAT_FSR_NCHW},
  97. {"fractal_deconv", FORMAT_FRACTAL_DECONV},
  98. {"c1hwnc0", FORMAT_C1HWNC0},
  99. {"fractal_deconv_transpose", FORMAT_FRACTAL_DECONV_TRANSPOSE},
  100. {"fractal_deconv_sp_stride_trans", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS},
  101. {"nc1hwc0_c04", FORMAT_NC1HWC0_C04},
  102. {"fractal_z_c04", FORMAT_FRACTAL_Z_C04},
  103. {"chwn", FORMAT_CHWN},
  104. {"deconv_sp_stride8_trans", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS},
  105. {"nc1khkwhwc0", FORMAT_NC1KHKWHWC0},
  106. {"bn_weight", FORMAT_BN_WEIGHT},
  107. {"filter_hwck", FORMAT_FILTER_HWCK},
  108. {"hwcn", FORMAT_HWCN},
  109. {"lookup_lookups", FORMAT_HASHTABLE_LOOKUP_LOOKUPS},
  110. {"lookup_keys", FORMAT_HASHTABLE_LOOKUP_KEYS},
  111. {"lookup_value", FORMAT_HASHTABLE_LOOKUP_VALUE},
  112. {"lookup_output", FORMAT_HASHTABLE_LOOKUP_OUTPUT},
  113. {"lookup_hits", FORMAT_HASHTABLE_LOOKUP_HITS},
  114. {"md", FORMAT_MD},
  115. {"c1hwncoc0", FORMAT_C1HWNCoC0},
  116. {"fractal_nz", FORMAT_FRACTAL_NZ},
  117. {"ndhwc", FORMAT_NDHWC},
  118. {"ncdhw", FORMAT_NCDHW},
  119. {"dhwcn", FORMAT_DHWCN},
  120. {"dhwnc", FORMAT_DHWNC},
  121. {"ndc1hwc0", FORMAT_NDC1HWC0},
  122. {"fractal_z_3d", FORMAT_FRACTAL_Z_3D},
  123. {"fractal_z_3d_transpose", FORMAT_FRACTAL_Z_3D_TRANSPOSE},
  124. {"cn", FORMAT_CN},
  125. {"nc", FORMAT_NC},
  126. {"fractal_zn_lstm", FORMAT_FRACTAL_ZN_LSTM},
  127. {"fractal_z_g", FORMAT_FRACTAL_Z_G}
  128. };
  129. std::string GenerateFileName(const SingleOpDesc &single_op_desc, int index) {
  130. std::stringstream file_name_ss;
  131. file_name_ss << index;
  132. file_name_ss << "_" << single_op_desc.op;
  133. for (auto &desc : single_op_desc.input_desc) {
  134. file_name_ss << "_" << desc.type << "_" << desc.format;
  135. for (auto dim : desc.dims) {
  136. file_name_ss << "_" << dim;
  137. }
  138. }
  139. for (auto &desc : single_op_desc.output_desc) {
  140. file_name_ss << "_" << desc.type << "_" << desc.format;
  141. for (auto dim : desc.dims) {
  142. file_name_ss << "_" << dim;
  143. }
  144. }
  145. std::string file_name = file_name_ss.str();
  146. if (file_name.length() > kMaxFileNameLen) {
  147. GELOGI("Trim file name for it is too long, origin file name = %s", file_name.c_str());
  148. file_name = file_name.substr(0, kMaxFileNameLen);
  149. }
  150. file_name += kFileSuffix;
  151. return file_name;
  152. }
  153. } // namespace
  154. bool AttrValueIsString(const Json &j, const string &key) {
  155. try {
  156. string tmp_str = j.at(key).get<string>();
  157. return true;
  158. } catch (Json::type_error &e) {
  159. return false;
  160. }
  161. }
  162. template<typename T>
  163. T GetValue(const map<string, T> &dict, string &key, T default_val) {
  164. transform(key.begin(), key.end(), key.begin(), ::tolower);
  165. auto it = dict.find(key);
  166. if (it == dict.end()) {
  167. return default_val;
  168. }
  169. return it->second;
  170. }
  171. template<typename T>
  172. void SetAttrValue(const Json &j, SingleOpAttr &attr) {
  173. // when attr type is "data_type", we support two kinds of attr value.
  174. // 1. value: "DT_FLOAT", "DT_INT32", "DT_INT8" ...
  175. // 2. value: 1, 3 ...
  176. if (j.at(kKeyType).get<string>() == "data_type" && AttrValueIsString(j, kKeyValue)) {
  177. string type_str = j.at(kKeyValue).get<string>();
  178. DataType dtype = TypeUtils::SerialStringToDataType(type_str);
  179. attr.value.SetValue<DataType>(dtype);
  180. return;
  181. }
  182. attr.value.SetValue<T>(j.at(kKeyValue).get<T>());
  183. }
  184. void from_json(const Json &j, SingleOpTensorDesc &desc) {
  185. bool is_tensor_valid = true;
  186. desc.dims = j.at(kKeyShape).get<vector<int64_t>>();
  187. auto it = j.find(kKeyShapeRange);
  188. if (it != j.end()) {
  189. desc.dim_ranges = j.at(kKeyShapeRange).get<vector<std::vector<int64_t>>>();
  190. }
  191. it = j.find(kKeyOriginShape);
  192. if (it != j.end()) {
  193. desc.ori_dims = j.at(kKeyOriginShape).get<vector<int64_t>>();
  194. }
  195. string format_str = j.at(kKeyFormat).get<string>();
  196. string type_str = j.at(kKeyType).get<string>();
  197. desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED);
  198. desc.type = GetValue(kDataTypeDict, type_str, DT_UNDEFINED);
  199. is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsFormatValid(format_str);
  200. is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsDataTypeValid(type_str);
  201. it = j.find(kKeyOriginFormat);
  202. if (it != j.end()) {
  203. string origin_format_str = j.at(kKeyOriginFormat).get<string>();
  204. is_tensor_valid = is_tensor_valid && ge::TypeUtils::IsFormatValid(origin_format_str);
  205. desc.ori_format = GetValue(kFormatDict, origin_format_str, FORMAT_RESERVED);
  206. }
  207. auto tensor_name = j.find(kKeyName);
  208. if (tensor_name != j.end()) {
  209. desc.name = tensor_name->get<string>();
  210. }
  211. auto dynamic_input_name = j.find(kKeyDynamicInput);
  212. if (dynamic_input_name != j.end()) {
  213. desc.dynamic_input_name = dynamic_input_name->get<string>();
  214. }
  215. if (!is_tensor_valid) {
  216. desc.SetValidFlag(is_tensor_valid);
  217. }
  218. }
  219. void from_json(const Json &j, SingleOpAttr &attr) {
  220. attr.name = j.at(kKeyName).get<string>();
  221. attr.type = j.at(kKeyType).get<string>();
  222. auto it = kAttrTypeDict.find(attr.type);
  223. if (it == kAttrTypeDict.end()) {
  224. GELOGE(UNSUPPORTED, "[Find][JsonAttr] name=%s, type=%s failed for Unsupported type.",
  225. attr.name.c_str(), attr.type.c_str());
  226. REPORT_INNER_ERROR("E19999", "Find jsonattr name=%s, type=%s failed for Unsupported type.",
  227. attr.name.c_str(), attr.type.c_str());
  228. return;
  229. }
  230. switch (it->second) {
  231. case GeAttrValue::VT_BOOL:
  232. SetAttrValue<bool>(j, attr);
  233. break;
  234. case GeAttrValue::VT_INT:
  235. SetAttrValue<int64_t>(j, attr);
  236. break;
  237. case GeAttrValue::VT_FLOAT:
  238. SetAttrValue<float>(j, attr);
  239. break;
  240. case GeAttrValue::VT_STRING:
  241. SetAttrValue<string>(j, attr);
  242. break;
  243. case GeAttrValue::VT_LIST_BOOL:
  244. SetAttrValue<vector<bool>>(j, attr);
  245. break;
  246. case GeAttrValue::VT_LIST_INT:
  247. SetAttrValue<vector<int64_t>>(j, attr);
  248. break;
  249. case GeAttrValue::VT_LIST_FLOAT:
  250. SetAttrValue<vector<float>>(j, attr);
  251. break;
  252. case GeAttrValue::VT_LIST_STRING:
  253. SetAttrValue<vector<string>>(j, attr);
  254. break;
  255. case GeAttrValue::VT_LIST_LIST_INT:
  256. SetAttrValue<vector<vector<int64_t>>>(j, attr);
  257. break;
  258. case GeAttrValue::VT_DATA_TYPE:
  259. SetAttrValue<DataType>(j, attr);
  260. break;
  261. default:
  262. GELOGE(UNSUPPORTED, "[Find][JsonAttr] name=%s, type=%s failed for Unsupported type.",
  263. attr.name.c_str(), attr.type.c_str());
  264. REPORT_INNER_ERROR("E19999", "Find jsonattr name=%s, type=%s failed for Unsupported type.",
  265. attr.name.c_str(), attr.type.c_str());
  266. break;
  267. }
  268. }
  269. void from_json(const Json &j, SingleOpDesc &desc) {
  270. auto op = j.find(kKeyOp);
  271. if (op != j.end()) {
  272. desc.op = j.at(kKeyOp).get<string>();
  273. }
  274. auto input_desc = j.find(kKeyInputDesc);
  275. if (input_desc != j.end()) {
  276. desc.input_desc = input_desc->get<vector<SingleOpTensorDesc>>();
  277. }
  278. auto output_desc = j.find(kKeyOutputDesc);
  279. if (output_desc != j.end()) {
  280. desc.output_desc = output_desc->get<vector<SingleOpTensorDesc>>();
  281. }
  282. auto attr_field = j.find(kKeyAttr);
  283. if (attr_field != j.end()) {
  284. desc.attrs = attr_field->get<vector<SingleOpAttr>>();
  285. }
  286. auto compile_flag = j.find(kKeyCompileFlag);
  287. if (compile_flag != j.end()) {
  288. desc.compile_flag = compile_flag->get<int32_t>();
  289. }
  290. }
  291. Status SingleOpParser::ReadJsonFile(const std::string &file, Json &json_obj) {
  292. std::string real_path = RealPath(file.c_str());
  293. if (real_path.empty()) {
  294. ErrorManager::GetInstance().ATCReportErrMessage("E10023", {"value"}, {file});
  295. GELOGE(FAILED, "[Read][JsonFile]Input parameter[--singleop]'s value[%s] is not a valid path.", file.c_str());
  296. return INTERNAL_ERROR;
  297. }
  298. std::ifstream ifs(real_path);
  299. if (!ifs.is_open()) {
  300. ErrorManager::GetInstance().ATCReportErrMessage("E10024", {"value"}, {file});
  301. GELOGE(FAILED, "[Open][JsonFile] failed for file[%s] provided in input parameter[--singleop].", file.c_str());
  302. return FAILED;
  303. }
  304. try {
  305. ifs >> json_obj;
  306. } catch (const std::exception &e) {
  307. ErrorManager::GetInstance().ATCReportErrMessage("E10025", {"realpath", "errmsg"}, {real_path, e.what()});
  308. GELOGE(PARAM_INVALID,
  309. "[Parse][JsonFile] fail for file[%s] provided in input parameter[--singleop], exception = %s.",
  310. real_path.c_str(), e.what());
  311. return PARAM_INVALID;
  312. }
  313. ifs.close();
  314. return SUCCESS;
  315. }
  316. bool SingleOpParser::Validate(const SingleOpDesc &op_desc) {
  317. if (op_desc.op.empty()) {
  318. ErrorManager::GetInstance().ATCReportErrMessage("E10026");
  319. GELOGE(PARAM_INVALID, "[Check][Param] fail for name of input SingleOpDesc is empty.");
  320. return false;
  321. }
  322. int index = 0;
  323. for (auto &tensor_desc : op_desc.input_desc) {
  324. if (!tensor_desc.GetValidFlag()) {
  325. ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"op_name", "input", "type", "index"},
  326. {op_desc.op, "input", "tensor", std::to_string(index)});
  327. GELOGE(PARAM_INVALID,
  328. "[Check][Param] fail for Input's dataType or format is invalid when the index is %d", index);
  329. return false;
  330. }
  331. if ((tensor_desc.type == DT_UNDEFINED && tensor_desc.format != FORMAT_RESERVED) ||
  332. (tensor_desc.type != DT_UNDEFINED && tensor_desc.format == FORMAT_RESERVED)){
  333. ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"op_name", "input", "type", "index"},
  334. {op_desc.op, "input", "datatype or format", std::to_string(index)});
  335. GELOGE(PARAM_INVALID, "[Check][Param]Input's dataType or format is invalid when the index is %d", index);
  336. return false;
  337. }
  338. ++index;
  339. }
  340. index = 0;
  341. for (auto &tensor_desc : op_desc.output_desc) {
  342. if (!tensor_desc.GetValidFlag()) {
  343. ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"op_name", "input", "type", "index"},
  344. {op_desc.op, "output", "tensor", std::to_string(index)});
  345. GELOGE(PARAM_INVALID, "[Check][Param]fail for Output's dataType is invalid when the index is %d", index);
  346. return false;
  347. }
  348. if (tensor_desc.type == DT_UNDEFINED) {
  349. ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"op_name", "input", "type", "index"},
  350. {op_desc.op, "output", "datatype", std::to_string(index)});
  351. GELOGE(PARAM_INVALID, "[Check][Param]Output's dataType is invalid when the index is %d", index);
  352. return false;
  353. }
  354. if (tensor_desc.format == FORMAT_RESERVED) {
  355. ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"op_name", "input", "type", "index"},
  356. {op_desc.op, "output", "format", std::to_string(index)});
  357. GELOGE(PARAM_INVALID, "[Check][Param]Output's format is invalid when the index is %d", index);
  358. return false;
  359. }
  360. ++index;
  361. }
  362. for (auto &attr : op_desc.attrs) {
  363. if (attr.name.empty()) {
  364. ErrorManager::GetInstance().ATCReportErrMessage("E10029", {"op_name"}, {op_desc.op});
  365. GELOGE(PARAM_INVALID, "[Parse][Attr]attr name is empty");
  366. return false;
  367. }
  368. if (attr.value.IsEmpty()) {
  369. ErrorManager::GetInstance().ATCReportErrMessage("E10030", {"op_name", "attrname"}, {op_desc.op, attr.name});
  370. GELOGE(PARAM_INVALID, "[Parse][Attr] fail for vale of attr name:\"%s\" is empty. ", attr.name.c_str());
  371. return false;
  372. }
  373. }
  374. return true;
  375. }
  376. std::unique_ptr<OpDesc> SingleOpParser::CreateOpDesc(const string &op_type) {
  377. return std::unique_ptr<OpDesc>(new(std::nothrow) OpDesc(op_type, op_type));
  378. }
  379. Status SingleOpParser::UpdateDynamicTensorName(std::vector<SingleOpTensorDesc> &desc) {
  380. std::map<std::string, int> dynamic_name_map;
  381. for (auto &tensor : desc) {
  382. if (tensor.dynamic_input_name.empty()) {
  383. continue;
  384. }
  385. if (dynamic_name_map.find(tensor.dynamic_input_name) == dynamic_name_map.end()) {
  386. dynamic_name_map[tensor.dynamic_input_name] = 0;
  387. } else {
  388. dynamic_name_map[tensor.dynamic_input_name]++;
  389. }
  390. tensor.name = tensor.dynamic_input_name + std::to_string(dynamic_name_map[tensor.dynamic_input_name]);
  391. }
  392. GELOGD("Update dynamic tensor name success!");
  393. return SUCCESS;
  394. }
  395. Status SingleOpParser::ConvertToBuildParam(int index,
  396. const SingleOpDesc &single_op_desc,
  397. SingleOpBuildParam &build_param) {
  398. auto op_desc = CreateOpDesc(single_op_desc.op);
  399. GE_CHECK_NOTNULL(op_desc);
  400. for (auto &desc : single_op_desc.input_desc) {
  401. GeTensorDesc ge_tensor_desc(GeShape(desc.dims),
  402. desc.format,
  403. desc.type);
  404. auto ori_format_to_set = desc.ori_format != FORMAT_RESERVED ? desc.ori_format : desc.format;
  405. auto ori_dims = !desc.ori_dims.empty() ? desc.ori_dims : desc.dims;
  406. ge_tensor_desc.SetOriginFormat(ori_format_to_set);
  407. ge_tensor_desc.SetOriginShape(GeShape(ori_dims));
  408. GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc));
  409. TensorUtils::SetRealDimCnt(ge_tensor_desc, ori_dims.size());
  410. TensorUtils::SetInputTensor(ge_tensor_desc, true);
  411. TensorUtils::SetOutputTensor(ge_tensor_desc, false);
  412. if (desc.name.empty()) {
  413. op_desc->AddInputDesc(ge_tensor_desc);
  414. } else {
  415. op_desc->AddInputDesc(desc.name, ge_tensor_desc);
  416. }
  417. build_param.inputs.emplace_back(ge_tensor_desc);
  418. }
  419. for (auto &desc : single_op_desc.output_desc) {
  420. GeTensorDesc ge_tensor_desc(GeShape(desc.dims),
  421. desc.format,
  422. desc.type);
  423. auto ori_format_to_set = desc.ori_format != FORMAT_RESERVED ? desc.ori_format : desc.format;
  424. auto ori_dims = !desc.ori_dims.empty() ? desc.ori_dims : desc.dims;
  425. ge_tensor_desc.SetOriginFormat(ori_format_to_set);
  426. ge_tensor_desc.SetOriginShape(GeShape(ori_dims));
  427. GE_CHK_STATUS_RET_NOLOG(SetShapeRange(op_desc->GetName(), desc, ge_tensor_desc));
  428. TensorUtils::SetRealDimCnt(ge_tensor_desc, ori_dims.size());
  429. TensorUtils::SetInputTensor(ge_tensor_desc, false);
  430. TensorUtils::SetOutputTensor(ge_tensor_desc, true);
  431. if (desc.name.empty()) {
  432. op_desc->AddOutputDesc(ge_tensor_desc);
  433. } else {
  434. op_desc->AddOutputDesc(desc.name, ge_tensor_desc);
  435. }
  436. build_param.outputs.emplace_back(ge_tensor_desc);
  437. }
  438. for (const auto &attr : single_op_desc.attrs) {
  439. op_desc->SetAttr(attr.name, attr.value);
  440. }
  441. if (VerifyOpInputOutputSizeByIr(*op_desc) != SUCCESS) {
  442. GELOGE(PARAM_INVALID, "[Verify][OpInputOutputSize] fail for input op [%s] invalid.", op_desc->GetType().c_str());
  443. return PARAM_INVALID;
  444. }
  445. build_param.file_name = GenerateFileName(single_op_desc, index);
  446. build_param.op_desc.reset(op_desc.release());
  447. return SUCCESS;
  448. }
  449. Status SingleOpParser::VerifyOpInputOutputSizeByIr(const OpDesc &current_op_desc) {
  450. ge::Operator operator_ir = ge::OperatorFactory::CreateOperator("tmp_operator", current_op_desc.GetType());
  451. if (!operator_ir.IsEmpty()) {
  452. auto opdesc_ir = ge::OpDescUtils::GetOpDescFromOperator(operator_ir);
  453. GE_CHECK_NOTNULL(opdesc_ir);
  454. size_t current_opdesc_inputs_num = current_op_desc.GetInputsSize();
  455. size_t ir_opdesc_inputs_num = opdesc_ir->GetInputsSize();
  456. if (current_opdesc_inputs_num < ir_opdesc_inputs_num) {
  457. string reason = "is smaller than the ir needed input size " + std::to_string(ir_opdesc_inputs_num);
  458. ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
  459. {current_op_desc.GetName(), "input size " + std::to_string(current_opdesc_inputs_num), reason});
  460. GELOGE(PARAM_INVALID,
  461. "[Verify][OpInputOutputSize]This op:%s input size %zu is smaller than the ir needed input size %zu",
  462. current_op_desc.GetName().c_str(), current_opdesc_inputs_num, ir_opdesc_inputs_num);
  463. return PARAM_INVALID;
  464. }
  465. size_t current_opdesc_outputs_num = current_op_desc.GetOutputsSize();
  466. size_t ir_opdesc_outputs_num = opdesc_ir->GetOutputsSize();
  467. if (current_opdesc_outputs_num < ir_opdesc_outputs_num) {
  468. string reason = "is smaller than the ir needed output size " + std::to_string(ir_opdesc_outputs_num);
  469. ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
  470. {current_op_desc.GetName(), "output size " + std::to_string(current_opdesc_outputs_num), reason});
  471. GELOGE(PARAM_INVALID,
  472. "[Verify][OpInputOutputSize]This op:%s output size %zu is smaller than the ir needed output size %zu",
  473. current_op_desc.GetName().c_str(), current_opdesc_outputs_num, ir_opdesc_outputs_num);
  474. return PARAM_INVALID;
  475. }
  476. }
  477. return SUCCESS;
  478. }
  479. Status SingleOpParser::SetShapeRange(const std::string &op_name,
  480. const SingleOpTensorDesc &tensor_desc,
  481. GeTensorDesc &ge_tensor_desc) {
  482. auto num_shape_ranges = tensor_desc.dim_ranges.size();
  483. GELOGD("Number of shape ranges = %zu", num_shape_ranges);
  484. auto it = std::find(tensor_desc.dims.begin(), tensor_desc.dims.end(), ge::UNKNOWN_DIM_NUM);
  485. if (it != tensor_desc.dims.end()) {
  486. if (tensor_desc.dims != ge::UNKNOWN_RANK) {
  487. ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
  488. {op_name,
  489. "shape",
  490. "has unknown rank but dim size is not one"});
  491. GELOGE(PARAM_INVALID, "[Set][ShapeRange]Invalid tensor shape:%s.",
  492. ge_tensor_desc.MutableShape().ToString().c_str());
  493. return PARAM_INVALID;
  494. }
  495. if (!tensor_desc.dim_ranges.empty()) {
  496. ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
  497. {op_name,
  498. "shape range",
  499. "is not needed while the rank the shape is unknown"});
  500. GELOGE(PARAM_INVALID, "[Set][ShapeRange]Shape range is not needed while the rank the shape is unknown.");
  501. return PARAM_INVALID;
  502. }
  503. GELOGD("Shape is unknown rank, do not set shape range");
  504. return SUCCESS;
  505. }
  506. std::vector<std::pair<int64_t, int64_t>> shape_range;
  507. size_t range_index = 0;
  508. for (auto dim : tensor_desc.dims) {
  509. if (dim >= 0) {
  510. shape_range.emplace_back(dim, dim);
  511. GELOGD("Adding shape range: [%ld, %ld]", dim, dim);
  512. } else {
  513. GELOGD("To get shape range by index = %zu", range_index);
  514. if (range_index >= num_shape_ranges) {
  515. string reason = "is smaller than the unknown dim size " + std::to_string(++range_index);
  516. ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
  517. {op_name,
  518. "shape range size " + std::to_string(num_shape_ranges),
  519. reason});
  520. GELOGE(PARAM_INVALID, "[Set][ShapeRange]The number of shape_range mismatches that of unknown dims.");
  521. return PARAM_INVALID;
  522. }
  523. auto &range = tensor_desc.dim_ranges[range_index];
  524. if (range.size() != kShapeRangePairSize) {
  525. string reason = "has " + std::to_string(range.size()) + " item(s)";
  526. ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
  527. {op_name,
  528. "shape range " + std::to_string(range_index),
  529. reason});
  530. GELOGE(PARAM_INVALID, "[Set][ShapeRange]Invalid shape range entry. index = %zu, size = %zu",
  531. range_index, range.size());
  532. return PARAM_INVALID;
  533. }
  534. shape_range.emplace_back(range[kShapeRangeLow], range[kShapeRangeHigh]);
  535. GELOGD("Adding shape range: [%ld, %ld]", range[kShapeRangeLow], range[kShapeRangeHigh]);
  536. ++range_index;
  537. }
  538. }
  539. if (num_shape_ranges != range_index) {
  540. string reason = "is greater than the unknown dim size " + std::to_string(range_index);
  541. ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"},
  542. {op_name,
  543. "shape range size " + std::to_string(num_shape_ranges),
  544. reason});
  545. GELOGE(PARAM_INVALID,
  546. "[Set][ShapeRange]The number of shape_range(%zu) mismatches that of unknown dims(%zu).",
  547. num_shape_ranges, range_index);
  548. return PARAM_INVALID;
  549. }
  550. if (range_index > 0) {
  551. ge_tensor_desc.SetShapeRange(shape_range);
  552. }
  553. return SUCCESS;
  554. }
  555. Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector<SingleOpBuildParam> &op_list) {
  556. int index = 0;
  557. try {
  558. Json single_op_list_json;
  559. auto ret = ReadJsonFile(file, single_op_list_json);
  560. if (ret != SUCCESS) {
  561. return ret;
  562. }
  563. int32_t compile_flag = 0;
  564. for (const Json &single_op_json : single_op_list_json) {
  565. SingleOpDesc single_op_desc;
  566. GELOGI("Parsing op[%d], jsonStr = %s", index, single_op_json.dump(kDumpJsonIndent).c_str());
  567. single_op_desc = single_op_json;
  568. GELOGD("Compile flag is %d.", single_op_desc.compile_flag);
  569. if (single_op_desc.compile_flag == 1) {
  570. compile_flag = single_op_desc.compile_flag;
  571. continue;
  572. }
  573. if (UpdateDynamicTensorName(single_op_desc.input_desc) != SUCCESS) {
  574. GELOGE(FAILED, "[Update][DynamicTensorName] failed for invalid input param!");
  575. REPORT_CALL_ERROR("E19999", "UpdateDynamicTensorName failed for invalid input param.");
  576. return FAILED;
  577. }
  578. if (!Validate(single_op_desc)) {
  579. GELOGE(PARAM_INVALID,
  580. "[Check][OpDesc]Validate the index[%d] of op failed when read json file[%s].", index, file.c_str());
  581. return PARAM_INVALID;
  582. }
  583. SingleOpBuildParam param;
  584. ret = ConvertToBuildParam(index, single_op_desc, param);
  585. if (ret != SUCCESS) {
  586. return ret;
  587. }
  588. param.compile_flag = compile_flag;
  589. op_list.emplace_back(param);
  590. GELOGI("Parse the index[%d] of op success", index);
  591. index += 1;
  592. }
  593. } catch (const nlohmann::json::exception &e) {
  594. REPORT_INNER_ERROR("E19999", "parse singleop file:%s failed, catch exception:%s, current index:%d",
  595. file.c_str(), e.what(), index);
  596. GELOGE(PARAM_INVALID, "[Parse][OpList] the index:%d of op failed when read json file:%s, exception:%s",
  597. index, file.c_str(), e.what());
  598. return PARAM_INVALID;
  599. }
  600. return SUCCESS;
  601. }
  602. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示