You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

task_info.h 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
  17. #define INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
  18. #include <stdint.h>
  19. #include <functional>
  20. #include <memory>
  21. #include <string>
  22. #include <utility>
  23. #include <vector>
  24. #include "cce/taskdown_api.h"
  25. namespace ge {
  26. namespace model_runner {
  27. enum TaskInfoType {
  28. CCE = 0,
  29. TBE,
  30. AICPU,
  31. LABEL_SET,
  32. LABEL_SWITCH,
  33. LABEL_GOTO,
  34. EVENT_RECORD,
  35. EVENT_WAIT,
  36. FUSION_START,
  37. FUSION_END,
  38. HCCL,
  39. PROFILER_TRACE,
  40. MEMCPY_ASYNC,
  41. STREAM_SWITCH,
  42. STREAM_ACTIVE,
  43. // Insert new task type here
  44. REVSERVED = 23
  45. };
  46. class TaskInfo {
  47. public:
  48. virtual ~TaskInfo() {}
  49. uint32_t stream_id() const { return stream_id_; }
  50. TaskInfoType type() const { return type_; }
  51. std::string op_name() const { return op_name_; }
  52. bool dump_flag() const { return dump_flag_; }
  53. protected:
  54. TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag)
  55. : op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {}
  56. private:
  57. std::string op_name_;
  58. uint32_t stream_id_;
  59. TaskInfoType type_;
  60. bool dump_flag_;
  61. };
  62. class CceTaskInfo : public TaskInfo {
  63. public:
  64. CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func,
  65. uint32_t block_dim, const std::vector<uint8_t> &args, uint32_t args_size,
  66. const std::vector<uint8_t> &sm_desc, const std::vector<uint8_t> &flow_table,
  67. const std::vector<uint8_t> &args_offset, bool is_flowtable)
  68. : TaskInfo(op_name, stream_id, TaskInfoType::CCE, false),
  69. ctx_(ctx),
  70. stub_func_(stub_func),
  71. block_dim_(block_dim),
  72. args_(args),
  73. args_size_(args_size),
  74. sm_desc_(sm_desc),
  75. flow_table_(flow_table),
  76. args_offset_(args_offset),
  77. is_flowtable_(is_flowtable) {}
  78. ~CceTaskInfo() override {}
  79. cce::ccOpContext cc_context() const { return ctx_; }
  80. std::string stub_func() const { return stub_func_; }
  81. uint32_t block_dim() const { return block_dim_; }
  82. const std::vector<uint8_t> &args() const { return args_; }
  83. uint32_t args_size() const { return args_size_; }
  84. const std::vector<uint8_t> &sm_desc() const { return sm_desc_; }
  85. const std::vector<uint8_t> &flow_table() const { return flow_table_; }
  86. const std::vector<uint8_t> &args_offset() const { return args_offset_; }
  87. bool is_flowtable() const { return is_flowtable_; }
  88. private:
  89. cce::ccOpContext ctx_;
  90. std::string stub_func_;
  91. uint32_t block_dim_;
  92. std::vector<uint8_t> args_;
  93. uint32_t args_size_;
  94. std::vector<uint8_t> sm_desc_;
  95. std::vector<uint8_t> flow_table_;
  96. std::vector<uint8_t> args_offset_;
  97. bool is_flowtable_;
  98. };
  99. class TbeTaskInfo : public TaskInfo {
  100. public:
  101. TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim,
  102. const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary,
  103. uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
  104. const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag)
  105. : TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag),
  106. stub_func_(stub_func),
  107. block_dim_(block_dim),
  108. args_(args),
  109. args_size_(args_size),
  110. sm_desc_(sm_desc),
  111. binary_(binary),
  112. binary_size_(binary_size),
  113. meta_data_(meta_data),
  114. input_data_addrs_(input_data_addrs),
  115. output_data_addrs_(output_data_addrs),
  116. workspace_addrs_(workspace_addrs) {}
  117. ~TbeTaskInfo() override {}
  118. const std::string &stub_func() const { return stub_func_; }
  119. uint32_t block_dim() const { return block_dim_; }
  120. const std::vector<uint8_t> &args() const { return args_; }
  121. uint32_t args_size() const { return args_size_; }
  122. const std::vector<uint8_t> &sm_desc() const { return sm_desc_; }
  123. void *binary() const { return binary_; }
  124. uint32_t binary_size() const { return binary_size_; }
  125. const std::vector<uint8_t> &meta_data() const { return meta_data_; }
  126. const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
  127. const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
  128. const std::vector<void *> &workspace_addrs() const { return workspace_addrs_; }
  129. void SetBinary(void *binary, uint32_t binary_size) {
  130. binary_ = binary;
  131. binary_size_ = binary_size;
  132. }
  133. private:
  134. std::string stub_func_;
  135. uint32_t block_dim_;
  136. std::vector<uint8_t> args_;
  137. uint32_t args_size_;
  138. std::vector<uint8_t> sm_desc_;
  139. void *binary_;
  140. uint32_t binary_size_;
  141. std::vector<uint8_t> meta_data_;
  142. std::vector<void *> input_data_addrs_;
  143. std::vector<void *> output_data_addrs_;
  144. std::vector<void *> workspace_addrs_;
  145. };
  146. class AicpuTaskInfo : public TaskInfo {
  147. public:
  148. AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name,
  149. const std::string &node_def, const std::string &ext_info, const std::vector<void *> &input_data_addrs,
  150. const std::vector<void *> &output_data_addrs, bool dump_flag)
  151. : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag),
  152. so_name_(so_name),
  153. kernel_name_(kernel_name),
  154. node_def_(node_def),
  155. ext_info_(ext_info),
  156. input_data_addrs_(input_data_addrs),
  157. output_data_addrs_(output_data_addrs) {}
  158. ~AicpuTaskInfo() override {}
  159. const std::string &so_name() const { return so_name_; }
  160. const std::string &kernel_name() const { return kernel_name_; }
  161. const std::string &node_def() const { return node_def_; }
  162. const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
  163. const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
  164. const std::string &ext_info() const { return ext_info_; }
  165. private:
  166. std::string so_name_;
  167. std::string kernel_name_;
  168. std::string node_def_;
  169. std::string ext_info_;
  170. std::vector<void *> input_data_addrs_;
  171. std::vector<void *> output_data_addrs_;
  172. };
  173. class LabelSetTaskInfo : public TaskInfo {
  174. public:
  175. LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
  176. : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {}
  177. ~LabelSetTaskInfo() override {}
  178. uint32_t label_id() const { return label_id_; }
  179. private:
  180. uint32_t label_id_;
  181. };
  182. class LabelGotoTaskInfo : public TaskInfo {
  183. public:
  184. LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
  185. : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {}
  186. ~LabelGotoTaskInfo() override {}
  187. uint32_t label_id() const { return label_id_; }
  188. private:
  189. uint32_t label_id_;
  190. };
  191. class LabelSwitchTaskInfo : public TaskInfo {
  192. public:
  193. LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size,
  194. const std::vector<uint32_t> &label_list, void *cond)
  195. : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false),
  196. label_size_(label_size),
  197. label_list_(label_list),
  198. cond_(cond) {}
  199. ~LabelSwitchTaskInfo() override {}
  200. uint32_t label_size() { return label_size_; };
  201. const std::vector<uint32_t> &label_list() { return label_list_; };
  202. void *cond() { return cond_; };
  203. private:
  204. uint32_t label_size_;
  205. std::vector<uint32_t> label_list_;
  206. void *cond_;
  207. };
  208. class EventTaskInfo : public TaskInfo {
  209. public:
  210. uint32_t event_id() const { return event_id_; }
  211. protected:
  212. EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id)
  213. : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {}
  214. virtual ~EventTaskInfo() override {}
  215. uint32_t event_id_;
  216. };
  217. class EventRecordTaskInfo : public EventTaskInfo {
  218. public:
  219. EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
  220. : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
  221. ~EventRecordTaskInfo() override {}
  222. };
  223. class EventWaitTaskInfo : public EventTaskInfo {
  224. public:
  225. EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
  226. : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
  227. ~EventWaitTaskInfo() override {}
  228. };
  229. class FusionStartTaskInfo : public TaskInfo {
  230. public:
  231. explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id)
  232. : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {}
  233. ~FusionStartTaskInfo() override {}
  234. };
  235. class FusionEndTaskInfo : public TaskInfo {
  236. public:
  237. explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id)
  238. : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {}
  239. ~FusionEndTaskInfo() override {}
  240. };
  241. class HcclTaskInfo : public TaskInfo {
  242. public:
  243. HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
  244. void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
  245. const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
  246. int64_t op_type, int64_t data_type, const std::string &group,
  247. std::function<bool(void *, void *)> hcom_bind_model, std::function<bool(void *)> hcom_unbind_model,
  248. std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task, bool dump_flag)
  249. : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
  250. hccl_type_(hccl_type),
  251. input_data_addr_(input_data_addr),
  252. output_data_addr_(output_data_addr),
  253. workspace_addr_(workspace_addr),
  254. workspace_size_(workspace_size),
  255. hccl_stream_num_(hccl_stream_num),
  256. private_def_(private_def),
  257. ops_kernel_store_(ops_kernel_store),
  258. count_(count),
  259. root_id_(root_id),
  260. op_type_(op_type),
  261. data_type_(data_type),
  262. group_(group),
  263. hcom_bind_model_(hcom_bind_model),
  264. hcom_unbind_model_(hcom_unbind_model),
  265. hcom_distribute_task_(hcom_distribute_task) {}
  266. ~HcclTaskInfo() override {}
  267. const std::string &hccl_type() const { return hccl_type_; }
  268. void *input_data_addr() const { return input_data_addr_; }
  269. void *output_data_addr() const { return output_data_addr_; }
  270. void *workspace_addr() const { return workspace_addr_; }
  271. int64_t workspace_size() const { return workspace_size_; }
  272. int64_t hccl_stream_num() const { return hccl_stream_num_; }
  273. const std::vector<uint8_t> &private_def() const { return private_def_; }
  274. void *ops_kernel_store() const { return ops_kernel_store_; }
  275. int32_t count() const { return count_; }
  276. int64_t root_id() const { return root_id_; }
  277. int64_t op_type() const { return op_type_; }
  278. int64_t data_type() const { return data_type_; }
  279. const std::string &group() const { return group_; }
  280. std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; }
  281. std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; }
  282. std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const {
  283. return hcom_distribute_task_;
  284. }
  285. private:
  286. std::string hccl_type_;
  287. void *input_data_addr_;
  288. void *output_data_addr_;
  289. void *workspace_addr_;
  290. int64_t workspace_size_;
  291. int64_t hccl_stream_num_;
  292. std::vector<uint8_t> private_def_;
  293. void *ops_kernel_store_;
  294. int32_t count_;
  295. int64_t root_id_;
  296. int64_t op_type_;
  297. int64_t data_type_;
  298. std::string group_;
  299. std::function<bool(void *, void *)> hcom_bind_model_;
  300. std::function<bool(void *)> hcom_unbind_model_;
  301. std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task_;
  302. };
  303. class ProfilerTraceTaskInfo : public TaskInfo {
  304. public:
  305. ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
  306. : TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false),
  307. log_id_(log_id),
  308. notify_(notify),
  309. flat_(flat) {}
  310. ~ProfilerTraceTaskInfo() override {}
  311. uint64_t log_id() const { return log_id_; }
  312. bool notify() const { return notify_; }
  313. uint32_t flat() const { return flat_; }
  314. private:
  315. uint64_t log_id_;
  316. bool notify_;
  317. uint32_t flat_;
  318. };
  319. class MemcpyAsyncTaskInfo : public TaskInfo {
  320. public:
  321. MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src,
  322. uint64_t count, uint32_t kind, bool dump_flag)
  323. : TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag),
  324. dst_(dst),
  325. dst_max_(dst_max),
  326. src_(src),
  327. count_(count),
  328. kind_(kind) {}
  329. ~MemcpyAsyncTaskInfo() override {}
  330. void *dst() const { return dst_; }
  331. uint64_t dst_max() const { return dst_max_; }
  332. void *src() const { return src_; }
  333. uint64_t count() const { return count_; }
  334. uint32_t kind() const { return kind_; }
  335. private:
  336. void *dst_;
  337. uint64_t dst_max_;
  338. void *src_;
  339. uint64_t count_;
  340. int32_t kind_;
  341. };
  342. class StreamSwitchTaskInfo : public TaskInfo {
  343. public:
  344. StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr,
  345. void *value_addr, int64_t cond, int64_t data_type)
  346. : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false),
  347. true_stream_id_(true_stream_id),
  348. input_addr_(input_addr),
  349. value_addr_(value_addr),
  350. cond_(cond),
  351. data_type_(data_type) {}
  352. ~StreamSwitchTaskInfo() override {}
  353. int64_t true_stream_id() const { return true_stream_id_; }
  354. void *input_addr() const { return input_addr_; }
  355. void *value_addr() const { return value_addr_; }
  356. int64_t cond() const { return cond_; }
  357. int64_t data_type() const { return data_type_; }
  358. private:
  359. int64_t true_stream_id_;
  360. void *input_addr_;
  361. void *value_addr_;
  362. int64_t cond_;
  363. int64_t data_type_;
  364. };
  365. class StreamActiveTaskInfo : public TaskInfo {
  366. public:
  367. StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id)
  368. : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {}
  369. ~StreamActiveTaskInfo() override {}
  370. uint32_t active_stream_id() const { return active_stream_id_; }
  371. private:
  372. uint32_t active_stream_id_;
  373. };
  374. } // namespace model_runner
  375. } // namespace ge
  376. #endif // INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示