You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

task_info.h 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
  17. #define INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
  18. #include <stdint.h>
  19. #include <memory>
  20. #include <string>
  21. #include <utility>
  22. #include <vector>
  23. #include "cce/taskdown_api.h"
  24. namespace ge {
  25. namespace model_runner {
  26. enum TaskInfoType {
  27. CCE = 0,
  28. TBE,
  29. AICPU,
  30. LABEL_SET,
  31. LABEL_SWITCH,
  32. LABEL_GOTO,
  33. EVENT_RECORD,
  34. EVENT_WAIT,
  35. FUSION_START,
  36. FUSION_END,
  37. HCCL,
  38. PROFILER_TRACE,
  39. MEMCPY_ASYNC,
  40. STREAM_SWITCH,
  41. STREAM_ACTIVE,
  42. // Insert new task type here
  43. REVSERVED = 23
  44. };
  45. class TaskInfo {
  46. public:
  47. virtual ~TaskInfo() {}
  48. uint32_t stream_id() const { return stream_id_; }
  49. TaskInfoType type() const { return type_; }
  50. std::string op_name() const { return op_name_; }
  51. bool dump_flag() const { return dump_flag_; }
  52. protected:
  53. TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag)
  54. : op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {}
  55. private:
  56. std::string op_name_;
  57. uint32_t stream_id_;
  58. TaskInfoType type_;
  59. bool dump_flag_;
  60. };
  61. class CceTaskInfo : public TaskInfo {
  62. public:
  63. CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func,
  64. uint32_t block_dim, const std::vector<uint8_t> &args, uint32_t args_size,
  65. const std::vector<uint8_t> &sm_desc, const std::vector<uint8_t> &flow_table,
  66. const std::vector<uint8_t> &args_offset, bool is_flowtable)
  67. : TaskInfo(op_name, stream_id, TaskInfoType::CCE, false),
  68. ctx_(ctx),
  69. stub_func_(stub_func),
  70. block_dim_(block_dim),
  71. args_(args),
  72. args_size_(args_size),
  73. sm_desc_(sm_desc),
  74. flow_table_(flow_table),
  75. args_offset_(args_offset),
  76. is_flowtable_(is_flowtable) {}
  77. ~CceTaskInfo() override {}
  78. cce::ccOpContext cc_context() const { return ctx_; }
  79. std::string stub_func() const { return stub_func_; }
  80. uint32_t block_dim() const { return block_dim_; }
  81. const std::vector<uint8_t> &args() const { return args_; }
  82. uint32_t args_size() const { return args_size_; }
  83. const std::vector<uint8_t> &sm_desc() const { return sm_desc_; }
  84. const std::vector<uint8_t> &flow_table() const { return flow_table_; }
  85. const std::vector<uint8_t> &args_offset() const { return args_offset_; }
  86. bool is_flowtable() const { return is_flowtable_; }
  87. private:
  88. cce::ccOpContext ctx_;
  89. std::string stub_func_;
  90. uint32_t block_dim_;
  91. std::vector<uint8_t> args_;
  92. uint32_t args_size_;
  93. std::vector<uint8_t> sm_desc_;
  94. std::vector<uint8_t> flow_table_;
  95. std::vector<uint8_t> args_offset_;
  96. bool is_flowtable_;
  97. };
  98. class TbeTaskInfo : public TaskInfo {
  99. public:
  100. TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim,
  101. const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary,
  102. uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
  103. const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag)
  104. : TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag),
  105. stub_func_(stub_func),
  106. block_dim_(block_dim),
  107. args_(args),
  108. args_size_(args_size),
  109. sm_desc_(sm_desc),
  110. binary_(binary),
  111. binary_size_(binary_size),
  112. meta_data_(meta_data),
  113. input_data_addrs_(input_data_addrs),
  114. output_data_addrs_(output_data_addrs),
  115. workspace_addrs_(workspace_addrs) {}
  116. ~TbeTaskInfo() override {}
  117. const std::string &stub_func() const { return stub_func_; }
  118. uint32_t block_dim() const { return block_dim_; }
  119. const std::vector<uint8_t> &args() const { return args_; }
  120. uint32_t args_size() const { return args_size_; }
  121. const std::vector<uint8_t> &sm_desc() const { return sm_desc_; }
  122. void *binary() const { return binary_; }
  123. uint32_t binary_size() const { return binary_size_; }
  124. const std::vector<uint8_t> &meta_data() const { return meta_data_; }
  125. const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
  126. const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
  127. const std::vector<void *> &workspace_addrs() const { return workspace_addrs_; }
  128. void SetBinary(void *binary, uint32_t binary_size) {
  129. binary_ = binary;
  130. binary_size_ = binary_size;
  131. }
  132. private:
  133. std::string stub_func_;
  134. uint32_t block_dim_;
  135. std::vector<uint8_t> args_;
  136. uint32_t args_size_;
  137. std::vector<uint8_t> sm_desc_;
  138. void *binary_;
  139. uint32_t binary_size_;
  140. std::vector<uint8_t> meta_data_;
  141. std::vector<void *> input_data_addrs_;
  142. std::vector<void *> output_data_addrs_;
  143. std::vector<void *> workspace_addrs_;
  144. };
  145. class AicpuTaskInfo : public TaskInfo {
  146. public:
  147. AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name,
  148. const std::string &node_def, const std::string &ext_info, const std::vector<void *> &input_data_addrs,
  149. const std::vector<void *> &output_data_addrs, bool dump_flag)
  150. : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag),
  151. so_name_(so_name),
  152. kernel_name_(kernel_name),
  153. node_def_(node_def),
  154. ext_info_(ext_info),
  155. input_data_addrs_(input_data_addrs),
  156. output_data_addrs_(output_data_addrs) {}
  157. ~AicpuTaskInfo() override {}
  158. const std::string &so_name() const { return so_name_; }
  159. const std::string &kernel_name() const { return kernel_name_; }
  160. const std::string &node_def() const { return node_def_; }
  161. const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
  162. const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
  163. const std::string &ext_info() const { return ext_info_; }
  164. private:
  165. std::string so_name_;
  166. std::string kernel_name_;
  167. std::string node_def_;
  168. std::string ext_info_;
  169. std::vector<void *> input_data_addrs_;
  170. std::vector<void *> output_data_addrs_;
  171. };
  172. class LabelSetTaskInfo : public TaskInfo {
  173. public:
  174. LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
  175. : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {}
  176. ~LabelSetTaskInfo() override {}
  177. uint32_t label_id() const { return label_id_; }
  178. private:
  179. uint32_t label_id_;
  180. };
  181. class LabelGotoTaskInfo : public TaskInfo {
  182. public:
  183. LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
  184. : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {}
  185. ~LabelGotoTaskInfo() override {}
  186. uint32_t label_id() const { return label_id_; }
  187. private:
  188. uint32_t label_id_;
  189. };
  190. class LabelSwitchTaskInfo : public TaskInfo {
  191. public:
  192. LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size,
  193. const std::vector<uint32_t> &label_list, void *cond)
  194. : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false),
  195. label_size_(label_size),
  196. label_list_(label_list),
  197. cond_(cond) {}
  198. ~LabelSwitchTaskInfo() override {}
  199. uint32_t label_size() const { return label_size_; }
  200. const std::vector<uint32_t> &label_list() const { return label_list_; }
  201. void *cond() const { return cond_; }
  202. private:
  203. uint32_t label_size_;
  204. std::vector<uint32_t> label_list_;
  205. void *cond_;
  206. };
  207. class EventTaskInfo : public TaskInfo {
  208. public:
  209. uint32_t event_id() const { return event_id_; }
  210. protected:
  211. EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id)
  212. : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {}
  213. ~EventTaskInfo() override {}
  214. uint32_t event_id_;
  215. };
  216. class EventRecordTaskInfo : public EventTaskInfo {
  217. public:
  218. EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
  219. : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
  220. ~EventRecordTaskInfo() override {}
  221. };
  222. class EventWaitTaskInfo : public EventTaskInfo {
  223. public:
  224. EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
  225. : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
  226. ~EventWaitTaskInfo() override {}
  227. };
  228. class FusionStartTaskInfo : public TaskInfo {
  229. public:
  230. explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id)
  231. : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {}
  232. ~FusionStartTaskInfo() override {}
  233. };
  234. class FusionEndTaskInfo : public TaskInfo {
  235. public:
  236. explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id)
  237. : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {}
  238. ~FusionEndTaskInfo() override {}
  239. };
  240. class HcclTaskInfo : public TaskInfo {
  241. public:
  242. HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
  243. void *output_data_addr, int64_t workspace_size, int64_t hccl_stream_num,
  244. const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
  245. int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag)
  246. : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
  247. hccl_type_(hccl_type),
  248. input_data_addr_(input_data_addr),
  249. output_data_addr_(output_data_addr),
  250. workspace_size_(workspace_size),
  251. hccl_stream_num_(hccl_stream_num),
  252. private_def_(private_def),
  253. ops_kernel_store_(ops_kernel_store),
  254. count_(count),
  255. root_id_(root_id),
  256. op_type_(op_type),
  257. data_type_(data_type),
  258. group_(group) {}
  259. ~HcclTaskInfo() override {}
  260. const std::string &hccl_type() const { return hccl_type_; }
  261. void *input_data_addr() const { return input_data_addr_; }
  262. void *output_data_addr() const { return output_data_addr_; }
  263. int64_t workspace_size() const { return workspace_size_; }
  264. int64_t hccl_stream_num() const { return hccl_stream_num_; }
  265. const std::vector<uint8_t> &private_def() const { return private_def_; }
  266. void *ops_kernel_store() const { return ops_kernel_store_; }
  267. int32_t count() const { return count_; }
  268. int64_t root_id() const { return root_id_; }
  269. int64_t op_type() const { return op_type_; }
  270. int64_t data_type() const { return data_type_; }
  271. const std::string &group() const { return group_; }
  272. private:
  273. std::string hccl_type_;
  274. void *input_data_addr_;
  275. void *output_data_addr_;
  276. int64_t workspace_size_;
  277. int64_t hccl_stream_num_;
  278. std::vector<uint8_t> private_def_;
  279. void *ops_kernel_store_;
  280. int32_t count_;
  281. int64_t root_id_;
  282. int64_t op_type_;
  283. int64_t data_type_;
  284. std::string group_;
  285. };
  286. class ProfilerTraceTaskInfo : public TaskInfo {
  287. public:
  288. ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
  289. : TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false),
  290. log_id_(log_id),
  291. notify_(notify),
  292. flat_(flat) {}
  293. ~ProfilerTraceTaskInfo() override {}
  294. uint64_t log_id() const { return log_id_; }
  295. bool notify() const { return notify_; }
  296. uint32_t flat() const { return flat_; }
  297. private:
  298. uint64_t log_id_;
  299. bool notify_;
  300. uint32_t flat_;
  301. };
  302. class MemcpyAsyncTaskInfo : public TaskInfo {
  303. public:
  304. MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src,
  305. uint64_t count, uint32_t kind, bool dump_flag)
  306. : TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag),
  307. dst_(dst),
  308. dst_max_(dst_max),
  309. src_(src),
  310. count_(count),
  311. kind_(kind) {}
  312. ~MemcpyAsyncTaskInfo() override {}
  313. void *dst() const { return dst_; }
  314. uint64_t dst_max() const { return dst_max_; }
  315. void *src() const { return src_; }
  316. uint64_t count() const { return count_; }
  317. uint32_t kind() const { return kind_; }
  318. private:
  319. void *dst_;
  320. uint64_t dst_max_;
  321. void *src_;
  322. uint64_t count_;
  323. int32_t kind_;
  324. };
  325. class StreamSwitchTaskInfo : public TaskInfo {
  326. public:
  327. StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr,
  328. void *value_addr, int64_t cond, int64_t data_type)
  329. : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false),
  330. true_stream_id_(true_stream_id),
  331. input_addr_(input_addr),
  332. value_addr_(value_addr),
  333. cond_(cond),
  334. data_type_(data_type) {}
  335. ~StreamSwitchTaskInfo() override {}
  336. int64_t true_stream_id() const { return true_stream_id_; }
  337. void *input_addr() const { return input_addr_; }
  338. void *value_addr() const { return value_addr_; }
  339. int64_t cond() const { return cond_; }
  340. int64_t data_type() const { return data_type_; }
  341. private:
  342. int64_t true_stream_id_;
  343. void *input_addr_;
  344. void *value_addr_;
  345. int64_t cond_;
  346. int64_t data_type_;
  347. };
  348. class StreamActiveTaskInfo : public TaskInfo {
  349. public:
  350. StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id)
  351. : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {}
  352. ~StreamActiveTaskInfo() override {}
  353. uint32_t active_stream_id() const { return active_stream_id_; }
  354. private:
  355. uint32_t active_stream_id_;
  356. };
  357. } // namespace model_runner
  358. } // namespace ge
  359. #endif // INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示