You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_cache_helper.cc 66 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <climits>
  17. #include <cstdio>
  18. #include <fstream>
  19. #include <functional>
  20. #include "common/ge/ge_util.h"
  21. #include "common/helper/model_cache_helper.h"
  22. #include "common/types.h"
  23. #include "framework/common/debug/ge_log.h"
  24. #include "framework/common/ge_types.h"
  25. #include "framework/common/helper/model_helper.h"
  26. #include "framework/common/util.h"
  27. #include "graph/detail/attributes_holder.h"
  28. #include "graph/detail/model_serialize_imp.h"
  29. #include "graph/load/new_model_manager/davinci_model_parser.h"
  30. #include "graph/model.h"
  31. #include "graph/utils/graph_utils.h"
  32. #include "graph/utils/tensor_utils.h"
  33. #include "init/gelib.h"
  34. #include "proto/ge_ir.pb.h"
  35. using namespace std;
  36. namespace {
  37. const char *const kTbeKernelInfoStoreName = "AIcoreEngine";
  38. const char *const kGraphName = "temp_name";
  39. // Keys of json
  40. const char *const kNodeNum = "nodeNum";
  41. const char *const kEdgeNum = "edgeNum";
  42. const char *const kGraphHash = "graphHash";
  43. const char *const kNodeHash = "nodeHash";
  44. const char *const kHash = "hash";
  45. const char *const kSessionId = "sessionId";
  46. const char *const kDeviceId = "deviceId";
  47. const char *const kJobId = "jobId";
  48. const char *const kGraphMemMaxSize = "graphMemMaxSize";
  49. const char *const kVarMemMaxSize = "varMemMaxSize";
  50. const char *const kVarMemLogicBase = "varMemLogicBase";
  51. const char *const kUseMaxMemSize = "useMaxMemSize";
  52. const char *const kMemResourceMap = "memResourceMap";
  53. const char *const kMemType = "memType";
  54. const char *const kTotalSize = "totalSize";
  55. const char *const kVarMemSize = "varMemSize";
  56. const char *const kVarResource = "varResource";
  57. const char *const kVarAddrMgrMap = "varAddrMgrMap";
  58. const char *const kName = "name";
  59. const char *const kAddress = "address";
  60. const char *const kOffset = "offset";
  61. const char *const kMemoryType = "memoryType";
  62. const char *const kTensorDesc = "tensorDesc";
  63. const char *const kDataType = "dataType";
  64. const char *const kShape = "shape";
  65. const char *const kLayout = "layout";
  66. const char *const kOriginDataType = "originDataType";
  67. const char *const kOriginShape = "originShape";
  68. const char *const kOriginLayout = "originLayout";
  69. const char *const kRealDimCnt = "realDimCnt";
  70. const char *const kCurVarTensorDescMap = "curVarTensorDescMap";
  71. const char *const kTransRoads = "transRoads";
  72. const char *const kTransRoad = "transRoad";
  73. const char *const kNodeType = "nodeType";
  74. const char *const kInputTensorDesc = "inputTensorDesc";
  75. const char *const kOutputTensorDesc = "outputTensorDesc";
  76. const char *const kChangedGraphId = "changedGraphId";
  77. const char *const kAllocatedGraphId = "allocatedGraphId";
  78. const char *const kGraphId = "graphId";
  79. const char *const kVarBroadcastInfo = "varBroadcastInfo";
  80. const char *const kBroadcastName = "broadcastName";
  81. const char *const kIdx = "idx";
  82. const char *const kInputOffset = "inputOffset";
  83. const char *const kInputSize = "inputSize";
  84. const char *const kOutputOffset = "outputOffset";
  85. const char *const kOutputSize = "outputSize";
  86. // Suffix of cache files
  87. const char *const kBeforeVarManagerSuffix = "_before_build_var_manager.json";
  88. const char *const kAfterVarManagerSuffix = "_after_build_var_manager.json";
  89. const char *const kManifestSuffix = ".manifest";
  90. const char *const kOmSuffix = ".om";
  91. } // namespace
  92. namespace ge {
  93. map<uint32_t, uint32_t> ModelCacheHelper::graph_id_run_times_;
  94. ModelCacheHelper::ModelCacheHelper(uint64_t session_id, uint32_t graph_id, ComputeGraphPtr &compute_graph)
  95. : session_id_(session_id),
  96. graph_id_(graph_id),
  97. compute_graph_(compute_graph),
  98. is_cache_path_valid_for_output(false) {
  99. if (graph_id_run_times_.count(graph_id) == 0) {
  100. graph_id_run_times_[graph_id] = 1;
  101. } else {
  102. graph_id_run_times_[graph_id] = graph_id_run_times_[graph_id] + 1;
  103. }
  104. for (const auto &node : compute_graph_->GetDirectNode()) {
  105. bool is_variable = (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) ||
  106. (node->GetType() == VARHANDLEOP) || (node->GetType() == CONSTANTOP);
  107. if (!is_variable) {
  108. continue;
  109. }
  110. var_names_.insert(node->GetName());
  111. }
  112. std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
  113. if (instance_ptr != nullptr && instance_ptr->IsIncreBuild()) {
  114. std::string cache_path = instance_ptr->GetIncreBuildCachePath();
  115. GELOGD("Incre build path conf: %s", cache_path.c_str());
  116. string fake_file_path = cache_path + to_string(graph_id_) + kManifestSuffix;
  117. if (CheckOutputPathValid(fake_file_path)) {
  118. is_cache_path_valid_for_output = true;
  119. } else {
  120. GELOGW("Invalid cache path for output.");
  121. }
  122. std::string real_cache_path = RealPath(cache_path.c_str());
  123. if (real_cache_path.empty()) {
  124. GELOGW("Invalid incre build cache path conf: %s", cache_path.c_str());
  125. return;
  126. }
  127. cache_path_ = real_cache_path + '/';
  128. GELOGD("Try to use incre build cache path: %s", cache_path_.c_str());
  129. }
  130. }
  131. ModelCacheHelper::~ModelCacheHelper() { var_names_.clear(); }
  132. bool ModelCacheHelper::IsModelCacheHit() const {
  133. CacheInfo cache_info;
  134. if (GetCacheInfo(cache_info) != SUCCESS) {
  135. GELOGI("Get cache info of graph id[%u] failed.", graph_id_);
  136. return false;
  137. }
  138. // Check number of nodes and edges first.
  139. if (cache_info.node_num != compute_graph_->GetDirectNodesSize()) {
  140. GELOGI("Graph id[%u] cache miss: the node number of the graph does not match the cache info.", graph_id_);
  141. return false;
  142. }
  143. size_t edge_num = 0;
  144. for (const auto &node : compute_graph_->GetDirectNode()) {
  145. for (const auto &anchor : node->GetAllInAnchors()) {
  146. edge_num += anchor->GetPeerAnchors().size();
  147. }
  148. }
  149. if (cache_info.edge_num != edge_num) {
  150. GELOGI("Graph id[%u] cache miss: the edge number of the graph does not match the cache info.", graph_id_);
  151. return false;
  152. }
  153. size_t compute_graph_hash;
  154. auto ret = GetComputeGraphHash(compute_graph_hash);
  155. if (ret != SUCCESS || cache_info.graph_hash != compute_graph_hash) {
  156. GELOGI("Graph id[%u] cache miss: the hash code of the graph does not match the cache info.", graph_id_);
  157. return false;
  158. }
  159. if (!IsNodeHashSameAsCache(cache_info.nodes_hash)) {
  160. GELOGI("Graph id[%u] cache miss: the hash code of node does not match the cache info.", graph_id_);
  161. return false;
  162. }
  163. string var_manager_cache =
  164. to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kBeforeVarManagerSuffix;
  165. Json var_manager_json;
  166. if (LoadJsonFromFile(var_manager_cache, var_manager_json) != SUCCESS) {
  167. GELOGW("Fail to load json from cache file: %s", var_manager_cache.c_str());
  168. return false;
  169. }
  170. if (!IsVarManagerSameAsCache(var_manager_json)) {
  171. GELOGI("Graph id[%u] cache miss: the VarManager does not match the cache info.", graph_id_);
  172. return false;
  173. }
  174. GELOGI("Graph id[%u] cache hit.", graph_id_);
  175. return true;
  176. }
  177. Status ModelCacheHelper::RefreshComputeGraph(const ComputeGraphPtr &compute_graph) {
  178. if (compute_graph->IsValid()) {
  179. compute_graph_ = compute_graph;
  180. var_names_.clear();
  181. for (const auto &node : compute_graph_->GetDirectNode()) {
  182. bool is_variable = (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) ||
  183. (node->GetType() == VARHANDLEOP) || (node->GetType() == CONSTANTOP);
  184. if (!is_variable) {
  185. continue;
  186. }
  187. var_names_.insert(node->GetName());
  188. }
  189. return SUCCESS;
  190. } else {
  191. GELOGW("Invalid compute graph.");
  192. return FAILED;
  193. }
  194. }
  195. Status ModelCacheHelper::ClearCache(uint32_t graph_id) const {
  196. if (!is_cache_path_valid_for_output) {
  197. GELOGW("Invalid cache path.");
  198. return SUCCESS;
  199. }
  200. string manifest_file = cache_path_ + to_string(graph_id) + kManifestSuffix;
  201. string manifest_file_path = RealPath(manifest_file.c_str());
  202. int ret;
  203. if (!manifest_file_path.empty()) {
  204. ret = remove(manifest_file_path.c_str());
  205. // If remove file failed, print the warning log
  206. if (ret != 0) {
  207. GELOGW("Clear cache [%s] failed.", manifest_file_path.c_str());
  208. }
  209. }
  210. string before_var_manager_file = cache_path_ + to_string(graph_id) + kManifestSuffix;
  211. string before_var_manager_file_path = RealPath(before_var_manager_file.c_str());
  212. if (!before_var_manager_file_path.empty()) {
  213. ret = remove(before_var_manager_file_path.c_str());
  214. if (ret != 0) {
  215. GELOGW("Clear cache [%s] failed.", before_var_manager_file_path.c_str());
  216. }
  217. }
  218. string after_var_manager_file = cache_path_ + to_string(graph_id) + kManifestSuffix;
  219. string after_var_manager_file_path = RealPath(after_var_manager_file.c_str());
  220. if (!after_var_manager_file_path.empty()) {
  221. ret = remove(after_var_manager_file_path.c_str());
  222. if (ret != 0) {
  223. GELOGW("Clear cache [%s] failed.", after_var_manager_file_path.c_str());
  224. }
  225. }
  226. string om_file = cache_path_ + to_string(graph_id) + kManifestSuffix;
  227. string om_file_path = RealPath(om_file.c_str());
  228. if (!om_file_path.empty()) {
  229. ret = remove(om_file_path.c_str());
  230. if (ret != 0) {
  231. GELOGW("Clear cache [%s] failed.", om_file_path.c_str());
  232. }
  233. }
  234. return SUCCESS;
  235. }
  236. Status ModelCacheHelper::RecoverVarManagerFromCache() const {
  237. string var_manager_cache =
  238. to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kAfterVarManagerSuffix;
  239. Json var_manager_json;
  240. if (LoadJsonFromFile(var_manager_cache, var_manager_json) != SUCCESS) {
  241. GELOGW("Fail to load json from cache file: %s", var_manager_cache.c_str());
  242. return FAILED;
  243. }
  244. Json mem_resource_json = move(var_manager_json[kMemResourceMap]);
  245. auto ret = RecoverMemResource(mem_resource_json);
  246. if (ret != SUCCESS) {
  247. GELOGW("Recover VarManager from cache failed.[MemResource]");
  248. return FAILED;
  249. }
  250. Json var_resource_json = move(var_manager_json[kVarResource]);
  251. ret = RecoverAllocatedGraphId(var_resource_json[kAllocatedGraphId]);
  252. if (ret != SUCCESS) {
  253. GELOGW("Recover VarManager from cache failed.[AllocatedGraphId]");
  254. return FAILED;
  255. }
  256. ret = RecoverChangedGraphId(var_resource_json[kChangedGraphId]);
  257. if (ret != SUCCESS) {
  258. GELOGW("Recover VarManager from cache failed.[ChangedGraphId]");
  259. return FAILED;
  260. }
  261. ret = RecoverBroadcastInfo(var_resource_json[kVarBroadcastInfo]);
  262. if (ret != SUCCESS) {
  263. GELOGW("Recover VarManager from cache failed.[VarBroadcastInfo]");
  264. return FAILED;
  265. }
  266. ret = RecoverVarAddrAndTensorDesc(var_resource_json[kVarAddrMgrMap]);
  267. if (ret != SUCCESS) {
  268. GELOGW("Recover VarManager from cache failed.[VarAddrMgrMap & CurVarTensorDesc]");
  269. return FAILED;
  270. }
  271. ret = RecoverTransRoads(var_resource_json[kTransRoads]);
  272. if (ret != SUCCESS) {
  273. GELOGW("Recover VarManager from cache failed.[TransRoads]");
  274. return FAILED;
  275. }
  276. GELOGI("Recover VarManager from cache[%s] success.", cache_path_.c_str());
  277. return SUCCESS;
  278. }
  279. Status ModelCacheHelper::GetNodesNeedRecompile(ComputeGraphPtr &graph, vector<NodePtr> &nodes) {
  280. std::shared_ptr<GELib> instance = ge::GELib::GetInstance();
  281. if (instance == nullptr || !instance->InitFlag()) {
  282. GELOGW("RecompileNodes failed.");
  283. return ge::GE_CLI_GE_NOT_INITIALIZED;
  284. }
  285. // Collect aicore ops for recompile
  286. for (auto &node : graph->GetDirectNode()) {
  287. if (node == nullptr) {
  288. continue;
  289. }
  290. auto op_desc = node->GetOpDesc();
  291. if (op_desc == nullptr) {
  292. continue;
  293. }
  294. // Get op kernel lib name
  295. string kernel_lib_name = op_desc->GetOpKernelLibName();
  296. if (kernel_lib_name.empty()) {
  297. // reset op kernel lib
  298. (void)instance->DNNEngineManagerObj().GetDNNEngineName(node);
  299. kernel_lib_name = op_desc->GetOpKernelLibName();
  300. if (kernel_lib_name.empty()) {
  301. GELOGW("Get node:%s, type:%s kernel lib failed.", node->GetName().c_str(), op_desc->GetType().c_str());
  302. continue;
  303. }
  304. }
  305. }
  306. return SUCCESS;
  307. }
  308. Status ModelCacheHelper::RecompileNodes(GeModelPtr &ge_model) {
  309. std::shared_ptr<GELib> instance = ge::GELib::GetInstance();
  310. if (instance == nullptr || !instance->InitFlag()) {
  311. GELOGW("RecompileNodes failed.");
  312. return ge::GE_CLI_GE_NOT_INITIALIZED;
  313. }
  314. // Get aicore ops kernel info store.
  315. OpsKernelInfoStorePtr kernel_info = instance->OpsKernelManagerObj().GetOpsKernelInfoStore(kTbeKernelInfoStoreName);
  316. if (kernel_info == nullptr) {
  317. GELOGW("Get %s ops kernel info store failed", kTbeKernelInfoStoreName);
  318. return INTERNAL_ERROR;
  319. }
  320. auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph());
  321. vector<NodePtr> node_vec;
  322. auto ret = GetNodesNeedRecompile(compute_graph, node_vec);
  323. GE_CHK_BOOL_EXEC_WARN(ret == ge::SUCCESS, return ret, "Get nodes need recompiling failed");
  324. // Recompile aicore ops
  325. ret = kernel_info->CompileOp(node_vec);
  326. GE_CHK_BOOL_EXEC_WARN(ret == ge::SUCCESS, return ret, "Recompile op failed");
  327. const TBEKernelStore &tbekernel_store = ge_model->GetTBEKernelStore();
  328. TBEKernelStore tbe_kernel_store;
  329. for (const ge::NodePtr &n : compute_graph->GetDirectNode()) {
  330. auto node_op_desc = n->GetOpDesc();
  331. GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue);
  332. TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr());
  333. if (tbe_kernel == nullptr) {
  334. // Load tbe kernel from tbe_kernel_store to op if op was not recompiled
  335. auto op_desc = n->GetOpDesc();
  336. tbekernel_store.LoadTBEKernelBinToOpDesc(op_desc);
  337. GELOGD("LoadOmModelFromCache: Load tbe kernel bin to op desc[%s].", op_desc->GetName().c_str());
  338. }
  339. tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr());
  340. GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue);
  341. // Refresh tbe kernel in tbe_kernel_store
  342. tbe_kernel_store.AddTBEKernel(tbe_kernel);
  343. GELOGD("Add tbe kernel bin %s", tbe_kernel->GetName().c_str());
  344. }
  345. GE_CHK_BOOL_EXEC_WARN(tbe_kernel_store.Build(), return FAILED, "TBE Kernels store build failed!");
  346. ge_model->SetTBEKernelStore(tbe_kernel_store);
  347. return SUCCESS;
  348. }
  349. Status ModelCacheHelper::GetNodesHash(map<std::string, size_t> &hash_map) const {
  350. vector<NodePtr> nodes;
  351. GraphUtils::TopologicalSortingByName(compute_graph_, nodes);
  352. ModelSerializeImp model_serialize_imp;
  353. std::hash<string> node_hash;
  354. for (const auto &node : nodes) {
  355. if (node == nullptr) {
  356. continue;
  357. }
  358. proto::OpDef op_def;
  359. bool is_framework_op = (node->GetType() == FRAMEWORKOP);
  360. int32_t framework_type = 0;
  361. if (is_framework_op) {
  362. AttrUtils::GetInt(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, framework_type);
  363. AttrUtils::SetInt(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, 0);
  364. }
  365. bool ret = model_serialize_imp.SerializeNode(node, &op_def, is_framework_op);
  366. op_def.set_id(0); // Id of op is not stable because of parallel parsing
  367. // Clear weights attr in constant.
  368. auto attr = op_def.mutable_attr();
  369. if (op_def.type() == CONSTANT || op_def.type() == CONSTANTOP) {
  370. attr->erase(ATTR_NAME_WEIGHTS);
  371. }
  372. if (is_framework_op) {
  373. AttrUtils::SetInt(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, framework_type);
  374. }
  375. if (!ret) {
  376. GELOGW("Fail to serialize node[%s].", node->GetName().c_str());
  377. return INTERNAL_ERROR;
  378. }
  379. string prototxt;
  380. ret = google::protobuf::TextFormat::PrintToString(op_def, &prototxt);
  381. if (!ret) {
  382. GELOGW("Print OpDef to string failed.");
  383. hash_map.clear();
  384. return INTERNAL_ERROR;
  385. }
  386. size_t hash_code = node_hash(prototxt);
  387. hash_map[node->GetName()] = hash_code;
  388. }
  389. return SUCCESS;
  390. }
  391. Status ModelCacheHelper::GetComputeGraphHash(size_t &hash) const {
  392. proto::GraphDef graph_proto;
  393. ModelSerializeImp model_serialize_imp;
  394. // The name of compute graph may be generated randomly, so replace it temporarily.
  395. const string origin_name = compute_graph_->GetName();
  396. compute_graph_->SetName(kGraphName);
  397. bool serialize_ret = model_serialize_imp.SerializeGraph(compute_graph_, &graph_proto);
  398. graph_proto.clear_op();
  399. if (!serialize_ret) {
  400. GELOGW("Serialize graph failed.");
  401. hash = 0;
  402. return INTERNAL_ERROR;
  403. }
  404. compute_graph_->SetName(origin_name);
  405. // Generate proto text of GraphDef
  406. string prototxt;
  407. bool print_ret = google::protobuf::TextFormat::PrintToString(graph_proto, &prototxt);
  408. if (!print_ret) {
  409. GELOGW("Print GraphDef to string failed.");
  410. hash = 0;
  411. return INTERNAL_ERROR;
  412. }
  413. // Get the hash code of proto text
  414. std::hash<string> graph_hash;
  415. hash = graph_hash(prototxt);
  416. return SUCCESS;
  417. }
  418. Status ModelCacheHelper::SaveJsonToFile(const string &file_name, const Json &json) const {
  419. if (!is_cache_path_valid_for_output) {
  420. GELOGW("Invalid cache path.");
  421. return PARAM_INVALID;
  422. }
  423. // Check whether the manifest exists, if not, create it.
  424. string real_path = RealPath(cache_path_.c_str());
  425. if (real_path.empty()) {
  426. GELOGW("File path is invalid. please check cache path: %s", cache_path_.c_str());
  427. return FAILED;
  428. }
  429. const string path = cache_path_ + file_name;
  430. const int FILE_AUTHORITY = 0600;
  431. int fd = mmOpen2(path.c_str(), M_WRONLY | M_CREAT | O_TRUNC, FILE_AUTHORITY);
  432. if (fd < 0) {
  433. GELOGW("Fail to open the file: %s.", path.c_str());
  434. return INTERNAL_ERROR;
  435. }
  436. if (mmClose(fd) != 0) {
  437. GELOGW("Fail to close the file: %s.", path.c_str());
  438. return INTERNAL_ERROR;
  439. }
  440. // Write json into cache file
  441. ofstream ofs;
  442. ofs.open(path);
  443. if (!ofs.is_open()) {
  444. GELOGW("Fail to open the file: %s.", path.c_str());
  445. return INTERNAL_ERROR;
  446. }
  447. ofs << json << std::endl;
  448. ofs.close();
  449. return SUCCESS;
  450. }
  451. Status ModelCacheHelper::LoadJsonFromFile(const string &file_name, Json &json) const {
  452. if (!json.is_null()) {
  453. GELOGW("Input param json type should be null.");
  454. return PARAM_INVALID;
  455. }
  456. string real_path = RealPath(cache_path_.c_str());
  457. if (real_path.empty()) {
  458. GELOGW("File path is invalid. please check cache path: %s", cache_path_.c_str());
  459. return FAILED;
  460. }
  461. const string path = cache_path_ + file_name;
  462. if (!CheckInputPathValid(path)) {
  463. GELOGW("Invalid cache path for input:%s.", path.c_str());
  464. return FAILED;
  465. }
  466. string cache_real_path = RealPath(path.c_str());
  467. if (cache_real_path.empty()) {
  468. GELOGI("File[%s] is not found.", path.c_str());
  469. return FAILED;
  470. }
  471. // Read json from cache file
  472. ifstream ifs;
  473. ifs.open(path);
  474. if (!ifs.is_open()) {
  475. GELOGW("Fail to open the file: %s.", path.c_str());
  476. return INTERNAL_ERROR;
  477. }
  478. try {
  479. ifs >> json;
  480. } catch (nlohmann::detail::parse_error e) {
  481. GELOGW("Fail to load json from file, json throw an error:%s.", e.what());
  482. return INTERNAL_ERROR;
  483. } catch (nlohmann::detail::invalid_iterator e) {
  484. GELOGW("Fail to load json from file, json throw an error:%s.", e.what());
  485. return INTERNAL_ERROR;
  486. } catch (nlohmann::detail::type_error e) {
  487. GELOGW("Fail to load json from file, json throw an error:%s.", e.what());
  488. return INTERNAL_ERROR;
  489. } catch (nlohmann::detail::out_of_range e) {
  490. GELOGW("Fail to load json from file, json throw an error:%s.", e.what());
  491. return INTERNAL_ERROR;
  492. } catch (nlohmann::detail::other_error e) {
  493. GELOGW("Fail to load json from file, json throw an error:%s.", e.what());
  494. return INTERNAL_ERROR;
  495. }
  496. if (!json.is_object()) {
  497. GELOGW("Fail to load the json file: %s.", path.c_str());
  498. return INTERNAL_ERROR;
  499. }
  500. return SUCCESS;
  501. }
  502. Status ModelCacheHelper::SaveCacheInfoToCache() const {
  503. // Generate cache json
  504. // example: {"edgeNum":6,"nodeNum":7,"graphCache":134714827475991356}
  505. Json cache_json;
  506. try {
  507. cache_json[kNodeNum] = compute_graph_->GetDirectNodesSize();
  508. size_t edge_num = 0;
  509. for (const auto &node : compute_graph_->GetDirectNode()) {
  510. for (const auto &anchor : node->GetAllInAnchors()) {
  511. edge_num += anchor->GetPeerAnchors().size();
  512. }
  513. }
  514. cache_json[kEdgeNum] = edge_num;
  515. size_t hash = 0;
  516. auto ret = GetComputeGraphHash(hash);
  517. if (ret != SUCCESS) {
  518. GELOGW("Error occur when generate graph hash code.");
  519. return ret;
  520. }
  521. cache_json[kGraphHash] = hash;
  522. Json nodes_hash_json;
  523. ret = GetNodesHashMapJson(nodes_hash_json);
  524. if (ret != SUCCESS) {
  525. GELOGW("Error occur when generate nodes hash code.");
  526. return ret;
  527. }
  528. cache_json[kNodeHash] = nodes_hash_json;
  529. } catch (const std::exception &e) {
  530. GELOGW("Fail to generate cache info json. Error message: %s", e.what());
  531. return INTERNAL_ERROR;
  532. }
  533. string cache_manifest = to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kManifestSuffix;
  534. auto ret = SaveJsonToFile(cache_manifest, cache_json);
  535. if (ret != SUCCESS) {
  536. GELOGW("Fail to save cache info to json file, path: %s.", cache_path_.c_str());
  537. return ret;
  538. }
  539. return SUCCESS;
  540. }
  541. Status ModelCacheHelper::GetCacheInfo(CacheInfo &cache_info) const {
  542. string cache_manifest = to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kManifestSuffix;
  543. Json cache_json;
  544. if (LoadJsonFromFile(cache_manifest, cache_json) != SUCCESS) {
  545. GELOGW("Fail to load json from cache file: %s", cache_manifest.c_str());
  546. return INTERNAL_ERROR;
  547. }
  548. if (!cache_json.is_object()) {
  549. GELOGW("Manifest should be a json object");
  550. return INTERNAL_ERROR;
  551. }
  552. try {
  553. cache_info.node_num = cache_json[kNodeNum];
  554. cache_info.edge_num = cache_json[kEdgeNum];
  555. cache_info.graph_hash = cache_json[kGraphHash];
  556. Json nodes_hash_json = cache_json[kNodeHash];
  557. if (!(nodes_hash_json.is_null() || nodes_hash_json.is_array())) {
  558. GELOGW("Nodes hash in cache should be null or array.");
  559. return FAILED;
  560. }
  561. for (const auto &iter : nodes_hash_json) {
  562. cache_info.nodes_hash[iter[kName].get<std::string>()] = iter[kHash].get<size_t>();
  563. }
  564. } catch (const std::exception &e) {
  565. GELOGW("Fail to get info from json file. Error message: %s", e.what());
  566. return INTERNAL_ERROR;
  567. }
  568. return SUCCESS;
  569. }
  570. bool ModelCacheHelper::IsAllocatedGraphIdSameAsCache(Json &json) const {
  571. if (!(json.is_null() || json.is_array())) {
  572. GELOGW("Input param json type should be null or array.");
  573. return false;
  574. }
  575. // Compare allocated graph id info between json and VarManager
  576. std::unordered_map<std::string, uint32_t> allocated_graph_id;
  577. auto ret = ParseAllocatedGraphIdFromJson(json, allocated_graph_id);
  578. if (ret != SUCCESS) {
  579. GELOGW("Fail to parse AllocatedGraphId from Json.");
  580. return false;
  581. }
  582. for (const auto &iter : allocated_graph_id) {
  583. uint32_t graph_id = 0;
  584. ret = VarManager::Instance(session_id_)->GetAllocatedGraphId(iter.first, graph_id);
  585. if (ret != SUCCESS) {
  586. GELOGW("Fail to find allocated graph id of var[%s].", iter.first.c_str());
  587. return false;
  588. }
  589. if (graph_id != iter.second) {
  590. GELOGW("The allocated graph id of variable[%s] in cache is different from VarManager.", iter.first.c_str());
  591. return false;
  592. }
  593. }
  594. return true;
  595. }
  596. bool ModelCacheHelper::IsNodeHashSameAsCache(const map<std::string, size_t> &hash_map) const {
  597. map<std::string, size_t> cur_hash_map;
  598. GetNodesHash(cur_hash_map);
  599. if (hash_map.size() != cur_hash_map.size()) {
  600. GELOGI("The number of hash code is different from cache info.");
  601. return false;
  602. }
  603. for (const auto &iter : cur_hash_map) {
  604. if (hash_map.count(iter.first) == 0) {
  605. GELOGI("Node[%s] is not found in cache info.", iter.first.c_str());
  606. return false;
  607. }
  608. if (hash_map.at(iter.first) != iter.second) {
  609. GELOGI("The hash code of node[%s] is different from cache info.", iter.first.c_str());
  610. return false;
  611. }
  612. }
  613. return true;
  614. }
  615. bool ModelCacheHelper::IsMemResourceSameAsCache(Json &json) const {
  616. if (!(json.is_null() || json.is_array())) {
  617. GELOGW("Input param json type should be null or array.");
  618. return false;
  619. }
  620. // Compare var mem size info between json and VarManager
  621. std::map<rtMemType_t, int64_t> var_mem_size;
  622. auto ret = ParseMemResourceFromJson(json, var_mem_size);
  623. if (ret != SUCCESS) {
  624. GELOGW("Fail to parse MemResource from Json.");
  625. return false;
  626. }
  627. for (const auto &iter : var_mem_size) {
  628. int64_t mem_size = VarManager::Instance(session_id_)->GetVarMemSize(iter.first);
  629. if (mem_size != iter.second) {
  630. GELOGW("The var mem size of memory_type[%u] in cache is different from VarManager.", iter.first);
  631. return false;
  632. }
  633. }
  634. return true;
  635. }
  636. bool ModelCacheHelper::IsChangedGraphIdSameAsCache(Json &json) const {
  637. if (!(json.is_null() || json.is_array())) {
  638. GELOGW("Input param json type should be null or array.");
  639. return false;
  640. }
  641. // Compare variable changed graph id info between json and VarManager
  642. std::unordered_map<std::string, uint32_t> changed_graph_id;
  643. auto ret = ParseChangedGraphIdFromJson(json, changed_graph_id);
  644. if (ret != SUCCESS) {
  645. GELOGW("Fail to parse ChangedGraphId from Json.");
  646. return false;
  647. }
  648. for (const auto &iter : changed_graph_id) {
  649. uint32_t graph_id = 0;
  650. ret = VarManager::Instance(session_id_)->GetChangedGraphId(iter.first, graph_id);
  651. if (ret != SUCCESS) {
  652. GELOGW("Fail to find changed graph id of var[%s].", iter.first.c_str());
  653. return false;
  654. }
  655. if (graph_id != iter.second) {
  656. GELOGW("The changed graph id of variable[%s] in cache is different from VarManager.", iter.first.c_str());
  657. return false;
  658. }
  659. }
  660. return true;
  661. }
  662. bool ModelCacheHelper::IsCurVarTensorDescSameAsCache(Json &json) const {
  663. if (!(json.is_null() || json.is_array())) {
  664. GELOGW("Input param json type should be null or array.");
  665. return false;
  666. }
  667. // Compare variable tensor desc info between json and VarManager
  668. std::unordered_map<std::string, ge::GeTensorDesc> cur_var_tensor_desc;
  669. auto ret = ParseCurVarTensorDescMapFromJson(json, cur_var_tensor_desc);
  670. if (ret != SUCCESS) {
  671. GELOGW("Fail to parse CurVarTensorDesc from Json.");
  672. return false;
  673. }
  674. for (const auto &iter : cur_var_tensor_desc) {
  675. GeTensorDesc tensor_desc;
  676. ret = VarManager::Instance(session_id_)->GetCurVarDesc(iter.first, tensor_desc);
  677. if (ret != SUCCESS) {
  678. GELOGW("Fail to find tensor desc of var[%s].", iter.first.c_str());
  679. return false;
  680. }
  681. uint32_t l_real_dim_cnt = 0;
  682. uint32_t r_real_dim_cnt = 0;
  683. TensorUtils::GetRealDimCnt(tensor_desc, l_real_dim_cnt);
  684. TensorUtils::GetRealDimCnt(iter.second, r_real_dim_cnt);
  685. if ((tensor_desc.GetDataType() != iter.second.GetDataType()) ||
  686. (tensor_desc.GetOriginDataType() != iter.second.GetOriginDataType()) ||
  687. (tensor_desc.GetFormat() != iter.second.GetFormat()) ||
  688. (tensor_desc.GetOriginFormat() != iter.second.GetOriginFormat()) ||
  689. (tensor_desc.GetShape().ToString() != iter.second.GetShape().ToString()) ||
  690. (tensor_desc.GetOriginShape().ToString() != iter.second.GetOriginShape().ToString()) ||
  691. (l_real_dim_cnt != r_real_dim_cnt)) {
  692. GELOGW("The var tensor desc of variable[%s] in cache is different from VarManager.", iter.first.c_str());
  693. return false;
  694. }
  695. }
  696. return true;
  697. }
  698. bool ModelCacheHelper::IsVarAddrMgrMapSameAsCache(Json &json) const {
  699. if (!(json.is_null() || json.is_array())) {
  700. GELOGW("Input param json type should be null or array.");
  701. return false;
  702. }
  703. // Compare variable address info between json and VarManager
  704. std::vector<std::pair<std::string, VarAddrMgr>> var_addr_mgr_vector;
  705. std::unordered_set<uint64_t> var_offset_set;
  706. auto ret = ParseVarAddrMgrMapFromJson(json, var_addr_mgr_vector, var_offset_set);
  707. if (ret != SUCCESS) {
  708. GELOGW("Fail to parse VarAddrMgrMap from Json.");
  709. return false;
  710. }
  711. for (const auto &iter : var_addr_mgr_vector) {
  712. uint8_t *dev_ptr = nullptr;
  713. rtMemType_t memory_type;
  714. ret = VarManager::Instance(session_id_)->GetVarAddr(iter.first, iter.second.tensor_desc, &dev_ptr, memory_type);
  715. if (ret != SUCCESS) {
  716. GELOGW("Fail to find tensor desc of var[%s].", iter.first.c_str());
  717. return false;
  718. }
  719. // Compare memory type and logic address
  720. if (iter.second.memory_type != memory_type || iter.second.address != dev_ptr) {
  721. GELOGW("The VarAddrMgr of variable[%s] in cache is different from VarManager.", iter.first.c_str());
  722. return false;
  723. }
  724. }
  725. return true;
  726. }
  727. bool ModelCacheHelper::IsBroadcastInfoSameAsCache(Json &json) const {
  728. if (!(json.is_null() || json.is_array())) {
  729. GELOGW("Input param json type should be null or array.");
  730. return false;
  731. }
  732. // Compare broadcast info between json and VarManager
  733. std::unordered_map<std::string, VarBroadCastInfo> var_broadcast_info;
  734. auto ret = ParseBroadcastInfoFromJson(json, var_broadcast_info);
  735. if (ret != SUCCESS) {
  736. GELOGW("Fail to parse BroadcastInfo from Json.");
  737. return false;
  738. }
  739. for (const auto &iter : var_broadcast_info) {
  740. VarBroadCastInfo broadcast_info;
  741. if (VarManager::Instance(session_id_)->GetBroadCastInfo(graph_id_, iter.first, broadcast_info) != SUCCESS) {
  742. GELOGW("Fail to find broadcast info of var[%s].", iter.first.c_str());
  743. return false;
  744. }
  745. if (iter.second.var_name != broadcast_info.var_name || iter.second.idx != broadcast_info.idx ||
  746. iter.second.input_size != broadcast_info.input_size ||
  747. iter.second.input_offset != broadcast_info.input_offset ||
  748. iter.second.output_size != broadcast_info.output_size ||
  749. iter.second.output_offset != broadcast_info.output_offset) {
  750. GELOGW("The BroadcastInfo of variable[%s] in cache is different from VarManager.", iter.first.c_str());
  751. return false;
  752. }
  753. }
  754. return true;
  755. }
  756. bool ModelCacheHelper::IsTransRoadsSameAsCache(Json &json) const {
  757. if (!(json.is_null() || json.is_array())) {
  758. GELOGW("Input param json type should be null or array.");
  759. return false;
  760. }
  761. // Compare trans road between json and VarManager
  762. std::unordered_map<std::string, std::vector<TransNodeInfo>> trans_roads;
  763. auto ret = ParseTransRoadsFromJson(json, trans_roads);
  764. if (ret != SUCCESS) {
  765. GELOGW("Fail to parse TransRoads from Json.");
  766. return false;
  767. }
  768. for (const auto &iter : trans_roads) {
  769. VarTransRoad *trans_road;
  770. trans_road = VarManager::Instance(session_id_)->GetTransRoad(iter.first);
  771. if (trans_road == nullptr) {
  772. GELOGW("Fail to find trans road of var[%s].", iter.first.c_str());
  773. return false;
  774. }
  775. if (trans_road->size() != iter.second.size()) {
  776. GELOGW("The TransRoad of variable[%s] in cache is different from VarManager.", iter.first.c_str());
  777. return false;
  778. }
  779. // Compare every trans node in trans road.
  780. for (size_t idx = 0; idx < trans_road->size(); idx += 1) {
  781. if (!(trans_road->at(idx).node_type == iter.second.at(idx).node_type &&
  782. trans_road->at(idx).input == iter.second.at(idx).input &&
  783. trans_road->at(idx).output == iter.second.at(idx).output)) {
  784. GELOGW("The TransRoad of variable[%s] in cache is different from VarManager.", iter.first.c_str());
  785. return false;
  786. }
  787. }
  788. }
  789. return true;
  790. }
  791. bool ModelCacheHelper::IsVarManagerParamSameAsCache(Json &json) const {
  792. if (!json.is_object()) {
  793. GELOGW("Input param json type should be object.");
  794. return false;
  795. }
  796. try {
  797. if (json[kSessionId].get<uint64_t>() != session_id_) {
  798. GELOGW("Check VarManager cache failed.[sessionId]");
  799. return false;
  800. }
  801. if (json[kDeviceId].get<uint32_t>() != VarManager::Instance(session_id_)->DeviceId()) {
  802. GELOGW("Check VarManager cache failed.[deviceId]");
  803. return false;
  804. }
  805. if (json[kJobId].get<uint64_t>() != VarManager::Instance(session_id_)->JobId()) {
  806. GELOGW("Check VarManager cache failed.[jobId]");
  807. return false;
  808. }
  809. if (json[kGraphMemMaxSize].get<size_t>() != VarManager::Instance(session_id_)->GetGraphMemoryMaxSize()) {
  810. GELOGW("Check VarManager cache failed.[graphMemMaxSize]");
  811. return false;
  812. }
  813. if (json[kVarMemMaxSize].get<size_t>() != VarManager::Instance(session_id_)->GetVarMemMaxSize()) {
  814. GELOGW("Check VarManager cache failed.[varMemMaxSize]");
  815. return false;
  816. }
  817. if (json[kVarMemLogicBase].get<size_t>() != VarManager::Instance(session_id_)->GetVarMemLogicBase()) {
  818. GELOGW("Check VarManager cache failed.[varMemLogicBase]");
  819. return false;
  820. }
  821. if (json[kUseMaxMemSize].get<size_t>() != VarManager::Instance(session_id_)->GetUseMaxMemorySize()) {
  822. GELOGW("Check VarManager cache failed.[useMaxMemSize]");
  823. return false;
  824. }
  825. } catch (const std::exception &e) {
  826. GELOGW("Fail to check VarManager json. Error message: %s", e.what());
  827. return false;
  828. }
  829. return true;
  830. }
  831. bool ModelCacheHelper::IsVarManagerSameAsCache(Json &json) const {
  832. if (!json.is_object()) {
  833. GELOGW("Input param json type should be object.");
  834. return false;
  835. }
  836. try {
  837. if (!IsVarManagerParamSameAsCache(json)) {
  838. GELOGW("Check VarManager cache failed.[Param]");
  839. return false;
  840. }
  841. Json mem_resource_json = move(json[kMemResourceMap]);
  842. auto ret = IsMemResourceSameAsCache(mem_resource_json);
  843. if (!ret) {
  844. GELOGW("Check VarManager cache failed.[MemResource]");
  845. return false;
  846. }
  847. Json var_resource_json = move(json[kVarResource]);
  848. ret = IsAllocatedGraphIdSameAsCache(var_resource_json[kAllocatedGraphId]);
  849. if (!ret) {
  850. GELOGW("Check VarManager cache failed.[AllocatedGraphId]");
  851. return false;
  852. }
  853. ret = IsChangedGraphIdSameAsCache(var_resource_json[kChangedGraphId]);
  854. if (!ret) {
  855. GELOGW("Check VarManager cache failed.[ChangedGraphId]");
  856. return false;
  857. }
  858. ret = IsBroadcastInfoSameAsCache(var_resource_json[kVarBroadcastInfo]);
  859. if (!ret) {
  860. GELOGW("Check VarManager cache failed.[VarBroadcastInfo]");
  861. return false;
  862. }
  863. ret = IsCurVarTensorDescSameAsCache(var_resource_json[kCurVarTensorDescMap]);
  864. if (!ret) {
  865. GELOGW("Check VarManager cache failed.[CurVarTensorDesc]");
  866. return false;
  867. }
  868. ret = IsVarAddrMgrMapSameAsCache(var_resource_json[kVarAddrMgrMap]);
  869. if (!ret) {
  870. GELOGW("Check VarManager cache failed.[VarAddrMgrMap]");
  871. return false;
  872. }
  873. ret = IsTransRoadsSameAsCache(var_resource_json[kTransRoads]);
  874. if (!ret) {
  875. GELOGW("Check VarManager cache failed.[TransRoads]");
  876. return false;
  877. }
  878. } catch (const std::exception &e) {
  879. GELOGW("Fail to check VarManager json. Error message: %s", e.what());
  880. return false;
  881. }
  882. return true;
  883. }
  884. Status ModelCacheHelper::RecoverMemResource(const Json &json) const {
  885. if (!(json.is_null() || json.is_array())) {
  886. GELOGW("Input param json type should be null or array.");
  887. return PARAM_INVALID;
  888. }
  889. std::map<rtMemType_t, int64_t> var_mem_size;
  890. auto ret = ParseMemResourceFromJson(json, var_mem_size);
  891. if (ret != SUCCESS) {
  892. GELOGW("Fail to parse MemResource from Json.");
  893. return ret;
  894. }
  895. for (const auto &iter : var_mem_size) {
  896. ret = VarManager::Instance(session_id_)->UpdateVarMemSize(iter.first, iter.second);
  897. if (ret != SUCCESS) {
  898. GELOGW("Fail to recover var mem size.");
  899. return ret;
  900. }
  901. }
  902. return SUCCESS;
  903. }
  904. Status ModelCacheHelper::RecoverAllocatedGraphId(const Json &json) const {
  905. if (!(json.is_null() || json.is_array())) {
  906. GELOGW("Input param json type should be null or array.");
  907. return PARAM_INVALID;
  908. }
  909. std::unordered_map<std::string, uint32_t> allocated_graph_id;
  910. auto ret = ParseAllocatedGraphIdFromJson(json, allocated_graph_id);
  911. if (ret != SUCCESS) {
  912. GELOGW("Fail to parse AllocatedGraphId from Json.");
  913. return ret;
  914. }
  915. for (const auto &iter : allocated_graph_id) {
  916. ret = VarManager::Instance(session_id_)->SetAllocatedGraphId(iter.first, iter.second);
  917. if (ret != SUCCESS) {
  918. GELOGW("Fail to recover allocated graph id.");
  919. return ret;
  920. }
  921. }
  922. return SUCCESS;
  923. }
  924. Status ModelCacheHelper::RecoverChangedGraphId(const Json &json) const {
  925. if (!(json.is_null() || json.is_array())) {
  926. GELOGW("Input param json type should be null or array.");
  927. return PARAM_INVALID;
  928. }
  929. std::unordered_map<std::string, uint32_t> changed_graph_id;
  930. auto ret = ParseChangedGraphIdFromJson(json, changed_graph_id);
  931. if (ret != SUCCESS) {
  932. GELOGW("Fail to parse AllocatedGraphId from Json.");
  933. return ret;
  934. }
  935. for (const auto &iter : changed_graph_id) {
  936. ret = VarManager::Instance(session_id_)->SetChangedGraphId(iter.first, iter.second);
  937. if (ret != SUCCESS) {
  938. GELOGW("Fail to recover changed graph id.");
  939. return ret;
  940. }
  941. }
  942. return SUCCESS;
  943. }
  944. Status ModelCacheHelper::RecoverVarAddrAndTensorDesc(const Json &json) const {
  945. if (!(json.is_null() || json.is_array())) {
  946. GELOGW("Input param json type should be null or array.");
  947. return PARAM_INVALID;
  948. }
  949. std::vector<std::pair<std::string, VarAddrMgr>> var_addr_mgr_vector;
  950. std::unordered_set<uint64_t> var_offset_set;
  951. auto ret = ParseVarAddrMgrMapFromJson(json, var_addr_mgr_vector, var_offset_set);
  952. if (ret != SUCCESS) {
  953. GELOGW("Fail to parse VarAddrMgrMap from Json.");
  954. return ret;
  955. }
  956. for (const auto &iter : var_addr_mgr_vector) {
  957. const VarAddrMgr &tensor_addr_mgr = iter.second;
  958. const bool var_exist = VarManager::Instance(session_id_)->IsVarExist(iter.first, tensor_addr_mgr.tensor_desc);
  959. // SaveVarVddr if var does not exist, the logic address will be recorded by VarManager
  960. if (!var_exist) {
  961. auto logic_address = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(tensor_addr_mgr.address));
  962. auto offset = (tensor_addr_mgr.offset);
  963. // Check logic address and offset
  964. if (logic_address - offset != VarManager::Instance(session_id_)->GetVarMemLogicBase()) {
  965. GELOGW("Check logic_address[%u] and offset [%u] of %s failed, var mem logic base is %u, abandon", logic_address,
  966. offset, iter.first.c_str(), VarManager::Instance(session_id_)->GetVarMemLogicBase());
  967. return PARAM_INVALID;
  968. }
  969. // Offset is needed by SaveVarVddr instead of logic address
  970. ret = VarManager::Instance(session_id_)->SaveVarAddr(iter.first, tensor_addr_mgr.tensor_desc,
  971. reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(offset)),
  972. tensor_addr_mgr.memory_type);
  973. if (ret != SUCCESS) {
  974. GELOGW("Fail to recover VarAddr or TensorDesc of var[%s].", iter.first.c_str());
  975. return ret;
  976. }
  977. }
  978. // SetVarAddr to update cur_var_tensor_desc_map_
  979. ret = VarManager::Instance(session_id_)
  980. ->SetVarAddr(iter.first, tensor_addr_mgr.tensor_desc, tensor_addr_mgr.address, tensor_addr_mgr.memory_type);
  981. if (ret != SUCCESS) {
  982. GELOGW("Fail to recover VarAddr or TensorDesc desc of var[%s].", iter.first.c_str());
  983. return ret;
  984. }
  985. }
  986. return SUCCESS;
  987. }
  988. Status ModelCacheHelper::RecoverBroadcastInfo(const Json &json) const {
  989. if (!(json.is_null() || json.is_array())) {
  990. GELOGW("Input param json type should be null or array.");
  991. return PARAM_INVALID;
  992. }
  993. std::unordered_map<std::string, VarBroadCastInfo> var_broadcast_info;
  994. auto ret = ParseBroadcastInfoFromJson(json, var_broadcast_info);
  995. if (ret != SUCCESS) {
  996. GELOGW("Fail to parse BroadcastInfo from Json.");
  997. return ret;
  998. }
  999. for (const auto &iter : var_broadcast_info) {
  1000. VarBroadCastInfo broadcast_info;
  1001. ret = VarManager::Instance(session_id_)->SaveBroadCastInfo(graph_id_, iter.second);
  1002. if (ret != SUCCESS) {
  1003. GELOGW("Fail to recover broadcast info of var[%s].", iter.first.c_str());
  1004. return ret;
  1005. }
  1006. }
  1007. return SUCCESS;
  1008. }
  1009. Status ModelCacheHelper::RecoverTransRoads(const Json &json) const {
  1010. if (!(json.is_null() || json.is_array())) {
  1011. GELOGW("Input param json type should be null or array.");
  1012. return PARAM_INVALID;
  1013. }
  1014. std::unordered_map<std::string, std::vector<TransNodeInfo>> trans_roads;
  1015. auto ret = ParseTransRoadsFromJson(json, trans_roads);
  1016. if (ret != SUCCESS) {
  1017. GELOGW("Fail to parse TransRoads from Json.");
  1018. return ret;
  1019. }
  1020. for (const auto &iter : trans_roads) {
  1021. ret = VarManager::Instance(session_id_)->SetTransRoad(iter.first, iter.second);
  1022. if (ret != SUCCESS) {
  1023. GELOGW("Fail to find trans road of var[%s].", iter.first.c_str());
  1024. return ret;
  1025. }
  1026. }
  1027. return SUCCESS;
  1028. }
  1029. Status ModelCacheHelper::TensorDescToJson(const GeTensorDesc &ge_tensor_desc, Json &json) {
  1030. if (!(json.is_null() || json.is_object())) {
  1031. GELOGW("Input param json type should be null or object.");
  1032. return PARAM_INVALID;
  1033. }
  1034. try {
  1035. json[kDataType] = static_cast<int>(ge_tensor_desc.GetDataType());
  1036. json[kOriginDataType] = static_cast<int>(ge_tensor_desc.GetOriginDataType());
  1037. json[kLayout] = static_cast<int>(ge_tensor_desc.GetFormat());
  1038. json[kOriginLayout] = static_cast<int>(ge_tensor_desc.GetOriginFormat());
  1039. json[kShape] = ge_tensor_desc.GetShape().GetDims();
  1040. json[kOriginShape] = ge_tensor_desc.GetOriginShape().GetDims();
  1041. uint32_t real_dim_cnt = 0;
  1042. (void)TensorUtils::GetRealDimCnt(ge_tensor_desc, real_dim_cnt); // [No need to check value]
  1043. json[kRealDimCnt] = real_dim_cnt;
  1044. } catch (const std::exception &e) {
  1045. GELOGW("Fail to trans GeTensorDesc to json. Error message: %s", e.what());
  1046. return INTERNAL_ERROR;
  1047. }
  1048. return SUCCESS;
  1049. }
  1050. Status ModelCacheHelper::JsonToTensorDesc(const Json &json, ge::GeTensorDesc &ge_tensor_desc) {
  1051. if (!json.is_object()) {
  1052. GELOGW("Input param json type should be object.");
  1053. return PARAM_INVALID;
  1054. }
  1055. try {
  1056. ge_tensor_desc.SetDataType(static_cast<DataType>(json[kDataType].get<int>()));
  1057. ge_tensor_desc.SetOriginDataType(static_cast<DataType>(json[kOriginDataType].get<int>()));
  1058. ge_tensor_desc.SetFormat(static_cast<Format>(json[kLayout].get<int>()));
  1059. ge_tensor_desc.SetOriginFormat(static_cast<Format>(json[kOriginLayout].get<int>()));
  1060. GeShape shape(json[kShape].get<std::vector<int64_t>>());
  1061. ge_tensor_desc.SetShape(shape);
  1062. GeShape origin_shape(json[kOriginShape].get<std::vector<int64_t>>());
  1063. ge_tensor_desc.SetOriginShape(origin_shape);
  1064. auto real_dim_cnt = json[kRealDimCnt].get<uint32_t>();
  1065. (void)TensorUtils::SetRealDimCnt(ge_tensor_desc, real_dim_cnt); // [No need to check value]
  1066. } catch (const std::exception &e) {
  1067. GELOGW("Fail to trans Json to GeTensorDesc. Error message: %s", e.what());
  1068. return INTERNAL_ERROR;
  1069. }
  1070. return SUCCESS;
  1071. }
  1072. Status ModelCacheHelper::GetNodesHashMapJson(Json &json) const {
  1073. if (!(json.is_null() || json.is_array())) {
  1074. GELOGW("Input param json type should be null or array.");
  1075. return PARAM_INVALID;
  1076. }
  1077. map<std::string, size_t> hash_map;
  1078. GetNodesHash(hash_map);
  1079. for (const auto &iter : hash_map) {
  1080. Json node_hash_json;
  1081. try {
  1082. node_hash_json[kName] = iter.first;
  1083. node_hash_json[kHash] = iter.second;
  1084. json.emplace_back(move(node_hash_json));
  1085. } catch (const std::exception &e) {
  1086. GELOGW("Fail to trans node cache to json. Error message: %s", e.what());
  1087. return INTERNAL_ERROR;
  1088. }
  1089. }
  1090. return SUCCESS;
  1091. }
  1092. Status ModelCacheHelper::GetMemResourceMap(Json &json) const {
  1093. if (!(json.is_null() || json.is_array())) {
  1094. GELOGW("Input param json type should be null or array.");
  1095. return PARAM_INVALID;
  1096. }
  1097. const auto total_size = VarManager::Instance(session_id_)->GetVarMemMaxSize();
  1098. const auto var_mem_size = VarManager::Instance(session_id_)->GetVarMemSize(RT_MEMORY_HBM);
  1099. Json mem_resource_json;
  1100. try {
  1101. mem_resource_json[kMemType] = RT_MEMORY_HBM;
  1102. mem_resource_json[kTotalSize] = total_size;
  1103. mem_resource_json[kVarMemSize] = var_mem_size;
  1104. json.emplace_back(move(mem_resource_json));
  1105. } catch (const std::exception &e) {
  1106. GELOGW("Fail to trans MemResourceMap to json. Error message: %s", e.what());
  1107. return INTERNAL_ERROR;
  1108. }
  1109. return SUCCESS;
  1110. }
  1111. Status ModelCacheHelper::GetVarAddrMgrMapJson(Json &json) const {
  1112. if (!(json.is_null() || json.is_array())) {
  1113. GELOGW("Input param json type should be null or array.");
  1114. return PARAM_INVALID;
  1115. }
  1116. std::unordered_map<std::string, VarAddrMgr> var_addr_mgr_map;
  1117. VarManager::Instance(session_id_)->GetAllVarAddrMgr(var_addr_mgr_map);
  1118. try {
  1119. for (const auto &iter : var_addr_mgr_map) {
  1120. Json var_addr_json;
  1121. string name;
  1122. GetVarNameFromVarKey(iter.first, iter.second.tensor_desc, name);
  1123. var_addr_json[kName] = name;
  1124. var_addr_json[kAddress] = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(iter.second.address));
  1125. var_addr_json[kMemoryType] = iter.second.memory_type;
  1126. var_addr_json[kOffset] = iter.second.offset;
  1127. // Copy tensor desc to json.
  1128. Json tensor_desc_json;
  1129. auto ret = TensorDescToJson(iter.second.tensor_desc, tensor_desc_json);
  1130. if (ret != SUCCESS) {
  1131. GELOGW("Fail to trans tensor desc to json.");
  1132. return INTERNAL_ERROR;
  1133. }
  1134. var_addr_json[kTensorDesc] = move(tensor_desc_json);
  1135. json.emplace_back(move(var_addr_json));
  1136. }
  1137. } catch (const std::exception &e) {
  1138. GELOGW("Fail to trans VarAddrMgrMap to json. Error message: %s", e.what());
  1139. return INTERNAL_ERROR;
  1140. }
  1141. return SUCCESS;
  1142. }
  1143. Status ModelCacheHelper::GetCurVarTensorDescMapJson(Json &json) const {
  1144. if (!(json.is_null() || json.is_array())) {
  1145. GELOGW("Input param json type should be null or array.");
  1146. return PARAM_INVALID;
  1147. }
  1148. try {
  1149. for (const auto &name : var_names_) {
  1150. Json cur_tensor_desc_json;
  1151. GeTensorDesc tensor_desc;
  1152. auto ret = VarManager::Instance(session_id_)->GetCurVarDesc(name, tensor_desc);
  1153. if (ret != SUCCESS) {
  1154. GELOGI("Get variable[%s] current tensor desc failed. It will be skipped.", name.c_str());
  1155. continue;
  1156. }
  1157. cur_tensor_desc_json[kName] = name;
  1158. Json tensor_desc_json;
  1159. ret = TensorDescToJson(tensor_desc, tensor_desc_json);
  1160. if (ret != SUCCESS) {
  1161. GELOGW("Fail to trans tensor desc to json.");
  1162. return INTERNAL_ERROR;
  1163. }
  1164. cur_tensor_desc_json[kTensorDesc] = move(tensor_desc_json);
  1165. json.emplace_back(move(cur_tensor_desc_json));
  1166. }
  1167. } catch (const std::exception &e) {
  1168. GELOGW("Fail to trans CurVarTensorDescMap to json. Error message: %s", e.what());
  1169. return INTERNAL_ERROR;
  1170. }
  1171. return SUCCESS;
  1172. }
  1173. Status ModelCacheHelper::GetTransRoadsJson(Json &json) const {
  1174. if (!(json.is_null() || json.is_array())) {
  1175. GELOGW("Input param json type should be null or array.");
  1176. return PARAM_INVALID;
  1177. }
  1178. try {
  1179. for (const auto &name : var_names_) {
  1180. auto trans_road = VarManager::Instance(session_id_)->GetTransRoad(name);
  1181. if (trans_road == nullptr) {
  1182. continue;
  1183. }
  1184. // Json object, variable name and trans road
  1185. Json trans_road_map_json;
  1186. trans_road_map_json[kName] = name;
  1187. Json trans_road_json;
  1188. Status ret;
  1189. // Add nodes' info to json
  1190. for (const auto &trans_node_info : *trans_road) {
  1191. Json trans_node_info_json;
  1192. trans_node_info_json[kNodeType] = trans_node_info.node_type;
  1193. Json input_tensor_desc_json;
  1194. ret = TensorDescToJson(trans_node_info.input, input_tensor_desc_json);
  1195. if (ret != SUCCESS) {
  1196. GELOGW("Fail to trans tensor desc to json.");
  1197. return INTERNAL_ERROR;
  1198. }
  1199. trans_node_info_json[kInputTensorDesc] = move(input_tensor_desc_json);
  1200. Json output_tensor_desc_json;
  1201. ret = TensorDescToJson(trans_node_info.output, output_tensor_desc_json);
  1202. if (ret != SUCCESS) {
  1203. GELOGW("Fail to trans tensor desc to json.");
  1204. return INTERNAL_ERROR;
  1205. }
  1206. trans_node_info_json[kOutputTensorDesc] = move(output_tensor_desc_json);
  1207. trans_road_json.emplace_back(move(trans_node_info_json));
  1208. }
  1209. trans_road_map_json[kTransRoad] = move(trans_road_json);
  1210. json.emplace_back(move(trans_road_map_json));
  1211. }
  1212. } catch (const std::exception &e) {
  1213. GELOGW("Fail to trans VarToTransRoad to json. Error message: %s", e.what());
  1214. return INTERNAL_ERROR;
  1215. }
  1216. return SUCCESS;
  1217. }
  1218. Status ModelCacheHelper::GetChangedGraphIdJson(Json &json) const {
  1219. if (!(json.is_null() || json.is_array())) {
  1220. GELOGW("Input param json type should be null or array.");
  1221. return PARAM_INVALID;
  1222. }
  1223. for (const auto &name : var_names_) {
  1224. uint32_t changed_graph_id = 0;
  1225. Status ret = VarManager::Instance(session_id_)->GetChangedGraphId(name, changed_graph_id);
  1226. if (ret != SUCCESS) {
  1227. continue;
  1228. }
  1229. Json name_and_changed_graph_id;
  1230. try {
  1231. name_and_changed_graph_id[kName] = name;
  1232. name_and_changed_graph_id[kGraphId] = changed_graph_id;
  1233. json.emplace_back(move(name_and_changed_graph_id));
  1234. } catch (const std::exception &e) {
  1235. GELOGW("Fail to trans ChangedGraphId to json. Error message: %s", e.what());
  1236. return INTERNAL_ERROR;
  1237. }
  1238. }
  1239. return SUCCESS;
  1240. }
  1241. Status ModelCacheHelper::GetAllocatedGraphIdJson(Json &json) const {
  1242. if (!(json.is_null() || json.is_array())) {
  1243. GELOGW("Input param json type should be null or array.");
  1244. return PARAM_INVALID;
  1245. }
  1246. for (const auto &name : var_names_) {
  1247. uint32_t allocated_graph_id = 0;
  1248. Status ret = VarManager::Instance(session_id_)->GetAllocatedGraphId(name, allocated_graph_id);
  1249. if (ret != SUCCESS) {
  1250. continue;
  1251. }
  1252. Json name_and_allocated_graph_id;
  1253. try {
  1254. name_and_allocated_graph_id[kName] = name;
  1255. name_and_allocated_graph_id[kGraphId] = allocated_graph_id;
  1256. json.emplace_back(move(name_and_allocated_graph_id));
  1257. } catch (const std::exception &e) {
  1258. GELOGW("Fail to trans AllocatedGraphId to json. Error message: %s", e.what());
  1259. return INTERNAL_ERROR;
  1260. }
  1261. }
  1262. return SUCCESS;
  1263. }
  1264. Status ModelCacheHelper::GetBroadcastInfoJson(Json &json) const {
  1265. if (!(json.is_null() || json.is_array())) {
  1266. GELOGW("Input param json type should be null or array.");
  1267. return PARAM_INVALID;
  1268. }
  1269. for (const auto &name : var_names_) {
  1270. VarBroadCastInfo var_broadcast_info;
  1271. Status ret = VarManager::Instance(session_id_)->GetBroadCastInfo(graph_id_, name, var_broadcast_info);
  1272. if (ret != SUCCESS) {
  1273. continue;
  1274. }
  1275. Json var_broadcast_info_json;
  1276. try {
  1277. var_broadcast_info_json[kName] = name;
  1278. var_broadcast_info_json[kBroadcastName] = var_broadcast_info.broadcast_name;
  1279. var_broadcast_info_json[kIdx] = var_broadcast_info.idx;
  1280. var_broadcast_info_json[kInputOffset] = var_broadcast_info.input_offset;
  1281. var_broadcast_info_json[kInputSize] = var_broadcast_info.input_size;
  1282. var_broadcast_info_json[kOutputOffset] = var_broadcast_info.output_offset;
  1283. var_broadcast_info_json[kOutputSize] = var_broadcast_info.output_size;
  1284. json.emplace_back(move(var_broadcast_info_json));
  1285. } catch (const std::exception &e) {
  1286. GELOGW("Fail to trans VarBroadcastInfo to json. Error message: %s", e.what());
  1287. return INTERNAL_ERROR;
  1288. }
  1289. }
  1290. return SUCCESS;
  1291. }
  1292. Status ModelCacheHelper::GetVarResourceJson(Json &json) const {
  1293. if (!(json.is_null() || json.is_object())) {
  1294. GELOGW("Input param json type should be null or object.");
  1295. return PARAM_INVALID;
  1296. }
  1297. Json var_addr_mgr_map_json;
  1298. Status ret = GetVarAddrMgrMapJson(var_addr_mgr_map_json);
  1299. if (ret != SUCCESS) {
  1300. GELOGW("GetVarAddrMgrMapJson failed.");
  1301. return INTERNAL_ERROR;
  1302. }
  1303. Json cur_var_tensor_desc_map_json;
  1304. ret = GetCurVarTensorDescMapJson(cur_var_tensor_desc_map_json);
  1305. if (ret != SUCCESS) {
  1306. GELOGW("GetCurVarTensorDescMapJson failed.");
  1307. return INTERNAL_ERROR;
  1308. }
  1309. Json trans_roads_json;
  1310. ret = GetTransRoadsJson(trans_roads_json);
  1311. if (ret != SUCCESS) {
  1312. GELOGW("GetTransRoadsJson failed.");
  1313. return INTERNAL_ERROR;
  1314. }
  1315. Json changed_graph_id_json;
  1316. ret = GetChangedGraphIdJson(changed_graph_id_json);
  1317. if (ret != SUCCESS) {
  1318. GELOGW("GetChangedGraphIdJson failed.");
  1319. return INTERNAL_ERROR;
  1320. }
  1321. Json allocated_graph_id_json;
  1322. ret = GetAllocatedGraphIdJson(allocated_graph_id_json);
  1323. if (ret != SUCCESS) {
  1324. GELOGW("GetAllocatedGraphIdJson failed.");
  1325. return INTERNAL_ERROR;
  1326. }
  1327. Json var_broadcast_info_json;
  1328. ret = GetBroadcastInfoJson(var_broadcast_info_json);
  1329. if (ret != SUCCESS) {
  1330. GELOGW("GetBroadcastInfoJson failed.");
  1331. return INTERNAL_ERROR;
  1332. }
  1333. try {
  1334. json[kVarAddrMgrMap] = move(var_addr_mgr_map_json);
  1335. json[kCurVarTensorDescMap] = move(cur_var_tensor_desc_map_json);
  1336. json[kTransRoads] = move(trans_roads_json);
  1337. json[kChangedGraphId] = move(changed_graph_id_json);
  1338. json[kAllocatedGraphId] = move(allocated_graph_id_json);
  1339. json[kVarBroadcastInfo] = move(var_broadcast_info_json);
  1340. } catch (const exception &e) {
  1341. GELOGW("Fail to generate VarResource json. Error message: %s", e.what());
  1342. return INTERNAL_ERROR;
  1343. }
  1344. return SUCCESS;
  1345. }
  1346. Status ModelCacheHelper::GetVarManagerJson(Json &json) const {
  1347. if (!(json.is_null() || json.is_object())) {
  1348. GELOGW("Input param json type should be null or object.");
  1349. return PARAM_INVALID;
  1350. }
  1351. Json mem_resource_map_json;
  1352. auto ret = GetMemResourceMap(mem_resource_map_json);
  1353. if (ret != SUCCESS) {
  1354. GELOGW("GetMemResourceMap failed.");
  1355. return INTERNAL_ERROR;
  1356. }
  1357. Json var_resource_json;
  1358. ret = GetVarResourceJson(var_resource_json);
  1359. if (ret != SUCCESS) {
  1360. GELOGW("GetVarResourceJson failed.");
  1361. return INTERNAL_ERROR;
  1362. }
  1363. try {
  1364. json[kSessionId] = session_id_;
  1365. json[kDeviceId] = VarManager::Instance(session_id_)->DeviceId();
  1366. json[kJobId] = VarManager::Instance(session_id_)->JobId();
  1367. json[kGraphMemMaxSize] = VarManager::Instance(session_id_)->GetGraphMemoryMaxSize();
  1368. json[kVarMemMaxSize] = VarManager::Instance(session_id_)->GetVarMemMaxSize();
  1369. json[kVarMemLogicBase] = VarManager::Instance(session_id_)->GetVarMemLogicBase();
  1370. json[kUseMaxMemSize] = VarManager::Instance(session_id_)->GetUseMaxMemorySize();
  1371. json[kMemResourceMap] = move(mem_resource_map_json);
  1372. json[kVarResource] = move(var_resource_json);
  1373. } catch (const exception &e) {
  1374. GELOGW("Fail to generate VarManager json. Error message: %s", e.what());
  1375. return INTERNAL_ERROR;
  1376. }
  1377. return SUCCESS;
  1378. }
  1379. Status ModelCacheHelper::SaveVarManagerToCache(bool before_build) const {
  1380. if (!is_cache_path_valid_for_output) {
  1381. GELOGW("Invalid cache path.");
  1382. return FAILED;
  1383. }
  1384. Json var_manager_json;
  1385. auto ret = GetVarManagerJson(var_manager_json);
  1386. if (ret != SUCCESS) {
  1387. GELOGW("Fail to generate VarManager json.");
  1388. return FAILED;
  1389. }
  1390. string var_manager_path = to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) +
  1391. (before_build ? kBeforeVarManagerSuffix : kAfterVarManagerSuffix);
  1392. ret = SaveJsonToFile(var_manager_path, var_manager_json);
  1393. if (ret != SUCCESS) {
  1394. GELOGW("Fail to save VarManager info to json file, path: %s.", cache_path_.c_str());
  1395. return ret;
  1396. }
  1397. return SUCCESS;
  1398. }
  1399. Status ModelCacheHelper::SaveOmModelToCache(const GeModelPtr &ge_model) const {
  1400. if (!is_cache_path_valid_for_output) {
  1401. GELOGW("Invalid cache path.");
  1402. return FAILED;
  1403. }
  1404. string om_path = RealPath(cache_path_.c_str());
  1405. if (om_path.empty()) {
  1406. GELOGW("file path is invalid. please check path om: %s", cache_path_.c_str());
  1407. return FAILED;
  1408. }
  1409. string cache_om_path = cache_path_;
  1410. cache_om_path += (to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kOmSuffix);
  1411. GELOGI("SaveOmModelToCache: start to save om model : %s", cache_om_path.c_str());
  1412. ModelHelper model_helper;
  1413. SaveParam save_param;
  1414. ModelBufferData model;
  1415. Status ret = model_helper.SaveToOmModel(ge_model, save_param, cache_om_path, model);
  1416. if (ret != SUCCESS) {
  1417. GELOGW("SaveOmModelToCache: save mode failed. ret = %u", ret);
  1418. return ret;
  1419. }
  1420. return SUCCESS;
  1421. }
  1422. Status ModelCacheHelper::ParseMemResourceFromJson(const Json &json, map<rtMemType_t, int64_t> &mem_resource) {
  1423. if (!(json.is_array() || json.is_null())) {
  1424. GELOGW("Input param json type should be null or array.");
  1425. return PARAM_INVALID;
  1426. }
  1427. mem_resource.clear();
  1428. for (const Json &mem_resource_json : json) {
  1429. try {
  1430. rtMemType_t mem_type = mem_resource_json[kMemType].get<rtMemType_t>();
  1431. uint64_t var_mem_size = mem_resource_json[kVarMemSize].get<int64_t>();
  1432. mem_resource[mem_type] = var_mem_size;
  1433. } catch (const exception &e) {
  1434. GELOGW("Fail to trans Json to MemResource. Error message: %s", e.what());
  1435. return INTERNAL_ERROR;
  1436. }
  1437. }
  1438. return SUCCESS;
  1439. }
  1440. Status ModelCacheHelper::ParseVarAddrMgrMapFromJson(
  1441. const Json &json, std::vector<std::pair<std::string, VarAddrMgr>> &var_addr_mgr_vector,
  1442. std::unordered_set<uint64_t> &var_offset_set) {
  1443. if (!(json.is_array() || json.is_null())) {
  1444. GELOGW("Input param json type should be null or array.");
  1445. return PARAM_INVALID;
  1446. }
  1447. var_addr_mgr_vector.clear();
  1448. var_offset_set.clear();
  1449. for (const Json &var_addr_json : json) {
  1450. VarAddrMgr var_addr_mgr;
  1451. try {
  1452. auto logic_address = var_addr_json[kAddress].get<uint64_t>();
  1453. auto address = reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(logic_address));
  1454. var_addr_mgr.address = address;
  1455. var_addr_mgr.offset = var_addr_json[kOffset].get<uint64_t>();
  1456. var_addr_mgr.memory_type = var_addr_json[kMemoryType].get<rtMemType_t>();
  1457. auto ret = JsonToTensorDesc(var_addr_json[kTensorDesc], var_addr_mgr.tensor_desc);
  1458. if (ret != SUCCESS) {
  1459. GELOGW("Fail to trans json to tensor desc.");
  1460. return ret;
  1461. }
  1462. var_addr_mgr_vector.emplace_back(var_addr_json[kName].get<string>(), move(var_addr_mgr));
  1463. var_offset_set.insert(logic_address);
  1464. } catch (const exception &e) {
  1465. GELOGW("Fail to trans Json to VarAddrMgr. Error message: %s", e.what());
  1466. return INTERNAL_ERROR;
  1467. }
  1468. }
  1469. return SUCCESS;
  1470. }
  1471. Status ModelCacheHelper::ParseCurVarTensorDescMapFromJson(
  1472. const Json &json, std::unordered_map<std::string, ge::GeTensorDesc> &cur_var_tensor_desc_map) {
  1473. if (!(json.is_array() || json.is_null())) {
  1474. GELOGW("Input param json type should be null or array.");
  1475. return PARAM_INVALID;
  1476. }
  1477. cur_var_tensor_desc_map.clear();
  1478. for (const Json &tensor_desc_json : json) {
  1479. GeTensorDesc tensor_desc;
  1480. try {
  1481. auto ret = JsonToTensorDesc(tensor_desc_json[kTensorDesc], tensor_desc);
  1482. if (ret != SUCCESS) {
  1483. GELOGW("Fail to trans json to tensor desc.");
  1484. return ret;
  1485. }
  1486. cur_var_tensor_desc_map[tensor_desc_json[kName].get<string>()] = move(tensor_desc);
  1487. } catch (const exception &e) {
  1488. GELOGW("Fail to trans Json to VarAddrMgr. Error message: %s", e.what());
  1489. return INTERNAL_ERROR;
  1490. }
  1491. }
  1492. return SUCCESS;
  1493. }
  1494. Status ModelCacheHelper::ParseTransRoadsFromJson(
  1495. const Json &json, std::unordered_map<std::string, std::vector<TransNodeInfo>> &trans_roads) {
  1496. if (!(json.is_array() || json.is_null())) {
  1497. GELOGW("Input param json type should be null or array.");
  1498. return PARAM_INVALID;
  1499. }
  1500. trans_roads.clear();
  1501. try {
  1502. for (const Json &name_trans_road_json : json) {
  1503. const Json &trans_road_json = name_trans_road_json[kTransRoad];
  1504. if (!(trans_road_json.is_array() || trans_road_json.is_null())) {
  1505. GELOGW("%s json type should be null or object.", kTransRoad);
  1506. return PARAM_INVALID;
  1507. }
  1508. vector<TransNodeInfo> trans_road;
  1509. for (const Json &trans_node_json : trans_road_json) {
  1510. TransNodeInfo trans_node_info;
  1511. trans_node_info.node_type = trans_node_json[kNodeType];
  1512. GeTensorDesc input_tensor_desc;
  1513. auto ret = JsonToTensorDesc(trans_node_json[kInputTensorDesc], input_tensor_desc);
  1514. if (ret != SUCCESS) {
  1515. GELOGW("Fail to trans json to tensor desc.");
  1516. return ret;
  1517. }
  1518. trans_node_info.input = move(input_tensor_desc);
  1519. GeTensorDesc output_tensor_desc;
  1520. ret = JsonToTensorDesc(trans_node_json[kOutputTensorDesc], output_tensor_desc);
  1521. if (ret != SUCCESS) {
  1522. GELOGW("Fail to trans json to tensor desc.");
  1523. return ret;
  1524. }
  1525. trans_node_info.output = move(output_tensor_desc);
  1526. trans_road.emplace_back(move(trans_node_info));
  1527. }
  1528. trans_roads[name_trans_road_json[kName].get<string>()] = move(trans_road);
  1529. }
  1530. } catch (const exception &e) {
  1531. GELOGW("Fail to trans Json to TransRoads. Error message: %s", e.what());
  1532. return INTERNAL_ERROR;
  1533. }
  1534. return SUCCESS;
  1535. }
  1536. Status ModelCacheHelper::ParseChangedGraphIdFromJson(const Json &json,
  1537. std::unordered_map<std::string, uint32_t> &changed_graph_id) {
  1538. if (!(json.is_array() || json.is_null())) {
  1539. GELOGW("Input param json type should be null or array.");
  1540. return PARAM_INVALID;
  1541. }
  1542. changed_graph_id.clear();
  1543. for (const Json &name_graph_id_json : json) {
  1544. try {
  1545. changed_graph_id[name_graph_id_json[kName].get<string>()] = name_graph_id_json[kGraphId].get<uint32_t>();
  1546. } catch (const exception &e) {
  1547. GELOGW("Fail to trans Json to changed graph id. Error message: %s", e.what());
  1548. return INTERNAL_ERROR;
  1549. }
  1550. }
  1551. return SUCCESS;
  1552. }
  1553. Status ModelCacheHelper::ParseAllocatedGraphIdFromJson(const Json &json,
  1554. std::unordered_map<std::string, uint32_t> &allocated_graph_id) {
  1555. if (!(json.is_array() || json.is_null())) {
  1556. GELOGW("Input param json type should be null or array.");
  1557. return PARAM_INVALID;
  1558. }
  1559. allocated_graph_id.clear();
  1560. for (const Json &name_graph_id_json : json) {
  1561. try {
  1562. allocated_graph_id[name_graph_id_json[kName].get<string>()] = name_graph_id_json[kGraphId].get<uint32_t>();
  1563. } catch (const exception &e) {
  1564. GELOGW("Fail to trans Json to allocated graph id. Error message: %s", e.what());
  1565. return INTERNAL_ERROR;
  1566. }
  1567. }
  1568. return SUCCESS;
  1569. }
  1570. Status ModelCacheHelper::ParseBroadcastInfoFromJson(
  1571. const Json &json, std::unordered_map<std::string, VarBroadCastInfo> &var_broadcast_info) {
  1572. if (!(json.is_array() || json.is_null())) {
  1573. GELOGW("Input param json type should be null or array.");
  1574. return PARAM_INVALID;
  1575. }
  1576. for (const Json &broadcast_info_json : json) {
  1577. VarBroadCastInfo broadcast_info;
  1578. try {
  1579. broadcast_info.var_name = broadcast_info_json[kName].get<string>();
  1580. broadcast_info.broadcast_name = broadcast_info_json[kBroadcastName].get<string>();
  1581. broadcast_info.idx = broadcast_info_json[kIdx].get<int>();
  1582. broadcast_info.input_offset = broadcast_info_json[kInputOffset].get<int64_t>();
  1583. broadcast_info.input_size = broadcast_info_json[kInputSize].get<uint64_t>();
  1584. broadcast_info.output_offset = broadcast_info_json[kOutputOffset].get<int64_t>();
  1585. broadcast_info.output_size = broadcast_info_json[kOutputSize].get<uint64_t>();
  1586. } catch (const exception &e) {
  1587. GELOGW("Fail to trans Json to VarBroadCastInfo. Error message: %s", e.what());
  1588. return INTERNAL_ERROR;
  1589. }
  1590. var_broadcast_info[broadcast_info.var_name] = broadcast_info;
  1591. }
  1592. return SUCCESS;
  1593. }
  1594. Status ModelCacheHelper::LoadOmModelFromCache(GeModelPtr &ge_model) const {
  1595. string cache_om = cache_path_ + to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kOmSuffix;
  1596. if (!CheckInputPathValid(cache_om)) {
  1597. GELOGW("Invalid cache path for input:%s.", cache_om.c_str());
  1598. return FAILED;
  1599. }
  1600. string om_path = RealPath(cache_om.c_str());
  1601. if (om_path.empty()) {
  1602. GELOGW("file path is invalid. please check file om: %s", om_path.c_str());
  1603. return FAILED;
  1604. }
  1605. GELOGI("load model data from file: %s", om_path.c_str());
  1606. Status ret;
  1607. string key_path;
  1608. int32_t priority = 0;
  1609. ModelData model_data;
  1610. ret = DavinciModelParser::LoadFromFile(om_path.c_str(), key_path.c_str(), priority, model_data);
  1611. if (ret != SUCCESS) {
  1612. GELOGW("LoadOmModelFromCache: Load model from file failed. ret = %u", ret);
  1613. return ret;
  1614. }
  1615. ModelHelper model_helper;
  1616. ret = model_helper.LoadModel(model_data);
  1617. if (ret != SUCCESS) {
  1618. GELOGW("LoadOmModelFromCache: Load model from data failed. ret = %u", ret);
  1619. return ret;
  1620. }
  1621. ge_model = model_helper.GetGeModel();
  1622. ret = RecompileNodes(ge_model);
  1623. if (ret != SUCCESS) {
  1624. GELOGW("LoadOmModelFromCache: recompile nodes failed. ret = %u", ret);
  1625. return ret;
  1626. }
  1627. return SUCCESS;
  1628. }
  1629. Status ModelCacheHelper::GetVarNameFromVarKey(const string &var_key, const GeTensorDesc &tensor_desc,
  1630. string &var_name) {
  1631. std::string::size_type underline_idx = var_key.rfind('_');
  1632. if (underline_idx == std::string::npos) {
  1633. GELOGW("Invalid var key: underline not found");
  1634. return FAILED;
  1635. }
  1636. std::string::size_type format_idx =
  1637. var_key.rfind(std::to_string(static_cast<int32_t>(tensor_desc.GetFormat())), underline_idx);
  1638. if (format_idx == std::string::npos) {
  1639. GELOGW("Invalid var key: format not found");
  1640. return FAILED;
  1641. }
  1642. var_name = var_key.substr(0, format_idx);
  1643. return SUCCESS;
  1644. }
  1645. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示