You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

operator_reg.h 26 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_EXTERNAL_GRAPH_OPERATOR_REG_H_
  17. #define INC_EXTERNAL_GRAPH_OPERATOR_REG_H_
  18. #include <functional>
  19. #include <memory>
  20. #include <string>
  21. #include <vector>
  22. #include "graph/operator.h"
  23. #include "graph/operator_factory.h"
  24. #include "graph/tensor.h"
  25. #include "graph/types.h"
  26. #include "graph/graph.h"
  27. namespace ge {
  28. using std::function;
  29. using std::string;
  30. using std::vector;
  31. class OpReg {
  32. public:
  33. OpReg &N() { return *this; }
  34. OpReg &ATTR() { return *this; }
  35. OpReg &REQUIRED_ATTR() { return *this; }
  36. OpReg &INPUT() { return *this; }
  37. OpReg &OPTIONAL_INPUT() { return *this; }
  38. OpReg &OUTPUT() { return *this; }
  39. OpReg &GRAPH() { return *this; }
  40. OpReg &DYNAMIC_GRAPH() { return *this; }
  41. OpReg &INFER_SHAPE_AND_TYPE() { return *this; }
  42. };
  43. #define REG_OP(x) \
  44. namespace op { \
  45. class x : public Operator { \
  46. typedef x _THIS_TYPE; \
  47. \
  48. public: \
  49. explicit x(const string &name) : Operator(name, #x) { __##x(); } \
  50. x() : Operator(#x) { __##x(); } \
  51. \
  52. private: \
  53. void __##x() { \
  54. OpReg()
  55. #define ATTR(x, Type, ...) \
  56. N(); \
  57. __attr_##x(); \
  58. } \
  59. \
  60. public: \
  61. static const string name_attr_##x() { return #x; } \
  62. Op##Type get_attr_##x() const { \
  63. Op##Type ret = __VA_ARGS__; \
  64. if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
  65. return ret; \
  66. } \
  67. return ret; \
  68. } \
  69. _THIS_TYPE &set_attr_##x(const Op##Type &v) { \
  70. Operator::SetAttr(#x, v); \
  71. return *this; \
  72. } \
  73. _THIS_TYPE &set_attr_##x(const function<Op##Type()> &v) { return *this; } \
  74. \
  75. private: \
  76. void __attr_##x() { \
  77. Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \
  78. string attr_name(#x); \
  79. (void)OpReg()
  80. #define REQUIRED_ATTR(x, Type) \
  81. N(); \
  82. __required_attr_##x(); \
  83. } \
  84. \
  85. public: \
  86. static const string name_attr_##x() { return #x; } \
  87. Op##Type get_attr_##x() const { \
  88. Op##Type ret; \
  89. if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
  90. return ret; \
  91. } \
  92. return ret; \
  93. } \
  94. _THIS_TYPE &set_attr_##x(const Op##Type &v) { \
  95. Operator::SetAttr(#x, v); \
  96. return *this; \
  97. } \
  98. _THIS_TYPE &set_attr_##x(const function<Op##Type()> &v) { return *this; } \
  99. \
  100. private: \
  101. void __required_attr_##x() { \
  102. Operator::RequiredAttrRegister(#x); \
  103. string attr_name(#x); \
  104. (void)OpReg()
  105. #define INPUT(x, t) \
  106. N(); \
  107. __input_##x(); \
  108. } \
  109. \
  110. public: \
  111. static const string name_in_##x() { return #x; } \
  112. _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \
  113. Operator::SetInput(#x, v, srcName); \
  114. return *this; \
  115. } \
  116. _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \
  117. Operator::SetInput(#x, v, index); \
  118. return *this; \
  119. } \
  120. _THIS_TYPE &set_input_##x(Operator &v) { \
  121. Operator::SetInput(#x, v); \
  122. return *this; \
  123. } \
  124. TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x); } \
  125. graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \
  126. return Operator::UpdateInputDesc(#x, tensorDesc); \
  127. } \
  128. \
  129. private: \
  130. void __input_##x() { \
  131. Operator::InputRegister(#x); \
  132. (void)OpReg()
  133. #define OPTIONAL_INPUT(x, t) \
  134. N(); \
  135. __optional_input_##x(); \
  136. } \
  137. \
  138. public: \
  139. static const string name_in_##x() { return #x; } \
  140. _THIS_TYPE &set_input_##x(Operator &v) { \
  141. Operator::SetInput(#x, v); \
  142. return *this; \
  143. } \
  144. _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \
  145. Operator::SetInput(#x, v, srcName); \
  146. return *this; \
  147. } \
  148. _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \
  149. Operator::SetInput(#x, v, index); \
  150. return *this; \
  151. } \
  152. TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x); } \
  153. graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \
  154. return Operator::UpdateInputDesc(#x, tensorDesc); \
  155. } \
  156. \
  157. private: \
  158. void __optional_input_##x() { \
  159. Operator::OptionalInputRegister(#x); \
  160. (void)OpReg()
  161. #define OUTPUT(x, t) \
  162. N(); \
  163. __out_##x(); \
  164. } \
  165. \
  166. public: \
  167. static const string name_out_##x() { return #x; } \
  168. TensorDesc get_output_desc_##x() const { return Operator::GetOutputDesc(#x); } \
  169. graphStatus update_output_desc_##x(const TensorDesc &tensorDesc) { \
  170. return Operator::UpdateOutputDesc(#x, tensorDesc); \
  171. } \
  172. \
  173. private: \
  174. void __out_##x() { \
  175. Operator::OutputRegister(#x); \
  176. (void)OpReg()
  177. #define DYNAMIC_INPUT(x, t) \
  178. N(); \
  179. __dy_input_##x(); \
  180. } \
  181. \
  182. public: \
  183. _THIS_TYPE &create_dynamic_input_##x(uint32_t num, bool isPushBack = true) { \
  184. Operator::DynamicInputRegister(#x, num, isPushBack); \
  185. return *this; \
  186. } \
  187. _THIS_TYPE &create_dynamic_input_byindex_##x(uint32_t num, size_t index) { \
  188. Operator::DynamicInputRegisterByIndex(#x, num, index); \
  189. return *this; \
  190. } \
  191. TensorDesc get_dynamic_input_desc_##x(uint32_t index) const { return Operator::GetDynamicInputDesc(#x, index); } \
  192. graphStatus update_dynamic_input_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \
  193. return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \
  194. } \
  195. _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v) { \
  196. Operator::SetInput(#x, dstIndex, v); \
  197. return *this; \
  198. } \
  199. _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const string &srcName) { \
  200. Operator::SetInput(#x, dstIndex, v, srcName); \
  201. return *this; \
  202. } \
  203. \
  204. private: \
  205. void __dy_input_##x() { \
  206. (void)OpReg()
  207. #define DYNAMIC_OUTPUT(x, t) \
  208. N(); \
  209. __dy_output_##x(); \
  210. } \
  211. \
  212. public: \
  213. _THIS_TYPE &create_dynamic_output_##x(uint32_t num, bool isPushBack = true) { \
  214. Operator::DynamicOutputRegister(#x, num, isPushBack); \
  215. return *this; \
  216. } \
  217. TensorDesc get_dynamic_output_desc_##x(uint32_t index) const { return Operator::GetDynamicOutputDesc(#x, index); } \
  218. graphStatus update_dynamic_output_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \
  219. return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \
  220. } \
  221. \
  222. private: \
  223. void __dy_output_##x() { \
  224. (void)OpReg()
  225. #define GRAPH(x) \
  226. N(); \
  227. __graph_##x(); \
  228. } \
  229. \
  230. public: \
  231. static const string name_graph_##x() { return #x; } \
  232. SubgraphBuilder get_subgraph_builder_##x() const { return Operator::GetSubgraphBuilder(#x); } \
  233. _THIS_TYPE &set_subgraph_builder_##x(const SubgraphBuilder &v) { \
  234. Operator::SetSubgraphBuilder(#x, 0, v); \
  235. return *this; \
  236. } \
  237. Graph get_subgraph_##x() const { return Operator::GetSubgraph(#x); } \
  238. \
  239. private: \
  240. void __graph_##x() { \
  241. Operator::SubgraphRegister(#x, false); \
  242. Operator::SubgraphCountRegister(#x, 1); \
  243. (void)OpReg()
  244. #define DYNAMIC_GRAPH(x) \
  245. N(); \
  246. __graph_##x(); \
  247. } \
  248. \
  249. public: \
  250. static const string name_graph_##x() { return #x; } \
  251. _THIS_TYPE &create_dynamic_subgraph_##x(uint32_t num) { \
  252. Operator::SubgraphCountRegister(#x, num); \
  253. return *this; \
  254. } \
  255. SubgraphBuilder get_dynamic_subgraph_builder_##x(uint32_t index) const { \
  256. return Operator::GetDynamicSubgraphBuilder(#x, index); \
  257. } \
  258. Graph get_dynamic_subgraph_##x(uint32_t index) const { return Operator::GetDynamicSubgraph(#x, index); } \
  259. _THIS_TYPE &set_dynamic_subgraph_builder_##x(uint32_t index, const SubgraphBuilder &v) { \
  260. Operator::SetSubgraphBuilder(#x, index, v); \
  261. return *this; \
  262. } \
  263. \
  264. private: \
  265. void __graph_##x() { \
  266. Operator::SubgraphRegister(#x, true); \
  267. (void)OpReg()
  268. #define PASTE(g_register, y) g_register##y
  269. #define __OP_END_IMPL__(x, y) \
  270. N(); \
  271. } \
  272. static_assert( \
  273. std::is_same<x, _THIS_TYPE>::value, \
  274. "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \
  275. } \
  276. ; \
  277. static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \
  278. }
  279. #define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__)
  280. // Specialized shape inferencer macro
  281. #define IMPLEMT_INFERFUNC(op_name, func_name) \
  282. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op)
  283. #define IMPLEMT_COMMON_INFERFUNC(func_name) \
  284. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(Operator &op)
  285. #define IMPLEMT_INFERFORMAT_FUNC(op_name, func_name) \
  286. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op)
  287. // Specialized verifier macro
  288. #define IMPLEMT_VERIFIER(op_name, func_name) \
  289. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name op)
  290. #define INFER_VERIFY_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); }
  291. #define COMMON_INFER_VERIFY_FUNC(x) [&](Operator &v) { return x(v); }
  292. #define INFER_FORMAT_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); }
  293. #define __INFER_FUNC_REG_IMPL__(op_name, x, n) static const InferShapeFuncRegister PASTE(if_register, n)(#op_name, x)
  294. #define __VERIFY_FUNC_REG_IMPL__(op_name, x, n) static const VerifyFuncRegister PASTE(vf_register, n)(#op_name, x)
  295. // Infer format func register
  296. #define __INFER_FORMAT_FUNC_REG_IMPL__(op_name, x, n) \
  297. static const InferFormatFuncRegister PASTE(ff_register, n)(#op_name, x)
  298. // Shape inferencer & verifier register macro
  299. #define INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__)
  300. #define COMMON_INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, COMMON_INFER_VERIFY_FUNC(x), __COUNTER__)
  301. #define VERIFY_FUNC_REG(op_name, x) __VERIFY_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__)
  302. // Infer format func reg
  303. #define INFER_FORMAT_FUNC_REG(op_name, x) \
  304. __INFER_FORMAT_FUNC_REG_IMPL__(op_name, INFER_FORMAT_FUNC(op_name, x), __COUNTER__)
  305. // Common shape inferencer
  306. #define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \
  307. [](Operator op) -> graphStatus { \
  308. auto x_shape = op.GetInputDesc(in_name).GetShape().GetDims(); \
  309. auto x_type = op.GetInputDesc(in_name).GetDataType(); \
  310. TensorDesc op_output_desc = op.GetOutputDesc(out_name); \
  311. op_output_desc.SetShape(ge::Shape(x_shape)); \
  312. op_output_desc.SetOriginShape(ge::Shape(x_shape)); \
  313. op_output_desc.SetDataType(x_type); \
  314. return op.UpdateOutputDesc(out_name, op_output_desc); \
  315. }
  316. graphStatus BroadCastInfer(const function<vector<int64_t>()> &get_in1_shape,
  317. const function<vector<int64_t>()> &get_in2_shape,
  318. const function<void(const vector<int64_t> &y_shape)> &set_out_shape);
  319. #define BROADCAST_INFER(in1_name, in2_name, out_name) \
  320. [](Operator op) -> graphStatus { \
  321. return BroadCastInfer([&]() { return op.GetInputDesc(in1_name).GetShape().GetDims(); }, \
  322. [&]() { return op.GetInputDesc(in2_name).GetShape().GetDims(); }, \
  323. [&](const vector<int64_t> &y_shape) { \
  324. TensorDesc op_output_desc = op.GetOutputDesc(out_name); \
  325. op_output_desc.SetShape(ge::Shape(y_shape)); \
  326. (void)op.UpdateOutputDesc(out_name, op_output_desc); \
  327. }); \
  328. }
  329. } // namespace ge
  330. #endif // INC_EXTERNAL_GRAPH_OPERATOR_REG_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示