You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

operator_reg.h 30 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef INC_EXTERNAL_GRAPH_OPERATOR_REG_H_
  17. #define INC_EXTERNAL_GRAPH_OPERATOR_REG_H_
  18. #include <functional>
  19. #include <memory>
  20. #include <string>
  21. #include <vector>
  22. #include "graph/operator.h"
  23. #include "graph/operator_factory.h"
  24. #include "graph/tensor.h"
  25. #include "graph/types.h"
  26. #include "graph/graph.h"
  27. namespace ge {
  28. using std::function;
  29. using std::string;
  30. using std::vector;
  31. class OpReg {
  32. public:
  33. OpReg &N() { return *this; }
  34. OpReg &ATTR() { return *this; }
  35. OpReg &REQUIRED_ATTR() { return *this; }
  36. OpReg &INPUT() { return *this; }
  37. OpReg &OPTIONAL_INPUT() { return *this; }
  38. OpReg &OUTPUT() { return *this; }
  39. OpReg &GRAPH() { return *this; }
  40. OpReg &DYNAMIC_GRAPH() { return *this; }
  41. OpReg &INFER_SHAPE_AND_TYPE() { return *this; }
  42. };
  43. #define REG_OP(x) \
  44. namespace op { \
  45. class x : public Operator { \
  46. typedef x _THIS_TYPE; \
  47. \
  48. public: \
  49. ATTRIBUTED_DEPRECATED(x(const AscendString &)) \
  50. explicit x(const string &name) : Operator(name, #x) { __##x(); } \
  51. explicit x(const char *name) : Operator(name, #x) { __##x(); } \
  52. explicit x(const AscendString &name) : Operator(name, #x) { __##x(); } \
  53. x() : Operator(#x) { __##x(); } \
  54. \
  55. private: \
  56. void __##x() { \
  57. OpReg()
  58. #define ATTR(x, Type, ...) \
  59. N(); \
  60. __attr_##x(); \
  61. } \
  62. \
  63. public: \
  64. ATTRIBUTED_DEPRECATED(static const void name_attr_##x(AscendString &)) \
  65. static const string name_attr_##x() { return #x; } \
  66. static const void name_attr_##x(AscendString &attr) { attr = AscendString(#x); } \
  67. Op##Type get_attr_##x() const { \
  68. Op##Type ret = __VA_ARGS__; \
  69. if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
  70. return ret; \
  71. } \
  72. return ret; \
  73. } \
  74. _THIS_TYPE &set_attr_##x(const Op##Type &v) { \
  75. Operator::SetAttr(#x, v); \
  76. return *this; \
  77. } \
  78. _THIS_TYPE &set_attr_##x(const function<Op##Type()> &v) { return *this; } \
  79. \
  80. private: \
  81. void __attr_##x() { \
  82. Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \
  83. string attr_name(#x); \
  84. (void)OpReg()
  85. #define REQUIRED_ATTR(x, Type) \
  86. N(); \
  87. __required_attr_##x(); \
  88. } \
  89. \
  90. public: \
  91. ATTRIBUTED_DEPRECATED(static const void name_attr_##x(AscendString &)) \
  92. static const string name_attr_##x() { return #x; } \
  93. static const void name_attr_##x(AscendString &attr_name) { attr_name = AscendString(#x); } \
  94. Op##Type get_attr_##x() const { \
  95. Op##Type ret; \
  96. if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \
  97. return ret; \
  98. } \
  99. return ret; \
  100. } \
  101. _THIS_TYPE &set_attr_##x(const Op##Type &v) { \
  102. Operator::SetAttr(#x, v); \
  103. return *this; \
  104. } \
  105. _THIS_TYPE &set_attr_##x(const function<Op##Type()> &v) { return *this; } \
  106. \
  107. private: \
  108. void __required_attr_##x() { \
  109. Operator::RequiredAttrRegister(#x); \
  110. string attr_name(#x); \
  111. (void)OpReg()
  112. #define INPUT(x, t) \
  113. N(); \
  114. __input_##x(); \
  115. } \
  116. \
  117. public: \
  118. ATTRIBUTED_DEPRECATED(static const void name_in_##x(AscendString &)) \
  119. static const string name_in_##x() { return #x; } \
  120. static const void name_in_##x(AscendString &name) { name = AscendString(#x); } \
  121. ATTRIBUTED_DEPRECATED(_THIS_TYPE &set_input_##x(Operator &, const char *)) \
  122. _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \
  123. Operator::SetInput(#x, v, srcName); \
  124. return *this; \
  125. } \
  126. _THIS_TYPE &set_input_##x(Operator &v, const char *srcName) { \
  127. Operator::SetInput(#x, v, srcName); \
  128. return *this; \
  129. } \
  130. _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \
  131. Operator::SetInput(#x, v, index); \
  132. return *this; \
  133. } \
  134. _THIS_TYPE &set_input_##x(Operator &v) { \
  135. Operator::SetInput(#x, v); \
  136. return *this; \
  137. } \
  138. TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x, 0); } \
  139. graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \
  140. return Operator::UpdateInputDesc(#x, tensorDesc); \
  141. } \
  142. \
  143. private: \
  144. void __input_##x() { \
  145. Operator::InputRegister(#x); \
  146. (void)OpReg()
  147. #define OPTIONAL_INPUT(x, t) \
  148. N(); \
  149. __optional_input_##x(); \
  150. } \
  151. \
  152. public: \
  153. ATTRIBUTED_DEPRECATED(static const void name_in_##x(AscendString &)) \
  154. static const string name_in_##x() { return #x; } \
  155. static const void name_in_##x(AscendString &name) { name = AscendString(#x); } \
  156. _THIS_TYPE &set_input_##x(Operator &v) { \
  157. Operator::SetInput(#x, v); \
  158. return *this; \
  159. } \
  160. ATTRIBUTED_DEPRECATED(_THIS_TYPE &set_input_##x(Operator &, const char *)) \
  161. _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \
  162. Operator::SetInput(#x, v, srcName); \
  163. return *this; \
  164. } \
  165. _THIS_TYPE &set_input_##x(Operator &v, const char *srcName) { \
  166. Operator::SetInput(#x, v, srcName); \
  167. return *this; \
  168. } \
  169. _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \
  170. Operator::SetInput(#x, v, index); \
  171. return *this; \
  172. } \
  173. TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x, 0); } \
  174. graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \
  175. return Operator::UpdateInputDesc(#x, tensorDesc); \
  176. } \
  177. \
  178. private: \
  179. void __optional_input_##x() { \
  180. Operator::OptionalInputRegister(#x); \
  181. (void)OpReg()
  182. #define OUTPUT(x, t) \
  183. N(); \
  184. __out_##x(); \
  185. } \
  186. \
  187. public: \
  188. ATTRIBUTED_DEPRECATED(static const void name_out_##x(AscendString &)) \
  189. static const string name_out_##x() { return #x; } \
  190. static const void name_out_##x(AscendString &name) { name = AscendString(#x); } \
  191. TensorDesc get_output_desc_##x() const { return Operator::GetOutputDesc(#x, 0); } \
  192. graphStatus update_output_desc_##x(const TensorDesc &tensorDesc) { \
  193. return Operator::UpdateOutputDesc(#x, tensorDesc); \
  194. } \
  195. \
  196. private: \
  197. void __out_##x() { \
  198. Operator::OutputRegister(#x); \
  199. (void)OpReg()
  200. #define DYNAMIC_INPUT(x, t) \
  201. N(); \
  202. __dy_input_##x(); \
  203. } \
  204. \
  205. public: \
  206. _THIS_TYPE &create_dynamic_input_##x(uint32_t num, bool isPushBack = true) { \
  207. Operator::DynamicInputRegister(#x, num, isPushBack); \
  208. return *this; \
  209. } \
  210. _THIS_TYPE &create_dynamic_input_byindex_##x(uint32_t num, size_t index) { \
  211. Operator::DynamicInputRegisterByIndex(#x, num, index); \
  212. return *this; \
  213. } \
  214. TensorDesc get_dynamic_input_desc_##x(uint32_t index) const { return Operator::GetDynamicInputDesc(#x, index); } \
  215. graphStatus update_dynamic_input_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \
  216. return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \
  217. } \
  218. _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v) { \
  219. Operator::SetInput(#x, dstIndex, v); \
  220. return *this; \
  221. } \
  222. ATTRIBUTED_DEPRECATED(_THIS_TYPE &set_dynamic_input_##x(uint32_t, Operator &, const char *)) \
  223. _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const string &srcName) { \
  224. Operator::SetInput(#x, dstIndex, v, srcName); \
  225. return *this; \
  226. } \
  227. _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const char *srcName) { \
  228. Operator::SetInput(#x, dstIndex, v, srcName); \
  229. return *this; \
  230. } \
  231. \
  232. private: \
  233. void __dy_input_##x() { \
  234. Operator::DynamicInputRegister(#x, 0, true); \
  235. (void)OpReg()
  236. #define DYNAMIC_OUTPUT(x, t) \
  237. N(); \
  238. __dy_output_##x(); \
  239. } \
  240. \
  241. public: \
  242. _THIS_TYPE &create_dynamic_output_##x(uint32_t num, bool isPushBack = true) { \
  243. Operator::DynamicOutputRegister(#x, num, isPushBack); \
  244. return *this; \
  245. } \
  246. TensorDesc get_dynamic_output_desc_##x(uint32_t index) const { return Operator::GetDynamicOutputDesc(#x, index); } \
  247. graphStatus update_dynamic_output_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \
  248. return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \
  249. } \
  250. \
  251. private: \
  252. void __dy_output_##x() { \
  253. Operator::DynamicOutputRegister(#x, 0, true); \
  254. (void)OpReg()
  255. #define GRAPH(x) \
  256. N(); \
  257. __graph_##x(); \
  258. } \
  259. \
  260. public: \
  261. ATTRIBUTED_DEPRECATED(static const void name_graph_##x(AscendString &)) \
  262. static const string name_graph_##x() { return #x; } \
  263. static const void name_graph_##x(AscendString &name) { name = AscendString(#x); } \
  264. SubgraphBuilder get_subgraph_builder_##x() const { return Operator::GetSubgraphBuilder(#x); } \
  265. _THIS_TYPE &set_subgraph_builder_##x(const SubgraphBuilder &v) { \
  266. Operator::SetSubgraphBuilder(#x, 0, v); \
  267. return *this; \
  268. } \
  269. Graph get_subgraph_##x() const { return Operator::GetSubgraph(#x); } \
  270. \
  271. private: \
  272. void __graph_##x() { \
  273. Operator::SubgraphRegister(#x, false); \
  274. Operator::SubgraphCountRegister(#x, 1); \
  275. (void)OpReg()
  276. #define DYNAMIC_GRAPH(x) \
  277. N(); \
  278. __graph_##x(); \
  279. } \
  280. \
  281. public: \
  282. ATTRIBUTED_DEPRECATED(static const void name_graph_##x(AscendString &)) \
  283. static const string name_graph_##x() { return #x; } \
  284. static const void name_graph_##x(AscendString &name) { name = AscendString(#x); } \
  285. _THIS_TYPE &create_dynamic_subgraph_##x(uint32_t num) { \
  286. Operator::SubgraphCountRegister(#x, num); \
  287. return *this; \
  288. } \
  289. SubgraphBuilder get_dynamic_subgraph_builder_##x(uint32_t index) const { \
  290. return Operator::GetDynamicSubgraphBuilder(#x, index); \
  291. } \
  292. Graph get_dynamic_subgraph_##x(uint32_t index) const { return Operator::GetDynamicSubgraph(#x, index); } \
  293. _THIS_TYPE &set_dynamic_subgraph_builder_##x(uint32_t index, const SubgraphBuilder &v) { \
  294. Operator::SetSubgraphBuilder(#x, index, v); \
  295. return *this; \
  296. } \
  297. \
  298. private: \
  299. void __graph_##x() { \
  300. Operator::SubgraphRegister(#x, true); \
  301. (void)OpReg()
  302. #define PASTE(g_register, y) g_register##y
  303. #define __OP_END_IMPL__(x, y) \
  304. N(); \
  305. } \
  306. static_assert( \
  307. std::is_same<x, _THIS_TYPE>::value, \
  308. "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \
  309. } \
  310. ; \
  311. static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const AscendString &name) { return x(name); }); \
  312. }
  313. #define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__)
  314. // Specialized shape inferencer macro
  315. #define IMPLEMT_INFERFUNC(op_name, func_name) \
  316. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op)
  317. #define IMPLEMT_COMMON_INFERFUNC(func_name) \
  318. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(Operator &op)
  319. #define IMPLEMT_INFERFORMAT_FUNC(op_name, func_name) \
  320. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op)
  321. // Specialized verifier macro
  322. #define IMPLEMT_VERIFIER(op_name, func_name) \
  323. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name op)
  324. #define INFER_VERIFY_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); }
  325. #define COMMON_INFER_VERIFY_FUNC(x) [&](Operator &v) { return x(v); }
  326. #define INFER_FORMAT_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); }
  327. #define __INFER_FUNC_REG_IMPL__(op_name, x, n) static const InferShapeFuncRegister PASTE(if_register, n)(#op_name, x)
  328. #define __VERIFY_FUNC_REG_IMPL__(op_name, x, n) static const VerifyFuncRegister PASTE(vf_register, n)(#op_name, x)
  329. // Infer format func register
  330. #define __INFER_FORMAT_FUNC_REG_IMPL__(op_name, x, n) \
  331. static const InferFormatFuncRegister PASTE(ff_register, n)(#op_name, x)
  332. // Shape inferencer & verifier register macro
  333. #define INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__)
  334. #define COMMON_INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, COMMON_INFER_VERIFY_FUNC(x), __COUNTER__)
  335. #define VERIFY_FUNC_REG(op_name, x) __VERIFY_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__)
  336. // Infer format func reg
  337. #define INFER_FORMAT_FUNC_REG(op_name, x) \
  338. __INFER_FORMAT_FUNC_REG_IMPL__(op_name, INFER_FORMAT_FUNC(op_name, x), __COUNTER__)
  339. // Common shape inferencer
  340. #define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \
  341. [](Operator op) -> graphStatus { \
  342. auto x_shape = op.GetInputDesc(in_name, 0).GetShape().GetDims(); \
  343. auto x_type = op.GetInputDesc(in_name, 0).GetDataType(); \
  344. TensorDesc op_output_desc = op.GetOutputDesc(out_name, 0); \
  345. op_output_desc.SetShape(ge::Shape(x_shape)); \
  346. op_output_desc.SetOriginShape(ge::Shape(x_shape)); \
  347. op_output_desc.SetDataType(x_type); \
  348. return op.UpdateOutputDesc(out_name, op_output_desc); \
  349. }
  350. graphStatus BroadCastInfer(const function<vector<int64_t>()> &get_in1_shape,
  351. const function<vector<int64_t>()> &get_in2_shape,
  352. const function<void(const vector<int64_t> &y_shape)> &set_out_shape);
  353. #define BROADCAST_INFER(in1_name, in2_name, out_name) \
  354. [](Operator op) -> graphStatus { \
  355. return BroadCastInfer([&]() { return op.GetInputDesc(in1_name, 0).GetShape().GetDims(); }, \
  356. [&]() { return op.GetInputDesc(in2_name, 0).GetShape().GetDims(); }, \
  357. [&](const vector<int64_t> &y_shape) { \
  358. TensorDesc op_output_desc = op.GetOutputDesc(out_name, 0); \
  359. op_output_desc.SetShape(ge::Shape(y_shape)); \
  360. (void)op.UpdateOutputDesc(out_name, op_output_desc); \
  361. }); \
  362. }
  363. } // namespace ge
  364. #endif // INC_EXTERNAL_GRAPH_OPERATOR_REG_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示