You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

davinci_model.h 32 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_
  17. #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_
  18. #include <map>
  19. #include <memory>
  20. #include <set>
  21. #include <string>
  22. #include <thread>
  23. #include <vector>
  24. #include "common/ge_types.h"
  25. #include "common/helper/model_helper.h"
  26. #include "common/helper/om_file_helper.h"
  27. #include "common/opskernel/ge_task_info.h"
  28. #include "common/properties_manager.h"
  29. #include "common/types.h"
  30. #include "framework/common/util.h"
  31. #include "graph/debug/ge_attr_define.h"
  32. #include "graph/load/new_model_manager/aipp_utils.h"
  33. #include "graph/load/new_model_manager/data_dumper.h"
  34. #include "graph/load/new_model_manager/data_inputer.h"
  35. #include "graph/load/new_model_manager/model_utils.h"
  36. #include "graph/load/new_model_manager/zero_copy_offset.h"
  37. #include "graph/load/new_model_manager/zero_copy_task.h"
  38. #include "graph/model.h"
  39. #include "graph/node.h"
  40. #include "graph/op_desc.h"
  41. #include "graph/operator.h"
  42. #include "graph/utils/attr_utils.h"
  43. #include "graph/utils/tensor_utils.h"
  44. #include "mmpa/mmpa_api.h"
  45. #include "proto/task.pb.h"
  46. #include "task_info/task_info.h"
  47. #include "graph/common/local_context.h"
  48. namespace ge {
  49. // op debug need 2048 bits buffer
  50. const size_t kOpDebugMemorySize = 2048UL;
  51. const size_t kDebugP2pSize = 8UL;
  52. typedef enum tagModelProcStage {
  53. MODEL_LOAD_START = 1,
  54. MODEL_LOAD_END,
  55. MODEL_PRE_PROC_START,
  56. MODEL_PRE_PROC_END,
  57. MODEL_INFER_START,
  58. MODEL_INFER_END,
  59. MODEL_AFTER_PROC_START,
  60. MODEL_AFTER_PROC_END,
  61. MODEL_PROC_INVALID,
  62. } ModelProcStage;
  63. struct timeInfo {
  64. uint32_t modelId;
  65. int64_t processBeginTime;
  66. int64_t processEndTime;
  67. int64_t inferenceBeginTime;
  68. int64_t inferenceEndTime;
  69. int64_t dumpBeginTime;
  70. int64_t dumpEndTime;
  71. };
  72. enum ExecuteMode {
  73. INITIALIZATION,
  74. SYNCHRONIZATION,
  75. ASYNCHRONIZATION,
  76. };
  77. // comments
  78. class DavinciModel {
  79. public:
  80. ///
  81. /// @ingroup ge
  82. /// @brief DavinciModel constructor
  83. /// @author
  84. ///
  85. DavinciModel(int32_t priority, const std::shared_ptr<ModelListener> &listener);
  86. ///
  87. /// @ingroup ge
  88. /// @brief DavinciModel desctructor, free Parse and Init resources
  89. /// @author
  90. ///
  91. ~DavinciModel();
  92. ///
  93. /// @ingroup ge
  94. /// @brief apply model to model_def_
  95. ///
  96. Status Assign(const GeModelPtr &ge_model);
  97. ///
  98. /// @ingroup ge
  99. /// @brief DavinciModel initialization, including Stream, ccHandle, Event, DataInputer, etc
  100. /// @return execute result
  101. /// @author
  102. ///
  103. Status Init(void *dev_ptr = nullptr, size_t memsize = 0, void *weight_ptr = nullptr, size_t weightsize = 0);
  104. ///
  105. /// @ingroup ge
  106. /// @brief ACL case, Load task list with queue.
  107. /// @param [in] input_que_ids: input queue ids from user, nums equal Data Op.
  108. /// @param [in] output_que_ids: input queue ids from user, nums equal NetOutput Op.
  109. /// @return: 0 for success / others for fail
  110. ///
  111. Status SetQueIds(const std::vector<uint32_t> &input_queue_ids, const std::vector<uint32_t> &output_queue_ids);
  112. ///
  113. /// @ingroup ge
  114. /// @brief Get DataInputer
  115. /// @return model ID
  116. ///
  117. uint32_t Id() const { return model_id_; }
  118. ///
  119. /// @ingroup ge
  120. /// @brief Get DataInputer
  121. /// @return model ID
  122. ///
  123. void SetId(uint32_t model_id) { model_id_ = model_id; }
  124. static void *Run(DavinciModel *model_pointer);
  125. ///
  126. /// @ingroup ge
  127. /// @brief NnExecute
  128. /// @param [in] stream execute stream
  129. /// @param [in] async_mode is asynchronize mode.
  130. /// @param [in] input_data model input data
  131. /// @param [out] output_data model output data
  132. ///
  133. Status NnExecute(rtStream_t stream, bool async_mode, const InputData &input_data, OutputData &output_data);
  134. ///
  135. /// @ingroup ge
  136. /// @brief lock mutex run flag
  137. /// @author
  138. ///
  139. void LockRunFlg() { mux_run_flg_.lock(); }
  140. ///
  141. /// @ingroup ge
  142. /// @brief unlock mutex run flag
  143. /// @author
  144. ///
  145. void UnlockRunFlg() { mux_run_flg_.unlock(); }
  146. ///
  147. /// @ingroup ge
  148. /// @brief get DataInputer
  149. /// @return DataInputer pointer
  150. ///
  151. DataInputer *const GetDataInputer() const { return data_inputer_; }
  152. // get Stream number
  153. uint32_t StreamNum() const { return runtime_param_.stream_num; }
  154. // get Event number
  155. uint32_t EventNum() const { return runtime_param_.event_num; }
  156. // get Lable number
  157. uint32_t LabelNum() const { return runtime_param_.label_num; }
  158. // get batch number
  159. uint32_t BatchNum() const { return runtime_param_.batch_num; }
  160. // get session id
  161. uint64_t SessionId() const { return runtime_param_.session_id; }
  162. // get model priority
  163. int32_t Priority() const { return priority_; }
  164. // get total mem size
  165. size_t TotalMemSize() const { return runtime_param_.mem_size; }
  166. const std::map<uint32_t, MemInfo> &P2PMemInfos() const {return runtime_param_.memory_infos;}
  167. // model name
  168. string Name() const { return name_; }
  169. // om_name
  170. string OmName() const { return om_name_; }
  171. // version
  172. uint32_t Version() const { return version_; }
  173. // get total weights mem size
  174. size_t TotalWeightsMemSize() const { return runtime_param_.weight_size; }
  175. size_t TotalVarMemSize() const { return runtime_param_.var_size; }
  176. // get base memory address
  177. uint8_t *MemBase() { return mem_base_; }
  178. // get weight base memory address
  179. uint8_t *WeightsMemBase() { return weights_mem_base_; }
  180. uint8_t *VarMemBase() { return var_mem_base_; }
  181. // get Event list
  182. const vector<rtEvent_t> &GetEventList() const { return event_list_; }
  183. const vector<rtStream_t> &GetStreamList() const { return stream_list_; }
  184. const vector<rtLabel_t> &GetLabelList() const { return label_list_; }
  185. Status DestroyThread();
  186. // Get Data Op.
  187. const vector<OpDescPtr> &GetDataList() const { return data_op_list_; }
  188. // get Op
  189. const map<uint32_t, OpDescPtr> &GetOpList() const { return op_list_; }
  190. OpDescPtr GetOpByIndex(uint32_t index) const {
  191. if (op_list_.find(index) == op_list_.end()) {
  192. return nullptr;
  193. }
  194. return op_list_.at(index);
  195. }
  196. OpDescPtr GetVariableOp(const string &name) {
  197. for (auto op_desc : variable_op_list_) {
  198. if (op_desc != nullptr && op_desc->GetName() == name) {
  199. return op_desc;
  200. }
  201. }
  202. return nullptr;
  203. }
  204. // get task info for profiling
  205. const std::vector<TaskDescInfo> &GetTaskDescInfo() const { return task_desc_info_; }
  206. // get updated task info list
  207. std::vector<TaskInfoPtr> GetTaskList() { return task_list_; }
  208. ///
  209. /// @ingroup ge
  210. /// @brief get model input and output format
  211. /// @return ccTensorFormat_t current model input and output format
  212. ///
  213. Format GetFormat();
  214. rtModel_t GetRtModelHandle() const { return rt_model_handle_; }
  215. rtStream_t GetRtModelStream() const { return rt_model_stream_; }
  216. uint64_t GetRtBaseAddr() const { return runtime_param_.logic_mem_base; }
  217. uint64_t GetRtWeightAddr() const { return runtime_param_.logic_weight_base; }
  218. uint64_t GetRtVarAddr() const { return runtime_param_.logic_var_base; }
  219. uint32_t GetFlowctrlIndex(uint32_t op_index);
  220. void PushHcclStream(rtStream_t value);
  221. bool IsBroadCastOpData(const NodePtr &var_node);
  222. ///
  223. /// @ingroup ge
  224. /// @brief For TVM Op, avoid Addr Reuse.
  225. /// @return void*
  226. ///
  227. const char *GetRegisterStub(const string &tvm_binfile_key, const string &session_graph_model_id = "");
  228. ///
  229. /// @ingroup ge
  230. /// @brief get model input and output desc info
  231. /// @param [out] input_shape model input size
  232. /// @param [out] output_shape model output size
  233. /// @return execute result
  234. ///
  235. Status GetInputOutputDescInfo(vector<InputOutputDescInfo> &input_desc, vector<InputOutputDescInfo> &output_desc);
  236. Status GetInputOutputDescInfo(vector<InputOutputDescInfo> &input_desc, vector<InputOutputDescInfo> &output_desc,
  237. std::vector<uint32_t> &inputFormats, std::vector<uint32_t> &output_formats);
  238. ///
  239. /// @ingroup ge
  240. /// @brief Get dynamic batch_info
  241. /// @param [out] batch_info
  242. /// @param [out] dynamic_type
  243. /// @return execute result
  244. ///
  245. Status GetDynamicBatchInfo(std::vector<std::vector<int64_t>> &batch_info, int32_t &dynamic_type) const;
  246. ///
  247. /// @ingroup ge
  248. /// @brief Get combined dynamic dims info
  249. /// @param [out] batch_info
  250. /// @return None
  251. ///
  252. void GetCombinedDynamicDims(std::vector<std::vector<int64_t>> &batch_info) const;
  253. void GetUserDesignateShapeOrder(std::vector<std::string> &user_input_shape_order) const;
  254. void GetCurShape(std::vector<int64_t> &batch_info, int32_t &dynamic_type);
  255. void GetModelAttr(std::vector<std::string> &dynamic_output_shape_info);
  256. ///
  257. /// @ingroup ge
  258. /// @brief Get AIPP input info
  259. /// @param [in] index
  260. /// @param [out] aipp_info
  261. /// @return execute result
  262. ///
  263. Status GetAIPPInfo(uint32_t index, AippConfigInfo &aipp_info);
  264. Status GetAippType(uint32_t index, InputAippType &type, size_t &aipp_index);
  265. ///
  266. /// @ingroup ge
  267. /// @brief Get model_id.
  268. /// @return model_id
  269. ///
  270. uint32_t GetModelId() const { return model_id_; }
  271. ///
  272. /// @ingroup ge
  273. /// @brief get unique identification for op when load two or more models
  274. /// @param [in] op_desc : current op.
  275. /// @param [in] string identification: unique identification for current op.
  276. /// @return None
  277. ///
  278. void GetUniqueId(const OpDescPtr &op_desc, std::string &unique_identification);
  279. ///
  280. /// @ingroup ge
  281. /// @brief get model input and output desc for zero copy
  282. /// @param [out] input_shape model input size
  283. /// @param [out] output_shape model output size
  284. /// @return execute result
  285. ///
  286. Status GetInputOutputDescInfoForZeroCopy(vector<InputOutputDescInfo> &input_desc,
  287. vector<InputOutputDescInfo> &output_desc,
  288. std::vector<uint32_t> &inputFormats, std::vector<uint32_t> &output_formats);
  289. Status ReturnResult(uint32_t data_id, const bool rslt_flg, const bool seq_end_flg, OutputData *output_data);
  290. Status ReturnNoOutput(uint32_t data_id);
  291. Status ModelRunStart();
  292. ///
  293. /// @ingroup ge
  294. /// @brief stop run model
  295. /// @return Status
  296. ///
  297. Status ModelRunStop();
  298. ///
  299. /// @ingroup ge
  300. /// @brief model run flag
  301. /// @return Status
  302. ///
  303. bool RunFlag() const { return run_flg_; }
  304. Status GetOutputDescInfo(vector<InputOutputDescInfo> &output_desc, std::vector<uint32_t> &formats);
  305. ///
  306. /// @ingroup ge
  307. /// @brief Set Session Id
  308. /// @return void
  309. ///
  310. void SetSessionId(uint64_t session_id) { session_id_ = session_id; }
  311. ///
  312. /// @ingroup ge
  313. /// @brief Get Session Id
  314. /// @return sessionID
  315. ///
  316. uint64_t GetSessionId() const { return session_id_; }
  317. ///
  318. /// @ingroup ge
  319. /// @brief SetDeviceId
  320. /// @return void
  321. ///
  322. void SetDeviceId(uint32_t device_id) { device_id_ = device_id; }
  323. ///
  324. /// @ingroup ge
  325. /// @brief Get device Id
  326. /// @return device id
  327. ///
  328. uint32_t GetDeviceId() const { return device_id_; }
  329. bool NeedDestroyAicpuKernel() const { return need_destroy_aicpu_kernel_; }
  330. Status UpdateSessionId(uint64_t session_id);
  331. const RuntimeParam &GetRuntimeParam() { return runtime_param_; }
  332. int32_t GetDataInputTid() const { return dataInputTid; }
  333. void SetDataInputTid(int32_t data_input_tid) { dataInputTid = data_input_tid; }
  334. void DisableZeroCopy(const void *addr);
  335. bool GetOpDugReg() const { return is_op_debug_reg_; }
  336. ///
  337. /// @ingroup ge
  338. /// @brief Save outside address of Data or NetOutput used info for ZeroCopy.
  339. /// @param [in] const OpDescPtr &op_desc: current op desc
  340. /// @param [in] const std::vector<void *> &outside_addrs: address of task
  341. /// @param [in] const void *args_offset: arguments address save the address.
  342. /// @return None.
  343. ///
  344. void SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vector<void *> &outside_addrs, const void *info, void *args,
  345. size_t size, size_t offset);
  346. void SetDynamicSize(const std::vector<uint64_t> &batch_num, int32_t dynamic_type);
  347. bool GetL1FusionEnableOption() { return is_l1_fusion_enable_; }
  348. void SetProfileTime(ModelProcStage stage, int64_t endTime = 0);
  349. int64_t GetLoadBeginTime() { return load_begin_time_; }
  350. int64_t GetLoadEndTime() { return load_end_time_; }
  351. Status SinkModelProfile();
  352. Status SinkTimeProfile(const InputData &current_data);
  353. Status ReportProfilingData(bool check_device = true);
  354. void SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, uint32_t stream_id) {
  355. data_dumper_.SaveDumpOpInfo(model_param, op, task_id, stream_id);
  356. }
  357. void SaveDumpTask(uint32_t task_id, uint32_t stream_id, const std::shared_ptr<OpDesc> &op_desc, uintptr_t args) {
  358. data_dumper_.SaveDumpTask(task_id, stream_id, op_desc, args);
  359. }
  360. void SetEndGraphId(uint32_t task_id, uint32_t stream_id);
  361. DavinciModel &operator=(const DavinciModel &model) = delete;
  362. DavinciModel(const DavinciModel &model) = delete;
  363. const map<int64_t, std::vector<rtStream_t>> &GetHcclFolowStream() {
  364. return main_follow_stream_mapping_;
  365. }
  366. void SaveHcclFollowStream(int64_t main_stream_id, rtStream_t stream);
  367. void InitRuntimeParams();
  368. Status InitVariableMem();
  369. void UpdateMemBase(uint8_t *mem_base) {
  370. runtime_param_.mem_base = mem_base;
  371. mem_base_ = mem_base;
  372. }
  373. void SetTotalArgsSize(uint32_t args_size) { total_args_size_ += args_size; }
  374. uint32_t GetTotalArgsSize() { return total_args_size_; }
  375. void *GetCurrentArgsAddr(uint32_t offset) {
  376. void *cur_args = static_cast<char *>(args_) + offset;
  377. return cur_args;
  378. }
  379. void SetTotalIOAddrs(vector<void *> &io_addrs) {
  380. total_io_addrs_.insert(total_io_addrs_.end(), io_addrs.begin(), io_addrs.end());
  381. }
  382. void SetTotalFixedAddrsSize(string tensor_name, int64_t fix_addr_size);
  383. int64_t GetFixedAddrsSize(string tensor_name);
  384. void *GetCurrentFixedAddr(int64_t offset) const {
  385. void *cur_addr = static_cast<char *>(fixed_addrs_) + offset;
  386. return cur_addr;
  387. }
  388. uint32_t GetFixedAddrOutputIndex(string tensor_name) {
  389. if (tensor_name_to_peer_output_index_.find(tensor_name) != tensor_name_to_peer_output_index_.end()) {
  390. return tensor_name_to_peer_output_index_[tensor_name];
  391. }
  392. return UINT32_MAX;
  393. }
  394. void SetKnownNode(bool known_node) { known_node_ = known_node; }
  395. bool IsKnownNode() { return known_node_; }
  396. Status MallocKnownArgs();
  397. Status UpdateKnownNodeArgs(const vector<void *> &inputs, const vector<void *> &outputs);
  398. Status CreateKnownZeroCopyMap(const vector<void *> &inputs, const vector<void *> &outputs);
  399. Status UpdateKnownZeroCopyAddr();
  400. void SetKnownNodeAddrNotChanged(bool base_addr_not_changed) { base_addr_not_changed_ = base_addr_not_changed; }
  401. Status GetOrigInputInfo(uint32_t index, OriginInputInfo &orig_input_info);
  402. Status GetAllAippInputOutputDims(uint32_t index, std::vector<InputOutputDims> &input_dims,
  403. std::vector<InputOutputDims> &output_dims);
  404. void SetModelDescVersion(bool is_new_model_desc) { is_new_model_desc_ = is_new_model_desc; }
  405. // om file name
  406. void SetOmName(string om_name) { om_name_ = om_name; }
  407. void SetDumpProperties(const DumpProperties &dump_properties) { data_dumper_.SetDumpProperties(dump_properties); }
  408. const DumpProperties &GetDumpProperties() const { return data_dumper_.GetDumpProperties(); }
  409. bool GetOpDescInfo(uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info) const {
  410. return data_dumper_.GetOpDescInfo(stream_id, task_id, op_desc_info);
  411. }
  412. Status InitInputOutputForDynamic(const ComputeGraphPtr &compute_graph);
  413. private:
  414. // memory address of weights
  415. uint8_t *weights_mem_base_;
  416. uint8_t *var_mem_base_;
  417. // memory address of model
  418. uint8_t *mem_base_;
  419. uint8_t *p2p_mem_base_;
  420. bool is_inner_mem_base_;
  421. bool is_inner_weight_base_;
  422. bool is_inner_p2p_mem_base_;
  423. // input data manager
  424. DataInputer *data_inputer_;
  425. int64_t load_begin_time_;
  426. int64_t load_end_time_;
  427. struct timeInfo time_info_;
  428. int32_t dataInputTid;
  429. ///
  430. /// @ingroup ge
  431. /// @brief Save Batch label Info.
  432. /// @param [in] const OpDescPtr &op_desc
  433. /// @param [in] uintptr_t addr: address value in args block.
  434. /// @return None.
  435. ///
  436. void SetBatchLabelAddr(const OpDescPtr &op_desc, uintptr_t addr);
  437. ///
  438. /// @ingroup ge
  439. /// @brief Copy Check input size and model op size.
  440. /// @param [in] const int64_t &input_size: input size.
  441. /// @param [in] const int64_t &op_size: model op size.
  442. /// @param [in] is_dynamic: dynamic batch input flag.
  443. /// @return true if success
  444. ///
  445. bool CheckInputAndModelSize(const int64_t &input_size, const int64_t &op_size, bool is_dynamic);
  446. ///
  447. /// @ingroup ge
  448. /// @brief Set copy only for No task feed NetOutput address.
  449. /// @return None.
  450. ///
  451. void SetCopyOnlyOutput();
  452. ///
  453. /// @ingroup ge
  454. /// @brief Copy Input/Output to model for direct use.
  455. /// @param [in] const InputData &input_data: user input data info.
  456. /// @param [in/out] OutputData &output_data: user output data info.
  457. /// @param [in] bool is_dynamic: whether is dynamic input, true: is dynamic input; false: not is dynamic input
  458. /// @return SUCCESS handle successfully / others handle failed
  459. ///
  460. Status CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic);
  461. ///
  462. /// @ingroup ge
  463. /// @brief Copy Data addr to model for direct use.
  464. /// @param [in] data_info: model memory addr/size map { data_index, { tensor_size, tensor_addr } }.
  465. /// @param [in] is_input: input data or output data
  466. /// @param [in] blobs: user input/output data list.
  467. /// @param [in] is_dynamic: whether is dynamic input, true: is dynamic input; false: not is dynamic input
  468. /// @param [in] batch_label: batch label for multi-batch scenes
  469. /// @return SUCCESS handle successfully / others handle failed
  470. ///
  471. Status UpdateIoTaskArgs(const std::map<uint32_t, ZeroCopyOffset> &data_info, bool is_input,
  472. const vector<DataBuffer> &blobs, bool is_dynamic, const string &batch_label);
  473. Status CopyInputData(const InputData &input_data, bool device_data = false);
  474. Status CopyOutputData(uint32_t data_id, OutputData &output_data, rtMemcpyKind_t kind);
  475. Status SyncVarData();
  476. Status InitModelMem(void *dev_ptr, size_t memsize, void *weight_ptr, size_t weightsize);
  477. void CreateInputDimsInfo(const OpDescPtr &op_desc, Format format, InputOutputDescInfo &input);
  478. void SetInputDimsInfo(const vector<int64_t> &model_input_dims, Format &format, InputOutputDescInfo &input);
  479. Status GetInputDescInfo(vector<InputOutputDescInfo> &input_desc, std::vector<uint32_t> &formats);
  480. Status InitTaskInfo(domi::ModelTaskDef &modelTaskInfo);
  481. void UnbindHcomStream();
  482. Status DistributeTask();
  483. uint8_t *MallocFeatureMapMem(size_t data_size);
  484. uint8_t *MallocWeightsMem(size_t weights_size);
  485. uint8_t* MallocP2PMem(size_t p2p_data_size);
  486. void FreeFeatureMapMem();
  487. void FreeWeightsMem();
  488. void FreeP2PMem();
  489. void ReleaseTask();
  490. void UnbindTaskSinkStream();
  491. bool IsAicpuKernelConnectSpecifiedLayer();
  492. ///
  493. /// @ingroup ge
  494. /// @brief Reduce memory usage after task sink.
  495. /// @return: void
  496. ///
  497. void Shrink();
  498. ///
  499. /// @ingroup ge
  500. /// @brief Travel all nodes and do some init.
  501. /// @param [in] compute_graph: ComputeGraph to load.
  502. /// @return Status
  503. ///
  504. Status InitNodes(const ComputeGraphPtr &compute_graph);
  505. ///
  506. /// @ingroup ge
  507. /// @brief Data Op Initialize.
  508. /// @param [in] NodePtr: Data Op.
  509. /// @param [in/out] data_op_index: NetOutput addr size info.
  510. /// @return Status
  511. ///
  512. Status InitDataOp(const NodePtr &node, uint32_t &data_op_index, map<uint32_t, OpDescPtr> &data_by_index);
  513. ///
  514. /// @ingroup ge
  515. /// @brief Sort Data op list by index.
  516. /// @param [in] data_by_index: map of Data Op.
  517. /// @return
  518. ///
  519. void AdjustDataOpList(const map<uint32_t, OpDescPtr> &data_by_index);
  520. ///
  521. /// @ingroup ge
  522. /// @brief input zero copy node Initialize.
  523. /// @param [in] NodePtr: Data Op.
  524. /// @return Status
  525. ///
  526. Status InitInputZeroCopyNodes(const NodePtr &node);
  527. ///
  528. /// @ingroup ge
  529. /// @brief NetOutput Op Initialize.
  530. /// @param [in] NodePtr: NetOutput Op.
  531. /// @return Status
  532. ///
  533. Status InitNetOutput(const NodePtr &node);
  534. ///
  535. /// @ingroup ge
  536. /// @brief output zero copy node Initialize.
  537. /// @param [in] NodePtr: Data Op.
  538. /// @return Status
  539. ///
  540. Status InitOutputZeroCopyNodes(const NodePtr &node);
  541. ///
  542. /// @ingroup ge
  543. /// @brief input zero copy node Initialize for Case.
  544. /// @param [in] NodePtr: Data Op.
  545. /// @return Status
  546. ///
  547. Status InitInputBatchLabel(const NodePtr &node);
  548. ///
  549. /// @ingroup ge
  550. /// @brief output zero copy node Initialize for Case.
  551. /// @param [in] NodePtr: netoutput Op.
  552. /// @return Status
  553. ///
  554. Status InitOutputBatchLabel(const NodePtr &node);
  555. ///
  556. /// @ingroup ge
  557. /// @brief Constant Op Init.
  558. /// @return Status
  559. ///
  560. Status InitConstant(const OpDescPtr &op_desc);
  561. Status InitVariable(const OpDescPtr &op_desc);
  562. /// @ingroup ge
  563. /// @brief LabelSet Op Initialize.
  564. /// @param [in] op_desc: LabelSet Op descriptor.
  565. /// @return Status
  566. Status InitLabelSet(const OpDescPtr &op_desc);
  567. Status InitStreamSwitch(const OpDescPtr &op_desc);
  568. Status InitStreamActive(const OpDescPtr &op_desc);
  569. Status InitStreamSwitchN(const OpDescPtr &op_desc);
  570. ///
  571. /// @ingroup ge
  572. /// @brief Case Op Init.
  573. /// @return Status
  574. ///
  575. Status InitCase(const OpDescPtr &op_desc);
  576. Status SetDynamicBatchInfo(const OpDescPtr &op_desc, uint32_t batch_num);
  577. ///
  578. /// @ingroup ge
  579. /// @brief TVM Op Init.
  580. /// @return Status
  581. ///
  582. Status InitTbeHandle(const OpDescPtr &op_desc);
  583. void StoreTbeHandle(const std::string &handle_key);
  584. void CleanTbeHandle();
  585. ///
  586. /// @ingroup ge
  587. /// @brief Make active stream list and bind to model.
  588. /// @return: 0 for success / others for fail
  589. ///
  590. Status BindModelStream();
  591. ///
  592. /// @ingroup ge
  593. /// @brief Init model stream for NN model.
  594. /// @return Status
  595. ///
  596. Status InitModelStream(rtStream_t stream);
  597. ///
  598. /// @ingroup ge
  599. /// @brief ACL, Load task list with queue entrance.
  600. /// @return: 0 for success / others for fail
  601. ///
  602. Status LoadWithQueue();
  603. ///
  604. /// @ingroup ge
  605. /// @brief ACL, Bind Data Op addr to input queue.
  606. /// @return: 0 for success / others for fail
  607. ///
  608. Status BindInputQueue();
  609. Status CpuTaskModelZeroCopy(std::vector<uintptr_t> &mbuf_list, std::map<const void *, ZeroCopyOffset> &outside_addrs);
  610. ///
  611. /// @ingroup ge
  612. /// @brief ACL, Bind NetOutput Op addr to output queue.
  613. /// @return: 0 for success / others for fail
  614. ///
  615. Status BindOutputQueue();
  616. Status CpuModelPrepareOutput(uintptr_t addr, uint32_t size);
  617. ///
  618. /// @ingroup ge
  619. /// @brief definiteness queue schedule, bind input queue to task.
  620. /// @param [in] queue_id: input queue id from user.
  621. /// @param [in] addr: Data Op output tensor address.
  622. /// @param [in] size: Data Op output tensor size.
  623. /// @return: 0 for success / others for fail
  624. ///
  625. Status CpuModelDequeue(uint32_t queue_id);
  626. ///
  627. /// @ingroup ge
  628. /// @brief definiteness queue schedule, bind output queue to task.
  629. /// @param [in] queue_id: output queue id from user.
  630. /// @param [in] addr: NetOutput Op input tensor address.
  631. /// @param [in] size: NetOutput Op input tensor size.
  632. /// @return: 0 for success / others for fail
  633. ///
  634. Status CpuModelEnqueue(uint32_t queue_id, uintptr_t addr, uint32_t size);
  635. ///
  636. /// @ingroup ge
  637. /// @brief definiteness queue schedule, active original model stream.
  638. /// @return: 0 for success / others for fail
  639. ///
  640. Status CpuActiveStream();
  641. ///
  642. /// @ingroup ge
  643. /// @brief definiteness queue schedule, wait for end graph.
  644. /// @return: 0 for success / others for fail
  645. ///
  646. Status CpuWaitEndGraph();
  647. Status BindEnqueue();
  648. Status CpuModelEnqueue(uint32_t queue_id, uintptr_t out_mbuf);
  649. ///
  650. /// @ingroup ge
  651. /// @brief definiteness queue schedule, repeat run model.
  652. /// @return: 0 for success / others for fail
  653. ///
  654. Status CpuModelRepeat();
  655. Status InitEntryTask();
  656. Status AddHeadStream();
  657. ///
  658. /// @ingroup ge
  659. /// @brief set ts device.
  660. /// @return: 0 for success / others for fail
  661. ///
  662. Status SetTSDevice();
  663. Status OpDebugRegister();
  664. void OpDebugUnRegister();
  665. void CheckHasHcomOp();
  666. Status DoTaskSink();
  667. void CreateOutput(uint32_t index, OpDescPtr &op_desc, InputOutputDescInfo &output, uint32_t &format_result);
  668. Status TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id);
  669. // get desc info of graph for profiling
  670. Status GetComputeGraphInfo(vector<ComputeGraphDescInfo> &graph_desc_info);
  671. void SetDataDumperArgs(const ComputeGraphPtr &compute_graph);
  672. Status GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data_index, OutputData *output_data,
  673. std::vector<ge::OutputTensorInfo> &outputs);
  674. void ParseAIPPInfo(std::string in_out_info, InputOutputDims &dims_info);
  675. void SetLabelForDynamic(const NodePtr &node);
  676. void ParseDynamicOutShape(const std::vector<std::string> &str_info, std::vector<vector<int64_t>> &vec_info);
  677. bool IsGetNextSinkDynamic(const OpDescPtr &op_desc);
  678. void GetAllGearsInfo(const NodePtr &node);
  679. Status GetGetDynamicDimsNodeInfo(const NodePtr &node);
  680. Status GetGearAndRealOutSizeInfo(size_t input_count, const NodePtr &node);
  681. Status GetRealOutputSizeOfMerge(size_t input_index, const NodePtr &merge_node);
  682. Status GetGearAndRealOutShapeInfo(size_t input_count, const OpDescPtr &op_desc);
  683. bool is_model_has_inited_;
  684. uint32_t model_id_;
  685. uint32_t runtime_model_id_;
  686. string name_;
  687. // used for inference data dump
  688. string om_name_;
  689. uint32_t version_;
  690. GeModelPtr ge_model_;
  691. bool need_destroy_aicpu_kernel_{false};
  692. vector<std::string> out_node_name_;
  693. map<uint32_t, OpDescPtr> op_list_;
  694. // data op_desc
  695. vector<OpDescPtr> data_op_list_;
  696. vector<OpDescPtr> output_op_list_;
  697. vector<OpDescPtr> variable_op_list_;
  698. std::map<uint32_t, ZeroCopyOffset> new_input_data_info_;
  699. std::map<uint32_t, ZeroCopyOffset> new_output_data_info_;
  700. std::map<const void *, ZeroCopyOffset> new_input_outside_addrs_;
  701. std::map<const void *, ZeroCopyOffset> new_output_outside_addrs_;
  702. std::set<const void *> real_virtual_addrs_;
  703. // output op: save cce op actual needed memory size
  704. vector<int64_t> output_memory_size_list_;
  705. std::thread thread_id_;
  706. std::shared_ptr<ModelListener> listener_;
  707. bool run_flg_;
  708. std::mutex mux_run_flg_;
  709. int32_t priority_;
  710. vector<rtStream_t> stream_list_;
  711. std::mutex all_hccl_stream_list_mutex_;
  712. vector<rtStream_t> all_hccl_stream_list_;
  713. // for reuse hccl_follow_stream
  714. std::mutex capacity_of_stream_mutex_;
  715. std::map<int64_t, std::vector<rtStream_t>> main_follow_stream_mapping_;
  716. vector<rtEvent_t> event_list_;
  717. vector<rtLabel_t> label_list_;
  718. set<uint32_t> label_id_indication_;
  719. std::mutex outside_addrs_mutex_;
  720. std::vector<ZeroCopyTask> zero_copy_tasks_; // Task used Data or NetOutput addr.
  721. std::set<const void *> copy_only_addrs_; // Address need copy to original place.
  722. // {op_id, batch_label}
  723. std::map<int64_t, std::string> zero_copy_op_id_batch_label_;
  724. // {batch_label, addrs}
  725. std::map<std::string, std::set<uintptr_t>> zero_copy_batch_label_addrs_;
  726. std::vector<TaskInfoPtr> task_list_;
  727. // rt_moodel_handle
  728. rtModel_t rt_model_handle_;
  729. rtStream_t rt_model_stream_;
  730. bool is_inner_model_stream_;
  731. bool is_async_mode_; // For NN execute, Async mode use rtMemcpyAsync on rt_model_stream_.
  732. ExecuteMode last_execute_mode_;
  733. bool is_stream_list_bind_{false};
  734. bool is_pure_head_stream_{false};
  735. rtStream_t rt_head_stream_{nullptr};
  736. rtStream_t rt_entry_stream_{nullptr};
  737. rtAicpuDeployType_t deploy_type_{AICPU_DEPLOY_RESERVED};
  738. // ACL queue schedule, save queue ids for Init.
  739. std::vector<TaskInfoPtr> cpu_task_list_;
  740. std::vector<uint32_t> input_queue_ids_; // input queue ids created by caller.
  741. std::vector<uint32_t> output_queue_ids_; // output queue ids created by caller.
  742. std::vector<uintptr_t> input_mbuf_list_; // input mbuf created by dequeue task.
  743. std::vector<uintptr_t> output_mbuf_list_; // output mbuf created by dequeue task.
  744. uint64_t session_id_;
  745. uint32_t device_id_;
  746. std::mutex flowctrl_op_index_internal_map_mutex_;
  747. std::map<uint32_t, uint32_t> flowctrl_op_index_internal_map_;
  748. std::vector<rtStream_t> active_stream_list_;
  749. std::set<uint32_t> active_stream_indication_;
  750. std::set<uint32_t> hcom_streams_;
  751. RuntimeParam runtime_param_;
  752. static std::mutex tvm_bin_mutex_;
  753. std::set<std::string> tvm_bin_kernel_;
  754. std::map<std::string, uint32_t> used_tbe_handle_map_;
  755. // for profiling task and graph info
  756. std::vector<TaskDescInfo> task_desc_info_;
  757. int64_t maxDumpOpNum_;
  758. // for data dump
  759. DataDumper data_dumper_;
  760. uint64_t iterator_count_;
  761. bool is_l1_fusion_enable_;
  762. std::map<OpDescPtr, void *> saved_task_addrs_;
  763. void *l1_fusion_addr_ = nullptr;
  764. bool known_node_ = false;
  765. uint32_t total_args_size_ = 0;
  766. void *args_ = nullptr;
  767. void *args_host_ = nullptr;
  768. void *fixed_addrs_ = nullptr;
  769. int64_t total_fixed_addr_size_ = 0;
  770. std::map<const void *, void *> knonw_input_data_info_;
  771. std::map<const void *, void *> knonw_output_data_info_;
  772. vector<void *> total_io_addrs_;
  773. vector<void *> orig_total_io_addrs_;
  774. bool base_addr_not_changed_ = false;
  775. vector<vector<int64_t>> batch_info_;
  776. std::vector<std::vector<int64_t>> combined_batch_info_;
  777. vector<string> user_designate_shape_order_;
  778. int32_t dynamic_type_ = 0;
  779. bool is_dynamic_ = false;
  780. vector<uint64_t> batch_size_;
  781. // key: input tensor name, generally rts op;
  782. // value: the fixed addr of input anchor, same as the peer output anchor addr of the peer op
  783. std::map<string, int64_t> tensor_name_to_fixed_addr_size_;
  784. // key: input tensor name, generally rts op; value: the peer output anchor of the peer op
  785. std::map<string, int64_t> tensor_name_to_peer_output_index_;
  786. // if model is first execute
  787. bool is_first_execute_;
  788. // for op debug
  789. std::mutex debug_reg_mutex_;
  790. bool is_op_debug_reg_ = false;
  791. void *op_debug_addr_ = nullptr;
  792. void *p2p_debug_addr_ = nullptr;
  793. bool is_new_model_desc_{false};
  794. bool is_online_infer_dynamic_ = false;
  795. bool is_getnext_sink_dynamic_ = false;
  796. std::vector<int64_t> cur_dynamic_dims_;
  797. void *netoutput_last_input_addr_ = nullptr;
  798. int64_t netoutput_last_input_size_ = 0;
  799. size_t shape_of_cur_dynamic_dims_ = 0;
  800. // key: input_index: input is merge node; value: each gear info and each output size
  801. std::map<size_t, std::map<vector<int64_t>, int64_t>> merge_nodes_gear_and_real_out_size_info_;
  802. // key: input_index: input is merge node; value: each gear info and each output shape
  803. std::map<size_t, std::map<vector<int64_t>, vector<int64_t>>> merge_nodes_gear_and_real_out_shape_info_;
  804. std::vector<std::vector<int64_t>> all_gears_info_;
  805. };
  806. } // namespace ge
  807. #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示