You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

task_info.h 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
  17. #define INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
  18. #include <stdint.h>
  19. #include <functional>
  20. #include <memory>
  21. #include <string>
  22. #include <vector>
  23. #include "cce/taskdown_api.h"
  24. namespace ge {
  25. namespace model_runner {
  26. enum TaskInfoType {
  27. CCE = 0,
  28. TBE,
  29. AICPU,
  30. LABEL_SET,
  31. LABEL_SWITCH,
  32. LABEL_GOTO,
  33. EVENT_RECORD,
  34. EVENT_WAIT,
  35. FUSION_START,
  36. FUSION_END,
  37. HCCL,
  38. PROFILER_TRACE,
  39. MEMCPY_ASYNC,
  40. STREAM_SWITCH,
  41. STREAM_ACTIVE,
  42. // Insert new task type here
  43. REVSERVED = 23
  44. };
  45. class TaskInfo {
  46. public:
  47. virtual ~TaskInfo() {}
  48. uint32_t stream_id() const { return stream_id_; }
  49. TaskInfoType type() const { return type_; }
  50. protected:
  51. TaskInfo(uint32_t stream_id, TaskInfoType type) : stream_id_(stream_id), type_(type) {}
  52. private:
  53. uint32_t stream_id_;
  54. TaskInfoType type_;
  55. };
  56. class CceTaskInfo : public TaskInfo {
  57. public:
  58. CceTaskInfo(uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, uint32_t block_dim,
  59. const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc,
  60. const std::vector<uint8_t> &flow_table, const std::vector<uint8_t> &args_offset, bool is_flowtable)
  61. : TaskInfo(stream_id, TaskInfoType::CCE),
  62. ctx_(ctx),
  63. stub_func_(stub_func),
  64. block_dim_(block_dim),
  65. args_(args),
  66. args_size_(args_size),
  67. sm_desc_(sm_desc),
  68. flow_table_(flow_table),
  69. args_offset_(args_offset),
  70. is_flowtable_(is_flowtable) {}
  71. ~CceTaskInfo() override {}
  72. cce::ccOpContext cc_context() const { return ctx_; }
  73. std::string stub_func() const { return stub_func_; }
  74. uint32_t block_dim() const { return block_dim_; }
  75. const std::vector<uint8_t> &args() const { return args_; }
  76. uint32_t args_size() const { return args_size_; }
  77. const std::vector<uint8_t> &sm_desc() const { return sm_desc_; }
  78. const std::vector<uint8_t> &flow_table() const { return flow_table_; }
  79. const std::vector<uint8_t> &args_offset() const { return args_offset_; }
  80. bool is_flowtable() const { return is_flowtable_; }
  81. private:
  82. cce::ccOpContext ctx_;
  83. std::string stub_func_;
  84. uint32_t block_dim_;
  85. std::vector<uint8_t> args_;
  86. uint32_t args_size_;
  87. std::vector<uint8_t> sm_desc_;
  88. std::vector<uint8_t> flow_table_;
  89. std::vector<uint8_t> args_offset_;
  90. bool is_flowtable_;
  91. };
  92. class TbeTaskInfo : public TaskInfo {
  93. public:
  94. TbeTaskInfo(uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, const std::vector<uint8_t> &args,
  95. uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, uint32_t binary_size,
  96. const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
  97. const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs)
  98. : TaskInfo(stream_id, TaskInfoType::TBE),
  99. stub_func_(stub_func),
  100. block_dim_(block_dim),
  101. args_(args),
  102. args_size_(args_size),
  103. sm_desc_(sm_desc),
  104. binary_(binary),
  105. binary_size_(binary_size),
  106. meta_data_(meta_data),
  107. input_data_addrs_(input_data_addrs),
  108. output_data_addrs_(output_data_addrs),
  109. workspace_addrs_(workspace_addrs) {}
  110. ~TbeTaskInfo() override {}
  111. const std::string &stub_func() const { return stub_func_; }
  112. uint32_t block_dim() const { return block_dim_; }
  113. const std::vector<uint8_t> &args() const { return args_; }
  114. uint32_t args_size() const { return args_size_; }
  115. const std::vector<uint8_t> &sm_desc() const { return sm_desc_; }
  116. void *binary() const { return binary_; }
  117. uint32_t binary_size() const { return binary_size_; }
  118. const std::vector<uint8_t> &meta_data() const { return meta_data_; }
  119. const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
  120. const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
  121. const std::vector<void *> &workspace_addrs() const { return workspace_addrs_; }
  122. void SetBinary(void *binary, uint32_t binary_size) {
  123. binary_ = binary;
  124. binary_size_ = binary_size;
  125. }
  126. private:
  127. std::string stub_func_;
  128. uint32_t block_dim_;
  129. std::vector<uint8_t> args_;
  130. uint32_t args_size_;
  131. std::vector<uint8_t> sm_desc_;
  132. void *binary_;
  133. uint32_t binary_size_;
  134. std::vector<uint8_t> meta_data_;
  135. std::vector<void *> input_data_addrs_;
  136. std::vector<void *> output_data_addrs_;
  137. std::vector<void *> workspace_addrs_;
  138. };
  139. class AicpuTaskInfo : public TaskInfo {
  140. public:
  141. AicpuTaskInfo(uint32_t stream_id, const string &so_name, const std::string &kernel_name, const std::string &node_def,
  142. const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs)
  143. : TaskInfo(stream_id, TaskInfoType::AICPU),
  144. so_name_(so_name),
  145. kernel_name_(kernel_name),
  146. node_def_(node_def),
  147. input_data_addrs_(input_data_addrs),
  148. output_data_addrs_(output_data_addrs) {}
  149. ~AicpuTaskInfo() override {}
  150. const std::string &so_name() const { return so_name_; }
  151. const std::string &kernel_name() const { return kernel_name_; }
  152. const std::string &node_def() const { return node_def_; }
  153. const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
  154. const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
  155. private:
  156. std::string so_name_;
  157. std::string kernel_name_;
  158. std::string node_def_;
  159. std::vector<void *> input_data_addrs_;
  160. std::vector<void *> output_data_addrs_;
  161. };
  162. class LabelTaskInfo : public TaskInfo {
  163. public:
  164. uint32_t label_id() const { return label_id_; }
  165. protected:
  166. LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id)
  167. : TaskInfo(stream_id, type), label_id_(label_id) {}
  168. virtual ~LabelTaskInfo() override {}
  169. uint32_t label_id_;
  170. };
  171. class LabelSetTaskInfo : public LabelTaskInfo {
  172. public:
  173. LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id)
  174. : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {}
  175. ~LabelSetTaskInfo() override {}
  176. };
  177. class LabelSwitchTaskInfo : public LabelTaskInfo {
  178. public:
  179. LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id)
  180. : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {}
  181. ~LabelSwitchTaskInfo() override {}
  182. };
  183. class LabelGotoTaskInfo : public LabelTaskInfo {
  184. public:
  185. LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id)
  186. : LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {}
  187. ~LabelGotoTaskInfo() override {}
  188. };
  189. class EventTaskInfo : public TaskInfo {
  190. public:
  191. uint32_t event_id() const { return event_id_; }
  192. protected:
  193. EventTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t event_id)
  194. : TaskInfo(stream_id, type), event_id_(event_id) {}
  195. virtual ~EventTaskInfo() override {}
  196. uint32_t event_id_;
  197. };
  198. class EventRecordTaskInfo : public EventTaskInfo {
  199. public:
  200. EventRecordTaskInfo(uint32_t stream_id, uint32_t event_id)
  201. : EventTaskInfo(stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
  202. ~EventRecordTaskInfo() override {}
  203. };
  204. class EventWaitTaskInfo : public EventTaskInfo {
  205. public:
  206. EventWaitTaskInfo(uint32_t stream_id, uint32_t event_id)
  207. : EventTaskInfo(stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
  208. ~EventWaitTaskInfo() override {}
  209. };
  210. class FusionStartTaskInfo : public TaskInfo {
  211. public:
  212. explicit FusionStartTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_START) {}
  213. ~FusionStartTaskInfo() override {}
  214. };
  215. class FusionEndTaskInfo : public TaskInfo {
  216. public:
  217. explicit FusionEndTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_END) {}
  218. ~FusionEndTaskInfo() override {}
  219. };
  220. class HcclTaskInfo : public TaskInfo {
  221. public:
  222. HcclTaskInfo(uint32_t stream_id, const std::string hccl_type, void *input_data_addr, void *output_data_addr,
  223. void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
  224. const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
  225. int64_t op_type, int64_t data_type, const std::string &group,
  226. std::function<bool(void *, void *)> hcom_bind_model, std::function<bool(void *)> hcom_unbind_model,
  227. std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task)
  228. : TaskInfo(stream_id, TaskInfoType::HCCL),
  229. hccl_type_(hccl_type),
  230. input_data_addr_(input_data_addr),
  231. output_data_addr_(output_data_addr),
  232. workspace_addr_(workspace_addr),
  233. workspace_size_(workspace_size),
  234. hccl_stream_num_(hccl_stream_num),
  235. private_def_(private_def),
  236. ops_kernel_store_(ops_kernel_store),
  237. count_(count),
  238. root_id_(root_id),
  239. op_type_(op_type),
  240. data_type_(data_type),
  241. group_(group),
  242. hcom_bind_model_(hcom_bind_model),
  243. hcom_unbind_model_(hcom_unbind_model),
  244. hcom_distribute_task_(hcom_distribute_task) {}
  245. ~HcclTaskInfo() override {}
  246. const std::string &hccl_type() const { return hccl_type_; }
  247. void *input_data_addr() const { return input_data_addr_; }
  248. void *output_data_addr() const { return output_data_addr_; }
  249. void *workspace_addr() const { return workspace_addr_; }
  250. int64_t workspace_size() const { return workspace_size_; }
  251. int64_t hccl_stream_num() const { return hccl_stream_num_; }
  252. const std::vector<uint8_t> &private_def() const { return private_def_; }
  253. void *ops_kernel_store() const { return ops_kernel_store_; }
  254. int32_t count() const { return count_; }
  255. int64_t root_id() const { return root_id_; }
  256. int64_t op_type() const { return op_type_; }
  257. int64_t data_type() const { return data_type_; }
  258. const std::string &group() const { return group_; }
  259. std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; }
  260. std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; }
  261. std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const {
  262. return hcom_distribute_task_;
  263. }
  264. private:
  265. std::string hccl_type_;
  266. void *input_data_addr_;
  267. void *output_data_addr_;
  268. void *workspace_addr_;
  269. int64_t workspace_size_;
  270. int64_t hccl_stream_num_;
  271. std::vector<uint8_t> private_def_;
  272. void *ops_kernel_store_;
  273. int32_t count_;
  274. int64_t root_id_;
  275. int64_t op_type_;
  276. int64_t data_type_;
  277. std::string group_;
  278. std::function<bool(void *, void *)> hcom_bind_model_;
  279. std::function<bool(void *)> hcom_unbind_model_;
  280. std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task_;
  281. };
  282. class ProfilerTraceTaskInfo : public TaskInfo {
  283. public:
  284. ProfilerTraceTaskInfo(uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
  285. : TaskInfo(stream_id, TaskInfoType::PROFILER_TRACE), log_id_(log_id), notify_(notify), flat_(flat) {}
  286. ~ProfilerTraceTaskInfo() override {}
  287. uint64_t log_id() const { return log_id_; }
  288. bool notify() const { return notify_; }
  289. uint32_t flat() const { return flat_; }
  290. private:
  291. uint64_t log_id_;
  292. bool notify_;
  293. uint32_t flat_;
  294. };
  295. class MemcpyAsyncTaskInfo : public TaskInfo {
  296. public:
  297. MemcpyAsyncTaskInfo(uint32_t stream_id, void *dst, uint64_t dst_max, void *src, uint64_t count, uint32_t kind)
  298. : TaskInfo(stream_id, TaskInfoType::MEMCPY_ASYNC),
  299. dst_(dst),
  300. dst_max_(dst_max),
  301. src_(src),
  302. count_(count),
  303. kind_(kind) {}
  304. ~MemcpyAsyncTaskInfo() override {}
  305. void *dst() const { return dst_; }
  306. uint64_t dst_max() const { return dst_max_; }
  307. void *src() const { return src_; }
  308. uint64_t count() const { return count_; }
  309. uint32_t kind() const { return kind_; }
  310. private:
  311. void *dst_;
  312. uint64_t dst_max_;
  313. void *src_;
  314. uint64_t count_;
  315. int32_t kind_;
  316. };
  317. class StreamSwitchTaskInfo : public TaskInfo {
  318. public:
  319. StreamSwitchTaskInfo(uint32_t stream_id, int64_t true_stream_id, void *input_addr, void *value_addr, int64_t cond,
  320. int64_t data_type)
  321. : TaskInfo(stream_id, TaskInfoType::STREAM_SWITCH),
  322. true_stream_id_(true_stream_id),
  323. input_addr_(input_addr),
  324. value_addr_(value_addr),
  325. cond_(cond),
  326. data_type_(data_type) {}
  327. ~StreamSwitchTaskInfo() override {}
  328. int64_t true_stream_id() const { return true_stream_id_; }
  329. void *input_addr() const { return input_addr_; }
  330. void *value_addr() const { return value_addr_; }
  331. int64_t cond() const { return cond_; }
  332. int64_t data_type() const { return data_type_; }
  333. private:
  334. int64_t true_stream_id_;
  335. void *input_addr_;
  336. void *value_addr_;
  337. int64_t cond_;
  338. int64_t data_type_;
  339. };
  340. class StreamActiveTaskInfo : public TaskInfo {
  341. public:
  342. StreamActiveTaskInfo(uint32_t stream_id, uint32_t active_stream_id)
  343. : TaskInfo(stream_id, TaskInfoType::STREAM_ACTIVE), active_stream_id_(active_stream_id) {}
  344. ~StreamActiveTaskInfo() override {}
  345. uint32_t active_stream_id() const { return active_stream_id_; }
  346. private:
  347. uint32_t active_stream_id_;
  348. };
  349. } // namespace model_runner
  350. } // namespace ge
  351. #endif // INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示