You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

task_info.h 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
  17. #define INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
  18. #include <stdint.h>
  19. #include <memory>
  20. #include <string>
  21. #include <utility>
  22. #include <vector>
  23. #include "cce/taskdown_api.h"
  24. namespace ge {
  25. namespace model_runner {
  26. enum TaskInfoType {
  27. CCE = 0,
  28. TBE,
  29. AICPU,
  30. LABEL_SET,
  31. LABEL_SWITCH,
  32. LABEL_GOTO,
  33. EVENT_RECORD,
  34. EVENT_WAIT,
  35. FUSION_START,
  36. FUSION_END,
  37. HCCL,
  38. PROFILER_TRACE,
  39. MEMCPY_ASYNC,
  40. STREAM_SWITCH,
  41. STREAM_ACTIVE,
  42. // Insert new task type here
  43. REVSERVED = 23
  44. };
  45. class TaskInfo {
  46. public:
  47. virtual ~TaskInfo() {}
  48. uint32_t stream_id() const {
  49. return stream_id_;
  50. }
  51. TaskInfoType type() const {
  52. return type_;
  53. }
  54. std::string op_name() const {
  55. return op_name_;
  56. }
  57. bool dump_flag() const {
  58. return dump_flag_;
  59. }
  60. protected:
  61. TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag)
  62. : op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {}
  63. private:
  64. std::string op_name_;
  65. uint32_t stream_id_;
  66. TaskInfoType type_;
  67. bool dump_flag_;
  68. };
  69. class CceTaskInfo : public TaskInfo {
  70. public:
  71. CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func,
  72. uint32_t block_dim, const std::vector<uint8_t> &args, uint32_t args_size,
  73. const std::vector<uint8_t> &sm_desc, const std::vector<uint8_t> &flow_table,
  74. const std::vector<uint8_t> &args_offset, bool is_flowtable)
  75. : TaskInfo(op_name, stream_id, TaskInfoType::CCE, false),
  76. ctx_(ctx),
  77. stub_func_(stub_func),
  78. block_dim_(block_dim),
  79. args_(args),
  80. args_size_(args_size),
  81. sm_desc_(sm_desc),
  82. flow_table_(flow_table),
  83. args_offset_(args_offset),
  84. is_flowtable_(is_flowtable) {}
  85. ~CceTaskInfo() override {}
  86. cce::ccOpContext cc_context() const {
  87. return ctx_;
  88. }
  89. std::string stub_func() const {
  90. return stub_func_;
  91. }
  92. uint32_t block_dim() const {
  93. return block_dim_;
  94. }
  95. const std::vector<uint8_t> &args() const {
  96. return args_;
  97. }
  98. uint32_t args_size() const {
  99. return args_size_;
  100. }
  101. const std::vector<uint8_t> &sm_desc() const {
  102. return sm_desc_;
  103. }
  104. const std::vector<uint8_t> &flow_table() const {
  105. return flow_table_;
  106. }
  107. const std::vector<uint8_t> &args_offset() const {
  108. return args_offset_;
  109. }
  110. bool is_flowtable() const {
  111. return is_flowtable_;
  112. }
  113. private:
  114. cce::ccOpContext ctx_;
  115. std::string stub_func_;
  116. uint32_t block_dim_;
  117. std::vector<uint8_t> args_;
  118. uint32_t args_size_;
  119. std::vector<uint8_t> sm_desc_;
  120. std::vector<uint8_t> flow_table_;
  121. std::vector<uint8_t> args_offset_;
  122. bool is_flowtable_;
  123. };
  124. class TbeTaskInfo : public TaskInfo {
  125. public:
  126. TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim,
  127. const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary,
  128. uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
  129. const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag)
  130. : TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag),
  131. stub_func_(stub_func),
  132. block_dim_(block_dim),
  133. args_(args),
  134. args_size_(args_size),
  135. sm_desc_(sm_desc),
  136. binary_(binary),
  137. binary_size_(binary_size),
  138. meta_data_(meta_data),
  139. input_data_addrs_(input_data_addrs),
  140. output_data_addrs_(output_data_addrs),
  141. workspace_addrs_(workspace_addrs) {}
  142. ~TbeTaskInfo() override {}
  143. const std::string &stub_func() const {
  144. return stub_func_;
  145. }
  146. uint32_t block_dim() const {
  147. return block_dim_;
  148. }
  149. const std::vector<uint8_t> &args() const {
  150. return args_;
  151. }
  152. uint32_t args_size() const {
  153. return args_size_;
  154. }
  155. const std::vector<uint8_t> &sm_desc() const {
  156. return sm_desc_;
  157. }
  158. void *binary() const {
  159. return binary_;
  160. }
  161. uint32_t binary_size() const {
  162. return binary_size_;
  163. }
  164. const std::vector<uint8_t> &meta_data() const {
  165. return meta_data_;
  166. }
  167. const std::vector<void *> &input_data_addrs() const {
  168. return input_data_addrs_;
  169. }
  170. const std::vector<void *> &output_data_addrs() const {
  171. return output_data_addrs_;
  172. }
  173. const std::vector<void *> &workspace_addrs() const {
  174. return workspace_addrs_;
  175. }
  176. void SetBinary(void *binary, uint32_t binary_size) {
  177. binary_ = binary;
  178. binary_size_ = binary_size;
  179. }
  180. private:
  181. std::string stub_func_;
  182. uint32_t block_dim_;
  183. std::vector<uint8_t> args_;
  184. uint32_t args_size_;
  185. std::vector<uint8_t> sm_desc_;
  186. void *binary_;
  187. uint32_t binary_size_;
  188. std::vector<uint8_t> meta_data_;
  189. std::vector<void *> input_data_addrs_;
  190. std::vector<void *> output_data_addrs_;
  191. std::vector<void *> workspace_addrs_;
  192. };
  193. class AicpuTaskInfo : public TaskInfo {
  194. public:
  195. AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name,
  196. const std::string &node_def, const std::string &ext_info, const std::vector<void *> &input_data_addrs,
  197. const std::vector<void *> &output_data_addrs, bool dump_flag)
  198. : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag),
  199. so_name_(so_name),
  200. kernel_name_(kernel_name),
  201. node_def_(node_def),
  202. ext_info_(ext_info),
  203. input_data_addrs_(input_data_addrs),
  204. output_data_addrs_(output_data_addrs) {}
  205. ~AicpuTaskInfo() override {}
  206. const std::string &so_name() const {
  207. return so_name_;
  208. }
  209. const std::string &kernel_name() const {
  210. return kernel_name_;
  211. }
  212. const std::string &node_def() const {
  213. return node_def_;
  214. }
  215. const std::vector<void *> &input_data_addrs() const {
  216. return input_data_addrs_;
  217. }
  218. const std::vector<void *> &output_data_addrs() const {
  219. return output_data_addrs_;
  220. }
  221. const std::string &ext_info() const {
  222. return ext_info_;
  223. }
  224. private:
  225. std::string so_name_;
  226. std::string kernel_name_;
  227. std::string node_def_;
  228. std::string ext_info_;
  229. std::vector<void *> input_data_addrs_;
  230. std::vector<void *> output_data_addrs_;
  231. };
  232. class LabelSetTaskInfo : public TaskInfo {
  233. public:
  234. LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
  235. : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {}
  236. ~LabelSetTaskInfo() override {}
  237. uint32_t label_id() const {
  238. return label_id_;
  239. }
  240. private:
  241. uint32_t label_id_;
  242. };
  243. class LabelGotoTaskInfo : public TaskInfo {
  244. public:
  245. LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
  246. : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {}
  247. ~LabelGotoTaskInfo() override {}
  248. uint32_t label_id() const {
  249. return label_id_;
  250. }
  251. private:
  252. uint32_t label_id_;
  253. };
  254. class LabelSwitchTaskInfo : public TaskInfo {
  255. public:
  256. LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size,
  257. const std::vector<uint32_t> &label_list, void *cond)
  258. : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false),
  259. label_size_(label_size),
  260. label_list_(label_list),
  261. cond_(cond) {}
  262. ~LabelSwitchTaskInfo() override {}
  263. uint32_t label_size() const {
  264. return label_size_;
  265. }
  266. const std::vector<uint32_t> &label_list() const {
  267. return label_list_;
  268. }
  269. void *cond() const {
  270. return cond_;
  271. }
  272. private:
  273. uint32_t label_size_;
  274. std::vector<uint32_t> label_list_;
  275. void *cond_;
  276. };
  277. class EventTaskInfo : public TaskInfo {
  278. public:
  279. uint32_t event_id() const {
  280. return event_id_;
  281. }
  282. protected:
  283. EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id)
  284. : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {}
  285. ~EventTaskInfo() override {}
  286. uint32_t event_id_;
  287. };
  288. class EventRecordTaskInfo : public EventTaskInfo {
  289. public:
  290. EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
  291. : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
  292. ~EventRecordTaskInfo() override {}
  293. };
  294. class EventWaitTaskInfo : public EventTaskInfo {
  295. public:
  296. EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
  297. : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
  298. ~EventWaitTaskInfo() override {}
  299. };
  300. class FusionStartTaskInfo : public TaskInfo {
  301. public:
  302. explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id)
  303. : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {}
  304. ~FusionStartTaskInfo() override {}
  305. };
  306. class FusionEndTaskInfo : public TaskInfo {
  307. public:
  308. explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id)
  309. : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {}
  310. ~FusionEndTaskInfo() override {}
  311. };
  312. class HcclTaskInfo : public TaskInfo {
  313. public:
  314. HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
  315. void *output_data_addr, int64_t workspace_size, int64_t hccl_stream_num,
  316. const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
  317. int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag)
  318. : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
  319. hccl_type_(hccl_type),
  320. input_data_addr_(input_data_addr),
  321. output_data_addr_(output_data_addr),
  322. workspace_size_(workspace_size),
  323. hccl_stream_num_(hccl_stream_num),
  324. private_def_(private_def),
  325. ops_kernel_store_(ops_kernel_store),
  326. count_(count),
  327. root_id_(root_id),
  328. op_type_(op_type),
  329. data_type_(data_type),
  330. group_(group) {}
  331. ~HcclTaskInfo() override {}
  332. const std::string &hccl_type() const {
  333. return hccl_type_;
  334. }
  335. void *input_data_addr() const {
  336. return input_data_addr_;
  337. }
  338. void *output_data_addr() const {
  339. return output_data_addr_;
  340. }
  341. int64_t workspace_size() const {
  342. return workspace_size_;
  343. }
  344. int64_t hccl_stream_num() const {
  345. return hccl_stream_num_;
  346. }
  347. const std::vector<uint8_t> &private_def() const {
  348. return private_def_;
  349. }
  350. void *ops_kernel_store() const {
  351. return ops_kernel_store_;
  352. }
  353. int32_t count() const {
  354. return count_;
  355. }
  356. int64_t root_id() const {
  357. return root_id_;
  358. }
  359. int64_t op_type() const {
  360. return op_type_;
  361. }
  362. int64_t data_type() const {
  363. return data_type_;
  364. }
  365. const std::string &group() const {
  366. return group_;
  367. }
  368. private:
  369. std::string hccl_type_;
  370. void *input_data_addr_;
  371. void *output_data_addr_;
  372. int64_t workspace_size_;
  373. int64_t hccl_stream_num_;
  374. std::vector<uint8_t> private_def_;
  375. void *ops_kernel_store_;
  376. int32_t count_;
  377. int64_t root_id_;
  378. int64_t op_type_;
  379. int64_t data_type_;
  380. std::string group_;
  381. };
  382. class ProfilerTraceTaskInfo : public TaskInfo {
  383. public:
  384. ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
  385. : TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false),
  386. log_id_(log_id),
  387. notify_(notify),
  388. flat_(flat) {}
  389. ~ProfilerTraceTaskInfo() override {}
  390. uint64_t log_id() const {
  391. return log_id_;
  392. }
  393. bool notify() const {
  394. return notify_;
  395. }
  396. uint32_t flat() const {
  397. return flat_;
  398. }
  399. private:
  400. uint64_t log_id_;
  401. bool notify_;
  402. uint32_t flat_;
  403. };
  404. class MemcpyAsyncTaskInfo : public TaskInfo {
  405. public:
  406. MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src,
  407. uint64_t count, uint32_t kind, bool dump_flag)
  408. : TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag),
  409. dst_(dst),
  410. dst_max_(dst_max),
  411. src_(src),
  412. count_(count),
  413. kind_(kind) {}
  414. ~MemcpyAsyncTaskInfo() override {}
  415. void *dst() const {
  416. return dst_;
  417. }
  418. uint64_t dst_max() const {
  419. return dst_max_;
  420. }
  421. void *src() const {
  422. return src_;
  423. }
  424. uint64_t count() const {
  425. return count_;
  426. }
  427. uint32_t kind() const {
  428. return kind_;
  429. }
  430. private:
  431. void *dst_;
  432. uint64_t dst_max_;
  433. void *src_;
  434. uint64_t count_;
  435. int32_t kind_;
  436. };
  437. class StreamSwitchTaskInfo : public TaskInfo {
  438. public:
  439. StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr,
  440. void *value_addr, int64_t cond, int64_t data_type)
  441. : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false),
  442. true_stream_id_(true_stream_id),
  443. input_addr_(input_addr),
  444. value_addr_(value_addr),
  445. cond_(cond),
  446. data_type_(data_type) {}
  447. ~StreamSwitchTaskInfo() override {}
  448. int64_t true_stream_id() const {
  449. return true_stream_id_;
  450. }
  451. void *input_addr() const {
  452. return input_addr_;
  453. }
  454. void *value_addr() const {
  455. return value_addr_;
  456. }
  457. int64_t cond() const {
  458. return cond_;
  459. }
  460. int64_t data_type() const {
  461. return data_type_;
  462. }
  463. private:
  464. int64_t true_stream_id_;
  465. void *input_addr_;
  466. void *value_addr_;
  467. int64_t cond_;
  468. int64_t data_type_;
  469. };
  470. class StreamActiveTaskInfo : public TaskInfo {
  471. public:
  472. StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id)
  473. : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {}
  474. ~StreamActiveTaskInfo() override {}
  475. uint32_t active_stream_id() const {
  476. return active_stream_id_;
  477. }
  478. private:
  479. uint32_t active_stream_id_;
  480. };
  481. } // namespace model_runner
  482. } // namespace ge
  483. #endif // INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示