You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_generator.cc 41 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago

  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "generator/ge_generator.h"
  17. #include <atomic>
  18. #include "common/ge/ge_util.h"
  19. #include "common/ge/plugin_manager.h"
  20. #include "common/helper/model_helper.h"
  21. #include "common/helper/om_file_helper.h"
  22. #include "common/util.h"
  23. #include "common/util/error_manager/error_manager.h"
  24. #include "framework/common/debug/ge_log.h"
  25. #include "framework/common/debug/log.h"
  26. #include "ge/ge_api.h"
  27. #include "graph/debug/ge_attr_define.h"
  28. #include "graph/ge_context.h"
  29. #include "graph/manager/graph_manager.h"
  30. #include "graph/manager/util/rt_context_util.h"
  31. #include "graph/opsproto_manager.h"
  32. #include "graph/utils/graph_utils.h"
  33. #include "graph/utils/type_utils.h"
  34. #include "init/gelib.h"
  35. #include "model/ge_model.h"
  36. using std::map;
  37. using std::string;
  38. using std::vector;
  39. namespace {
  40. const char *const kAttrOpType = "op_type";
  41. const char *const kEngineNameDefault = "default";
  42. const char *const kVectorEngine = "VectorEngine";
  43. const char *const kAIcoreEngine = "AIcoreEngine";
  44. const char *const kFileNameSuffix = "online";
  45. const char *const kAicpuAllshape = "_AllShape";
  46. constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape";
  47. const int64_t kDynamicDimValue = -2;
  48. const int kDefaultDeviceId = 0;
  49. const int kDefaultJobId = 0;
  50. std::map<ge::OpEngineType, std::string> engine_type_map{
  51. {ge::ENGINE_SYS, kEngineNameDefault},
  52. {ge::ENGINE_AICORE, kAIcoreEngine},
  53. {ge::ENGINE_VECTOR, kVectorEngine}};
  54. bool ContainsDynamicInpus(const ge::OpDesc &op_desc) {
  55. for (auto &tensor_desc : op_desc.GetAllInputsDescPtr()) {
  56. if (tensor_desc->MutableShape().IsUnknownShape()) {
  57. GELOGI("Contains unknown shape input. set is_dynamic_input to true.");
  58. return true;
  59. }
  60. }
  61. return false;
  62. }
  63. bool IsOptional(const ge::GeTensorDesc &tensor_desc) {
  64. return tensor_desc.GetFormat() == ge::FORMAT_RESERVED && tensor_desc.GetDataType() == ge::DT_UNDEFINED;
  65. }
  66. } // namespace
  67. namespace ge {
  68. static Status CheckEngineTypeSupport(const NodePtr &node, OpEngineType engine_type) {
  69. const OpDescPtr &op_desc = node->GetOpDesc();
  70. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  71. if (engine_type == ENGINE_SYS) {
  72. GELOGI("CheckEngineType: use default engine.");
  73. return SUCCESS;
  74. }
  75. // get op engine name
  76. string op_engine_name;
  77. auto iter = engine_type_map.find(engine_type);
  78. if (iter != engine_type_map.end()) {
  79. op_engine_name = iter->second;
  80. GELOGI("CheckEngineType: engine type: %d", static_cast<int>(engine_type));
  81. } else {
  82. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  83. {op_desc->GetName(), op_desc->GetType(), "engine type",
  84. "it only support default/AIcoreEngine/VectorEngine"});
  85. GELOGE(FAILED, "[Check][EngineType]value:%d not support, "
  86. "only support default/AIcoreEngine/VectorEngine now", static_cast<int>(engine_type));
  87. return FAILED;
  88. }
  89. if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
  90. op_desc->SetOpEngineName(op_engine_name);
  91. op_desc->SetOpKernelLibName(op_engine_name);
  92. return SUCCESS;
  93. }
  94. // set op engine name and opkernelLib. when engine support
  95. std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
  96. if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
  97. GELOGE(GE_CLI_GE_NOT_INITIALIZED, "CheckEngineType failed.");
  98. return FAILED;
  99. }
  100. OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj();
  101. std::vector<OpInfo> op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType());
  102. if (op_infos.empty()) {
  103. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  104. {op_desc->GetName(), op_desc->GetType(), "optype", "it can not find"});
  105. GELOGE(FAILED, "CheckEngineType: Can not get op info by op type %s", op_desc->GetType().c_str());
  106. return FAILED;
  107. }
  108. string kernel_name;
  109. for (const auto &it : op_infos) {
  110. if (it.engine == op_engine_name) {
  111. kernel_name = it.opKernelLib;
  112. break;
  113. }
  114. }
  115. if (kernel_name.empty()) {
  116. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  117. {op_desc->GetName(), op_desc->GetType(), "engine name" + FmtToStr(op_engine_name), "it can not find"});
  118. GELOGE(FAILED, "CheckEngineType:Can not find ops kernel, engine name: %s.", op_engine_name.c_str());
  119. return FAILED;
  120. }
  121. auto &kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores();
  122. auto kernel_info_store = kernel_map.find(kernel_name);
  123. if (kernel_info_store != kernel_map.end()) {
  124. std::string unsupported_reason;
  125. if (kernel_info_store->second->CheckSupported(node, unsupported_reason)) {
  126. op_desc->SetOpEngineName(op_engine_name);
  127. op_desc->SetOpKernelLibName(kernel_name);
  128. GELOGI("CheckEngineType:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(),
  129. op_engine_name.c_str(), op_desc->GetName().c_str());
  130. return SUCCESS;
  131. } else {
  132. ErrorManager::GetInstance().ATCReportErrMessage(
  133. "E13002", {"optype", "opskernel", "reason"}, {op_desc->GetType(), kernel_name, unsupported_reason});
  134. GELOGE(FAILED, "CheckEngineType: check support failed, Op type %s of ops kernel %s is unsupported, reason:%s",
  135. op_desc->GetType().c_str(), kernel_name.c_str(), unsupported_reason.c_str());
  136. return FAILED;
  137. }
  138. } else {
  139. ErrorManager::GetInstance().ATCReportErrMessage(
  140. "E13003", {"opname", "optype"}, {op_desc->GetName(), op_desc->GetType()});
  141. GELOGE(FAILED,
  142. "CheckEngineType:Can not find any supported ops kernel info store by kernel_name %s,"
  143. "op type is %s, op name is %s",
  144. kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str());
  145. }
  146. return FAILED;
  147. }
  148. static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const GeTensorDesc &tensor, int32_t index,
  149. bool attr, int32_t &data_index) {
  150. GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
  151. GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);
  152. auto format = tensor.GetFormat();
  153. auto data_type = tensor.GetDataType();
  154. if (format == FORMAT_RESERVED && data_type == DT_UNDEFINED) {
  155. return SUCCESS;
  156. }
  157. string op_type;
  158. bool is_const = false;
  159. (void)AttrUtils::GetBool(tensor, CONST_ATTR_NAME_INPUT, is_const);
  160. if (is_const) {
  161. GELOGD("Get input[%d] is const", index);
  162. op_type = CONSTANTOP;
  163. } else if (!AttrUtils::GetStr(tensor, kAttrOpType, op_type) || op_type.empty()) {
  164. op_type = DATA;
  165. }
  166. string op_name = node->GetName() + "_in_" + std::to_string(index);
  167. OpDescPtr data_op = MakeShared<ge::OpDesc>(op_name, op_type);
  168. if (data_op == nullptr) {
  169. return FAILED;
  170. }
  171. if (is_const) {
  172. ConstGeTensorPtr tensor_value;
  173. if (!AttrUtils::GetTensor(tensor, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
  174. GELOGE(FAILED, "Get value failed, node name:%s.", tensor.GetName().c_str());
  175. return FAILED;
  176. }
  177. if (!AttrUtils::SetTensor(data_op, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
  178. GELOGE(FAILED, "Set attr ATTR_NAME_WEIGHTS fail.");
  179. return FAILED;
  180. }
  181. }
  182. (void)AttrUtils::SetBool(data_op, "_is_single_op", true);
  183. GE_CHK_BOOL_EXEC(data_op->AddInputDesc(tensor) == GRAPH_SUCCESS, return FAILED,
  184. "[Add][InputDesc]fail for node:%s", data_op->GetName().c_str());
  185. GE_CHK_BOOL_EXEC(data_op->AddOutputDesc(tensor) == GRAPH_SUCCESS, return FAILED,
  186. "[Add][OutputDesc]fail for node:%s", data_op->GetName().c_str());
  187. if (attr && !is_const) {
  188. GE_CHK_BOOL_EXEC(AttrUtils::SetInt(data_op, ATTR_NAME_INDEX, data_index), return FAILED,
  189. "[Set][Attr:%s]fail for node:%s", ATTR_NAME_INDEX.c_str(), data_op->GetName().c_str());
  190. ++data_index;
  191. }
  192. ge::NodePtr arg_node = graph->AddNode(data_op);
  193. GE_CHK_BOOL_EXEC(arg_node != nullptr, return FAILED, "Insert Data node fail");
  194. GE_CHK_STATUS(GraphUtils::AddEdge(arg_node->GetOutDataAnchor(0), node->GetInDataAnchor(index)),
  195. "[Add][Edge]fail from node:%s to node:%s", data_op->GetName().c_str(), node->GetName().c_str());
  196. return SUCCESS;
  197. }
  198. static Status AddOutputs(const ComputeGraphPtr &graph, const NodePtr &node, const vector<GeTensor> &outputs) {
  199. OpDescPtr op_desc = MakeShared<ge::OpDesc>(graph->GetName() + "_" + NODE_NAME_NET_OUTPUT, NETOUTPUT);
  200. if (op_desc == nullptr) {
  201. return FAILED;
  202. }
  203. (void)AttrUtils::SetBool(op_desc, "_is_single_op", true);
  204. int32_t count = 0;
  205. for (const auto &out_desc : outputs) {
  206. GeTensorDesc tensor = out_desc.GetTensorDesc();
  207. TensorUtils::SetInputTensor(tensor, true);
  208. GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(tensor) == GRAPH_SUCCESS, return FAILED,
  209. "[Add][InputDesc]fail for node:%s", op_desc->GetName().c_str());
  210. TensorUtils::SetInputTensor(tensor, false);
  211. TensorUtils::SetOutputTensor(tensor, true);
  212. GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(tensor) == GRAPH_SUCCESS, return FAILED,
  213. "[Add][OutputDesc]fail for node:%s", op_desc->GetName().c_str());
  214. count++;
  215. }
  216. GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
  217. ge::NodePtr out_node = graph->AddNode(op_desc);
  218. GE_CHK_BOOL_EXEC(out_node != nullptr, return FAILED,
  219. "[Add][Node:%s]fail in graph:%u", op_desc->GetName().c_str(), graph->GetGraphID());
  220. GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);
  221. for (int32_t i = 0; i < count; ++i) {
  222. GE_CHK_STATUS(GraphUtils::AddEdge(node->GetOutDataAnchor(i), out_node->GetInDataAnchor(i)),
  223. "[Add][Edge]fail from node:%s to node:%s", node->GetName().c_str(), out_node->GetName().c_str());
  224. }
  225. return SUCCESS;
  226. }
  227. static void GetOpsProtoPath(string &opsproto_path) {
  228. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  229. if (path_env != nullptr) {
  230. string path = path_env;
  231. string file_path = RealPath(path.c_str());
  232. if (file_path.empty()) {
  233. GELOGE(FAILED, "File path %s is invalid.", path.c_str());
  234. return;
  235. }
  236. opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/");
  237. GELOGI("Get opsproto so path from env : %s", path.c_str());
  238. return;
  239. }
  240. string path_base = PluginManager::GetPath();
  241. GELOGI("path_base is %s", path_base.c_str());
  242. path_base = path_base.substr(0, path_base.rfind('/'));
  243. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  244. opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
  245. }
  246. static Status ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor> &inputs_dynamic) {
  247. for (auto input : inputs) {
  248. auto input_desc = input.GetTensorDesc();
  249. GeShape shape_ori = input_desc.GetShape();
  250. std::vector<int64_t> dynamic_shape_dims = {kDynamicDimValue};
  251. GeShape dynamic_shape(dynamic_shape_dims);
  252. std::vector<std::pair<int64_t, int64_t>> dynamic_shape_range;
  253. ge::GeTensor inputTensor;
  254. ge::GeTensorDesc desc(input_desc);
  255. bool is_const = false;
  256. (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);
  257. if (!is_const) {
  258. int64_t storage_format = FORMAT_NCHW;
  259. if (ge::AttrUtils::GetInt(desc, ge::ATTR_NAME_STORAGE_FORMAT, storage_format) &&
  260. !ge::AttrUtils::SetListInt(desc, ge::ATTR_NAME_STORAGE_SHAPE, dynamic_shape_dims)) {
  261. GELOGE(FAILED, "Set attr ATTR_NAME_STORAGE_SHAPE fail.");
  262. return FAILED;
  263. }
  264. desc.SetShape(dynamic_shape);
  265. desc.SetShapeRange(dynamic_shape_range);
  266. }
  267. inputTensor.SetTensorDesc(desc);
  268. inputs_dynamic.push_back(inputTensor);
  269. }
  270. return SUCCESS;
  271. }
  272. class GeGenerator::Impl {
  273. public:
  274. Impl(OmgContext &omg_context) : omg_context_(omg_context) {}
  275. ~Impl() = default;
  276. Status BuildModel(const Graph &graph, const vector<GeTensor> &inputs, GeRootModelPtr &ge_models);
  277. Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model);
  278. Status SaveRootModel(const string &file_name_prefix, GeRootModelPtr &model, ModelBufferData &model_buff);
  279. Status SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
  280. const vector<GeTensor> &inputs, const vector<GeTensor> &outputs);
  281. Status GenerateInfershapeGraph(const Graph &graph);
  282. OmgContext &omg_context_;
  283. GraphManager graph_manager_;
  284. SaveParam save_param_;
  285. bool is_offline_ = true;
  286. bool is_singleop_unregistered_ = false;
  287. std::string build_mode_;
  288. std::string build_step_;
  289. static std::mutex mutex_;
  290. private:
  291. static std::string Trim(const std::string &str);
  292. bool ParseVersion(const std::string &line, std::string &version);
  293. bool GetVersionFromPath(const std::string &file_path, std::string &version);
  294. bool SetAtcVersionInfo(AttrHolder &obj);
  295. bool SetOppVersionInfo(AttrHolder &obj);
  296. bool SetOmSystemInfo(AttrHolder &obj);
  297. };
  298. Status GeGenerator::Initialize(const map<string, string> &options) {
  299. return Initialize(options, domi::GetContext());
  300. }
  301. Status GeGenerator::Initialize(const map<string, string> &options, OmgContext &omg_context) {
  302. impl_ = ge::MakeShared<Impl>(omg_context);
  303. if (impl_ == nullptr) {
  304. GELOGE(MEMALLOC_FAILED, "Make shared failed");
  305. return MEMALLOC_FAILED;
  306. }
  307. ErrorManager::GetInstance().SetStage(ErrorMessage::kInitialize, ErrorMessage::kOpsProtoInit);
  308. string opsproto_path;
  309. GetOpsProtoPath(opsproto_path);
  310. GELOGI("Get opsproto path is %s", opsproto_path.c_str());
  311. OpsProtoManager *manager = OpsProtoManager::Instance();
  312. map<string, string> option_tmp;
  313. option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path));
  314. (void)manager->Initialize(option_tmp);
  315. Status ret = impl_->graph_manager_.Initialize(options);
  316. if (ret != SUCCESS) {
  317. GELOGE(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, "Graph manager initialize failed.");
  318. return GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED;
  319. }
  320. // get ek file
  321. auto iter = options.find(EK_FILE);
  322. if (iter != options.end()) {
  323. impl_->save_param_.ek_file = iter->second;
  324. }
  325. // get cert file
  326. iter = options.find(CERT_FILE);
  327. if (iter != options.end()) {
  328. impl_->save_param_.cert_file = iter->second;
  329. }
  330. // get hw key file
  331. iter = options.find(HW_KEY_FILE);
  332. if (iter != options.end()) {
  333. impl_->save_param_.hw_key_file = iter->second;
  334. }
  335. // get private file
  336. iter = options.find(PRIVATE_KEY_FILE);
  337. if (iter != options.end()) {
  338. impl_->save_param_.pri_key_file = iter->second;
  339. }
  340. // get build mode
  341. iter = options.find(BUILD_MODE);
  342. if (iter != options.end()) {
  343. impl_->build_mode_ = iter->second;
  344. }
  345. // get build step
  346. iter = options.find(BUILD_STEP);
  347. if (iter != options.end()) {
  348. impl_->build_step_ = iter->second;
  349. }
  350. return SUCCESS;
  351. }
  352. Status GeGenerator::Finalize() {
  353. ErrorManager::GetInstance().SetStage(ErrorMessage::kFinalize, ErrorMessage::kFinalize);
  354. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  355. Status ret = impl_->graph_manager_.Finalize();
  356. if (ret != SUCCESS) {
  357. GELOGE(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, "Graph manager finalize failed.");
  358. return GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED;
  359. }
  360. return SUCCESS;
  361. }
  362. Status GeGenerator::GenerateOfflineModel(const Graph &graph, const string &file_name_prefix,
  363. const vector<GeTensor> &inputs) {
  364. ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
  365. GELOGI("Start to generate offline model.");
  366. ModelBufferData model;
  367. return GenerateModel(graph, file_name_prefix, inputs, model, true);
  368. }
  369. Status GeGenerator::GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ModelBufferData &model) {
  370. ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
  371. return GenerateModel(graph, "online", inputs, model, false);
  372. }
  373. Status GeGenerator::GenerateInfershapeGraph(const Graph &graph) {
  374. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  375. Status ret = impl_->GenerateInfershapeGraph(graph);
  376. if (ret != SUCCESS) {
  377. GELOGE(ret, "Dump infershape json failed");
  378. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  379. GELOGE(FAILED, "graph_manager finalize fail.");
  380. }
  381. return ret;
  382. }
  383. GELOGI("Generate infer shape graph success");
  384. return SUCCESS;
  385. }
  386. std::mutex GeGenerator::Impl::mutex_;
  387. // Remove the space and tab before and after the string
  388. std::string GeGenerator::Impl::Trim(const std::string &str) {
  389. if (str.empty()) {
  390. return str;
  391. }
  392. std::string::size_type start = str.find_first_not_of(" \t\r\n");
  393. if (start == std::string::npos) {
  394. return str;
  395. }
  396. std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1;
  397. return str.substr(start, end);
  398. }
  399. // Parsing the command line
  400. bool GeGenerator::Impl::ParseVersion(const std::string &line, std::string &version) {
  401. std::string flag = "Version=";
  402. std::string temp = Trim(line);
  403. if (temp.empty()) {
  404. GELOGW("line is empty.");
  405. return false;
  406. }
  407. std::string::size_type pos = temp.find(flag);
  408. if (pos == std::string::npos) {
  409. GELOGW("Incorrect line [%s], it must include [%s].", line.c_str(), flag.c_str());
  410. return false;
  411. }
  412. if (temp.size() == flag.size()) {
  413. GELOGW("version information is empty. %s", line.c_str());
  414. return false;
  415. }
  416. version = temp.substr(pos + flag.size());
  417. return true;
  418. }
  419. bool GeGenerator::Impl::GetVersionFromPath(const std::string &file_path, std::string &version) {
  420. // Normalize the path
  421. string resolved_file_path = RealPath(file_path.c_str());
  422. if (resolved_file_path.empty()) {
  423. GELOGW("Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str());
  424. return false;
  425. }
  426. std::ifstream fs(resolved_file_path, std::ifstream::in);
  427. if (!fs.is_open()) {
  428. GELOGW("Open %s failed.", file_path.c_str());
  429. return false;
  430. }
  431. std::string line;
  432. if (getline(fs, line)) {
  433. if (!ParseVersion(line, version)) {
  434. GELOGW("Parse version failed. content is [%s].", line.c_str());
  435. fs.close();
  436. return false;
  437. }
  438. } else {
  439. GELOGW("No version information found in the file path:%s", file_path.c_str());
  440. fs.close();
  441. return false;
  442. }
  443. fs.close(); // close the file
  444. return true;
  445. }
  446. // Set package version information in the model
  447. bool GeGenerator::Impl::SetAtcVersionInfo(AttrHolder &obj) {
  448. std::string path_base = ge::GELib::GetPath();
  449. path_base = path_base.substr(0, path_base.rfind('/'));
  450. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  451. std::string version_path = path_base + "version.info";
  452. std::string version;
  453. if (!GetVersionFromPath(version_path, version)) {
  454. GELOGW("Get atc version information failed!");
  455. return false;
  456. }
  457. // set version info
  458. if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_ATC_VERSION, version)) {
  459. GELOGW("Ge model set atc version failed!");
  460. return false;
  461. }
  462. return true;
  463. }
  464. // Set package version information in the model
  465. bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) {
  466. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  467. if (path_env == nullptr) {
  468. GELOGW("Get environment variable ASCEND_OPP_PATH failed!");
  469. return false;
  470. }
  471. std::string version_path = path_env;
  472. version_path += "/version.info";
  473. std::string version;
  474. if (!GetVersionFromPath(version_path, version)) {
  475. GELOGW("Get opp version information failed!");
  476. return false;
  477. }
  478. // set version info
  479. if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_OPP_VERSION, version)) {
  480. GELOGW("Ge model set opp version failed!");
  481. return false;
  482. }
  483. return true;
  484. }
  485. bool GeGenerator::Impl::SetOmSystemInfo(AttrHolder &obj) {
  486. std::string soc_version;
  487. (void)ge::GetContext().GetOption(ge::SOC_VERSION, soc_version);
  488. GELOGI("SetOmSystemInfo soc_version: %s", soc_version.c_str());
  489. if (!ge::AttrUtils::SetStr(obj, "soc_version", soc_version)) {
  490. GELOGW("SetStr of soc_version failed.");
  491. return false;
  492. }
  493. std::string framework_type;
  494. (void)ge::GetContext().GetOption(ge::FRAMEWORK_TYPE, framework_type);
  495. GELOGI("SetOmSystemInfo framework_type: %s", framework_type.c_str());
  496. auto iter = ge::kFwkTypeToStr.find(framework_type);
  497. if (iter == ge::kFwkTypeToStr.end()) {
  498. GELOGW("Can not find framework_type in the map.");
  499. return false;
  500. }
  501. if (!ge::AttrUtils::SetStr(obj, "framework_type", iter->second)) {
  502. GELOGW("SetStr of framework_type failed.");
  503. return false;
  504. }
  505. return true;
  506. }
  507. Status GeGenerator::SetModelNameForDump(const GeRootModelPtr &ge_root_model) {
  508. bool is_unknown_shape = false;
  509. Status ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
  510. if (ret != SUCCESS) {
  511. GELOGE(FAILED, "[Check][IsUnknownShape]Check root model is unknown shape failed, model id:%u",
  512. ge_root_model->GetModelId());
  513. REPORT_CALL_ERROR("E19999", "Check root model is unknown shape failed, model id:%zu",
  514. ge_root_model->GetModelId());
  515. return FAILED;
  516. }
  517. GeModelPtr model_root = nullptr;
  518. if (is_unknown_shape) {
  519. model_root = MakeShared<GeModel>();
  520. GE_CHECK_NOTNULL(model_root);
  521. model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph()));
  522. ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root);
  523. }
  524. ModelHelper model_helper;
  525. string model_name;
  526. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  527. Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(),
  528. model_name);
  529. if (name_ret != SUCCESS) {
  530. ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"});
  531. GELOGE(FAILED, "[Check][GetModelNameStep]Get model_name failed. Param --output is invalid, root graph name: %s",
  532. ge_root_model->GetRootGraph()->GetName().c_str());
  533. REPORT_CALL_ERROR("E19999", "Get model_name failed. Param --output is invalid, root graph name: %s",
  534. ge_root_model->GetRootGraph()->GetName().c_str());
  535. return PARAM_INVALID;
  536. }
  537. map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  538. GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
  539. GE_CHECK_NOTNULL(ge_model);
  540. ge_model->SetName(model_name);
  541. return SUCCESS;
  542. }
  543. Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs,
  544. ModelBufferData &model, bool is_offline) {
  545. rtContext_t ctx = nullptr;
  546. auto rt = rtCtxGetCurrent(&ctx);
  547. if (rt != RT_ERROR_NONE) {
  548. GELOGD("Current ctx is null.");
  549. ctx = nullptr;
  550. }
  551. GeRootModelPtr ge_root_model = nullptr;
  552. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  553. impl_->is_offline_ = is_offline;
  554. Status ret = impl_->BuildModel(graph, inputs, ge_root_model);
  555. if (ret != SUCCESS) {
  556. GELOGE(ret, "Build model failed.");
  557. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  558. GELOGE(FAILED, "graph_manager finalize fail.");
  559. }
  560. return ret;
  561. }
  562. /// BUILD_MODE_TUNING with BUILD_STEP_BEFORE_UB_MATCH no need save model;
  563. /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER no need save model;
  564. /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need save model.
  565. if ((impl_->build_mode_ == BUILD_MODE_TUNING) &&
  566. (impl_->build_step_ == BUILD_STEP_BEFORE_UB_MATCH || impl_->build_step_ == BUILD_STEP_AFTER_BUILDER ||
  567. impl_->build_step_ == BUILD_STEP_AFTER_BUILDER_SUB)) {
  568. GELOGI("Build mode:%s with step:%s no need SaveModel.",
  569. impl_->build_mode_.c_str(),
  570. impl_->build_step_.c_str());
  571. return SUCCESS;
  572. }
  573. GE_CHECK_NOTNULL(ge_root_model);
  574. ret = SetModelNameForDump(ge_root_model);
  575. if (ret != SUCCESS) {
  576. return ret;
  577. }
  578. ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model);
  579. if (ret != SUCCESS) {
  580. GELOGE(ret, "Save model failed");
  581. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  582. GELOGE(FAILED, "graph_manager finalize fail.");
  583. }
  584. return ret;
  585. }
  586. if (ctx != nullptr) {
  587. (void)rtCtxSetCurrent(ctx);
  588. }
  589. return SUCCESS;
  590. }
  591. namespace {
  592. bool IsNeedConnectInputOpForSingleOp(GeTensorDesc &tensor_desc) {
  593. bool is_need = true;
  594. // format and dtype is all reserved, stand for Optional input. When singleop scene
  595. if (tensor_desc.GetFormat() == FORMAT_RESERVED && tensor_desc.GetDataType() == DT_UNDEFINED) {
  596. is_need = false;
  597. }
  598. return is_need;
  599. }
  600. Status CheckDynamicSupport(GeModelPtr &ge_model, const ComputeGraphPtr &graph) {
  601. bool support_dynamic = true;
  602. bool is_dynamic = false;
  603. for (const auto &node : graph->GetDirectNode()) {
  604. GE_CHECK_NOTNULL(node);
  605. auto op_desc = node->GetOpDesc();
  606. GE_CHECK_NOTNULL(op_desc);
  607. if (op_desc->GetOpEngineName() != kAIcoreEngine) {
  608. continue;
  609. }
  610. if (AttrUtils::HasAttr(op_desc, kAttrSupportDynamicShape)) {
  611. is_dynamic = true;
  612. (void) AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic);
  613. if (!support_dynamic) {
  614. GELOGW("Node[%s] doesn't support dynamic shape.", node->GetName().c_str());
  615. break;
  616. }
  617. }
  618. }
  619. if (is_dynamic) {
  620. (void) AttrUtils::SetBool(ge_model, kAttrSupportDynamicShape, support_dynamic);
  621. }
  622. return SUCCESS;
  623. }
  624. bool CheckNoAicore(const ComputeGraphPtr &graph) {
  625. for (const auto &node : graph->GetDirectNode()) {
  626. if (node == nullptr) {
  627. continue;
  628. }
  629. auto op_desc = node->GetOpDesc();
  630. if (op_desc == nullptr) {
  631. continue;
  632. }
  633. if (op_desc->GetOpEngineName() == kAIcoreEngine) {
  634. return false;
  635. }
  636. }
  637. return true;
  638. }
  639. }
  640. void GeGenerator::RemoveConst(const vector<GeTensor> &inputs, vector<GeTensor> &outputs) {
  641. for (auto &input : inputs) {
  642. GeTensorDesc input_desc = input.GetTensorDesc();
  643. bool is_const = false;
  644. (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);
  645. bool is_optional = IsOptional(input_desc);
  646. if (!is_optional && !is_const) {
  647. outputs.emplace_back(input);
  648. }
  649. }
  650. }
  651. Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  652. const vector<GeTensor> &outputs) {
  653. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  654. if (!inputs.empty() && (inputs.size() != op_desc->GetAllInputsSize())) {
  655. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  656. {op_desc->GetName(), op_desc->GetType(), "inputs size" + FmtToStr(op_desc->GetAllInputsSize()),
  657. "tensor size is " + FmtToStr(inputs.size())});
  658. GELOGE(PARAM_INVALID, "Tensor size: %zu, Inputs size: %zu", inputs.size(), op_desc->GetAllInputsSize());
  659. return PARAM_INVALID;
  660. }
  661. if (!outputs.empty() && (outputs.size() != op_desc->GetOutputsSize())) {
  662. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  663. {op_desc->GetName(), op_desc->GetType(), "outputs size" + FmtToStr(op_desc->GetOutputsSize()),
  664. "tensor size is " + FmtToStr(outputs.size())});
  665. GELOGE(PARAM_INVALID, "Tensor size: %zu, Outputs size: %zu", outputs.size(), op_desc->GetOutputsSize());
  666. return PARAM_INVALID;
  667. }
  668. return SUCCESS;
  669. }
  670. Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
  671. const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
  672. bool is_offline) {
  673. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  674. impl_->is_offline_ = is_offline;
  675. if (!is_offline) {
  676. (void)AttrUtils::SetBool(op_desc, ATTR_SINGLE_OP_SCENE, true);
  677. }
  678. if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) {
  679. GELOGE(PARAM_INVALID, "input param is invalid when build single op!");
  680. return PARAM_INVALID;
  681. }
  682. OmgContext &omg_context = (impl_ == nullptr) ? domi::GetContext() : impl_->omg_context_;
  683. omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc);
  684. if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
  685. impl_->is_singleop_unregistered_ = true;
  686. }
  687. // 0. Save original attributes.
  688. OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc);
  689. GE_CHECK_NOTNULL(op_desc_tmp);
  690. // 1. Create ComputeGraph.
  691. string name = ge::CurrentTimeInStr() + "_" + model_file_name;
  692. Graph graph;
  693. GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), "make graph fail.");
  694. // 2. check engine type when compile online
  695. if (model_file_name == kFileNameSuffix) {
  696. auto comp_graph = GraphUtils::GetComputeGraph(graph);
  697. GE_CHECK_NOTNULL(comp_graph);
  698. auto node = comp_graph->FindNode(op_desc->GetName());
  699. Status ret = CheckEngineTypeSupport(node, engine_type);
  700. if (ret != SUCCESS) {
  701. GELOGE(ret, "[Check][EngineType]value:%d for node:%s not support", engine_type, node->GetName().c_str());
  702. return ret;
  703. }
  704. }
  705. GELOGI("ATC parser success in single op build.");
  706. GeRootModelPtr ge_root_model = nullptr;
  707. vector<GeTensor> data_inputs;
  708. RemoveConst(inputs, data_inputs);
  709. GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, data_inputs, ge_root_model));
  710. map<string, GeAttrValue> op_attrs = op_desc_tmp->GetAllAttrs();
  711. GE_CHECK_NOTNULL(ge_root_model);
  712. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  713. map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  714. if (name_to_ge_model.empty()) {
  715. GELOGE(PARAM_INVALID, "GetSubgraphInstanceNameToModel is empty.");
  716. return PARAM_INVALID;
  717. }
  718. const ComputeGraphPtr root_graph = ge_root_model->GetRootGraph();
  719. GeModelPtr &ge_model = name_to_ge_model.begin()->second;
  720. GE_CHK_STATUS_RET_NOLOG(CheckDynamicSupport(ge_model, root_graph));
  721. GELOGI("After build model, The opType in op_desc_tmp is [%s]", op_desc_tmp->GetType().c_str());
  722. bool all_shape = false;
  723. (void)AttrUtils::GetBool(op_desc, kAicpuAllshape, all_shape);
  724. if (all_shape && CheckNoAicore(root_graph)) {
  725. GELOGD("Get aicpu all_shape kernel!");
  726. vector<GeTensor> inputs_dynamic;
  727. vector<GeTensor> outputs_dynamic;
  728. GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(inputs, inputs_dynamic));
  729. GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(outputs, outputs_dynamic));
  730. GE_CHK_STATUS_RET_NOLOG(
  731. impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs_dynamic, outputs_dynamic));
  732. } else {
  733. GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs));
  734. }
  735. GELOGI("Start save GeModel to Model buffer");
  736. GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_model, model_buff));
  737. return SUCCESS;
  738. }
  739. /**
  740. * @ingroup ge
  741. * @brief Compiling a single operator into an offline model
  742. * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file
  743. * @param [in] vector<GeTensor> &inputs: Operator input data description information.
  744. * @param [in] vector<GeTensor> &outputs: Operator output data description information.
  745. * @param [in] const string &model_file_name: Offline model filename.
  746. * @return SUCCESS handle successfully / others handle failed
  747. */
  748. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  749. const vector<GeTensor> &outputs, const string &model_file_name) {
  750. ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
  751. GELOGI("Start to build single op offline model, input size: %zu, output size: %zu", inputs.size(), outputs.size());
  752. ModelBufferData model_buff;
  753. OpEngineType engine_type = ENGINE_SYS;
  754. Status status = BuildSingleOp(op_desc, inputs, outputs, model_file_name, engine_type, model_buff, true);
  755. GELOGI("Finish build single offline model, status: %u", status);
  756. return status;
  757. }
  758. /**
  759. * @ingroup ge
  760. * @brief Compiling a single operator into online buffer
  761. * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file
  762. * @param [in] vector<GeTensor> &inputs: Operator input data description information.
  763. * @param [in] vector<GeTensor> &outputs: Operator output data description information.
  764. * @param [in] engine_type: specific engine.
  765. * @param [out] ModelBufferData &Model_buff: Model_buff: model buffer of the op.
  766. * @return SUCCESS handle successfully / others handle failed
  767. */
  768. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  769. const vector<GeTensor> &outputs, OpEngineType engine_type,
  770. ModelBufferData &model_buff) {
  771. ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
  772. GELOGI("Start to build single op online, input size: %zu, output size: %zu", inputs.size(), outputs.size());
  773. Status status = BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false);
  774. GELOGI("Finish build single online model, status: %u", status);
  775. return status;
  776. }
  777. Status GeGenerator::BuildSingleOpGraph(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  778. const vector<GeTensor> &outputs, std::string graph_name, Graph &graph) {
  779. ge::ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>(graph_name);
  780. GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR);
  781. // 1. Add Node to ComputeGraph.
  782. NodePtr op_node = compute_graph->AddNode(op_desc);
  783. GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR);
  784. // 2. Create InputData node.
  785. int32_t arg_index = 0;
  786. int32_t data_index = 0;
  787. if (inputs.empty()) {
  788. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  789. GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR);
  790. if (!IsNeedConnectInputOpForSingleOp(*input_desc)) {
  791. continue;
  792. }
  793. GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false, data_index));
  794. arg_index++;
  795. }
  796. } else {
  797. for (const auto &in_desc : inputs) {
  798. GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, in_desc.GetTensorDesc(), arg_index, true, data_index));
  799. arg_index++;
  800. }
  801. }
  802. // 3. Create Output node.
  803. if (!outputs.empty()) {
  804. GE_CHK_STATUS_RET_NOLOG(AddOutputs(compute_graph, op_node, outputs));
  805. }
  806. // dump ComputeGraph node.
  807. compute_graph->Dump();
  808. graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  809. return SUCCESS;
  810. }
  811. Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
  812. const vector<GeTensor> &inputs, const vector<GeTensor> &outputs) {
  813. GE_CHECK_NOTNULL_EXEC(ge_model, return PARAM_INVALID);
  814. GE_CHK_BOOL_EXEC_NOLOG(graph_manager_.SaveParams(*ge_model, type, attrs, inputs, outputs) == SUCCESS,
  815. (void)graph_manager_.Finalize();
  816. return FAILED);
  817. return SUCCESS;
  818. }
  819. Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr &model, ModelBufferData &model_buff) {
  820. // set atc version
  821. if (!SetAtcVersionInfo(*(model.get()))) {
  822. GELOGW("SetPackageVersionInfo of atc failed!");
  823. }
  824. // set opp version
  825. if (!SetOppVersionInfo(*(model.get()))) {
  826. GELOGW("SetPackageVersionInfo of ops failed!");
  827. }
  828. ModelHelper model_helper;
  829. model_helper.SetSaveMode(is_offline_);
  830. Status ret = model_helper.SaveToOmModel(model, save_param_, file_name_prefix, model_buff);
  831. if (ret != SUCCESS) {
  832. GELOGE(ret, "Save to om model failed");
  833. return ret;
  834. }
  835. return SUCCESS;
  836. }
  837. Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootModelPtr &ge_root_model,
  838. ModelBufferData &model_buff) {
  839. bool is_unknown_shape = false;
  840. auto ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
  841. if (ret != SUCCESS) {
  842. GELOGE(FAILED, "Check root model is unkonwn shape failed");
  843. return FAILED;
  844. }
  845. GELOGD("begin save root model, cur model is unkonwn shape model ? : %d", is_unknown_shape);
  846. GE_CHK_BOOL_EXEC(!ge_root_model->GetSubgraphInstanceNameToModel().empty(), return FAILED,
  847. "ge root model has no sub model")
  848. GeModelPtr model_root = nullptr;
  849. if (is_unknown_shape) {
  850. auto name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  851. model_root = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
  852. } else {
  853. model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second;
  854. }
  855. GE_CHECK_NOTNULL(model_root);
  856. // set atc version
  857. if (!SetAtcVersionInfo(*(model_root.get()))) {
  858. GELOGW("SetPackageVersionInfo of atc failed!");
  859. }
  860. // set opp version
  861. if (!SetOppVersionInfo(*(model_root.get()))) {
  862. GELOGW("SetPackageVersionInfo of ops failed!");
  863. }
  864. if (!SetOmSystemInfo(*(model_root.get()))) {
  865. GELOGW("SetOmsystemInfo failed!");
  866. }
  867. ModelHelper model_helper;
  868. model_helper.SetSaveMode(is_offline_);
  869. ret = model_helper.SaveToOmRootModel(ge_root_model, save_param_, file_name_prefix, model_buff, is_unknown_shape);
  870. if (ret != SUCCESS) {
  871. GELOGE(ret, "Save to om model failed");
  872. return ret;
  873. }
  874. return SUCCESS;
  875. }
  876. Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs,
  877. GeRootModelPtr &ge_root_model) {
  878. static std::atomic<GraphId> atomic_graph_id(0);
  879. auto graph_id = atomic_graph_id.fetch_add(1);
  880. const std::map<std::string, std::string> options;
  881. Status ret = graph_manager_.AddGraph(graph_id, graph, options, omg_context_);
  882. if (ret != SUCCESS) {
  883. GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "GraphManager add graph fail, graph id: %u", graph_id);
  884. (void)graph_manager_.Finalize();
  885. return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED;
  886. }
  887. graph_manager_.SetOptionsRunGraphFlag(false);
  888. static std::atomic<uint64_t> atomic_session_id(0);
  889. auto session_id = atomic_session_id.fetch_add(1);
  890. // This is a temporary add for graph with variable
  891. auto version = static_cast<int32_t>(SessionVersion::ClOUD_VERSION);
  892. ret = VarManager::Instance(session_id)->Init(version, session_id, kDefaultDeviceId, kDefaultJobId);
  893. GELOGI("Start init var instance, session_id %lu", session_id);
  894. if (ret != SUCCESS) {
  895. GELOGW("Failed init var instance, session_id %lu", session_id);
  896. }
  897. if (is_singleop_unregistered_) {
  898. ret = graph_manager_.BuildGraphForUnregisteredOp(graph_id, inputs, ge_root_model, session_id);
  899. } else {
  900. ret = graph_manager_.BuildGraph(graph_id, inputs, ge_root_model, session_id);
  901. }
  902. ErrorManager::GetInstance().SetStage(ErrorMessage::kModelCompile, ErrorMessage::kOther);
  903. if (ret != SUCCESS) {
  904. GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager build graph fail, graph id: %u", graph_id);
  905. VarManagerPool::Instance().RemoveVarManager(session_id);
  906. return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED;
  907. }
  908. VarManagerPool::Instance().RemoveVarManager(session_id);
  909. return SUCCESS;
  910. }
  911. Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph) {
  912. static std::atomic<GraphId> atomic_graph_id(0);
  913. auto graph_id = atomic_graph_id.fetch_add(1);
  914. const std::map<std::string, std::string> options;
  915. Status ret = graph_manager_.AddGraph(graph_id, graph, options, omg_context_);
  916. if (ret != SUCCESS) {
  917. GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "GraphManager add graph failed, graph id: %u", graph_id);
  918. (void)graph_manager_.Finalize();
  919. return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED;
  920. }
  921. ret = graph_manager_.GenerateInfershapeGraph(graph_id);
  922. if (ret != SUCCESS) {
  923. GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager generate graph failed");
  924. return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED;
  925. }
  926. return SUCCESS;
  927. }
  928. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示