You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ge_generator.cc 50 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "generator/ge_generator.h"
  17. #include <atomic>
  18. #include "common/ge/ge_util.h"
  19. #include "common/ge/plugin_manager.h"
  20. #include "common/helper/model_helper.h"
  21. #include "common/helper/om_file_helper.h"
  22. #include "common/util.h"
  23. #include "common/util/error_manager/error_manager.h"
  24. #include "framework/common/debug/ge_log.h"
  25. #include "framework/common/debug/log.h"
  26. #include "ge/ge_api.h"
  27. #include "graph/debug/ge_attr_define.h"
  28. #include "graph/ge_context.h"
  29. #include "graph/manager/graph_manager.h"
  30. #include "graph/manager/util/rt_context_util.h"
  31. #include "graph/operator_factory_impl.h"
  32. #include "graph/opsproto_manager.h"
  33. #include "graph/utils/graph_utils.h"
  34. #include "graph/utils/type_utils.h"
  35. #include "init/gelib.h"
  36. #include "model/ge_model.h"
  37. #include "analyzer/analyzer.h"
  38. using std::map;
  39. using std::string;
  40. using std::vector;
  41. namespace {
  42. const char *const kAttrOpType = "op_type";
  43. const char *const kEngineNameDefault = "default";
  44. const char *const kVectorEngine = "VectorEngine";
  45. const char *const kAIcoreEngine = "AIcoreEngine";
  46. const char *const kFileNameSuffix = "online";
  47. const char *const kAicpuAllshape = "_AllShape";
  48. constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape";
  49. const int64_t kDynamicDimValue = -2;
  50. const int kDefaultDeviceId = 0;
  51. const int kDefaultJobId = 0;
  52. const int32_t kFuzzBuildPattern = 1;
  53. std::map<ge::OpEngineType, std::string> engine_type_map{
  54. {ge::ENGINE_SYS, kEngineNameDefault},
  55. {ge::ENGINE_AICORE, kAIcoreEngine},
  56. {ge::ENGINE_VECTOR, kVectorEngine}};
  57. bool ContainsDynamicInpus(const ge::OpDesc &op_desc) {
  58. for (auto &tensor_desc : op_desc.GetAllInputsDescPtr()) {
  59. if (tensor_desc->MutableShape().IsUnknownShape()) {
  60. GELOGI("Contains unknown shape input. set is_dynamic_input to true.");
  61. return true;
  62. }
  63. }
  64. return false;
  65. }
  66. // if optional in/out, format is format_reserved and dtype is dt_undefined
  67. bool IsOptional(const ge::GeTensorDesc &tensor_desc) {
  68. return tensor_desc.GetFormat() == ge::FORMAT_RESERVED && tensor_desc.GetDataType() == ge::DT_UNDEFINED;
  69. }
  70. } // namespace
  71. namespace ge {
  72. static Status CheckEngineTypeSupport(const NodePtr &node, OpEngineType engine_type) {
  73. const OpDescPtr &op_desc = node->GetOpDesc();
  74. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  75. if (engine_type == ENGINE_SYS) {
  76. GELOGI("CheckEngineType: use default engine.");
  77. return SUCCESS;
  78. }
  79. // get op engine name
  80. string op_engine_name;
  81. auto iter = engine_type_map.find(engine_type);
  82. if (iter != engine_type_map.end()) {
  83. op_engine_name = iter->second;
  84. GELOGI("CheckEngineType: engine type: %d", static_cast<int>(engine_type));
  85. } else {
  86. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  87. {op_desc->GetName(), op_desc->GetType(), "engine type",
  88. "it only support default/AIcoreEngine/VectorEngine"});
  89. GELOGE(FAILED, "[Check][Param] value:%d not support, "
  90. "only support default/AIcoreEngine/VectorEngine now", static_cast<int>(engine_type));
  91. return FAILED;
  92. }
  93. if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
  94. op_desc->SetOpEngineName(op_engine_name);
  95. op_desc->SetOpKernelLibName(op_engine_name);
  96. return SUCCESS;
  97. }
  98. // set op engine name and opkernelLib. when engine support
  99. std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
  100. if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) {
  101. REPORT_INNER_ERROR("E19999", "get gelib failed, as get instance failed or initflag failed.");
  102. GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Get][GELib] CheckEngineType failed, as get gelib failed.");
  103. return FAILED;
  104. }
  105. OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj();
  106. std::vector<OpInfo> op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType());
  107. if (op_infos.empty()) {
  108. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  109. {op_desc->GetName(), op_desc->GetType(), "optype", "it can not find"});
  110. GELOGE(FAILED, "[Get][OpInfo] by op type %s failed.", op_desc->GetType().c_str());
  111. return FAILED;
  112. }
  113. string kernel_name;
  114. for (const auto &it : op_infos) {
  115. if (it.engine == op_engine_name) {
  116. kernel_name = it.opKernelLib;
  117. break;
  118. }
  119. }
  120. if (kernel_name.empty()) {
  121. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  122. {op_desc->GetName(), op_desc->GetType(), "engine name" + FmtToStr(op_engine_name), "it can not find"});
  123. GELOGE(FAILED, "[Check][Param] Can not find ops kernel, engine name:%s. op:%s(%s)",
  124. op_engine_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str());
  125. return FAILED;
  126. }
  127. auto &kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores();
  128. auto kernel_info_store = kernel_map.find(kernel_name);
  129. if (kernel_info_store != kernel_map.end()) {
  130. std::string unsupported_reason;
  131. if (kernel_info_store->second->CheckSupported(node, unsupported_reason)) {
  132. op_desc->SetOpEngineName(op_engine_name);
  133. op_desc->SetOpKernelLibName(kernel_name);
  134. GELOGI("CheckEngineType:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(),
  135. op_engine_name.c_str(), op_desc->GetName().c_str());
  136. return SUCCESS;
  137. } else {
  138. ErrorManager::GetInstance().ATCReportErrMessage(
  139. "E13002", {"optype", "opskernel", "reason"}, {op_desc->GetType(), kernel_name, unsupported_reason});
  140. GELOGE(FAILED, "[Call][CheckSupported] failed, Op type %s of ops kernel %s is unsupported, reason:%s",
  141. op_desc->GetType().c_str(), kernel_name.c_str(), unsupported_reason.c_str());
  142. return FAILED;
  143. }
  144. } else {
  145. ErrorManager::GetInstance().ATCReportErrMessage(
  146. "E13003", {"opname", "optype"}, {op_desc->GetName(), op_desc->GetType()});
  147. GELOGE(FAILED, "[Check][Param] Can not find any supported ops kernel info store by kernel_name %s,"
  148. "op type is %s, op name is %s",
  149. kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str());
  150. }
  151. return FAILED;
  152. }
  153. static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const GeTensorDesc &tensor, int32_t index,
  154. bool attr, int32_t &data_index) {
  155. GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
  156. GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);
  157. auto format = tensor.GetFormat();
  158. auto data_type = tensor.GetDataType();
  159. if (format == FORMAT_RESERVED && data_type == DT_UNDEFINED) {
  160. return SUCCESS;
  161. }
  162. string op_type;
  163. bool is_const = false;
  164. (void)AttrUtils::GetBool(tensor, CONST_ATTR_NAME_INPUT, is_const);
  165. if (is_const) {
  166. GELOGD("Get input[%d] is const", index);
  167. op_type = CONSTANTOP;
  168. } else if (!AttrUtils::GetStr(tensor, kAttrOpType, op_type) || op_type.empty()) {
  169. op_type = DATA;
  170. }
  171. string op_name = node->GetName() + "_in_" + std::to_string(index);
  172. OpDescPtr data_op = MakeShared<ge::OpDesc>(op_name, op_type);
  173. if (data_op == nullptr) {
  174. REPORT_CALL_ERROR("E19999", "create OpDesc failed, name:%s", op_name.c_str());
  175. GELOGE(FAILED, "[Create][OpDesc] failed, name:%s", op_name.c_str());
  176. return FAILED;
  177. }
  178. if (is_const) {
  179. ConstGeTensorPtr tensor_value;
  180. if (!AttrUtils::GetTensor(tensor, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
  181. REPORT_CALL_ERROR("E19999", "get attr %s failed, tensor:%s.",
  182. ge::ATTR_NAME_WEIGHTS.c_str(), tensor.GetName().c_str());
  183. GELOGE(FAILED, "[Get][Attr] %s failed, tensor:%s.", ge::ATTR_NAME_WEIGHTS.c_str(), tensor.GetName().c_str());
  184. return FAILED;
  185. }
  186. if (!AttrUtils::SetTensor(data_op, ge::ATTR_NAME_WEIGHTS, tensor_value)) {
  187. REPORT_CALL_ERROR("E19999", "set attr %s failed, op:%s.", ge::ATTR_NAME_WEIGHTS.c_str(), op_name.c_str());
  188. GELOGE(FAILED, "[Set][Attr] %s failed, op:%s.", ge::ATTR_NAME_WEIGHTS.c_str(), op_name.c_str());
  189. return FAILED;
  190. }
  191. }
  192. (void)AttrUtils::SetBool(data_op, "_is_single_op", true);
  193. GE_CHK_BOOL_EXEC(data_op->AddInputDesc(tensor) == GRAPH_SUCCESS,
  194. REPORT_CALL_ERROR("E19999", "AddInputDesc failed for node:%s", data_op->GetName().c_str());
  195. return FAILED, "[Add][InputDesc] fail for node:%s", data_op->GetName().c_str());
  196. GE_CHK_BOOL_EXEC(data_op->AddOutputDesc(tensor) == GRAPH_SUCCESS,
  197. REPORT_CALL_ERROR("E19999", "AddOutputDesc failed for node:%s", data_op->GetName().c_str());
  198. return FAILED, "[Add][OutputDesc] fail for node:%s", data_op->GetName().c_str());
  199. if (attr && !is_const) {
  200. GE_CHK_BOOL_EXEC(AttrUtils::SetInt(data_op, ATTR_NAME_INDEX, data_index),
  201. REPORT_CALL_ERROR("E19999", "set attr %s failed for node:%s",
  202. ATTR_NAME_INDEX.c_str(), data_op->GetName().c_str());
  203. return FAILED,
  204. "[Set][Attr:%s] fail for node:%s", ATTR_NAME_INDEX.c_str(), data_op->GetName().c_str());
  205. ++data_index;
  206. }
  207. ge::NodePtr arg_node = graph->AddNode(data_op);
  208. GE_CHK_BOOL_EXEC(arg_node != nullptr,
  209. REPORT_CALL_ERROR("E19999", "add node:%s to graph:%s failed", data_op->GetName().c_str(),
  210. graph->GetName().c_str());
  211. return FAILED, "[Add][Node] Insert Data node:%s fail", data_op->GetName().c_str());
  212. GE_CHK_STATUS(GraphUtils::AddEdge(arg_node->GetOutDataAnchor(0), node->GetInDataAnchor(index)),
  213. "[Add][Edge]fail from node:%s to node:%s", data_op->GetName().c_str(), node->GetName().c_str());
  214. return SUCCESS;
  215. }
  216. static Status AddOutputs(const ComputeGraphPtr &graph, const NodePtr &node, const vector<GeTensor> &outputs) {
  217. OpDescPtr op_desc = MakeShared<ge::OpDesc>(graph->GetName() + "_" + NODE_NAME_NET_OUTPUT, NETOUTPUT);
  218. if (op_desc == nullptr) {
  219. REPORT_CALL_ERROR("E19999", "create OpDesc failed, graph:%s", graph->GetName().c_str());
  220. GELOGE(FAILED, "[Create][OpDesc] failed, graph:%s", graph->GetName().c_str());
  221. return FAILED;
  222. }
  223. (void)AttrUtils::SetBool(op_desc, "_is_single_op", true);
  224. int32_t count = 0;
  225. for (const auto &out_desc : outputs) {
  226. GeTensorDesc tensor = out_desc.GetTensorDesc();
  227. TensorUtils::SetInputTensor(tensor, true);
  228. GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(tensor) == GRAPH_SUCCESS,
  229. REPORT_CALL_ERROR("E19999", "AddInputDesc failed for node:%s", op_desc->GetName().c_str());
  230. return FAILED, "[Add][InputDesc]fail for node:%s", op_desc->GetName().c_str());
  231. TensorUtils::SetInputTensor(tensor, false);
  232. TensorUtils::SetOutputTensor(tensor, true);
  233. GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(tensor) == GRAPH_SUCCESS,
  234. REPORT_CALL_ERROR("E19999", "AddOutputDesc failed for node:%s", op_desc->GetName().c_str());
  235. return FAILED, "[Add][OutputDesc]fail for node:%s", op_desc->GetName().c_str());
  236. count++;
  237. }
  238. GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID);
  239. ge::NodePtr out_node = graph->AddNode(op_desc);
  240. GE_CHK_BOOL_EXEC(out_node != nullptr,
  241. REPORT_CALL_ERROR("E19999", "add node:%s to graph:%u failed.",
  242. op_desc->GetName().c_str(), graph->GetGraphID());
  243. return FAILED,
  244. "[Add][Node:%s]fail in graph:%u", op_desc->GetName().c_str(), graph->GetGraphID());
  245. GE_CHECK_NOTNULL_EXEC(node, return PARAM_INVALID);
  246. for (int32_t i = 0; i < count; ++i) {
  247. GE_CHK_STATUS(GraphUtils::AddEdge(node->GetOutDataAnchor(i), out_node->GetInDataAnchor(i)),
  248. "[Add][Edge]fail from node:%s to node:%s", node->GetName().c_str(), out_node->GetName().c_str());
  249. }
  250. return SUCCESS;
  251. }
  252. static void GetOpsProtoPath(string &opsproto_path) {
  253. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  254. if (path_env != nullptr) {
  255. string path = path_env;
  256. string file_path = RealPath(path.c_str());
  257. if (file_path.empty()) {
  258. REPORT_CALL_ERROR("E19999", "File path %s is invalid.", path.c_str());
  259. GELOGE(FAILED, "[Call][RealPath] File path %s is invalid.", path.c_str());
  260. return;
  261. }
  262. opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/");
  263. GELOGI("Get opsproto so path from env : %s", path.c_str());
  264. return;
  265. }
  266. string path_base = PluginManager::GetPath();
  267. GELOGI("path_base is %s", path_base.c_str());
  268. path_base = path_base.substr(0, path_base.rfind('/'));
  269. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  270. opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/");
  271. }
  272. static Status ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor> &inputs_dynamic) {
  273. for (auto input : inputs) {
  274. auto input_desc = input.GetTensorDesc();
  275. GeShape shape_ori = input_desc.GetShape();
  276. std::vector<int64_t> dynamic_shape_dims = {kDynamicDimValue};
  277. GeShape dynamic_shape(dynamic_shape_dims);
  278. std::vector<std::pair<int64_t, int64_t>> dynamic_shape_range;
  279. ge::GeTensor inputTensor;
  280. ge::GeTensorDesc desc(input_desc);
  281. bool is_const = false;
  282. (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);
  283. if (!is_const) {
  284. int64_t storage_format = FORMAT_NCHW;
  285. if (ge::AttrUtils::GetInt(desc, ge::ATTR_NAME_STORAGE_FORMAT, storage_format) &&
  286. !ge::AttrUtils::SetListInt(desc, ge::ATTR_NAME_STORAGE_SHAPE, dynamic_shape_dims)) {
  287. REPORT_CALL_ERROR("E19999", "Set attr ATTR_NAME_STORAGE_SHAPE failed to op:%s.", desc.GetName().c_str());
  288. GELOGE(FAILED, "[Set][Attr] ATTR_NAME_STORAGE_SHAPE fail.");
  289. return FAILED;
  290. }
  291. desc.SetShape(dynamic_shape);
  292. desc.SetShapeRange(dynamic_shape_range);
  293. }
  294. inputTensor.SetTensorDesc(desc);
  295. inputs_dynamic.push_back(inputTensor);
  296. }
  297. return SUCCESS;
  298. }
  299. static Status GetFuzzBuildAttrs(const OpDescPtr &op_desc, const GeRootModelPtr &ge_root_model,
  300. GeAttrValue::LIST_NAMED_ATTRS &fuzz_build_attrs) {
  301. GELOGD("Start get fuzz build attrs of %s.", op_desc->GetName().c_str());
  302. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  303. for (const auto &node : ge_root_model->GetRootGraph()->GetAllNodes()) {
  304. GE_CHECK_NOTNULL(node);
  305. GE_CHECK_NOTNULL(node->GetOpDesc());
  306. GELOGD("Delete fuzz build attr of %s after build.", node->GetName().c_str());
  307. node->GetOpDesc()->DelAttr(ATTR_NAME_FUZZ_BUILD);
  308. }
  309. (void)AttrUtils::GetListNamedAttrs(op_desc, ATTR_NAME_FUZZ_BUILD_RES_ATTRS, fuzz_build_attrs);
  310. if (!fuzz_build_attrs.empty()) {
  311. GELOGD("%s has split, get ATTR_NAME_FUZZ_BUILD_RES_ATTRS directly.", op_desc->GetName().c_str());
  312. return SUCCESS;
  313. } else {
  314. GELOGW("%s build with fuzz build pattern, but not set ATTR_NAME_FUZZ_BUILD_RES_ATTRS.", op_desc->GetName().c_str());
  315. }
  316. return SUCCESS;
  317. }
  318. static bool HasShapeRange(const vector<GeTensor> &inputs) {
  319. for (const auto &input : inputs) {
  320. vector<pair<int64_t, int64_t>> shape_range;
  321. (void)input.GetTensorDesc().GetShapeRange(shape_range);
  322. if (!shape_range.empty()) {
  323. GELOGD("Has set shape range.");
  324. return true;
  325. }
  326. }
  327. return false;
  328. }
  329. class GeGenerator::Impl {
  330. public:
  331. Impl(OmgContext &omg_context) : omg_context_(omg_context) {}
  332. ~Impl() = default;
  333. Status BuildModel(const Graph &graph, const vector<GeTensor> &inputs, GeRootModelPtr &ge_models);
  334. Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model);
  335. Status SaveRootModel(const string &file_name_prefix, GeRootModelPtr &model, ModelBufferData &model_buff);
  336. Status SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
  337. const vector<GeTensor> &inputs, const vector<GeTensor> &outputs);
  338. Status GenerateInfershapeGraph(const Graph &graph);
  339. OmgContext &omg_context_;
  340. GraphManager graph_manager_;
  341. SaveParam save_param_;
  342. bool is_offline_ = true;
  343. bool is_singleop_unregistered_ = false;
  344. std::string build_mode_;
  345. std::string build_step_;
  346. static std::mutex mutex_;
  347. private:
  348. static std::string Trim(const std::string &str);
  349. bool ParseVersion(const std::string &line, std::string &version);
  350. bool GetVersionFromPath(const std::string &file_path, std::string &version);
  351. bool SetAtcVersionInfo(AttrHolder &obj);
  352. bool SetOppVersionInfo(AttrHolder &obj);
  353. bool SetOmSystemInfo(AttrHolder &obj);
  354. };
  355. Status GeGenerator::Initialize(const map<string, string> &options) {
  356. return Initialize(options, domi::GetContext());
  357. }
  358. Status GeGenerator::Initialize(const map<string, string> &options, OmgContext &omg_context) {
  359. impl_ = ge::MakeShared<Impl>(omg_context);
  360. if (impl_ == nullptr) {
  361. REPORT_CALL_ERROR("E19999", "create Impl failed.");
  362. GELOGE(MEMALLOC_FAILED, "[Create][Impl] Make shared failed");
  363. return MEMALLOC_FAILED;
  364. }
  365. ErrorManager::GetInstance().SetStage(error_message::kInitialize, error_message::kOpsProtoInit);
  366. string opsproto_path;
  367. GetOpsProtoPath(opsproto_path);
  368. GELOGI("Get opsproto path is %s", opsproto_path.c_str());
  369. OpsProtoManager *manager = OpsProtoManager::Instance();
  370. map<string, string> option_tmp;
  371. option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path));
  372. (void)manager->Initialize(option_tmp);
  373. Status ret = impl_->graph_manager_.Initialize(options);
  374. if (ret != SUCCESS) {
  375. GELOGE(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, "[Call][Initialize] Graph manager initialize failed.");
  376. return GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED;
  377. }
  378. // get ek file
  379. auto iter = options.find(EK_FILE);
  380. if (iter != options.end()) {
  381. impl_->save_param_.ek_file = iter->second;
  382. }
  383. // get cert file
  384. iter = options.find(CERT_FILE);
  385. if (iter != options.end()) {
  386. impl_->save_param_.cert_file = iter->second;
  387. }
  388. // get hw key file
  389. iter = options.find(HW_KEY_FILE);
  390. if (iter != options.end()) {
  391. impl_->save_param_.hw_key_file = iter->second;
  392. }
  393. // get private file
  394. iter = options.find(PRIVATE_KEY_FILE);
  395. if (iter != options.end()) {
  396. impl_->save_param_.pri_key_file = iter->second;
  397. }
  398. // get build mode
  399. iter = options.find(BUILD_MODE);
  400. if (iter != options.end()) {
  401. impl_->build_mode_ = iter->second;
  402. }
  403. // get build step
  404. iter = options.find(BUILD_STEP);
  405. if (iter != options.end()) {
  406. impl_->build_step_ = iter->second;
  407. }
  408. return SUCCESS;
  409. }
  410. Status GeGenerator::Finalize() {
  411. ErrorManager::GetInstance().SetStage(error_message::kFinalize, error_message::kFinalize);
  412. if (impl_ == nullptr) {
  413. return SUCCESS;
  414. }
  415. Status ret = impl_->graph_manager_.Finalize();
  416. if (ret != SUCCESS) {
  417. GELOGE(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, "[Call][Finalize] Graph manager finalize failed.");
  418. return GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED;
  419. }
  420. return SUCCESS;
  421. }
  422. Status GeGenerator::GenerateOfflineModel(const Graph &graph, const string &file_name_prefix,
  423. const vector<GeTensor> &inputs) {
  424. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  425. GELOGI("Start to generate offline model.");
  426. ModelBufferData model;
  427. return GenerateModel(graph, file_name_prefix, inputs, model, true);
  428. }
  429. Status GeGenerator::GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ModelBufferData &model) {
  430. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  431. return GenerateModel(graph, "online", inputs, model, false);
  432. }
  433. Status GeGenerator::GenerateInfershapeGraph(const Graph &graph) {
  434. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  435. Status ret = impl_->GenerateInfershapeGraph(graph);
  436. if (ret != SUCCESS) {
  437. GELOGE(ret, "[Call][GenerateInfershapeGraph] Dump infershape json failed");
  438. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  439. GELOGE(FAILED, "[Call][Finalize] graph_manager finalize fail.");
  440. }
  441. return ret;
  442. }
  443. GELOGI("Generate infer shape graph success");
  444. return SUCCESS;
  445. }
  446. std::mutex GeGenerator::Impl::mutex_;
  447. // Remove the space and tab before and after the string
  448. std::string GeGenerator::Impl::Trim(const std::string &str) {
  449. if (str.empty()) {
  450. return str;
  451. }
  452. std::string::size_type start = str.find_first_not_of(" \t\r\n");
  453. if (start == std::string::npos) {
  454. return str;
  455. }
  456. std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1;
  457. return str.substr(start, end);
  458. }
  459. // Parsing the command line
  460. bool GeGenerator::Impl::ParseVersion(const std::string &line, std::string &version) {
  461. std::string flag = "Version=";
  462. std::string temp = Trim(line);
  463. if (temp.empty()) {
  464. GELOGW("line is empty.");
  465. return false;
  466. }
  467. std::string::size_type pos = temp.find(flag);
  468. if (pos == std::string::npos) {
  469. GELOGW("Incorrect line [%s], it must include [%s].", line.c_str(), flag.c_str());
  470. return false;
  471. }
  472. if (temp.size() == flag.size()) {
  473. GELOGW("version information is empty. %s", line.c_str());
  474. return false;
  475. }
  476. version = temp.substr(pos + flag.size());
  477. return true;
  478. }
  479. bool GeGenerator::Impl::GetVersionFromPath(const std::string &file_path, std::string &version) {
  480. // Normalize the path
  481. string resolved_file_path = RealPath(file_path.c_str());
  482. if (resolved_file_path.empty()) {
  483. GELOGW("Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str());
  484. return false;
  485. }
  486. std::ifstream fs(resolved_file_path, std::ifstream::in);
  487. if (!fs.is_open()) {
  488. GELOGW("Open %s failed.", file_path.c_str());
  489. return false;
  490. }
  491. std::string line;
  492. if (getline(fs, line)) {
  493. if (!ParseVersion(line, version)) {
  494. GELOGW("Parse version failed. content is [%s].", line.c_str());
  495. fs.close();
  496. return false;
  497. }
  498. } else {
  499. GELOGW("No version information found in the file path:%s", file_path.c_str());
  500. fs.close();
  501. return false;
  502. }
  503. fs.close(); // close the file
  504. return true;
  505. }
  506. // Set package version information in the model
  507. bool GeGenerator::Impl::SetAtcVersionInfo(AttrHolder &obj) {
  508. std::string path_base = ge::GELib::GetPath();
  509. path_base = path_base.substr(0, path_base.rfind('/'));
  510. path_base = path_base.substr(0, path_base.rfind('/') + 1);
  511. std::string version_path = path_base + "version.info";
  512. std::string version;
  513. if (!GetVersionFromPath(version_path, version)) {
  514. GELOGW("Get atc version information failed!");
  515. return false;
  516. }
  517. // set version info
  518. if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_ATC_VERSION, version)) {
  519. GELOGW("Ge model set atc version failed!");
  520. return false;
  521. }
  522. return true;
  523. }
  524. // Set package version information in the model
  525. bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) {
  526. const char *path_env = std::getenv("ASCEND_OPP_PATH");
  527. if (path_env == nullptr) {
  528. GELOGW("Get environment variable ASCEND_OPP_PATH failed!");
  529. return false;
  530. }
  531. std::string version_path = path_env;
  532. version_path += "/version.info";
  533. std::string version;
  534. if (!GetVersionFromPath(version_path, version)) {
  535. GELOGW("Get opp version information failed!");
  536. return false;
  537. }
  538. // set version info
  539. if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_OPP_VERSION, version)) {
  540. GELOGW("Ge model set opp version failed!");
  541. return false;
  542. }
  543. return true;
  544. }
  545. bool GeGenerator::Impl::SetOmSystemInfo(AttrHolder &obj) {
  546. std::string soc_version;
  547. (void)ge::GetContext().GetOption(ge::SOC_VERSION, soc_version);
  548. GELOGI("SetOmSystemInfo soc_version: %s", soc_version.c_str());
  549. if (!ge::AttrUtils::SetStr(obj, "soc_version", soc_version)) {
  550. GELOGW("SetStr of soc_version failed.");
  551. return false;
  552. }
  553. std::string framework_type;
  554. (void)ge::GetContext().GetOption(ge::FRAMEWORK_TYPE, framework_type);
  555. GELOGI("SetOmSystemInfo framework_type: %s", framework_type.c_str());
  556. auto iter = ge::kFwkTypeToStr.find(framework_type);
  557. if (iter == ge::kFwkTypeToStr.end()) {
  558. GELOGW("Can not find framework_type in the map.");
  559. return false;
  560. }
  561. if (!ge::AttrUtils::SetStr(obj, "framework_type", iter->second)) {
  562. GELOGW("SetStr of framework_type failed.");
  563. return false;
  564. }
  565. return true;
  566. }
  567. Status GeGenerator::SetModelNameForDump(const GeRootModelPtr &ge_root_model) {
  568. bool is_unknown_shape = false;
  569. Status ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
  570. if (ret != SUCCESS) {
  571. GELOGE(FAILED, "[Check][IsUnknownShape]Check root model is unknown shape failed, model id:%u",
  572. ge_root_model->GetModelId());
  573. REPORT_CALL_ERROR("E19999", "Check root model is unknown shape failed, model id:%u",
  574. ge_root_model->GetModelId());
  575. return FAILED;
  576. }
  577. GeModelPtr model_root = nullptr;
  578. if (is_unknown_shape) {
  579. model_root = MakeShared<GeModel>();
  580. GE_CHECK_NOTNULL(model_root);
  581. model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph()));
  582. ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root);
  583. }
  584. ModelHelper model_helper;
  585. string model_name;
  586. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  587. Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(),
  588. model_name);
  589. if (name_ret != SUCCESS) {
  590. ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"});
  591. GELOGE(FAILED, "[Check][GetModelNameStep]Get model_name failed. Param --output is invalid, root graph name: %s",
  592. ge_root_model->GetRootGraph()->GetName().c_str());
  593. return PARAM_INVALID;
  594. }
  595. map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  596. GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
  597. GE_CHECK_NOTNULL(ge_model);
  598. ge_model->SetName(model_name);
  599. return SUCCESS;
  600. }
  601. Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs,
  602. ModelBufferData &model, bool is_offline) {
  603. rtContext_t ctx = nullptr;
  604. auto rt = rtCtxGetCurrent(&ctx);
  605. if (rt != RT_ERROR_NONE) {
  606. GELOGD("Current ctx is null.");
  607. ctx = nullptr;
  608. }
  609. std::function<void()> callback = [&]() {
  610. if (ctx != nullptr) {
  611. (void)rtCtxSetCurrent(ctx);
  612. }
  613. };
  614. GE_MAKE_GUARD(restore, callback);
  615. GeRootModelPtr ge_root_model = nullptr;
  616. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  617. impl_->is_offline_ = is_offline;
  618. Status ret = impl_->BuildModel(graph, inputs, ge_root_model);
  619. if (ret != SUCCESS) {
  620. GELOGE(ret, "[Build][Model] failed, ret:%d.", ret);
  621. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  622. GELOGE(FAILED, "[Call][Finalize] graph_manager finalize fail.");
  623. }
  624. return ret;
  625. }
  626. /// BUILD_MODE_TUNING with BUILD_STEP_BEFORE_UB_MATCH no need save model;
  627. /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER no need save model;
  628. /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need save model.
  629. if ((impl_->build_mode_ == BUILD_MODE_TUNING) &&
  630. (impl_->build_step_ == BUILD_STEP_BEFORE_UB_MATCH || impl_->build_step_ == BUILD_STEP_AFTER_BUILDER ||
  631. impl_->build_step_ == BUILD_STEP_AFTER_BUILDER_SUB)) {
  632. GELOGI("Build mode:%s with step:%s no need SaveModel.",
  633. impl_->build_mode_.c_str(),
  634. impl_->build_step_.c_str());
  635. return SUCCESS;
  636. }
  637. GE_CHECK_NOTNULL(ge_root_model);
  638. ret = SetModelNameForDump(ge_root_model);
  639. if (ret != SUCCESS) {
  640. return ret;
  641. }
  642. ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model);
  643. if (ret != SUCCESS) {
  644. GELOGE(ret, "[Save][RootModel] failed, ret:%d, file:%s", ret, file_name_prefix.c_str());
  645. if (impl_->graph_manager_.Finalize() != SUCCESS) {
  646. GELOGE(FAILED, "graph_manager finalize fail.");
  647. }
  648. return ret;
  649. }
  650. return SUCCESS;
  651. }
  652. namespace {
  653. bool IsNeedConnectInputOpForSingleOp(GeTensorDesc &tensor_desc) {
  654. bool is_need = true;
  655. // format and dtype is all reserved, stand for Optional input. When singleop scene
  656. if (tensor_desc.GetFormat() == FORMAT_RESERVED && tensor_desc.GetDataType() == DT_UNDEFINED) {
  657. is_need = false;
  658. }
  659. return is_need;
  660. }
  661. Status CheckDynamicSupport(GeModelPtr &ge_model, const ComputeGraphPtr &graph) {
  662. bool support_dynamic = true;
  663. bool is_dynamic = false;
  664. for (const auto &node : graph->GetDirectNode()) {
  665. GE_CHECK_NOTNULL(node);
  666. auto op_desc = node->GetOpDesc();
  667. GE_CHECK_NOTNULL(op_desc);
  668. if (op_desc->GetOpEngineName() != kAIcoreEngine) {
  669. continue;
  670. }
  671. if (AttrUtils::HasAttr(op_desc, kAttrSupportDynamicShape)) {
  672. is_dynamic = true;
  673. (void) AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic);
  674. if (!support_dynamic) {
  675. GELOGW("Node[%s] doesn't support dynamic shape.", node->GetName().c_str());
  676. break;
  677. }
  678. }
  679. }
  680. if (is_dynamic) {
  681. (void) AttrUtils::SetBool(ge_model, kAttrSupportDynamicShape, support_dynamic);
  682. }
  683. return SUCCESS;
  684. }
  685. }
  686. bool GeGenerator::CheckNoAicore(const ComputeGraphPtr &graph) {
  687. for (const auto &node : graph->GetDirectNode()) {
  688. if (node == nullptr) {
  689. continue;
  690. }
  691. auto op_desc = node->GetOpDesc();
  692. if (op_desc == nullptr) {
  693. continue;
  694. }
  695. if (op_desc->GetOpEngineName() == kAIcoreEngine) {
  696. return false;
  697. }
  698. }
  699. return true;
  700. }
  701. void GeGenerator::RemoveConst(const vector<GeTensor> &inputs, vector<GeTensor> &outputs) {
  702. for (auto &input : inputs) {
  703. GeTensorDesc input_desc = input.GetTensorDesc();
  704. bool is_const = false;
  705. (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const);
  706. bool is_optional = IsOptional(input_desc);
  707. if (!is_optional && !is_const) {
  708. outputs.emplace_back(input);
  709. }
  710. }
  711. }
  712. Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  713. const vector<GeTensor> &outputs) {
  714. GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID);
  715. if (!inputs.empty() && (inputs.size() != op_desc->GetAllInputsSize())) {
  716. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  717. {op_desc->GetName(), op_desc->GetType(), "inputs size" + FmtToStr(op_desc->GetAllInputsSize()),
  718. "tensor size is " + FmtToStr(inputs.size())});
  719. GELOGE(PARAM_INVALID, "[Check][Param] Tensor size: %zu, op:%s(%s) Inputs size: %zu, not equal",
  720. inputs.size(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), op_desc->GetAllInputsSize());
  721. return PARAM_INVALID;
  722. }
  723. if (!outputs.empty() && (outputs.size() != op_desc->GetOutputsSize())) {
  724. ErrorManager::GetInstance().ATCReportErrMessage("E14001", {"opname", "optype", "value", "reason"},
  725. {op_desc->GetName(), op_desc->GetType(), "outputs size" + FmtToStr(op_desc->GetOutputsSize()),
  726. "tensor size is " + FmtToStr(outputs.size())});
  727. GELOGE(PARAM_INVALID, "[Check][Param] Tensor size: %zu, op:%s(%s) Outputs size: %zu, not equal",
  728. outputs.size(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), op_desc->GetOutputsSize());
  729. return PARAM_INVALID;
  730. }
  731. return SUCCESS;
  732. }
  733. Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc, Graph &graph) {
  734. GE_CHECK_NOTNULL(op_desc);
  735. if (OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()) != nullptr) {
  736. auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_desc->GetType());
  737. if (node_op.IsEmpty()) {
  738. GELOGW("get op from OperatorFactory fail. op type: %s", op_desc->GetType().c_str());
  739. } else {
  740. GELOGD("get op from OperatorFactory success. op type: %s", op_desc->GetType().c_str());
  741. auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op);
  742. if (temp_op_desc == nullptr) {
  743. REPORT_INNER_ERROR("E19999", "GetOpDescFromOperator failed, as return nullptr, type:%s",
  744. op_desc->GetType().c_str());
  745. GELOGE(FAILED, "[Get][OpDesc] temp op desc is null, type:%s", op_desc->GetType().c_str());
  746. return FAILED;
  747. }
  748. if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) {
  749. GELOGW("InferFormatForSingleOp UpdateInputName failed");
  750. }
  751. if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) {
  752. GELOGW("InferFormatForSingleOp UpdateOutputName failed");
  753. }
  754. }
  755. node_op.BreakConnect();
  756. }
  757. auto comp_graph = GraphUtils::GetComputeGraph(graph);
  758. GE_CHECK_NOTNULL(comp_graph);
  759. auto node = comp_graph->FindNode(op_desc->GetName());
  760. GE_CHECK_NOTNULL(node);
  761. auto op = OpDescUtils::CreateOperatorFromNode(node);
  762. auto ret = op_desc->CallInferFormatFunc(op);
  763. if (ret != GRAPH_SUCCESS) {
  764. REPORT_INNER_ERROR("E19999", "call InferFormatFunc for single op:%s fail",
  765. op_desc->GetName().c_str());
  766. GELOGE(FAILED, "[Call][InferFormatFunc] for single op:%s fail.", op_desc->GetName().c_str());
  767. return FAILED;
  768. }
  769. return SUCCESS;
  770. }
  771. Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
  772. const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
  773. bool is_offline, int32_t compile_flag) {
  774. GELOGD("Inputs size is %zu, outputs size is %zu.", inputs.size(), outputs.size());
  775. GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
  776. impl_->is_offline_ = is_offline;
  777. (void)AttrUtils::SetBool(op_desc, ATTR_SINGLE_OP_SCENE, true);
  778. if (CheckForSingleOp(op_desc, inputs, outputs) != SUCCESS) {
  779. GELOGE(PARAM_INVALID, "[Check][Param] input param is invalid when build single op:%s!",
  780. op_desc->GetName().c_str());
  781. return PARAM_INVALID;
  782. }
  783. OmgContext &omg_context = impl_->omg_context_;
  784. omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc);
  785. if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) {
  786. impl_->is_singleop_unregistered_ = true;
  787. }
  788. // 0. Save original attributes.
  789. OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc);
  790. GE_CHECK_NOTNULL(op_desc_tmp);
  791. bool fuzz_compile_flag = false;
  792. if (!HasShapeRange(inputs) && compile_flag == kFuzzBuildPattern) {
  793. fuzz_compile_flag = true;
  794. }
  795. (void)AttrUtils::SetBool(op_desc, ATTR_NAME_FUZZ_BUILD, fuzz_compile_flag);
  796. impl_->omg_context_.fuzz_compile_flag = fuzz_compile_flag;
  797. // 1. Create ComputeGraph.
  798. string name = ge::CurrentTimeInStr() + "_" + model_file_name;
  799. Graph graph;
  800. GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph),
  801. "[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str());
  802. GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc, graph));
  803. // 2. check engine type when compile online
  804. if (model_file_name == kFileNameSuffix) {
  805. auto comp_graph = GraphUtils::GetComputeGraph(graph);
  806. GE_CHECK_NOTNULL(comp_graph);
  807. auto node = comp_graph->FindNode(op_desc->GetName());
  808. Status ret = CheckEngineTypeSupport(node, engine_type);
  809. if (ret != SUCCESS) {
  810. GELOGE(ret, "[Check][EngineType]not support node:%s with engine of %d.", node->GetName().c_str(), engine_type);
  811. return ret;
  812. }
  813. }
  814. GELOGI("ATC parser success in single op build.");
  815. GeRootModelPtr ge_root_model = nullptr;
  816. vector<GeTensor> data_inputs;
  817. RemoveConst(inputs, data_inputs);
  818. GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, data_inputs, ge_root_model));
  819. map<string, GeAttrValue> op_attrs = op_desc_tmp->GetAllAttrs();
  820. GE_CHECK_NOTNULL(ge_root_model);
  821. GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
  822. map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  823. if (name_to_ge_model.empty()) {
  824. REPORT_CALL_ERROR("E19999", "GetSubgraphInstanceNameToModel failed.");
  825. GELOGE(PARAM_INVALID, "[Get][Name] GetSubgraphInstanceNameToModel is empty.");
  826. return PARAM_INVALID;
  827. }
  828. const ComputeGraphPtr root_graph = ge_root_model->GetRootGraph();
  829. GeModelPtr &ge_model = name_to_ge_model.begin()->second;
  830. GE_CHK_STATUS_RET_NOLOG(CheckDynamicSupport(ge_model, root_graph));
  831. GELOGI("After build model, The opType in op_desc_tmp is [%s]", op_desc_tmp->GetType().c_str());
  832. bool all_shape = false;
  833. (void)AttrUtils::GetBool(op_desc, kAicpuAllshape, all_shape);
  834. GELOGD("Node: %s, all_shape is %d, compile_flag is %d.", op_desc->GetName().c_str(), all_shape, compile_flag);
  835. (void)AttrUtils::SetInt(ge_model, ATTR_NAME_BUILD_MODE, fuzz_compile_flag);
  836. if (all_shape) {
  837. (void)AttrUtils::SetBool(ge_model, kAicpuAllshape, all_shape);
  838. }
  839. if (all_shape && CheckNoAicore(root_graph)) {
  840. GELOGD("Get aicpu all_shape kernel!");
  841. vector<GeTensor> inputs_dynamic;
  842. vector<GeTensor> outputs_dynamic;
  843. GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(inputs, inputs_dynamic));
  844. GE_CHK_STATUS_RET_NOLOG(ResetTensorVecShape(outputs, outputs_dynamic));
  845. GE_CHK_STATUS_RET_NOLOG(
  846. impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs_dynamic, outputs_dynamic));
  847. } else if (fuzz_compile_flag) {
  848. GeAttrValue::LIST_NAMED_ATTRS fuzz_build_attrs;
  849. if (GetFuzzBuildAttrs(op_desc, ge_root_model, fuzz_build_attrs) != SUCCESS) {
  850. GELOGE(FAILED, "[Get][FuzzRet]Failed to get fuzz build result of %s.", op_desc->GetName().c_str());
  851. return FAILED;
  852. }
  853. if (!fuzz_build_attrs.empty()) {
  854. GE_CHK_BOOL_EXEC(AttrUtils::SetListNamedAttrs(ge_model, ATTR_NAME_FUZZ_BUILD_RES_ATTRS, fuzz_build_attrs),
  855. REPORT_CALL_ERROR("E19999", "Set model:%s(id:%u) attr:%s failed.",
  856. ge_model->GetName().c_str(), ge_model->GetModelId(),
  857. ATTR_NAME_FUZZ_BUILD_RES_ATTRS.c_str());
  858. return FAILED, "Set model:%s(id:%u) attr:%s failed.",
  859. ge_model->GetName().c_str(), ge_model->GetModelId(), ATTR_NAME_FUZZ_BUILD_RES_ATTRS.c_str());
  860. }
  861. GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs));
  862. } else {
  863. GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs));
  864. }
  865. GELOGI("Start save GeModel to Model buffer");
  866. GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_model, model_buff));
  867. return SUCCESS;
  868. }
  869. /**
  870. * @ingroup ge
  871. * @brief Compiling a single operator into an offline model
  872. * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file
  873. * @param [in] vector<GeTensor> &inputs: Operator input data description information.
  874. * @param [in] vector<GeTensor> &outputs: Operator output data description information.
  875. * @param [in] const string &model_file_name: Offline model filename.
  876. * @param [in] compile_flag: op build flag from atc
  877. * @return SUCCESS handle successfully / others handle failed
  878. */
  879. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  880. const vector<GeTensor> &outputs, const string &model_file_name,
  881. int32_t compile_flag) {
  882. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  883. GELOGI("Start to build single op offline model, input size: %zu, output size: %zu", inputs.size(), outputs.size());
  884. ModelBufferData model_buff;
  885. OpEngineType engine_type = ENGINE_SYS;
  886. Status status = BuildSingleOp(op_desc, inputs, outputs, model_file_name, engine_type, model_buff, true, compile_flag);
  887. GELOGI("Finish build single offline model, status: %u", status);
  888. return status;
  889. }
  890. /**
  891. * @ingroup ge
  892. * @brief Compiling a single operator into online buffer
  893. * @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file
  894. * @param [in] vector<GeTensor> &inputs: Operator input data description information.
  895. * @param [in] vector<GeTensor> &outputs: Operator output data description information.
  896. * @param [in] engine_type: specific engine.
  897. * @param [in] compile_flag: op build flag, compile flag by acl
  898. * @param [out] ModelBufferData &Model_buff: Model_buff: model buffer of the op.
  899. * @return SUCCESS handle successfully / others handle failed
  900. */
  901. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  902. const vector<GeTensor> &outputs, OpEngineType engine_type,
  903. ModelBufferData &model_buff) {
  904. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  905. GELOGI("Start to build single op online, input size: %zu, output size: %zu", inputs.size(), outputs.size());
  906. Status status = BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false);
  907. GELOGI("Finish build single online model, status: %u", status);
  908. return status;
  909. }
  910. Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  911. const vector<GeTensor> &outputs, OpEngineType engine_type, int32_t compile_flag,
  912. ModelBufferData &model_buff) {
  913. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  914. GELOGI("Start to build single op online, input size: %zu, output size: %zu", inputs.size(), outputs.size());
  915. Status status = BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false,
  916. compile_flag);
  917. GELOGI("Finish build single online model, status: %u", status);
  918. return status;
  919. }
  920. Status GeGenerator::BuildSingleOpGraph(OpDescPtr &op_desc, const vector<GeTensor> &inputs,
  921. const vector<GeTensor> &outputs, std::string graph_name, Graph &graph) {
  922. ge::ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>(graph_name);
  923. GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR);
  924. // 1. Add Node to ComputeGraph.
  925. NodePtr op_node = compute_graph->AddNode(op_desc);
  926. GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR);
  927. // 2. Create InputData node.
  928. int32_t arg_index = 0;
  929. int32_t data_index = 0;
  930. if (inputs.empty()) {
  931. for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
  932. GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR);
  933. if (!IsNeedConnectInputOpForSingleOp(*input_desc)) {
  934. continue;
  935. }
  936. GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false, data_index));
  937. arg_index++;
  938. }
  939. } else {
  940. for (const auto &in_desc : inputs) {
  941. GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, in_desc.GetTensorDesc(), arg_index, true, data_index));
  942. arg_index++;
  943. }
  944. }
  945. // 3. Create Output node.
  946. if (!outputs.empty()) {
  947. GE_CHK_STATUS_RET_NOLOG(AddOutputs(compute_graph, op_node, outputs));
  948. }
  949. // dump ComputeGraph node.
  950. compute_graph->Dump();
  951. graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
  952. return SUCCESS;
  953. }
  954. Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
  955. const vector<GeTensor> &inputs, const vector<GeTensor> &outputs) {
  956. GE_CHECK_NOTNULL_EXEC(ge_model, return PARAM_INVALID);
  957. GE_CHK_BOOL_EXEC_NOLOG(graph_manager_.SaveParams(*ge_model, type, attrs, inputs, outputs) == SUCCESS,
  958. (void)graph_manager_.Finalize();
  959. return FAILED);
  960. return SUCCESS;
  961. }
  962. Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr &model, ModelBufferData &model_buff) {
  963. // set atc version
  964. if (!SetAtcVersionInfo(*(model.get()))) {
  965. GELOGW("SetPackageVersionInfo of atc failed!");
  966. }
  967. // set opp version
  968. if (!SetOppVersionInfo(*(model.get()))) {
  969. GELOGW("SetPackageVersionInfo of ops failed!");
  970. }
  971. ModelHelper model_helper;
  972. model_helper.SetSaveMode(is_offline_);
  973. Status ret = model_helper.SaveToOmModel(model, save_param_, file_name_prefix, model_buff);
  974. if (ret != SUCCESS) {
  975. GELOGE(ret, "[Call][SaveToOmModel] Save to om model failed");
  976. return ret;
  977. }
  978. return SUCCESS;
  979. }
  980. Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootModelPtr &ge_root_model,
  981. ModelBufferData &model_buff) {
  982. bool is_unknown_shape = false;
  983. auto ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
  984. if (ret != SUCCESS) {
  985. REPORT_CALL_ERROR("E19999", "root model(id:%u) CheckIsUnknownShape failed, ret:%d",
  986. ge_root_model->GetModelId(), ret);
  987. GELOGE(FAILED, "[Check][RootModel] is unkonwn shape failed, ret:%d", ret);
  988. return FAILED;
  989. }
  990. GELOGD("begin save root model, cur model is unkonwn shape model ? : %d", is_unknown_shape);
  991. GE_CHK_BOOL_EXEC(!ge_root_model->GetSubgraphInstanceNameToModel().empty(),
  992. REPORT_CALL_ERROR("E19999", "root model(id:%u) has no sub model.", ge_root_model->GetModelId());
  993. return FAILED, "[Get][SubModel] ge root model has no sub model")
  994. GeModelPtr model_root = nullptr;
  995. if (is_unknown_shape) {
  996. auto name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
  997. model_root = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
  998. } else {
  999. model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second;
  1000. }
  1001. GE_CHECK_NOTNULL(model_root);
  1002. // set atc version
  1003. if (!SetAtcVersionInfo(*(model_root.get()))) {
  1004. GELOGW("SetPackageVersionInfo of atc failed!");
  1005. }
  1006. // set opp version
  1007. if (!SetOppVersionInfo(*(model_root.get()))) {
  1008. GELOGW("SetPackageVersionInfo of ops failed!");
  1009. }
  1010. if (!SetOmSystemInfo(*(model_root.get()))) {
  1011. GELOGW("SetOmsystemInfo failed!");
  1012. }
  1013. ModelHelper model_helper;
  1014. model_helper.SetSaveMode(is_offline_);
  1015. ret = model_helper.SaveToOmRootModel(ge_root_model, save_param_, file_name_prefix, model_buff, is_unknown_shape);
  1016. if (ret != SUCCESS) {
  1017. REPORT_CALL_ERROR("E19999", "SaveToOmRootModel failed, ret:%d, model id:%u", ret, ge_root_model->GetModelId());
  1018. GELOGE(ret, "[Call][SaveToOmRootModel] failed, ret:%d, model id:%u", ret, ge_root_model->GetModelId());
  1019. return ret;
  1020. }
  1021. return SUCCESS;
  1022. }
  1023. Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs,
  1024. GeRootModelPtr &ge_root_model) {
  1025. static std::atomic<GraphId> atomic_graph_id(0);
  1026. auto graph_id = atomic_graph_id.fetch_add(1);
  1027. const std::map<std::string, std::string> options;
  1028. Status ret = graph_manager_.AddGraph(graph_id, graph, options, omg_context_);
  1029. if (ret != SUCCESS) {
  1030. REPORT_CALL_ERROR("E19999", "add graph(id:%u) failed, ret:%d", graph_id, ret);
  1031. GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "[Add][Graph] fail, graph id: %u", graph_id);
  1032. (void)graph_manager_.Finalize();
  1033. return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED;
  1034. }
  1035. graph_manager_.SetOptionsRunGraphFlag(false);
  1036. static std::atomic<uint64_t> atomic_session_id(0);
  1037. auto session_id = atomic_session_id.fetch_add(1);
  1038. // This is a temporary add for graph with variable
  1039. auto version = static_cast<int32_t>(SessionVersion::ClOUD_VERSION);
  1040. ret = VarManager::Instance(session_id)->Init(version, session_id, kDefaultDeviceId, kDefaultJobId);
  1041. GELOGI("Start init var instance, session_id %lu", session_id);
  1042. if (ret != SUCCESS) {
  1043. GELOGW("Failed init var instance, session_id %lu", session_id);
  1044. }
  1045. if (is_singleop_unregistered_) {
  1046. ret = graph_manager_.BuildGraphForUnregisteredOp(graph_id, inputs, ge_root_model, session_id);
  1047. } else {
  1048. ret = graph_manager_.BuildGraph(graph_id, inputs, ge_root_model, session_id);
  1049. }
  1050. ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kOther);
  1051. if (ret != SUCCESS) {
  1052. REPORT_CALL_ERROR("E19999", "build graph failed, graph id:%u, ret:%d", graph_id, ret);
  1053. GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "[Build][Graph] fail, graph id: %u", graph_id);
  1054. ret = GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED;
  1055. }
  1056. RtContextUtil::GetInstance().DestroyRtContexts(session_id);
  1057. Analyzer::GetInstance()->DestroySessionJsonObject(session_id);
  1058. VarManagerPool::Instance().RemoveVarManager(session_id);
  1059. return ret;
  1060. }
  1061. Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph) {
  1062. static std::atomic<GraphId> atomic_graph_id(0);
  1063. auto graph_id = atomic_graph_id.fetch_add(1);
  1064. const std::map<std::string, std::string> options;
  1065. Status ret = graph_manager_.AddGraph(graph_id, graph, options, omg_context_);
  1066. if (ret != SUCCESS) {
  1067. REPORT_CALL_ERROR("E19999", "add graph failed, graph id:%u, ret:%d", graph_id, ret);
  1068. GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "[Add][Graph] failed, graph id: %u", graph_id);
  1069. (void)graph_manager_.Finalize();
  1070. return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED;
  1071. }
  1072. ret = graph_manager_.GenerateInfershapeGraph(graph_id);
  1073. if (ret != SUCCESS) {
  1074. REPORT_CALL_ERROR("E19999", "GenerateInfershapeGraph failed, graph id:%u, ret:%d", graph_id, ret);
  1075. GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED,
  1076. "[Generate][Graph] failed, graph id:%u, ret:%d", graph_id, ret);
  1077. return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED;
  1078. }
  1079. return SUCCESS;
  1080. }
  1081. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示