You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

register.h 6.2 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_EXTERNAL_REGISTER_REGISTER_H_
  17. #define INC_EXTERNAL_REGISTER_REGISTER_H_
  18. #include <functional>
  19. #include <initializer_list>
  20. #include <map>
  21. #include <memory>
  22. #include <set>
  23. #include <string>
  24. #include <utility>
  25. #include <unordered_map>
  26. #include <vector>
  27. #include "graph/operator.h"
  28. #include "register/register_error_codes.h"
  29. #include "register/register_fmk_types.h"
  30. #include "register/register_types.h"
  31. using std::make_shared;
  32. using std::map;
  33. using std::pair;
  34. using std::string;
  35. using std::to_string;
  36. using std::unique_ptr;
  37. using std::vector;
  38. namespace ge {
  39. class Operator;
  40. class TensorDesc;
  41. class Tensor;
  42. class TBEPluginManager;
  43. } // namespace ge
  44. namespace google {
  45. namespace protobuf {
  46. class Message;
  47. }
  48. } // namespace google
  49. namespace domi {
  50. const int64_t kMaxNameLength = 1048576; // 1M
  51. enum DynamicType { kInvalid = 0, kInput = 1, kOutput = 2 };
  52. struct DynamicInputOutputInfo {
  53. DynamicType type; // input/output
  54. const char *port_name;
  55. int64_t port_name_len;
  56. const char *attr_name;
  57. int64_t attr_name_len;
  58. DynamicInputOutputInfo()
  59. : type(kInvalid), port_name(nullptr), port_name_len(0), attr_name(nullptr), attr_name_len(0) {}
  60. DynamicInputOutputInfo(DynamicType type, const char *port_name, int64_t port_name_len, const char *attr_name,
  61. int64_t attr_name_len)
  62. : type(type),
  63. port_name(port_name),
  64. port_name_len(port_name_len),
  65. attr_name(attr_name),
  66. attr_name_len(attr_name_len) {}
  67. };
  68. Status AutoMappingByOpFn(const ge::Operator &op_src, ge::Operator &op);
  69. Status AutoMappingByOpFnDynamic(const ge::Operator &op_src, ge::Operator &op,
  70. const vector<DynamicInputOutputInfo> &dynamic_name_attr_value);
  71. Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op);
  72. Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op,
  73. std::map<std::string, std::pair<std::string, std::string>> dynamic_name_attr_value,
  74. int in_pos = -1, int out_pos = -1);
  75. Status AutoMappingSubgraphIndex(const ge::Graph &graph, const std::function<int(int data_index)> &input,
  76. const std::function<int(int netoutput_index)> &output);
  77. Status AutoMappingSubgraphIndex(const ge::Graph &graph,
  78. const std::function<Status(int data_index, int &parent_input_index)> &input,
  79. const std::function<Status(int netoutput_index, int &parent_output_index)> &output);
  80. using google::protobuf::Message;
  81. class OpRegistrationDataImpl;
  82. using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>;
  83. using ParseParamByOpFunc = std::function<domi::Status(const ge::Operator &, ge::Operator &)>;
  84. using FusionParseParamFunc =
  85. std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>;
  86. using FusionParseParamByOpFunc = std::function<domi::Status(const std::vector<ge::Operator> &, ge::Operator &)>;
  87. using ParseSubgraphFunc = std::function<Status(const std::string &subgraph_name, const ge::Graph &graph)>;
  88. class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData {
  89. public:
  90. OpRegistrationData(const std::string &om_optype);
  91. ~OpRegistrationData();
  92. OpRegistrationData &FrameworkType(const domi::FrameworkType &fmk_type);
  93. OpRegistrationData &OriginOpType(const std::initializer_list<std::string> &ori_optype_list);
  94. OpRegistrationData &OriginOpType(const std::string &ori_optype);
  95. OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn);
  96. OpRegistrationData &ParseParamsByOperatorFn(const ParseParamByOpFunc &parse_param_by_op_fn);
  97. OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn);
  98. OpRegistrationData &FusionParseParamsFn(const FusionParseParamByOpFunc &fusion_parse_param_fn);
  99. OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn);
  100. OpRegistrationData &ImplyType(const domi::ImplyType &imply_type);
  101. OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue);
  102. OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type);
  103. OpRegistrationData &InputReorderVector(const vector<int> &input_order);
  104. domi::ImplyType GetImplyType() const;
  105. std::string GetOmOptype() const;
  106. std::set<std::string> GetOriginOpTypeSet() const;
  107. domi::FrameworkType GetFrameworkType() const;
  108. ParseParamFunc GetParseParamFn() const;
  109. ParseParamByOpFunc GetParseParamByOperatorFn() const;
  110. FusionParseParamFunc GetFusionParseParamFn() const;
  111. FusionParseParamByOpFunc GetFusionParseParamByOpFn() const;
  112. ParseSubgraphFunc GetParseSubgraphPostFn() const;
  113. private:
  114. std::shared_ptr<OpRegistrationDataImpl> impl_;
  115. friend class OpRegistry;
  116. friend class OpRegistrationTbe;
  117. friend class ge::TBEPluginManager;
  118. };
  119. class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver {
  120. public:
  121. OpReceiver(OpRegistrationData &reg_data);
  122. ~OpReceiver() {}
  123. };
  124. #define REGISTER_CUSTOM_OP(name) REGISTER_CUSTOM_OP_UNIQ_HELPER(__COUNTER__, name)
  125. #define REGISTER_CUSTOM_OP_UNIQ_HELPER(ctr, name) REGISTER_CUSTOM_OP_UNIQ(ctr, name)
  126. #define REGISTER_CUSTOM_OP_UNIQ(ctr, name) \
  127. static OpReceiver register_op##ctr __attribute__((unused)) = OpRegistrationData(name)
  128. } // namespace domi
  129. namespace ge {
  130. using OpRegistrationData = domi::OpRegistrationData;
  131. using OpReceiver = domi::OpReceiver;
  132. } // namespace ge
  133. #endif // INC_EXTERNAL_REGISTER_REGISTER_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示