You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

control_flow_ops.h 16 kB

5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file control_flow_ops.h
  18. * \brief
  19. */
  20. #ifndef OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_
  21. #define OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_
  22. #include "graph/operator_reg.h"
  23. #include "graph/operator.h"
  24. namespace ge {
  25. /**
  26. *@brief Forwards the value of an available tensor from input "x" to output "y".
  27. * Merge waits for at least one of the input tensors to become available.
  28. * It is usually combined with Switch to implement branching.
  29. * Merge forwards the first tensor to become available to output "y",
  30. * and sets "value_index" the index of the tensor in inputs . \n
  31. *@par Inputs:
  32. *x: The input tensors, one of which will become available.
  33. * Must be one of the following types: float16, float32, float64, int8,
  34. * int16, int32, int64, uint8, uint16, uint32, uint64, bool . It's a dynamic input. \n
  35. *@par Outputs:
  36. *@li y: The available tensor. Has the same type as "x".
  37. *@li value_index: A scalar of type int32, for the index of the chosen input
  38. * tensor . \n
  39. *@see Switch()
  40. *@par Third-party framework compatibility
  41. *@Compatible with the TensorFlow operator Merge.
  42. */
  43. REG_OP(Merge)
  44. .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  45. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  46. DT_UINT64, DT_BOOL}))
  47. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  48. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  49. DT_UINT64, DT_BOOL}))
  50. .OUTPUT(value_index, TensorType({DT_INT32}))
  51. .OP_END_FACTORY_REG(Merge)
  52. /**
  53. *@brief Forwards the value of an available tensor from input "x" to output "y".
  54. * Merge waits for at least one of the input tensors to become available.
  55. * It is usually combined with Switch to implement branching.
  56. * Merge forwards the first tensor to become available to output "y",
  57. * and sets "value_index" the index of the tensor in inputs . \n
  58. *@par Inputs:
  59. *x: The input tensors, one of which will become available.
  60. * Must be one of the following types: float16, float32, float64, int8,
  61. * int16, int32, int64, uint8, uint16, uint32, uint64, bool . It's a dynamic input. \n
  62. *@par Outputs:
  63. *@li y: The available tensor. Has the same type as "x".
  64. *@li value_index: A scalar of type int32, for the index of the chosen input
  65. * tensor . \n
  66. *@see Switch() | Merge()
  67. *@par Third-party framework compatibility
  68. *@Compatible with the TensorFlow operator RefMerge.
  69. */
  70. REG_OP(RefMerge)
  71. .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  72. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  73. DT_UINT64, DT_BOOL}))
  74. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  75. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  76. DT_UINT64, DT_BOOL}))
  77. .OUTPUT(value_index, TensorType({DT_INT32}))
  78. .OP_END_FACTORY_REG(RefMerge)
  79. /**
  80. *@brief Forwards "data" to the output port determined by "pred".
  81. * If "pred" is "true", the data input is forwarded to "output_true".
  82. * Otherwise, the data is forwarded to "output_false" . \n
  83. *@par Inputs:
  84. *@li data: The tensor to be forwarded. \ n
  85. * Must be one of the following types: float16, float32, float64,
  86. * int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  87. *@li pred: A boolean scalar. The output port that will receive data . \n
  88. *@par Outputs:
  89. *@li output_false: If "pred" is "false", data will be forwarded to this output.
  90. * Has the same type as "data".
  91. *@li output_true: If "pred" is "true", data will be forwarded to this output.
  92. * Has the same type as "data" . \n
  93. *@see Merge()
  94. *@par Third-party framework compatibility
  95. *@Compatible with the TensorFlow operator Switch.
  96. */
  97. REG_OP(Switch)
  98. .INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  99. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  100. DT_UINT64, DT_BOOL}))
  101. .INPUT(pred, TensorType({DT_BOOL}))
  102. .OUTPUT(output_false, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  103. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  104. DT_UINT64, DT_BOOL}))
  105. .OUTPUT(output_true, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  106. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  107. DT_UINT64, DT_BOOL}))
  108. .OP_END_FACTORY_REG(Switch)
  109. /**
  110. *@brief Forwards "data" to the output port determined by "pred".
  111. * If "pred" is "true", the data input is forwarded to "output_true".
  112. * Otherwise, the data is forwarded to "output_false" . \n
  113. *@par Inputs:
  114. *@li data: The ref tensor to be forwarded.
  115. * Must be one of the following types: float16, float32, float64,
  116. * int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  117. *@li pred: A boolean scalar. The output port that will receive data . \n
  118. *@par Outputs:
  119. *@li output_false: If "pred" is "false", data will be forwarded to this output.
  120. * Has the same type as "data".
  121. *@li output_true: If "pred" is "true", data will be forwarded to this output.
  122. * Has the same type as "data" . \n
  123. *@see Merge() | Switch()
  124. *@par Third-party framework compatibility
  125. *@Compatible with the TensorFlow operator RefSwitch.
  126. */
  127. REG_OP(RefSwitch)
  128. .INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  129. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  130. DT_UINT64, DT_BOOL}))
  131. .INPUT(pred, TensorType({DT_BOOL}))
  132. .OUTPUT(output_false, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  133. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  134. DT_UINT64, DT_BOOL}))
  135. .OUTPUT(output_true, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  136. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  137. DT_UINT64, DT_BOOL}))
  138. .OP_END_FACTORY_REG(RefSwitch)
  139. /**
  140. *@brief Forwards "data" to the output port determined by "pred_value" . \n
  141. *@par Inputs:
  142. *@li data: The tensor to be forwarded. \ n
  143. * Must be one of the following types: float16, float32, float64,
  144. * int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool.
  145. *@li pred_value: A int64 tensor which determines the output port that will receive data . \n
  146. *@par Outputs:
  147. *output: The output tensors, one of which will become available.
  148. * Has the same type as "data".
  149. */
  150. REG_OP(SwitchN)
  151. .INPUT(data, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  152. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  153. DT_UINT64, DT_BOOL}))
  154. .INPUT(pred_value, TensorType({DT_INT64}))
  155. .DYNAMIC_OUTPUT(output, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  156. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  157. DT_UINT64, DT_BOOL}))
  158. .OP_END_FACTORY_REG(SwitchN)
  159. /**
  160. *@brief Creates or finds a child frame, and makes "x" available to the child
  161. * frame. This op is used together with Exit to create loops in the graph.
  162. * The Executor uses the unique "frame_name" to identify frames.
  163. * If "is_constant" is "true", output "y" is a constant in the child
  164. * frame; otherwise it may be changed in the child frame . \n
  165. *@par Inputs:
  166. *x: The tensor to be made available to the child frame.
  167. * Must be one of the following types: float16, float32, float64, int8,
  168. * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n
  169. *@par Attributes:
  170. *@li frame_name: A required string. The name of the child frame.
  171. *@li is_constant: A required bool. If true, the output is constant in
  172. * the child frame . \n
  173. *@par Outputs:
  174. *y: A Tensor. Has the same type as "x" . \n
  175. *@see Exit()
  176. *@par Third-party framework compatibility
  177. *@Compatible with the TensorFlow operator Enter.
  178. */
  179. REG_OP(Enter)
  180. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  181. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  182. DT_UINT64, DT_BOOL}))
  183. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  184. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  185. DT_UINT64, DT_BOOL}))
  186. .REQUIRED_ATTR(frame_name, String)
  187. .REQUIRED_ATTR(is_constant, Bool)
  188. .OP_END_FACTORY_REG(Enter)
  189. /**
  190. *@brief Creates or finds a child frame, and makes "x" available to the child
  191. * frame. This op is used together with Exit to create loops in the graph.
  192. * The Executor uses the unique "frame_name" to identify frames.
  193. * If "is_constant" is "true", output "y" is a constant in the child
  194. * frame; otherwise it may be changed in the child frame . \n
  195. *@par Inputs:
  196. *x: The tensor to be made available to the child frame.
  197. * Must be one of the following types: float16, float32, float64, int8,
  198. * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n
  199. *@par Attributes:
  200. *@li frame_name: A required string. The name of the child frame.
  201. *@li is_constant: A required bool. If true, the output is constant in
  202. * the child frame . \n
  203. *@par Outputs:
  204. *y: A tensor. Has the same type as "x" . \n
  205. *@see Exit() | Enter()
  206. *@par Third-party framework compatibility
  207. *@Compatible with the TensorFlow operator RefEnter.
  208. */
  209. REG_OP(RefEnter)
  210. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  211. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  212. DT_UINT64, DT_BOOL}))
  213. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  214. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  215. DT_UINT64, DT_BOOL}))
  216. .REQUIRED_ATTR(frame_name, String)
  217. .REQUIRED_ATTR(is_constant, Bool)
  218. .OP_END_FACTORY_REG(RefEnter)
  219. /**
  220. *@brief Forwards the input to the output. This op represents the loop
  221. * termination condition . \n
  222. *@par Inputs:
  223. *x: A boolean scalar. The condition of the Switch op . \n
  224. *@par Outputs:
  225. *y: The tensor "x" . \n
  226. *@see Switch()
  227. *@par Third-party framework compatibility
  228. *@Compatible with the TensorFlow operator LoopCond.
  229. */
  230. REG_OP(LoopCond)
  231. .INPUT(x, TensorType({DT_BOOL}))
  232. .OUTPUT(y, TensorType({DT_BOOL}))
  233. .OP_END_FACTORY_REG(LoopCond)
  234. /**
  235. *@brief Makes the input available to the next iteration . \n
  236. *@par Inputs:
  237. *x: The tensor to be made available to the next iteration.
  238. * Must be one of the following types: float16, float32, float64, int8,
  239. * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n
  240. *@par Outputs:
  241. *y: A Tensor. Has the same type as "x" . \n
  242. *@par Third-party framework compatibility
  243. *@Compatible with the TensorFlow operator NextIteration.
  244. */
  245. REG_OP(NextIteration)
  246. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  247. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  248. DT_UINT64, DT_BOOL}))
  249. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  250. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  251. DT_UINT64, DT_BOOL}))
  252. .OP_END_FACTORY_REG(NextIteration)
  253. /**
  254. *@brief Makes the input available to the next iteration . \n
  255. *@par Inputs:
  256. *x: The tensor to be made available to the next iteration.
  257. * Must be one of the following types: float16, float32, float64, int8,
  258. * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n
  259. *@par Outputs:
  260. *y: A tensor. Has the same type as "x" . \n
  261. *@par Third-party framework compatibility
  262. *@Compatible with the TensorFlow operator RefNextIteration.
  263. */
  264. REG_OP(RefNextIteration)
  265. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  266. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  267. DT_UINT64, DT_BOOL}))
  268. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  269. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  270. DT_UINT64, DT_BOOL}))
  271. .OP_END_FACTORY_REG(RefNextIteration)
  272. /**
  273. *@brief Exits the current frame to its parent frame . \n
  274. *@par Inputs:
  275. *x: The tensor to be made available to the parent frame.
  276. * Must be one of the following types: float16, float32, float64, int8,
  277. * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n
  278. *@par Outputs:
  279. *y: A Tensor. Has the same type as "x" . \n
  280. *@see Enter()
  281. *@par Third-party framework compatibility
  282. *@Compatible with the TensorFlow operator Exit.
  283. */
  284. REG_OP(Exit)
  285. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  286. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  287. DT_UINT64, DT_BOOL}))
  288. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  289. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  290. DT_UINT64, DT_BOOL}))
  291. .OP_END_FACTORY_REG(Exit)
  292. /**
  293. *@brief Exits the current frame to its parent frame . \n
  294. *@par Inputs:
  295. *x: The tensor to be made available to the parent frame.
  296. * Must be one of the following types: float16, float32, float64, int8,
  297. * int16, int32, int64, uint8, uint16, uint32, uint64, bool . \n
  298. *@par Outputs:
  299. *y: A tensor. Has the same type as "x" . \n
  300. *@see Enter() | Exit()
  301. *@par Third-party framework compatibility
  302. *@Compatible with the TensorFlow operator RefExit.
  303. */
  304. REG_OP(RefExit)
  305. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  306. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  307. DT_UINT64, DT_BOOL}))
  308. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE,
  309. DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32,
  310. DT_UINT64, DT_BOOL}))
  311. .OP_END_FACTORY_REG(RefExit)
  312. /**
  313. *@brief Only useful as a placeholder for control edges.
  314. * It is similar to a no-op that always produces a live control output
  315. * even when some control inputs are dead . \n
  316. *@par Third-party framework compatibility
  317. *@Compatible with the TensorFlow operator ControlTrigger.
  318. */
  319. REG_OP(ControlTrigger)
  320. .OP_END_FACTORY_REG(ControlTrigger)
  321. /**
  322. *@brief Returns index of shape in the map.
  323. *@par Inputs:
  324. * Three inputs, including:
  325. *@li x: One dimensional tensore of type int32, specifying queried shape, max size is 8.
  326. *@li data_seq: One dimensional tensore of type int32, specifying the mapped table is queried.
  327. *@li level_index: One dimensional tensore of type int32, specifying secondary index. \n
  328. *@par Outputs:
  329. *@li y: A Tensor with shape [batch, 8], of type int32, specifying index of shape in the map.
  330. *@par Third-party framework compatibility
  331. * It is a custom operator. It has no corresponding operator in Caffe.
  332. */
  333. REG_OP(MapIndex)
  334. .INPUT(x, TensorType({DT_INT32}))
  335. .INPUT(data_seq, TensorType({DT_INT32}))
  336. .OPTIONAL_INPUT(level_index, TensorType({DT_INT32}))
  337. .OUTPUT(y, TensorType({DT_INT32}))
  338. .OP_END_FACTORY_REG(MapIndex)
  339. } // namespace ge
  340. #endif // OPS_BUILT_IN_OP_PROTO_INC_CONTROL_FLOW_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示