You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

register.h 4.6 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_EXTERNAL_REGISTER_REGISTER_H_
  17. #define INC_EXTERNAL_REGISTER_REGISTER_H_
  18. #include <functional>
  19. #include <initializer_list>
  20. #include <map>
  21. #include <memory>
  22. #include <set>
  23. #include <string>
  24. #include <utility>
  25. #include <unordered_map>
  26. #include <vector>
  27. #include "graph/operator.h"
  28. #include "register/register_error_codes.h"
  29. #include "register/register_fmk_types.h"
  30. #include "register/register_types.h"
  31. using std::make_shared;
  32. using std::map;
  33. using std::pair;
  34. using std::string;
  35. using std::to_string;
  36. using std::unique_ptr;
  37. using std::vector;
  38. namespace ge {
  39. class Operator;
  40. class TensorDesc;
  41. class Tensor;
  42. class TBEPluginManager;
  43. } // namespace ge
  44. namespace google {
  45. namespace protobuf {
  46. class Message;
  47. }
  48. } // namespace google
  49. namespace domi {
  50. Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op);
  51. Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op,
  52. std::map<std::string, std::pair<std::string, std::string>> dynamic_name_attr_value,
  53. int in_pos = -1, int out_pos = -1);
  54. Status AutoMappingSubgraphIndex(const ge::Graph &graph, const std::function<int(int data_index)> &input,
  55. const std::function<int(int netoutput_index)> &output);
  56. Status AutoMappingSubgraphIndex(const ge::Graph &graph,
  57. const std::function<Status(int data_index, int &parent_input_index)> &input,
  58. const std::function<Status(int netoutput_index, int &parent_output_index)> &output);
  59. using google::protobuf::Message;
  60. class OpRegistrationDataImpl;
  61. using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>;
  62. using FusionParseParamFunc =
  63. std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>;
  64. using ParseSubgraphFunc = std::function<Status(const std::string &subgraph_name, const ge::Graph &graph)>;
  65. class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData {
  66. public:
  67. OpRegistrationData(const std::string &om_optype);
  68. ~OpRegistrationData();
  69. OpRegistrationData &FrameworkType(const domi::FrameworkType &fmk_type);
  70. OpRegistrationData &OriginOpType(const std::initializer_list<std::string> &ori_optype_list);
  71. OpRegistrationData &OriginOpType(const std::string &ori_optype);
  72. OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn);
  73. OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn);
  74. OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn);
  75. OpRegistrationData &ImplyType(const domi::ImplyType &imply_type);
  76. OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue);
  77. OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type);
  78. domi::ImplyType GetImplyType() const;
  79. std::string GetOmOptype() const;
  80. std::set<std::string> GetOriginOpTypeSet() const;
  81. domi::FrameworkType GetFrameworkType() const;
  82. ParseParamFunc GetParseParamFn() const;
  83. FusionParseParamFunc GetFusionParseParamFn() const;
  84. ParseSubgraphFunc GetParseSubgraphPostFn() const;
  85. private:
  86. std::shared_ptr<OpRegistrationDataImpl> impl_;
  87. friend class OpRegistry;
  88. friend class OpRegistrationTbe;
  89. friend class ge::TBEPluginManager;
  90. };
  91. class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver {
  92. public:
  93. OpReceiver(OpRegistrationData &reg_data);
  94. ~OpReceiver() {}
  95. };
  96. #define REGISTER_CUSTOM_OP(name) REGISTER_CUSTOM_OP_UNIQ_HELPER(__COUNTER__, name)
  97. #define REGISTER_CUSTOM_OP_UNIQ_HELPER(ctr, name) REGISTER_CUSTOM_OP_UNIQ(ctr, name)
  98. #define REGISTER_CUSTOM_OP_UNIQ(ctr, name) \
  99. static OpReceiver register_op##ctr __attribute__((unused)) = OpRegistrationData(name)
  100. } // namespace domi
  101. namespace ge {
  102. using OpRegistrationData = domi::OpRegistrationData;
  103. using OpReceiver = domi::OpReceiver;
  104. } // namespace ge
  105. #endif // INC_EXTERNAL_REGISTER_REGISTER_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示