You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

task_info.h 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
  17. #define INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
  18. #include <stdint.h>
  19. #include <functional>
  20. #include <memory>
  21. #include <string>
  22. #include <vector>
  23. #include "cce/taskdown_api.h"
  24. namespace ge {
  25. namespace model_runner {
  26. enum TaskInfoType {
  27. CCE = 0,
  28. TBE,
  29. AICPU,
  30. LABEL_SET,
  31. LABEL_SWITCH,
  32. LABEL_GOTO,
  33. EVENT_RECORD,
  34. EVENT_WAIT,
  35. FUSION_START,
  36. FUSION_END,
  37. HCCL,
  38. PROFILER_TRACE,
  39. MEMCPY_ASYNC,
  40. STREAM_SWITCH,
  41. STREAM_ACTIVE,
  42. // Insert new task type here
  43. REVSERVED = 23
  44. };
  45. class TaskInfo {
  46. public:
  47. virtual ~TaskInfo() {}
  48. uint32_t stream_id() const { return stream_id_; }
  49. TaskInfoType type() const { return type_; }
  50. protected:
  51. TaskInfo(uint32_t stream_id, TaskInfoType type) : stream_id_(stream_id), type_(type) {}
  52. private:
  53. uint32_t stream_id_;
  54. TaskInfoType type_;
  55. };
  56. class CceTaskInfo : public TaskInfo {
  57. public:
  58. CceTaskInfo(uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, uint32_t block_dim,
  59. const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc,
  60. const std::vector<uint8_t> &flow_table, const std::vector<uint8_t> &args_offset, bool is_flowtable)
  61. : TaskInfo(stream_id, TaskInfoType::CCE),
  62. ctx_(ctx),
  63. stub_func_(stub_func),
  64. block_dim_(block_dim),
  65. args_(args),
  66. args_size_(args_size),
  67. sm_desc_(sm_desc),
  68. flow_table_(flow_table),
  69. args_offset_(args_offset),
  70. is_flowtable_(is_flowtable) {}
  71. ~CceTaskInfo() override {}
  72. cce::ccOpContext cc_context() const { return ctx_; }
  73. std::string stub_func() const { return stub_func_; }
  74. uint32_t block_dim() const { return block_dim_; }
  75. const std::vector<uint8_t> &args() const { return args_; }
  76. uint32_t args_size() const { return args_size_; }
  77. const std::vector<uint8_t> &sm_desc() const { return sm_desc_; }
  78. const std::vector<uint8_t> &flow_table() const { return flow_table_; }
  79. const std::vector<uint8_t> &args_offset() const { return args_offset_; }
  80. bool is_flowtable() const { return is_flowtable_; }
  81. private:
  82. cce::ccOpContext ctx_;
  83. std::string stub_func_;
  84. uint32_t block_dim_;
  85. std::vector<uint8_t> args_;
  86. uint32_t args_size_;
  87. std::vector<uint8_t> sm_desc_;
  88. std::vector<uint8_t> flow_table_;
  89. std::vector<uint8_t> args_offset_;
  90. bool is_flowtable_;
  91. };
  92. class TbeTaskInfo : public TaskInfo {
  93. public:
  94. TbeTaskInfo(uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, const std::vector<uint8_t> &args,
  95. uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, uint32_t binary_size,
  96. const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
  97. const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs)
  98. : TaskInfo(stream_id, TaskInfoType::TBE),
  99. stub_func_(stub_func),
  100. block_dim_(block_dim),
  101. args_(args),
  102. args_size_(args_size),
  103. sm_desc_(sm_desc),
  104. binary_(binary),
  105. binary_size_(binary_size),
  106. meta_data_(meta_data),
  107. input_data_addrs_(input_data_addrs),
  108. output_data_addrs_(output_data_addrs),
  109. workspace_addrs_(workspace_addrs) {}
  110. ~TbeTaskInfo() override {}
  111. const std::string &stub_func() const { return stub_func_; }
  112. uint32_t block_dim() const { return block_dim_; }
  113. const std::vector<uint8_t> &args() const { return args_; }
  114. uint32_t args_size() const { return args_size_; }
  115. const std::vector<uint8_t> &sm_desc() const { return sm_desc_; }
  116. void *binary() const { return binary_; }
  117. uint32_t binary_size() const { return binary_size_; }
  118. const std::vector<uint8_t> &meta_data() const { return meta_data_; }
  119. const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
  120. const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
  121. const std::vector<void *> &workspace_addrs() const { return workspace_addrs_; }
  122. void SetBinary(void *binary, uint32_t binary_size) {
  123. binary_ = binary;
  124. binary_size_ = binary_size;
  125. }
  126. private:
  127. std::string stub_func_;
  128. uint32_t block_dim_;
  129. std::vector<uint8_t> args_;
  130. uint32_t args_size_;
  131. std::vector<uint8_t> sm_desc_;
  132. void *binary_;
  133. uint32_t binary_size_;
  134. std::vector<uint8_t> meta_data_;
  135. std::vector<void *> input_data_addrs_;
  136. std::vector<void *> output_data_addrs_;
  137. std::vector<void *> workspace_addrs_;
  138. };
  139. class AicpuTaskInfo : public TaskInfo {
  140. public:
  141. AicpuTaskInfo(uint32_t stream_id, const string &so_name, const std::string &kernel_name, const std::string &node_def,
  142. const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs)
  143. : TaskInfo(stream_id, TaskInfoType::AICPU),
  144. so_name_(so_name),
  145. kernel_name_(kernel_name),
  146. node_def_(node_def),
  147. input_data_addrs_(input_data_addrs),
  148. output_data_addrs_(output_data_addrs) {}
  149. ~AicpuTaskInfo() override {}
  150. const std::string &so_name() const { return so_name_; }
  151. const std::string &kernel_name() const { return kernel_name_; }
  152. const std::string &node_def() const { return node_def_; }
  153. const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
  154. const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
  155. private:
  156. std::string so_name_;
  157. std::string kernel_name_;
  158. std::string node_def_;
  159. std::vector<void *> input_data_addrs_;
  160. std::vector<void *> output_data_addrs_;
  161. };
  162. class LabelSetTaskInfo : public TaskInfo {
  163. public:
  164. LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id)
  165. : TaskInfo(stream_id, TaskInfoType::LABEL_SET), label_id_(label_id) {}
  166. ~LabelSetTaskInfo() override {}
  167. uint32_t label_id() const { return label_id_; }
  168. private:
  169. uint32_t label_id_;
  170. };
  171. class LabelGotoTaskInfo : public TaskInfo {
  172. public:
  173. LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id)
  174. : TaskInfo(stream_id, TaskInfoType::LABEL_GOTO), label_id_(label_id) {}
  175. ~LabelGotoTaskInfo() override {}
  176. uint32_t label_id() const { return label_id_; }
  177. private:
  178. uint32_t label_id_;
  179. };
  180. class LabelSwitchTaskInfo : public TaskInfo {
  181. public:
  182. LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_size, const std::vector<uint32_t> &label_list, void *cond)
  183. : TaskInfo(stream_id, TaskInfoType::LABEL_SWITCH),
  184. label_size_(label_size),
  185. label_list_(label_list),
  186. cond_(cond) {}
  187. ~LabelSwitchTaskInfo() override {}
  188. uint32_t label_size() { return label_size_; };
  189. const std::vector<uint32_t> &label_list() { return label_list_; };
  190. void *cond() { return cond_; };
  191. private:
  192. uint32_t label_size_;
  193. std::vector<uint32_t> label_list_;
  194. void *cond_;
  195. };
  196. class EventTaskInfo : public TaskInfo {
  197. public:
  198. uint32_t event_id() const { return event_id_; }
  199. protected:
  200. EventTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t event_id)
  201. : TaskInfo(stream_id, type), event_id_(event_id) {}
  202. virtual ~EventTaskInfo() override {}
  203. uint32_t event_id_;
  204. };
  205. class EventRecordTaskInfo : public EventTaskInfo {
  206. public:
  207. EventRecordTaskInfo(uint32_t stream_id, uint32_t event_id)
  208. : EventTaskInfo(stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
  209. ~EventRecordTaskInfo() override {}
  210. };
  211. class EventWaitTaskInfo : public EventTaskInfo {
  212. public:
  213. EventWaitTaskInfo(uint32_t stream_id, uint32_t event_id)
  214. : EventTaskInfo(stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
  215. ~EventWaitTaskInfo() override {}
  216. };
  217. class FusionStartTaskInfo : public TaskInfo {
  218. public:
  219. explicit FusionStartTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_START) {}
  220. ~FusionStartTaskInfo() override {}
  221. };
  222. class FusionEndTaskInfo : public TaskInfo {
  223. public:
  224. explicit FusionEndTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_END) {}
  225. ~FusionEndTaskInfo() override {}
  226. };
  227. class HcclTaskInfo : public TaskInfo {
  228. public:
  229. HcclTaskInfo(uint32_t stream_id, const std::string hccl_type, void *input_data_addr, void *output_data_addr,
  230. void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
  231. const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
  232. int64_t op_type, int64_t data_type, const std::string &group,
  233. std::function<bool(void *, void *)> hcom_bind_model, std::function<bool(void *)> hcom_unbind_model,
  234. std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task)
  235. : TaskInfo(stream_id, TaskInfoType::HCCL),
  236. hccl_type_(hccl_type),
  237. input_data_addr_(input_data_addr),
  238. output_data_addr_(output_data_addr),
  239. workspace_addr_(workspace_addr),
  240. workspace_size_(workspace_size),
  241. hccl_stream_num_(hccl_stream_num),
  242. private_def_(private_def),
  243. ops_kernel_store_(ops_kernel_store),
  244. count_(count),
  245. root_id_(root_id),
  246. op_type_(op_type),
  247. data_type_(data_type),
  248. group_(group),
  249. hcom_bind_model_(hcom_bind_model),
  250. hcom_unbind_model_(hcom_unbind_model),
  251. hcom_distribute_task_(hcom_distribute_task) {}
  252. ~HcclTaskInfo() override {}
  253. const std::string &hccl_type() const { return hccl_type_; }
  254. void *input_data_addr() const { return input_data_addr_; }
  255. void *output_data_addr() const { return output_data_addr_; }
  256. void *workspace_addr() const { return workspace_addr_; }
  257. int64_t workspace_size() const { return workspace_size_; }
  258. int64_t hccl_stream_num() const { return hccl_stream_num_; }
  259. const std::vector<uint8_t> &private_def() const { return private_def_; }
  260. void *ops_kernel_store() const { return ops_kernel_store_; }
  261. int32_t count() const { return count_; }
  262. int64_t root_id() const { return root_id_; }
  263. int64_t op_type() const { return op_type_; }
  264. int64_t data_type() const { return data_type_; }
  265. const std::string &group() const { return group_; }
  266. std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; }
  267. std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; }
  268. std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const {
  269. return hcom_distribute_task_;
  270. }
  271. private:
  272. std::string hccl_type_;
  273. void *input_data_addr_;
  274. void *output_data_addr_;
  275. void *workspace_addr_;
  276. int64_t workspace_size_;
  277. int64_t hccl_stream_num_;
  278. std::vector<uint8_t> private_def_;
  279. void *ops_kernel_store_;
  280. int32_t count_;
  281. int64_t root_id_;
  282. int64_t op_type_;
  283. int64_t data_type_;
  284. std::string group_;
  285. std::function<bool(void *, void *)> hcom_bind_model_;
  286. std::function<bool(void *)> hcom_unbind_model_;
  287. std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task_;
  288. };
  289. class ProfilerTraceTaskInfo : public TaskInfo {
  290. public:
  291. ProfilerTraceTaskInfo(uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
  292. : TaskInfo(stream_id, TaskInfoType::PROFILER_TRACE), log_id_(log_id), notify_(notify), flat_(flat) {}
  293. ~ProfilerTraceTaskInfo() override {}
  294. uint64_t log_id() const { return log_id_; }
  295. bool notify() const { return notify_; }
  296. uint32_t flat() const { return flat_; }
  297. private:
  298. uint64_t log_id_;
  299. bool notify_;
  300. uint32_t flat_;
  301. };
  302. class MemcpyAsyncTaskInfo : public TaskInfo {
  303. public:
  304. MemcpyAsyncTaskInfo(uint32_t stream_id, void *dst, uint64_t dst_max, void *src, uint64_t count, uint32_t kind)
  305. : TaskInfo(stream_id, TaskInfoType::MEMCPY_ASYNC),
  306. dst_(dst),
  307. dst_max_(dst_max),
  308. src_(src),
  309. count_(count),
  310. kind_(kind) {}
  311. ~MemcpyAsyncTaskInfo() override {}
  312. void *dst() const { return dst_; }
  313. uint64_t dst_max() const { return dst_max_; }
  314. void *src() const { return src_; }
  315. uint64_t count() const { return count_; }
  316. uint32_t kind() const { return kind_; }
  317. private:
  318. void *dst_;
  319. uint64_t dst_max_;
  320. void *src_;
  321. uint64_t count_;
  322. int32_t kind_;
  323. };
  324. class StreamSwitchTaskInfo : public TaskInfo {
  325. public:
  326. StreamSwitchTaskInfo(uint32_t stream_id, int64_t true_stream_id, void *input_addr, void *value_addr, int64_t cond,
  327. int64_t data_type)
  328. : TaskInfo(stream_id, TaskInfoType::STREAM_SWITCH),
  329. true_stream_id_(true_stream_id),
  330. input_addr_(input_addr),
  331. value_addr_(value_addr),
  332. cond_(cond),
  333. data_type_(data_type) {}
  334. ~StreamSwitchTaskInfo() override {}
  335. int64_t true_stream_id() const { return true_stream_id_; }
  336. void *input_addr() const { return input_addr_; }
  337. void *value_addr() const { return value_addr_; }
  338. int64_t cond() const { return cond_; }
  339. int64_t data_type() const { return data_type_; }
  340. private:
  341. int64_t true_stream_id_;
  342. void *input_addr_;
  343. void *value_addr_;
  344. int64_t cond_;
  345. int64_t data_type_;
  346. };
  347. class StreamActiveTaskInfo : public TaskInfo {
  348. public:
  349. StreamActiveTaskInfo(uint32_t stream_id, uint32_t active_stream_id)
  350. : TaskInfo(stream_id, TaskInfoType::STREAM_ACTIVE), active_stream_id_(active_stream_id) {}
  351. ~StreamActiveTaskInfo() override {}
  352. uint32_t active_stream_id() const { return active_stream_id_; }
  353. private:
  354. uint32_t active_stream_id_;
  355. };
  356. } // namespace model_runner
  357. } // namespace ge
  358. #endif // INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示