You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

task_info.h 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
  17. #define INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
  18. #include <stdint.h>
  19. #include <memory>
  20. #include <string>
  21. #include <utility>
  22. #include <vector>
  23. #include "cce/taskdown_api.h"
  24. namespace ge {
  25. namespace model_runner {
  26. enum TaskInfoType {
  27. CCE = 0,
  28. TBE,
  29. AICPU,
  30. LABEL_SET,
  31. LABEL_SWITCH,
  32. LABEL_GOTO,
  33. EVENT_RECORD,
  34. EVENT_WAIT,
  35. FUSION_START,
  36. FUSION_END,
  37. HCCL,
  38. PROFILER_TRACE,
  39. MEMCPY_ASYNC,
  40. STREAM_SWITCH,
  41. STREAM_ACTIVE,
  42. // Insert new task type here
  43. REVSERVED = 23
  44. };
  45. class TaskInfo {
  46. public:
  47. virtual ~TaskInfo() {}
  48. uint32_t stream_id() const {
  49. return stream_id_;
  50. }
  51. TaskInfoType type() const {
  52. return type_;
  53. }
  54. std::string op_name() const {
  55. return op_name_;
  56. }
  57. bool dump_flag() const {
  58. return dump_flag_;
  59. }
  60. protected:
  61. TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag)
  62. : op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {}
  63. private:
  64. std::string op_name_;
  65. uint32_t stream_id_;
  66. TaskInfoType type_;
  67. bool dump_flag_;
  68. };
  69. class CceTaskInfo : public TaskInfo {
  70. public:
  71. CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func,
  72. uint32_t block_dim, const std::vector<uint8_t> &args, uint32_t args_size,
  73. const std::vector<uint8_t> &sm_desc, const std::vector<uint8_t> &flow_table,
  74. const std::vector<uint8_t> &args_offset, bool is_flowtable)
  75. : TaskInfo(op_name, stream_id, TaskInfoType::CCE, false),
  76. ctx_(ctx),
  77. stub_func_(stub_func),
  78. block_dim_(block_dim),
  79. args_(args),
  80. args_size_(args_size),
  81. sm_desc_(sm_desc),
  82. flow_table_(flow_table),
  83. args_offset_(args_offset),
  84. is_flowtable_(is_flowtable) {}
  85. ~CceTaskInfo() override {}
  86. cce::ccOpContext cc_context() const {
  87. return ctx_;
  88. }
  89. std::string stub_func() const {
  90. return stub_func_;
  91. }
  92. uint32_t block_dim() const {
  93. return block_dim_;
  94. }
  95. const std::vector<uint8_t> &args() const {
  96. return args_;
  97. }
  98. uint32_t args_size() const {
  99. return args_size_;
  100. }
  101. const std::vector<uint8_t> &sm_desc() const {
  102. return sm_desc_;
  103. }
  104. const std::vector<uint8_t> &flow_table() const {
  105. return flow_table_;
  106. }
  107. const std::vector<uint8_t> &args_offset() const {
  108. return args_offset_;
  109. }
  110. bool is_flowtable() const {
  111. return is_flowtable_;
  112. }
  113. private:
  114. cce::ccOpContext ctx_;
  115. std::string stub_func_;
  116. uint32_t block_dim_;
  117. std::vector<uint8_t> args_;
  118. uint32_t args_size_;
  119. std::vector<uint8_t> sm_desc_;
  120. std::vector<uint8_t> flow_table_;
  121. std::vector<uint8_t> args_offset_;
  122. bool is_flowtable_;
  123. };
  124. class TbeTaskInfo : public TaskInfo {
  125. public:
  126. TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim,
  127. const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary,
  128. uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
  129. const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag)
  130. : TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag),
  131. stub_func_(stub_func),
  132. block_dim_(block_dim),
  133. args_(args),
  134. args_size_(args_size),
  135. sm_desc_(sm_desc),
  136. binary_(binary),
  137. binary_size_(binary_size),
  138. meta_data_(meta_data),
  139. input_data_addrs_(input_data_addrs),
  140. output_data_addrs_(output_data_addrs),
  141. workspace_addrs_(workspace_addrs) {}
  142. ~TbeTaskInfo() override {}
  143. const std::string &stub_func() const {
  144. return stub_func_;
  145. }
  146. uint32_t block_dim() const {
  147. return block_dim_;
  148. }
  149. const std::vector<uint8_t> &args() const {
  150. return args_;
  151. }
  152. uint32_t args_size() const {
  153. return args_size_;
  154. }
  155. const std::vector<uint8_t> &sm_desc() const {
  156. return sm_desc_;
  157. }
  158. void *binary() const {
  159. return binary_;
  160. }
  161. uint32_t binary_size() const {
  162. return binary_size_;
  163. }
  164. const std::vector<uint8_t> &meta_data() const {
  165. return meta_data_;
  166. }
  167. const std::vector<void *> &input_data_addrs() const {
  168. return input_data_addrs_;
  169. }
  170. const std::vector<void *> &output_data_addrs() const {
  171. return output_data_addrs_;
  172. }
  173. const std::vector<void *> &workspace_addrs() const {
  174. return workspace_addrs_;
  175. }
  176. void SetBinary(void *binary, uint32_t binary_size) {
  177. binary_ = binary;
  178. binary_size_ = binary_size;
  179. }
  180. private:
  181. std::string stub_func_;
  182. uint32_t block_dim_;
  183. std::vector<uint8_t> args_;
  184. uint32_t args_size_;
  185. std::vector<uint8_t> sm_desc_;
  186. void *binary_;
  187. uint32_t binary_size_;
  188. std::vector<uint8_t> meta_data_;
  189. std::vector<void *> input_data_addrs_;
  190. std::vector<void *> output_data_addrs_;
  191. std::vector<void *> workspace_addrs_;
  192. };
  193. class AicpuTaskInfo : public TaskInfo {
  194. public:
  195. AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name,
  196. const std::string &node_def, const std::string &ext_info, const std::vector<void *> &input_data_addrs,
  197. const std::vector<void *> &output_data_addrs, bool dump_flag)
  198. : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag),
  199. so_name_(so_name),
  200. kernel_name_(kernel_name),
  201. node_def_(node_def),
  202. ext_info_(ext_info),
  203. input_data_addrs_(input_data_addrs),
  204. output_data_addrs_(output_data_addrs) {}
  205. ~AicpuTaskInfo() override {}
  206. const std::string &so_name() const {
  207. return so_name_;
  208. }
  209. const std::string &kernel_name() const {
  210. return kernel_name_;
  211. }
  212. const std::string &node_def() const {
  213. return node_def_;
  214. }
  215. const std::vector<void *> &input_data_addrs() const {
  216. return input_data_addrs_;
  217. }
  218. const std::vector<void *> &output_data_addrs() const {
  219. return output_data_addrs_;
  220. }
  221. const std::string &ext_info() const {
  222. return ext_info_;
  223. }
  224. private:
  225. std::string so_name_;
  226. std::string kernel_name_;
  227. std::string node_def_;
  228. std::string ext_info_;
  229. std::vector<void *> input_data_addrs_;
  230. std::vector<void *> output_data_addrs_;
  231. };
  232. class LabelSetTaskInfo : public TaskInfo {
  233. public:
  234. LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
  235. : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {}
  236. ~LabelSetTaskInfo() override {}
  237. uint32_t label_id() const {
  238. return label_id_;
  239. }
  240. private:
  241. uint32_t label_id_;
  242. };
  243. class LabelGotoTaskInfo : public TaskInfo {
  244. public:
  245. LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
  246. : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {}
  247. ~LabelGotoTaskInfo() override {}
  248. uint32_t label_id() const {
  249. return label_id_;
  250. }
  251. private:
  252. uint32_t label_id_;
  253. };
  254. class LabelSwitchTaskInfo : public TaskInfo {
  255. public:
  256. LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size,
  257. const std::vector<uint32_t> &label_list, void *cond)
  258. : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false),
  259. label_size_(label_size),
  260. label_list_(label_list),
  261. cond_(cond) {}
  262. ~LabelSwitchTaskInfo() override {}
  263. uint32_t label_size() const {
  264. return label_size_;
  265. }
  266. const std::vector<uint32_t> &label_list() const {
  267. return label_list_;
  268. }
  269. void *cond() const {
  270. return cond_;
  271. }
  272. private:
  273. uint32_t label_size_;
  274. std::vector<uint32_t> label_list_;
  275. void *cond_;
  276. };
  277. class EventTaskInfo : public TaskInfo {
  278. public:
  279. uint32_t event_id() const {
  280. return event_id_;
  281. }
  282. protected:
  283. EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id)
  284. : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {}
  285. ~EventTaskInfo() override {}
  286. uint32_t event_id_;
  287. };
  288. class EventRecordTaskInfo : public EventTaskInfo {
  289. public:
  290. EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
  291. : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
  292. ~EventRecordTaskInfo() override {}
  293. };
  294. class EventWaitTaskInfo : public EventTaskInfo {
  295. public:
  296. EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
  297. : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
  298. ~EventWaitTaskInfo() override {}
  299. };
  300. class FusionStartTaskInfo : public TaskInfo {
  301. public:
  302. explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id)
  303. : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {}
  304. ~FusionStartTaskInfo() override {}
  305. };
  306. class FusionEndTaskInfo : public TaskInfo {
  307. public:
  308. explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id)
  309. : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {}
  310. ~FusionEndTaskInfo() override {}
  311. };
  312. class HcclTaskInfo : public TaskInfo {
  313. public:
  314. HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
  315. void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
  316. const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
  317. int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag)
  318. : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
  319. hccl_type_(hccl_type),
  320. input_data_addr_(input_data_addr),
  321. output_data_addr_(output_data_addr),
  322. workspace_addr_(workspace_addr),
  323. workspace_size_(workspace_size),
  324. hccl_stream_num_(hccl_stream_num),
  325. private_def_(private_def),
  326. ops_kernel_store_(ops_kernel_store),
  327. count_(count),
  328. root_id_(root_id),
  329. op_type_(op_type),
  330. data_type_(data_type),
  331. group_(group) {}
  332. ~HcclTaskInfo() override {}
  333. const std::string &hccl_type() const {
  334. return hccl_type_;
  335. }
  336. void *input_data_addr() const {
  337. return input_data_addr_;
  338. }
  339. void *output_data_addr() const {
  340. return output_data_addr_;
  341. }
  342. void *workspace_addr() const {
  343. return workspace_addr_;
  344. }
  345. int64_t workspace_size() const {
  346. return workspace_size_;
  347. }
  348. int64_t hccl_stream_num() const {
  349. return hccl_stream_num_;
  350. }
  351. const std::vector<uint8_t> &private_def() const {
  352. return private_def_;
  353. }
  354. void *ops_kernel_store() const {
  355. return ops_kernel_store_;
  356. }
  357. int32_t count() const {
  358. return count_;
  359. }
  360. int64_t root_id() const {
  361. return root_id_;
  362. }
  363. int64_t op_type() const {
  364. return op_type_;
  365. }
  366. int64_t data_type() const {
  367. return data_type_;
  368. }
  369. const std::string &group() const {
  370. return group_;
  371. }
  372. private:
  373. std::string hccl_type_;
  374. void *input_data_addr_;
  375. void *output_data_addr_;
  376. void *workspace_addr_;
  377. int64_t workspace_size_;
  378. int64_t hccl_stream_num_;
  379. std::vector<uint8_t> private_def_;
  380. void *ops_kernel_store_;
  381. int32_t count_;
  382. int64_t root_id_;
  383. int64_t op_type_;
  384. int64_t data_type_;
  385. std::string group_;
  386. };
  387. class ProfilerTraceTaskInfo : public TaskInfo {
  388. public:
  389. ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
  390. : TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false),
  391. log_id_(log_id),
  392. notify_(notify),
  393. flat_(flat) {}
  394. ~ProfilerTraceTaskInfo() override {}
  395. uint64_t log_id() const {
  396. return log_id_;
  397. }
  398. bool notify() const {
  399. return notify_;
  400. }
  401. uint32_t flat() const {
  402. return flat_;
  403. }
  404. private:
  405. uint64_t log_id_;
  406. bool notify_;
  407. uint32_t flat_;
  408. };
  409. class MemcpyAsyncTaskInfo : public TaskInfo {
  410. public:
  411. MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src,
  412. uint64_t count, uint32_t kind, bool dump_flag)
  413. : TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag),
  414. dst_(dst),
  415. dst_max_(dst_max),
  416. src_(src),
  417. count_(count),
  418. kind_(kind) {}
  419. ~MemcpyAsyncTaskInfo() override {}
  420. void *dst() const {
  421. return dst_;
  422. }
  423. uint64_t dst_max() const {
  424. return dst_max_;
  425. }
  426. void *src() const {
  427. return src_;
  428. }
  429. uint64_t count() const {
  430. return count_;
  431. }
  432. uint32_t kind() const {
  433. return kind_;
  434. }
  435. private:
  436. void *dst_;
  437. uint64_t dst_max_;
  438. void *src_;
  439. uint64_t count_;
  440. int32_t kind_;
  441. };
  442. class StreamSwitchTaskInfo : public TaskInfo {
  443. public:
  444. StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr,
  445. void *value_addr, int64_t cond, int64_t data_type)
  446. : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false),
  447. true_stream_id_(true_stream_id),
  448. input_addr_(input_addr),
  449. value_addr_(value_addr),
  450. cond_(cond),
  451. data_type_(data_type) {}
  452. ~StreamSwitchTaskInfo() override {}
  453. int64_t true_stream_id() const {
  454. return true_stream_id_;
  455. }
  456. void *input_addr() const {
  457. return input_addr_;
  458. }
  459. void *value_addr() const {
  460. return value_addr_;
  461. }
  462. int64_t cond() const {
  463. return cond_;
  464. }
  465. int64_t data_type() const {
  466. return data_type_;
  467. }
  468. private:
  469. int64_t true_stream_id_;
  470. void *input_addr_;
  471. void *value_addr_;
  472. int64_t cond_;
  473. int64_t data_type_;
  474. };
  475. class StreamActiveTaskInfo : public TaskInfo {
  476. public:
  477. StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id)
  478. : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {}
  479. ~StreamActiveTaskInfo() override {}
  480. uint32_t active_stream_id() const {
  481. return active_stream_id_;
  482. }
  483. private:
  484. uint32_t active_stream_id_;
  485. };
  486. } // namespace model_runner
  487. } // namespace ge
  488. #endif // INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示