You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

scope_fusion_pass_register.h 17 kB


  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_
  17. #define EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_
  18. #include <memory>
  19. #include <string>
  20. #include <vector>
  21. #include <map>
  22. #include <unordered_map>
  23. #include "ge/ge_api_error_codes.h"
  24. #include "register/register_error_codes.h"
  25. #include "register/register_types.h"
  26. #include "graph/operator.h"
  27. #define CHECK_INNER_NODE_CONDITION(cond, fusion_rlt) \
  28. do { \
  29. if (!(cond)) { \
  30. if ((fusion_rlt) != nullptr) { \
  31. (fusion_rlt)->SetType(ge::kScopeInvalidType); \
  32. } \
  33. return; \
  34. } \
  35. } while (0)
  36. namespace domi {
  37. class TensorFlowModelParser;
  38. } // namespace domi
  39. namespace ge {
  40. const int32_t kFusionDisableIndex = 99999;
  41. const char *const kScopeToMultiNodes = "ScopeToMultiNodes";
  42. const char *const kScopeInvalidType = "ScopeInvalidType";
  43. const char *const kInputFromFusionScope = "InputFromFusionScope";
  44. const char *const kOutputToFusionScope = "OutputToFusionScope";
  45. class ScopePattern;
  46. using ScopeFusionPatterns = std::vector<std::vector<ScopePattern *>>;
  47. class ScopePassManager;
  48. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY Scope {
  49. public:
  50. Scope();
  51. ATTRIBUTED_DEPRECATED(Status Init(const char *, const char *, Scope *))
  52. Status Init(const std::string &name, const std::string &sub_type = "", Scope *father_scope = nullptr);
  53. Status Init(const char *name, const char *sub_type, Scope *father_scope = nullptr);
  54. ~Scope();
  55. ATTRIBUTED_DEPRECATED(Status Name(AscendString &) const)
  56. const std::string &Name() const;
  57. Status Name(AscendString &name) const;
  58. ATTRIBUTED_DEPRECATED(Status SubType(AscendString &) const)
  59. const std::string &SubType() const;
  60. Status SubType(AscendString &sub_type) const;
  61. ATTRIBUTED_DEPRECATED(Status AllNodesMap(std::unordered_map<AscendString, ge::OperatorPtr> &) const)
  62. const std::unordered_map<std::string, ge::OperatorPtr> &AllNodesMap() const;
  63. Status AllNodesMap(std::unordered_map<AscendString, ge::OperatorPtr> &node_map) const;
  64. ATTRIBUTED_DEPRECATED(Scope *GetSubScope(const char *scope_name) const)
  65. Scope *GetSubScope(const std::string &scope_name) const;
  66. Scope *GetSubScope(const char *scope_name) const;
  67. ATTRIBUTED_DEPRECATED(Status LastName(AscendString &) const)
  68. const std::string LastName() const;
  69. Status LastName(AscendString &name) const;
  70. const std::vector<Scope *> &GetAllSubScopes() const;
  71. const Scope *GetFatherScope() const;
  72. private:
  73. class ScopeImpl;
  74. std::unique_ptr<ScopeImpl> impl_;
  75. friend class ScopeBasePass;
  76. friend class ScopeTree;
  77. friend class NodeOpTypeFeature;
  78. friend class NodeAttrFeature;
  79. friend class ScopeFeature;
  80. };
  81. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY FusionScopesResult {
  82. public:
  83. FusionScopesResult();
  84. Status Init();
  85. ~FusionScopesResult();
  86. ATTRIBUTED_DEPRECATED(void SetName(const char *))
  87. void SetName(const std::string &name);
  88. void SetName(const char *name);
  89. ATTRIBUTED_DEPRECATED(void SetType(const char *))
  90. void SetType(const std::string &type);
  91. void SetType(const char *type);
  92. ATTRIBUTED_DEPRECATED(void SetDescription(const char *))
  93. void SetDescription(const std::string &description);
  94. void SetDescription(const char *description);
  95. ATTRIBUTED_DEPRECATED(const Status Name(AscendString &) const)
  96. const std::string &Name() const;
  97. const Status Name(AscendString &name) const;
  98. const std::vector<ge::OperatorPtr> &Nodes() const;
  99. ATTRIBUTED_DEPRECATED(void InsertInputs(const char *, const std::vector<int32_t> &))
  100. void InsertInputs(const std::string &inner_op_name, const std::vector<int32_t> &index_map);
  101. void InsertInputs(const char *inner_op_name, const std::vector<int32_t> &index_map);
  102. ATTRIBUTED_DEPRECATED(void InsertOutputs(const char *, const std::vector<int32_t> &))
  103. void InsertOutputs(const std::string &inner_op_name, const std::vector<int32_t> &index_map);
  104. void InsertOutputs(const char *inner_op_name, const std::vector<int32_t> &index_map);
  105. class InnerNodeInfo {
  106. public:
  107. ATTRIBUTED_DEPRECATED(InnerNodeInfo(const char *))
  108. explicit InnerNodeInfo(const std::string &fusion_node_name);
  109. explicit InnerNodeInfo(const char *fusion_node_name);
  110. ATTRIBUTED_DEPRECATED(InnerNodeInfo(const char *, const char *, const char *))
  111. InnerNodeInfo(const std::string &fusion_node_name, const std::string &name, const std::string &type);
  112. InnerNodeInfo(const char *fusion_node_name, const char *name, const char *type);
  113. InnerNodeInfo(InnerNodeInfo &&other) noexcept;
  114. InnerNodeInfo &operator=(InnerNodeInfo &&other) noexcept;
  115. InnerNodeInfo(const InnerNodeInfo &) = delete;
  116. InnerNodeInfo &operator=(const InnerNodeInfo &) = delete;
  117. ~InnerNodeInfo();
  118. ATTRIBUTED_DEPRECATED(InnerNodeInfo &SetName(const char *))
  119. InnerNodeInfo &SetName(const std::string &name);
  120. InnerNodeInfo &SetName(const char *name);
  121. ATTRIBUTED_DEPRECATED(InnerNodeInfo &SetType(const char *))
  122. InnerNodeInfo &SetType(const std::string &type);
  123. InnerNodeInfo &SetType(const char *type);
  124. ATTRIBUTED_DEPRECATED(InnerNodeInfo &InsertInput(const char *, int32_t))
  125. InnerNodeInfo &InsertInput(const std::string &input_node, int32_t peer_out_idx);
  126. InnerNodeInfo &InsertInput(const char *input_node, int32_t peer_out_idx);
  127. ATTRIBUTED_DEPRECATED(InnerNodeInfo &InsertOutput(const char *, int32_t))
  128. InnerNodeInfo &InsertOutput(const std::string &output_node, int32_t peer_in_idx);
  129. InnerNodeInfo &InsertOutput(const char *output_node, int32_t peer_in_idx);
  130. ge::graphStatus BuildInnerNode();
  131. ATTRIBUTED_DEPRECATED(ge::graphStatus SetInputFormat(const char *, const char *))
  132. ge::graphStatus SetInputFormat(const std::string &input_name, const std::string &format);
  133. ge::graphStatus SetInputFormat(const char *input_name, const char *format);
  134. ATTRIBUTED_DEPRECATED(ge::graphStatus SetOutputFormat(const char *, const char *))
  135. ge::graphStatus SetOutputFormat(const std::string &output_name, const std::string &format);
  136. ge::graphStatus SetOutputFormat(const char *output_name, const char *format);
  137. ATTRIBUTED_DEPRECATED(ge::graphStatus SetDynamicInputFormat(const char *, uint32_t index, const char *))
  138. ge::graphStatus SetDynamicInputFormat(const std::string &input_name, uint32_t index, const std::string &format);
  139. ge::graphStatus SetDynamicInputFormat(const char *input_name, uint32_t index, const char *format);
  140. ATTRIBUTED_DEPRECATED(ge::graphStatus SetDynamicOutputFormat(const char *, uint32_t, const char *))
  141. ge::graphStatus SetDynamicOutputFormat(const std::string &output_name, uint32_t index, const std::string &format);
  142. ge::graphStatus SetDynamicOutputFormat(const char *output_name, uint32_t index, const char *format);
  143. ge::Operator *MutableOperator();
  144. ATTRIBUTED_DEPRECATED(ge::graphStatus GetName(AscendString &) const)
  145. std::string GetName() const;
  146. ge::graphStatus GetName(AscendString &name) const;
  147. ATTRIBUTED_DEPRECATED(ge::graphStatus GetType(AscendString &) const)
  148. std::string GetType() const;
  149. ge::graphStatus GetType(AscendString &type) const;
  150. ATTRIBUTED_DEPRECATED(ge::graphStatus GetInputs(std::vector<std::pair<AscendString, int32_t>> &) const)
  151. std::vector<std::pair<std::string, int32_t>> GetInputs() const;
  152. ge::graphStatus GetInputs(std::vector<std::pair<AscendString, int32_t>> &inputs) const;
  153. ATTRIBUTED_DEPRECATED(ge::graphStatus GetOutputs(std::vector<std::pair<AscendString, int32_t>> &) const)
  154. std::vector<std::pair<std::string, int32_t>> GetOutputs() const;
  155. ge::graphStatus GetOutputs(std::vector<std::pair<AscendString, int32_t>> &outputs) const;
  156. private:
  157. class InnerNodeInfoImpl;
  158. std::unique_ptr<InnerNodeInfoImpl> impl_;
  159. };
  160. ATTRIBUTED_DEPRECATED(InnerNodeInfo *AddInnerNode(const char *, const char *))
  161. InnerNodeInfo *AddInnerNode(const std::string &name, const std::string &type);
  162. InnerNodeInfo *AddInnerNode(const char *name, const char *type);
  163. InnerNodeInfo *MutableRecentInnerNode();
  164. InnerNodeInfo *MutableInnerNode(uint32_t index);
  165. ge::graphStatus CheckInnerNodesInfo();
  166. private:
  167. class FusionScopesResultImpl;
  168. std::unique_ptr<FusionScopesResultImpl> impl_;
  169. friend class ScopeGraph;
  170. friend class ScopeBasePass;
  171. friend class TensorFlowModelParser;
  172. };
  173. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeTree {
  174. public:
  175. ScopeTree();
  176. Status Init();
  177. ScopeTree(const ScopeTree &scopetree) = delete;
  178. ScopeTree &operator=(const ScopeTree &scopetree) = delete;
  179. ~ScopeTree();
  180. const std::vector<Scope *> &GetAllScopes() const;
  181. private:
  182. class ScopeTreeImpl;
  183. std::unique_ptr<ScopeTreeImpl> impl_;
  184. friend class ScopeGraph;
  185. friend class ScopeBasePass;
  186. };
  187. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeGraph {
  188. public:
  189. ScopeGraph();
  190. Status Init();
  191. ScopeGraph(const ScopeGraph &scope_graph) = delete;
  192. ScopeGraph &operator=(const ScopeGraph &scope_graph) = delete;
  193. ~ScopeGraph();
  194. const ScopeTree *GetScopeTree() const;
  195. ATTRIBUTED_DEPRECATED(Status GetNodesMap(std::unordered_map<AscendString, ge::OperatorPtr> &) const)
  196. const std::unordered_map<std::string, ge::OperatorPtr> &GetNodesMap() const;
  197. Status GetNodesMap(std::unordered_map<AscendString, ge::OperatorPtr> &nodes_map) const;
  198. private:
  199. class ScopeGraphImpl;
  200. std::unique_ptr<ScopeGraphImpl> impl_;
  201. friend class ScopePassManager;
  202. friend class ScopeBasePass;
  203. friend class TensorFlowModelParser;
  204. };
  205. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeAttrValue {
  206. public:
  207. ScopeAttrValue();
  208. ScopeAttrValue(ScopeAttrValue const &attr_value);
  209. ScopeAttrValue &operator=(ScopeAttrValue const &attr_value);
  210. ~ScopeAttrValue();
  211. void SetIntValue(int64_t value);
  212. void SetFloatValue(float value);
  213. ATTRIBUTED_DEPRECATED(void SetStringValue(const char *))
  214. void SetStringValue(std::string value);
  215. void SetStringValue(const char *value);
  216. void SetBoolValue(bool value);
  217. private:
  218. class ScopeAttrValueImpl;
  219. std::unique_ptr<ScopeAttrValueImpl> impl_;
  220. friend class NodeAttrFeature;
  221. };
  222. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBaseFeature {
  223. public:
  224. virtual bool Match(const Scope *scope) = 0;
  225. virtual ~ScopeBaseFeature(){};
  226. };
  227. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeOpTypeFeature : ScopeBaseFeature {
  228. public:
  229. ATTRIBUTED_DEPRECATED(NodeOpTypeFeature(const char *, int, int))
  230. NodeOpTypeFeature(std::string nodeType, int num, int step = 0);
  231. NodeOpTypeFeature(const char *node_type, int num, int step = 0);
  232. NodeOpTypeFeature(NodeOpTypeFeature const &feature);
  233. NodeOpTypeFeature &operator=(NodeOpTypeFeature const &feature);
  234. ~NodeOpTypeFeature();
  235. bool Match(const Scope *scope) override;
  236. private:
  237. class NodeOpTypeFeatureImpl;
  238. std::unique_ptr<NodeOpTypeFeatureImpl> impl_;
  239. };
  240. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeAttrFeature : ScopeBaseFeature {
  241. public:
  242. ATTRIBUTED_DEPRECATED(NodeAttrFeature(const char *, const char *, ge::DataType, ScopeAttrValue &))
  243. NodeAttrFeature(std::string nodeType, std::string attr_name, ge::DataType datatype, ScopeAttrValue &attr_value);
  244. NodeAttrFeature(const char *node_type, const char *attr_name, ge::DataType datatype, ScopeAttrValue &attr_value);
  245. NodeAttrFeature(NodeAttrFeature const &feature);
  246. NodeAttrFeature &operator=(NodeAttrFeature const &feature);
  247. ~NodeAttrFeature();
  248. bool Match(const Scope *scope) override;
  249. private:
  250. class NodeAttrFeatureImpl;
  251. std::unique_ptr<NodeAttrFeatureImpl> impl_;
  252. };
  253. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFeature : ScopeBaseFeature {
  254. public:
  255. ATTRIBUTED_DEPRECATED(ScopeFeature(const char *, int32_t, const char *, const char *, int))
  256. ScopeFeature(std::string sub_type, int32_t num, std::string suffix = "", std::string sub_scope_mask = "",
  257. int step = 0);
  258. ScopeFeature(const char *sub_type, int32_t num, const char *suffix, const char *sub_scope_mask, int step = 0);
  259. ScopeFeature(ScopeFeature const &feature);
  260. ScopeFeature &operator=(ScopeFeature const &feature);
  261. ~ScopeFeature();
  262. bool Match(const Scope *scope) override;
  263. private:
  264. class ScopeFeatureImpl;
  265. std::unique_ptr<ScopeFeatureImpl> impl_;
  266. };
  267. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopePattern {
  268. public:
  269. ScopePattern();
  270. ~ScopePattern();
  271. ATTRIBUTED_DEPRECATED(ScopePattern &SetSubType(const char *))
  272. ScopePattern &SetSubType(const std::string &sub_type);
  273. ScopePattern &SetSubType(const char *sub_type);
  274. ScopePattern &AddNodeOpTypeFeature(NodeOpTypeFeature feature);
  275. ScopePattern &AddNodeAttrFeature(NodeAttrFeature feature);
  276. ScopePattern &AddScopeFeature(ScopeFeature feature);
  277. private:
  278. class ScopePatternImpl;
  279. std::unique_ptr<ScopePatternImpl> impl_;
  280. friend class ScopeBasePass;
  281. };
  282. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopesResult {
  283. public:
  284. ScopesResult();
  285. ScopesResult(ScopesResult const &result);
  286. ScopesResult &operator=(ScopesResult const &result);
  287. ~ScopesResult();
  288. void SetScopes(std::vector<Scope *> &scopes);
  289. void SetNodes(std::vector<ge::OperatorPtr> &nodes);
  290. private:
  291. class ScopesResultImpl;
  292. std::unique_ptr<ScopesResultImpl> impl_;
  293. friend class ScopeBasePass;
  294. };
  295. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBasePass {
  296. public:
  297. ScopeBasePass();
  298. virtual ~ScopeBasePass();
  299. protected:
  300. // Subclasses implement respective fusion strategies and build the Patterns
  301. virtual std::vector<ScopeFusionPatterns> DefinePatterns() = 0;
  302. // Define the name of the scope pass
  303. virtual std::string PassName() = 0;
  304. // Subclasses implement respective multi-scope or operator fusion methods across scopes
  305. virtual Status LastMatchScopesAndOPs(std::shared_ptr<ScopeGraph> &scope_graph,
  306. std::vector<ScopesResult> &results) = 0;
  307. // Subclasses implement their own results and set the input and output of the final fusion operator
  308. virtual void GenerateFusionResult(const std::vector<Scope *> &scopes, FusionScopesResult *fusion_rlt) = 0;
  309. private:
  310. class ScopeBasePassImpl;
  311. std::unique_ptr<ScopeBasePassImpl> impl_;
  312. friend class ge::ScopePassManager;
  313. friend class ScopeBasePassImpl;
  314. };
  315. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistry {
  316. public:
  317. using CreateFn = ScopeBasePass *(*)();
  318. ~ScopeFusionPassRegistry();
  319. static ScopeFusionPassRegistry &GetInstance() {
  320. static ScopeFusionPassRegistry instance;
  321. return instance;
  322. }
  323. ATTRIBUTED_DEPRECATED(void RegisterScopeFusionPass(const char *, CreateFn, bool))
  324. void RegisterScopeFusionPass(const std::string &pass_name, CreateFn create_fn, bool is_general);
  325. void RegisterScopeFusionPass(const char *pass_name, CreateFn create_fn, bool is_general);
  326. private:
  327. ScopeFusionPassRegistry();
  328. class ScopeFusionPassRegistryImpl;
  329. std::unique_ptr<ScopeFusionPassRegistryImpl> impl_;
  330. friend class TensorFlowModelParser;
  331. };
  332. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeUtil {
  333. public:
  334. ATTRIBUTED_DEPRECATED(static AscendString StringReplaceAll(const char *, const char *, const char *))
  335. static std::string StringReplaceAll(std::string str, const std::string &old_value, const std::string &new_value);
  336. static AscendString StringReplaceAll(const char *str, const char *old_value, const char *new_value);
  337. static void FreeScopePatterns(ScopeFusionPatterns &patterns);
  338. static void FreeOneBatchPattern(std::vector<ScopePattern *> &one_batch_pattern);
  339. };
  340. class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistrar {
  341. public:
  342. ScopeFusionPassRegistrar(const char *pass_name, ScopeBasePass *(*create_fn)(), bool is_general);
  343. ~ScopeFusionPassRegistrar() {}
  344. };
  345. #define REGISTER_SCOPE_FUSION_PASS(pass_name, scope_pass, is_general) \
  346. REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(__COUNTER__, pass_name, scope_pass, is_general)
  347. #define REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(ctr, pass_name, scope_pass, is_general) \
  348. REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general)
  349. #define REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) \
  350. static ::ge::ScopeFusionPassRegistrar register_scope_fusion_pass##ctr __attribute__((unused)) = \
  351. ::ge::ScopeFusionPassRegistrar( \
  352. pass_name, []() -> ::ge::ScopeBasePass * { return new (std::nothrow) scope_pass(); }, is_general)
  353. } // namespace ge
  354. #endif // EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示