You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_manager_utils.h 9.2 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_MANAGER_GRAPH_MANAGER_UTILS_H_
  17. #define GE_GRAPH_MANAGER_GRAPH_MANAGER_UTILS_H_
  18. #include <condition_variable>
  19. #include <map>
  20. #include <memory>
  21. #include <mutex>
  22. #include <string>
  23. #include <unordered_map>
  24. #include <utility>
  25. #include <vector>
  26. #include "common/blocking_queue.h"
  27. #include "common/ge_types.h"
  28. #include "common/types.h"
  29. #include "common/util.h"
  30. #include "framework/common/debug/ge_log.h"
  31. #include "framework/common/ge_inner_error_codes.h"
  32. #include "graph/compute_graph.h"
  33. #include "graph/graph.h"
  34. #include "graph/model.h"
  35. #include "model/ge_model.h"
  36. #include "model/ge_root_model.h"
  37. #include "register/register_fmk_types.h"
  38. #include "external/ge/ge_api_types.h"
  39. namespace ge {
  40. // state for graph task in life cycle
  41. enum GraphNodeState {
  42. GRAPH_NODE_INIT = 0,
  43. GRAPH_NODE_READY,
  44. };
  45. using GraphId = uint32_t;
  46. using ConstModelPtr = std::shared_ptr<const ge::Model>;
  47. using GeModelPtr = std::shared_ptr<ge::GeModel>;
  48. using ConstGraphPtr = std::shared_ptr<const ge::Graph>;
  49. using GraphPtr = std::shared_ptr<ge::Graph>;
  50. const uint64_t INVALID_SESSION_ID = 0xffffffffffffffffULL;
  51. struct ModelIdInfo {
  52. uint32_t model_id{INVALID_MODEL_ID};
  53. };
  54. class SubGraphInfo {
  55. public:
  56. SubGraphInfo();
  57. ~SubGraphInfo();
  58. void SetSubGraph(const ComputeGraphPtr &sub_graph_ptr) { subgraph_ptr_ = sub_graph_ptr; }
  59. ComputeGraphPtr GetSubGraph() const { return subgraph_ptr_; }
  60. void SetEngineName(const std::string &engine_name) { engine_name_ = engine_name; }
  61. const std::string &GetEngineName() const { return engine_name_; }
  62. void SetInputFlag(const std::vector<bool> &input_flag) { input_flag_ = input_flag; }
  63. const std::vector<bool> &GetInputFlag() const { return input_flag_; }
  64. void SetOutputFlag(const std::vector<bool> &output_flag) { output_flag_ = output_flag; }
  65. const std::vector<bool> &GetOutputFlag() const { return output_flag_; }
  66. void SetModelIdInfo(const ModelIdInfo &model_id_info) { model_id_info_ = model_id_info; }
  67. ModelIdInfo GetModelIdInfo() const { return model_id_info_; }
  68. void SetGeModelPtr(const GeModelPtr &ge_model_ptr) { ge_model_ptr_ = ge_model_ptr; }
  69. bool GeModelIsValid() const { return ge_model_ptr_ != nullptr; }
  70. Status FreeInOutBuffer();
  71. void SetOutputContext(const std::string &output) { output_names_ = output; }
  72. std::string GetOutputContext() const { return output_names_; }
  73. void SetStreamLabel(const std::string &stream_label) { stream_label_ = stream_label; }
  74. const std::string &GetStreamLabel() const { return stream_label_; }
  75. void SetEnd2PldMap(std::unordered_map<ge::NodePtr, ge::NodePtr> &end_map) { end_to_pld_ = end_map; }
  76. const std::unordered_map<ge::NodePtr, ge::NodePtr> &GetEnd2PldMap() const { return end_to_pld_; }
  77. void SetPld2EndMap(std::unordered_map<ge::NodePtr, ge::NodePtr> &pld_map) { pld_to_end_ = pld_map; }
  78. const std::unordered_map<ge::NodePtr, ge::NodePtr> &GetPld2EndMap() const { return pld_to_end_; }
  79. private:
  80. ComputeGraphPtr subgraph_ptr_;
  81. std::string engine_name_;
  82. std::vector<bool> input_flag_;
  83. std::vector<bool> output_flag_;
  84. ModelIdInfo model_id_info_;
  85. GeModelPtr ge_model_ptr_;
  86. bool malloc_flag_;
  87. std::vector<void *> buffer_addr_;
  88. std::string output_names_;
  89. std::vector<uint32_t> buffer_size_;
  90. std::string stream_label_;
  91. std::unordered_map<ge::NodePtr, ge::NodePtr> end_to_pld_;
  92. std::unordered_map<ge::NodePtr, ge::NodePtr> pld_to_end_;
  93. };
  94. using SubGraphInfoPtr = std::shared_ptr<ge::SubGraphInfo>;
  95. using Graph2SubGraphInfoList = std::unordered_map<ComputeGraphPtr, std::vector<SubGraphInfoPtr>>;
  96. using Graph2InputNodesSubGraphInfo = std::unordered_map<ComputeGraphPtr, SubGraphInfoPtr>;
  97. // for run graph async listener
  98. class RunAsyncListener : public ge::ModelListener {
  99. public:
  100. RunAsyncListener() : sem_(1) {}
  101. ~RunAsyncListener() = default;
  102. void SetCallback(const RunAsyncCallback &callback);
  103. // callback
  104. Status OnComputeDone(uint32_t model_id, uint32_t task_id, uint32_t result,
  105. std::vector<ge::OutputTensorInfo> &outputs) override;
  106. private:
  107. RunAsyncCallback callback_;
  108. BlockingQueue<uint8_t> sem_;
  109. };
  110. // single graph node info
  111. class GraphNode {
  112. public:
  113. explicit GraphNode(GraphId graph_id);
  114. ~GraphNode();
  115. GraphId GetGraphId() const { return graph_id_; }
  116. ConstGraphPtr GetGraph() const { return graph_; }
  117. void SetGraph(const GraphPtr &graph) { graph_ = graph; }
  118. ComputeGraphPtr GetComputeGraph() const { return compute_graph_; }
  119. void SetComputeGraph(const ComputeGraphPtr &compute_graph) { compute_graph_ = compute_graph; }
  120. bool GetRunFlag() const { return run_flag_; }
  121. void SetRunFlag(bool flag) { run_flag_ = flag; }
  122. bool IsAsync() const { return async_; }
  123. void SetAsync(bool flag) { async_ = flag; }
  124. void SetSubGraph(std::vector<SubGraphInfoPtr> &subgraph_ptr_list) { subgraph_ptr_list_ = subgraph_ptr_list; }
  125. const std::vector<SubGraphInfoPtr> &GetAllSubGraph() const { return subgraph_ptr_list_; }
  126. bool GetBuildFlag() const { return build_flag_; }
  127. void SetBuildFlag(bool buildFlag) { build_flag_ = buildFlag; }
  128. bool GetLoadFlag() const { return load_flag_; }
  129. void SetLoadFlag(bool load_flag) { load_flag_ = load_flag; }
  130. void SetGeModel(const GeModelPtr &ge_model) { ge_model_ = ge_model; }
  131. GeModelPtr GetGeModel() const { return ge_model_; }
  132. void SetGeRootModel(const GeRootModelPtr &ge_root_model) { ge_root_model_ = ge_root_model; }
  133. GeRootModelPtr GetGeRootModel() const { return ge_root_model_; }
  134. const std::map<std::string, std::string>& GetOptions() const { return options_; }
  135. void SetOptions(const std::map<std::string, std::string> &options) { options_ = options; }
  136. void Lock();
  137. void Unlock();
  138. // run graph asynchronous listener
  139. std::shared_ptr<RunAsyncListener> graph_run_async_listener_;
  140. private:
  141. GraphId graph_id_;
  142. std::map<std::string, std::string> options_;
  143. bool run_flag_;
  144. std::vector<SubGraphInfoPtr> subgraph_ptr_list_;
  145. GraphPtr graph_;
  146. ComputeGraphPtr compute_graph_;
  147. bool build_flag_;
  148. bool load_flag_;
  149. bool async_;
  150. GeModelPtr ge_model_;
  151. GeRootModelPtr ge_root_model_;
  152. BlockingQueue<uint8_t> sem_;
  153. };
  154. using GraphNodePtr = std::shared_ptr<GraphNode>;
  155. using ConstGraphNodePtr = shared_ptr<const GraphNode>;
  156. class GraphModelListener : public ge::ModelListener {
  157. public:
  158. GraphModelListener(std::mutex &mutex, std::condition_variable &cond);
  159. ~GraphModelListener() = default;
  160. // callback
  161. Status OnComputeDone(uint32_t model_id, uint32_t task_id, uint32_t result,
  162. std::vector<ge::OutputTensorInfo> &outputs) override;
  163. Status ResetResult();
  164. // need lock by caller
  165. uint32_t GetResultCode() const;
  166. bool IsFinished() const { return is_finished_; }
  167. private:
  168. uint32_t result_code_;
  169. bool is_finished_;
  170. // not owner
  171. std::mutex &mutex_;
  172. // not owner
  173. std::condition_variable &condition_;
  174. };
  175. struct GraphManagerOptions {
  176. int32_t stream_num;
  177. int32_t perf_level;
  178. int32_t encrypt_mode;
  179. int32_t framework_type;
  180. std::string ek_file;
  181. std::string cert_file;
  182. std::string hw_key_file;
  183. std::string private_key_file;
  184. std::string calibration_conf_file;
  185. std::string insert_op_file;
  186. std::string output_node_name;
  187. std::string func_bin_path;
  188. std::string input_nodes_set_fp16;
  189. std::string core_type;
  190. bool compress_flag;
  191. bool run_graph_flag;
  192. bool train_graph_flag;
  193. bool local_fmk_op_flag;
  194. bool hcom_parallel;
  195. bool enable_print_op_pass;
  196. bool is_single_op;
  197. std::map<std::string, int> stream_max_parallel_num;
  198. std::string output_datatype;
  199. std::string original_model_file;
  200. std::string save_original_model;
  201. std::string build_mode;
  202. std::string build_step;
  203. std::string input_shape;
  204. std::string dynamic_dims;
  205. int32_t dynamic_node_type = -1;
  206. GraphManagerOptions()
  207. : stream_num(1),
  208. perf_level(domi::GEN_TASK_WITHOUT_FUSION),
  209. encrypt_mode(-1),
  210. framework_type(domi::TENSORFLOW),
  211. ek_file(""),
  212. cert_file(""),
  213. hw_key_file(""),
  214. private_key_file(""),
  215. calibration_conf_file(""),
  216. insert_op_file(""),
  217. output_node_name(""),
  218. func_bin_path(""),
  219. core_type(""),
  220. compress_flag(false),
  221. run_graph_flag(false),
  222. train_graph_flag(false),
  223. local_fmk_op_flag(false),
  224. hcom_parallel(false),
  225. enable_print_op_pass(true),
  226. is_single_op(false),
  227. save_original_model("false"),
  228. build_mode(""),
  229. build_step("") {}
  230. };
  231. } // namespace ge
  232. #endif // GE_GRAPH_MANAGER_GRAPH_MANAGER_UTILS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示