You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_generator.cc 50 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "framework/generator/ge_generator.h"
  17. #include <atomic>
  18. #include "common/ge/ge_util.h"
  19. #include "common/ge/plugin_manager.h"
  20. #include "framework/common/helper/model_helper.h"
  21. #include "framework/common/helper/om_file_helper.h"
  22. #include "framework/common/util.h"
  23. #include "common/util/error_manager/error_manager.h"
  24. #include "framework/common/debug/ge_log.h"
  25. #include "framework/common/debug/log.h"
  26. #include "external/ge/ge_api.h"
  27. #include "graph/debug/ge_attr_define.h"
  28. #include "graph/ge_context.h"
  29. #include "graph/manager/graph_manager.h"
  30. #include "graph/manager/graph_var_manager.h"
  31. #include "graph/manager/util/rt_context_util.h"
  32. #include "graph/operator_factory_impl.h"
  33. #include "graph/opsproto_manager.h"
  34. #include "graph/utils/graph_utils.h"
  35. #include "graph/utils/type_utils.h"
  36. #include "init/gelib.h"
  37. #include "common/model/ge_model.h"
  38. #include "analyzer/analyzer.h"
  39. using std::map;
  40. using std::string;
  41. using std::vector;
  42. namespace {
  43. const char *const kAttrOpType = "op_type";
  44. const char *const kEngineNameDefault = "default";
  45. const char *const kVectorEngine = "VectorEngine";
  46. const char *const kAIcoreEngine = "AIcoreEngine";
  47. const char *const kFileNameSuffix = "online";
  48. const char *const kAicpuAllshape = "_AllShape";
  49. constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape";
  50. const int64_t kDynamicDimValue = -2;
  51. const int kDefaultDeviceId = 0;
  52. const int kDefaultJobId = 0;
  53. const int32_t kFuzzBuildPattern = 1;
  54. std::map<ge::OpEngineType, std::string> engine_type_map{
  55. {ge::ENGINE_SYS, kEngineNameDefault},
  56. {ge::ENGINE_AICORE, kAIcoreEngine},
  57. {ge::ENGINE_VECTOR, kVectorEngine}};
  58. bool ContainsDynamicInpus(const ge::OpDesc &op_desc) {
  59. for (auto &tensor_desc : op_desc.GetAllInputsDescPtr()) {
  60. if (tensor_desc->MutableShape().IsUnknownShape()) {
  61. GELOGI("Contains unknown shape input. set is_dynamic_input to true.");
  62. return true;
  63. }
  64. }
  65. return false;
  66. }
  67. // if optional in/out, format is format_reserved and dtype is dt_undefined
  68. bool IsOptional(const ge::GeTensorDesc &tensor_desc) {
  69. return tensor_desc.GetFormat() == ge::FORMAT_RESERVED && tensor_desc.GetDataType() == ge::DT_UNDEFINED;
  70. }
  71. } // namespace
  72. namespace ge {
  73. static Status CheckEngineTypeSupport(const NodePtr &node, OpEngineType engine_type) {
  74. const OpDescPtr &op_desc = node->GetOpDesc();
  75. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  76. if (engine_type == ENGINE_SYS) {
  77. GELOGI("CheckEngineType: use default engine.");
  78. return SUCCESS;
  79. }
  80. // get op engine name
  81. string op_engine_name;
  82. auto iter = engine_type_map.find(engine_type);
  83. if (iter != engine_type_map.end()) {
  84. op_engine_name = iter->second;
  85. GELOGI("CheckEngineType: engine type: %d", static_cast<int>(engine_type));
  86. } else {
  87. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  88. {op_desc->GetName(), op_desc->GetType(), "engine type",
  89. "it only support default/AIcoreEngine/VectorEngine"});
  90. GELOGE(FAILED, "[Check][Param] value:%d not support, "
  91. "only support default/AIcoreEngine/VectorEngine now", static_cast<int>(engine_type));
  92. return FAILED;
  93. }
  94. if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
  95. op_desc->SetOpEngineName(op_engine_name);
  96. op_desc->SetOpKernelLibName(op_engine_name);
  97. return SUCCESS;
  98. }
  99. // set op engine name and opkernelLib. when engine support
  100. std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
  101. if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
  102. REPORT_INNER_ERROR("E19999", "get gelib failed, as get instance failed or initflag failed.");
  103. GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Get][GELib] CheckEngineType failed, as get gelib failed.");
  104. return FAILED;
  105. }
  106. OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj();
  107. std::vector<OpInfo> op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType());
  108. if (op_infos.empty()) {
  109. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  110. {op_desc->GetName(), op_desc->GetType(), "optype", "it can not find"});
  111. GELOGE(FAILED, "[Get][OpInfo] by op type %s failed.", op_desc->GetType().c_str());
  112. return FAILED;
  113. }
  114. string kernel_name;
  115. for (const auto &it : op_infos) {
  116. if (it.engine == op_engine_name) {
  117. kernel_name = it.opKernelLib;
  118. break;
  119. }
  120. }
  121. if (kernel_name.empty()) {
  122. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  123. {op_desc->GetName(), op_desc->GetType(), "engine name" + FmtToStr(op_engine_name), "it can not find"});
  124. GELOGE(FAILED, "[Check][Param] Can not find ops kernel, engine name:%s. op:%s(%s)",
  125. op_engine_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str());
  126. return FAILED;
  127. }
  128. auto &kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores();
  129. auto kernel_info_store = kernel_map.find(kernel_name);
  130. if (kernel_info_store != kernel_map.end()) {
  131. std::string unsupported_reason;
  132. if (kernel_info_store->second->CheckSupported(node, unsupported_reason)) {
  133. op_desc->SetOpEngineName(op_engine_name);
  134. op_desc->SetOpKernelLibName(kernel_name);
  135. GELOGI("CheckEngineType:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(),
  136. op_engine_name.c_str(), op_desc->GetName().c_str());
  137. return SUCCESS;
  138. } else {
  139. ErrorManager::GetInstance().ATCReportErrMessage(
  140. "E13002", {"optype", "opskernel", "reason"}, {op_desc->GetType(), kernel_name, unsupported_reason});
  141. GELOGE(FAILED, "[Call][CheckSupported] failed, Op type %s of ops kernel %s is unsupported, reason:%s",
  142. op_desc->GetType().c_str(), kernel_name.c_str(), unsupported_reason.c_str());
  143. return FAILED;
  144. }
  145. } else {
  146. ErrorManager::GetInstance().ATCReportErrMessage(
  147. "E13003", {"opname", "optype"}, {op_desc->GetName(), op_desc->GetType()});
  148. GELOGE(FAILED, "[Check][Param] Can not find any supported ops kernel info store by kernel_name %s,"
  149. "op type is %s, op name is %s",
  150. kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str());
  151. }
  152. return FAILED;
  153. }
  154. static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const GeTensorDesc &tensor, int32_t index,
  155. bool attr, int32_t &data_index) {
  156. GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
  157. GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);
  158. auto format = tensor.GetFormat();
  159. auto data_type = tensor.GetDataType();
  160. if (format == FORMAT_RESERVED && data_type == DT_UNDEFINED) {
  161. return SUCCESS;
  162. }
  163. string op_type;
  164. bool is_const = false;
  165. (void)AttrUtils::GetBool(tensor, CONST_ATTR_NAME_INPUT, is_const);
  166. if (is_const) {
  167. GELOGD("Get input[%d] is const", index);
  168. op_type = CONSTANTOP;
  169. } else if (!AttrUtils::GetStr(tensor, kAttrOpType, op_type) || op_type.empty()) {
  170. op_type = DATA;
  171. }
  172. string op_name = node->GetName() + "_in_" + std::to_string(index);
  173. OpDescPtr data_op = MakeShared<ge::OpDesc>(op_name, op_type);
  174. if (data_op == nullptr) {
  175. REPORT_CALL_ERROR("E19999", "create OpDesc failed, name:%s", op_name.c_str());
  176. GELOGE(FAILED, "[Create][OpDesc] failed, name:%s", op_name.c_str());
  177. return FAILED;
  178. }
  179. if (is_const) {
  180. ConstGeTensorPtr tensor_value;
  181. if (!AttrUtils::GetTensor(tensor, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
  182. REPORT_CALL_ERROR("E19999", "get attr %s failed, tensor:%s.",
  183. ge::ATTR_NAME_WEIGHTS.c_str(), tensor.GetName().c_str());
  184. GELOGE(FAILED, "[Get][Attr] %s failed, tensor:%s.", ge::ATTR_NAME_WEIGHTS.c_str(), tensor.GetName().c_str());
  185. return FAILED;
  186. }
  187. if (!AttrUtils::SetTensor(data_op, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
  188. REPORT_CALL_ERROR("E19999", "set attr %s failed, op:%s.", ge::ATTR_NAME_WEIGHTS.c_str(), op_name.c_str());
  189. GELOGE(FAILED, "[Set][Attr] %s failed, op:%s.", ge::ATTR_NAME_WEIGHTS.c_str(), op_name.c_str());
  190. return FAILED;
  191. }
  192. }
  193. (void)AttrUtils::SetBool(data_op, "_is_single_op", true);
  194. (void)AttrUtils::SetBool(data_op, ATTR_NAME_IS_ORIGINAL_INPUT, true);
  195. GE_CHK_BOOL_EXEC(data_op->AddInputDesc(tensor) == GRAPH_SUCCESS,
  196. REPORT_CALL_ERROR("E19999", "AddInputDesc failed for node:%s", data_op->GetName().c_str());
  197. return FAILED, "[Add][InputDesc] fail for node:%s", data_op->GetName().c_str());
  198. GE_CHK_BOOL_EXEC(data_op->AddOutputDesc(tensor) == GRAPH_SUCCESS,
  199. REPORT_CALL_ERROR("E19999", "AddOutputDesc failed for node:%s", data_op->GetName().c_str());
  200. return FAILED, "[Add][OutputDesc] fail for node:%s", data_op->GetName().c_str());
  201. if (attr && !is_const) {
  202. GE_CHK_BOOL_EXEC(AttrUtils::SetInt(data_op, ATTR_NAME_INDEX, data_index),
  203. REPORT_CALL_ERROR("E19999", "set attr %s failed for node:%s",
  204. ATTR_NAME_INDEX.c_str(), data_op->GetName().c_str());
  205. return FAILED,
  206. "[Set][Attr:%s] fail for node:%s", ATTR_NAME_INDEX.c_str(), data_op->GetName().c_str());
  207. ++data_index;
  208. }
  209. ge::NodePtr arg_node = graph->AddNode(data_op);
  210. GE_CHK_BOOL_EXEC(arg_node != nullptr,
  211. REPORT_CALL_ERROR("E19999", "add node:%s to graph:%s failed", data_op->GetName().c_str(),
  212. graph->GetName().c_str());
  213. return FAILED, "[Add][Node] Insert Data node:%s fail", data_op->GetName().c_str());
  214. GE_CHK_STATUS(GraphUtils::AddEdge(arg_node->GetOutDataAnchor(0), node->GetInDataAnchor(index)),
  215. "[Add][Edge]fail from node:%s to node:%s", data_op->GetName().c_str(), node->GetName().c_str());
  216. return SUCCESS;
  217. }
  218. static Status AddOutputs(const ComputeGraphPtr &graph, const NodePtr &node, const vector<GeTensor> &outputs) {
  219. OpDescPtr op_desc = MakeShared<ge::OpDesc>(graph->GetName() + "_" + NODE_NAME_NET_OUTPUT, NETOUTPUT);
  220. if (op_desc == nullptr) {
  221. REPORT_CALL_ERROR("E19999", "create OpDesc failed, graph:%s", graph->GetName().c_str());
  222. GELOGE(FAILED, "[Create][OpDesc] failed, graph:%s", graph->GetName().c_str());
  223. return FAILED;
  224. }
  225. (void)AttrUtils::SetBool(op_desc, "_is_single_op", true);
  226. int32_t count = 0;
  227. for (const auto &out_desc : outputs) {
  228. GeTensorDesc tensor = out_desc.GetTensorDesc();
  229. TensorUtils::SetInputTensor(tensor, true);
  230. GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(tensor) == GRAPH_SUCCESS,
  231. REPORT_CALL_ERROR("E19999", "AddInputDesc failed for node:%s", op_desc->GetName().c_str());
  232. return FAILED, "[Add][InputDesc]fail for node:%s", op_desc->GetName().c_str());
  233. TensorUtils::SetInputTensor(tensor, false);
  234. TensorUtils::SetOutputTensor(tensor, true);
  235. GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(tensor) == GRAPH_SUCCESS,
  236. REPORT_CALL_ERROR("E19999", "AddOutputDesc failed for node:%s", op_desc->GetName().c_str());
  237. return FAILED, "[Add][OutputDesc]fail for node:%s", op_desc->GetName().c_str());
  238. count++;
  239. }
  240. GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
  241. ge::NodePtr out_node = graph->AddNode(op_desc);
  242. GE_CHK_BOOL_EXEC(out_node != nullptr,
  243. REPORT_CALL_ERROR("E19999", "add node:%s to graph:%u failed.",
  244. op_desc->GetName().c_str(), graph->GetGraphID());
  245. return FAILED,
  246. "[Add][Node:%s]fail in graph:%u", op_desc->GetName().c_str(), graph->GetGraphID());
  247. GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);
  248. for (int32_t i = 0; i < count; ++i) {
  249. GE_CHK_STATUS(GraphUtils::AddEdge(node->GetOutDataAnchor(i), out_node->GetInDataAnchor(i)),
  250. "[Add][Edge]fail from node:%s to node:%s", node->GetName().c_str(), out_node->GetName().c_str());
  251. }
  252. return SUCCESS;
  253. }
  254. static void GetOpsProtoPath(string &opsproto_path) {
  255. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  256. if (path_env != nullptr) {
  257. string path = path_env;
  258. string file_path = RealPath(path.c_str());
  259. if (file_path.empty()) {
  260. REPORT_CALL_ERROR("E19999", "File path %s is invalid.", path.c_str());
  261. GELOGE(FAILED, "[Call][RealPath] File path %s is invalid.", path.c_str());
  262. return;
  263. }
  264. opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/");
  265. GELOGI("Get opsproto so path from env : %s", path.c_str());
  266. return;
  267. }
  268. string path_base = PluginManager::GetPath();
  269. GELOGI("path_base is %s", path_base.c_str());
  270. path_base = path_base.substr(0, path_base.rfind('/'));
  271. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  272. opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
  273. }
  274. static Status ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor> &inputs_dynamic) {
  275. for (auto input : inputs) {
  276. auto input_desc = input.GetTensorDesc();
  277. GeShape shape_ori = input_desc.GetShape();
  278. std::vector<int64_t> dynamic_shape_dims = {kDynamicDimValue};
  279. GeShape dynamic_shape(dynamic_shape_dims);
  280. std::vector<std::pair<int64_t, int64_t>> dynamic_shape_range;
  281. ge::GeTensor inputTensor;
  282. ge::GeTensorDesc desc(input_desc);
  283. bool is_const = false;
  284. (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);
  285. if (!is_const) {
  286. int64_t storage_format = FORMAT_NCHW;
  287. if (ge::AttrUtils::GetInt(desc, ge::ATTR_NAME_STORAGE_FORMAT, storage_format) &&
  288. !ge::AttrUtils::SetListInt(desc, ge::ATTR_NAME_STORAGE_SHAPE, dynamic_shape_dims)) {
  289. REPORT_CALL_ERROR("E19999", "Set attr ATTR_NAME_STORAGE_SHAPE failed to op:%s.", desc.GetName().c_str());
  290. GELOGE(FAILED, "[Set][Attr] ATTR_NAME_STORAGE_SHAPE fail.");
  291. return FAILED;
  292. }
  293. desc.SetShape(dynamic_shape);
  294. desc.SetShapeRange(dynamic_shape_range);
  295. }
  296. inputTensor.SetTensorDesc(desc);
  297. inputs_dynamic.push_back(inputTensor);
  298. }
  299. return SUCCESS;
  300. }
  301. static Status GetFuzzBuildAttrs(const OpDescPtr &op_desc, const GeRootModelPtr &ge_root_model,
  302. GeAttrValue::LIST_NAMED_ATTRS &fuzz_build_attrs) {
  303. GELOGD("Start get fuzz build attrs of %s.", op_desc->GetName().c_str());
  304. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  305. for (const auto &node : ge_root_model->GetRootGraph()->GetAllNodes()) {
  306. GE_CHECK_NOTNULL(node);
  307. GE_CHECK_NOTNULL(node->GetOpDesc());
  308. GELOGD("Delete fuzz build attr of %s after build.", node->GetName().c_str());
  309. node->GetOpDesc()->DelAttr(ATTR_NAME_FUZZ_BUILD);
  310. }
  311. (void)AttrUtils::GetListNamedAttrs(op_desc, ATTR_NAME_FUZZ_BUILD_RES_ATTRS, fuzz_build_attrs);
  312. if (!fuzz_build_attrs.empty()) {
  313. GELOGD("%s has split, get ATTR_NAME_FUZZ_BUILD_RES_ATTRS directly.", op_desc->GetName().c_str());
  314. return SUCCESS;
  315. } else {
  316. GELOGW("%s build with fuzz build pattern, but not set ATTR_NAME_FUZZ_BUILD_RES_ATTRS.", op_desc->GetName().c_str());
  317. }
  318. return SUCCESS;
  319. }
  320. static bool HasShapeRange(const vector<GeTensor> &inputs) {
  321. for (const auto &input : inputs) {
  322. vector<pair<int64_t, int64_t>> shape_range;
  323. (void)input.GetTensorDesc().GetShapeRange(shape_range);
  324. if (!shape_range.empty()) {
  325. GELOGD("Has set shape range.");
  326. return true;
  327. }
  328. }
  329. return false;
  330. }
  331. class GeGenerator::Impl {
  332. public:
  333. Impl(OmgContext &omg_context) : omg_context_(omg_context) {}
  334. ~Impl() = default;
  335. Status BuildModel(const Graph &graph, const vector<GeTensor> &inputs, GeRootModelPtr &ge_models);
  336. Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model);
  337. Status SaveRootModel(const string &file_name_prefix, GeRootModelPtr &model, ModelBufferData &model_buff);
  338. Status SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
  339. const vector<GeTensor> &inputs, const vector<GeTensor> &outputs);
  340. Status GenerateInfershapeGraph(const Graph &graph);
  341. OmgContext &omg_context_;
  342. GraphManager graph_manager_;
  343. SaveParam save_param_;
  344. bool is_offline_ = true;
  345. bool is_singleop_unregistered_ = false;
  346. std::string build_mode_;
  347. std::string build_step_;
  348. static std::mutex mutex_;
  349. private:
  350. static std::string Trim(const std::string &str);
  351. bool ParseVersion(const std::string &line, std::string &version);
  352. bool GetVersionFromPath(const std::string &file_path, std::string &version);
  353. bool SetAtcVersionInfo(AttrHolder &obj);
  354. bool SetOppVersionInfo(AttrHolder &obj);
  355. bool SetOmSystemInfo(AttrHolder &obj);
  356. };
  357. Status GeGenerator::Initialize(const map<string, string> &options) {
  358. return Initialize(options, domi::GetContext());
  359. }
  360. Status GeGenerator::Initialize(const map<string, string> &options, OmgContext &omg_context) {
  361. impl_ = ge::MakeShared<Impl>(omg_context);
  362. if (impl_ == nullptr) {
  363. REPORT_CALL_ERROR("E19999", "create Impl failed.");
  364. GELOGE(MEMALLOC_FAILED, "[Create][Impl] Make shared failed");
  365. return MEMALLOC_FAILED;
  366. }
  367. ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOpsProtoInit);
  368. string opsproto_path;
  369. GetOpsProtoPath(opsproto_path);
  370. GELOGI("Get opsproto path is %s", opsproto_path.c_str());
  371. OpsProtoManager *manager = OpsProtoManager::Instance();
  372. map<string, string> option_tmp;
  373. option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path));
  374. (void)manager->Initialize(option_tmp);
  375. Status ret = impl_->graph_manager_.Initialize(options);
  376. if (ret != SUCCESS) {
  377. GELOGE(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, "[Call][Initialize] Graph manager initialize failed.");
  378. return GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED;
  379. }
  380. // get ek file
  381. auto iter = options.find(EK_FILE);
  382. if (iter != options.end()) {
  383. impl_->save_param_.ek_file = iter->second;
  384. }
  385. // get cert file
  386. iter = options.find(CERT_FILE);
  387. if (iter != options.end()) {
  388. impl_->save_param_.cert_file = iter->second;
  389. }
  390. // get hw key file
  391. iter = options.find(HW_KEY_FILE);
  392. if (iter != options.end()) {
  393. impl_->save_param_.hw_key_file = iter->second;
  394. }
  395. // get private file
  396. iter = options.find(PRIVATE_KEY_FILE);
  397. if (iter != options.end()) {
  398. impl_->save_param_.pri_key_file = iter->second;
  399. }
  400. // get build mode
  401. iter = options.find(BUILD_MODE);
  402. if (iter != options.end()) {
  403. impl_->build_mode_ = iter->second;
  404. }
  405. // get build step
  406. iter = options.find(BUILD_STEP);
  407. if (iter != options.end()) {
  408. impl_->build_step_ = iter->second;
  409. }
  410. return SUCCESS;
  411. }
  412. Status GeGenerator::Finalize() {
  413. ErrorManager::GetInstance().SetStage(error_message::kFinalize, error_message::kFinalize);
  414. if (impl_ == nullptr) {
  415. return SUCCESS;
  416. }
  417. Status ret = impl_->graph_manager_.Finalize();
  418. if (ret != SUCCESS) {
  419. GELOGE(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, "[Call][Finalize] Graph manager finalize failed.");
  420. return GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED;
  421. }
  422. return SUCCESS;
  423. }
  424. Status GeGenerator::GenerateOfflineModel(const Graph &graph, const string &file_name_prefix,
  425. const vector<GeTensor> &inputs) {
  426. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  427. GELOGI("Start to generate offline model.");
  428. ModelBufferData model;
  429. return GenerateModel(graph, file_name_prefix, inputs, model, true);
  430. }
  431. Status GeGenerator::GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ModelBufferData &model) {
  432. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  433. return GenerateModel(graph, "online", inputs, model, false);
  434. }
  435. Status GeGenerator::GenerateInfershapeGraph(const Graph &graph) {
  436. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  437. Status ret = impl_->GenerateInfershapeGraph(graph);
  438. if (ret != SUCCESS) {
  439. GELOGE(ret, "[Call][GenerateInfershapeGraph] Dump infershape json failed");
  440. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  441. GELOGE(FAILED, "[Call][Finalize] graph_manager finalize fail.");
  442. }
  443. return ret;
  444. }
  445. GELOGI("Generate infer shape graph success");
  446. return SUCCESS;
  447. }
  448. std::mutex GeGenerator::Impl::mutex_;
  449. // Remove the space and tab before and after the string
  450. std::string GeGenerator::Impl::Trim(const std::string &str) {
  451. if (str.empty()) {
  452. return str;
  453. }
  454. std::string::size_type start = str.find_first_not_of(" \t\r\n");
  455. if (start == std::string::npos) {
  456. return str;
  457. }
  458. std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1;
  459. return str.substr(start, end);
  460. }
  461. // Parsing the command line
  462. bool GeGenerator::Impl::ParseVersion(const std::string &line, std::string &version) {
  463. std::string flag = "Version=";
  464. std::string temp = Trim(line);
  465. if (temp.empty()) {
  466. GELOGW("line is empty.");
  467. return false;
  468. }
  469. std::string::size_type pos = temp.find(flag);
  470. if (pos == std::string::npos) {
  471. GELOGW("Incorrect line [%s], it must include [%s].", line.c_str(), flag.c_str());
  472. return false;
  473. }
  474. if (temp.size() == flag.size()) {
  475. GELOGW("version information is empty. %s", line.c_str());
  476. return false;
  477. }
  478. version = temp.substr(pos + flag.size());
  479. return true;
  480. }
  481. bool GeGenerator::Impl::GetVersionFromPath(const std::string &file_path, std::string &version) {
  482. // Normalize the path
  483. string resolved_file_path = RealPath(file_path.c_str());
  484. if (resolved_file_path.empty()) {
  485. GELOGW("Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str());
  486. return false;
  487. }
  488. std::ifstream fs(resolved_file_path, std::ifstream::in);
  489. if (!fs.is_open()) {
  490. GELOGW("Open %s failed.", file_path.c_str());
  491. return false;
  492. }
  493. std::string line;
  494. if (getline(fs, line)) {
  495. if (!ParseVersion(line, version)) {
  496. GELOGW("Parse version failed. content is [%s].", line.c_str());
  497. fs.close();
  498. return false;
  499. }
  500. } else {
  501. GELOGW("No version information found in the file path:%s", file_path.c_str());
  502. fs.close();
  503. return false;
  504. }
  505. fs.close(); // close the file
  506. return true;
  507. }
  508. // Set package version information in the model
  509. bool GeGenerator::Impl::SetAtcVersionInfo(AttrHolder &obj) {
  510. std::string path_base = ge::GELib::GetPath();
  511. path_base = path_base.substr(0, path_base.rfind('/'));
  512. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  513. std::string version_path = path_base + "version.info";
  514. std::string version;
  515. if (!GetVersionFromPath(version_path, version)) {
  516. GELOGW("Get atc version information failed!");
  517. return false;
  518. }
  519. // set version info
  520. if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_ATC_VERSION, version)) {
  521. GELOGW("Ge model set atc version failed!");
  522. return false;
  523. }
  524. return true;
  525. }
  526. // Set package version information in the model
  527. bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) {
  528. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  529. if (path_env == nullptr) {
  530. GELOGW("Get environment variable ASCEND_OPP_PATH failed!");
  531. return false;
  532. }
  533. std::string version_path = path_env;
  534. version_path += "/version.info";
  535. std::string version;
  536. if (!GetVersionFromPath(version_path, version)) {
  537. GELOGW("Get opp version information failed!");
  538. return false;
  539. }
  540. // set version info
  541. if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_OPP_VERSION, version)) {
  542. GELOGW("Ge model set opp version failed!");
  543. return false;
  544. }
  545. return true;
  546. }
  547. bool GeGenerator::Impl::SetOmSystemInfo(AttrHolder &obj) {
  548. std::string soc_version;
  549. (void)ge::GetContext().GetOption(ge::SOC_VERSION, soc_version);
  550. GELOGI("SetOmSystemInfo soc_version: %s", soc_version.c_str());
  551. if (!ge::AttrUtils::SetStr(obj, "soc_version", soc_version)) {
  552. GELOGW("SetStr of soc_version failed.");
  553. return false;
  554. }
  555. std::string framework_type;
  556. (void)ge::GetContext().GetOption(ge::FRAMEWORK_TYPE, framework_type);
  557. GELOGI("SetOmSystemInfo framework_type: %s", framework_type.c_str());
  558. auto iter = ge::kFwkTypeToStr.find(framework_type);
  559. if (iter == ge::kFwkTypeToStr.end()) {
  560. GELOGW("Can not find framework_type in the map.");
  561. return false;
  562. }
  563. if (!ge::AttrUtils::SetStr(obj, "framework_type", iter->second)) {
  564. GELOGW("SetStr of framework_type failed.");
  565. return false;
  566. }
  567. return true;
  568. }
  569. Status GeGenerator::SetModelNameForDump(const GeRootModelPtr &ge_root_model) {
  570. bool is_unknown_shape = false;
  571. Status ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
  572. if (ret != SUCCESS) {
  573. GELOGE(FAILED, "[Check][IsUnknownShape]Check root model is unknown shape failed, model id:%u",
  574. ge_root_model->GetModelId());
  575. REPORT_CALL_ERROR("E19999", "Check root model is unknown shape failed, model id:%u",
  576. ge_root_model->GetModelId());
  577. return FAILED;
  578. }
  579. GeModelPtr model_root = nullptr;
  580. if (is_unknown_shape) {
  581. model_root = MakeShared<GeModel>();
  582. GE_CHECK_NOTNULL(model_root);
  583. model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph()));
  584. ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root);
  585. }
  586. ModelHelper model_helper;
  587. string model_name;
  588. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  589. Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(),
  590. model_name);
  591. if (name_ret != SUCCESS) {
  592. ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"});
  593. GELOGE(FAILED, "[Check][GetModelNameStep]Get model_name failed. Param --output is invalid, root graph name: %s",
  594. ge_root_model->GetRootGraph()->GetName().c_str());
  595. return PARAM_INVALID;
  596. }
  597. map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  598. GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
  599. GE_CHECK_NOTNULL(ge_model);
  600. ge_model->SetName(model_name);
  601. return SUCCESS;
  602. }
  603. Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs,
  604. ModelBufferData &model, bool is_offline) {
  605. rtContext_t ctx = nullptr;
  606. auto rt = rtCtxGetCurrent(&ctx);
  607. if (rt != RT_ERROR_NONE) {
  608. GELOGD("Current ctx is null.");
  609. ctx = nullptr;
  610. }
  611. std::function<void()> callback = [&]() {
  612. if (ctx != nullptr) {
  613. (void)rtCtxSetCurrent(ctx);
  614. }
  615. };
  616. GE_MAKE_GUARD(restore, callback);
  617. GeRootModelPtr ge_root_model = nullptr;
  618. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  619. impl_->is_offline_ = is_offline;
  620. Status ret = impl_->BuildModel(graph, inputs, ge_root_model);
  621. if (ret != SUCCESS) {
  622. GELOGE(ret, "[Build][Model] failed, ret:%d.", ret);
  623. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  624. GELOGE(FAILED, "[Call][Finalize] graph_manager finalize fail.");
  625. }
  626. return ret;
  627. }
  628. /// BUILD_MODE_TUNING with BUILD_STEP_BEFORE_UB_MATCH no need save model;
  629. /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER no need save model;
  630. /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need save model.
  631. if ((impl_->build_mode_ == BUILD_MODE_TUNING) &&
  632. (impl_->build_step_ == BUILD_STEP_BEFORE_UB_MATCH || impl_->build_step_ == BUILD_STEP_AFTER_BUILDER ||
  633. impl_->build_step_ == BUILD_STEP_AFTER_BUILDER_SUB)) {
  634. GELOGI("Build mode:%s with step:%s no need SaveModel.",
  635. impl_->build_mode_.c_str(),
  636. impl_->build_step_.c_str());
  637. return SUCCESS;
  638. }
  639. GE_CHECK_NOTNULL(ge_root_model);
  640. ret = SetModelNameForDump(ge_root_model);
  641. if (ret != SUCCESS) {
  642. return ret;
  643. }
  644. ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model);
  645. if (ret != SUCCESS) {
  646. GELOGE(ret, "[Save][RootModel] failed, ret:%d, file:%s", ret, file_name_prefix.c_str());
  647. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  648. GELOGE(FAILED, "graph_manager finalize fail.");
  649. }
  650. return ret;
  651. }
  652. return SUCCESS;
  653. }
  654. namespace {
  655. bool IsNeedConnectInputOpForSingleOp(GeTensorDesc &tensor_desc) {
  656. bool is_need = true;
  657. // format and dtype is all reserved, stand for Optional input. When singleop scene
  658. if (tensor_desc.GetFormat() == FORMAT_RESERVED && tensor_desc.GetDataType() == DT_UNDEFINED) {
  659. is_need = false;
  660. }
  661. return is_need;
  662. }
  663. Status CheckDynamicSupport(GeModelPtr &ge_model, const ComputeGraphPtr &graph) {
  664. bool support_dynamic = true;
  665. bool is_dynamic = false;
  666. for (const auto &node : graph->GetDirectNode()) {
  667. GE_CHECK_NOTNULL(node);
  668. auto op_desc = node->GetOpDesc();
  669. GE_CHECK_NOTNULL(op_desc);
  670. if (op_desc->GetOpEngineName() != kAIcoreEngine) {
  671. continue;
  672. }
  673. if (AttrUtils::HasAttr(op_desc, kAttrSupportDynamicShape)) {
  674. is_dynamic = true;
  675. (void) AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic);
  676. if (!support_dynamic) {
  677. GELOGW("Node[%s] doesn't support dynamic shape.", node->GetName().c_str());
  678. break;
  679. }
  680. }
  681. }
  682. if (is_dynamic) {
  683. (void) AttrUtils::SetBool(ge_model, kAttrSupportDynamicShape, support_dynamic);
  684. }
  685. return SUCCESS;
  686. }
  687. }
  688. bool GeGenerator::CheckNoAicore(const ComputeGraphPtr &graph) {
  689. for (const auto &node : graph->GetDirectNode()) {
  690. if (node == nullptr) {
  691. continue;
  692. }
  693. auto op_desc = node->GetOpDesc();
  694. if (op_desc == nullptr) {
  695. continue;
  696. }
  697. if (op_desc->GetOpEngineName() == kAIcoreEngine) {
  698. return false;
  699. }
  700. }
  701. return true;
  702. }
  703. void GeGenerator::RemoveConst(const vector<GeTensor> &inputs, vector<GeTensor> &outputs) {
  704. for (auto &input : inputs) {
  705. GeTensorDesc input_desc = input.GetTensorDesc();
  706. bool is_const = false;
  707. (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);
  708. bool is_optional = IsOptional(input_desc);
  709. if (!is_optional && !is_const) {
  710. outputs.emplace_back(input);
  711. }
  712. }
  713. }
  714. Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  715. const vector<GeTensor> &outputs) {
  716. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  717. if (!inputs.empty() && (inputs.size() != op_desc->GetAllInputsSize())) {
  718. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  719. {op_desc->GetName(), op_desc->GetType(), "inputs size" + FmtToStr(op_desc->GetAllInputsSize()),
  720. "tensor size is " + FmtToStr(inputs.size())});
  721. GELOGE(PARAM_INVALID, "[Check][Param] Tensor size: %zu, op:%s(%s) Inputs size: %zu, not equal",
  722. inputs.size(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), op_desc->GetAllInputsSize());
  723. return PARAM_INVALID;
  724. }
  725. if (!outputs.empty() && (outputs.size() != op_desc->GetOutputsSize())) {
  726. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  727. {op_desc->GetName(), op_desc->GetType(), "outputs size" + FmtToStr(op_desc->GetOutputsSize()),
  728. "tensor size is " + FmtToStr(outputs.size())});
  729. GELOGE(PARAM_INVALID, "[Check][Param] Tensor size: %zu, op:%s(%s) Outputs size: %zu, not equal",
  730. outputs.size(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), op_desc->GetOutputsSize());
  731. return PARAM_INVALID;
  732. }
  733. return SUCCESS;
  734. }
  735. Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc, Graph &graph) {
  736. GE_CHECK_NOTNULL(op_desc);
  737. if (OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()) != nullptr) {
  738. auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_desc->GetType());
  739. if (node_op.IsEmpty()) {
  740. GELOGW("get op from OperatorFactory fail. op type: %s", op_desc->GetType().c_str());
  741. } else {
  742. GELOGD("get op from OperatorFactory success. op type: %s", op_desc->GetType().c_str());
  743. auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
  744. if (temp_op_desc == nullptr) {
  745. REPORT_INNER_ERROR("E19999", "GetOpDescFromOperator failed, as return nullptr, type:%s",
  746. op_desc->GetType().c_str());
  747. GELOGE(FAILED, "[Get][OpDesc] temp op desc is null, type:%s", op_desc->GetType().c_str());
  748. return FAILED;
  749. }
  750. if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) {
  751. GELOGW("InferFormatForSingleOp UpdateInputName failed");
  752. }
  753. if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) {
  754. GELOGW("InferFormatForSingleOp UpdateOutputName failed");
  755. }
  756. }
  757. node_op.BreakConnect();
  758. }
  759. auto comp_graph = GraphUtils::GetComputeGraph(graph);
  760. GE_CHECK_NOTNULL(comp_graph);
  761. auto node = comp_graph->FindNode(op_desc->GetName());
  762. GE_CHECK_NOTNULL(node);
  763. auto op = OpDescUtils::CreateOperatorFromNode(node);
  764. auto ret = op_desc->CallInferFormatFunc(op);
  765. if (ret != GRAPH_SUCCESS) {
  766. REPORT_INNER_ERROR("E19999", "call InferFormatFunc for single op:%s fail",
  767. op_desc->GetName().c_str());
  768. GELOGE(FAILED, "[Call][InferFormatFunc] for single op:%s fail.", op_desc->GetName().c_str());
  769. return FAILED;
  770. }
  771. return SUCCESS;
  772. }
  773. Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
  774. const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
  775. bool is_offline, int32_t compile_flag) {
  776. GELOGD("Inputs size is %zu, outputs size is %zu.", inputs.size(), outputs.size());
  777. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  778. impl_->is_offline_ = is_offline;
  779. (void)AttrUtils::SetBool(op_desc, ATTR_SINGLE_OP_SCENE, true);
  780. if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) {
  781. GELOGE(PARAM_INVALID, "[Check][Param] input param is invalid when build single op:%s!",
  782. op_desc->GetName().c_str());
  783. return PARAM_INVALID;
  784. }
  785. OmgContext &omg_context = impl_->omg_context_;
  786. omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc);
  787. if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
  788. impl_->is_singleop_unregistered_ = true;
  789. }
  790. // 0. Save original attributes.
  791. OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc);
  792. GE_CHECK_NOTNULL(op_desc_tmp);
  793. bool fuzz_compile_flag = false;
  794. if (!HasShapeRange(inputs) && compile_flag == kFuzzBuildPattern) {
  795. fuzz_compile_flag = true;
  796. }
  797. (void)AttrUtils::SetBool(op_desc, ATTR_NAME_FUZZ_BUILD, fuzz_compile_flag);
  798. impl_->omg_context_.fuzz_compile_flag = fuzz_compile_flag;
  799. // 1. Create ComputeGraph.
  800. string name = ge::CurrentTimeInStr() + "_" + model_file_name;
  801. Graph graph;
  802. GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph),
  803. "[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str());
  804. GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc, graph));
  805. // 2. check engine type when compile online
  806. if (model_file_name == kFileNameSuffix) {
  807. auto comp_graph = GraphUtils::GetComputeGraph(graph);
  808. GE_CHECK_NOTNULL(comp_graph);
  809. auto node = comp_graph->FindNode(op_desc->GetName());
  810. Status ret = CheckEngineTypeSupport(node, engine_type);
  811. if (ret != SUCCESS) {
  812. GELOGE(ret, "[Check][EngineType]not support node:%s with engine of %d.", node->GetName().c_str(), engine_type);
  813. return ret;
  814. }
  815. }
  816. GELOGI("ATC parser success in single op build.");
  817. GeRootModelPtr ge_root_model = nullptr;
  818. vector<GeTensor> data_inputs;
  819. RemoveConst(inputs, data_inputs);
  820. GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, data_inputs, ge_root_model));
  821. map<string, GeAttrValue> op_attrs = op_desc_tmp->GetAllAttrs();
  822. GE_CHECK_NOTNULL(ge_root_model);
  823. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  824. map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  825. if (name_to_ge_model.empty()) {
  826. REPORT_CALL_ERROR("E19999", "GetSubgraphInstanceNameToModel failed.");
  827. GELOGE(PARAM_INVALID, "[Get][Name] GetSubgraphInstanceNameToModel is empty.");
  828. return PARAM_INVALID;
  829. }
  830. const ComputeGraphPtr root_graph = ge_root_model->GetRootGraph();
  831. GeModelPtr &ge_model = name_to_ge_model.begin()->second;
  832. GE_CHK_STATUS_RET_NOLOG(CheckDynamicSupport(ge_model, root_graph));
  833. GELOGI("After build model, The opType in op_desc_tmp is [%s]", op_desc_tmp->GetType().c_str());
  834. bool all_shape = false;
  835. (void)AttrUtils::GetBool(op_desc, kAicpuAllshape, all_shape);
  836. GELOGD("Node: %s, all_shape is %d, compile_flag is %d.", op_desc->GetName().c_str(), all_shape, compile_flag);
  837. (void)AttrUtils::SetInt(ge_model, ATTR_NAME_BUILD_MODE, fuzz_compile_flag);
  838. if (all_shape) {
  839. (void)AttrUtils::SetBool(ge_model, kAicpuAllshape, all_shape);
  840. }
  841. if (all_shape && CheckNoAicore(root_graph)) {
  842. GELOGD("Get aicpu all_shape kernel!");
  843. vector<GeTensor> inputs_dynamic;
  844. vector<GeTensor> outputs_dynamic;
  845. GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(inputs, inputs_dynamic));
  846. GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(outputs, outputs_dynamic));
  847. GE_CHK_STATUS_RET_NOLOG(
  848. impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs_dynamic, outputs_dynamic));
  849. } else if (fuzz_compile_flag) {
  850. GeAttrValue::LIST_NAMED_ATTRS fuzz_build_attrs;
  851. if (GetFuzzBuildAttrs(op_desc, ge_root_model, fuzz_build_attrs) != SUCCESS) {
  852. GELOGE(FAILED, "[Get][FuzzRet]Failed to get fuzz build result of %s.", op_desc->GetName().c_str());
  853. return FAILED;
  854. }
  855. if (!fuzz_build_attrs.empty()) {
  856. GE_CHK_BOOL_EXEC(AttrUtils::SetListNamedAttrs(ge_model, ATTR_NAME_FUZZ_BUILD_RES_ATTRS, fuzz_build_attrs),
  857. REPORT_CALL_ERROR("E19999", "Set model:%s(id:%u) attr:%s failed.",
  858. ge_model->GetName().c_str(), ge_model->GetModelId(),
  859. ATTR_NAME_FUZZ_BUILD_RES_ATTRS.c_str());
  860. return FAILED, "Set model:%s(id:%u) attr:%s failed.",
  861. ge_model->GetName().c_str(), ge_model->GetModelId(), ATTR_NAME_FUZZ_BUILD_RES_ATTRS.c_str());
  862. }
  863. GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs));
  864. } else {
  865. GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs));
  866. }
  867. GELOGI("Start save GeModel to Model buffer");
  868. GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_model, model_buff));
  869. return SUCCESS;
  870. }
  871. /**
  872. * @ingroup ge
  873. * @brief Compiling a single operator into an offline model
  874. * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file
  875. * @param [in] vector<GeTensor> &inputs: Operator input data description information.
  876. * @param [in] vector<GeTensor> &outputs: Operator output data description information.
  877. * @param [in] const string &model_file_name: Offline model filename.
  878. * @param [in] compile_flag: op build flag from atc
  879. * @return SUCCESS handle successfully / others handle failed
  880. */
  881. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  882. const vector<GeTensor> &outputs, const string &model_file_name,
  883. int32_t compile_flag) {
  884. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  885. GELOGI("Start to build single op offline model, input size: %zu, output size: %zu", inputs.size(), outputs.size());
  886. ModelBufferData model_buff;
  887. OpEngineType engine_type = ENGINE_SYS;
  888. Status status = BuildSingleOp(op_desc, inputs, outputs, model_file_name, engine_type, model_buff, true, compile_flag);
  889. GELOGI("Finish build single offline model, status: %u", status);
  890. return status;
  891. }
  892. /**
  893. * @ingroup ge
  894. * @brief Compiling a single operator into online buffer
  895. * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file
  896. * @param [in] vector<GeTensor> &inputs: Operator input data description information.
  897. * @param [in] vector<GeTensor> &outputs: Operator output data description information.
  898. * @param [in] engine_type: specific engine.
  899. * @param [in] compile_flag: op build flag, compile flag by acl
  900. * @param [out] ModelBufferData &Model_buff: Model_buff: model buffer of the op.
  901. * @return SUCCESS handle successfully / others handle failed
  902. */
  903. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  904. const vector<GeTensor> &outputs, OpEngineType engine_type,
  905. ModelBufferData &model_buff) {
  906. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  907. GELOGI("Start to build single op online, input size: %zu, output size: %zu", inputs.size(), outputs.size());
  908. Status status = BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false);
  909. GELOGI("Finish build single online model, status: %u", status);
  910. return status;
  911. }
  912. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  913. const vector<GeTensor> &outputs, OpEngineType engine_type, int32_t compile_flag,
  914. ModelBufferData &model_buff) {
  915. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  916. GELOGI("Start to build single op online, input size: %zu, output size: %zu", inputs.size(), outputs.size());
  917. Status status = BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false,
  918. compile_flag);
  919. GELOGI("Finish build single online model, status: %u", status);
  920. return status;
  921. }
  922. Status GeGenerator::BuildSingleOpGraph(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  923. const vector<GeTensor> &outputs, std::string graph_name, Graph &graph) {
  924. ge::ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>(graph_name);
  925. GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR);
  926. // 1. Add Node to ComputeGraph.
  927. NodePtr op_node = compute_graph->AddNode(op_desc);
  928. GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR);
  929. // 2. Create InputData node.
  930. int32_t arg_index = 0;
  931. int32_t data_index = 0;
  932. if (inputs.empty()) {
  933. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  934. GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR);
  935. if (!IsNeedConnectInputOpForSingleOp(*input_desc)) {
  936. continue;
  937. }
  938. GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false, data_index));
  939. arg_index++;
  940. }
  941. } else {
  942. for (const auto &in_desc : inputs) {
  943. GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, in_desc.GetTensorDesc(), arg_index, true, data_index));
  944. arg_index++;
  945. }
  946. }
  947. // 3. Create Output node.
  948. if (!outputs.empty()) {
  949. GE_CHK_STATUS_RET_NOLOG(AddOutputs(compute_graph, op_node, outputs));
  950. }
  951. // dump ComputeGraph node.
  952. compute_graph->Dump();
  953. graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  954. return SUCCESS;
  955. }
  956. Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
  957. const vector<GeTensor> &inputs, const vector<GeTensor> &outputs) {
  958. GE_CHECK_NOTNULL_EXEC(ge_model, return PARAM_INVALID);
  959. GE_CHK_BOOL_EXEC_NOLOG(graph_manager_.SaveParams(*ge_model, type, attrs, inputs, outputs) == SUCCESS,
  960. (void)graph_manager_.Finalize();
  961. return FAILED);
  962. return SUCCESS;
  963. }
  964. Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr &model, ModelBufferData &model_buff) {
  965. // set atc version
  966. if (!SetAtcVersionInfo(*(model.get()))) {
  967. GELOGW("SetPackageVersionInfo of atc failed!");
  968. }
  969. // set opp version
  970. if (!SetOppVersionInfo(*(model.get()))) {
  971. GELOGW("SetPackageVersionInfo of ops failed!");
  972. }
  973. ModelHelper model_helper;
  974. model_helper.SetSaveMode(is_offline_);
  975. Status ret = model_helper.SaveToOmModel(model, save_param_, file_name_prefix, model_buff);
  976. if (ret != SUCCESS) {
  977. GELOGE(ret, "[Call][SaveToOmModel] Save to om model failed");
  978. return ret;
  979. }
  980. return SUCCESS;
  981. }
  982. Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootModelPtr &ge_root_model,
  983. ModelBufferData &model_buff) {
  984. bool is_unknown_shape = false;
  985. auto ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
  986. if (ret != SUCCESS) {
  987. REPORT_CALL_ERROR("E19999", "root model(id:%u) CheckIsUnknownShape failed, ret:%d",
  988. ge_root_model->GetModelId(), ret);
  989. GELOGE(FAILED, "[Check][RootModel] is unkonwn shape failed, ret:%d", ret);
  990. return FAILED;
  991. }
  992. GELOGD("begin save root model, cur model is unkonwn shape model ? : %d", is_unknown_shape);
  993. GE_CHK_BOOL_EXEC(!ge_root_model->GetSubgraphInstanceNameToModel().empty(),
  994. REPORT_CALL_ERROR("E19999", "root model(id:%u) has no sub model.", ge_root_model->GetModelId());
  995. return FAILED, "[Get][SubModel] ge root model has no sub model")
  996. GeModelPtr model_root = nullptr;
  997. if (is_unknown_shape) {
  998. auto name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  999. model_root = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
  1000. } else {
  1001. model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second;
  1002. }
  1003. GE_CHECK_NOTNULL(model_root);
  1004. // set atc version
  1005. if (!SetAtcVersionInfo(*(model_root.get()))) {
  1006. GELOGW("SetPackageVersionInfo of atc failed!");
  1007. }
  1008. // set opp version
  1009. if (!SetOppVersionInfo(*(model_root.get()))) {
  1010. GELOGW("SetPackageVersionInfo of ops failed!");
  1011. }
  1012. if (!SetOmSystemInfo(*(model_root.get()))) {
  1013. GELOGW("SetOmsystemInfo failed!");
  1014. }
  1015. ModelHelper model_helper;
  1016. model_helper.SetSaveMode(is_offline_);
  1017. ret = model_helper.SaveToOmRootModel(ge_root_model, save_param_, file_name_prefix, model_buff, is_unknown_shape);
  1018. if (ret != SUCCESS) {
  1019. REPORT_CALL_ERROR("E19999", "SaveToOmRootModel failed, ret:%d, model id:%u", ret, ge_root_model->GetModelId());
  1020. GELOGE(ret, "[Call][SaveToOmRootModel] failed, ret:%d, model id:%u", ret, ge_root_model->GetModelId());
  1021. return ret;
  1022. }
  1023. return SUCCESS;
  1024. }
  1025. Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs,
  1026. GeRootModelPtr &ge_root_model) {
  1027. static std::atomic<GraphId> atomic_graph_id(0);
  1028. auto graph_id = atomic_graph_id.fetch_add(1);
  1029. const std::map<std::string, std::string> options;
  1030. Status ret = graph_manager_.AddGraph(graph_id, graph, options, omg_context_);
  1031. if (ret != SUCCESS) {
  1032. REPORT_CALL_ERROR("E19999", "add graph(id:%u) failed, ret:%d", graph_id, ret);
  1033. GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "[Add][Graph] fail, graph id: %u", graph_id);
  1034. (void)graph_manager_.Finalize();
  1035. return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED;
  1036. }
  1037. graph_manager_.SetOptionsRunGraphFlag(false);
  1038. static std::atomic<uint64_t> atomic_session_id(0);
  1039. auto session_id = atomic_session_id.fetch_add(1);
  1040. // This is a temporary add for graph with variable
  1041. auto version = static_cast<int32_t>(SessionVersion::ClOUD_VERSION);
  1042. ret = VarManager::Instance(session_id)->Init(version, session_id, kDefaultDeviceId, kDefaultJobId);
  1043. GELOGI("Start init var instance, session_id %lu", session_id);
  1044. if (ret != SUCCESS) {
  1045. GELOGW("Failed init var instance, session_id %lu", session_id);
  1046. }
  1047. if (is_singleop_unregistered_) {
  1048. ret = graph_manager_.BuildGraphForUnregisteredOp(graph_id, inputs, ge_root_model, session_id);
  1049. } else {
  1050. ret = graph_manager_.BuildGraph(graph_id, inputs, ge_root_model, session_id);
  1051. }
  1052. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  1053. if (ret != SUCCESS) {
  1054. REPORT_CALL_ERROR("E19999", "build graph failed, graph id:%u, ret:%d", graph_id, ret);
  1055. GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "[Build][Graph] fail, graph id: %u", graph_id);
  1056. }
  1057. RtContextUtil::GetInstance().DestroyRtContexts(session_id);
  1058. Analyzer::GetInstance()->DestroySessionJsonObject(session_id);
  1059. VarManagerPool::Instance().RemoveVarManager(session_id);
  1060. return ret;
  1061. }
  1062. Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph) {
  1063. static std::atomic<GraphId> atomic_graph_id(0);
  1064. auto graph_id = atomic_graph_id.fetch_add(1);
  1065. const std::map<std::string, std::string> options;
  1066. Status ret = graph_manager_.AddGraph(graph_id, graph, options, omg_context_);
  1067. if (ret != SUCCESS) {
  1068. REPORT_CALL_ERROR("E19999", "add graph failed, graph id:%u, ret:%d", graph_id, ret);
  1069. GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "[Add][Graph] failed, graph id: %u", graph_id);
  1070. (void)graph_manager_.Finalize();
  1071. return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED;
  1072. }
  1073. ret = graph_manager_.GenerateInfershapeGraph(graph_id);
  1074. if (ret != SUCCESS) {
  1075. REPORT_CALL_ERROR("E19999", "GenerateInfershapeGraph failed, graph id:%u, ret:%d", graph_id, ret);
  1076. GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED,
  1077. "[Generate][Graph] failed, graph id:%u, ret:%d", graph_id, ret);
  1078. return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED;
  1079. }
  1080. return SUCCESS;
  1081. }
  1082. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示