You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

register.h 8.0 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_EXTERNAL_REGISTER_REGISTER_H_
  17. #define INC_EXTERNAL_REGISTER_REGISTER_H_
  18. #include <google/protobuf/message.h>
  19. #include <functional>
  20. #include <initializer_list>
  21. #include <map>
  22. #include <memory>
  23. #include <set>
  24. #include <string>
  25. #include <utility>
  26. #include <unordered_map>
  27. #include <vector>
  28. #include "graph/operator.h"
  29. #include "register/register_error_codes.h"
  30. #include "register/register_fmk_types.h"
  31. #include "register/register_types.h"
  32. using std::unique_ptr;
  33. using std::map;
  34. using std::make_shared;
  35. using std::to_string;
  36. using std::string;
  37. using std::pair;
  38. using std::vector;
  39. namespace ge {
  40. class Operator;
  41. class TensorDesc;
  42. class Tensor;
  43. class TBEPluginManager;
  44. }
  45. namespace domi {
  46. struct OpOutput {
  47. ge::Operator op;
  48. // The output name of op
  49. std::string outputName;
  50. };
  51. struct InferShapeContext {
  52. ge::Operator op;
  53. // Input name, input
  54. std::map<std::string, OpOutput> inputs;
  55. };
  56. struct InferShapeOutput {
  57. std::vector<ge::TensorDesc> outputDescs;
  58. std::vector<uint32_t> realDimCnt;
  59. };
  60. enum OmgMoveTypeToAttr {
  61. OMG_MOVE_TYPE_DTYPE = 0,
  62. OMG_MOVE_TYPE_VALUE,
  63. OMG_MOVE_TYPE_SHAPE,
  64. OMG_MOVE_TYPE_FORMAT,
  65. OMG_MOVE_TYPE_AXIS,
  66. OMG_MOVE_TYPE_SCALAR_VALUE,
  67. OMG_REMOVE_TYPE_WITH_COND = 1000,
  68. };
  69. struct MoveInputToAttrStu {
  70. int inputIdx;
  71. std::string attrName;
  72. OmgMoveTypeToAttr moveType;
  73. bool attrValue;
  74. };
  75. Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op);
  76. Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op,
  77. std::map<std::string, std::pair<std::string, std::string>> dynamic_name_attr_value,
  78. int in_pos = -1, int out_pos = -1);
  79. using google::protobuf::Message;
  80. using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>;
  81. using InferShapeFunc = std::function<domi::Status(const ge::Operator &, std::vector<ge::TensorDesc> &)>;
  82. using InferShapeFuncV2 = std::function<domi::Status(const InferShapeContext &, InferShapeOutput &)>;
  83. using GetWorkspaceSizeFunc = std::function<domi::Status(const ge::Operator &, std::vector<int64_t> &)>;
  84. using UpdateOpDescFunc = std::function<domi::Status(ge::Operator &)>;
  85. using BuildTeBinFunc = std::function<domi::Status(const ge::Operator &, TEBinInfo &)>;
  86. class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData {
  87. public:
  88. OpRegistrationData(const std::string &om_optype);
  89. ~OpRegistrationData();
  90. OpRegistrationData &FrameworkType(const domi::FrameworkType &fmk_type);
  91. OpRegistrationData &OriginOpType(const std::initializer_list<std::string> &ori_optype_list);
  92. OpRegistrationData &OriginOpType(const std::string &ori_optype);
  93. OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn);
  94. OpRegistrationData &InferShapeAndTypeFn(const InferShapeFunc &inferShapeFn);
  95. OpRegistrationData &InferShapeAndTypeFn(const InferShapeFuncV2 &inferShapeFn);
  96. OpRegistrationData &UpdateOpDescFn(const UpdateOpDescFunc &updateOpDescFn);
  97. OpRegistrationData &GetWorkspaceSizeFn(const GetWorkspaceSizeFunc &getWorkspaceSizeFn);
  98. OpRegistrationData &TEBinBuildFn(const BuildTeBinFunc &buildTeBinFn);
  99. OpRegistrationData &ImplyType(const domi::ImplyType &imply_type);
  100. OpRegistrationData &Formats(const std::initializer_list<domi::tagDomiTensorFormat> &input_formats,
  101. const std::initializer_list<domi::tagDomiTensorFormat> &output_formats);
  102. OpRegistrationData &WeightFormats(const std::initializer_list<domi::tagDomiTensorFormat> &weight_formats);
  103. OpRegistrationData &InputFormat(const std::initializer_list<std::initializer_list<ge::Format>> &inputFormats);
  104. OpRegistrationData &OutputFormat(const std::initializer_list<std::initializer_list<ge::Format>> &outputFormats);
  105. OpRegistrationData &InputDataType(const std::initializer_list<std::initializer_list<ge::DataType>> &inputDataTypes);
  106. OpRegistrationData &OutputDataType(const std::initializer_list<std::initializer_list<ge::DataType>> &outputDataTypes);
  107. OpRegistrationData &InputLimitedTensorDescInfo(
  108. const std::initializer_list<std::initializer_list<ge::TensorDescInfo>> &limitedTensorDescs);
  109. OpRegistrationData &OutputLimitedTensorDescInfo(
  110. const std::initializer_list<std::initializer_list<ge::TensorDescInfo>> &limitedTensorDescs);
  111. OpRegistrationData &MoveInputToAttr(int inputIdx, const std::string &attrName, OmgMoveTypeToAttr moveType);
  112. OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue);
  113. private:
  114. domi::FrameworkType fmk_type_; // Framework type
  115. std::set<std::string> ori_optype_set_; // OP type in the original model, there may be multiple
  116. std::string om_optype_; // OP type in OM model
  117. domi::ImplyType imply_type_; // Execution type
  118. std::vector<domi::tagDomiTensorFormat> input_formats_; // Data formats supported by operator input
  119. std::vector<domi::tagDomiTensorFormat> output_formats_; // Data formats supported by operator output
  120. std::vector<domi::tagDomiTensorFormat> weight_formats_; // Data format supported by operator weight
  121. ParseParamFunc parseParamFn_; // ParseParam function
  122. InferShapeFunc inferShapeFn_; // InferShape function
  123. InferShapeFuncV2 inferShapeFnV2_; // InferShape function
  124. GetWorkspaceSizeFunc getWorkspaceSizeFn_; // GetWorkspaceSizeFunc function
  125. UpdateOpDescFunc updateOpDescFn_;
  126. BuildTeBinFunc buildTeBinFn_;
  127. // Input formats list supported by tbe operators
  128. std::vector<std::vector<ge::Format>> supportedInputFormats_;
  129. // Output formats list supported by tbe operators
  130. std::vector<std::vector<ge::Format>> supportedOutputFormats_;
  131. // Input datatypes list supported by tbe operators
  132. std::vector<std::vector<ge::DataType>> supportedInputDataTypes_;
  133. // Output datatypes list supported by tbe operators
  134. std::vector<std::vector<ge::DataType>> supportedOutputDataTypes_;
  135. // Input tensordesinfo list supported by tbe operator
  136. std::vector<std::vector<ge::TensorDescInfo>> inputLimitedTensorDescs_;
  137. // Output tensordesinfo list supported by tbe operator
  138. std::vector<std::vector<ge::TensorDescInfo>> outputLimitedTensorDescs_;
  139. std::vector<MoveInputToAttrStu> moveInputToAttrVec_;
  140. friend class OpRegistry;
  141. friend class OpRegistrationTbe;
  142. friend class ge::TBEPluginManager;
  143. };
  144. class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver {
  145. public:
  146. OpReceiver(OpRegistrationData &reg_data);
  147. ~OpReceiver() {}
  148. };
  149. #define REGISTER_CUSTOM_OP(name) REGISTER_CUSTOM_OP_UNIQ_HELPER(__COUNTER__, name)
  150. #define REGISTER_CUSTOM_OP_UNIQ_HELPER(ctr, name) REGISTER_CUSTOM_OP_UNIQ(ctr, name)
  151. #define REGISTER_CUSTOM_OP_UNIQ(ctr, name) \
  152. static OpReceiver register_op##ctr \
  153. __attribute__((unused)) = \
  154. OpRegistrationData(name)
  155. } // namespace domi
  156. namespace ge {
  157. using OpOutput = domi::OpOutput;
  158. using InferShapeContext = domi::InferShapeContext;
  159. using InferShapeOutput = domi::InferShapeOutput;
  160. using OmgMoveTypeToAttr = domi::OmgMoveTypeToAttr;
  161. using MoveInputToAttrStu = domi::MoveInputToAttrStu;
  162. using OpRegistrationData = domi::OpRegistrationData;
  163. using OpReceiver = domi::OpReceiver;
  164. }
  165. #endif // INC_EXTERNAL_REGISTER_REGISTER_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示

Contributors (1)