You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

omg_stub.cc 30 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <map>
  17. #include <fstream>
  18. #include <unordered_map>
  19. #include <google/protobuf/io/coded_stream.h>
  20. #include <google/protobuf/io/zero_copy_stream_impl.h>
  21. #include "mmpa/mmpa_api.h"
  22. #include "common/debug/log.h"
  23. #include "common/debug/memory_dumper.h"
  24. #include "common/types.h"
  25. #include "common/util.h"
  26. #include "common/string_util.h"
  27. #include "common/properties_manager.h"
  28. #include "common/model_parser/base.h"
  29. #include "graph/model.h"
  30. #include "cce/dnn.h"
  31. #include "ge/ge_api_types.h"
  32. #include "framework/common/ge_types.h"
  33. #include "graph/utils/op_desc_utils.h"
  34. #include "common/profiling/profiling_manager.h"
  35. using domi::domiTensorFormat_t;
  36. using namespace cce;
  37. using namespace ge;
  38. struct PROC_PARAM {
  39. uint8_t *model_name;
  40. // ISV Ek buffer
  41. uint8_t *model_key;
  42. uint32_t model_key_len;
  43. // ISV root certificate buffer
  44. uint8_t *root_cert;
  45. uint32_t root_cert_len;
  46. // ISV private key buffer
  47. uint8_t *pri_key;
  48. uint32_t pri_key_len;
  49. // Raw AI Module Image buffer
  50. uint8_t *ai_image;
  51. uint32_t ai_image_len;
  52. // ISV HW key buffer
  53. uint8_t *hw_key;
  54. uint32_t hw_key_len;
  55. };
  56. #ifdef __cplusplus
  57. extern "C" {
  58. #endif
  59. using namespace ge;
  60. namespace {
  61. const char FMK_STATUS_FILE_DIR_ENV[] = "FMK_STATUS_FILE_DIR";
  62. const char JOBSTATE_FILE_NAME[] = "jobstateupdate_framework";
  63. const char HCOM_DETECT_FILE_NAME[] = "hcom_detection_result";
  64. const char FILE_SEPARATE[] = "/";
  65. } // namespace
  66. #ifdef __cplusplus
  67. }
  68. #endif
  69. namespace ge {
  70. struct GeModelPartition {
  71. ModelPartitionType type_ = MODEL_DEF;
  72. uint8_t *data_ = nullptr;
  73. size_t size_ = 0;
  74. GeModelPartition() = default;
  75. GeModelPartition(const GeModelPartition &partition){};
  76. GeModelPartition &operator=(const GeModelPartition &partition) = delete;
  77. ~GeModelPartition() {
  78. if (data_ != nullptr) {
  79. delete[] data_;
  80. data_ = nullptr;
  81. }
  82. }
  83. Status SetData(uint8_t *data, size_t size) {
  84. size_ = size;
  85. data_ = new (std::nothrow) uint8_t[size]();
  86. errno_t err;
  87. err = memcpy_s(data_, size_, data, size);
  88. if (err) {
  89. GELOGE(ge::FAILED, "[GeModel Partition] Error occur when copy GeModel Partition data.");
  90. return FAILED;
  91. }
  92. return SUCCESS;
  93. }
  94. Status SetType(ModelPartitionType type) {
  95. type_ = type;
  96. return SUCCESS;
  97. }
  98. };
  99. struct OmFileContext {
  100. vector<GeModelPartition> partition_datas_;
  101. vector<char> partition_table_;
  102. uint32_t model_data_len_;
  103. };
  104. class SubGraphInfo;
  105. using SubGraphInfoPtr = std::shared_ptr<ge::SubGraphInfo>;
  106. using GeModelPartitionPtr = std::shared_ptr<GeModelPartition>;
  107. using ModelPtr = std::shared_ptr<ge::Model>;
  108. class GeModel {
  109. public:
  110. explicit GeModel(const ModelPtr &model_ptr);
  111. ~GeModel() = default;
  112. GeModel(const GeModel &other) = delete;
  113. GeModel &operator=(const GeModel &other) = delete;
  114. ModelPtr GetModelPtr() const;
  115. Status AddPartition(uint8_t *data, size_t size, ModelPartitionType type);
  116. Status GetPartition(ModelPartitionType type, GeModelPartitionPtr &partition);
  117. uint8_t GetPlatformType() const;
  118. void SetPlatformType(const uint8_t platform_type) { platform_type_ = platform_type; }
  119. private:
  120. std::map<ModelPartitionType, GeModelPartitionPtr> partitions_;
  121. ModelPtr model_ = nullptr;
  122. uint8_t platform_type_ = {0};
  123. };
  124. using GeModelPtr = std::shared_ptr<ge::GeModel>;
  125. GeModel::GeModel(const ModelPtr &model_ptr) { this->model_ = model_ptr; }
  126. ModelPtr GeModel::GetModelPtr() const { return this->model_; }
  127. uint8_t GeModel::GetPlatformType() const { return platform_type_; }
  128. Status GeModel::AddPartition(uint8_t *data, size_t size, ModelPartitionType type) {
  129. if (size == 0) {
  130. return FAILED;
  131. }
  132. if (data == nullptr) {
  133. return FAILED;
  134. }
  135. auto iter = partitions_.find(type);
  136. if (iter != partitions_.end()) {
  137. return FAILED;
  138. }
  139. GeModelPartitionPtr partition = nullptr;
  140. GE_MAKE_SHARED(partition = std::make_shared<ge::GeModelPartition>(), return FAILED);
  141. Status ret = partition->SetType(type);
  142. if (ret != SUCCESS) {
  143. return FAILED;
  144. }
  145. ret = partition->SetData(data, size);
  146. if (ret != SUCCESS) {
  147. return FAILED;
  148. }
  149. partitions_.insert(std::pair<ModelPartitionType, GeModelPartitionPtr>(type, partition));
  150. return SUCCESS;
  151. }
  152. Status GeModel::GetPartition(ModelPartitionType type, GeModelPartitionPtr &partition) {
  153. auto iter = partitions_.find(type);
  154. if (iter == partitions_.end()) {
  155. return FAILED;
  156. }
  157. partition = iter->second;
  158. return SUCCESS;
  159. }
  160. class OmFileSaveHelper {
  161. public:
  162. OmFileSaveHelper();
  163. ~OmFileSaveHelper();
  164. vector<GeModelPartition> &GetModelPartitions();
  165. ModelPartitionTable *GetPartitionTable();
  166. ModelFileHeader model_header_;
  167. ModelFileHeader &GetModelFileHeader() { return model_header_; }
  168. void AddPartition(GeModelPartition &partition);
  169. private:
  170. OmFileContext context_;
  171. };
  172. OmFileSaveHelper::OmFileSaveHelper() {}
  173. OmFileSaveHelper::~OmFileSaveHelper() {}
  174. vector<GeModelPartition> &OmFileSaveHelper::GetModelPartitions() {
  175. static std::vector<GeModelPartition> tmp;
  176. return tmp;
  177. }
  178. ModelPartitionTable *OmFileSaveHelper::GetPartitionTable() { return nullptr; }
  179. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void OmFileSaveHelper::AddPartition(GeModelPartition &partition) {
  180. context_.partition_datas_.push_back(partition);
  181. context_.model_data_len_ += partition.size_;
  182. }
  183. class ModelBuilder {
  184. public:
  185. ModelBuilder(ge::ComputeGraphPtr compute_graph, const std::vector<SubGraphInfoPtr> &subgraphs,
  186. const std::map<std::string, int> &stream_max_parallel_num, bool hcom_parallel, int mode);
  187. virtual ~ModelBuilder();
  188. Status BuildModel(ge::Model &model_def);
  189. Status SaveWeightsToModel(ge::Model &model);
  190. Status SaveDataToModel(ge::Model &model, ge::GeModel &ge_model);
  191. Status PreBuildModel();
  192. Status BuildModelForGetTask(ge::Model &model_def);
  193. ge::Buffer GetWeightBuffer() const;
  194. void SetModelVersion(ge::Model &model_def);
  195. public:
  196. ge::Buffer weight_buffer_;
  197. };
  198. ModelBuilder::ModelBuilder(ge::ComputeGraphPtr compute_graph, const std::vector<SubGraphInfoPtr> &subgraphs,
  199. const std::map<std::string, int> &stream_max_parallel_num, bool hcom_parallel, int mode) {
  200. weight_buffer_ = ge::Buffer(4100000);
  201. }
  202. ModelBuilder::~ModelBuilder() {}
  203. Status ModelBuilder::SaveWeightsToModel(ge::Model &model) { return SUCCESS; }
  204. Status ModelBuilder::BuildModel(ge::Model &model_def) { return SUCCESS; }
  205. Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { return SUCCESS; }
  206. Status ModelBuilder::PreBuildModel() { return SUCCESS; }
  207. Status ModelBuilder::BuildModelForGetTask(ge::Model &model_def) { return SUCCESS; }
  208. void ModelBuilder::SetModelVersion(ge::Model &model_def) { return; }
  209. ge::Buffer ModelBuilder::GetWeightBuffer() const { return ge::Buffer(4100000); }
  210. } // namespace ge
  211. using ProcParam = struct PROC_PARAM;
  212. namespace ge {
  213. #include <iostream>
  214. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NCHW_DIM_N = 0;
  215. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NCHW_DIM_C = 1;
  216. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NCHW_DIM_H = 2;
  217. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NCHW_DIM_W = 3;
  218. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NHWC_DIM_N = 0;
  219. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NHWC_DIM_H = 1;
  220. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NHWC_DIM_W = 2;
  221. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NHWC_DIM_C = 3;
  222. const uint32_t MODEL_FILE_MAGIC_NUM = 0x444F4D49;
  223. const uint32_t MODEL_FILE_HEAD_LEN = 256;
  224. const uint32_t MODEL_VERSION = 0x10000000;
  225. const int MAX_FILE_SIZE_LIMIT = INT_MAX;
  226. bool FC_WEIGHT_COMPRESS_FLAG = false;
  227. bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length) {
  228. length = 10;
  229. *buffer = new (std::nothrow) char[10]();
  230. GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(*buffer == nullptr, false, "new an object failed.");
  231. return true;
  232. }
  233. bool ReadProtoFromText(const char *file, google::protobuf::Message *message) {
  234. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((nullptr == file || nullptr == message), return false,
  235. "incorrect parameter. nullptr == file || nullptr == message");
  236. string real_path = RealPath(file);
  237. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "proto file path '%s' not valid", file);
  238. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path.c_str()) == -1, return false, "file size not valid.");
  239. std::ifstream fs(real_path.c_str(), std::ifstream::in);
  240. if (!fs.is_open()) {
  241. GELOGE(ge::FAILED, "proto file '%s' open fail.", file);
  242. return false;
  243. }
  244. google::protobuf::io::IstreamInputStream input(&fs);
  245. bool ret = google::protobuf::TextFormat::Parse(&input, message);
  246. GE_IF_BOOL_EXEC(ret != true,
  247. GELOGI("call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file."));
  248. fs.close();
  249. return ret;
  250. }
  251. uint64_t GetCurrentTimestap() { return 0; }
  252. // get length of file
  253. long GetFileLength(const std::string &input_file) {
  254. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(input_file.empty(), return -1, "input_file path is null.");
  255. string real_path = RealPath(input_file.c_str());
  256. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str());
  257. unsigned long long file_length = 0;
  258. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, return -1,
  259. "open file failed.");
  260. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length <= 0), return -1, "file length <= 0, not valid.");
  261. GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(file_length > MAX_FILE_SIZE_LIMIT, return -1, "file size %ld is out of limit: %d.",
  262. file_length, MAX_FILE_SIZE_LIMIT);
  263. return file_length;
  264. }
  265. string RealPath(const char *path) {
  266. string s = path;
  267. if (s.size() >= PATH_MAX) {
  268. return "";
  269. }
  270. if (s == "." || s == "1") {
  271. return path;
  272. // for insert_aipp_op unittest
  273. } else if (s.substr(0, 3) == "llt") {
  274. return path;
  275. } else {
  276. return "22";
  277. }
  278. }
  279. bool CheckInputPathValid(const string &file_path) { return true; }
  280. bool ReadProtoFromArray(const void *data, int size, Message *proto) { return true; }
  281. struct ModelPartition {
  282. ModelPartitionType type;
  283. uint8_t *data = 0;
  284. uint32_t size = 0;
  285. };
  286. class InsertNewOpUtil {
  287. public:
  288. InsertNewOpUtil();
  289. ~InsertNewOpUtil();
  290. Status InsertNewOps(const ComputeGraphPtr &graph);
  291. Status InsertAippOps(ge::ComputeGraphPtr graph, std::string &aipp_config_path);
  292. Status Parse(const char *conf_path);
  293. };
  294. InsertNewOpUtil::InsertNewOpUtil() {}
  295. Status InsertNewOpUtil::InsertNewOps(const ComputeGraphPtr &graph) { return SUCCESS; }
  296. Status InsertNewOpUtil::InsertAippOps(ge::ComputeGraphPtr graph, std::string &aipp_config_path) { return SUCCESS; }
  297. Status InsertNewOpUtil::Parse(const char *conf_path) { return SUCCESS; }
  298. Status InitOME() { return SUCCESS; }
  299. class GraphOptimizer {
  300. public:
  301. Status Optimize();
  302. Status OptimizeAfterCal();
  303. Status AdjustDataOpDesc();
  304. Status InsertTransOp();
  305. Status FusionFmkop();
  306. Status Optimize4Cloud();
  307. Status Optimize4FlowCtrl();
  308. Status OptimizeBeforeBuild();
  309. };
  310. Status GraphOptimizer::Optimize() { return SUCCESS; }
  311. Status Init(Options options) { return SUCCESS; }
  312. Status Shutdown(Options options) { return SUCCESS; }
  313. class Session {
  314. public:
  315. // singleton
  316. static Session *Instance();
  317. const uint32_t &DeviceId() const;
  318. };
  319. const uint32_t &Session::DeviceId() const { return 0; }
  320. Session *Session::Instance() {
  321. static Session instance;
  322. return &instance;
  323. }
  324. struct OmgContext {
  325. domiTensorFormat_t format;
  326. // get input format from cmd
  327. std::unordered_map<std::string, domiTensorFormat_t> input_nodes_format_map;
  328. std::vector<domiTensorFormat_t> output_formats;
  329. // user-designate input dims
  330. std::vector<std::pair<std::string, std::vector<int64_t>>> user_input_dims;
  331. // global input dims
  332. std::unordered_map<std::string, std::vector<int64_t>> input_dims;
  333. // solve rename op e.g: Detectionoutput:SsdDetectiontOutput
  334. std::map<std::string, std::string> op_conf_map;
  335. // save output node of network: key is op name, value = index, index is the output index of op
  336. std::map<std::string, std::vector<int32_t>> out_nodes_map;
  337. // user-designate out nodes (this is used for determing the orders)
  338. std::vector<std::pair<std::string, int32_t>> user_out_nodes;
  339. // save the path of cutsom_aicpu
  340. std::vector<std::string> aicpu_op_run_paths;
  341. // save ddk
  342. std::string ddk_version;
  343. // save format
  344. domiTensorFormat_t net_format;
  345. FrameworkType type;
  346. // RunMode run_mode;
  347. bool train_flag = false;
  348. std::string output_type;
  349. /// save the name of network
  350. /// eg:faster-rcnn, based on FirstStageProcessor after scope_fusion is faster-rcnn
  351. /// then reorder conv+reshape of FirstStageBoxPredictor/BoxEncodingPredictor
  352. /// need to delete op of reshape
  353. std::string net_name;
  354. };
  355. } // namespace ge
  356. namespace domi {
  357. ge::OmgContext &GetContext() {
  358. static ge::OmgContext tmp;
  359. return tmp;
  360. }
  361. } // namespace domi
  362. namespace ge {
  363. class OpUtils {
  364. public:
  365. static Status InitTensorDescriptor(const GeTensorDesc &tensor, ccTensorDescriptor_t &cc_tensor);
  366. static Status InitTensorDescriptor(int32_t format, int32_t data_type, const std::vector<int64_t> &dim,
  367. ccTensorDescriptor_t &cc_tensor, uint32_t real_dim_cnt);
  368. static void DestroyTensorDescriptor(ccTensorDescriptor_t &cc_tensor);
  369. };
  370. Status OpUtils::InitTensorDescriptor(const GeTensorDesc &tensor, ccTensorDescriptor_t &cc_tensor) {
  371. ccCreatePoolingMaskDescriptor(&cc_tensor);
  372. return SUCCESS;
  373. }
  374. Status OpUtils::InitTensorDescriptor(int32_t format, int32_t data_type, const std::vector<int64_t> &dim,
  375. ccTensorDescriptor_t &cc_tensor, uint32_t real_dim_cnt) {
  376. Status ret = SUCCESS;
  377. return ret;
  378. }
  379. class FileSaver {
  380. public:
  381. Status SaveToFile(const string &file_path, ModelFileHeader &model_file_header,
  382. ModelPartitionTable &model_partition_table, const std::vector<ModelPartition> &partition_datas);
  383. Status SaveToFileWithEncrypt(const std::string file_path, const ProcParam proc_param,
  384. const ModelFileHeader *model_file_header, bool check_sum);
  385. };
  386. Status FileSaver::SaveToFile(const string &file_path, ModelFileHeader &model_file_header,
  387. ModelPartitionTable &model_partition_table,
  388. const std::vector<ModelPartition> &partition_datas) {
  389. return SUCCESS;
  390. }
  391. Status FileSaver::SaveToFileWithEncrypt(const std::string file_path, const ProcParam proc_param,
  392. const ModelFileHeader *model_file_header, bool check_sum) {
  393. return SUCCESS;
  394. }
  395. class ModelSaver : public FileSaver {};
  396. FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void OpUtils::DestroyTensorDescriptor(
  397. ccTensorDescriptor_t &cc_tensor) {
  398. if (nullptr != cc_tensor) {
  399. ccStatus_t ret = ccDestroyTensorDescriptor(&cc_tensor);
  400. GE_LOGE_IF(CC_STATUS_SUCCESS != ret, "ccDestroyTensorDescriptor failed. ret = %d", ret);
  401. cc_tensor = nullptr;
  402. }
  403. }
  404. } // namespace ge
  405. namespace domi {
  406. class OpRegistrationData {};
  407. class OpRegistry {
  408. public:
  409. static OpRegistry *Instance();
  410. std::vector<OpRegistrationData> registration_datas;
  411. ImplyType GetImplyType(const std::string &op_type);
  412. void GetOpTypeByImplyType(std::vector<std::string> &vec_op_type, const ImplyType &imply_type);
  413. };
  414. OpRegistry *OpRegistry::Instance() {
  415. static OpRegistry instance;
  416. return &instance;
  417. }
  418. void OpRegistry::GetOpTypeByImplyType(std::vector<std::string> &vec_op_type, const ImplyType &imply_type) {
  419. if (imply_type == ImplyType::AI_CPU) {
  420. vec_op_type.push_back("square");
  421. }
  422. }
  423. class OpRegistrationTbe {
  424. public:
  425. static OpRegistrationTbe *Instance();
  426. bool Finalize(OpRegistrationData &reg_data, bool is_train);
  427. };
  428. OpRegistrationTbe *OpRegistrationTbe::Instance() {
  429. static OpRegistrationTbe instance;
  430. return &instance;
  431. }
  432. bool OpRegistrationTbe::Finalize(OpRegistrationData &reg_data, bool is_train) { return true; }
  433. } // namespace domi
  434. namespace ge {
  435. class GraphPrepare {
  436. private:
  437. Status OptimizeForPreprocess(ge::ComputeGraphPtr &compute_graph);
  438. };
  439. Status GraphPrepare::OptimizeForPreprocess(ge::ComputeGraphPtr &compute_graph) { return SUCCESS; }
  440. } // namespace ge
  441. namespace ge {
  442. Status GetOriginalType(const ge::NodePtr &node, string &type) {
  443. type = node->GetType();
  444. GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS);
  445. ge::AttrUtils::GetStr(node->GetOpDesc(), "original_type", type);
  446. return SUCCESS;
  447. }
  448. Status SetCycleEvent(const ge::NodePtr &node) { return SUCCESS; }
  449. Status SetStreamLabel(const ge::NodePtr &node, const std::string &label) {
  450. GE_CHECK_NOTNULL(node);
  451. OpDescPtr tmp_desc = AttrUtils::CloneOpDesc(node->GetOpDesc());
  452. GE_CHECK_NOTNULL(tmp_desc);
  453. if (!AttrUtils::SetStr(tmp_desc, "_stream_label", label)) {
  454. GELOGE(ge::FAILED, "Op :%s set ATTR_NAME_STREAM_LABEL failed", node->GetName().c_str());
  455. return FAILED;
  456. }
  457. return SUCCESS;
  458. }
  459. Status SetActiveLabelList(const ge::NodePtr &node, const std::vector<std::string> &label) {
  460. GE_CHECK_NOTNULL(node);
  461. OpDescPtr tmp_desc = node->GetOpDesc();
  462. GE_CHECK_NOTNULL(tmp_desc);
  463. // add list of active_label
  464. if (!AttrUtils::SetListStr(tmp_desc, "_active_label", label)) {
  465. GELOGE(ge::FAILED, "Op: %s set ATTR_NAME_ACTIVE_LABEL_LIST failed", node->GetName().c_str());
  466. return FAILED;
  467. }
  468. return SUCCESS;
  469. }
  470. Status SetSwitchBranchNodeLabel(const ge::NodePtr &node, const std::string &branch_label) {
  471. GE_CHECK_NOTNULL(node);
  472. OpDescPtr tmp_desc = node->GetOpDesc();
  473. GE_CHECK_NOTNULL(tmp_desc);
  474. // add branch_label of switch
  475. if (!AttrUtils::SetStr(tmp_desc, "_switch_branch_node_label", branch_label)) {
  476. GELOGE(ge::FAILED, "Op :%s set ATTR_NAME_SWITCH_BRANCH_NODE_LABEL failed", node->GetName().c_str());
  477. return FAILED;
  478. }
  479. return SUCCESS;
  480. }
  481. Status SetSwitchTrueBranchFlag(const ge::NodePtr &node, bool value) {
  482. GE_CHECK_NOTNULL(node);
  483. OpDescPtr tmp_desc = node->GetOpDesc();
  484. GE_CHECK_NOTNULL(tmp_desc);
  485. // add switch_true_branch_flag
  486. if (!AttrUtils::SetBool(tmp_desc, "_switch_true_branch_flag", value)) {
  487. GELOGE(ge::FAILED, "Op :%s set ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG failed", node->GetName().c_str());
  488. return FAILED;
  489. }
  490. return SUCCESS;
  491. }
  492. Status SetOriginalNodeName(const ge::NodePtr &node, const std::string &orig_name) {
  493. GE_CHECK_NOTNULL(node);
  494. OpDescPtr tmp_desc = node->GetOpDesc();
  495. GE_CHECK_NOTNULL(tmp_desc);
  496. // record original_node_name
  497. if (!AttrUtils::SetStr(tmp_desc, "_original_node_name", orig_name)) {
  498. GELOGE(ge::FAILED, "Op :%s set ATTR_NAME_ORIG_NODE_NAME failed", node->GetName().c_str());
  499. return FAILED;
  500. }
  501. return SUCCESS;
  502. }
  503. Status SetCyclicDependenceFlag(const ge::NodePtr &node) {
  504. GE_CHECK_NOTNULL(node);
  505. OpDescPtr tmp_desc = node->GetOpDesc();
  506. GE_CHECK_NOTNULL(tmp_desc);
  507. // add cyclic_dependence_flag
  508. if (!AttrUtils::SetBool(tmp_desc, "_cyclic_dependence_flag", true)) {
  509. GELOGE(ge::FAILED, "Op :%s set ATTR_NAME_CYCLIC_DEPENDENCE_FLAG failed", node->GetName().c_str());
  510. return FAILED;
  511. }
  512. return SUCCESS;
  513. }
  514. Status SetNextIteration(const ge::NodePtr &node, const std::string &next) {
  515. GE_CHECK_NOTNULL(node);
  516. OpDescPtr tmp_desc = node->GetOpDesc();
  517. GE_CHECK_NOTNULL(tmp_desc);
  518. if (!AttrUtils::SetStr(tmp_desc, "_next_iteration_node", next)) {
  519. GELOGE(ge::FAILED, "Op: %s set ATTR_NAME_NEXT_ITERATION failed", node->GetName().c_str());
  520. return FAILED;
  521. }
  522. return SUCCESS;
  523. }
  524. } // namespace ge
  525. namespace cce {
  526. bool ccGetFuncState(ccFuncParamType_t type) { return true; }
  527. } // namespace cce
  528. namespace ge {
  529. Status UnloadModel(uint32_t model_id) { return SUCCESS; }
  530. Status GetInputOutputDescInfo(uint32_t model_id, vector<InputOutputDescInfo> &input_desc,
  531. vector<InputOutputDescInfo> &output_desc) {
  532. return SUCCESS;
  533. }
  534. Status DataInput(const InputData *input_data, OutputData *output_data) { return SUCCESS; }
  535. /*
  536. class ModelManager {
  537. public:
  538. static std::shared_ptr<ModelManager> GetInstance();
  539. static void FinalizeForPtr(ModelManager *) {}
  540. Status DataInputTensor(uint32_t model_id, const std::vector<ge::TensorInfo> &inputs,
  541. std::vector<ge::TensorInfo> &outputs);
  542. Status DataInput(const InputData &input_data, OutputData &output_data);
  543. Status GetInputOutputDescInfo(const uint32_t model_id, std::vector<InputOutputDescInfo> &input_desc,
  544. std::vector<InputOutputDescInfo> &output_desc);
  545. Status GetInputOutputDescInfo(const uint32_t model_id, std::vector<InputOutputDescInfo> &input_desc,
  546. std::vector<InputOutputDescInfo> &output_desc, std::vector<uint32_t> &input_formats,
  547. std::vector<uint32_t> &output_formats);
  548. Status GetInputOutputDescInfoForZeroCopy(const uint32_t model_id, std::vector<InputOutputDescInfo> &input_desc,
  549. std::vector<InputOutputDescInfo> &output_desc,
  550. std::vector<uint32_t> &input_formats, std::vector<uint32_t> &output_formats);
  551. Status Stop(uint32_t model_id);
  552. Status Unload(uint32_t model_id);
  553. Status LoadModelOnline(uint32_t &model_id, std::shared_ptr<ge::Model> &model,
  554. std::shared_ptr<ModelListener> listener);
  555. Status Start(uint32_t model_id);
  556. Status GetMaxUsedMemory(const uint32_t model_id, uint64_t &max_size);
  557. Status LoadModelOffline(uint32_t &model_id, const ModelData &model, std::shared_ptr<ModelListener> listener = nullptr,
  558. void *dev_ptr = nullptr, size_t mem_size = 0, void *weight_ptr = nullptr,
  559. size_t weight_size = 0);
  560. Status LoadModelWithQ(uint32_t &model_id, const ModelData &model_data, const std::vector<uint32_t> &input_queue_ids,
  561. const std::vector<uint32_t> &output_queue_ids);
  562. Status HandleCommand(const Command &command);
  563. Status ExecuteModel(uint32_t model_id, rtStream_t stream, bool async_mode, const InputData &input_data,
  564. OutputData &output_data);
  565. void DestroyAicpuSession(uint64_t session_id);
  566. };
  567. void ModelManager::DestroyAicpuSession(uint64_t session_id) {}
  568. std::shared_ptr<ModelManager> ModelManager::GetInstance() {
  569. static std::shared_ptr<ModelManager> instance_ptr =
  570. shared_ptr<ModelManager>(new ModelManager(), ModelManager::FinalizeForPtr);
  571. return instance_ptr;
  572. }
  573. Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector<ge::TensorInfo> &inputs,
  574. std::vector<ge::TensorInfo> &outputs) {
  575. return SUCCESS;
  576. }
  577. Status ModelManager::DataInput(const InputData &input_data, OutputData &output_data) { return SUCCESS; }
  578. Status ModelManager::GetInputOutputDescInfo(const uint32_t model_id, std::vector<InputOutputDescInfo> &input_desc,
  579. std::vector<InputOutputDescInfo> &output_desc,
  580. std::vector<uint32_t> &input_formats,
  581. std::vector<uint32_t> &output_formats) {
  582. return SUCCESS;
  583. }
  584. Status ModelManager::GetInputOutputDescInfo(const uint32_t model_id, std::vector<InputOutputDescInfo> &input_desc,
  585. std::vector<InputOutputDescInfo> &output_desc) {
  586. return SUCCESS;
  587. }
  588. Status ModelManager::GetInputOutputDescInfoForZeroCopy(const uint32_t model_id,
  589. std::vector<InputOutputDescInfo> &input_desc,
  590. std::vector<InputOutputDescInfo> &output_desc,
  591. std::vector<uint32_t> &input_formats,
  592. std::vector<uint32_t> &output_formats) {
  593. return SUCCESS;
  594. }
  595. Status ModelManager::Stop(uint32_t model_id) { return SUCCESS; }
  596. Status ModelManager::Unload(uint32_t model_id) { return SUCCESS; }
  597. Status ModelManager::LoadModelOnline(uint32_t &model_id, std::shared_ptr<ge::Model> &model,
  598. std::shared_ptr<ModelListener> listener) {
  599. return SUCCESS;
  600. }
  601. Status ModelManager::Start(uint32_t model_id) { return SUCCESS; }
  602. Status ModelManager::GetMaxUsedMemory(const uint32_t model_id, uint64_t &max_size) { return SUCCESS; }
  603. Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model, shared_ptr<ModelListener> listener,
  604. void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) {
  605. return SUCCESS;
  606. }
  607. Status ModelManager::LoadModelWithQ(uint32_t &model_id, const ModelData &model_data,
  608. const std::vector<uint32_t> &input_queue_ids,
  609. const std::vector<uint32_t> &output_queue_ids) {
  610. return SUCCESS;
  611. }
  612. Status ModelManager::HandleCommand(const Command &command) { return SUCCESS; }
  613. Status ModelManager::ExecuteModel(uint32_t model_id, rtStream_t stream, bool async_mode, const InputData &input_data,
  614. OutputData &output_data) {
  615. return SUCCESS;
  616. }
  617. */
  618. } // namespace ge
  619. namespace ge {
  620. enum JobState {
  621. JOBSTATE_WAITING = 1,
  622. JOBSTATE_RUNNING,
  623. JOBSTATE_KILLING,
  624. JOBSTATE_SUCCEED,
  625. JOBSTATE_FAILED,
  626. JOBSTATE_KILLED,
  627. JOBSTATE_UNKOWN
  628. };
  629. enum JobSubState {
  630. JOBSUBSTATE_ENV_INIT = 201,
  631. JOBSUBSTATE_ENV_FIN,
  632. JOBSUBSTATE_RESOUCE_ALLOC,
  633. JOBSUBSTATE_MODEL_COMPILE,
  634. JOBSUBSTATE_GRAPH_PREPARE,
  635. JOBSUBSTATE_GRAPH_SPLIT,
  636. JOBSUBSTATE_GRAPH_OPTIMIZE,
  637. JOBSUBSTATE_GRAPH_BUILD,
  638. JOBSUBSTATE_GRAPH_LOAD,
  639. JOBSUBSTATE_GRAPH_EXEC,
  640. JOBSUBSTATE_GRAPH_UNLOAD,
  641. JOBSUBSTATE_OTHER
  642. };
  643. enum ErrorModule {
  644. ERROR_MODULE_DRIVER = 0x01,
  645. ERROR_MODULE_RUNTIME = 0x04,
  646. ERROR_MODULE_CCE = 0x06,
  647. ERROR_MODULE_FMK = 0x08,
  648. ERROR_MODULE_HCCL = 0x12
  649. };
  650. class CsaInteract {
  651. public:
  652. CsaInteract &GetInstance();
  653. void WriteErrorCode(uint32_t module_ret_errcode, ErrorModule error_module, JobSubState job_sub_state);
  654. void Init(int32_t dev_index, int64_t job_id);
  655. Status WriteJobState(JobState job_state, JobSubState job_sub_state = JOBSUBSTATE_OTHER,
  656. uint32_t module_ret_errcode = SUCCESS, ErrorModule error_module = ERROR_MODULE_FMK);
  657. // device index
  658. int32_t dev_index_;
  659. // job id
  660. int64_t job_id_;
  661. // is initialization complete
  662. bool is_init_;
  663. // current job state
  664. JobState curr_state_;
  665. // job state file
  666. std::string job_state_file_;
  667. // network connectivity detect file
  668. std::string hcom_detect_file_;
  669. // identification of internal errors that occurred during the training
  670. bool is_have_internal_error_;
  671. };
  672. CsaInteract &CsaInteract::GetInstance() {
  673. static CsaInteract instance;
  674. return instance;
  675. }
  676. void CsaInteract::Init(int32_t dev_index, int64_t job_id) {
  677. if (!is_init_) {
  678. dev_index_ = dev_index;
  679. job_id_ = job_id;
  680. string csa_path_prefix;
  681. if (std::getenv(FMK_STATUS_FILE_DIR_ENV) != nullptr) {
  682. csa_path_prefix = std::getenv(FMK_STATUS_FILE_DIR_ENV);
  683. }
  684. if (!csa_path_prefix.empty()) {
  685. std::string job_state_file = csa_path_prefix + std::to_string(dev_index_) + FILE_SEPARATE + JOBSTATE_FILE_NAME;
  686. std::string hcom_detect_file =
  687. csa_path_prefix + std::to_string(dev_index_) + FILE_SEPARATE + HCOM_DETECT_FILE_NAME;
  688. job_state_file_ = RealPath(job_state_file.c_str());
  689. hcom_detect_file_ = RealPath(hcom_detect_file.c_str());
  690. }
  691. is_init_ = true;
  692. }
  693. }
  694. void CsaInteract::WriteErrorCode(uint32_t module_ret_errcode, ErrorModule error_module, JobSubState job_sub_state) {}
  695. } // namespace ge
  696. Status ModelParserBase::LoadFromFile(const char *model_path, const char *key, int32_t priority,
  697. ge::ModelData &model_data) {
  698. return SUCCESS;
  699. }
  700. Status CsaInteract::WriteJobState(JobState job_state, JobSubState job_sub_state, uint32_t module_ret_errcode,
  701. ErrorModule error_module) {
  702. return SUCCESS;
  703. }
  704. namespace ge {
  705. static std::map<ge::DataType, uint32_t> data_type_to_length = {
  706. {DT_BOOL, sizeof(bool)}, {DT_INT64, sizeof(int64_t)}, {DT_UINT64, sizeof(int64_t)}, {DT_FLOAT, sizeof(float)},
  707. {DT_INT32, sizeof(int32_t)}, {DT_UINT32, sizeof(int32_t)}, {DT_INT8, sizeof(char)}, {DT_UINT8, sizeof(char)},
  708. {DT_INT16, sizeof(int16_t)}, {DT_UINT16, sizeof(int16_t)}, {DT_FLOAT16, sizeof(int16_t)}, {DT_DOUBLE, sizeof(double)},
  709. };
  710. class TypeUtils {
  711. public:
  712. static bool GetDataTypeLength(ge::DataType data_type, uint32_t &length);
  713. static bool CheckUint64MulOverflow(uint64_t a, uint32_t b);
  714. };
  715. bool TypeUtils::GetDataTypeLength(ge::DataType data_type, uint32_t &length) {
  716. auto it = data_type_to_length.find(data_type);
  717. if (it != data_type_to_length.end()) {
  718. length = it->second;
  719. return true;
  720. } else {
  721. return false;
  722. }
  723. }
  724. bool TypeUtils::CheckUint64MulOverflow(uint64_t a, uint32_t b) {
  725. // Not overflow
  726. if (a == 0) {
  727. return false;
  728. }
  729. if ((ULLONG_MAX / a) >= b) {
  730. return false;
  731. }
  732. return true;
  733. }
  734. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示