You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_manager_utils.h 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_MANAGER_GRAPH_MANAGER_UTILS_H_
  17. #define GE_GRAPH_MANAGER_GRAPH_MANAGER_UTILS_H_
  18. #include <condition_variable>
  19. #include <map>
  20. #include <memory>
  21. #include <mutex>
  22. #include <string>
  23. #include <unordered_map>
  24. #include <utility>
  25. #include <vector>
  26. #include "common/blocking_queue.h"
  27. #include "framework/common/ge_types.h"
  28. #include "framework/common/types.h"
  29. #include "framework/common/util.h"
  30. #include "framework/common/debug/ge_log.h"
  31. #include "framework/common/ge_inner_error_codes.h"
  32. #include "graph/compute_graph.h"
  33. #include "common/local_context.h"
  34. #include "external/graph/graph.h"
  35. #include "graph/model.h"
  36. #include "common/model/ge_model.h"
  37. #include "common/model/ge_root_model.h"
  38. #include "external/register/register_fmk_types.h"
  39. #include "external/ge/ge_api_types.h"
  40. namespace ge {
  41. // state for graph task in life cycle
  42. enum GraphNodeState {
  43. GRAPH_NODE_INIT = 0,
  44. GRAPH_NODE_READY,
  45. };
  46. using GraphId = uint32_t;
  47. using ConstModelPtr = std::shared_ptr<const ge::Model>;
  48. using GeModelPtr = std::shared_ptr<ge::GeModel>;
  49. using ConstGraphPtr = std::shared_ptr<const ge::Graph>;
  50. using GraphPtr = std::shared_ptr<ge::Graph>;
  51. const uint64_t INVALID_SESSION_ID = 0xffffffffffffffffULL;
  52. const uint32_t kMaxLoadNum = 8;
  53. struct ModelIdInfo {
  54. uint32_t model_id{INVALID_MODEL_ID};
  55. };
  56. class SubGraphInfo {
  57. public:
  58. SubGraphInfo();
  59. ~SubGraphInfo();
  60. void SetSubGraph(const ComputeGraphPtr &sub_graph_ptr) { subgraph_ptr_ = sub_graph_ptr; }
  61. ComputeGraphPtr GetSubGraph() const { return subgraph_ptr_; }
  62. void SetEngineName(const std::string &engine_name) { engine_name_ = engine_name; }
  63. const std::string &GetEngineName() const { return engine_name_; }
  64. void SetInputFlag(const std::vector<bool> &input_flag) { input_flag_ = input_flag; }
  65. const std::vector<bool> &GetInputFlag() const { return input_flag_; }
  66. void SetOutputFlag(const std::vector<bool> &output_flag) { output_flag_ = output_flag; }
  67. const std::vector<bool> &GetOutputFlag() const { return output_flag_; }
  68. void SetModelIdInfo(const ModelIdInfo &model_id_info) { model_id_info_ = model_id_info; }
  69. ModelIdInfo GetModelIdInfo() const { return model_id_info_; }
  70. void SetGeModelPtr(const GeModelPtr &ge_model_ptr) { ge_model_ptr_ = ge_model_ptr; }
  71. bool GeModelIsValid() const { return ge_model_ptr_ != nullptr; }
  72. Status FreeInOutBuffer();
  73. void SetOutputContext(const std::string &output) { output_names_ = output; }
  74. std::string GetOutputContext() const { return output_names_; }
  75. void SetStreamLabel(const std::string &stream_label) { stream_label_ = stream_label; }
  76. const std::string &GetStreamLabel() const { return stream_label_; }
  77. void SetEnd2PldMap(std::unordered_map<ge::NodePtr, ge::NodePtr> &end_map) { end_to_pld_ = end_map; }
  78. const std::unordered_map<ge::NodePtr, ge::NodePtr> &GetEnd2PldMap() const { return end_to_pld_; }
  79. void SetPld2EndMap(std::unordered_map<ge::NodePtr, ge::NodePtr> &pld_map) { pld_to_end_ = pld_map; }
  80. const std::unordered_map<ge::NodePtr, ge::NodePtr> &GetPld2EndMap() const { return pld_to_end_; }
  81. private:
  82. ComputeGraphPtr subgraph_ptr_;
  83. std::string engine_name_;
  84. std::vector<bool> input_flag_;
  85. std::vector<bool> output_flag_;
  86. ModelIdInfo model_id_info_;
  87. GeModelPtr ge_model_ptr_;
  88. bool malloc_flag_;
  89. std::vector<void *> buffer_addr_;
  90. std::string output_names_;
  91. std::vector<uint32_t> buffer_size_;
  92. std::string stream_label_;
  93. std::unordered_map<ge::NodePtr, ge::NodePtr> end_to_pld_;
  94. std::unordered_map<ge::NodePtr, ge::NodePtr> pld_to_end_;
  95. };
  96. using SubGraphInfoPtr = std::shared_ptr<ge::SubGraphInfo>;
  97. using Graph2SubGraphInfoList = std::unordered_map<ComputeGraphPtr, std::vector<SubGraphInfoPtr>>;
  98. using Graph2InputNodesSubGraphInfo = std::unordered_map<ComputeGraphPtr, SubGraphInfoPtr>;
  99. // for run graph async listener
  100. class RunAsyncListener : public ge::ModelListener {
  101. public:
  102. RunAsyncListener() : sem_(1) {}
  103. ~RunAsyncListener() = default;
  104. void SetCallback(const RunAsyncCallback &callback);
  105. // callback
  106. Status OnComputeDone(uint32_t model_id, uint32_t task_id, uint32_t result,
  107. std::vector<ge::Tensor> &outputs) override;
  108. private:
  109. RunAsyncCallback callback_;
  110. BlockingQueue<uint8_t> sem_;
  111. };
  112. // single graph node info
  113. class GraphNode {
  114. public:
  115. explicit GraphNode(GraphId graph_id);
  116. ~GraphNode();
  117. GraphId GetGraphId() const { return graph_id_; }
  118. ConstGraphPtr GetGraph() const { return graph_; }
  119. void SetGraph(const GraphPtr &graph) { graph_ = graph; }
  120. ComputeGraphPtr GetComputeGraph() const { return compute_graph_; }
  121. void SetComputeGraph(const ComputeGraphPtr &compute_graph) { compute_graph_ = compute_graph; }
  122. bool GetRunFlag() const { return run_flag_; }
  123. void SetRunFlag(bool flag) { run_flag_ = flag; }
  124. void SetOmeContext(const OmeContext &context) { context_ = context; }
  125. OmeContext &GetOmeContext() { return context_; }
  126. bool IsAsync() const { return async_; }
  127. void SetAsync(bool flag) { async_ = flag; }
  128. void SetSubGraph(std::vector<SubGraphInfoPtr> &subgraph_ptr_list) { subgraph_ptr_list_ = subgraph_ptr_list; }
  129. const std::vector<SubGraphInfoPtr> &GetAllSubGraph() const { return subgraph_ptr_list_; }
  130. bool GetBuildFlag() const { return build_flag_; }
  131. void SetBuildFlag(bool buildFlag) { build_flag_ = buildFlag; }
  132. bool GetLoadFlag() const { return load_flag_; }
  133. // allow repeatively load graph owns same graph id
  134. void UpdateLoadFlag() { load_flag_ = load_count_ == 0 || load_record_ >= kMaxLoadNum; }
  135. void SetLoadFlag(bool load_flag) { load_flag_ = load_flag; }
  136. void SetGeModel(const GeModelPtr &ge_model) { ge_model_ = ge_model; }
  137. void SetIsSpecificStream(bool specific_stream) { is_specific_stream_ = specific_stream; }
  138. bool IsSpecificStream() const { return is_specific_stream_; }
  139. GeModelPtr GetGeModel() const { return ge_model_; }
  140. void SetGeRootModel(const GeRootModelPtr &ge_root_model) { ge_root_model_ = ge_root_model; }
  141. GeRootModelPtr GetGeRootModel() const { return ge_root_model_; }
  142. const std::map<std::string, std::string>& GetOptions() const { return options_; }
  143. void SetOptions(const std::map<std::string, std::string> &options) { options_ = options; }
  144. void Lock();
  145. void Unlock();
  146. void SetSemSize(uint32_t size) { sem_.SetMaxSize(size); }
  147. uint32_t GetLoadCount() const { return load_count_; }
  148. void SetLoadCount(uint32_t count) { load_count_ = count; }
  149. uint32_t GetLoadRecord() const { return load_record_; }
  150. void SetLoadRecord(uint32_t record) { load_record_ = record; }
  151. void IncreaseLoadRecord() { ++load_record_; }
  152. void IncreaseLoadCount();
  153. void DecreaseLoadCount() { --load_count_; }
  154. // run graph asynchronous listener
  155. std::shared_ptr<RunAsyncListener> graph_run_async_listener_;
  156. private:
  157. GraphId graph_id_;
  158. std::map<std::string, std::string> options_;
  159. bool run_flag_;
  160. std::vector<SubGraphInfoPtr> subgraph_ptr_list_;
  161. OmeContext context_;
  162. GraphPtr graph_;
  163. ComputeGraphPtr compute_graph_;
  164. bool build_flag_;
  165. // load_flag_ is true if more than 1 model were loaded
  166. bool load_flag_;
  167. bool async_;
  168. bool is_specific_stream_;
  169. GeModelPtr ge_model_;
  170. GeRootModelPtr ge_root_model_;
  171. BlockingQueue<uint8_t> sem_;
  172. // consist with graph_count of same graph_id in graph_manager
  173. uint32_t load_count_ = 0;
  174. // total times of loading a graph with same graph_id.
  175. uint32_t load_record_ = 0;
  176. std::mutex load_count_mu_;
  177. };
  178. using GraphNodePtr = std::shared_ptr<GraphNode>;
  179. using ConstGraphNodePtr = shared_ptr<const GraphNode>;
  180. class GraphModelListener : public ge::ModelListener {
  181. public:
  182. GraphModelListener(std::mutex &mutex, std::condition_variable &cond);
  183. ~GraphModelListener() = default;
  184. // callback
  185. Status OnComputeDone(uint32_t model_id, uint32_t task_id, uint32_t result,
  186. std::vector<ge::Tensor> &outputs) override;
  187. Status ResetResult();
  188. // need lock by caller
  189. uint32_t GetResultCode() const;
  190. bool IsFinished() const { return is_finished_; }
  191. private:
  192. uint32_t result_code_;
  193. bool is_finished_;
  194. // not owner
  195. std::mutex &mutex_;
  196. // not owner
  197. std::condition_variable &condition_;
  198. };
  199. struct GraphManagerOptions {
  200. int32_t stream_num;
  201. int32_t perf_level;
  202. int32_t encrypt_mode;
  203. int32_t framework_type;
  204. std::string ek_file;
  205. std::string cert_file;
  206. std::string hw_key_file;
  207. std::string private_key_file;
  208. std::string calibration_conf_file;
  209. std::string insert_op_file;
  210. std::string output_node_name;
  211. std::string func_bin_path;
  212. std::string input_nodes_set_fp16;
  213. std::string core_type;
  214. bool compress_flag;
  215. bool run_graph_flag;
  216. bool train_graph_flag;
  217. bool local_fmk_op_flag;
  218. bool hcom_parallel;
  219. bool enable_print_op_pass;
  220. bool is_single_op;
  221. std::map<std::string, int> stream_max_parallel_num;
  222. std::string output_datatype;
  223. std::string original_model_file;
  224. std::string save_original_model;
  225. std::string build_mode;
  226. std::string build_step;
  227. std::string tuning_path;
  228. std::string input_shape;
  229. std::string dynamic_dims;
  230. int32_t dynamic_node_type = -1;
  231. GraphManagerOptions()
  232. : stream_num(1),
  233. perf_level(domi::GEN_TASK_WITHOUT_FUSION),
  234. encrypt_mode(-1),
  235. framework_type(domi::TENSORFLOW),
  236. ek_file(""),
  237. cert_file(""),
  238. hw_key_file(""),
  239. private_key_file(""),
  240. calibration_conf_file(""),
  241. insert_op_file(""),
  242. output_node_name(""),
  243. func_bin_path(""),
  244. core_type(""),
  245. compress_flag(false),
  246. run_graph_flag(false),
  247. train_graph_flag(false),
  248. local_fmk_op_flag(false),
  249. hcom_parallel(false),
  250. enable_print_op_pass(true),
  251. is_single_op(false),
  252. save_original_model("false"),
  253. build_mode(""),
  254. build_step(""),
  255. tuning_path(""){}
  256. };
  257. } // namespace ge
  258. #endif // GE_GRAPH_MANAGER_GRAPH_MANAGER_UTILS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示