You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_manager_utils.h 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_MANAGER_GRAPH_MANAGER_UTILS_H_
  17. #define GE_GRAPH_MANAGER_GRAPH_MANAGER_UTILS_H_
  18. #include <condition_variable>
  19. #include <map>
  20. #include <memory>
  21. #include <mutex>
  22. #include <string>
  23. #include <unordered_map>
  24. #include <utility>
  25. #include <vector>
  26. #include "common/blocking_queue.h"
  27. #include "common/ge_types.h"
  28. #include "common/types.h"
  29. #include "common/util.h"
  30. #include "framework/common/debug/ge_log.h"
  31. #include "framework/common/ge_inner_error_codes.h"
  32. #include "graph/compute_graph.h"
  33. #include "graph/graph.h"
  34. #include "graph/model.h"
  35. #include "model/ge_model.h"
  36. #include "model/ge_root_model.h"
  37. #include "register/register_fmk_types.h"
  38. #include "external/ge/ge_api_types.h"
  39. namespace ge {
  40. // state for graph task in life cycle
  41. enum GraphNodeState {
  42. GRAPH_NODE_INIT = 0,
  43. GRAPH_NODE_READY,
  44. };
  45. using GraphId = uint32_t;
  46. using ConstModelPtr = std::shared_ptr<const ge::Model>;
  47. using GeModelPtr = std::shared_ptr<ge::GeModel>;
  48. using ConstGraphPtr = std::shared_ptr<const ge::Graph>;
  49. using GraphPtr = std::shared_ptr<ge::Graph>;
  50. const uint64_t INVALID_SESSION_ID = 0xffffffffffffffffULL;
  51. const uint32_t kMaxLoadNum = 8;
  52. struct ModelIdInfo {
  53. uint32_t model_id{INVALID_MODEL_ID};
  54. };
  55. class SubGraphInfo {
  56. public:
  57. SubGraphInfo();
  58. ~SubGraphInfo();
  59. void SetSubGraph(const ComputeGraphPtr &sub_graph_ptr) { subgraph_ptr_ = sub_graph_ptr; }
  60. ComputeGraphPtr GetSubGraph() const { return subgraph_ptr_; }
  61. void SetEngineName(const std::string &engine_name) { engine_name_ = engine_name; }
  62. const std::string &GetEngineName() const { return engine_name_; }
  63. void SetInputFlag(const std::vector<bool> &input_flag) { input_flag_ = input_flag; }
  64. const std::vector<bool> &GetInputFlag() const { return input_flag_; }
  65. void SetOutputFlag(const std::vector<bool> &output_flag) { output_flag_ = output_flag; }
  66. const std::vector<bool> &GetOutputFlag() const { return output_flag_; }
  67. void SetModelIdInfo(const ModelIdInfo &model_id_info) { model_id_info_ = model_id_info; }
  68. ModelIdInfo GetModelIdInfo() const { return model_id_info_; }
  69. void SetGeModelPtr(const GeModelPtr &ge_model_ptr) { ge_model_ptr_ = ge_model_ptr; }
  70. bool GeModelIsValid() const { return ge_model_ptr_ != nullptr; }
  71. Status FreeInOutBuffer();
  72. void SetOutputContext(const std::string &output) { output_names_ = output; }
  73. std::string GetOutputContext() const { return output_names_; }
  74. void SetStreamLabel(const std::string &stream_label) { stream_label_ = stream_label; }
  75. const std::string &GetStreamLabel() const { return stream_label_; }
  76. void SetEnd2PldMap(std::unordered_map<ge::NodePtr, ge::NodePtr> &end_map) { end_to_pld_ = end_map; }
  77. const std::unordered_map<ge::NodePtr, ge::NodePtr> &GetEnd2PldMap() const { return end_to_pld_; }
  78. void SetPld2EndMap(std::unordered_map<ge::NodePtr, ge::NodePtr> &pld_map) { pld_to_end_ = pld_map; }
  79. const std::unordered_map<ge::NodePtr, ge::NodePtr> &GetPld2EndMap() const { return pld_to_end_; }
  80. private:
  81. ComputeGraphPtr subgraph_ptr_;
  82. std::string engine_name_;
  83. std::vector<bool> input_flag_;
  84. std::vector<bool> output_flag_;
  85. ModelIdInfo model_id_info_;
  86. GeModelPtr ge_model_ptr_;
  87. bool malloc_flag_;
  88. std::vector<void *> buffer_addr_;
  89. std::string output_names_;
  90. std::vector<uint32_t> buffer_size_;
  91. std::string stream_label_;
  92. std::unordered_map<ge::NodePtr, ge::NodePtr> end_to_pld_;
  93. std::unordered_map<ge::NodePtr, ge::NodePtr> pld_to_end_;
  94. };
  95. using SubGraphInfoPtr = std::shared_ptr<ge::SubGraphInfo>;
  96. using Graph2SubGraphInfoList = std::unordered_map<ComputeGraphPtr, std::vector<SubGraphInfoPtr>>;
  97. using Graph2InputNodesSubGraphInfo = std::unordered_map<ComputeGraphPtr, SubGraphInfoPtr>;
  98. // for run graph async listener
  99. class RunAsyncListener : public ge::ModelListener {
  100. public:
  101. RunAsyncListener() : sem_(1) {}
  102. ~RunAsyncListener() = default;
  103. void SetCallback(const RunAsyncCallback &callback);
  104. // callback
  105. Status OnComputeDone(uint32_t model_id, uint32_t task_id, uint32_t result,
  106. std::vector<ge::OutputTensorInfo> &outputs) override;
  107. private:
  108. RunAsyncCallback callback_;
  109. BlockingQueue<uint8_t> sem_;
  110. };
  111. // single graph node info
  112. class GraphNode {
  113. public:
  114. explicit GraphNode(GraphId graph_id);
  115. ~GraphNode();
  116. GraphId GetGraphId() const { return graph_id_; }
  117. ConstGraphPtr GetGraph() const { return graph_; }
  118. void SetGraph(const GraphPtr &graph) { graph_ = graph; }
  119. ComputeGraphPtr GetComputeGraph() const { return compute_graph_; }
  120. void SetComputeGraph(const ComputeGraphPtr &compute_graph) { compute_graph_ = compute_graph; }
  121. bool GetRunFlag() const { return run_flag_; }
  122. void SetRunFlag(bool flag) { run_flag_ = flag; }
  123. bool IsAsync() const { return async_; }
  124. void SetAsync(bool flag) { async_ = flag; }
  125. void SetSubGraph(std::vector<SubGraphInfoPtr> &subgraph_ptr_list) { subgraph_ptr_list_ = subgraph_ptr_list; }
  126. const std::vector<SubGraphInfoPtr> &GetAllSubGraph() const { return subgraph_ptr_list_; }
  127. bool GetBuildFlag() const { return build_flag_; }
  128. void SetBuildFlag(bool buildFlag) { build_flag_ = buildFlag; }
  129. bool GetLoadFlag() const { return load_flag_; }
  130. // allow repeatively load graph owns same graph id
  131. void UpdateLoadFlag() { load_flag_ = load_count_ == 0 || load_record_ >= kMaxLoadNum; }
  132. void SetLoadFlag(bool load_flag) { load_flag_ = load_flag; }
  133. void SetGeModel(const GeModelPtr &ge_model) { ge_model_ = ge_model; }
  134. GeModelPtr GetGeModel() const { return ge_model_; }
  135. void SetGeRootModel(const GeRootModelPtr &ge_root_model) { ge_root_model_ = ge_root_model; }
  136. GeRootModelPtr GetGeRootModel() const { return ge_root_model_; }
  137. const std::map<std::string, std::string>& GetOptions() const { return options_; }
  138. void SetOptions(const std::map<std::string, std::string> &options) { options_ = options; }
  139. void Lock();
  140. void Unlock();
  141. void SetSemSize(uint32_t size) { sem_.SetMaxSize(size); }
  142. uint32_t GetLoadCount() const { return load_count_; }
  143. void SetLoadCount(uint32_t count) { load_count_ = count; }
  144. uint32_t GetLoadRecord() const { return load_record_; }
  145. void SetLoadRecord(uint32_t record) { load_record_ = record; }
  146. void IncreaseLoadRecord() { ++load_record_; }
  147. void IncreaseLoadCount();
  148. void DecreaseLoadCount() { --load_count_; }
  149. // run graph asynchronous listener
  150. std::shared_ptr<RunAsyncListener> graph_run_async_listener_;
  151. private:
  152. GraphId graph_id_;
  153. std::map<std::string, std::string> options_;
  154. bool run_flag_;
  155. std::vector<SubGraphInfoPtr> subgraph_ptr_list_;
  156. GraphPtr graph_;
  157. ComputeGraphPtr compute_graph_;
  158. bool build_flag_;
  159. // load_flag_ is true if more than 1 model were loaded
  160. bool load_flag_;
  161. bool async_;
  162. GeModelPtr ge_model_;
  163. GeRootModelPtr ge_root_model_;
  164. BlockingQueue<uint8_t> sem_;
  165. // consist with graph_count of same graph_id in graph_manager
  166. uint32_t load_count_ = 0;
  167. // total times of loading a graph with same graph_id.
  168. uint32_t load_record_ = 0;
  169. std::mutex load_count_mu_;
  170. };
  171. using GraphNodePtr = std::shared_ptr<GraphNode>;
  172. using ConstGraphNodePtr = shared_ptr<const GraphNode>;
  173. class GraphModelListener : public ge::ModelListener {
  174. public:
  175. GraphModelListener(std::mutex &mutex, std::condition_variable &cond);
  176. ~GraphModelListener() = default;
  177. // callback
  178. Status OnComputeDone(uint32_t model_id, uint32_t task_id, uint32_t result,
  179. std::vector<ge::OutputTensorInfo> &outputs) override;
  180. Status ResetResult();
  181. // need lock by caller
  182. uint32_t GetResultCode() const;
  183. bool IsFinished() const { return is_finished_; }
  184. private:
  185. uint32_t result_code_;
  186. bool is_finished_;
  187. // not owner
  188. std::mutex &mutex_;
  189. // not owner
  190. std::condition_variable &condition_;
  191. };
  192. struct GraphManagerOptions {
  193. int32_t stream_num;
  194. int32_t perf_level;
  195. int32_t encrypt_mode;
  196. int32_t framework_type;
  197. std::string ek_file;
  198. std::string cert_file;
  199. std::string hw_key_file;
  200. std::string private_key_file;
  201. std::string calibration_conf_file;
  202. std::string insert_op_file;
  203. std::string output_node_name;
  204. std::string func_bin_path;
  205. std::string input_nodes_set_fp16;
  206. std::string core_type;
  207. bool compress_flag;
  208. bool run_graph_flag;
  209. bool train_graph_flag;
  210. bool local_fmk_op_flag;
  211. bool hcom_parallel;
  212. bool enable_print_op_pass;
  213. bool is_single_op;
  214. std::map<std::string, int> stream_max_parallel_num;
  215. std::string output_datatype;
  216. std::string original_model_file;
  217. std::string save_original_model;
  218. std::string build_mode;
  219. std::string build_step;
  220. std::string tuning_path;
  221. std::string input_shape;
  222. std::string dynamic_dims;
  223. int32_t dynamic_node_type = -1;
  224. GraphManagerOptions()
  225. : stream_num(1),
  226. perf_level(domi::GEN_TASK_WITHOUT_FUSION),
  227. encrypt_mode(-1),
  228. framework_type(domi::TENSORFLOW),
  229. ek_file(""),
  230. cert_file(""),
  231. hw_key_file(""),
  232. private_key_file(""),
  233. calibration_conf_file(""),
  234. insert_op_file(""),
  235. output_node_name(""),
  236. func_bin_path(""),
  237. core_type(""),
  238. compress_flag(false),
  239. run_graph_flag(false),
  240. train_graph_flag(false),
  241. local_fmk_op_flag(false),
  242. hcom_parallel(false),
  243. enable_print_op_pass(true),
  244. is_single_op(false),
  245. save_original_model("false"),
  246. build_mode(""),
  247. build_step(""),
  248. tuning_path(""){}
  249. };
  250. } // namespace ge
  251. #endif // GE_GRAPH_MANAGER_GRAPH_MANAGER_UTILS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示