You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

scope_fusion_pass_register.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_
  17. #define EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include <map>
  22. #include "ge/ge_api_error_codes.h"
  23. #include "register/register_error_codes.h"
  24. #include "register/register_types.h"
  25. #include "graph/operator.h"
  26. #define CHECK_INNER_NODE_CONDITION(cond, fusion_rlt) \
  27. do { \
  28. if (!(cond)) { \
  29. if ((fusion_rlt) != nullptr) { \
  30. (fusion_rlt)->SetType(ge::kScopeInvalidType); \
  31. } \
  32. return; \
  33. } \
  34. } while (0)
  35. namespace domi {
  36. class TensorFlowModelParser;
  37. } // namespace domi
  38. namespace ge {
  39. const int32_t kFusionDisableIndex = 99999;
  40. const char *const kScopeToMultiNodes = "ScopeToMultiNodes";
  41. const char *const kScopeInvalidType = "ScopeInvalidType";
  42. const char *const kInputFromFusionScope = "InputFromFusionScope";
  43. const char *const kOutputToFusionScope = "OutputToFusionScope";
  44. class ScopePattern;
  45. using ScopeFusionPatterns = std::vector<std::vector<ScopePattern *>>;
  46. class ScopePassManager;
  47. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY Scope {
  48. public:
  49. explicit Scope(const std::string &name, const std::string &sub_type = "", Scope *father_scope = nullptr);
  50. ~Scope();
  51. std::string Name() const;
  52. std::string SubType() const;
  53. std::map<std::string, ge::OperatorPtr> AllNodesMap() const;
  54. Scope *GetSubScope(const std::string &scope_name) const;
  55. std::string LastName() const;
  56. std::vector<Scope *> GetAllSubScopes() const;
  57. const Scope *GetFatherScope() const;
  58. private:
  59. class ScopeImpl;
  60. std::unique_ptr<ScopeImpl> impl_;
  61. friend class ScopeBasePass;
  62. friend class ScopeTree;
  63. friend class NodeOpTypeFeature;
  64. friend class NodeAttrFeature;
  65. friend class ScopeFeature;
  66. };
  67. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY FusionScopesResult {
  68. public:
  69. FusionScopesResult();
  70. ~FusionScopesResult();
  71. void SetName(const std::string &name);
  72. void SetType(const std::string &type);
  73. void SetDescription(const std::string &description);
  74. std::string Name() const;
  75. std::vector<ge::OperatorPtr> Nodes() const;
  76. void InsertInputs(const std::string &inner_op_name, const std::vector<int32_t> &index_map);
  77. void InsertOutputs(const std::string &inner_op_name, const std::vector<int32_t> &index_map);
  78. class InnerNodeInfo {
  79. public:
  80. explicit InnerNodeInfo(const std::string &fusion_node_name);
  81. InnerNodeInfo(const std::string &fusion_node_name, const std::string &name, const std::string &type);
  82. InnerNodeInfo(InnerNodeInfo &&other) noexcept;
  83. InnerNodeInfo &operator=(InnerNodeInfo &&other) noexcept;
  84. InnerNodeInfo(const InnerNodeInfo &) = delete;
  85. InnerNodeInfo &operator=(const InnerNodeInfo &) = delete;
  86. ~InnerNodeInfo();
  87. InnerNodeInfo &SetName(const std::string &name);
  88. InnerNodeInfo &SetType(const std::string &type);
  89. InnerNodeInfo &InsertInput(const std::string &input_node, int32_t peer_out_idx);
  90. InnerNodeInfo &InsertOutput(const std::string &output_node, int32_t peer_in_idx);
  91. ge::graphStatus BuildInnerNode();
  92. ge::graphStatus SetInputFormat(const std::string &input_name, const std::string &format);
  93. ge::graphStatus SetOutputFormat(const std::string &output_name, const std::string &format);
  94. ge::graphStatus SetDynamicInputFormat(const std::string &input_name, uint32_t index, const std::string &format);
  95. ge::graphStatus SetDynamicOutputFormat(const std::string &output_name, uint32_t index, const std::string &format);
  96. ge::Operator *MutableOperator();
  97. std::string GetName() const;
  98. std::string GetType() const;
  99. std::vector<std::pair<std::string, int32_t>> GetInputs() const;
  100. std::vector<std::pair<std::string, int32_t>> GetOutputs() const;
  101. private:
  102. class InnerNodeInfoImpl;
  103. std::unique_ptr<InnerNodeInfoImpl> impl_;
  104. };
  105. InnerNodeInfo *AddInnerNode(const std::string &name, const std::string &type);
  106. InnerNodeInfo *MutableRecentInnerNode();
  107. InnerNodeInfo *MutableInnerNode(uint32_t index);
  108. ge::graphStatus CheckInnerNodesInfo();
  109. private:
  110. class FusionScopesResultImpl;
  111. std::unique_ptr<FusionScopesResultImpl> impl_;
  112. friend class ScopeGraph;
  113. friend class ScopeBasePass;
  114. friend class TensorFlowModelParser;
  115. };
  116. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeTree {
  117. public:
  118. ScopeTree();
  119. Status Init();
  120. ScopeTree(const ScopeTree &scopetree) = delete;
  121. ScopeTree &operator=(const ScopeTree &scopetree) = delete;
  122. ~ScopeTree();
  123. std::vector<Scope *> GetAllScopes() const;
  124. private:
  125. class ScopeTreeImpl;
  126. std::unique_ptr<ScopeTreeImpl> impl_;
  127. friend class ScopeGraph;
  128. friend class ScopeBasePass;
  129. };
  130. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeGraph {
  131. public:
  132. ScopeGraph();
  133. Status Init();
  134. ScopeGraph(const ScopeGraph &scope_graph) = delete;
  135. ScopeGraph &operator=(const ScopeGraph &scope_graph) = delete;
  136. ~ScopeGraph();
  137. const ScopeTree *GetScopeTree() const;
  138. std::map<std::string, ge::OperatorPtr> GetNodesMap() const;
  139. private:
  140. class ScopeGraphImpl;
  141. std::unique_ptr<ScopeGraphImpl> impl_;
  142. friend class ScopePassManager;
  143. friend class ScopeBasePass;
  144. friend class TensorFlowModelParser;
  145. };
  146. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeAttrValue {
  147. public:
  148. ScopeAttrValue();
  149. ScopeAttrValue(ScopeAttrValue const &attr_value);
  150. ScopeAttrValue &operator=(ScopeAttrValue const &attr_value);
  151. ~ScopeAttrValue();
  152. void SetIntValue(int64_t value);
  153. void SetFloatValue(float value);
  154. void SetStringValue(std::string value);
  155. void SetBoolValue(bool value);
  156. private:
  157. class ScopeAttrValueImpl;
  158. std::unique_ptr<ScopeAttrValueImpl> impl_;
  159. friend class NodeAttrFeature;
  160. };
  161. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBaseFeature {
  162. public:
  163. virtual bool Match(const Scope *scope) = 0;
  164. virtual ~ScopeBaseFeature(){};
  165. };
  166. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeOpTypeFeature : ScopeBaseFeature {
  167. public:
  168. NodeOpTypeFeature(std::string nodeType, int num, int step = 0);
  169. NodeOpTypeFeature(NodeOpTypeFeature const &feature);
  170. NodeOpTypeFeature &operator=(NodeOpTypeFeature const &feature);
  171. ~NodeOpTypeFeature();
  172. bool Match(const Scope *scope) override;
  173. private:
  174. class NodeOpTypeFeatureImpl;
  175. std::unique_ptr<NodeOpTypeFeatureImpl> impl_;
  176. };
  177. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeAttrFeature : ScopeBaseFeature {
  178. public:
  179. NodeAttrFeature(std::string nodeType, std::string attr_name, ge::DataType datatype, ScopeAttrValue attr_value);
  180. NodeAttrFeature(NodeAttrFeature const &feature);
  181. NodeAttrFeature &operator=(NodeAttrFeature const &feature);
  182. ~NodeAttrFeature();
  183. bool Match(const Scope *scope) override;
  184. private:
  185. class NodeAttrFeatureImpl;
  186. std::unique_ptr<NodeAttrFeatureImpl> impl_;
  187. };
  188. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFeature : ScopeBaseFeature {
  189. public:
  190. ScopeFeature(std::string sub_type, int32_t num, std::string suffix = "", std::string sub_scope_mask = "",
  191. int step = 0);
  192. ScopeFeature(ScopeFeature const &feature);
  193. ScopeFeature &operator=(ScopeFeature const &feature);
  194. ~ScopeFeature();
  195. bool Match(const Scope *scope) override;
  196. private:
  197. class ScopeFeatureImpl;
  198. std::unique_ptr<ScopeFeatureImpl> impl_;
  199. };
  200. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopePattern {
  201. public:
  202. ScopePattern();
  203. ~ScopePattern();
  204. ScopePattern &SetSubType(const std::string &sub_type);
  205. ScopePattern &AddNodeOpTypeFeature(NodeOpTypeFeature feature);
  206. ScopePattern &AddNodeAttrFeature(NodeAttrFeature feature);
  207. ScopePattern &AddScopeFeature(ScopeFeature feature);
  208. private:
  209. class ScopePatternImpl;
  210. std::unique_ptr<ScopePatternImpl> impl_;
  211. friend class ScopeBasePass;
  212. };
  213. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopesResult {
  214. public:
  215. ScopesResult();
  216. ScopesResult(ScopesResult const &result);
  217. ScopesResult &operator=(ScopesResult const &result);
  218. ~ScopesResult();
  219. void SetScopes(std::vector<Scope *> &scopes);
  220. void SetNodes(std::vector<ge::OperatorPtr> &nodes);
  221. private:
  222. class ScopesResultImpl;
  223. std::unique_ptr<ScopesResultImpl> impl_;
  224. friend class ScopeBasePass;
  225. };
  226. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBasePass {
  227. public:
  228. ScopeBasePass();
  229. virtual ~ScopeBasePass();
  230. protected:
  231. // Subclasses implement respective fusion strategies and build the Patterns
  232. virtual std::vector<ScopeFusionPatterns> DefinePatterns() = 0;
  233. // Define the name of the scope pass
  234. virtual std::string PassName() = 0;
  235. // Subclasses implement respective multi-scope or operator fusion methods across scopes
  236. virtual Status LastMatchScopesAndOPs(std::shared_ptr<ScopeGraph> &scope_graph,
  237. std::vector<ScopesResult> &results) = 0;
  238. // Subclasses implement their own results and set the input and output of the final fusion operator
  239. virtual void GenerateFusionResult(const std::vector<Scope *> &scopes, FusionScopesResult *fusion_rlt) = 0;
  240. private:
  241. class ScopeBasePassImpl;
  242. std::unique_ptr<ScopeBasePassImpl> impl_;
  243. friend class ge::ScopePassManager;
  244. friend class ScopeBasePassImpl;
  245. };
  246. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistry {
  247. public:
  248. using CreateFn = ScopeBasePass *(*)();
  249. ~ScopeFusionPassRegistry();
  250. static ScopeFusionPassRegistry &GetInstance() {
  251. static ScopeFusionPassRegistry instance;
  252. return instance;
  253. }
  254. void RegisterScopeFusionPass(const std::string &pass_name, CreateFn create_fn, bool is_general);
  255. private:
  256. ScopeFusionPassRegistry();
  257. class ScopeFusionPassRegistryImpl;
  258. /*lint -e148*/
  259. std::unique_ptr<ScopeFusionPassRegistryImpl> impl_;
  260. friend class TensorFlowModelParser;
  261. };
  262. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeUtil {
  263. public:
  264. static std::string StringReplaceAll(std::string str, const std::string &old_value, const std::string &new_value);
  265. static void FreeScopePatterns(ScopeFusionPatterns &patterns);
  266. static void FreeOneBatchPattern(std::vector<ScopePattern *> &one_batch_pattern);
  267. };
  268. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistrar {
  269. public:
  270. ScopeFusionPassRegistrar(const char *pass_name, ScopeBasePass *(*create_fn)(), bool is_general);
  271. ~ScopeFusionPassRegistrar() {}
  272. };
  273. #define REGISTER_SCOPE_FUSION_PASS(pass_name, scope_pass, is_general) \
  274. REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(__COUNTER__, pass_name, scope_pass, is_general)
  275. #define REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(ctr, pass_name, scope_pass, is_general) \
  276. REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general)
  277. #define REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) \
  278. static ::ge::ScopeFusionPassRegistrar register_scope_fusion_pass##ctr __attribute__((unused)) = \
  279. ::ge::ScopeFusionPassRegistrar( \
  280. pass_name, []() -> ::ge::ScopeBasePass * { return new (std::nothrow) scope_pass(); }, is_general)
  281. } // namespace ge
  282. #endif // EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示