You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

scope_fusion_pass_register.h 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_
  17. #define EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include <map>
  22. #include <unordered_map>
  23. #include "ge/ge_api_error_codes.h"
  24. #include "register/register_error_codes.h"
  25. #include "register/register_types.h"
  26. #include "graph/operator.h"
  27. #define CHECK_INNER_NODE_CONDITION(cond, fusion_rlt) \
  28. do { \
  29. if (!(cond)) { \
  30. if ((fusion_rlt) != nullptr) { \
  31. (fusion_rlt)->SetType(ge::kScopeInvalidType); \
  32. } \
  33. return; \
  34. } \
  35. } while (0)
  36. namespace domi {
  37. class TensorFlowModelParser;
  38. } // namespace domi
  39. namespace ge {
  40. const int32_t kFusionDisableIndex = 99999;
  41. const char *const kScopeToMultiNodes = "ScopeToMultiNodes";
  42. const char *const kScopeInvalidType = "ScopeInvalidType";
  43. const char *const kInputFromFusionScope = "InputFromFusionScope";
  44. const char *const kOutputToFusionScope = "OutputToFusionScope";
  45. class ScopePattern;
  46. using ScopeFusionPatterns = std::vector<std::vector<ScopePattern *>>;
  47. class ScopePassManager;
  48. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY Scope {
  49. public:
  50. Scope();
  51. Status Init(const std::string &name, const std::string &sub_type = "", Scope *father_scope = nullptr);
  52. ~Scope();
  53. const std::string &Name() const;
  54. const std::string &SubType() const;
  55. const std::unordered_map<std::string, ge::OperatorPtr> &AllNodesMap() const;
  56. Scope *GetSubScope(const std::string &scope_name) const;
  57. const std::string LastName() const;
  58. const std::vector<Scope *> &GetAllSubScopes() const;
  59. const Scope *GetFatherScope() const;
  60. private:
  61. class ScopeImpl;
  62. std::unique_ptr<ScopeImpl> impl_;
  63. friend class ScopeBasePass;
  64. friend class ScopeTree;
  65. friend class NodeOpTypeFeature;
  66. friend class NodeAttrFeature;
  67. friend class ScopeFeature;
  68. };
  69. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY FusionScopesResult {
  70. public:
  71. FusionScopesResult();
  72. Status Init();
  73. ~FusionScopesResult();
  74. void SetName(const std::string &name);
  75. void SetType(const std::string &type);
  76. void SetDescription(const std::string &description);
  77. const std::string &Name() const;
  78. const std::vector<ge::OperatorPtr> &Nodes() const;
  79. void InsertInputs(const std::string &inner_op_name, const std::vector<int32_t> &index_map);
  80. void InsertOutputs(const std::string &inner_op_name, const std::vector<int32_t> &index_map);
  81. class InnerNodeInfo {
  82. public:
  83. explicit InnerNodeInfo(const std::string &fusion_node_name);
  84. InnerNodeInfo(const std::string &fusion_node_name, const std::string &name, const std::string &type);
  85. InnerNodeInfo(InnerNodeInfo &&other) noexcept;
  86. InnerNodeInfo &operator=(InnerNodeInfo &&other) noexcept;
  87. InnerNodeInfo(const InnerNodeInfo &) = delete;
  88. InnerNodeInfo &operator=(const InnerNodeInfo &) = delete;
  89. ~InnerNodeInfo();
  90. InnerNodeInfo &SetName(const std::string &name);
  91. InnerNodeInfo &SetType(const std::string &type);
  92. InnerNodeInfo &InsertInput(const std::string &input_node, int32_t peer_out_idx);
  93. InnerNodeInfo &InsertOutput(const std::string &output_node, int32_t peer_in_idx);
  94. ge::graphStatus BuildInnerNode();
  95. ge::graphStatus SetInputFormat(const std::string &input_name, const std::string &format);
  96. ge::graphStatus SetOutputFormat(const std::string &output_name, const std::string &format);
  97. ge::graphStatus SetDynamicInputFormat(const std::string &input_name, uint32_t index, const std::string &format);
  98. ge::graphStatus SetDynamicOutputFormat(const std::string &output_name, uint32_t index, const std::string &format);
  99. ge::Operator *MutableOperator();
  100. std::string GetName() const;
  101. std::string GetType() const;
  102. std::vector<std::pair<std::string, int32_t>> GetInputs() const;
  103. std::vector<std::pair<std::string, int32_t>> GetOutputs() const;
  104. private:
  105. class InnerNodeInfoImpl;
  106. std::unique_ptr<InnerNodeInfoImpl> impl_;
  107. };
  108. InnerNodeInfo *AddInnerNode(const std::string &name, const std::string &type);
  109. InnerNodeInfo *MutableRecentInnerNode();
  110. InnerNodeInfo *MutableInnerNode(uint32_t index);
  111. ge::graphStatus CheckInnerNodesInfo();
  112. private:
  113. class FusionScopesResultImpl;
  114. std::unique_ptr<FusionScopesResultImpl> impl_;
  115. friend class ScopeGraph;
  116. friend class ScopeBasePass;
  117. friend class TensorFlowModelParser;
  118. };
  119. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeTree {
  120. public:
  121. ScopeTree();
  122. Status Init();
  123. ScopeTree(const ScopeTree &scopetree) = delete;
  124. ScopeTree &operator=(const ScopeTree &scopetree) = delete;
  125. ~ScopeTree();
  126. const std::vector<Scope *> &GetAllScopes() const;
  127. private:
  128. class ScopeTreeImpl;
  129. std::unique_ptr<ScopeTreeImpl> impl_;
  130. friend class ScopeGraph;
  131. friend class ScopeBasePass;
  132. };
  133. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeGraph {
  134. public:
  135. ScopeGraph();
  136. Status Init();
  137. ScopeGraph(const ScopeGraph &scope_graph) = delete;
  138. ScopeGraph &operator=(const ScopeGraph &scope_graph) = delete;
  139. ~ScopeGraph();
  140. const ScopeTree *GetScopeTree() const;
  141. const std::unordered_map<std::string, ge::OperatorPtr> &GetNodesMap() const;
  142. private:
  143. class ScopeGraphImpl;
  144. std::unique_ptr<ScopeGraphImpl> impl_;
  145. friend class ScopePassManager;
  146. friend class ScopeBasePass;
  147. friend class TensorFlowModelParser;
  148. };
  149. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeAttrValue {
  150. public:
  151. ScopeAttrValue();
  152. ScopeAttrValue(ScopeAttrValue const &attr_value);
  153. ScopeAttrValue &operator=(ScopeAttrValue const &attr_value);
  154. ~ScopeAttrValue();
  155. void SetIntValue(int64_t value);
  156. void SetFloatValue(float value);
  157. void SetStringValue(std::string value);
  158. void SetBoolValue(bool value);
  159. private:
  160. class ScopeAttrValueImpl;
  161. std::unique_ptr<ScopeAttrValueImpl> impl_;
  162. friend class NodeAttrFeature;
  163. };
  164. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBaseFeature {
  165. public:
  166. virtual bool Match(const Scope *scope) = 0;
  167. virtual ~ScopeBaseFeature(){};
  168. };
  169. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeOpTypeFeature : ScopeBaseFeature {
  170. public:
  171. NodeOpTypeFeature(std::string nodeType, int num, int step = 0);
  172. NodeOpTypeFeature(NodeOpTypeFeature const &feature);
  173. NodeOpTypeFeature &operator=(NodeOpTypeFeature const &feature);
  174. ~NodeOpTypeFeature();
  175. bool Match(const Scope *scope) override;
  176. private:
  177. class NodeOpTypeFeatureImpl;
  178. std::unique_ptr<NodeOpTypeFeatureImpl> impl_;
  179. };
  180. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeAttrFeature : ScopeBaseFeature {
  181. public:
  182. NodeAttrFeature(std::string nodeType, std::string attr_name, ge::DataType datatype, ScopeAttrValue &attr_value);
  183. NodeAttrFeature(NodeAttrFeature const &feature);
  184. NodeAttrFeature &operator=(NodeAttrFeature const &feature);
  185. ~NodeAttrFeature();
  186. bool Match(const Scope *scope) override;
  187. private:
  188. class NodeAttrFeatureImpl;
  189. std::unique_ptr<NodeAttrFeatureImpl> impl_;
  190. };
  191. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFeature : ScopeBaseFeature {
  192. public:
  193. ScopeFeature(std::string sub_type, int32_t num, std::string suffix = "", std::string sub_scope_mask = "",
  194. int step = 0);
  195. ScopeFeature(ScopeFeature const &feature);
  196. ScopeFeature &operator=(ScopeFeature const &feature);
  197. ~ScopeFeature();
  198. bool Match(const Scope *scope) override;
  199. private:
  200. class ScopeFeatureImpl;
  201. std::unique_ptr<ScopeFeatureImpl> impl_;
  202. };
  203. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopePattern {
  204. public:
  205. ScopePattern();
  206. ~ScopePattern();
  207. ScopePattern &SetSubType(const std::string &sub_type);
  208. ScopePattern &AddNodeOpTypeFeature(NodeOpTypeFeature feature);
  209. ScopePattern &AddNodeAttrFeature(NodeAttrFeature feature);
  210. ScopePattern &AddScopeFeature(ScopeFeature feature);
  211. private:
  212. class ScopePatternImpl;
  213. std::unique_ptr<ScopePatternImpl> impl_;
  214. friend class ScopeBasePass;
  215. };
  216. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopesResult {
  217. public:
  218. ScopesResult();
  219. ScopesResult(ScopesResult const &result);
  220. ScopesResult &operator=(ScopesResult const &result);
  221. ~ScopesResult();
  222. void SetScopes(std::vector<Scope *> &scopes);
  223. void SetNodes(std::vector<ge::OperatorPtr> &nodes);
  224. private:
  225. class ScopesResultImpl;
  226. std::unique_ptr<ScopesResultImpl> impl_;
  227. friend class ScopeBasePass;
  228. };
  229. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBasePass {
  230. public:
  231. ScopeBasePass();
  232. virtual ~ScopeBasePass();
  233. protected:
  234. // Subclasses implement respective fusion strategies and build the Patterns
  235. virtual std::vector<ScopeFusionPatterns> DefinePatterns() = 0;
  236. // Define the name of the scope pass
  237. virtual std::string PassName() = 0;
  238. // Subclasses implement respective multi-scope or operator fusion methods across scopes
  239. virtual Status LastMatchScopesAndOPs(std::shared_ptr<ScopeGraph> &scope_graph,
  240. std::vector<ScopesResult> &results) = 0;
  241. // Subclasses implement their own results and set the input and output of the final fusion operator
  242. virtual void GenerateFusionResult(const std::vector<Scope *> &scopes, FusionScopesResult *fusion_rlt) = 0;
  243. private:
  244. class ScopeBasePassImpl;
  245. std::unique_ptr<ScopeBasePassImpl> impl_;
  246. friend class ge::ScopePassManager;
  247. friend class ScopeBasePassImpl;
  248. };
  249. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistry {
  250. public:
  251. using CreateFn = ScopeBasePass *(*)();
  252. ~ScopeFusionPassRegistry();
  253. static ScopeFusionPassRegistry &GetInstance() {
  254. static ScopeFusionPassRegistry instance;
  255. return instance;
  256. }
  257. void RegisterScopeFusionPass(const std::string &pass_name, CreateFn create_fn, bool is_general);
  258. private:
  259. ScopeFusionPassRegistry();
  260. class ScopeFusionPassRegistryImpl;
  261. std::unique_ptr<ScopeFusionPassRegistryImpl> impl_;
  262. friend class TensorFlowModelParser;
  263. };
  264. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeUtil {
  265. public:
  266. static std::string StringReplaceAll(std::string str, const std::string &old_value, const std::string &new_value);
  267. static void FreeScopePatterns(ScopeFusionPatterns &patterns);
  268. static void FreeOneBatchPattern(std::vector<ScopePattern *> &one_batch_pattern);
  269. };
  270. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistrar {
  271. public:
  272. ScopeFusionPassRegistrar(const char *pass_name, ScopeBasePass *(*create_fn)(), bool is_general);
  273. ~ScopeFusionPassRegistrar() {}
  274. };
  275. #define REGISTER_SCOPE_FUSION_PASS(pass_name, scope_pass, is_general) \
  276. REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(__COUNTER__, pass_name, scope_pass, is_general)
  277. #define REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(ctr, pass_name, scope_pass, is_general) \
  278. REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general)
  279. #define REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) \
  280. static ::ge::ScopeFusionPassRegistrar register_scope_fusion_pass##ctr __attribute__((unused)) = \
  281. ::ge::ScopeFusionPassRegistrar( \
  282. pass_name, []() -> ::ge::ScopeBasePass * { return new (std::nothrow) scope_pass(); }, is_general)
  283. } // namespace ge
  284. #endif // EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示