You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_cache_helper.cc 65 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/helper/model_cache_helper.h"
  17. #include <cstdio>
  18. #include <fstream>
  19. #include <functional>
  20. #include "common/model_parser/model_parser.h"
  21. #include "framework/common/helper/model_helper.h"
  22. #include "graph/detail/model_serialize_imp.h"
  23. #include "graph/utils/graph_utils.h"
  24. #include "graph/utils/tensor_utils.h"
  25. #include "init/gelib.h"
  26. #include "proto/ge_ir.pb.h"
  27. using namespace std;
  28. namespace {
  29. const char *const kTbeKernelInfoStoreName = "AIcoreEngine";
  30. const char *const kGraphName = "temp_name";
  31. // Keys of json
  32. const char *const kNodeNum = "nodeNum";
  33. const char *const kEdgeNum = "edgeNum";
  34. const char *const kGraphHash = "graphHash";
  35. const char *const kNodeHash = "nodeHash";
  36. const char *const kHash = "hash";
  37. const char *const kSessionId = "sessionId";
  38. const char *const kDeviceId = "deviceId";
  39. const char *const kJobId = "jobId";
  40. const char *const kGraphMemMaxSize = "graphMemMaxSize";
  41. const char *const kVarMemMaxSize = "varMemMaxSize";
  42. const char *const kVarMemLogicBase = "varMemLogicBase";
  43. const char *const kUseMaxMemSize = "useMaxMemSize";
  44. const char *const kMemResourceMap = "memResourceMap";
  45. const char *const kMemType = "memType";
  46. const char *const kTotalSize = "totalSize";
  47. const char *const kVarMemSize = "varMemSize";
  48. const char *const kVarResource = "varResource";
  49. const char *const kVarAddrMgrMap = "varAddrMgrMap";
  50. const char *const kName = "name";
  51. const char *const kAddress = "address";
  52. const char *const kOffset = "offset";
  53. const char *const kMemoryType = "memoryType";
  54. const char *const kTensorDesc = "tensorDesc";
  55. const char *const kDataType = "dataType";
  56. const char *const kShape = "shape";
  57. const char *const kLayout = "layout";
  58. const char *const kOriginDataType = "originDataType";
  59. const char *const kOriginShape = "originShape";
  60. const char *const kOriginLayout = "originLayout";
  61. const char *const kRealDimCnt = "realDimCnt";
  62. const char *const kCurVarTensorDescMap = "curVarTensorDescMap";
  63. const char *const kTransRoads = "transRoads";
  64. const char *const kTransRoad = "transRoad";
  65. const char *const kNodeType = "nodeType";
  66. const char *const kInputTensorDesc = "inputTensorDesc";
  67. const char *const kOutputTensorDesc = "outputTensorDesc";
  68. const char *const kChangedGraphId = "changedGraphId";
  69. const char *const kAllocatedGraphId = "allocatedGraphId";
  70. const char *const kGraphId = "graphId";
  71. const char *const kVarBroadcastInfo = "varBroadcastInfo";
  72. const char *const kBroadcastName = "broadcastName";
  73. const char *const kIdx = "idx";
  74. const char *const kInputOffset = "inputOffset";
  75. const char *const kInputSize = "inputSize";
  76. const char *const kOutputOffset = "outputOffset";
  77. const char *const kOutputSize = "outputSize";
  78. // Suffix of cache files
  79. const char *const kBeforeVarManagerSuffix = "_before_build_var_manager.json";
  80. const char *const kAfterVarManagerSuffix = "_after_build_var_manager.json";
  81. const char *const kManifestSuffix = ".manifest";
  82. const char *const kOmSuffix = ".om";
  83. } // namespace
  84. namespace ge {
  85. map<uint32_t, uint32_t> ModelCacheHelper::graph_id_run_times_;
  86. ModelCacheHelper::ModelCacheHelper(uint64_t session_id, uint32_t graph_id, ComputeGraphPtr &compute_graph)
  87. : session_id_(session_id),
  88. graph_id_(graph_id),
  89. compute_graph_(compute_graph),
  90. is_cache_path_valid_for_output(false) {
  91. if (graph_id_run_times_.count(graph_id) == 0) {
  92. graph_id_run_times_[graph_id] = 1;
  93. } else {
  94. graph_id_run_times_[graph_id] = graph_id_run_times_[graph_id] + 1;
  95. }
  96. for (const auto &node : compute_graph_->GetDirectNode()) {
  97. bool is_variable = (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) ||
  98. (node->GetType() == VARHANDLEOP) || (node->GetType() == CONSTANTOP);
  99. if (!is_variable) {
  100. continue;
  101. }
  102. var_names_.insert(node->GetName());
  103. }
  104. std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
  105. if (instance_ptr != nullptr && instance_ptr->IsIncreBuild()) {
  106. std::string cache_path = instance_ptr->GetIncreBuildCachePath();
  107. GELOGD("Incre build path conf: %s", cache_path.c_str());
  108. string fake_file_path = cache_path + to_string(graph_id_) + kManifestSuffix;
  109. if (CheckOutputPathValid(fake_file_path)) {
  110. is_cache_path_valid_for_output = true;
  111. } else {
  112. GELOGW("Invalid cache path for output.");
  113. }
  114. std::string real_cache_path = RealPath(cache_path.c_str());
  115. if (real_cache_path.empty()) {
  116. GELOGW("Invalid incre build cache path conf: %s", cache_path.c_str());
  117. return;
  118. }
  119. cache_path_ = real_cache_path + '/';
  120. GELOGD("Try to use incre build cache path: %s", cache_path_.c_str());
  121. }
  122. }
  123. ModelCacheHelper::~ModelCacheHelper() { var_names_.clear(); }
  124. bool ModelCacheHelper::IsModelCacheHit() const {
  125. CacheInfo cache_info;
  126. if (GetCacheInfo(cache_info) != SUCCESS) {
  127. GELOGI("Get cache info of graph id[%u] failed.", graph_id_);
  128. return false;
  129. }
  130. // Check number of nodes and edges first.
  131. if (cache_info.node_num != compute_graph_->GetDirectNodesSize()) {
  132. GELOGI("Graph id[%u] cache miss: the node number of the graph does not match the cache info.", graph_id_);
  133. return false;
  134. }
  135. size_t edge_num = 0;
  136. for (const auto &node : compute_graph_->GetDirectNode()) {
  137. for (const auto &anchor : node->GetAllInAnchors()) {
  138. edge_num += anchor->GetPeerAnchors().size();
  139. }
  140. }
  141. if (cache_info.edge_num != edge_num) {
  142. GELOGI("Graph id[%u] cache miss: the edge number of the graph does not match the cache info.", graph_id_);
  143. return false;
  144. }
  145. size_t compute_graph_hash;
  146. auto ret = GetComputeGraphHash(compute_graph_hash);
  147. if (ret != SUCCESS || cache_info.graph_hash != compute_graph_hash) {
  148. GELOGI("Graph id[%u] cache miss: the hash code of the graph does not match the cache info.", graph_id_);
  149. return false;
  150. }
  151. if (!IsNodeHashSameAsCache(cache_info.nodes_hash)) {
  152. GELOGI("Graph id[%u] cache miss: the hash code of node does not match the cache info.", graph_id_);
  153. return false;
  154. }
  155. string var_manager_cache =
  156. to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kBeforeVarManagerSuffix;
  157. Json var_manager_json;
  158. if (LoadJsonFromFile(var_manager_cache, var_manager_json) != SUCCESS) {
  159. GELOGW("Fail to load json from cache file: %s", var_manager_cache.c_str());
  160. return false;
  161. }
  162. if (!IsVarManagerSameAsCache(var_manager_json)) {
  163. GELOGI("Graph id[%u] cache miss: the VarManager does not match the cache info.", graph_id_);
  164. return false;
  165. }
  166. GELOGI("Graph id[%u] cache hit.", graph_id_);
  167. return true;
  168. }
  169. Status ModelCacheHelper::RefreshComputeGraph(const ComputeGraphPtr &compute_graph) {
  170. if (compute_graph->IsValid()) {
  171. compute_graph_ = compute_graph;
  172. var_names_.clear();
  173. for (const auto &node : compute_graph_->GetDirectNode()) {
  174. bool is_variable = (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) ||
  175. (node->GetType() == VARHANDLEOP) || (node->GetType() == CONSTANTOP);
  176. if (!is_variable) {
  177. continue;
  178. }
  179. var_names_.insert(node->GetName());
  180. }
  181. return SUCCESS;
  182. } else {
  183. GELOGW("Invalid compute graph.");
  184. return FAILED;
  185. }
  186. }
  187. Status ModelCacheHelper::ClearCache(uint32_t graph_id) const {
  188. if (!is_cache_path_valid_for_output) {
  189. GELOGW("Invalid cache path.");
  190. return SUCCESS;
  191. }
  192. string manifest_file = cache_path_ + to_string(graph_id) + kManifestSuffix;
  193. string manifest_file_path = RealPath(manifest_file.c_str());
  194. int ret;
  195. if (!manifest_file_path.empty()) {
  196. ret = remove(manifest_file_path.c_str());
  197. // If remove file failed, print the warning log
  198. if (ret != 0) {
  199. GELOGW("Clear cache [%s] failed.", manifest_file_path.c_str());
  200. }
  201. }
  202. string before_var_manager_file = cache_path_ + to_string(graph_id) + kManifestSuffix;
  203. string before_var_manager_file_path = RealPath(before_var_manager_file.c_str());
  204. if (!before_var_manager_file_path.empty()) {
  205. ret = remove(before_var_manager_file_path.c_str());
  206. if (ret != 0) {
  207. GELOGW("Clear cache [%s] failed.", before_var_manager_file_path.c_str());
  208. }
  209. }
  210. string after_var_manager_file = cache_path_ + to_string(graph_id) + kManifestSuffix;
  211. string after_var_manager_file_path = RealPath(after_var_manager_file.c_str());
  212. if (!after_var_manager_file_path.empty()) {
  213. ret = remove(after_var_manager_file_path.c_str());
  214. if (ret != 0) {
  215. GELOGW("Clear cache [%s] failed.", after_var_manager_file_path.c_str());
  216. }
  217. }
  218. string om_file = cache_path_ + to_string(graph_id) + kManifestSuffix;
  219. string om_file_path = RealPath(om_file.c_str());
  220. if (!om_file_path.empty()) {
  221. ret = remove(om_file_path.c_str());
  222. if (ret != 0) {
  223. GELOGW("Clear cache [%s] failed.", om_file_path.c_str());
  224. }
  225. }
  226. return SUCCESS;
  227. }
  228. Status ModelCacheHelper::RecoverVarManagerFromCache() const {
  229. string var_manager_cache =
  230. to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kAfterVarManagerSuffix;
  231. Json var_manager_json;
  232. if (LoadJsonFromFile(var_manager_cache, var_manager_json) != SUCCESS) {
  233. GELOGW("Fail to load json from cache file: %s", var_manager_cache.c_str());
  234. return FAILED;
  235. }
  236. Json mem_resource_json = move(var_manager_json[kMemResourceMap]);
  237. auto ret = RecoverMemResource(mem_resource_json);
  238. if (ret != SUCCESS) {
  239. GELOGW("Recover VarManager from cache failed.[MemResource]");
  240. return FAILED;
  241. }
  242. Json var_resource_json = move(var_manager_json[kVarResource]);
  243. ret = RecoverAllocatedGraphId(var_resource_json[kAllocatedGraphId]);
  244. if (ret != SUCCESS) {
  245. GELOGW("Recover VarManager from cache failed.[AllocatedGraphId]");
  246. return FAILED;
  247. }
  248. ret = RecoverChangedGraphId(var_resource_json[kChangedGraphId]);
  249. if (ret != SUCCESS) {
  250. GELOGW("Recover VarManager from cache failed.[ChangedGraphId]");
  251. return FAILED;
  252. }
  253. ret = RecoverBroadcastInfo(var_resource_json[kVarBroadcastInfo]);
  254. if (ret != SUCCESS) {
  255. GELOGW("Recover VarManager from cache failed.[VarBroadcastInfo]");
  256. return FAILED;
  257. }
  258. ret = RecoverVarAddrAndTensorDesc(var_resource_json[kVarAddrMgrMap]);
  259. if (ret != SUCCESS) {
  260. GELOGW("Recover VarManager from cache failed.[VarAddrMgrMap & CurVarTensorDesc]");
  261. return FAILED;
  262. }
  263. ret = RecoverTransRoads(var_resource_json[kTransRoads]);
  264. if (ret != SUCCESS) {
  265. GELOGW("Recover VarManager from cache failed.[TransRoads]");
  266. return FAILED;
  267. }
  268. GELOGI("Recover VarManager from cache[%s] success.", cache_path_.c_str());
  269. return SUCCESS;
  270. }
  271. Status ModelCacheHelper::GetNodesNeedRecompile(ComputeGraphPtr &graph, vector<NodePtr> &nodes) {
  272. std::shared_ptr<GELib> instance = ge::GELib::GetInstance();
  273. if (instance == nullptr || !instance->InitFlag()) {
  274. GELOGW("RecompileNodes failed.");
  275. return ge::GE_CLI_GE_NOT_INITIALIZED;
  276. }
  277. // Collect aicore ops for recompile
  278. for (auto &node : graph->GetDirectNode()) {
  279. if (node == nullptr) {
  280. continue;
  281. }
  282. auto op_desc = node->GetOpDesc();
  283. if (op_desc == nullptr) {
  284. continue;
  285. }
  286. // Get op kernel lib name
  287. string kernel_lib_name = op_desc->GetOpKernelLibName();
  288. if (kernel_lib_name.empty()) {
  289. // reset op kernel lib
  290. (void)instance->DNNEngineManagerObj().GetDNNEngineName(node);
  291. kernel_lib_name = op_desc->GetOpKernelLibName();
  292. if (kernel_lib_name.empty()) {
  293. GELOGW("Get node:%s, type:%s kernel lib failed.", node->GetName().c_str(), op_desc->GetType().c_str());
  294. continue;
  295. }
  296. }
  297. }
  298. return SUCCESS;
  299. }
  300. Status ModelCacheHelper::RecompileNodes(GeModelPtr &ge_model) {
  301. std::shared_ptr<GELib> instance = ge::GELib::GetInstance();
  302. if (instance == nullptr || !instance->InitFlag()) {
  303. GELOGW("RecompileNodes failed.");
  304. return ge::GE_CLI_GE_NOT_INITIALIZED;
  305. }
  306. // Get aicore ops kernel info store.
  307. OpsKernelInfoStorePtr kernel_info = instance->OpsKernelManagerObj().GetOpsKernelInfoStore(kTbeKernelInfoStoreName);
  308. if (kernel_info == nullptr) {
  309. GELOGW("Get %s ops kernel info store failed", kTbeKernelInfoStoreName);
  310. return INTERNAL_ERROR;
  311. }
  312. auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph());
  313. vector<NodePtr> node_vec;
  314. auto ret = GetNodesNeedRecompile(compute_graph, node_vec);
  315. GE_CHK_BOOL_EXEC_WARN(ret == ge::SUCCESS, return ret, "Get nodes need recompiling failed");
  316. // Recompile aicore ops
  317. ret = kernel_info->CompileOp(node_vec);
  318. GE_CHK_BOOL_EXEC_WARN(ret == ge::SUCCESS, return ret, "Recompile op failed");
  319. const TBEKernelStore &tbekernel_store = ge_model->GetTBEKernelStore();
  320. TBEKernelStore tbe_kernel_store;
  321. for (const ge::NodePtr &n : compute_graph->GetDirectNode()) {
  322. auto node_op_desc = n->GetOpDesc();
  323. GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue);
  324. TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr());
  325. if (tbe_kernel == nullptr) {
  326. // Load tbe kernel from tbe_kernel_store to op if op was not recompiled
  327. auto op_desc = n->GetOpDesc();
  328. tbekernel_store.LoadTBEKernelBinToOpDesc(op_desc);
  329. GELOGD("LoadOmModelFromCache: Load tbe kernel bin to op desc[%s].", op_desc->GetName().c_str());
  330. }
  331. tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr());
  332. GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue);
  333. // Refresh tbe kernel in tbe_kernel_store
  334. tbe_kernel_store.AddTBEKernel(tbe_kernel);
  335. GELOGD("Add tbe kernel bin %s", tbe_kernel->GetName().c_str());
  336. }
  337. GE_CHK_BOOL_EXEC_WARN(tbe_kernel_store.Build(), return FAILED, "TBE Kernels store build failed!");
  338. ge_model->SetTBEKernelStore(tbe_kernel_store);
  339. return SUCCESS;
  340. }
  341. Status ModelCacheHelper::GetNodesHash(map<std::string, size_t> &hash_map) const {
  342. vector<NodePtr> nodes;
  343. GraphUtils::TopologicalSortingByName(compute_graph_, nodes);
  344. ModelSerializeImp model_serialize_imp;
  345. std::hash<string> node_hash;
  346. for (const auto &node : nodes) {
  347. if (node == nullptr) {
  348. continue;
  349. }
  350. proto::OpDef op_def;
  351. bool is_framework_op = (node->GetType() == FRAMEWORKOP);
  352. int32_t framework_type = 0;
  353. if (is_framework_op) {
  354. AttrUtils::GetInt(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, framework_type);
  355. AttrUtils::SetInt(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, 0);
  356. }
  357. bool ret = model_serialize_imp.SerializeNode(node, &op_def, is_framework_op);
  358. op_def.set_id(0); // Id of op is not stable because of parallel parsing
  359. // Clear weights attr in constant.
  360. auto attr = op_def.mutable_attr();
  361. if (op_def.type() == CONSTANT || op_def.type() == CONSTANTOP) {
  362. attr->erase(ATTR_NAME_WEIGHTS);
  363. }
  364. if (is_framework_op) {
  365. AttrUtils::SetInt(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, framework_type);
  366. }
  367. if (!ret) {
  368. GELOGW("Fail to serialize node[%s].", node->GetName().c_str());
  369. return INTERNAL_ERROR;
  370. }
  371. string prototxt;
  372. ret = google::protobuf::TextFormat::PrintToString(op_def, &prototxt);
  373. if (!ret) {
  374. GELOGW("Print OpDef to string failed.");
  375. hash_map.clear();
  376. return INTERNAL_ERROR;
  377. }
  378. size_t hash_code = node_hash(prototxt);
  379. hash_map[node->GetName()] = hash_code;
  380. }
  381. return SUCCESS;
  382. }
  383. Status ModelCacheHelper::GetComputeGraphHash(size_t &hash) const {
  384. proto::GraphDef graph_proto;
  385. ModelSerializeImp model_serialize_imp;
  386. // The name of compute graph may be generated randomly, so replace it temporarily.
  387. const string origin_name = compute_graph_->GetName();
  388. compute_graph_->SetName(kGraphName);
  389. bool serialize_ret = model_serialize_imp.SerializeGraph(compute_graph_, &graph_proto);
  390. graph_proto.clear_op();
  391. if (!serialize_ret) {
  392. GELOGW("Serialize graph failed.");
  393. hash = 0;
  394. return INTERNAL_ERROR;
  395. }
  396. compute_graph_->SetName(origin_name);
  397. // Generate proto text of GraphDef
  398. string prototxt;
  399. bool print_ret = google::protobuf::TextFormat::PrintToString(graph_proto, &prototxt);
  400. if (!print_ret) {
  401. GELOGW("Print GraphDef to string failed.");
  402. hash = 0;
  403. return INTERNAL_ERROR;
  404. }
  405. // Get the hash code of proto text
  406. std::hash<string> graph_hash;
  407. hash = graph_hash(prototxt);
  408. return SUCCESS;
  409. }
  410. Status ModelCacheHelper::SaveJsonToFile(const string &file_name, const Json &json) const {
  411. if (!is_cache_path_valid_for_output) {
  412. GELOGW("Invalid cache path.");
  413. return PARAM_INVALID;
  414. }
  415. // Check whether the manifest exists, if not, create it.
  416. string real_path = RealPath(cache_path_.c_str());
  417. if (real_path.empty()) {
  418. GELOGW("File path is invalid. please check cache path: %s", cache_path_.c_str());
  419. return FAILED;
  420. }
  421. const string path = cache_path_ + file_name;
  422. const int FILE_AUTHORITY = 0600;
  423. int fd = mmOpen2(path.c_str(), M_WRONLY | M_CREAT | O_TRUNC, FILE_AUTHORITY);
  424. if (fd < 0) {
  425. GELOGW("Fail to open the file:%s. errmsg:%s", path.c_str(), strerror(errno));
  426. return INTERNAL_ERROR;
  427. }
  428. if (mmClose(fd) != 0) {
  429. GELOGW("Fail to close the file:%s. errmsg:%s", path.c_str(), strerror(errno));
  430. return INTERNAL_ERROR;
  431. }
  432. // Write json into cache file
  433. ofstream ofs;
  434. ofs.open(path);
  435. if (!ofs.is_open()) {
  436. GELOGW("Fail to open the file: %s.", path.c_str());
  437. return INTERNAL_ERROR;
  438. }
  439. ofs << json << std::endl;
  440. ofs.close();
  441. return SUCCESS;
  442. }
  443. Status ModelCacheHelper::LoadJsonFromFile(const string &file_name, Json &json) const {
  444. if (!json.is_null()) {
  445. GELOGW("Input param json type should be null.");
  446. return PARAM_INVALID;
  447. }
  448. string real_path = RealPath(cache_path_.c_str());
  449. if (real_path.empty()) {
  450. GELOGW("File path is invalid. please check cache path: %s", cache_path_.c_str());
  451. return FAILED;
  452. }
  453. const string path = cache_path_ + file_name;
  454. if (!CheckInputPathValid(path)) {
  455. GELOGW("Invalid cache path for input:%s.", path.c_str());
  456. return FAILED;
  457. }
  458. string cache_real_path = RealPath(path.c_str());
  459. if (cache_real_path.empty()) {
  460. GELOGI("File[%s] is not found.", path.c_str());
  461. return FAILED;
  462. }
  463. // Read json from cache file
  464. ifstream ifs;
  465. ifs.open(path);
  466. if (!ifs.is_open()) {
  467. GELOGW("Fail to open the file: %s.", path.c_str());
  468. return INTERNAL_ERROR;
  469. }
  470. try {
  471. ifs >> json;
  472. } catch (nlohmann::detail::parse_error e) {
  473. GELOGW("Fail to load json from file, json throw an error:%s.", e.what());
  474. return INTERNAL_ERROR;
  475. } catch (nlohmann::detail::invalid_iterator e) {
  476. GELOGW("Fail to load json from file, json throw an error:%s.", e.what());
  477. return INTERNAL_ERROR;
  478. } catch (nlohmann::detail::type_error e) {
  479. GELOGW("Fail to load json from file, json throw an error:%s.", e.what());
  480. return INTERNAL_ERROR;
  481. } catch (nlohmann::detail::out_of_range e) {
  482. GELOGW("Fail to load json from file, json throw an error:%s.", e.what());
  483. return INTERNAL_ERROR;
  484. } catch (nlohmann::detail::other_error e) {
  485. GELOGW("Fail to load json from file, json throw an error:%s.", e.what());
  486. return INTERNAL_ERROR;
  487. }
  488. if (!json.is_object()) {
  489. GELOGW("Fail to load the json file: %s.", path.c_str());
  490. return INTERNAL_ERROR;
  491. }
  492. return SUCCESS;
  493. }
  494. Status ModelCacheHelper::SaveCacheInfoToCache() const {
  495. // Generate cache json
  496. // example: {"edgeNum":6,"nodeNum":7,"graphCache":134714827475991356}
  497. Json cache_json;
  498. try {
  499. cache_json[kNodeNum] = compute_graph_->GetDirectNodesSize();
  500. size_t edge_num = 0;
  501. for (const auto &node : compute_graph_->GetDirectNode()) {
  502. for (const auto &anchor : node->GetAllInAnchors()) {
  503. edge_num += anchor->GetPeerAnchors().size();
  504. }
  505. }
  506. cache_json[kEdgeNum] = edge_num;
  507. size_t hash = 0;
  508. auto ret = GetComputeGraphHash(hash);
  509. if (ret != SUCCESS) {
  510. GELOGW("Error occur when generate graph hash code.");
  511. return ret;
  512. }
  513. cache_json[kGraphHash] = hash;
  514. Json nodes_hash_json;
  515. ret = GetNodesHashMapJson(nodes_hash_json);
  516. if (ret != SUCCESS) {
  517. GELOGW("Error occur when generate nodes hash code.");
  518. return ret;
  519. }
  520. cache_json[kNodeHash] = nodes_hash_json;
  521. } catch (const std::exception &e) {
  522. GELOGW("Fail to generate cache info json. Error message: %s", e.what());
  523. return INTERNAL_ERROR;
  524. }
  525. string cache_manifest = to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kManifestSuffix;
  526. auto ret = SaveJsonToFile(cache_manifest, cache_json);
  527. if (ret != SUCCESS) {
  528. GELOGW("Fail to save cache info to json file, path: %s.", cache_path_.c_str());
  529. return ret;
  530. }
  531. return SUCCESS;
  532. }
  533. Status ModelCacheHelper::GetCacheInfo(CacheInfo &cache_info) const {
  534. string cache_manifest = to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kManifestSuffix;
  535. Json cache_json;
  536. if (LoadJsonFromFile(cache_manifest, cache_json) != SUCCESS) {
  537. GELOGW("Fail to load json from cache file: %s", cache_manifest.c_str());
  538. return INTERNAL_ERROR;
  539. }
  540. if (!cache_json.is_object()) {
  541. GELOGW("Manifest should be a json object");
  542. return INTERNAL_ERROR;
  543. }
  544. try {
  545. cache_info.node_num = cache_json[kNodeNum];
  546. cache_info.edge_num = cache_json[kEdgeNum];
  547. cache_info.graph_hash = cache_json[kGraphHash];
  548. Json nodes_hash_json = cache_json[kNodeHash];
  549. if (!(nodes_hash_json.is_null() || nodes_hash_json.is_array())) {
  550. GELOGW("Nodes hash in cache should be null or array.");
  551. return FAILED;
  552. }
  553. for (const auto &iter : nodes_hash_json) {
  554. cache_info.nodes_hash[iter[kName].get<std::string>()] = iter[kHash].get<size_t>();
  555. }
  556. } catch (const std::exception &e) {
  557. GELOGW("Fail to get info from json file. Error message: %s", e.what());
  558. return INTERNAL_ERROR;
  559. }
  560. return SUCCESS;
  561. }
  562. bool ModelCacheHelper::IsAllocatedGraphIdSameAsCache(Json &json) const {
  563. if (!(json.is_null() || json.is_array())) {
  564. GELOGW("Input param json type should be null or array.");
  565. return false;
  566. }
  567. // Compare allocated graph id info between json and VarManager
  568. std::map<std::string, uint32_t> allocated_graph_id;
  569. auto ret = ParseAllocatedGraphIdFromJson(json, allocated_graph_id);
  570. if (ret != SUCCESS) {
  571. GELOGW("Fail to parse AllocatedGraphId from Json.");
  572. return false;
  573. }
  574. for (const auto &iter : allocated_graph_id) {
  575. uint32_t graph_id = 0;
  576. ret = VarManager::Instance(session_id_)->GetAllocatedGraphId(iter.first, graph_id);
  577. if (ret != SUCCESS) {
  578. GELOGW("Fail to find allocated graph id of var[%s].", iter.first.c_str());
  579. return false;
  580. }
  581. if (graph_id != iter.second) {
  582. GELOGW("The allocated graph id of variable[%s] in cache is different from VarManager.", iter.first.c_str());
  583. return false;
  584. }
  585. }
  586. return true;
  587. }
  588. bool ModelCacheHelper::IsNodeHashSameAsCache(const map<std::string, size_t> &hash_map) const {
  589. map<std::string, size_t> cur_hash_map;
  590. GetNodesHash(cur_hash_map);
  591. if (hash_map.size() != cur_hash_map.size()) {
  592. GELOGI("The number of hash code is different from cache info.");
  593. return false;
  594. }
  595. for (const auto &iter : cur_hash_map) {
  596. if (hash_map.count(iter.first) == 0) {
  597. GELOGI("Node[%s] is not found in cache info.", iter.first.c_str());
  598. return false;
  599. }
  600. if (hash_map.at(iter.first) != iter.second) {
  601. GELOGI("The hash code of node[%s] is different from cache info.", iter.first.c_str());
  602. return false;
  603. }
  604. }
  605. return true;
  606. }
  607. bool ModelCacheHelper::IsMemResourceSameAsCache(Json &json) const {
  608. if (!(json.is_null() || json.is_array())) {
  609. GELOGW("Input param json type should be null or array.");
  610. return false;
  611. }
  612. // Compare var mem size info between json and VarManager
  613. std::map<rtMemType_t, int64_t> var_mem_size;
  614. auto ret = ParseMemResourceFromJson(json, var_mem_size);
  615. if (ret != SUCCESS) {
  616. GELOGW("Fail to parse MemResource from Json.");
  617. return false;
  618. }
  619. for (const auto &iter : var_mem_size) {
  620. int64_t mem_size = VarManager::Instance(session_id_)->GetVarMemSize(iter.first);
  621. if (mem_size != iter.second) {
  622. GELOGW("The var mem size of memory_type[%u] in cache is different from VarManager.", iter.first);
  623. return false;
  624. }
  625. }
  626. return true;
  627. }
  628. bool ModelCacheHelper::IsChangedGraphIdSameAsCache(Json &json) const {
  629. if (!(json.is_null() || json.is_array())) {
  630. GELOGW("Input param json type should be null or array.");
  631. return false;
  632. }
  633. // Compare variable changed graph id info between json and VarManager
  634. std::map<std::string, uint32_t> changed_graph_id;
  635. auto ret = ParseChangedGraphIdFromJson(json, changed_graph_id);
  636. if (ret != SUCCESS) {
  637. GELOGW("Fail to parse ChangedGraphId from Json.");
  638. return false;
  639. }
  640. for (const auto &iter : changed_graph_id) {
  641. uint32_t graph_id = 0;
  642. ret = VarManager::Instance(session_id_)->GetChangedGraphId(iter.first, graph_id);
  643. if (ret != SUCCESS) {
  644. GELOGW("Fail to find changed graph id of var[%s].", iter.first.c_str());
  645. return false;
  646. }
  647. if (graph_id != iter.second) {
  648. GELOGW("The changed graph id of variable[%s] in cache is different from VarManager.", iter.first.c_str());
  649. return false;
  650. }
  651. }
  652. return true;
  653. }
  654. bool ModelCacheHelper::IsCurVarTensorDescSameAsCache(Json &json) const {
  655. if (!(json.is_null() || json.is_array())) {
  656. GELOGW("Input param json type should be null or array.");
  657. return false;
  658. }
  659. // Compare variable tensor desc info between json and VarManager
  660. std::unordered_map<std::string, ge::GeTensorDesc> cur_var_tensor_desc;
  661. auto ret = ParseCurVarTensorDescMapFromJson(json, cur_var_tensor_desc);
  662. if (ret != SUCCESS) {
  663. GELOGW("Fail to parse CurVarTensorDesc from Json.");
  664. return false;
  665. }
  666. for (const auto &iter : cur_var_tensor_desc) {
  667. GeTensorDesc tensor_desc;
  668. ret = VarManager::Instance(session_id_)->GetCurVarDesc(iter.first, tensor_desc);
  669. if (ret != SUCCESS) {
  670. GELOGW("Fail to find tensor desc of var[%s].", iter.first.c_str());
  671. return false;
  672. }
  673. uint32_t l_real_dim_cnt = 0;
  674. uint32_t r_real_dim_cnt = 0;
  675. TensorUtils::GetRealDimCnt(tensor_desc, l_real_dim_cnt);
  676. TensorUtils::GetRealDimCnt(iter.second, r_real_dim_cnt);
  677. if ((tensor_desc.GetDataType() != iter.second.GetDataType()) ||
  678. (tensor_desc.GetOriginDataType() != iter.second.GetOriginDataType()) ||
  679. (tensor_desc.GetFormat() != iter.second.GetFormat()) ||
  680. (tensor_desc.GetOriginFormat() != iter.second.GetOriginFormat()) ||
  681. (tensor_desc.GetShape().ToString() != iter.second.GetShape().ToString()) ||
  682. (tensor_desc.GetOriginShape().ToString() != iter.second.GetOriginShape().ToString()) ||
  683. (l_real_dim_cnt != r_real_dim_cnt)) {
  684. GELOGW("The var tensor desc of variable[%s] in cache is different from VarManager.", iter.first.c_str());
  685. return false;
  686. }
  687. }
  688. return true;
  689. }
  690. bool ModelCacheHelper::IsVarAddrMgrMapSameAsCache(Json &json) const {
  691. if (!(json.is_null() || json.is_array())) {
  692. GELOGW("Input param json type should be null or array.");
  693. return false;
  694. }
  695. // Compare variable address info between json and VarManager
  696. std::vector<std::pair<std::string, VarAddrMgr>> var_addr_mgr_vector;
  697. std::set<uint64_t> var_offset_set;
  698. auto ret = ParseVarAddrMgrMapFromJson(json, var_addr_mgr_vector, var_offset_set);
  699. if (ret != SUCCESS) {
  700. GELOGW("Fail to parse VarAddrMgrMap from Json.");
  701. return false;
  702. }
  703. for (const auto &iter : var_addr_mgr_vector) {
  704. uint8_t *dev_ptr = nullptr;
  705. rtMemType_t memory_type;
  706. ret = VarManager::Instance(session_id_)->GetVarAddr(iter.first, iter.second.tensor_desc, &dev_ptr, memory_type);
  707. if (ret != SUCCESS) {
  708. GELOGW("Fail to find tensor desc of var[%s].", iter.first.c_str());
  709. return false;
  710. }
  711. // Compare memory type and logic address
  712. if (iter.second.memory_type != memory_type || iter.second.address != dev_ptr) {
  713. GELOGW("The VarAddrMgr of variable[%s] in cache is different from VarManager.", iter.first.c_str());
  714. return false;
  715. }
  716. }
  717. return true;
  718. }
  719. bool ModelCacheHelper::IsBroadcastInfoSameAsCache(Json &json) const {
  720. if (!(json.is_null() || json.is_array())) {
  721. GELOGW("Input param json type should be null or array.");
  722. return false;
  723. }
  724. // Compare broadcast info between json and VarManager
  725. std::unordered_map<std::string, VarBroadCastInfo> var_broadcast_info;
  726. auto ret = ParseBroadcastInfoFromJson(json, var_broadcast_info);
  727. if (ret != SUCCESS) {
  728. GELOGW("Fail to parse BroadcastInfo from Json.");
  729. return false;
  730. }
  731. for (const auto &iter : var_broadcast_info) {
  732. VarBroadCastInfo broadcast_info;
  733. if (VarManager::Instance(session_id_)->GetBroadCastInfo(graph_id_, iter.first, broadcast_info) != SUCCESS) {
  734. GELOGW("Fail to find broadcast info of var[%s].", iter.first.c_str());
  735. return false;
  736. }
  737. if (iter.second.var_name != broadcast_info.var_name || iter.second.idx != broadcast_info.idx ||
  738. iter.second.input_size != broadcast_info.input_size ||
  739. iter.second.input_offset != broadcast_info.input_offset ||
  740. iter.second.output_size != broadcast_info.output_size ||
  741. iter.second.output_offset != broadcast_info.output_offset) {
  742. GELOGW("The BroadcastInfo of variable[%s] in cache is different from VarManager.", iter.first.c_str());
  743. return false;
  744. }
  745. }
  746. return true;
  747. }
  748. bool ModelCacheHelper::IsTransRoadsSameAsCache(Json &json) const {
  749. if (!(json.is_null() || json.is_array())) {
  750. GELOGW("Input param json type should be null or array.");
  751. return false;
  752. }
  753. // Compare trans road between json and VarManager
  754. std::unordered_map<std::string, std::vector<TransNodeInfo>> trans_roads;
  755. auto ret = ParseTransRoadsFromJson(json, trans_roads);
  756. if (ret != SUCCESS) {
  757. GELOGW("Fail to parse TransRoads from Json.");
  758. return false;
  759. }
  760. for (const auto &iter : trans_roads) {
  761. VarTransRoad *trans_road;
  762. trans_road = VarManager::Instance(session_id_)->GetTransRoad(iter.first);
  763. if (trans_road == nullptr) {
  764. GELOGW("Fail to find trans road of var[%s].", iter.first.c_str());
  765. return false;
  766. }
  767. if (trans_road->size() != iter.second.size()) {
  768. GELOGW("The TransRoad of variable[%s] in cache is different from VarManager.", iter.first.c_str());
  769. return false;
  770. }
  771. // Compare every trans node in trans road.
  772. for (size_t idx = 0; idx < trans_road->size(); idx += 1) {
  773. if (!(trans_road->at(idx).node_type == iter.second.at(idx).node_type &&
  774. trans_road->at(idx).input == iter.second.at(idx).input &&
  775. trans_road->at(idx).output == iter.second.at(idx).output)) {
  776. GELOGW("The TransRoad of variable[%s] in cache is different from VarManager.", iter.first.c_str());
  777. return false;
  778. }
  779. }
  780. }
  781. return true;
  782. }
  783. bool ModelCacheHelper::IsVarManagerParamSameAsCache(Json &json) const {
  784. if (!json.is_object()) {
  785. GELOGW("Input param json type should be object.");
  786. return false;
  787. }
  788. try {
  789. if (json[kSessionId].get<uint64_t>() != session_id_) {
  790. GELOGW("Check VarManager cache failed.[sessionId]");
  791. return false;
  792. }
  793. if (json[kDeviceId].get<uint32_t>() != VarManager::Instance(session_id_)->DeviceId()) {
  794. GELOGW("Check VarManager cache failed.[deviceId]");
  795. return false;
  796. }
  797. if (json[kJobId].get<uint64_t>() != VarManager::Instance(session_id_)->JobId()) {
  798. GELOGW("Check VarManager cache failed.[jobId]");
  799. return false;
  800. }
  801. if (json[kGraphMemMaxSize].get<size_t>() != VarManager::Instance(session_id_)->GetGraphMemoryMaxSize()) {
  802. GELOGW("Check VarManager cache failed.[graphMemMaxSize]");
  803. return false;
  804. }
  805. if (json[kVarMemMaxSize].get<size_t>() != VarManager::Instance(session_id_)->GetVarMemMaxSize()) {
  806. GELOGW("Check VarManager cache failed.[varMemMaxSize]");
  807. return false;
  808. }
  809. if (json[kVarMemLogicBase].get<size_t>() != VarManager::Instance(session_id_)->GetVarMemLogicBase()) {
  810. GELOGW("Check VarManager cache failed.[varMemLogicBase]");
  811. return false;
  812. }
  813. if (json[kUseMaxMemSize].get<size_t>() != VarManager::Instance(session_id_)->GetUseMaxMemorySize()) {
  814. GELOGW("Check VarManager cache failed.[useMaxMemSize]");
  815. return false;
  816. }
  817. } catch (const std::exception &e) {
  818. GELOGW("Fail to check VarManager json. Error message: %s", e.what());
  819. return false;
  820. }
  821. return true;
  822. }
  823. bool ModelCacheHelper::IsVarManagerSameAsCache(Json &json) const {
  824. if (!json.is_object()) {
  825. GELOGW("Input param json type should be object.");
  826. return false;
  827. }
  828. try {
  829. if (!IsVarManagerParamSameAsCache(json)) {
  830. GELOGW("Check VarManager cache failed.[Param]");
  831. return false;
  832. }
  833. Json mem_resource_json = move(json[kMemResourceMap]);
  834. auto ret = IsMemResourceSameAsCache(mem_resource_json);
  835. if (!ret) {
  836. GELOGW("Check VarManager cache failed.[MemResource]");
  837. return false;
  838. }
  839. Json var_resource_json = move(json[kVarResource]);
  840. ret = IsAllocatedGraphIdSameAsCache(var_resource_json[kAllocatedGraphId]);
  841. if (!ret) {
  842. GELOGW("Check VarManager cache failed.[AllocatedGraphId]");
  843. return false;
  844. }
  845. ret = IsChangedGraphIdSameAsCache(var_resource_json[kChangedGraphId]);
  846. if (!ret) {
  847. GELOGW("Check VarManager cache failed.[ChangedGraphId]");
  848. return false;
  849. }
  850. ret = IsBroadcastInfoSameAsCache(var_resource_json[kVarBroadcastInfo]);
  851. if (!ret) {
  852. GELOGW("Check VarManager cache failed.[VarBroadcastInfo]");
  853. return false;
  854. }
  855. ret = IsCurVarTensorDescSameAsCache(var_resource_json[kCurVarTensorDescMap]);
  856. if (!ret) {
  857. GELOGW("Check VarManager cache failed.[CurVarTensorDesc]");
  858. return false;
  859. }
  860. ret = IsVarAddrMgrMapSameAsCache(var_resource_json[kVarAddrMgrMap]);
  861. if (!ret) {
  862. GELOGW("Check VarManager cache failed.[VarAddrMgrMap]");
  863. return false;
  864. }
  865. ret = IsTransRoadsSameAsCache(var_resource_json[kTransRoads]);
  866. if (!ret) {
  867. GELOGW("Check VarManager cache failed.[TransRoads]");
  868. return false;
  869. }
  870. } catch (const std::exception &e) {
  871. GELOGW("Fail to check VarManager json. Error message: %s", e.what());
  872. return false;
  873. }
  874. return true;
  875. }
  876. Status ModelCacheHelper::RecoverMemResource(const Json &json) const {
  877. if (!(json.is_null() || json.is_array())) {
  878. GELOGW("Input param json type should be null or array.");
  879. return PARAM_INVALID;
  880. }
  881. std::map<rtMemType_t, int64_t> var_mem_size;
  882. auto ret = ParseMemResourceFromJson(json, var_mem_size);
  883. if (ret != SUCCESS) {
  884. GELOGW("Fail to parse MemResource from Json.");
  885. return ret;
  886. }
  887. for (const auto &iter : var_mem_size) {
  888. ret = VarManager::Instance(session_id_)->UpdateVarMemSize(iter.first, iter.second);
  889. if (ret != SUCCESS) {
  890. GELOGW("Fail to recover var mem size.");
  891. return ret;
  892. }
  893. }
  894. return SUCCESS;
  895. }
  896. Status ModelCacheHelper::RecoverAllocatedGraphId(const Json &json) const {
  897. if (!(json.is_null() || json.is_array())) {
  898. GELOGW("Input param json type should be null or array.");
  899. return PARAM_INVALID;
  900. }
  901. std::map<std::string, uint32_t> allocated_graph_id;
  902. auto ret = ParseAllocatedGraphIdFromJson(json, allocated_graph_id);
  903. if (ret != SUCCESS) {
  904. GELOGW("Fail to parse AllocatedGraphId from Json.");
  905. return ret;
  906. }
  907. for (const auto &iter : allocated_graph_id) {
  908. ret = VarManager::Instance(session_id_)->SetAllocatedGraphId(iter.first, iter.second);
  909. if (ret != SUCCESS) {
  910. GELOGW("Fail to recover allocated graph id.");
  911. return ret;
  912. }
  913. }
  914. return SUCCESS;
  915. }
  916. Status ModelCacheHelper::RecoverChangedGraphId(const Json &json) const {
  917. if (!(json.is_null() || json.is_array())) {
  918. GELOGW("Input param json type should be null or array.");
  919. return PARAM_INVALID;
  920. }
  921. std::map<std::string, uint32_t> changed_graph_id;
  922. auto ret = ParseChangedGraphIdFromJson(json, changed_graph_id);
  923. if (ret != SUCCESS) {
  924. GELOGW("Fail to parse AllocatedGraphId from Json.");
  925. return ret;
  926. }
  927. for (const auto &iter : changed_graph_id) {
  928. ret = VarManager::Instance(session_id_)->SetChangedGraphId(iter.first, iter.second);
  929. if (ret != SUCCESS) {
  930. GELOGW("Fail to recover changed graph id.");
  931. return ret;
  932. }
  933. }
  934. return SUCCESS;
  935. }
  936. Status ModelCacheHelper::RecoverVarAddrAndTensorDesc(const Json &json) const {
  937. if (!(json.is_null() || json.is_array())) {
  938. GELOGW("Input param json type should be null or array.");
  939. return PARAM_INVALID;
  940. }
  941. std::vector<std::pair<std::string, VarAddrMgr>> var_addr_mgr_vector;
  942. std::set<uint64_t> var_offset_set;
  943. auto ret = ParseVarAddrMgrMapFromJson(json, var_addr_mgr_vector, var_offset_set);
  944. if (ret != SUCCESS) {
  945. GELOGW("Fail to parse VarAddrMgrMap from Json.");
  946. return ret;
  947. }
  948. for (const auto &iter : var_addr_mgr_vector) {
  949. const VarAddrMgr &tensor_addr_mgr = iter.second;
  950. const bool var_exist = VarManager::Instance(session_id_)->IsVarExist(iter.first, tensor_addr_mgr.tensor_desc);
  951. // SaveVarVddr if var does not exist, the logic address will be recorded by VarManager
  952. if (!var_exist) {
  953. auto logic_address = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(tensor_addr_mgr.address));
  954. auto offset = (tensor_addr_mgr.offset);
  955. // Check logic address and offset
  956. if (logic_address - offset != VarManager::Instance(session_id_)->GetVarMemLogicBase()) {
  957. GELOGW("Check logic_address[%lu] and offset [%lu] of %s failed, var mem logic base is %lu, abandon",
  958. logic_address, offset, iter.first.c_str(), VarManager::Instance(session_id_)->GetVarMemLogicBase());
  959. return PARAM_INVALID;
  960. }
  961. // Offset is needed by SaveVarVddr instead of logic address
  962. ret = VarManager::Instance(session_id_)->SaveVarAddr(iter.first, tensor_addr_mgr.tensor_desc,
  963. reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(offset)),
  964. tensor_addr_mgr.memory_type);
  965. if (ret != SUCCESS) {
  966. GELOGW("Fail to recover VarAddr or TensorDesc of var[%s].", iter.first.c_str());
  967. return ret;
  968. }
  969. }
  970. // SetVarAddr to update cur_var_tensor_desc_map_
  971. ret = VarManager::Instance(session_id_)
  972. ->SetVarAddr(iter.first, tensor_addr_mgr.tensor_desc, tensor_addr_mgr.address, tensor_addr_mgr.memory_type);
  973. if (ret != SUCCESS) {
  974. GELOGW("Fail to recover VarAddr or TensorDesc desc of var[%s].", iter.first.c_str());
  975. return ret;
  976. }
  977. }
  978. return SUCCESS;
  979. }
  980. Status ModelCacheHelper::RecoverBroadcastInfo(const Json &json) const {
  981. if (!(json.is_null() || json.is_array())) {
  982. GELOGW("Input param json type should be null or array.");
  983. return PARAM_INVALID;
  984. }
  985. std::unordered_map<std::string, VarBroadCastInfo> var_broadcast_info;
  986. auto ret = ParseBroadcastInfoFromJson(json, var_broadcast_info);
  987. if (ret != SUCCESS) {
  988. GELOGW("Fail to parse BroadcastInfo from Json.");
  989. return ret;
  990. }
  991. for (const auto &iter : var_broadcast_info) {
  992. VarBroadCastInfo broadcast_info;
  993. ret = VarManager::Instance(session_id_)->SaveBroadCastInfo(graph_id_, iter.second);
  994. if (ret != SUCCESS) {
  995. GELOGW("Fail to recover broadcast info of var[%s].", iter.first.c_str());
  996. return ret;
  997. }
  998. }
  999. return SUCCESS;
  1000. }
  1001. Status ModelCacheHelper::RecoverTransRoads(const Json &json) const {
  1002. if (!(json.is_null() || json.is_array())) {
  1003. GELOGW("Input param json type should be null or array.");
  1004. return PARAM_INVALID;
  1005. }
  1006. std::unordered_map<std::string, std::vector<TransNodeInfo>> trans_roads;
  1007. auto ret = ParseTransRoadsFromJson(json, trans_roads);
  1008. if (ret != SUCCESS) {
  1009. GELOGW("Fail to parse TransRoads from Json.");
  1010. return ret;
  1011. }
  1012. for (const auto &iter : trans_roads) {
  1013. ret = VarManager::Instance(session_id_)->SetTransRoad(iter.first, iter.second);
  1014. if (ret != SUCCESS) {
  1015. GELOGW("Fail to find trans road of var[%s].", iter.first.c_str());
  1016. return ret;
  1017. }
  1018. }
  1019. return SUCCESS;
  1020. }
  1021. Status ModelCacheHelper::TensorDescToJson(const GeTensorDesc &ge_tensor_desc, Json &json) {
  1022. if (!(json.is_null() || json.is_object())) {
  1023. GELOGW("Input param json type should be null or object.");
  1024. return PARAM_INVALID;
  1025. }
  1026. try {
  1027. json[kDataType] = static_cast<int>(ge_tensor_desc.GetDataType());
  1028. json[kOriginDataType] = static_cast<int>(ge_tensor_desc.GetOriginDataType());
  1029. json[kLayout] = static_cast<int>(ge_tensor_desc.GetFormat());
  1030. json[kOriginLayout] = static_cast<int>(ge_tensor_desc.GetOriginFormat());
  1031. json[kShape] = ge_tensor_desc.GetShape().GetDims();
  1032. json[kOriginShape] = ge_tensor_desc.GetOriginShape().GetDims();
  1033. uint32_t real_dim_cnt = 0;
  1034. (void)TensorUtils::GetRealDimCnt(ge_tensor_desc, real_dim_cnt); // [No need to check value]
  1035. json[kRealDimCnt] = real_dim_cnt;
  1036. } catch (const std::exception &e) {
  1037. GELOGW("Fail to trans GeTensorDesc to json. Error message: %s", e.what());
  1038. return INTERNAL_ERROR;
  1039. }
  1040. return SUCCESS;
  1041. }
  1042. Status ModelCacheHelper::JsonToTensorDesc(const Json &json, ge::GeTensorDesc &ge_tensor_desc) {
  1043. if (!json.is_object()) {
  1044. GELOGW("Input param json type should be object.");
  1045. return PARAM_INVALID;
  1046. }
  1047. try {
  1048. ge_tensor_desc.SetDataType(static_cast<DataType>(json[kDataType].get<int>()));
  1049. ge_tensor_desc.SetOriginDataType(static_cast<DataType>(json[kOriginDataType].get<int>()));
  1050. ge_tensor_desc.SetFormat(static_cast<Format>(json[kLayout].get<int>()));
  1051. ge_tensor_desc.SetOriginFormat(static_cast<Format>(json[kOriginLayout].get<int>()));
  1052. GeShape shape(json[kShape].get<std::vector<int64_t>>());
  1053. ge_tensor_desc.SetShape(shape);
  1054. GeShape origin_shape(json[kOriginShape].get<std::vector<int64_t>>());
  1055. ge_tensor_desc.SetOriginShape(origin_shape);
  1056. auto real_dim_cnt = json[kRealDimCnt].get<uint32_t>();
  1057. (void)TensorUtils::SetRealDimCnt(ge_tensor_desc, real_dim_cnt); // [No need to check value]
  1058. } catch (const std::exception &e) {
  1059. GELOGW("Fail to trans Json to GeTensorDesc. Error message: %s", e.what());
  1060. return INTERNAL_ERROR;
  1061. }
  1062. return SUCCESS;
  1063. }
  1064. Status ModelCacheHelper::GetNodesHashMapJson(Json &json) const {
  1065. if (!(json.is_null() || json.is_array())) {
  1066. GELOGW("Input param json type should be null or array.");
  1067. return PARAM_INVALID;
  1068. }
  1069. map<std::string, size_t> hash_map;
  1070. GetNodesHash(hash_map);
  1071. for (const auto &iter : hash_map) {
  1072. Json node_hash_json;
  1073. try {
  1074. node_hash_json[kName] = iter.first;
  1075. node_hash_json[kHash] = iter.second;
  1076. json.emplace_back(move(node_hash_json));
  1077. } catch (const std::exception &e) {
  1078. GELOGW("Fail to trans node cache to json. Error message: %s", e.what());
  1079. return INTERNAL_ERROR;
  1080. }
  1081. }
  1082. return SUCCESS;
  1083. }
  1084. Status ModelCacheHelper::GetMemResourceMap(Json &json) const {
  1085. if (!(json.is_null() || json.is_array())) {
  1086. GELOGW("Input param json type should be null or array.");
  1087. return PARAM_INVALID;
  1088. }
  1089. const auto total_size = VarManager::Instance(session_id_)->GetVarMemMaxSize();
  1090. const auto var_mem_size = VarManager::Instance(session_id_)->GetVarMemSize(RT_MEMORY_HBM);
  1091. Json mem_resource_json;
  1092. try {
  1093. mem_resource_json[kMemType] = RT_MEMORY_HBM;
  1094. mem_resource_json[kTotalSize] = total_size;
  1095. mem_resource_json[kVarMemSize] = var_mem_size;
  1096. json.emplace_back(move(mem_resource_json));
  1097. } catch (const std::exception &e) {
  1098. GELOGW("Fail to trans MemResourceMap to json. Error message: %s", e.what());
  1099. return INTERNAL_ERROR;
  1100. }
  1101. return SUCCESS;
  1102. }
  1103. Status ModelCacheHelper::GetVarAddrMgrMapJson(Json &json) const {
  1104. if (!(json.is_null() || json.is_array())) {
  1105. GELOGW("Input param json type should be null or array.");
  1106. return PARAM_INVALID;
  1107. }
  1108. std::unordered_map<std::string, VarAddrMgr> var_addr_mgr_map;
  1109. VarManager::Instance(session_id_)->GetAllVarAddrMgr(var_addr_mgr_map);
  1110. try {
  1111. for (const auto &iter : var_addr_mgr_map) {
  1112. Json var_addr_json;
  1113. string name;
  1114. GetVarNameFromVarKey(iter.first, iter.second.tensor_desc, name);
  1115. var_addr_json[kName] = name;
  1116. var_addr_json[kAddress] = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(iter.second.address));
  1117. var_addr_json[kMemoryType] = iter.second.memory_type;
  1118. var_addr_json[kOffset] = iter.second.offset;
  1119. // Copy tensor desc to json.
  1120. Json tensor_desc_json;
  1121. auto ret = TensorDescToJson(iter.second.tensor_desc, tensor_desc_json);
  1122. if (ret != SUCCESS) {
  1123. GELOGW("Fail to trans tensor desc to json.");
  1124. return INTERNAL_ERROR;
  1125. }
  1126. var_addr_json[kTensorDesc] = move(tensor_desc_json);
  1127. json.emplace_back(move(var_addr_json));
  1128. }
  1129. } catch (const std::exception &e) {
  1130. GELOGW("Fail to trans VarAddrMgrMap to json. Error message: %s", e.what());
  1131. return INTERNAL_ERROR;
  1132. }
  1133. return SUCCESS;
  1134. }
  1135. Status ModelCacheHelper::GetCurVarTensorDescMapJson(Json &json) const {
  1136. if (!(json.is_null() || json.is_array())) {
  1137. GELOGW("Input param json type should be null or array.");
  1138. return PARAM_INVALID;
  1139. }
  1140. try {
  1141. for (const auto &name : var_names_) {
  1142. Json cur_tensor_desc_json;
  1143. GeTensorDesc tensor_desc;
  1144. auto ret = VarManager::Instance(session_id_)->GetCurVarDesc(name, tensor_desc);
  1145. if (ret != SUCCESS) {
  1146. GELOGI("Get variable[%s] current tensor desc failed. It will be skipped.", name.c_str());
  1147. continue;
  1148. }
  1149. cur_tensor_desc_json[kName] = name;
  1150. Json tensor_desc_json;
  1151. ret = TensorDescToJson(tensor_desc, tensor_desc_json);
  1152. if (ret != SUCCESS) {
  1153. GELOGW("Fail to trans tensor desc to json.");
  1154. return INTERNAL_ERROR;
  1155. }
  1156. cur_tensor_desc_json[kTensorDesc] = move(tensor_desc_json);
  1157. json.emplace_back(move(cur_tensor_desc_json));
  1158. }
  1159. } catch (const std::exception &e) {
  1160. GELOGW("Fail to trans CurVarTensorDescMap to json. Error message: %s", e.what());
  1161. return INTERNAL_ERROR;
  1162. }
  1163. return SUCCESS;
  1164. }
  1165. Status ModelCacheHelper::GetTransRoadsJson(Json &json) const {
  1166. if (!(json.is_null() || json.is_array())) {
  1167. GELOGW("Input param json type should be null or array.");
  1168. return PARAM_INVALID;
  1169. }
  1170. try {
  1171. for (const auto &name : var_names_) {
  1172. auto trans_road = VarManager::Instance(session_id_)->GetTransRoad(name);
  1173. if (trans_road == nullptr) {
  1174. continue;
  1175. }
  1176. // Json object, variable name and trans road
  1177. Json trans_road_map_json;
  1178. trans_road_map_json[kName] = name;
  1179. Json trans_road_json;
  1180. Status ret;
  1181. // Add nodes' info to json
  1182. for (const auto &trans_node_info : *trans_road) {
  1183. Json trans_node_info_json;
  1184. trans_node_info_json[kNodeType] = trans_node_info.node_type;
  1185. Json input_tensor_desc_json;
  1186. ret = TensorDescToJson(trans_node_info.input, input_tensor_desc_json);
  1187. if (ret != SUCCESS) {
  1188. GELOGW("Fail to trans tensor desc to json.");
  1189. return INTERNAL_ERROR;
  1190. }
  1191. trans_node_info_json[kInputTensorDesc] = move(input_tensor_desc_json);
  1192. Json output_tensor_desc_json;
  1193. ret = TensorDescToJson(trans_node_info.output, output_tensor_desc_json);
  1194. if (ret != SUCCESS) {
  1195. GELOGW("Fail to trans tensor desc to json.");
  1196. return INTERNAL_ERROR;
  1197. }
  1198. trans_node_info_json[kOutputTensorDesc] = move(output_tensor_desc_json);
  1199. trans_road_json.emplace_back(move(trans_node_info_json));
  1200. }
  1201. trans_road_map_json[kTransRoad] = move(trans_road_json);
  1202. json.emplace_back(move(trans_road_map_json));
  1203. }
  1204. } catch (const std::exception &e) {
  1205. GELOGW("Fail to trans VarToTransRoad to json. Error message: %s", e.what());
  1206. return INTERNAL_ERROR;
  1207. }
  1208. return SUCCESS;
  1209. }
  1210. Status ModelCacheHelper::GetChangedGraphIdJson(Json &json) const {
  1211. if (!(json.is_null() || json.is_array())) {
  1212. GELOGW("Input param json type should be null or array.");
  1213. return PARAM_INVALID;
  1214. }
  1215. for (const auto &name : var_names_) {
  1216. uint32_t changed_graph_id = 0;
  1217. Status ret = VarManager::Instance(session_id_)->GetChangedGraphId(name, changed_graph_id);
  1218. if (ret != SUCCESS) {
  1219. continue;
  1220. }
  1221. Json name_and_changed_graph_id;
  1222. try {
  1223. name_and_changed_graph_id[kName] = name;
  1224. name_and_changed_graph_id[kGraphId] = changed_graph_id;
  1225. json.emplace_back(move(name_and_changed_graph_id));
  1226. } catch (const std::exception &e) {
  1227. GELOGW("Fail to trans ChangedGraphId to json. Error message: %s", e.what());
  1228. return INTERNAL_ERROR;
  1229. }
  1230. }
  1231. return SUCCESS;
  1232. }
  1233. Status ModelCacheHelper::GetAllocatedGraphIdJson(Json &json) const {
  1234. if (!(json.is_null() || json.is_array())) {
  1235. GELOGW("Input param json type should be null or array.");
  1236. return PARAM_INVALID;
  1237. }
  1238. for (const auto &name : var_names_) {
  1239. uint32_t allocated_graph_id = 0;
  1240. Status ret = VarManager::Instance(session_id_)->GetAllocatedGraphId(name, allocated_graph_id);
  1241. if (ret != SUCCESS) {
  1242. continue;
  1243. }
  1244. Json name_and_allocated_graph_id;
  1245. try {
  1246. name_and_allocated_graph_id[kName] = name;
  1247. name_and_allocated_graph_id[kGraphId] = allocated_graph_id;
  1248. json.emplace_back(move(name_and_allocated_graph_id));
  1249. } catch (const std::exception &e) {
  1250. GELOGW("Fail to trans AllocatedGraphId to json. Error message: %s", e.what());
  1251. return INTERNAL_ERROR;
  1252. }
  1253. }
  1254. return SUCCESS;
  1255. }
  1256. Status ModelCacheHelper::GetBroadcastInfoJson(Json &json) const {
  1257. if (!(json.is_null() || json.is_array())) {
  1258. GELOGW("Input param json type should be null or array.");
  1259. return PARAM_INVALID;
  1260. }
  1261. for (const auto &name : var_names_) {
  1262. VarBroadCastInfo var_broadcast_info;
  1263. Status ret = VarManager::Instance(session_id_)->GetBroadCastInfo(graph_id_, name, var_broadcast_info);
  1264. if (ret != SUCCESS) {
  1265. continue;
  1266. }
  1267. Json var_broadcast_info_json;
  1268. try {
  1269. var_broadcast_info_json[kName] = name;
  1270. var_broadcast_info_json[kBroadcastName] = var_broadcast_info.broadcast_name;
  1271. var_broadcast_info_json[kIdx] = var_broadcast_info.idx;
  1272. var_broadcast_info_json[kInputOffset] = var_broadcast_info.input_offset;
  1273. var_broadcast_info_json[kInputSize] = var_broadcast_info.input_size;
  1274. var_broadcast_info_json[kOutputOffset] = var_broadcast_info.output_offset;
  1275. var_broadcast_info_json[kOutputSize] = var_broadcast_info.output_size;
  1276. json.emplace_back(move(var_broadcast_info_json));
  1277. } catch (const std::exception &e) {
  1278. GELOGW("Fail to trans VarBroadcastInfo to json. Error message: %s", e.what());
  1279. return INTERNAL_ERROR;
  1280. }
  1281. }
  1282. return SUCCESS;
  1283. }
  1284. Status ModelCacheHelper::GetVarResourceJson(Json &json) const {
  1285. if (!(json.is_null() || json.is_object())) {
  1286. GELOGW("Input param json type should be null or object.");
  1287. return PARAM_INVALID;
  1288. }
  1289. Json var_addr_mgr_map_json;
  1290. Status ret = GetVarAddrMgrMapJson(var_addr_mgr_map_json);
  1291. if (ret != SUCCESS) {
  1292. GELOGW("GetVarAddrMgrMapJson failed.");
  1293. return INTERNAL_ERROR;
  1294. }
  1295. Json cur_var_tensor_desc_map_json;
  1296. ret = GetCurVarTensorDescMapJson(cur_var_tensor_desc_map_json);
  1297. if (ret != SUCCESS) {
  1298. GELOGW("GetCurVarTensorDescMapJson failed.");
  1299. return INTERNAL_ERROR;
  1300. }
  1301. Json trans_roads_json;
  1302. ret = GetTransRoadsJson(trans_roads_json);
  1303. if (ret != SUCCESS) {
  1304. GELOGW("GetTransRoadsJson failed.");
  1305. return INTERNAL_ERROR;
  1306. }
  1307. Json changed_graph_id_json;
  1308. ret = GetChangedGraphIdJson(changed_graph_id_json);
  1309. if (ret != SUCCESS) {
  1310. GELOGW("GetChangedGraphIdJson failed.");
  1311. return INTERNAL_ERROR;
  1312. }
  1313. Json allocated_graph_id_json;
  1314. ret = GetAllocatedGraphIdJson(allocated_graph_id_json);
  1315. if (ret != SUCCESS) {
  1316. GELOGW("GetAllocatedGraphIdJson failed.");
  1317. return INTERNAL_ERROR;
  1318. }
  1319. Json var_broadcast_info_json;
  1320. ret = GetBroadcastInfoJson(var_broadcast_info_json);
  1321. if (ret != SUCCESS) {
  1322. GELOGW("GetBroadcastInfoJson failed.");
  1323. return INTERNAL_ERROR;
  1324. }
  1325. try {
  1326. json[kVarAddrMgrMap] = move(var_addr_mgr_map_json);
  1327. json[kCurVarTensorDescMap] = move(cur_var_tensor_desc_map_json);
  1328. json[kTransRoads] = move(trans_roads_json);
  1329. json[kChangedGraphId] = move(changed_graph_id_json);
  1330. json[kAllocatedGraphId] = move(allocated_graph_id_json);
  1331. json[kVarBroadcastInfo] = move(var_broadcast_info_json);
  1332. } catch (const exception &e) {
  1333. GELOGW("Fail to generate VarResource json. Error message: %s", e.what());
  1334. return INTERNAL_ERROR;
  1335. }
  1336. return SUCCESS;
  1337. }
  1338. Status ModelCacheHelper::GetVarManagerJson(Json &json) const {
  1339. if (!(json.is_null() || json.is_object())) {
  1340. GELOGW("Input param json type should be null or object.");
  1341. return PARAM_INVALID;
  1342. }
  1343. Json mem_resource_map_json;
  1344. auto ret = GetMemResourceMap(mem_resource_map_json);
  1345. if (ret != SUCCESS) {
  1346. GELOGW("GetMemResourceMap failed.");
  1347. return INTERNAL_ERROR;
  1348. }
  1349. Json var_resource_json;
  1350. ret = GetVarResourceJson(var_resource_json);
  1351. if (ret != SUCCESS) {
  1352. GELOGW("GetVarResourceJson failed.");
  1353. return INTERNAL_ERROR;
  1354. }
  1355. try {
  1356. json[kSessionId] = session_id_;
  1357. json[kDeviceId] = VarManager::Instance(session_id_)->DeviceId();
  1358. json[kJobId] = VarManager::Instance(session_id_)->JobId();
  1359. json[kGraphMemMaxSize] = VarManager::Instance(session_id_)->GetGraphMemoryMaxSize();
  1360. json[kVarMemMaxSize] = VarManager::Instance(session_id_)->GetVarMemMaxSize();
  1361. json[kVarMemLogicBase] = VarManager::Instance(session_id_)->GetVarMemLogicBase();
  1362. json[kUseMaxMemSize] = VarManager::Instance(session_id_)->GetUseMaxMemorySize();
  1363. json[kMemResourceMap] = move(mem_resource_map_json);
  1364. json[kVarResource] = move(var_resource_json);
  1365. } catch (const exception &e) {
  1366. GELOGW("Fail to generate VarManager json. Error message: %s", e.what());
  1367. return INTERNAL_ERROR;
  1368. }
  1369. return SUCCESS;
  1370. }
  1371. Status ModelCacheHelper::SaveVarManagerToCache(bool before_build) const {
  1372. if (!is_cache_path_valid_for_output) {
  1373. GELOGW("Invalid cache path.");
  1374. return FAILED;
  1375. }
  1376. Json var_manager_json;
  1377. auto ret = GetVarManagerJson(var_manager_json);
  1378. if (ret != SUCCESS) {
  1379. GELOGW("Fail to generate VarManager json.");
  1380. return FAILED;
  1381. }
  1382. string var_manager_path = to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) +
  1383. (before_build ? kBeforeVarManagerSuffix : kAfterVarManagerSuffix);
  1384. ret = SaveJsonToFile(var_manager_path, var_manager_json);
  1385. if (ret != SUCCESS) {
  1386. GELOGW("Fail to save VarManager info to json file, path: %s.", cache_path_.c_str());
  1387. return ret;
  1388. }
  1389. return SUCCESS;
  1390. }
  1391. Status ModelCacheHelper::SaveOmModelToCache(const GeModelPtr &ge_model) const {
  1392. if (!is_cache_path_valid_for_output) {
  1393. GELOGW("Invalid cache path.");
  1394. return FAILED;
  1395. }
  1396. string om_path = RealPath(cache_path_.c_str());
  1397. if (om_path.empty()) {
  1398. GELOGW("file path is invalid. please check path om: %s", cache_path_.c_str());
  1399. return FAILED;
  1400. }
  1401. string cache_om_path = cache_path_;
  1402. cache_om_path += (to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kOmSuffix);
  1403. GELOGI("SaveOmModelToCache: start to save om model : %s", cache_om_path.c_str());
  1404. ModelHelper model_helper;
  1405. SaveParam save_param;
  1406. ModelBufferData model;
  1407. Status ret = model_helper.SaveToOmModel(ge_model, save_param, cache_om_path, model);
  1408. if (ret != SUCCESS) {
  1409. GELOGW("SaveOmModelToCache: save mode failed. ret = %u", ret);
  1410. return ret;
  1411. }
  1412. return SUCCESS;
  1413. }
  1414. Status ModelCacheHelper::ParseMemResourceFromJson(const Json &json, map<rtMemType_t, int64_t> &mem_resource) {
  1415. if (!(json.is_array() || json.is_null())) {
  1416. GELOGW("Input param json type should be null or array.");
  1417. return PARAM_INVALID;
  1418. }
  1419. mem_resource.clear();
  1420. for (const Json &mem_resource_json : json) {
  1421. try {
  1422. rtMemType_t mem_type = mem_resource_json[kMemType].get<rtMemType_t>();
  1423. uint64_t var_mem_size = mem_resource_json[kVarMemSize].get<int64_t>();
  1424. mem_resource[mem_type] = var_mem_size;
  1425. } catch (const exception &e) {
  1426. GELOGW("Fail to trans Json to MemResource. Error message: %s", e.what());
  1427. return INTERNAL_ERROR;
  1428. }
  1429. }
  1430. return SUCCESS;
  1431. }
  1432. Status ModelCacheHelper::ParseVarAddrMgrMapFromJson(
  1433. const Json &json, std::vector<std::pair<std::string, VarAddrMgr>> &var_addr_mgr_vector,
  1434. std::set<uint64_t> &var_offset_set) {
  1435. if (!(json.is_array() || json.is_null())) {
  1436. GELOGW("Input param json type should be null or array.");
  1437. return PARAM_INVALID;
  1438. }
  1439. var_addr_mgr_vector.clear();
  1440. var_offset_set.clear();
  1441. for (const Json &var_addr_json : json) {
  1442. VarAddrMgr var_addr_mgr;
  1443. try {
  1444. auto logic_address = var_addr_json[kAddress].get<uint64_t>();
  1445. auto address = reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(logic_address));
  1446. var_addr_mgr.address = address;
  1447. var_addr_mgr.offset = var_addr_json[kOffset].get<uint64_t>();
  1448. var_addr_mgr.memory_type = var_addr_json[kMemoryType].get<rtMemType_t>();
  1449. auto ret = JsonToTensorDesc(var_addr_json[kTensorDesc], var_addr_mgr.tensor_desc);
  1450. if (ret != SUCCESS) {
  1451. GELOGW("Fail to trans json to tensor desc.");
  1452. return ret;
  1453. }
  1454. var_addr_mgr_vector.emplace_back(var_addr_json[kName].get<string>(), move(var_addr_mgr));
  1455. var_offset_set.insert(logic_address);
  1456. } catch (const exception &e) {
  1457. GELOGW("Fail to trans Json to VarAddrMgr. Error message: %s", e.what());
  1458. return INTERNAL_ERROR;
  1459. }
  1460. }
  1461. return SUCCESS;
  1462. }
  1463. Status ModelCacheHelper::ParseCurVarTensorDescMapFromJson(
  1464. const Json &json, std::unordered_map<std::string, ge::GeTensorDesc> &cur_var_tensor_desc_map) {
  1465. if (!(json.is_array() || json.is_null())) {
  1466. GELOGW("Input param json type should be null or array.");
  1467. return PARAM_INVALID;
  1468. }
  1469. cur_var_tensor_desc_map.clear();
  1470. for (const Json &tensor_desc_json : json) {
  1471. GeTensorDesc tensor_desc;
  1472. try {
  1473. auto ret = JsonToTensorDesc(tensor_desc_json[kTensorDesc], tensor_desc);
  1474. if (ret != SUCCESS) {
  1475. GELOGW("Fail to trans json to tensor desc.");
  1476. return ret;
  1477. }
  1478. cur_var_tensor_desc_map[tensor_desc_json[kName].get<string>()] = move(tensor_desc);
  1479. } catch (const exception &e) {
  1480. GELOGW("Fail to trans Json to VarAddrMgr. Error message: %s", e.what());
  1481. return INTERNAL_ERROR;
  1482. }
  1483. }
  1484. return SUCCESS;
  1485. }
  1486. Status ModelCacheHelper::ParseTransRoadsFromJson(
  1487. const Json &json, std::unordered_map<std::string, std::vector<TransNodeInfo>> &trans_roads) {
  1488. if (!(json.is_array() || json.is_null())) {
  1489. GELOGW("Input param json type should be null or array.");
  1490. return PARAM_INVALID;
  1491. }
  1492. trans_roads.clear();
  1493. try {
  1494. for (const Json &name_trans_road_json : json) {
  1495. const Json &trans_road_json = name_trans_road_json[kTransRoad];
  1496. if (!(trans_road_json.is_array() || trans_road_json.is_null())) {
  1497. GELOGW("%s json type should be null or object.", kTransRoad);
  1498. return PARAM_INVALID;
  1499. }
  1500. vector<TransNodeInfo> trans_road;
  1501. for (const Json &trans_node_json : trans_road_json) {
  1502. TransNodeInfo trans_node_info;
  1503. trans_node_info.node_type = trans_node_json[kNodeType];
  1504. GeTensorDesc input_tensor_desc;
  1505. auto ret = JsonToTensorDesc(trans_node_json[kInputTensorDesc], input_tensor_desc);
  1506. if (ret != SUCCESS) {
  1507. GELOGW("Fail to trans json to tensor desc.");
  1508. return ret;
  1509. }
  1510. trans_node_info.input = move(input_tensor_desc);
  1511. GeTensorDesc output_tensor_desc;
  1512. ret = JsonToTensorDesc(trans_node_json[kOutputTensorDesc], output_tensor_desc);
  1513. if (ret != SUCCESS) {
  1514. GELOGW("Fail to trans json to tensor desc.");
  1515. return ret;
  1516. }
  1517. trans_node_info.output = move(output_tensor_desc);
  1518. trans_road.emplace_back(move(trans_node_info));
  1519. }
  1520. trans_roads[name_trans_road_json[kName].get<string>()] = move(trans_road);
  1521. }
  1522. } catch (const exception &e) {
  1523. GELOGW("Fail to trans Json to TransRoads. Error message: %s", e.what());
  1524. return INTERNAL_ERROR;
  1525. }
  1526. return SUCCESS;
  1527. }
  1528. Status ModelCacheHelper::ParseChangedGraphIdFromJson(const Json &json,
  1529. std::map<std::string, uint32_t> &changed_graph_id) {
  1530. if (!(json.is_array() || json.is_null())) {
  1531. GELOGW("Input param json type should be null or array.");
  1532. return PARAM_INVALID;
  1533. }
  1534. changed_graph_id.clear();
  1535. for (const Json &name_graph_id_json : json) {
  1536. try {
  1537. changed_graph_id[name_graph_id_json[kName].get<string>()] = name_graph_id_json[kGraphId].get<uint32_t>();
  1538. } catch (const exception &e) {
  1539. GELOGW("Fail to trans Json to changed graph id. Error message: %s", e.what());
  1540. return INTERNAL_ERROR;
  1541. }
  1542. }
  1543. return SUCCESS;
  1544. }
  1545. Status ModelCacheHelper::ParseAllocatedGraphIdFromJson(const Json &json,
  1546. std::map<std::string, uint32_t> &allocated_graph_id) {
  1547. if (!(json.is_array() || json.is_null())) {
  1548. GELOGW("Input param json type should be null or array.");
  1549. return PARAM_INVALID;
  1550. }
  1551. allocated_graph_id.clear();
  1552. for (const Json &name_graph_id_json : json) {
  1553. try {
  1554. allocated_graph_id[name_graph_id_json[kName].get<string>()] = name_graph_id_json[kGraphId].get<uint32_t>();
  1555. } catch (const exception &e) {
  1556. GELOGW("Fail to trans Json to allocated graph id. Error message: %s", e.what());
  1557. return INTERNAL_ERROR;
  1558. }
  1559. }
  1560. return SUCCESS;
  1561. }
  1562. Status ModelCacheHelper::ParseBroadcastInfoFromJson(
  1563. const Json &json, std::unordered_map<std::string, VarBroadCastInfo> &var_broadcast_info) {
  1564. if (!(json.is_array() || json.is_null())) {
  1565. GELOGW("Input param json type should be null or array.");
  1566. return PARAM_INVALID;
  1567. }
  1568. for (const Json &broadcast_info_json : json) {
  1569. VarBroadCastInfo broadcast_info;
  1570. try {
  1571. broadcast_info.var_name = broadcast_info_json[kName].get<string>();
  1572. broadcast_info.broadcast_name = broadcast_info_json[kBroadcastName].get<string>();
  1573. broadcast_info.idx = broadcast_info_json[kIdx].get<int>();
  1574. broadcast_info.input_offset = broadcast_info_json[kInputOffset].get<int64_t>();
  1575. broadcast_info.input_size = broadcast_info_json[kInputSize].get<uint64_t>();
  1576. broadcast_info.output_offset = broadcast_info_json[kOutputOffset].get<int64_t>();
  1577. broadcast_info.output_size = broadcast_info_json[kOutputSize].get<uint64_t>();
  1578. } catch (const exception &e) {
  1579. GELOGW("Fail to trans Json to VarBroadCastInfo. Error message: %s", e.what());
  1580. return INTERNAL_ERROR;
  1581. }
  1582. var_broadcast_info[broadcast_info.var_name] = broadcast_info;
  1583. }
  1584. return SUCCESS;
  1585. }
  1586. Status ModelCacheHelper::LoadOmModelFromCache(GeModelPtr &ge_model) const {
  1587. string cache_om = cache_path_ + to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kOmSuffix;
  1588. if (!CheckInputPathValid(cache_om)) {
  1589. GELOGW("Invalid cache path for input:%s.", cache_om.c_str());
  1590. return FAILED;
  1591. }
  1592. string om_path = RealPath(cache_om.c_str());
  1593. if (om_path.empty()) {
  1594. GELOGW("file path is invalid. please check file om: %s", om_path.c_str());
  1595. return FAILED;
  1596. }
  1597. GELOGI("load model data from file: %s", om_path.c_str());
  1598. Status ret;
  1599. int32_t priority = 0;
  1600. ModelData model_data;
  1601. ret = ModelParserBase::LoadFromFile(om_path.c_str(), priority, model_data);
  1602. if (ret != SUCCESS) {
  1603. GELOGW("LoadOmModelFromCache: Load model from file failed. ret = %u", ret);
  1604. return ret;
  1605. }
  1606. std::function<void()> callback = [&]() {
  1607. if (model_data.model_data != nullptr) {
  1608. delete[] reinterpret_cast<char *>(model_data.model_data);
  1609. model_data.model_data = nullptr;
  1610. }
  1611. };
  1612. GE_MAKE_GUARD(release, callback);
  1613. ModelHelper model_helper;
  1614. ret = model_helper.LoadModel(model_data);
  1615. if (ret != SUCCESS) {
  1616. GELOGW("LoadOmModelFromCache: Load model from data failed. ret = %u", ret);
  1617. return ret;
  1618. }
  1619. ge_model = model_helper.GetGeModel();
  1620. ret = RecompileNodes(ge_model);
  1621. if (ret != SUCCESS) {
  1622. GELOGW("LoadOmModelFromCache: recompile nodes failed. ret = %u", ret);
  1623. return ret;
  1624. }
  1625. return SUCCESS;
  1626. }
  1627. Status ModelCacheHelper::GetVarNameFromVarKey(const string &var_key, const GeTensorDesc &tensor_desc,
  1628. string &var_name) {
  1629. std::string::size_type underline_idx = var_key.rfind('_');
  1630. if (underline_idx == std::string::npos) {
  1631. GELOGW("Invalid var key: underline not found");
  1632. return FAILED;
  1633. }
  1634. std::string::size_type format_idx =
  1635. var_key.rfind(std::to_string(static_cast<int32_t>(tensor_desc.GetFormat())), underline_idx);
  1636. if (format_idx == std::string::npos) {
  1637. GELOGW("Invalid var key: format not found");
  1638. return FAILED;
  1639. }
  1640. var_name = var_key.substr(0, format_idx);
  1641. return SUCCESS;
  1642. }
  1643. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示