You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_generator.cc 29 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "generator/ge_generator.h"
  17. #include <atomic>
  18. #include "common/ge/ge_util.h"
  19. #include "common/ge/plugin_manager.h"
  20. #include "common/helper/model_helper.h"
  21. #include "common/helper/om_file_helper.h"
  22. #include "common/util.h"
  23. #include "common/util/error_manager/error_manager.h"
  24. #include "framework/common/debug/ge_log.h"
  25. #include "ge/ge_api.h"
  26. #include "graph/debug/ge_attr_define.h"
  27. #include "graph/ge_context.h"
  28. #include "graph/manager/graph_manager.h"
  29. #include "graph/manager/util/rt_context_util.h"
  30. #include "graph/opsproto_manager.h"
  31. #include "graph/utils/graph_utils.h"
  32. #include "graph/utils/type_utils.h"
  33. #include "init/gelib.h"
  34. #include "model/ge_model.h"
  35. using std::map;
  36. using std::string;
  37. using std::vector;
  38. namespace {
  39. const char *const kAttrOpType = "op_type";
  40. const char *const kEngineNameDefault = "default";
  41. const char *const kVectorEngine = "VectorEngine";
  42. const char *const kAIcoreEngine = "AIcoreEngine";
  43. const char *const kFileNameSuffix = "online";
  44. std::map<ge::OpEngineType, std::string> engine_type_map{
  45. {ge::ENGINE_SYS, kEngineNameDefault}, {ge::ENGINE_AICORE, kAIcoreEngine}, {ge::ENGINE_VECTOR, kVectorEngine}};
  46. bool ContainsDynamicInpus(const ge::OpDesc &op_desc) {
  47. for (auto &tensor_desc : op_desc.GetAllInputsDescPtr()) {
  48. if (tensor_desc->MutableShape().IsUnknownShape()) {
  49. GELOGI("Contains unknown shape input. set is_dynamic_input to true.");
  50. return true;
  51. }
  52. }
  53. return false;
  54. }
  55. } // namespace
  56. namespace ge {
  57. static Status CheckEngineTypeSupport(const OpDescPtr &op_desc, OpEngineType engine_type) {
  58. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  59. if (engine_type == ENGINE_SYS) {
  60. GELOGI("CheckEngineType: use default engine.");
  61. return SUCCESS;
  62. }
  63. // get op engine name
  64. string op_engine_name;
  65. auto iter = engine_type_map.find(engine_type);
  66. if (iter != engine_type_map.end()) {
  67. op_engine_name = iter->second;
  68. GELOGI("CheckEngineType: engine type: %d", static_cast<int>(engine_type));
  69. } else {
  70. GELOGE(FAILED, "CheckEngineType: engine type: %d not support", static_cast<int>(engine_type));
  71. return FAILED;
  72. }
  73. if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
  74. op_desc->SetOpEngineName(op_engine_name);
  75. op_desc->SetOpKernelLibName(op_engine_name);
  76. return SUCCESS;
  77. }
  78. // set op engine name and opkernelLib. when engine support
  79. std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
  80. if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
  81. GELOGE(GE_CLI_GE_NOT_INITIALIZED, "CheckEngineType failed.");
  82. return FAILED;
  83. }
  84. OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj();
  85. std::vector<OpInfo> op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType());
  86. if (op_infos.empty()) {
  87. GELOGE(FAILED, "CheckEngineType: Can not get op info by op type %s", op_desc->GetType().c_str());
  88. return FAILED;
  89. }
  90. string kernel_name;
  91. for (const auto &it : op_infos) {
  92. if (it.engine == op_engine_name) {
  93. kernel_name = it.opKernelLib;
  94. break;
  95. }
  96. }
  97. if (kernel_name.empty()) {
  98. GELOGE(FAILED, "CheckEngineType:Can not find ops kernel,engine name: %s.", op_engine_name.c_str());
  99. return FAILED;
  100. }
  101. auto &kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores();
  102. auto kernel_info_store = kernel_map.find(kernel_name);
  103. if (kernel_info_store != kernel_map.end()) {
  104. std::string unsupported_reason;
  105. if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) {
  106. op_desc->SetOpEngineName(op_engine_name);
  107. op_desc->SetOpKernelLibName(kernel_name);
  108. GELOGI("CheckEngineType:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(),
  109. op_engine_name.c_str(), op_desc->GetName().c_str());
  110. return SUCCESS;
  111. } else {
  112. GELOGE(FAILED, "CheckEngineType: check support failed, Op type %s of ops kernel %s is unsupported, reason:%s",
  113. op_desc->GetType().c_str(), kernel_name.c_str(), unsupported_reason.c_str());
  114. return FAILED;
  115. }
  116. } else {
  117. GELOGE(FAILED,
  118. "CheckEngineType:Can not find any supported ops kernel info store by kernel_name %s,"
  119. "op type is %s, op name is %s",
  120. kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str());
  121. }
  122. return FAILED;
  123. }
  124. static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, GeTensorDesc &tensor, int32_t index,
  125. bool attr) {
  126. GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
  127. GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);
  128. auto format = tensor.GetFormat();
  129. auto data_type = tensor.GetDataType();
  130. if (format == FORMAT_RESERVED && data_type == DT_UNDEFINED) {
  131. return SUCCESS;
  132. }
  133. string op_type;
  134. if (!AttrUtils::GetStr(tensor, kAttrOpType, op_type) || op_type.empty()) {
  135. op_type = DATA;
  136. }
  137. string op_name = node->GetName() + "_in_" + std::to_string(index);
  138. OpDescPtr data_op = MakeShared<ge::OpDesc>(op_name, op_type);
  139. if (data_op == nullptr) {
  140. return FAILED;
  141. }
  142. (void)AttrUtils::SetBool(data_op, "_is_single_op", true);
  143. GE_CHK_BOOL_EXEC(data_op->AddInputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "Add input desc fail.");
  144. GE_CHK_BOOL_EXEC(data_op->AddOutputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "Add output desc fail.");
  145. if (attr) {
  146. GE_CHK_BOOL_EXEC(AttrUtils::SetInt(data_op, ATTR_NAME_INDEX, index), return FAILED, "Set index fail.");
  147. }
  148. ge::NodePtr arg_node = graph->AddNode(data_op);
  149. GE_CHK_BOOL_EXEC(arg_node != nullptr, return FAILED, "Insert Data node fail.");
  150. GE_CHK_STATUS(GraphUtils::AddEdge(arg_node->GetOutDataAnchor(0), node->GetInDataAnchor(index)),
  151. "Add edge[%s->%s] fail.", data_op->GetName().c_str(), node->GetName().c_str());
  152. return SUCCESS;
  153. }
  154. static Status AddOutputs(const ComputeGraphPtr &graph, const NodePtr &node, const vector<GeTensor> &outputs) {
  155. OpDescPtr op_desc = MakeShared<ge::OpDesc>(graph->GetName() + "_" + NODE_NAME_NET_OUTPUT, NETOUTPUT);
  156. if (op_desc == nullptr) {
  157. return FAILED;
  158. }
  159. (void)AttrUtils::SetBool(op_desc, "_is_single_op", true);
  160. int32_t count = 0;
  161. for (const auto &out_desc : outputs) {
  162. GeTensorDesc tensor = out_desc.GetTensorDesc();
  163. TensorUtils::SetInputTensor(tensor, true);
  164. GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "Add input desc fail");
  165. TensorUtils::SetInputTensor(tensor, false);
  166. TensorUtils::SetOutputTensor(tensor, true);
  167. GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "Add output desc fail");
  168. count++;
  169. }
  170. GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
  171. ge::NodePtr out_node = graph->AddNode(op_desc);
  172. GE_CHK_BOOL_EXEC(out_node != nullptr, return FAILED, "Insert Output node fail.");
  173. GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);
  174. for (int32_t i = 0; i < count; ++i) {
  175. GE_CHK_STATUS(GraphUtils::AddEdge(node->GetOutDataAnchor(i), out_node->GetInDataAnchor(i)),
  176. "Add edge[%s->%s] fail.", node->GetName().c_str(), out_node->GetName().c_str());
  177. }
  178. return SUCCESS;
  179. }
  180. static void GetOpsProtoPath(string &opsproto_path) {
  181. GELOGI("Start to get ops proto path schedule.");
  182. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  183. if (path_env != nullptr) {
  184. string path = path_env;
  185. string file_path = RealPath(path.c_str());
  186. if (file_path.empty()) {
  187. GELOGE(FAILED, "File path %s is invalid.", path.c_str());
  188. return;
  189. }
  190. opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/");
  191. GELOGI("Get opsproto so path from env : %s", path.c_str());
  192. return;
  193. }
  194. string path_base = PluginManager::GetPath();
  195. GELOGI("path_base is %s", path_base.c_str());
  196. path_base = path_base.substr(0, path_base.rfind('/'));
  197. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  198. opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
  199. }
  200. class GeGenerator::Impl {
  201. public:
  202. Impl(OmgContext &omg_context) : omg_context_(omg_context), graph_manager_(omg_context) {}
  203. ~Impl() = default;
  204. Status BuildModel(const Graph &graph, const vector<GeTensor> &inputs, GeRootModelPtr &ge_models);
  205. Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model);
  206. Status SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
  207. const vector<GeTensor> &inputs, const vector<GeTensor> &outputs);
  208. Status GenerateInfershapeGraph(const Graph &graph);
  209. OmgContext &omg_context_;
  210. GraphManager graph_manager_;
  211. SaveParam save_param_;
  212. bool is_offline_ = true;
  213. bool is_singleop_unregistered_ = false;
  214. std::string build_mode_;
  215. std::string build_step_;
  216. static std::mutex mutex_;
  217. private:
  218. static std::string Trim(const std::string &str);
  219. bool ParseVersion(const std::string &line, std::string &version);
  220. bool GetVersionFromPath(const std::string &file_path, std::string &version);
  221. bool SetAtcVersionInfo(AttrHolder &obj);
  222. bool SetOppVersionInfo(AttrHolder &obj);
  223. bool SetOmSystemInfo(AttrHolder &obj);
  224. };
  225. Status GeGenerator::Initialize(const map<string, string> &options) { return Initialize(options, domi::GetContext()); }
  226. Status GeGenerator::Initialize(const map<string, string> &options, OmgContext &omg_context) {
  227. impl_ = ge::MakeShared<Impl>(omg_context);
  228. if (impl_ == nullptr) {
  229. GELOGE(MEMALLOC_FAILED, "Make shared failed");
  230. return MEMALLOC_FAILED;
  231. }
  232. string opsproto_path;
  233. GetOpsProtoPath(opsproto_path);
  234. GELOGI("Get opsproto path is %s", opsproto_path.c_str());
  235. OpsProtoManager *manager = OpsProtoManager::Instance();
  236. map<string, string> option_tmp;
  237. option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path));
  238. (void)manager->Initialize(option_tmp);
  239. Status ret = impl_->graph_manager_.Initialize(options);
  240. if (ret != SUCCESS) {
  241. GELOGE(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, "Graph manager initialize failed.");
  242. return GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED;
  243. }
  244. // get ek file
  245. auto iter = options.find(EK_FILE);
  246. if (iter != options.end()) {
  247. impl_->save_param_.ek_file = iter->second;
  248. }
  249. // get cert file
  250. iter = options.find(CERT_FILE);
  251. if (iter != options.end()) {
  252. impl_->save_param_.cert_file = iter->second;
  253. }
  254. // get hw key file
  255. iter = options.find(HW_KEY_FILE);
  256. if (iter != options.end()) {
  257. impl_->save_param_.hw_key_file = iter->second;
  258. }
  259. // get private file
  260. iter = options.find(PRIVATE_KEY_FILE);
  261. if (iter != options.end()) {
  262. impl_->save_param_.pri_key_file = iter->second;
  263. }
  264. // get build mode
  265. iter = options.find(BUILD_MODE);
  266. if (iter != options.end()) {
  267. impl_->build_mode_ = iter->second;
  268. }
  269. // get build step
  270. iter = options.find(BUILD_STEP);
  271. if (iter != options.end()) {
  272. impl_->build_step_ = iter->second;
  273. }
  274. return SUCCESS;
  275. }
  276. Status GeGenerator::Finalize() {
  277. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  278. Status ret = impl_->graph_manager_.Finalize();
  279. if (ret != SUCCESS) {
  280. GELOGE(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, "Graph manager finalize failed.");
  281. return GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED;
  282. }
  283. return SUCCESS;
  284. }
  285. Status GeGenerator::GenerateOfflineModel(const Graph &graph, const string &file_name_prefix,
  286. const vector<GeTensor> &inputs) {
  287. GELOGI("Start to generate offline model.");
  288. ModelBufferData model;
  289. return GenerateModel(graph, file_name_prefix, inputs, model, true);
  290. }
  291. Status GeGenerator::GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ModelBufferData &model) {
  292. return GenerateModel(graph, "online", inputs, model, false);
  293. }
  294. Status GeGenerator::GenerateInfershapeGraph(const Graph &graph) {
  295. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  296. Status ret = impl_->GenerateInfershapeGraph(graph);
  297. if (ret != SUCCESS) {
  298. GELOGE(ret, "Dump infershape json failed");
  299. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  300. GELOGE(FAILED, "graph_manager finalize fail.");
  301. }
  302. return ret;
  303. }
  304. GELOGI("Generate infer shape graph success");
  305. return SUCCESS;
  306. }
  307. std::mutex GeGenerator::Impl::mutex_;
  308. // Remove the space and tab before and after the string
  309. std::string GeGenerator::Impl::Trim(const std::string &str) {
  310. if (str.empty()) {
  311. return str;
  312. }
  313. std::string::size_type start = str.find_first_not_of(" \t\r\n");
  314. if (start == std::string::npos) {
  315. return str;
  316. }
  317. std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1;
  318. return str.substr(start, end);
  319. }
  320. // Parsing the command line
  321. bool GeGenerator::Impl::ParseVersion(const std::string &line, std::string &version) {
  322. std::string flag = "Version=";
  323. std::string temp = Trim(line);
  324. if (temp.empty()) {
  325. GELOGW("line is empty.");
  326. return false;
  327. }
  328. std::string::size_type pos = temp.find(flag);
  329. if (pos == std::string::npos) {
  330. GELOGW("Incorrect line [%s], it must include [%s].", line.c_str(), flag.c_str());
  331. return false;
  332. }
  333. if (temp.size() == flag.size()) {
  334. GELOGW("version information is empty. %s", line.c_str());
  335. return false;
  336. }
  337. version = temp.substr(pos + flag.size());
  338. GELOGI("Version=%s", version.c_str());
  339. return true;
  340. }
  341. bool GeGenerator::Impl::GetVersionFromPath(const std::string &file_path, std::string &version) {
  342. // Normalize the path
  343. string resolved_file_path = RealPath(file_path.c_str());
  344. if (resolved_file_path.empty()) {
  345. GELOGW("Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str());
  346. return false;
  347. }
  348. std::ifstream fs(resolved_file_path, std::ifstream::in);
  349. if (!fs.is_open()) {
  350. GELOGW("Open %s failed.", file_path.c_str());
  351. return false;
  352. }
  353. std::string line;
  354. if (getline(fs, line)) {
  355. if (!ParseVersion(line, version)) {
  356. GELOGW("Parse version failed. content is [%s].", line.c_str());
  357. fs.close();
  358. return false;
  359. }
  360. } else {
  361. GELOGW("No version information found in the file path:%s", file_path.c_str());
  362. fs.close();
  363. return false;
  364. }
  365. fs.close(); // close the file
  366. return true;
  367. }
  368. // Set package version information in the model
  369. bool GeGenerator::Impl::SetAtcVersionInfo(AttrHolder &obj) {
  370. std::string path_base = ge::GELib::GetPath();
  371. path_base = path_base.substr(0, path_base.rfind('/'));
  372. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  373. std::string version_path = path_base + "version.info";
  374. GELOGI("version_path is %s", version_path.c_str());
  375. std::string version;
  376. if (!GetVersionFromPath(version_path, version)) {
  377. GELOGW("Get atc version information failed!");
  378. return false;
  379. }
  380. // set version info
  381. if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_ATC_VERSION, version)) {
  382. GELOGW("Ge model set atc version failed!");
  383. return false;
  384. }
  385. GELOGI("Ge model set atc version information success.");
  386. return true;
  387. }
  388. // Set package version information in the model
  389. bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) {
  390. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  391. if (path_env == nullptr) {
  392. GELOGW("Get environment variable ASCEND_OPP_PATH failed!");
  393. return false;
  394. }
  395. std::string version_path = path_env;
  396. version_path += "/version.info";
  397. GELOGI("version_path is %s", version_path.c_str());
  398. std::string version;
  399. if (!GetVersionFromPath(version_path, version)) {
  400. GELOGW("Get opp version information failed!");
  401. return false;
  402. }
  403. // set version info
  404. if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_OPP_VERSION, version)) {
  405. GELOGW("Ge model set opp version failed!");
  406. return false;
  407. }
  408. GELOGI("Ge Model set opp version information success.");
  409. return true;
  410. }
  411. bool GeGenerator::Impl::SetOmSystemInfo(AttrHolder &obj) {
  412. std::string soc_version;
  413. (void)ge::GetContext().GetOption(ge::SOC_VERSION, soc_version);
  414. GELOGI("SetOmSystemInfo soc_version: %s", soc_version.c_str());
  415. if (!ge::AttrUtils::SetStr(obj, "soc_version", soc_version)) {
  416. GELOGW("SetStr of soc_version failed.");
  417. return false;
  418. }
  419. // 0(Caffe) 1(MindSpore) 3(TensorFlow) 5(Onnx)
  420. std::map<string, string> framework_type_to_string = {
  421. {"0", "Caffe"},
  422. {"1", "MindSpore"},
  423. {"3", "TensorFlow"},
  424. {"5", "Onnx"}
  425. };
  426. std::string framework_type;
  427. (void)ge::GetContext().GetOption(ge::FRAMEWORK_TYPE, framework_type);
  428. GELOGI("SetOmSystemInfo framework_type: %s", framework_type.c_str());
  429. if (!ge::AttrUtils::SetStr(obj, "framework_type", framework_type_to_string[framework_type.c_str()])) {
  430. GELOGW("SetStr of framework_type failed.");
  431. return false;
  432. }
  433. return true;
  434. }
  435. Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs,
  436. ModelBufferData &model, bool is_offline) {
  437. rtContext_t ctx = nullptr;
  438. auto rt = rtCtxGetCurrent(&ctx);
  439. if (rt != RT_ERROR_NONE) {
  440. GELOGW("Current ctx is null.");
  441. ctx = nullptr;
  442. }
  443. GeRootModelPtr ge_root_model = nullptr;
  444. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  445. impl_->is_offline_ = is_offline;
  446. Status ret = impl_->BuildModel(graph, inputs, ge_root_model);
  447. if (ret != SUCCESS) {
  448. GELOGE(ret, "Build model failed.");
  449. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  450. GELOGE(FAILED, "graph_manager finalize fail.");
  451. }
  452. return ret;
  453. }
  454. /// BUILD_MODE_TUNING with BUILD_STEP_BEFORE_UB_MATCH no need save model;
  455. /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER no need save model;
  456. /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need save model.
  457. if ((impl_->build_mode_ == BUILD_MODE_TUNING) &&
  458. (impl_->build_step_ == BUILD_STEP_BEFORE_UB_MATCH || impl_->build_step_ == BUILD_STEP_AFTER_BUILDER ||
  459. impl_->build_step_ == BUILD_STEP_AFTER_BUILDER_SUB)) {
  460. GELOGI("Build mode:%s with step:%s no need SaveModel.", impl_->build_mode_.c_str(), impl_->build_step_.c_str());
  461. return SUCCESS;
  462. }
  463. GE_CHECK_NOTNULL(ge_root_model);
  464. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  465. ModelHelper model_helper;
  466. string model_name = "";
  467. Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(), model_name);
  468. if (name_ret != SUCCESS) {
  469. ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"});
  470. GELOGE(FAILED, "Get model_name failed. Param --output is invalid");
  471. return PARAM_INVALID;
  472. }
  473. map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  474. GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
  475. GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model can not be null");
  476. ge_model->SetName(model_name);
  477. ret = impl_->SaveModel(file_name_prefix, ge_model, model);
  478. if (ret != SUCCESS) {
  479. GELOGE(ret, "Save model failed");
  480. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  481. GELOGE(FAILED, "graph_manager finalize fail.");
  482. }
  483. return ret;
  484. }
  485. if (ctx != nullptr) {
  486. (void)rtCtxSetCurrent(ctx);
  487. }
  488. GELOGI("GenerateOfflineModel success.");
  489. return SUCCESS;
  490. }
  491. Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
  492. const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
  493. bool is_offline) {
  494. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  495. if (!inputs.empty() && (inputs.size() != op_desc->GetAllInputsSize())) {
  496. GELOGE(PARAM_INVALID, "Tensor size: %zu, Inputs size: %zu", inputs.size(), op_desc->GetAllInputsSize());
  497. return PARAM_INVALID;
  498. }
  499. if (!outputs.empty() && (outputs.size() != op_desc->GetOutputsSize())) {
  500. GELOGE(PARAM_INVALID, "Tensor size: %zu, Outputs size: %zu", outputs.size(), op_desc->GetOutputsSize());
  501. return PARAM_INVALID;
  502. }
  503. OmgContext &omg_context = (impl_ == nullptr) ? domi::GetContext() : impl_->omg_context_;
  504. omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc);
  505. if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
  506. impl_->is_singleop_unregistered_ = true;
  507. }
  508. // 0. Save original attributes.
  509. OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc);
  510. GE_CHECK_NOTNULL(op_desc_tmp);
  511. // 1. check engine type when compile online
  512. if (model_file_name == kFileNameSuffix) {
  513. Status ret = CheckEngineTypeSupport(op_desc, engine_type);
  514. if (ret != SUCCESS) {
  515. GELOGE(ret, "check engine type failed.");
  516. return ret;
  517. }
  518. }
  519. // 2. Create ComputeGraph.
  520. string name = ge::CurrentTimeInStr() + "_" + model_file_name;
  521. ge::ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>(name);
  522. GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR);
  523. // 3. Add Node to ComputeGraph.
  524. NodePtr op_node = compute_graph->AddNode(op_desc);
  525. GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR);
  526. // 4. Create InputData node.
  527. int32_t arg_index = 0;
  528. if (inputs.empty()) {
  529. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  530. GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR);
  531. GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false));
  532. arg_index++;
  533. }
  534. } else {
  535. for (const auto &in_desc : inputs) {
  536. GeTensorDesc input_desc = in_desc.GetTensorDesc();
  537. GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, input_desc, arg_index, true));
  538. arg_index++;
  539. }
  540. }
  541. // 5. Create Output node.
  542. if (!outputs.empty()) {
  543. GE_CHK_STATUS_RET_NOLOG(AddOutputs(compute_graph, op_node, outputs));
  544. }
  545. // dump ComputeGraph.
  546. compute_graph->Dump();
  547. Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  548. GELOGI("ATC parser success in single op build.");
  549. GeRootModelPtr ge_root_model = nullptr;
  550. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  551. impl_->is_offline_ = is_offline;
  552. GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, inputs, ge_root_model));
  553. map<string, GeAttrValue> op_attrs = op_desc_tmp->GetAllAttrs();
  554. GE_CHECK_NOTNULL(ge_root_model);
  555. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  556. map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  557. if (name_to_ge_model.empty()) {
  558. GELOGE(PARAM_INVALID, "GetSubgraphInstanceNameToModel is empty.");
  559. return PARAM_INVALID;
  560. }
  561. GeModelPtr &ge_model = name_to_ge_model.begin()->second;
  562. GELOGD("The opType in op_desc_tmp is [%s]", op_desc_tmp->GetType().c_str());
  563. GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs));
  564. GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_model, model_buff));
  565. return SUCCESS;
  566. }
  567. /**
  568. * @ingroup ge
  569. * @brief Compiling a single operator into an offline model
  570. * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file
  571. * @param [in] vector<GeTensor> &inputs: Operator input data description information.
  572. * @param [in] vector<GeTensor> &outputs: Operator output data description information.
  573. * @param [in] const string &model_file_name: Offline model filename.
  574. * @return SUCCESS handle successfully / others handle failed
  575. */
  576. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  577. const vector<GeTensor> &outputs, const string &model_file_name) {
  578. GELOGI("Start to build single op offline model.");
  579. ModelBufferData model_buff;
  580. OpEngineType engine_type = ENGINE_SYS;
  581. return BuildSingleOp(op_desc, inputs, outputs, model_file_name, engine_type, model_buff, true);
  582. }
  583. /**
  584. * @ingroup ge
  585. * @brief Compiling a single operator into online buffer
  586. * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file
  587. * @param [in] vector<GeTensor> &inputs: Operator input data description information.
  588. * @param [in] vector<GeTensor> &outputs: Operator output data description information.
  589. * @param [in] engine_type: specific engine.
  590. * @param [out] ModelBufferData &Model_buff: Model_buff: model buffer of the op.
  591. * @return SUCCESS handle successfully / others handle failed
  592. */
  593. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  594. const vector<GeTensor> &outputs, OpEngineType engine_type,
  595. ModelBufferData &model_buff) {
  596. GELOGI("Start to build single op online");
  597. return BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false);
  598. }
  599. Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
  600. const vector<GeTensor> &inputs, const vector<GeTensor> &outputs) {
  601. GE_CHECK_NOTNULL_EXEC(ge_model, return PARAM_INVALID);
  602. GE_CHK_BOOL_EXEC_NOLOG(graph_manager_.SaveParams(*ge_model, type, attrs, inputs, outputs) == SUCCESS,
  603. (void)graph_manager_.Finalize();
  604. return FAILED);
  605. return SUCCESS;
  606. }
  607. Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr &model, ModelBufferData &model_buff) {
  608. // set atc version
  609. if (!SetAtcVersionInfo(*(model.get()))) {
  610. GELOGW("SetPackageVersionInfo of atc failed!");
  611. }
  612. // set opp version
  613. if (!SetOppVersionInfo(*(model.get()))) {
  614. GELOGW("SetPackageVersionInfo of ops failed!");
  615. }
  616. if (!SetOmSystemInfo(*(model_root.get()))) {
  617. GELOGW("SetOmsystemInfo failed!");
  618. }
  619. ModelHelper model_helper;
  620. model_helper.SetSaveMode(is_offline_);
  621. Status ret = model_helper.SaveToOmModel(model, save_param_, file_name_prefix, model_buff);
  622. if (ret != SUCCESS) {
  623. GELOGE(ret, "Save to om model failed");
  624. return ret;
  625. }
  626. return SUCCESS;
  627. }
  628. Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs,
  629. GeRootModelPtr &ge_root_model) {
  630. static std::atomic<GraphId> atomic_graph_id(0);
  631. auto graph_id = atomic_graph_id.fetch_add(1);
  632. const std::map<std::string, std::string> options;
  633. Status ret = graph_manager_.AddGraph(graph_id, graph, options);
  634. if (ret != SUCCESS) {
  635. GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "GraphManager add graph fail, graph id: %u", graph_id);
  636. (void)graph_manager_.Finalize();
  637. return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED;
  638. }
  639. GELOGI("Model inputs size is %zu", inputs.size());
  640. graph_manager_.SetOptionsRunGraphFlag(false);
  641. static std::atomic<uint64_t> atomic_session_id(0);
  642. auto session_id = atomic_session_id.fetch_add(1);
  643. if (is_singleop_unregistered_) {
  644. ret = graph_manager_.BuildGraphForUnregisteredOp(graph_id, inputs, ge_root_model, session_id);
  645. } else {
  646. ret = graph_manager_.BuildGraph(graph_id, inputs, ge_root_model, session_id);
  647. }
  648. if (ret != SUCCESS) {
  649. GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager build graph fail, graph id: %u", graph_id);
  650. VarManagerPool::Instance().RemoveVarManager(session_id);
  651. return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED;
  652. }
  653. VarManagerPool::Instance().RemoveVarManager(session_id);
  654. return SUCCESS;
  655. }
  656. Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph) {
  657. static std::atomic<GraphId> atomic_graph_id(0);
  658. auto graph_id = atomic_graph_id.fetch_add(1);
  659. const std::map<std::string, std::string> options;
  660. Status ret = graph_manager_.AddGraph(graph_id, graph, options);
  661. if (ret != SUCCESS) {
  662. GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "GraphManager add graph failed, graph id: %u", graph_id);
  663. (void)graph_manager_.Finalize();
  664. return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED;
  665. }
  666. ret = graph_manager_.GenerateInfershapeGraph(graph_id);
  667. if (ret != SUCCESS) {
  668. GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager generate graph failed");
  669. return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED;
  670. }
  671. return SUCCESS;
  672. }
  673. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示