You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

task_generator.h 7.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_BUILD_TASK_GENERATOR_H_
  17. #define GE_GRAPH_BUILD_TASK_GENERATOR_H_
  18. #include <map>
  19. #include <memory>
  20. #include <string>
  21. #include <vector>
  22. #include "common/ge_inner_error_codes.h"
  23. #include "common/opskernel/ops_kernel_info_types.h"
  24. #include "framework/common/types.h"
  25. #include "graph/compute_graph.h"
  26. #include "graph/model.h"
  27. #include "proto/task.pb.h"
  28. #include "runtime/rt.h"
  29. namespace ge {
  30. class GELib;
  31. class OpsKernelManager;
  32. struct ProfilingPoint {
  33. uint32_t fp_index = 0;
  34. uint32_t bp_index = 0;
  35. std::set<uint32_t> end_index;
  36. };
  37. // Describes infos needed by generate task for fusion node
  38. struct FusionTaskInfo {
  39. RunContext &run_context;
  40. ComputeGraphPtr &graph;
  41. NodePtr &node;
  42. OpDescPtr &fusion_op_desc;
  43. uint32_t &node_index;
  44. std::shared_ptr<GELib> &ge_lib;
  45. const OpsKernelManager &ops_kernel_manager;
  46. std::vector<domi::TaskDef> &task_def_list;
  47. std::map<uint32_t, string> &op_name_map;
  48. ProfilingPoint &profiling_point;
  49. vector<uint32_t> all_reduce_nodes;
  50. uint64_t all_reduce_node_idx;
  51. };
  52. class TaskGenerator {
  53. public:
  54. TaskGenerator() = default;
  55. TaskGenerator(const TaskGenerator &) = delete;
  56. TaskGenerator &operator=(const TaskGenerator &) = delete;
  57. virtual ~TaskGenerator();
  58. TaskGenerator(uint8_t *var_mem_base, uint64_t var_mem_size);
  59. ///
  60. /// get task info.
  61. /// @param model model
  62. /// @param graph compute graph
  63. /// @param buffer weights buffer
  64. /// @param session_id session id
  65. /// @return SUCCESS: success
  66. /// other:failed
  67. ///
  68. Status GetTaskInfo(Model &model, ComputeGraphPtr &graph, uint64_t session_id, RunContext &run_context);
  69. Status FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
  70. std::vector<uint32_t> &all_reduce_nodes);
  71. private:
  72. Status UpdateAnchorStatusForFfts(const NodePtr &node);
  73. Status UpdateAnchorStatus(const NodePtr &node);
  74. Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id);
  75. ///
  76. /// call engine to generate known shape task.
  77. /// @param run_context run context
  78. /// @param graph compute graph
  79. /// @param task_def_list task def list generate by engine
  80. /// @param op_name_map relation of task index and op
  81. /// @return SUCCESS:seccess
  82. /// Other: failed
  83. ///
  84. Status GenerateTask(RunContext &run_context, ComputeGraphPtr &graph, std::vector<domi::TaskDef> &task_def_list,
  85. std::map<uint32_t, string> &op_name_map);
  86. ///
  87. /// AddModelTaskToModel
  88. /// @param model_task_def model task
  89. /// @param model_def model
  90. /// @return SUCCESS:seccess
  91. /// Other: failed
  92. ///
  93. Status AddModelTaskToModel(const domi::ModelTaskDef &model_task_def, uint64_t session_id, Model &model_def,
  94. RunContext &run_context);
  95. Status MarkNodeAndSetIndex(ComputeGraphPtr &graph);
  96. // Mark first and last op according to the same stream and engine
  97. Status MarkFirstAndLastOps(const vector<OpDescPtr> &ops, bool is_single_stream) const;
  98. // profiling interface
  99. Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const;
  100. Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
  101. vector<uint32_t> &all_reduce_nodes) const;
  102. Status FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node, uint32_t &bp_index) const;
  103. Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str,
  104. ProfilingPoint &profiling_point) const;
  105. Status FindBpOfEnv(const ComputeGraphPtr &graph, const std::string &bp_point_str, ProfilingPoint &profiling_point,
  106. vector<uint32_t> &all_reduce_nodes) const;
  107. Status GetFpBpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, vector<uint32_t> &all_reduce_nodes,
  108. std::string& fp_point_str, std::string& bp_point_str) const;
  109. Status FindProfilingTaskIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
  110. std::vector<uint32_t> &all_reduce_nodes) const;
  111. Status InsertProfilingTaskBefore(const OpDescPtr &op_desc, const ProfilingPoint &profiling_point,
  112. std::vector<uint32_t> &all_reduce_nodes, uint32_t node_index,
  113. std::vector<domi::TaskDef> &task_def_list);
  114. Status InsertProfilingArTaskBefore(const OpDescPtr &op_desc, std::vector<uint32_t> &all_reduce_nodes,
  115. uint32_t node_index, std::vector<domi::TaskDef> &task_def_listy,
  116. bool is_insert_bp_profiling_task);
  117. Status InsertProfilingTaskAfter(const OpDescPtr &op_desc, const ProfilingPoint &profiling_point,
  118. std::vector<uint32_t> &all_reduce_nodes, uint32_t node_index,
  119. std::vector<domi::TaskDef> &task_def_list);
  120. Status InsertProfilingArTaskAfter(const OpDescPtr &op_desc, std::vector<uint32_t> &all_reduce_nodes,
  121. uint32_t node_index, std::vector<domi::TaskDef> &task_def_list,
  122. bool is_insert_bp_profiling_task);
  123. static bool IsProfPoint(const OpDescPtr &op, const std::string &name);
  124. /// call engine to generate task for fusion node.
  125. /// @param FusionTaskInfo
  126. /// @param fusion_nodes: nodes in graph with groud_id attr which means fusion node
  127. /// @param fusion_nodes_seen: fusion node has been called generate task
  128. /// @return SUCCESS:seccess
  129. /// Other: failed
  130. ///
  131. Status GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info,
  132. std::map<int64_t, std::vector<NodePtr>> &fusion_nodes,
  133. std::unordered_set<Node *> &fusion_nodes_seen);
  134. Status SaveFusionNodes(map<int64_t, std::vector<NodePtr>> &fusion_nodes, ComputeGraphPtr &graph);
  135. Status SetUnknownShapeStream(RunContext &run_context, rtStream_t &stream);
  136. Status DestroyUnknownShapeStream(RunContext &run_context, rtStream_t &stream);
  137. Status SetKnownShapeStream(RunContext &run_context, int64_t stream_id);
  138. bool IsSubGraphOfDynamicGraph(const ComputeGraphPtr &graph) const;
  139. uint8_t *var_mem_base_ = nullptr;
  140. uint64_t var_mem_size_ = 0;
  141. };
  142. } // namespace ge
  143. #endif // GE_GRAPH_BUILD_TASK_GENERATOR_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示