You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

task_info.h 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
  17. #define INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_
  18. #include <stdint.h>
  19. #include <functional>
  20. #include <memory>
  21. #include <string>
  22. #include <utility>
  23. #include <vector>
  24. #include "cce/taskdown_api.h"
  25. namespace ge {
  26. namespace model_runner {
  27. enum TaskInfoType {
  28. CCE = 0,
  29. TBE,
  30. AICPU,
  31. LABEL_SET,
  32. LABEL_SWITCH,
  33. LABEL_GOTO,
  34. EVENT_RECORD,
  35. EVENT_WAIT,
  36. FUSION_START,
  37. FUSION_END,
  38. HCCL,
  39. PROFILER_TRACE,
  40. MEMCPY_ASYNC,
  41. STREAM_SWITCH,
  42. STREAM_ACTIVE,
  43. // Insert new task type here
  44. REVSERVED = 23
  45. };
  46. class TaskInfo {
  47. public:
  48. virtual ~TaskInfo() {}
  49. uint32_t stream_id() const { return stream_id_; }
  50. TaskInfoType type() const { return type_; }
  51. std::string op_name() const { return op_name_; }
  52. bool dump_flag() const { return dump_flag_; }
  53. protected:
  54. TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag)
  55. : op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {}
  56. private:
  57. std::string op_name_;
  58. uint32_t stream_id_;
  59. TaskInfoType type_;
  60. bool dump_flag_;
  61. };
  62. class CceTaskInfo : public TaskInfo {
  63. public:
  64. CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func,
  65. uint32_t block_dim, const std::vector<uint8_t> &args, uint32_t args_size,
  66. const std::vector<uint8_t> &sm_desc, const std::vector<uint8_t> &flow_table,
  67. const std::vector<uint8_t> &args_offset, bool is_flowtable)
  68. : TaskInfo(op_name, stream_id, TaskInfoType::CCE, false),
  69. ctx_(ctx),
  70. stub_func_(stub_func),
  71. block_dim_(block_dim),
  72. args_(args),
  73. args_size_(args_size),
  74. sm_desc_(sm_desc),
  75. flow_table_(flow_table),
  76. args_offset_(args_offset),
  77. is_flowtable_(is_flowtable) {}
  78. ~CceTaskInfo() override {}
  79. cce::ccOpContext cc_context() const { return ctx_; }
  80. std::string stub_func() const { return stub_func_; }
  81. uint32_t block_dim() const { return block_dim_; }
  82. const std::vector<uint8_t> &args() const { return args_; }
  83. uint32_t args_size() const { return args_size_; }
  84. const std::vector<uint8_t> &sm_desc() const { return sm_desc_; }
  85. const std::vector<uint8_t> &flow_table() const { return flow_table_; }
  86. const std::vector<uint8_t> &args_offset() const { return args_offset_; }
  87. bool is_flowtable() const { return is_flowtable_; }
  88. private:
  89. cce::ccOpContext ctx_;
  90. std::string stub_func_;
  91. uint32_t block_dim_;
  92. std::vector<uint8_t> args_;
  93. uint32_t args_size_;
  94. std::vector<uint8_t> sm_desc_;
  95. std::vector<uint8_t> flow_table_;
  96. std::vector<uint8_t> args_offset_;
  97. bool is_flowtable_;
  98. };
  99. class TbeTaskInfo : public TaskInfo {
  100. public:
  101. TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim,
  102. const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary,
  103. uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
  104. const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag)
  105. : TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag),
  106. stub_func_(stub_func),
  107. block_dim_(block_dim),
  108. args_(args),
  109. args_size_(args_size),
  110. sm_desc_(sm_desc),
  111. binary_(binary),
  112. binary_size_(binary_size),
  113. meta_data_(meta_data),
  114. input_data_addrs_(input_data_addrs),
  115. output_data_addrs_(output_data_addrs),
  116. workspace_addrs_(workspace_addrs) {}
  117. ~TbeTaskInfo() override {}
  118. const std::string &stub_func() const { return stub_func_; }
  119. uint32_t block_dim() const { return block_dim_; }
  120. const std::vector<uint8_t> &args() const { return args_; }
  121. uint32_t args_size() const { return args_size_; }
  122. const std::vector<uint8_t> &sm_desc() const { return sm_desc_; }
  123. void *binary() const { return binary_; }
  124. uint32_t binary_size() const { return binary_size_; }
  125. const std::vector<uint8_t> &meta_data() const { return meta_data_; }
  126. const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
  127. const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
  128. const std::vector<void *> &workspace_addrs() const { return workspace_addrs_; }
  129. void SetBinary(void *binary, uint32_t binary_size) {
  130. binary_ = binary;
  131. binary_size_ = binary_size;
  132. }
  133. private:
  134. std::string stub_func_;
  135. uint32_t block_dim_;
  136. std::vector<uint8_t> args_;
  137. uint32_t args_size_;
  138. std::vector<uint8_t> sm_desc_;
  139. void *binary_;
  140. uint32_t binary_size_;
  141. std::vector<uint8_t> meta_data_;
  142. std::vector<void *> input_data_addrs_;
  143. std::vector<void *> output_data_addrs_;
  144. std::vector<void *> workspace_addrs_;
  145. };
  146. class AicpuTaskInfo : public TaskInfo {
  147. public:
  148. AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name,
  149. const std::string &node_def, const std::vector<void *> &input_data_addrs,
  150. const std::vector<void *> &output_data_addrs, bool dump_flag)
  151. : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag),
  152. so_name_(so_name),
  153. kernel_name_(kernel_name),
  154. node_def_(node_def),
  155. input_data_addrs_(input_data_addrs),
  156. output_data_addrs_(output_data_addrs) {}
  157. ~AicpuTaskInfo() override {}
  158. const std::string &so_name() const { return so_name_; }
  159. const std::string &kernel_name() const { return kernel_name_; }
  160. const std::string &node_def() const { return node_def_; }
  161. const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; }
  162. const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; }
  163. private:
  164. std::string so_name_;
  165. std::string kernel_name_;
  166. std::string node_def_;
  167. std::vector<void *> input_data_addrs_;
  168. std::vector<void *> output_data_addrs_;
  169. };
  170. class LabelSetTaskInfo : public TaskInfo {
  171. public:
  172. LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
  173. : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {}
  174. ~LabelSetTaskInfo() override {}
  175. uint32_t label_id() const { return label_id_; }
  176. private:
  177. uint32_t label_id_;
  178. };
  179. class LabelGotoTaskInfo : public TaskInfo {
  180. public:
  181. LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
  182. : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {}
  183. ~LabelGotoTaskInfo() override {}
  184. uint32_t label_id() const { return label_id_; }
  185. private:
  186. uint32_t label_id_;
  187. };
  188. class LabelSwitchTaskInfo : public TaskInfo {
  189. public:
  190. LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size,
  191. const std::vector<uint32_t> &label_list, void *cond)
  192. : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false),
  193. label_size_(label_size),
  194. label_list_(label_list),
  195. cond_(cond) {}
  196. ~LabelSwitchTaskInfo() override {}
  197. uint32_t label_size() { return label_size_; };
  198. const std::vector<uint32_t> &label_list() { return label_list_; };
  199. void *cond() { return cond_; };
  200. private:
  201. uint32_t label_size_;
  202. std::vector<uint32_t> label_list_;
  203. void *cond_;
  204. };
  205. class EventTaskInfo : public TaskInfo {
  206. public:
  207. uint32_t event_id() const { return event_id_; }
  208. protected:
  209. EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id)
  210. : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {}
  211. virtual ~EventTaskInfo() override {}
  212. uint32_t event_id_;
  213. };
  214. class EventRecordTaskInfo : public EventTaskInfo {
  215. public:
  216. EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
  217. : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
  218. ~EventRecordTaskInfo() override {}
  219. };
  220. class EventWaitTaskInfo : public EventTaskInfo {
  221. public:
  222. EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
  223. : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
  224. ~EventWaitTaskInfo() override {}
  225. };
  226. class FusionStartTaskInfo : public TaskInfo {
  227. public:
  228. explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id)
  229. : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {}
  230. ~FusionStartTaskInfo() override {}
  231. };
  232. class FusionEndTaskInfo : public TaskInfo {
  233. public:
  234. explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id)
  235. : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {}
  236. ~FusionEndTaskInfo() override {}
  237. };
  238. class HcclTaskInfo : public TaskInfo {
  239. public:
  240. HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
  241. void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
  242. const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
  243. int64_t op_type, int64_t data_type, const std::string &group,
  244. std::function<bool(void *, void *)> hcom_bind_model, std::function<bool(void *)> hcom_unbind_model,
  245. std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task, bool dump_flag)
  246. : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
  247. hccl_type_(hccl_type),
  248. input_data_addr_(input_data_addr),
  249. output_data_addr_(output_data_addr),
  250. workspace_addr_(workspace_addr),
  251. workspace_size_(workspace_size),
  252. hccl_stream_num_(hccl_stream_num),
  253. private_def_(private_def),
  254. ops_kernel_store_(ops_kernel_store),
  255. count_(count),
  256. root_id_(root_id),
  257. op_type_(op_type),
  258. data_type_(data_type),
  259. group_(group),
  260. hcom_bind_model_(hcom_bind_model),
  261. hcom_unbind_model_(hcom_unbind_model),
  262. hcom_distribute_task_(hcom_distribute_task) {}
  263. ~HcclTaskInfo() override {}
  264. const std::string &hccl_type() const { return hccl_type_; }
  265. void *input_data_addr() const { return input_data_addr_; }
  266. void *output_data_addr() const { return output_data_addr_; }
  267. void *workspace_addr() const { return workspace_addr_; }
  268. int64_t workspace_size() const { return workspace_size_; }
  269. int64_t hccl_stream_num() const { return hccl_stream_num_; }
  270. const std::vector<uint8_t> &private_def() const { return private_def_; }
  271. void *ops_kernel_store() const { return ops_kernel_store_; }
  272. int32_t count() const { return count_; }
  273. int64_t root_id() const { return root_id_; }
  274. int64_t op_type() const { return op_type_; }
  275. int64_t data_type() const { return data_type_; }
  276. const std::string &group() const { return group_; }
  277. std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; }
  278. std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; }
  279. std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const {
  280. return hcom_distribute_task_;
  281. }
  282. private:
  283. std::string hccl_type_;
  284. void *input_data_addr_;
  285. void *output_data_addr_;
  286. void *workspace_addr_;
  287. int64_t workspace_size_;
  288. int64_t hccl_stream_num_;
  289. std::vector<uint8_t> private_def_;
  290. void *ops_kernel_store_;
  291. int32_t count_;
  292. int64_t root_id_;
  293. int64_t op_type_;
  294. int64_t data_type_;
  295. std::string group_;
  296. std::function<bool(void *, void *)> hcom_bind_model_;
  297. std::function<bool(void *)> hcom_unbind_model_;
  298. std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task_;
  299. };
  300. class ProfilerTraceTaskInfo : public TaskInfo {
  301. public:
  302. ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
  303. : TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false),
  304. log_id_(log_id),
  305. notify_(notify),
  306. flat_(flat) {}
  307. ~ProfilerTraceTaskInfo() override {}
  308. uint64_t log_id() const { return log_id_; }
  309. bool notify() const { return notify_; }
  310. uint32_t flat() const { return flat_; }
  311. private:
  312. uint64_t log_id_;
  313. bool notify_;
  314. uint32_t flat_;
  315. };
  316. class MemcpyAsyncTaskInfo : public TaskInfo {
  317. public:
  318. MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src,
  319. uint64_t count, uint32_t kind, bool dump_flag)
  320. : TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag),
  321. dst_(dst),
  322. dst_max_(dst_max),
  323. src_(src),
  324. count_(count),
  325. kind_(kind) {}
  326. ~MemcpyAsyncTaskInfo() override {}
  327. void *dst() const { return dst_; }
  328. uint64_t dst_max() const { return dst_max_; }
  329. void *src() const { return src_; }
  330. uint64_t count() const { return count_; }
  331. uint32_t kind() const { return kind_; }
  332. private:
  333. void *dst_;
  334. uint64_t dst_max_;
  335. void *src_;
  336. uint64_t count_;
  337. int32_t kind_;
  338. };
  339. class StreamSwitchTaskInfo : public TaskInfo {
  340. public:
  341. StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr,
  342. void *value_addr, int64_t cond, int64_t data_type)
  343. : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false),
  344. true_stream_id_(true_stream_id),
  345. input_addr_(input_addr),
  346. value_addr_(value_addr),
  347. cond_(cond),
  348. data_type_(data_type) {}
  349. ~StreamSwitchTaskInfo() override {}
  350. int64_t true_stream_id() const { return true_stream_id_; }
  351. void *input_addr() const { return input_addr_; }
  352. void *value_addr() const { return value_addr_; }
  353. int64_t cond() const { return cond_; }
  354. int64_t data_type() const { return data_type_; }
  355. private:
  356. int64_t true_stream_id_;
  357. void *input_addr_;
  358. void *value_addr_;
  359. int64_t cond_;
  360. int64_t data_type_;
  361. };
  362. class StreamActiveTaskInfo : public TaskInfo {
  363. public:
  364. StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id)
  365. : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {}
  366. ~StreamActiveTaskInfo() override {}
  367. uint32_t active_stream_id() const { return active_stream_id_; }
  368. private:
  369. uint32_t active_stream_id_;
  370. };
  371. } // namespace model_runner
  372. } // namespace ge
  373. #endif // INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示