You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_generator.cc 35 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "generator/ge_generator.h"
  17. #include <atomic>
  18. #include "common/ge/ge_util.h"
  19. #include "common/ge/plugin_manager.h"
  20. #include "common/helper/model_helper.h"
  21. #include "common/helper/om_file_helper.h"
  22. #include "common/util.h"
  23. #include "common/util/error_manager/error_manager.h"
  24. #include "framework/common/debug/ge_log.h"
  25. #include "framework/common/debug/log.h"
  26. #include "ge/ge_api.h"
  27. #include "graph/debug/ge_attr_define.h"
  28. #include "graph/ge_context.h"
  29. #include "graph/manager/graph_manager.h"
  30. #include "graph/manager/util/rt_context_util.h"
  31. #include "graph/opsproto_manager.h"
  32. #include "graph/utils/graph_utils.h"
  33. #include "graph/utils/type_utils.h"
  34. #include "init/gelib.h"
  35. #include "model/ge_model.h"
  36. using std::map;
  37. using std::string;
  38. using std::vector;
  39. namespace {
  40. const char *const kAttrOpType = "op_type";
  41. const char *const kEngineNameDefault = "default";
  42. const char *const kVectorEngine = "VectorEngine";
  43. const char *const kAIcoreEngine = "AIcoreEngine";
  44. const char *const kFileNameSuffix = "online";
  45. const size_t kDynamicDimSize = 1;
  46. const int64_t kDynamicDimValue = -2;
  47. std::map<ge::OpEngineType, std::string> engine_type_map{
  48. {ge::ENGINE_SYS, kEngineNameDefault}, {ge::ENGINE_AICORE, kAIcoreEngine}, {ge::ENGINE_VECTOR, kVectorEngine}};
  49. bool ContainsDynamicInpus(const ge::OpDesc &op_desc) {
  50. for (auto &tensor_desc : op_desc.GetAllInputsDescPtr()) {
  51. if (tensor_desc->MutableShape().IsUnknownShape()) {
  52. GELOGI("Contains unknown shape input. set is_dynamic_input to true.");
  53. return true;
  54. }
  55. }
  56. return false;
  57. }
  58. } // namespace
  59. namespace ge {
  60. static Status CheckEngineTypeSupport(const OpDescPtr &op_desc, OpEngineType engine_type) {
  61. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  62. if (engine_type == ENGINE_SYS) {
  63. GELOGI("CheckEngineType: use default engine.");
  64. return SUCCESS;
  65. }
  66. // get op engine name
  67. string op_engine_name;
  68. auto iter = engine_type_map.find(engine_type);
  69. if (iter != engine_type_map.end()) {
  70. op_engine_name = iter->second;
  71. GELOGI("CheckEngineType: engine type: %d", static_cast<int>(engine_type));
  72. } else {
  73. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  74. {op_desc->GetName(), op_desc->GetType(), "engine type",
  75. "it only support kEngineNameDefault/kAIcoreEngine/kVectorEngine"});
  76. GELOGE(FAILED, "CheckEngineType: engine type: %d not support", static_cast<int>(engine_type));
  77. return FAILED;
  78. }
  79. if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
  80. op_desc->SetOpEngineName(op_engine_name);
  81. op_desc->SetOpKernelLibName(op_engine_name);
  82. return SUCCESS;
  83. }
  84. // set op engine name and opkernelLib. when engine support
  85. std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
  86. if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
  87. GELOGE(GE_CLI_GE_NOT_INITIALIZED, "CheckEngineType failed.");
  88. return FAILED;
  89. }
  90. OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj();
  91. std::vector<OpInfo> op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType());
  92. if (op_infos.empty()) {
  93. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  94. {op_desc->GetName(), op_desc->GetType(), "optype", "it can not find"});
  95. GELOGE(FAILED, "CheckEngineType: Can not get op info by op type %s", op_desc->GetType().c_str());
  96. return FAILED;
  97. }
  98. string kernel_name;
  99. for (const auto &it : op_infos) {
  100. if (it.engine == op_engine_name) {
  101. kernel_name = it.opKernelLib;
  102. break;
  103. }
  104. }
  105. if (kernel_name.empty()) {
  106. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  107. {op_desc->GetName(), op_desc->GetType(), "engine name" + FmtToStr(op_engine_name), "it can not find"});
  108. GELOGE(FAILED, "CheckEngineType:Can not find ops kernel, engine name: %s.", op_engine_name.c_str());
  109. return FAILED;
  110. }
  111. auto &kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores();
  112. auto kernel_info_store = kernel_map.find(kernel_name);
  113. if (kernel_info_store != kernel_map.end()) {
  114. std::string unsupported_reason;
  115. if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) {
  116. op_desc->SetOpEngineName(op_engine_name);
  117. op_desc->SetOpKernelLibName(kernel_name);
  118. GELOGI("CheckEngineType:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(),
  119. op_engine_name.c_str(), op_desc->GetName().c_str());
  120. return SUCCESS;
  121. } else {
  122. ErrorManager::GetInstance().ATCReportErrMessage(
  123. "E13002", {"optype", "opskernel", "reason"}, {op_desc->GetType(), kernel_name, unsupported_reason});
  124. GELOGE(FAILED, "CheckEngineType: check support failed, Op type %s of ops kernel %s is unsupported, reason:%s",
  125. op_desc->GetType().c_str(), kernel_name.c_str(), unsupported_reason.c_str());
  126. return FAILED;
  127. }
  128. } else {
  129. ErrorManager::GetInstance().ATCReportErrMessage(
  130. "E13003", {"opname", "optype"}, {op_desc->GetName(), op_desc->GetType()});
  131. GELOGE(FAILED,
  132. "CheckEngineType:Can not find any supported ops kernel info store by kernel_name %s,"
  133. "op type is %s, op name is %s",
  134. kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str());
  135. }
  136. return FAILED;
  137. }
  138. static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, GeTensorDesc &tensor, int32_t index,
  139. bool attr) {
  140. GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
  141. GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);
  142. auto format = tensor.GetFormat();
  143. auto data_type = tensor.GetDataType();
  144. if (format == FORMAT_RESERVED && data_type == DT_UNDEFINED) {
  145. return SUCCESS;
  146. }
  147. string op_type;
  148. bool is_const = false;
  149. (void)AttrUtils::GetBool(tensor, CONST_ATTR_NAME_INPUT, is_const);
  150. if (is_const) {
  151. GELOGD("Get input[%d] is const", index);
  152. op_type = CONSTANTOP;
  153. } else if (!AttrUtils::GetStr(tensor, kAttrOpType, op_type) || op_type.empty()) {
  154. op_type = DATA;
  155. }
  156. string op_name = node->GetName() + "_in_" + std::to_string(index);
  157. OpDescPtr data_op = MakeShared<ge::OpDesc>(op_name, op_type);
  158. if (data_op == nullptr) {
  159. return FAILED;
  160. }
  161. if (is_const) {
  162. ConstGeTensorPtr tensor_value;
  163. if (!AttrUtils::GetTensor(tensor, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
  164. GELOGE(FAILED, "Get value failed, node name:%s.", tensor.GetName().c_str());
  165. return FAILED;
  166. }
  167. if (!AttrUtils::SetTensor(data_op, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
  168. GELOGE(FAILED, "Set attr ATTR_NAME_WEIGHTS fail.");
  169. return FAILED;
  170. }
  171. }
  172. (void)AttrUtils::SetBool(data_op, "_is_single_op", true);
  173. GE_CHK_BOOL_EXEC(data_op->AddInputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "Add input desc fail.");
  174. GE_CHK_BOOL_EXEC(data_op->AddOutputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "Add output desc fail.");
  175. if (attr) {
  176. GE_CHK_BOOL_EXEC(AttrUtils::SetInt(data_op, ATTR_NAME_INDEX, index), return FAILED, "Set index fail.");
  177. }
  178. ge::NodePtr arg_node = graph->AddNode(data_op);
  179. GE_CHK_BOOL_EXEC(arg_node != nullptr, return FAILED, "Insert Data node fail.");
  180. GE_CHK_STATUS(GraphUtils::AddEdge(arg_node->GetOutDataAnchor(0), node->GetInDataAnchor(index)),
  181. "Add edge[%s->%s] fail.", data_op->GetName().c_str(), node->GetName().c_str());
  182. return SUCCESS;
  183. }
  184. static Status AddOutputs(const ComputeGraphPtr &graph, const NodePtr &node, const vector<GeTensor> &outputs) {
  185. OpDescPtr op_desc = MakeShared<ge::OpDesc>(graph->GetName() + "_" + NODE_NAME_NET_OUTPUT, NETOUTPUT);
  186. if (op_desc == nullptr) {
  187. return FAILED;
  188. }
  189. (void)AttrUtils::SetBool(op_desc, "_is_single_op", true);
  190. int32_t count = 0;
  191. for (const auto &out_desc : outputs) {
  192. GeTensorDesc tensor = out_desc.GetTensorDesc();
  193. TensorUtils::SetInputTensor(tensor, true);
  194. GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "Add input desc fail");
  195. TensorUtils::SetInputTensor(tensor, false);
  196. TensorUtils::SetOutputTensor(tensor, true);
  197. GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(tensor) == GRAPH_SUCCESS, return FAILED, "Add output desc fail");
  198. count++;
  199. }
  200. GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
  201. ge::NodePtr out_node = graph->AddNode(op_desc);
  202. GE_CHK_BOOL_EXEC(out_node != nullptr, return FAILED, "Insert Output node fail.");
  203. GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);
  204. for (int32_t i = 0; i < count; ++i) {
  205. GE_CHK_STATUS(GraphUtils::AddEdge(node->GetOutDataAnchor(i), out_node->GetInDataAnchor(i)),
  206. "Add edge[%s->%s] fail.", node->GetName().c_str(), out_node->GetName().c_str());
  207. }
  208. return SUCCESS;
  209. }
  210. static void GetOpsProtoPath(string &opsproto_path) {
  211. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  212. if (path_env != nullptr) {
  213. string path = path_env;
  214. string file_path = RealPath(path.c_str());
  215. if (file_path.empty()) {
  216. GELOGE(FAILED, "File path %s is invalid.", path.c_str());
  217. return;
  218. }
  219. opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/");
  220. GELOGI("Get opsproto so path from env : %s", path.c_str());
  221. return;
  222. }
  223. string path_base = PluginManager::GetPath();
  224. GELOGI("path_base is %s", path_base.c_str());
  225. path_base = path_base.substr(0, path_base.rfind('/'));
  226. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  227. opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
  228. }
  229. static Status CheckShapeReset(const OpDescPtr &op_desc, bool &change_shape_flag) {
  230. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  231. change_shape_flag = false;
  232. for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) {
  233. auto input_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(i));
  234. GE_CHECK_NOTNULL(input_desc);
  235. // pass scalar input desc
  236. auto dims = input_desc->GetShape().GetDims();
  237. if (dims.size() == kDynamicDimSize && dims[0] == kDynamicDimValue) {
  238. change_shape_flag = true;
  239. }
  240. }
  241. for (size_t i = 0; i < op_desc->GetAllOutputsDesc().size(); i++) {
  242. auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(i));
  243. GE_CHECK_NOTNULL(output_desc);
  244. // pass scalar output desc
  245. auto dims = output_desc->GetShape().GetDims();
  246. if (dims.size() == kDynamicDimSize && dims[0] == kDynamicDimValue) {
  247. change_shape_flag = true;
  248. }
  249. }
  250. return SUCCESS;
  251. }
  252. static Status ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor> &inputs_dynamic) {
  253. for (auto input : inputs) {
  254. auto input_desc = input.GetTensorDesc();
  255. GeShape shape_ori = input_desc.GetShape();
  256. std::vector<int64_t> dynamic_shape_dims = {kDynamicDimValue};
  257. GeShape dynamic_shape(dynamic_shape_dims);
  258. std::vector<std::pair<int64_t, int64_t>> dynamic_shape_range;
  259. ge::GeTensor inputTensor;
  260. ge::GeTensorDesc desc(input_desc);
  261. bool is_const = false;
  262. (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);
  263. if (!is_const && shape_ori.GetDims().size() > 0) {
  264. int64_t storage_format = FORMAT_NCHW;
  265. if (ge::AttrUtils::GetInt(desc, ge::ATTR_NAME_STORAGE_FORMAT, storage_format) &&
  266. !ge::AttrUtils::SetListInt(desc, ge::ATTR_NAME_STORAGE_SHAPE, dynamic_shape_dims)) {
  267. GELOGE(FAILED, "Set attr ATTR_NAME_STORAGE_SHAPE fail.");
  268. return FAILED;
  269. }
  270. desc.SetShape(dynamic_shape);
  271. desc.SetShapeRange(dynamic_shape_range);
  272. }
  273. inputTensor.SetTensorDesc(desc);
  274. inputs_dynamic.push_back(inputTensor);
  275. }
  276. return SUCCESS;
  277. }
  278. class GeGenerator::Impl {
  279. public:
  280. Impl(OmgContext &omg_context) : omg_context_(omg_context) {}
  281. ~Impl() = default;
  282. Status BuildModel(const Graph &graph, const vector<GeTensor> &inputs, GeRootModelPtr &ge_models);
  283. Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model);
  284. Status SaveRootModel(const string &file_name_prefix, GeRootModelPtr &model, ModelBufferData &model_buff);
  285. Status SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
  286. const vector<GeTensor> &inputs, const vector<GeTensor> &outputs);
  287. Status GenerateInfershapeGraph(const Graph &graph);
  288. OmgContext &omg_context_;
  289. GraphManager graph_manager_;
  290. SaveParam save_param_;
  291. bool is_offline_ = true;
  292. bool is_singleop_unregistered_ = false;
  293. std::string build_mode_;
  294. std::string build_step_;
  295. static std::mutex mutex_;
  296. private:
  297. static std::string Trim(const std::string &str);
  298. bool ParseVersion(const std::string &line, std::string &version);
  299. bool GetVersionFromPath(const std::string &file_path, std::string &version);
  300. bool SetAtcVersionInfo(AttrHolder &obj);
  301. bool SetOppVersionInfo(AttrHolder &obj);
  302. };
  303. Status GeGenerator::Initialize(const map<string, string> &options) {
  304. return Initialize(options, domi::GetContext());
  305. }
  306. Status GeGenerator::Initialize(const map<string, string> &options, OmgContext &omg_context) {
  307. impl_ = ge::MakeShared<Impl>(omg_context);
  308. if (impl_ == nullptr) {
  309. GELOGE(MEMALLOC_FAILED, "Make shared failed");
  310. return MEMALLOC_FAILED;
  311. }
  312. string opsproto_path;
  313. GetOpsProtoPath(opsproto_path);
  314. GELOGI("Get opsproto path is %s", opsproto_path.c_str());
  315. OpsProtoManager *manager = OpsProtoManager::Instance();
  316. map<string, string> option_tmp;
  317. option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path));
  318. (void)manager->Initialize(option_tmp);
  319. Status ret = impl_->graph_manager_.Initialize(options);
  320. if (ret != SUCCESS) {
  321. GELOGE(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, "Graph manager initialize failed.");
  322. return GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED;
  323. }
  324. // get ek file
  325. auto iter = options.find(EK_FILE);
  326. if (iter != options.end()) {
  327. impl_->save_param_.ek_file = iter->second;
  328. }
  329. // get cert file
  330. iter = options.find(CERT_FILE);
  331. if (iter != options.end()) {
  332. impl_->save_param_.cert_file = iter->second;
  333. }
  334. // get hw key file
  335. iter = options.find(HW_KEY_FILE);
  336. if (iter != options.end()) {
  337. impl_->save_param_.hw_key_file = iter->second;
  338. }
  339. // get private file
  340. iter = options.find(PRIVATE_KEY_FILE);
  341. if (iter != options.end()) {
  342. impl_->save_param_.pri_key_file = iter->second;
  343. }
  344. // get build mode
  345. iter = options.find(BUILD_MODE);
  346. if (iter != options.end()) {
  347. impl_->build_mode_ = iter->second;
  348. }
  349. // get build step
  350. iter = options.find(BUILD_STEP);
  351. if (iter != options.end()) {
  352. impl_->build_step_ = iter->second;
  353. }
  354. return SUCCESS;
  355. }
  356. Status GeGenerator::Finalize() {
  357. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  358. Status ret = impl_->graph_manager_.Finalize();
  359. if (ret != SUCCESS) {
  360. GELOGE(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, "Graph manager finalize failed.");
  361. return GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED;
  362. }
  363. return SUCCESS;
  364. }
  365. Status GeGenerator::GenerateOfflineModel(const Graph &graph, const string &file_name_prefix,
  366. const vector<GeTensor> &inputs) {
  367. GELOGI("Start to generate offline model.");
  368. ModelBufferData model;
  369. return GenerateModel(graph, file_name_prefix, inputs, model, true);
  370. }
  371. Status GeGenerator::GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ModelBufferData &model) {
  372. return GenerateModel(graph, "online", inputs, model, false);
  373. }
  374. Status GeGenerator::GenerateInfershapeGraph(const Graph &graph) {
  375. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  376. Status ret = impl_->GenerateInfershapeGraph(graph);
  377. if (ret != SUCCESS) {
  378. GELOGE(ret, "Dump infershape json failed");
  379. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  380. GELOGE(FAILED, "graph_manager finalize fail.");
  381. }
  382. return ret;
  383. }
  384. GELOGI("Generate infer shape graph success");
  385. return SUCCESS;
  386. }
  387. std::mutex GeGenerator::Impl::mutex_;
  388. // Remove the space and tab before and after the string
  389. std::string GeGenerator::Impl::Trim(const std::string &str) {
  390. if (str.empty()) {
  391. return str;
  392. }
  393. std::string::size_type start = str.find_first_not_of(" \t\r\n");
  394. if (start == std::string::npos) {
  395. return str;
  396. }
  397. std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1;
  398. return str.substr(start, end);
  399. }
  400. // Parsing the command line
  401. bool GeGenerator::Impl::ParseVersion(const std::string &line, std::string &version) {
  402. std::string flag = "Version=";
  403. std::string temp = Trim(line);
  404. if (temp.empty()) {
  405. GELOGW("line is empty.");
  406. return false;
  407. }
  408. std::string::size_type pos = temp.find(flag);
  409. if (pos == std::string::npos) {
  410. GELOGW("Incorrect line [%s], it must include [%s].", line.c_str(), flag.c_str());
  411. return false;
  412. }
  413. if (temp.size() == flag.size()) {
  414. GELOGW("version information is empty. %s", line.c_str());
  415. return false;
  416. }
  417. version = temp.substr(pos + flag.size());
  418. return true;
  419. }
  420. bool GeGenerator::Impl::GetVersionFromPath(const std::string &file_path, std::string &version) {
  421. // Normalize the path
  422. string resolved_file_path = RealPath(file_path.c_str());
  423. if (resolved_file_path.empty()) {
  424. GELOGW("Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str());
  425. return false;
  426. }
  427. std::ifstream fs(resolved_file_path, std::ifstream::in);
  428. if (!fs.is_open()) {
  429. GELOGW("Open %s failed.", file_path.c_str());
  430. return false;
  431. }
  432. std::string line;
  433. if (getline(fs, line)) {
  434. if (!ParseVersion(line, version)) {
  435. GELOGW("Parse version failed. content is [%s].", line.c_str());
  436. fs.close();
  437. return false;
  438. }
  439. } else {
  440. GELOGW("No version information found in the file path:%s", file_path.c_str());
  441. fs.close();
  442. return false;
  443. }
  444. fs.close(); // close the file
  445. return true;
  446. }
  447. // Set package version information in the model
  448. bool GeGenerator::Impl::SetAtcVersionInfo(AttrHolder &obj) {
  449. std::string path_base = ge::GELib::GetPath();
  450. path_base = path_base.substr(0, path_base.rfind('/'));
  451. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  452. std::string version_path = path_base + "version.info";
  453. std::string version;
  454. if (!GetVersionFromPath(version_path, version)) {
  455. GELOGW("Get atc version information failed!");
  456. return false;
  457. }
  458. // set version info
  459. if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_ATC_VERSION, version)) {
  460. GELOGW("Ge model set atc version failed!");
  461. return false;
  462. }
  463. return true;
  464. }
  465. // Set package version information in the model
  466. bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) {
  467. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  468. if (path_env == nullptr) {
  469. GELOGW("Get environment variable ASCEND_OPP_PATH failed!");
  470. return false;
  471. }
  472. std::string version_path = path_env;
  473. version_path += "/version.info";
  474. std::string version;
  475. if (!GetVersionFromPath(version_path, version)) {
  476. GELOGW("Get opp version information failed!");
  477. return false;
  478. }
  479. // set version info
  480. if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_OPP_VERSION, version)) {
  481. GELOGW("Ge model set opp version failed!");
  482. return false;
  483. }
  484. return true;
  485. }
  486. static Status SetModelNameForDump(GeRootModelPtr ge_root_model) {
  487. ModelHelper model_helper;
  488. string model_name = "";
  489. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  490. Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(),
  491. model_name);
  492. if (name_ret != SUCCESS) {
  493. ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"});
  494. GELOGE(FAILED, "Get model_name failed. Param --output is invalid.");
  495. return PARAM_INVALID;
  496. }
  497. map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  498. GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
  499. GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model cannot be null");
  500. ge_model->SetName(model_name);
  501. return SUCCESS;
  502. }
  503. Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs,
  504. ModelBufferData &model, bool is_offline) {
  505. rtContext_t ctx = nullptr;
  506. auto rt = rtCtxGetCurrent(&ctx);
  507. if (rt != RT_ERROR_NONE) {
  508. GELOGD("Current ctx is null.");
  509. ctx = nullptr;
  510. }
  511. GeRootModelPtr ge_root_model = nullptr;
  512. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  513. impl_->is_offline_ = is_offline;
  514. Status ret = impl_->BuildModel(graph, inputs, ge_root_model);
  515. if (ret != SUCCESS) {
  516. GELOGE(ret, "Build model failed.");
  517. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  518. GELOGE(FAILED, "graph_manager finalize fail.");
  519. }
  520. return ret;
  521. }
  522. /// BUILD_MODE_TUNING with BUILD_STEP_BEFORE_UB_MATCH no need save model;
  523. /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER no need save model;
  524. /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need save model.
  525. if ((impl_->build_mode_ == BUILD_MODE_TUNING) &&
  526. (impl_->build_step_ == BUILD_STEP_BEFORE_UB_MATCH || impl_->build_step_ == BUILD_STEP_AFTER_BUILDER ||
  527. impl_->build_step_ == BUILD_STEP_AFTER_BUILDER_SUB)) {
  528. GELOGI("Build mode:%s with step:%s no need SaveModel.",
  529. impl_->build_mode_.c_str(),
  530. impl_->build_step_.c_str());
  531. return SUCCESS;
  532. }
  533. GE_CHECK_NOTNULL(ge_root_model);
  534. ret = SetModelNameForDump(ge_root_model);
  535. if (ret != SUCCESS) {
  536. return ret;
  537. }
  538. ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model);
  539. if (ret != SUCCESS) {
  540. GELOGE(ret, "Save model failed");
  541. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  542. GELOGE(FAILED, "graph_manager finalize fail.");
  543. }
  544. return ret;
  545. }
  546. if (ctx != nullptr) {
  547. (void)rtCtxSetCurrent(ctx);
  548. }
  549. return SUCCESS;
  550. }
  551. namespace {
  552. bool IsNeedConnectInputOpForSingleOp(GeTensorDesc &tensor_desc) {
  553. bool is_need = true;
  554. // format and dtype is all reserved, stand for Optional input. When singleop scene
  555. if (tensor_desc.GetFormat() == FORMAT_RESERVED && tensor_desc.GetDataType() == DT_UNDEFINED) {
  556. is_need = false;
  557. }
  558. return is_need;
  559. }
  560. }
  561. Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  562. const vector<GeTensor> &outputs) {
  563. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  564. if (!inputs.empty() && (inputs.size() != op_desc->GetAllInputsSize())) {
  565. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  566. {op_desc->GetName(), op_desc->GetType(), "inputs size" + FmtToStr(op_desc->GetAllInputsSize()),
  567. "tensor size is " + FmtToStr(inputs.size())});
  568. GELOGE(PARAM_INVALID, "Tensor size: %zu, Inputs size: %zu", inputs.size(), op_desc->GetAllInputsSize());
  569. return PARAM_INVALID;
  570. }
  571. if (!outputs.empty() && (outputs.size() != op_desc->GetOutputsSize())) {
  572. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  573. {op_desc->GetName(), op_desc->GetType(), "outputs size" + FmtToStr(op_desc->GetOutputsSize()),
  574. "tensor size is " + FmtToStr(outputs.size())});
  575. GELOGE(PARAM_INVALID, "Tensor size: %zu, Outputs size: %zu", outputs.size(), op_desc->GetOutputsSize());
  576. return PARAM_INVALID;
  577. }
  578. return SUCCESS;
  579. }
  580. Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
  581. const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
  582. bool is_offline) {
  583. if (!is_offline) {
  584. (void)AttrUtils::SetBool(op_desc, ATTR_DYNAMIC_SHAPE_SINGLE_AICPU, true);
  585. }
  586. if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) {
  587. GELOGE(PARAM_INVALID, "input param is invalid when build single op!");
  588. return PARAM_INVALID;
  589. }
  590. OmgContext &omg_context = (impl_ == nullptr) ? domi::GetContext() : impl_->omg_context_;
  591. omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc);
  592. if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
  593. impl_->is_singleop_unregistered_ = true;
  594. }
  595. // 0. Save original attributes.
  596. OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc);
  597. GE_CHECK_NOTNULL(op_desc_tmp);
  598. // 1. check engine type when compile online
  599. if (model_file_name == kFileNameSuffix) {
  600. Status ret = CheckEngineTypeSupport(op_desc, engine_type);
  601. if (ret != SUCCESS) {
  602. GELOGE(ret, "check engine type failed.");
  603. return ret;
  604. }
  605. }
  606. // 2. Create ComputeGraph.
  607. string name = ge::CurrentTimeInStr() + "_" + model_file_name;
  608. Graph graph;
  609. if (BuildSingleOpGraph(op_desc, inputs, outputs, name, graph) != ge::SUCCESS) {
  610. GELOGE(GRAPH_FAILED, "make graph fail.");
  611. return GRAPH_FAILED;
  612. }
  613. GELOGI("ATC parser success in single op build.");
  614. GeRootModelPtr ge_root_model = nullptr;
  615. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  616. impl_->is_offline_ = is_offline;
  617. GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, inputs, ge_root_model));
  618. map<string, GeAttrValue> op_attrs = op_desc_tmp->GetAllAttrs();
  619. GE_CHECK_NOTNULL(ge_root_model);
  620. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  621. map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  622. if (name_to_ge_model.empty()) {
  623. GELOGE(PARAM_INVALID, "GetSubgraphInstanceNameToModel is empty.");
  624. return PARAM_INVALID;
  625. }
  626. GeModelPtr &ge_model = name_to_ge_model.begin()->second;
  627. GELOGD("The opType in op_desc_tmp is [%s]", op_desc_tmp->GetType().c_str());
  628. bool dynamic_flag = false;
  629. if (CheckShapeReset(op_desc, dynamic_flag) == SUCCESS && dynamic_flag) {
  630. vector<GeTensor> inputs_dynamic;
  631. vector<GeTensor> outputs_dynamic;
  632. GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(inputs, inputs_dynamic));
  633. GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(outputs, outputs_dynamic));
  634. GE_CHK_STATUS_RET_NOLOG(
  635. impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs_dynamic, outputs_dynamic));
  636. } else {
  637. GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs));
  638. }
  639. GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_model, model_buff));
  640. return SUCCESS;
  641. }
  642. /**
  643. * @ingroup ge
  644. * @brief Compiling a single operator into an offline model
  645. * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file
  646. * @param [in] vector<GeTensor> &inputs: Operator input data description information.
  647. * @param [in] vector<GeTensor> &outputs: Operator output data description information.
  648. * @param [in] const string &model_file_name: Offline model filename.
  649. * @return SUCCESS handle successfully / others handle failed
  650. */
  651. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  652. const vector<GeTensor> &outputs, const string &model_file_name) {
  653. GELOGI("Start to build single op offline model.");
  654. ModelBufferData model_buff;
  655. OpEngineType engine_type = ENGINE_SYS;
  656. return BuildSingleOp(op_desc, inputs, outputs, model_file_name, engine_type, model_buff, true);
  657. }
  658. /**
  659. * @ingroup ge
  660. * @brief Compiling a single operator into online buffer
  661. * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file
  662. * @param [in] vector<GeTensor> &inputs: Operator input data description information.
  663. * @param [in] vector<GeTensor> &outputs: Operator output data description information.
  664. * @param [in] engine_type: specific engine.
  665. * @param [out] ModelBufferData &Model_buff: Model_buff: model buffer of the op.
  666. * @return SUCCESS handle successfully / others handle failed
  667. */
  668. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  669. const vector<GeTensor> &outputs, OpEngineType engine_type,
  670. ModelBufferData &model_buff) {
  671. GELOGI("Start to build single op online");
  672. return BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false);
  673. }
  674. Status GeGenerator::BuildSingleOpGraph(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  675. const vector<GeTensor> &outputs, std::string graph_name, Graph &graph) {
  676. ge::ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>(graph_name);
  677. GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR);
  678. // 1. Add Node to ComputeGraph.
  679. NodePtr op_node = compute_graph->AddNode(op_desc);
  680. GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR);
  681. // 2. Create InputData node.
  682. int32_t arg_index = 0;
  683. if (inputs.empty()) {
  684. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  685. GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR);
  686. if (!IsNeedConnectInputOpForSingleOp(*input_desc)) {
  687. continue;
  688. }
  689. GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false));
  690. arg_index++;
  691. }
  692. } else {
  693. for (const auto &in_desc : inputs) {
  694. GeTensorDesc input_desc = in_desc.GetTensorDesc();
  695. GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, input_desc, arg_index, true));
  696. arg_index++;
  697. }
  698. }
  699. // 3. Create Output node.
  700. if (!outputs.empty()) {
  701. GE_CHK_STATUS_RET_NOLOG(AddOutputs(compute_graph, op_node, outputs));
  702. }
  703. // dump ComputeGraph node.
  704. compute_graph->Dump();
  705. graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  706. return SUCCESS;
  707. }
  708. Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
  709. const vector<GeTensor> &inputs, const vector<GeTensor> &outputs) {
  710. GE_CHECK_NOTNULL_EXEC(ge_model, return PARAM_INVALID);
  711. GE_CHK_BOOL_EXEC_NOLOG(graph_manager_.SaveParams(*ge_model, type, attrs, inputs, outputs) == SUCCESS,
  712. (void)graph_manager_.Finalize();
  713. return FAILED);
  714. return SUCCESS;
  715. }
  716. Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr &model, ModelBufferData &model_buff) {
  717. // set atc version
  718. if (!SetAtcVersionInfo(*(model.get()))) {
  719. GELOGW("SetPackageVersionInfo of atc failed!");
  720. }
  721. // set opp version
  722. if (!SetOppVersionInfo(*(model.get()))) {
  723. GELOGW("SetPackageVersionInfo of ops failed!");
  724. }
  725. ModelHelper model_helper;
  726. model_helper.SetSaveMode(is_offline_);
  727. Status ret = model_helper.SaveToOmModel(model, save_param_, file_name_prefix, model_buff);
  728. if (ret != SUCCESS) {
  729. GELOGE(ret, "Save to om model failed");
  730. return ret;
  731. }
  732. return SUCCESS;
  733. }
  734. Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootModelPtr &ge_root_model,
  735. ModelBufferData &model_buff) {
  736. bool is_unknown_shape = false;
  737. auto ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
  738. if (ret != SUCCESS) {
  739. GELOGE(FAILED, "Check root model is unkonwn shape failed");
  740. return FAILED;
  741. }
  742. GELOGD("begin save root model, cur model is unkonwn shape model ? : %d", is_unknown_shape);
  743. GE_CHK_BOOL_EXEC(!ge_root_model->GetSubgraphInstanceNameToModel().empty(), return FAILED,
  744. "ge root model has no sub model")
  745. GeModelPtr model_root = nullptr;
  746. if (is_unknown_shape) {
  747. model_root = make_shared<GeModel>();
  748. model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph()));
  749. ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root);
  750. model_root->SetName(ge_root_model->GetRootGraph()->GetName());
  751. } else {
  752. model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second;
  753. }
  754. // set atc version
  755. if (!SetAtcVersionInfo(*(model_root.get()))) {
  756. GELOGW("SetPackageVersionInfo of atc failed!");
  757. }
  758. // set opp version
  759. if (!SetOppVersionInfo(*(model_root.get()))) {
  760. GELOGW("SetPackageVersionInfo of ops failed!");
  761. }
  762. ModelHelper model_helper;
  763. model_helper.SetSaveMode(is_offline_);
  764. ret = model_helper.SaveToOmRootModel(ge_root_model, save_param_, file_name_prefix, model_buff, is_unknown_shape);
  765. if (ret != SUCCESS) {
  766. GELOGE(ret, "Save to om model failed");
  767. return ret;
  768. }
  769. return SUCCESS;
  770. }
  771. Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs,
  772. GeRootModelPtr &ge_root_model) {
  773. static std::atomic<GraphId> atomic_graph_id(0);
  774. auto graph_id = atomic_graph_id.fetch_add(1);
  775. const std::map<std::string, std::string> options;
  776. Status ret = graph_manager_.AddGraph(graph_id, graph, options, omg_context_);
  777. if (ret != SUCCESS) {
  778. GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "GraphManager add graph fail, graph id: %u", graph_id);
  779. (void)graph_manager_.Finalize();
  780. return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED;
  781. }
  782. graph_manager_.SetOptionsRunGraphFlag(false);
  783. static std::atomic<uint64_t> atomic_session_id(0);
  784. auto session_id = atomic_session_id.fetch_add(1);
  785. if (is_singleop_unregistered_) {
  786. ret = graph_manager_.BuildGraphForUnregisteredOp(graph_id, inputs, ge_root_model, session_id);
  787. } else {
  788. ret = graph_manager_.BuildGraph(graph_id, inputs, ge_root_model, session_id);
  789. }
  790. if (ret != SUCCESS) {
  791. GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager build graph fail, graph id: %u", graph_id);
  792. VarManagerPool::Instance().RemoveVarManager(session_id);
  793. return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED;
  794. }
  795. VarManagerPool::Instance().RemoveVarManager(session_id);
  796. return SUCCESS;
  797. }
  798. Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph) {
  799. static std::atomic<GraphId> atomic_graph_id(0);
  800. auto graph_id = atomic_graph_id.fetch_add(1);
  801. const std::map<std::string, std::string> options;
  802. Status ret = graph_manager_.AddGraph(graph_id, graph, options, omg_context_);
  803. if (ret != SUCCESS) {
  804. GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "GraphManager add graph failed, graph id: %u", graph_id);
  805. (void)graph_manager_.Finalize();
  806. return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED;
  807. }
  808. ret = graph_manager_.GenerateInfershapeGraph(graph_id);
  809. if (ret != SUCCESS) {
  810. GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager generate graph failed");
  811. return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED;
  812. }
  813. return SUCCESS;
  814. }
  815. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示