You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

task_generator.h 6.9 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_BUILD_TASK_GENERATOR_H_
  17. #define GE_GRAPH_BUILD_TASK_GENERATOR_H_
  18. #include <map>
  19. #include <memory>
  20. #include <string>
  21. #include <vector>
  22. #include "common/ge_inner_error_codes.h"
  23. #include "common/opskernel/ops_kernel_info_types.h"
  24. #include "framework/common/types.h"
  25. #include "graph/compute_graph.h"
  26. #include "graph/model.h"
  27. #include "proto/task.pb.h"
  28. #include "runtime/rt.h"
  29. namespace ge {
  30. class GELib;
  31. class OpsKernelManager;
  32. struct ProfilingPoint {
  33. uint32_t fp_index = 0;
  34. uint32_t bp_index = 0;
  35. std::set<uint32_t> end_index;
  36. };
  37. // Describes infos needed by generate task for fusion node
  38. struct FusionTaskInfo {
  39. RunContext &run_context;
  40. ComputeGraphPtr &graph;
  41. NodePtr &node;
  42. OpDescPtr &fusion_op_desc;
  43. uint32_t &node_index;
  44. std::shared_ptr<GELib> &ge_lib;
  45. const OpsKernelManager &ops_kernel_manager;
  46. std::vector<domi::TaskDef> &task_def_list;
  47. std::map<uint32_t, string> &op_name_map;
  48. ProfilingPoint &profiling_point;
  49. vector<uint32_t> all_reduce_nodes;
  50. uint64_t all_reduce_node_idx;
  51. };
  52. class TaskGenerator {
  53. public:
  54. TaskGenerator() = default;
  55. TaskGenerator(const TaskGenerator &) = delete;
  56. TaskGenerator &operator=(const TaskGenerator &) = delete;
  57. virtual ~TaskGenerator();
  58. TaskGenerator(uint8_t *var_mem_base, uint64_t var_mem_size);
  59. ///
  60. /// get task info.
  61. /// @param model model
  62. /// @param graph compute graph
  63. /// @param buffer weights buffer
  64. /// @param session_id session id
  65. /// @return SUCCESS: success
  66. /// other:failed
  67. ///
  68. Status GetTaskInfo(Model &model, ComputeGraphPtr &graph, uint64_t session_id, RunContext &run_context);
  69. Status FindProfilingNodeIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
  70. std::vector<uint32_t> &all_reduce_nodes);
  71. private:
  72. Status UpdateAnchorStatus(const NodePtr &node);
  73. Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id);
  74. ///
  75. /// call engine to generate known shape task.
  76. /// @param run_context run context
  77. /// @param graph compute graph
  78. /// @param task_def_list task def list generate by engine
  79. /// @param op_name_map relation of task index and op
  80. /// @return SUCCESS:seccess
  81. /// Other: failed
  82. ///
  83. Status GenerateTask(RunContext &run_context, ComputeGraphPtr &graph, std::vector<domi::TaskDef> &task_def_list,
  84. std::map<uint32_t, string> &op_name_map);
  85. ///
  86. /// AddModelTaskToModel
  87. /// @param model_task_def model task
  88. /// @param model_def model
  89. /// @return SUCCESS:seccess
  90. /// Other: failed
  91. ///
  92. Status AddModelTaskToModel(const domi::ModelTaskDef &model_task_def, uint64_t session_id, Model &model_def,
  93. RunContext &run_context);
  94. Status MarkNodeAndSetIndex(ComputeGraphPtr &graph);
  95. // Mark first and last op according to the same stream and engine
  96. Status MarkFirstAndLastOps(const vector<OpDescPtr> &ops, bool is_single_stream) const;
  97. // profiling interface
  98. Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const;
  99. Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
  100. vector<uint32_t> &all_reduce_nodes) const;
  101. uint32_t FindLastBpFromBpNode(const ComputeGraphPtr &graph, const NodePtr &bp_node) const;
  102. Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str,
  103. ProfilingPoint &profiling_point) const;
  104. Status FindBpOfEnv(const ComputeGraphPtr &graph, const std::string &bp_point_str, ProfilingPoint &profiling_point,
  105. vector<uint32_t> &all_reduce_nodes) const;
  106. Status GetFpBpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, vector<uint32_t> &all_reduce_nodes,
  107. std::string& fp_point_str, std::string& bp_point_str) const;
  108. Status FindProfilingTaskIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point,
  109. std::vector<uint32_t> &all_reduce_nodes) const;
  110. Status InsertProfilingTaskBefore(const OpDescPtr &op_desc, const ProfilingPoint &profiling_point,
  111. std::vector<uint32_t> &all_reduce_nodes, uint32_t node_index,
  112. std::vector<domi::TaskDef> &task_def_list);
  113. Status InsertProfilingArTaskBefore(const OpDescPtr &op_desc, std::vector<uint32_t> &all_reduce_nodes,
  114. uint32_t node_index, std::vector<domi::TaskDef> &task_def_listy,
  115. bool is_insert_bp_profiling_task);
  116. Status InsertProfilingTaskAfter(const OpDescPtr &op_desc, const ProfilingPoint &profiling_point,
  117. std::vector<uint32_t> &all_reduce_nodes, uint32_t node_index,
  118. std::vector<domi::TaskDef> &task_def_list);
  119. Status InsertProfilingArTaskAfter(const OpDescPtr &op_desc, std::vector<uint32_t> &all_reduce_nodes,
  120. uint32_t node_index, std::vector<domi::TaskDef> &task_def_list,
  121. bool is_insert_bp_profiling_task);
  122. static bool IsProfPoint(const OpDescPtr &op, const std::string &name);
  123. /// call engine to generate task for fusion node.
  124. /// @param FusionTaskInfo
  125. /// @param fusion_nodes: nodes in graph with groud_id attr which means fusion node
  126. /// @param fusion_nodes_seen: fusion node has been called generate task
  127. /// @return SUCCESS:seccess
  128. /// Other: failed
  129. ///
  130. Status GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info,
  131. std::map<int64_t, std::vector<NodePtr>> &fusion_nodes,
  132. std::unordered_set<Node *> &fusion_nodes_seen);
  133. Status SaveFusionNodes(map<int64_t, std::vector<NodePtr>> &fusion_nodes, ComputeGraphPtr &graph);
  134. Status SetUnknownShapeStream(RunContext &run_context, rtStream_t &stream);
  135. Status DestroyUnknownShapeStream(RunContext &run_context, rtStream_t &stream);
  136. Status SetKnownShapeStream(RunContext &run_context, int64_t stream_id);
  137. bool IsSubGraphOfDynamicGraph(const ComputeGraphPtr &graph) const;
  138. uint8_t *var_mem_base_ = nullptr;
  139. uint64_t var_mem_size_ = 0;
  140. };
  141. } // namespace ge
  142. #endif // GE_GRAPH_BUILD_TASK_GENERATOR_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示