You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

register.h 8.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_EXTERNAL_REGISTER_REGISTER_H_
  17. #define INC_EXTERNAL_REGISTER_REGISTER_H_
  18. #include <functional>
  19. #include <initializer_list>
  20. #include <map>
  21. #include <memory>
  22. #include <set>
  23. #include <string>
  24. #include <utility>
  25. #include <unordered_map>
  26. #include <vector>
  27. #include "graph/operator.h"
  28. #include "register/register_error_codes.h"
  29. #include "register/register_fmk_types.h"
  30. #include "register/register_types.h"
  31. using std::make_shared;
  32. using std::map;
  33. using std::pair;
  34. using std::string;
  35. using std::to_string;
  36. using std::unique_ptr;
  37. using std::vector;
  38. namespace ge {
  39. class Operator;
  40. class TensorDesc;
  41. class Tensor;
  42. class TBEPluginManager;
  43. } // namespace ge
  44. namespace google {
  45. namespace protobuf {
  46. class Message;
  47. }
  48. } // namespace google
  49. namespace domi {
  50. const int64_t kMaxNameLength = 1048576; // 1M
  51. enum DynamicType { kInvalid = 0, kInput = 1, kOutput = 2 };
  52. struct DynamicInputOutputInfo {
  53. DynamicType type; // input/output
  54. const char *port_name;
  55. int64_t port_name_len;
  56. const char *attr_name;
  57. int64_t attr_name_len;
  58. DynamicInputOutputInfo()
  59. : type(kInvalid), port_name(nullptr), port_name_len(0), attr_name(nullptr), attr_name_len(0) {}
  60. DynamicInputOutputInfo(DynamicType type, const char *port_name, int64_t port_name_len, const char *attr_name,
  61. int64_t attr_name_len)
  62. : type(type),
  63. port_name(port_name),
  64. port_name_len(port_name_len),
  65. attr_name(attr_name),
  66. attr_name_len(attr_name_len) {}
  67. };
  68. Status AutoMappingByOpFn(const ge::Operator &op_src, ge::Operator &op);
  69. Status AutoMappingByOpFnDynamic(const ge::Operator &op_src, ge::Operator &op,
  70. const vector<DynamicInputOutputInfo> &dynamic_name_attr_value);
  71. ATTRIBUTED_DEPRECATED(Status AutoMappingByOpFn(const ge::Operator &, ge::Operator &))
  72. Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op);
  73. ATTRIBUTED_DEPRECATED(Status AutoMappingByOpFnDynamic(const ge::Operator &, ge::Operator &,
  74. const vector<DynamicInputOutputInfo> &))
  75. Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op,
  76. std::map<std::string, std::pair<std::string, std::string>> dynamic_name_attr_value,
  77. int in_pos = -1, int out_pos = -1);
  78. Status AutoMappingSubgraphIndex(const ge::Graph &graph, const std::function<int(int data_index)> &input,
  79. const std::function<int(int netoutput_index)> &output);
  80. Status AutoMappingSubgraphIndex(const ge::Graph &graph,
  81. const std::function<Status(int data_index, int &parent_input_index)> &input,
  82. const std::function<Status(int netoutput_index, int &parent_output_index)> &output);
  83. using google::protobuf::Message;
  84. class OpRegistrationDataImpl;
  85. using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>;
  86. using ParseParamByOpFunc = std::function<domi::Status(const ge::Operator &, ge::Operator &)>;
  87. using FusionParseParamFunc =
  88. std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>;
  89. using FusionParseParamByOpFunc = std::function<domi::Status(const std::vector<ge::Operator> &, ge::Operator &)>;
  90. using ParseSubgraphFunc = std::function<Status(const std::string &subgraph_name, const ge::Graph &graph)>;
  91. using ParseOpToGraphFunc = std::function<Status(const ge::Operator &, ge::Graph &)>;
  92. using ParseSubgraphFuncV2 = std::function<Status(const ge::AscendString &subgraph_name, const ge::Graph &graph)>;
  93. class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData {
  94. public:
  95. ATTRIBUTED_DEPRECATED(OpRegistrationData(const char *))
  96. OpRegistrationData(const std::string &om_optype);
  97. OpRegistrationData(const char *om_optype);
  98. ~OpRegistrationData();
  99. OpRegistrationData &FrameworkType(const domi::FrameworkType &fmk_type);
  100. ATTRIBUTED_DEPRECATED(OpRegistrationData &OriginOpType(const std::vector<ge::AscendString> &))
  101. OpRegistrationData &OriginOpType(const std::initializer_list<std::string> &ori_optype_list);
  102. OpRegistrationData &OriginOpType(const std::vector<ge::AscendString> &ori_op_type_list);
  103. ATTRIBUTED_DEPRECATED(OpRegistrationData &OriginOpType(const char *))
  104. OpRegistrationData &OriginOpType(const std::string &ori_optype);
  105. OpRegistrationData &OriginOpType(const char *ori_op_type);
  106. OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn);
  107. OpRegistrationData &ParseParamsByOperatorFn(const ParseParamByOpFunc &parse_param_by_op_fn);
  108. OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn);
  109. OpRegistrationData &FusionParseParamsFn(const FusionParseParamByOpFunc &fusion_parse_param_fn);
  110. ATTRIBUTED_DEPRECATED(OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFuncV2 &))
  111. OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn);
  112. OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFuncV2 &subgraph_post_fn);
  113. OpRegistrationData &ImplyType(const domi::ImplyType &imply_type);
  114. ATTRIBUTED_DEPRECATED(OpRegistrationData &DelInputWithCond(int, const char *, bool))
  115. OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue);
  116. OpRegistrationData &DelInputWithCond(int input_idx, const char *attr_name, bool attr_value);
  117. ATTRIBUTED_DEPRECATED(OpRegistrationData &DelInputWithOriginalType(int, const char *))
  118. OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type);
  119. OpRegistrationData &DelInputWithOriginalType(int input_idx, const char *ori_type);
  120. OpRegistrationData &InputReorderVector(const vector<int> &input_order);
  121. OpRegistrationData &ParseOpToGraphFn(const ParseOpToGraphFunc &parse_op_to_graph_fn);
  122. domi::ImplyType GetImplyType() const;
  123. ATTRIBUTED_DEPRECATED(Status GetOmOptype(ge::AscendString &) const)
  124. std::string GetOmOptype() const;
  125. Status GetOmOptype(ge::AscendString &om_op_type) const;
  126. ATTRIBUTED_DEPRECATED(GetOriginOpTypeSet(std::set<ge::AscendString> &) const)
  127. std::set<std::string> GetOriginOpTypeSet() const;
  128. Status GetOriginOpTypeSet(std::set<ge::AscendString> &ori_op_type) const;
  129. domi::FrameworkType GetFrameworkType() const;
  130. ParseParamFunc GetParseParamFn() const;
  131. ParseParamByOpFunc GetParseParamByOperatorFn() const;
  132. FusionParseParamFunc GetFusionParseParamFn() const;
  133. FusionParseParamByOpFunc GetFusionParseParamByOpFn() const;
  134. ParseSubgraphFunc GetParseSubgraphPostFn() const;
  135. ParseOpToGraphFunc GetParseOpToGraphFn() const;
  136. Status GetParseSubgraphPostFn(ParseSubgraphFuncV2 &func) const;
  137. private:
  138. std::shared_ptr<OpRegistrationDataImpl> impl_;
  139. friend class OpRegistry;
  140. friend class OpRegistrationTbe;
  141. friend class ge::TBEPluginManager;
  142. };
  143. class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver {
  144. public:
  145. OpReceiver(OpRegistrationData &reg_data);
  146. ~OpReceiver() {}
  147. };
  148. #define REGISTER_CUSTOM_OP(name) REGISTER_CUSTOM_OP_UNIQ_HELPER(__COUNTER__, name)
  149. #define REGISTER_CUSTOM_OP_UNIQ_HELPER(ctr, name) REGISTER_CUSTOM_OP_UNIQ(ctr, name)
  150. #define REGISTER_CUSTOM_OP_UNIQ(ctr, name) \
  151. static OpReceiver register_op##ctr __attribute__((unused)) = OpRegistrationData(name)
  152. } // namespace domi
  153. namespace ge {
  154. using OpRegistrationData = domi::OpRegistrationData;
  155. using OpReceiver = domi::OpReceiver;
  156. } // namespace ge
  157. #endif // INC_EXTERNAL_REGISTER_REGISTER_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示