You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

operator_reg.h 26 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_EXTERNAL_GRAPH_OPERATOR_REG_H_
  17. #define INC_EXTERNAL_GRAPH_OPERATOR_REG_H_
  18. #include <functional>
  19. #include <memory>
  20. #include <string>
  21. #include <vector>
  22. #include "graph/operator.h"
  23. #include "graph/operator_factory.h"
  24. #include "graph/tensor.h"
  25. #include "graph/types.h"
  26. #include "graph/graph.h"
  27. namespace ge {
  28. using std::function;
  29. using std::string;
  30. using std::vector;
  31. class OpReg {
  32. public:
  33. OpReg &N() { return *this; }
  34. OpReg &ATTR() { return *this; }
  35. OpReg &REQUIRED_ATTR() { return *this; }
  36. OpReg &INPUT() { return *this; }
  37. OpReg &OPTIONAL_INPUT() { return *this; }
  38. OpReg &OUTPUT() { return *this; }
  39. OpReg &GRAPH() { return *this; }
  40. OpReg &DYNAMIC_GRAPH() { return *this; }
  41. OpReg &INFER_SHAPE_AND_TYPE() { return *this; }
  42. };
  43. #define REG_OP(x) \
  44. namespace op { \
  45. class x : public Operator { \
  46. typedef x _THIS_TYPE; \
  47. \
  48. public: \
  49. explicit x(const string &name) : Operator(name, #x) { __##x(); } \
  50. x() : Operator(#x) { __##x(); } \
  51. \
  52. private: \
  53. void __##x() { \
  54. OpReg()
  55. #define ATTR(x, Type, ...) \
  56. N(); \
  57. __attr_##x(); \
  58. } \
  59. \
  60. public: \
  61. static const string name_attr_##x() { return #x; } \
  62. Op##Type get_attr_##x() const { \
  63. Op##Type ret = __VA_ARGS__; \
  64. if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
  65. return ret; \
  66. } \
  67. return ret; \
  68. } \
  69. _THIS_TYPE &set_attr_##x(const Op##Type &v) { \
  70. Operator::SetAttr(#x, v); \
  71. return *this; \
  72. } \
  73. _THIS_TYPE &set_attr_##x(const function<Op##Type()> &v) { return *this; } \
  74. \
  75. private: \
  76. void __attr_##x() { \
  77. Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \
  78. string attr_name(#x); \
  79. (void)OpReg()
  80. #define REQUIRED_ATTR(x, Type) \
  81. N(); \
  82. __required_attr_##x(); \
  83. } \
  84. \
  85. public: \
  86. static const string name_attr_##x() { return #x; } \
  87. Op##Type get_attr_##x() const { \
  88. Op##Type ret; \
  89. if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
  90. return ret; \
  91. } \
  92. return ret; \
  93. } \
  94. _THIS_TYPE &set_attr_##x(const Op##Type &v) { \
  95. Operator::SetAttr(#x, v); \
  96. return *this; \
  97. } \
  98. _THIS_TYPE &set_attr_##x(const function<Op##Type()> &v) { return *this; } \
  99. \
  100. private: \
  101. void __required_attr_##x() { \
  102. Operator::RequiredAttrRegister(#x); \
  103. string attr_name(#x); \
  104. (void)OpReg()
  105. #define INPUT(x, t) \
  106. N(); \
  107. __input_##x(); \
  108. } \
  109. \
  110. public: \
  111. static const string name_in_##x() { return #x; } \
  112. _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \
  113. Operator::SetInput(#x, v, srcName); \
  114. return *this; \
  115. } \
  116. _THIS_TYPE &set_input_##x(Operator &v) { \
  117. Operator::SetInput(#x, v); \
  118. return *this; \
  119. } \
  120. TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x); } \
  121. graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \
  122. return Operator::UpdateInputDesc(#x, tensorDesc); \
  123. } \
  124. \
  125. private: \
  126. void __input_##x() { \
  127. Operator::InputRegister(#x); \
  128. (void)OpReg()
  129. #define OPTIONAL_INPUT(x, t) \
  130. N(); \
  131. __optional_input_##x(); \
  132. } \
  133. \
  134. public: \
  135. static const string name_in_##x() { return #x; } \
  136. _THIS_TYPE &set_input_##x(Operator &v) { \
  137. Operator::SetInput(#x, v); \
  138. return *this; \
  139. } \
  140. _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \
  141. Operator::SetInput(#x, v, srcName); \
  142. return *this; \
  143. } \
  144. TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x); } \
  145. graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \
  146. return Operator::UpdateInputDesc(#x, tensorDesc); \
  147. } \
  148. \
  149. private: \
  150. void __optional_input_##x() { \
  151. Operator::OptionalInputRegister(#x); \
  152. (void)OpReg()
  153. #define OUTPUT(x, t) \
  154. N(); \
  155. __out_##x(); \
  156. } \
  157. \
  158. public: \
  159. static const string name_out_##x() { return #x; } \
  160. TensorDesc get_output_desc_##x() const { return Operator::GetOutputDesc(#x); } \
  161. graphStatus update_output_desc_##x(const TensorDesc &tensorDesc) { \
  162. return Operator::UpdateOutputDesc(#x, tensorDesc); \
  163. } \
  164. \
  165. private: \
  166. void __out_##x() { \
  167. Operator::OutputRegister(#x); \
  168. (void)OpReg()
  169. #define DYNAMIC_INPUT(x, t) \
  170. N(); \
  171. __dy_input_##x(); \
  172. } \
  173. \
  174. public: \
  175. _THIS_TYPE &create_dynamic_input_##x(unsigned int num, bool isPushBack = true) { \
  176. Operator::DynamicInputRegister(#x, num, isPushBack); \
  177. return *this; \
  178. } \
  179. _THIS_TYPE &create_dynamic_input_byindex_##x(unsigned int num, size_t index) { \
  180. Operator::DynamicInputRegisterByIndex(#x, num, index); \
  181. return *this; \
  182. } \
  183. TensorDesc get_dynamic_input_desc_##x(unsigned int index) const { return Operator::GetDynamicInputDesc(#x, index); } \
  184. graphStatus update_dynamic_input_desc_##x(unsigned int index, const TensorDesc &tensorDesc) { \
  185. return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \
  186. } \
  187. _THIS_TYPE &set_dynamic_input_##x(unsigned int dstIndex, Operator &v) { \
  188. Operator::SetInput(#x, dstIndex, v); \
  189. return *this; \
  190. } \
  191. _THIS_TYPE &set_dynamic_input_##x(unsigned int dstIndex, Operator &v, const string &srcName) { \
  192. Operator::SetInput(#x, dstIndex, v, srcName); \
  193. return *this; \
  194. } \
  195. \
  196. private: \
  197. void __dy_input_##x() { \
  198. (void)OpReg()
  199. #define DYNAMIC_OUTPUT(x, t) \
  200. N(); \
  201. __dy_output_##x(); \
  202. } \
  203. \
  204. public: \
  205. _THIS_TYPE &create_dynamic_output_##x(unsigned int num, bool isPushBack = true) { \
  206. Operator::DynamicOutputRegister(#x, num, isPushBack); \
  207. return *this; \
  208. } \
  209. TensorDesc get_dynamic_output_desc_##x(unsigned int index) const { \
  210. return Operator::GetDynamicOutputDesc(#x, index); \
  211. } \
  212. graphStatus update_dynamic_output_desc_##x(unsigned int index, const TensorDesc &tensorDesc) { \
  213. return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \
  214. } \
  215. \
  216. private: \
  217. void __dy_output_##x() { \
  218. (void)OpReg()
  219. #define GRAPH(x) \
  220. N(); \
  221. __graph_##x(); \
  222. } \
  223. \
  224. public: \
  225. static const string name_graph_##x() { return #x; } \
  226. SubgraphBuilder get_subgraph_builder_##x() const { return Operator::GetSubgraphBuilder(#x); } \
  227. _THIS_TYPE &set_subgraph_builder_##x(const SubgraphBuilder &v) { \
  228. Operator::SetSubgraphBuilder(#x, 0, v); \
  229. return *this; \
  230. } \
  231. Graph get_subgraph_##x() const { return Operator::GetSubgraph(#x); } \
  232. \
  233. private: \
  234. void __graph_##x() { \
  235. Operator::SubgraphRegister(#x, false); \
  236. Operator::SubgraphCountRegister(#x, 1); \
  237. (void)OpReg()
  238. #define DYNAMIC_GRAPH(x) \
  239. N(); \
  240. __graph_##x(); \
  241. } \
  242. \
  243. public: \
  244. static const string name_graph_##x() { return #x; } \
  245. _THIS_TYPE &create_dynamic_subgraph_##x(unsigned int num) { \
  246. Operator::SubgraphCountRegister(#x, num); \
  247. return *this; \
  248. } \
  249. SubgraphBuilder get_dynamic_subgraph_builder_##x(unsigned int index) const { \
  250. return Operator::GetDynamicSubgraphBuilder(#x, index); \
  251. } \
  252. Graph get_dynamic_subgraph_##x(unsigned int index) const { return Operator::GetDynamicSubgraph(#x, index); } \
  253. _THIS_TYPE &set_dynamic_subgraph_builder_##x(unsigned int index, const SubgraphBuilder &v) { \
  254. Operator::SetSubgraphBuilder(#x, index, v); \
  255. return *this; \
  256. } \
  257. \
  258. private: \
  259. void __graph_##x() { \
  260. Operator::SubgraphRegister(#x, true); \
  261. (void)OpReg()
  262. #define PASTE(g_register, y) g_register##y
  263. #define __OP_END_IMPL__(x, y) \
  264. N(); \
  265. } \
  266. static_assert( \
  267. std::is_same<x, _THIS_TYPE>::value, \
  268. "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \
  269. } \
  270. ; \
  271. static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \
  272. }
  273. #define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__)
  274. // Specialized shape inferencer macro
  275. #define IMPLEMT_INFERFUNC(op_name, func_name) \
  276. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op)
  277. #define IMPLEMT_COMMON_INFERFUNC(func_name) \
  278. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(Operator &op)
  279. #define IMPLEMT_INFERFORMAT_FUNC(op_name, func_name) \
  280. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op)
  281. // Specialized verifier macro
  282. #define IMPLEMT_VERIFIER(op_name, func_name) \
  283. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name op)
  284. #define INFER_VERIFY_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); }
  285. #define COMMON_INFER_VERIFY_FUNC(x) [&](Operator &v) { return x(v); }
  286. #define INFER_FORMAT_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); }
  287. #define __INFER_FUNC_REG_IMPL__(op_name, x, n) static const InferShapeFuncRegister PASTE(if_register, n)(#op_name, x)
  288. #define __VERIFY_FUNC_REG_IMPL__(op_name, x, n) static const VerifyFuncRegister PASTE(vf_register, n)(#op_name, x)
  289. // Infer format func register
  290. #define __INFER_FORMAT_FUNC_REG_IMPL__(op_name, x, n) \
  291. static const InferFormatFuncRegister PASTE(ff_register, n)(#op_name, x)
  292. // Shape inferencer & verifier register macro
  293. #define INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__)
  294. #define COMMON_INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, COMMON_INFER_VERIFY_FUNC(x), __COUNTER__)
  295. #define VERIFY_FUNC_REG(op_name, x) __VERIFY_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__)
  296. // Infer format func reg
  297. #define INFER_FORMAT_FUNC_REG(op_name, x) \
  298. __INFER_FORMAT_FUNC_REG_IMPL__(op_name, INFER_FORMAT_FUNC(op_name, x), __COUNTER__)
  299. // Common shape inferencer
  300. #define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \
  301. [](Operator op) -> graphStatus { \
  302. auto x_shape = op.GetInputDesc(in_name).GetShape().GetDims(); \
  303. auto x_type = op.GetInputDesc(in_name).GetDataType(); \
  304. TensorDesc op_output_desc = op.GetOutputDesc(out_name); \
  305. op_output_desc.SetShape(ge::Shape(x_shape)); \
  306. op_output_desc.SetDataType(x_type); \
  307. return op.UpdateOutputDesc(out_name, op_output_desc); \
  308. }
  309. graphStatus BroadCastInfer(const function<vector<int64_t>()> &get_in1_shape,
  310. const function<vector<int64_t>()> &get_in2_shape,
  311. const function<void(const vector<int64_t> &y_shape)> &set_out_shape);
  312. #define BROADCAST_INFER(in1_name, in2_name, out_name) \
  313. [](Operator op) -> graphStatus { \
  314. return BroadCastInfer([&]() { return op.GetInputDesc(in1_name).GetShape().GetDims(); }, \
  315. [&]() { return op.GetInputDesc(in2_name).GetShape().GetDims(); }, \
  316. [&](const vector<int64_t> &y_shape) { \
  317. TensorDesc op_output_desc = op.GetOutputDesc(out_name); \
  318. op_output_desc.SetShape(ge::Shape(y_shape)); \
  319. (void)op.UpdateOutputDesc(out_name, op_output_desc); \
  320. }); \
  321. }
  322. } // namespace ge
  323. #endif // INC_EXTERNAL_GRAPH_OPERATOR_REG_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示