You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_calculation_ops.h 44 kB

5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file matrix_calculation_ops.h
  18. * \brief
  19. */
  20. #ifndef OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_
  21. #define OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_
  22. #include "graph/operator_reg.h"
  23. namespace ge {
  24. /**
  25. *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  26. *@par Inputs:
  27. *Three inputs, including:
  28. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float16,
  29. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  30. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float16,
  31. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  32. * @li bias: A optional 1D Tensor. Must be one of the following types: float16,
  33. * float32, int32. Has format [ND, NHWC] . \n
  34. *@par Attributes:
  35. *@li transpose_x1: A bool. If True, changes the shape of "x1" from [M, K] to [K, M].
  36. *@li transpose_x2: A bool. If True, changes the shape of "x2" from [M, K] to [K, M] . \n
  37. *@par Outputs:
  38. *y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  39. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ] . \n
  40. *@par Third-party framework compatibility
  41. * Compatible with the TensorFlow operator BatchMatmul.
  42. */
  43. REG_OP(MatMul)
  44. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  45. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  46. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  47. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  48. .ATTR(transpose_x1, Bool, false)
  49. .ATTR(transpose_x2, Bool, false)
  50. .OP_END_FACTORY_REG(MatMul)
  51. /**
  52. *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  53. *@par Inputs:
  54. *Two inputs, including:
  55. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float16,
  56. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  57. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float16,
  58. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  59. * @li bias: A 1D Tensor. Must be one of the following types: float16,
  60. * float32, int32. Has format [ND, NHWC]
  61. * @li offset_w: A Optional 1D Tensor for quantized interference. Type is int8. Reserved. \n
  62. *@par Attributes:
  63. *@li transpose_x1: A bool. If True, changes the shape of "x1" from [M, K] to [K, M].
  64. *@li transpose_x2: A bool. If True, changes the shape of "x2" from [M, K] to [K, M] .
  65. *@li offset_x: An optional integer for quantized deconvolution.
  66. *The negative offset added to the input image for int8 type. Ensure offset_x within the
  67. *effective range of int8 [-128, 127]. Defaults to "0". \n
  68. *@par Outputs:
  69. *y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  70. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ] . \n
  71. *@par Third-party framework compatibility
  72. * Compatible with the TensorFlow operator BatchMatmul.
  73. */
  74. REG_OP(MatMulV2)
  75. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8}))
  76. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8}))
  77. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  78. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  79. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  80. .ATTR(transpose_x1, Bool, false)
  81. .ATTR(transpose_x2, Bool, false)
  82. .ATTR(offset_x, Int, 0)
  83. .OP_END_FACTORY_REG(MatMulV2)
  84. /**
  85. *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  86. *@par Inputs:
  87. *Two inputs, including:
  88. * @li x1: A matrix Tensor. 2D. Must be one of the following types: int8.
  89. * @li x2: A matrix Tensor. 2D. Must be one of the following types: int8.
  90. * @li compress_index: A compress index matrix of type int8.
  91. * @li bias: A 1D Tensor. Must be one of the following types: int32, float16.
  92. *@par Attributes:
  93. *@li transpose_x1: A bool. If True, changes the shape of "x1" from [M, K] to [K, M].
  94. *@li transpose_x2: A bool. If True, changes the shape of "x2" from [M, K] to [K, M] . \n
  95. *@par Outputs:
  96. *y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  97. * int32. \n
  98. */
  99. REG_OP(MatMulV2Compress)
  100. .INPUT(x1, TensorType({DT_INT8}))
  101. .INPUT(x2, TensorType({DT_INT8}))
  102. .INPUT(compress_index, TensorType({DT_INT8}))
  103. .OPTIONAL_INPUT(bias, TensorType({DT_INT32, DT_FLOAT16}))
  104. .OUTPUT(y, TensorType({DT_INT32, DT_FLOAT16}))
  105. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  106. .ATTR(transpose_x1, Bool, false)
  107. .ATTR(transpose_x2, Bool, false)
  108. .ATTR(offset_x, Int, 0)
  109. .OP_END_FACTORY_REG(MatMulV2Compress)
  110. /**
  111. *@brief Performs Matrix-to-matrix Multiply, producing c=alpha[0]*a*b+beta[0]*c . \n
  112. *@attention Constraints:
  113. * For better performance, The k-axis must be aligned to 16 (input type
  114. * is float16) or 32 (input type is int8). \n
  115. *@par Inputs:
  116. *Five inputs, including:
  117. *@li a: A matrix Tensor. Must be one of the following types: float16, int8.
  118. * Has format [ND, FRACTAL_NZ]. 2D(ND) or 4D(FRACTAL_NZ).
  119. *@li b: A matrix Tensor. Must be one of the following types: float16, int8.
  120. * Has format [ND, FRACTAL_NZ, FRACTAL_Z]. 2D(ND) or 4D(FRACTAL_NZ, FRACTAL_Z).
  121. *@li c: A matrix Tensor. Must be one of the following types: float16, int32,
  122. * float32. has format [ND, FRACTAL_NZ]. 2D(ND) or 4D(FRACTAL_NZ).
  123. *@li alpha: A 1D Tensor. The shape of alpha is [1].Must be one of the following
  124. * types: float16, int32, float32. Has format [ND].
  125. *@li beta: A 1D Tensor. The shape of beta is [1]. Must be one of the following
  126. * types: float16, int32, float32. Has format [ND].
  127. * The format of a, b, c has restriction:\n
  128. * When type of a is int8 and type of c is int32, the format of a, b, c should
  129. * all be ND, or a is FRACTAL_NZ and b is FRACTAL_Z and c is ND.\n
  130. * When type of a is int8 and type of c is float32, the format of a, b, c should
  131. * all be ND or a is FRACTAL_NZ and b is FRACTAL_Z and c is FRACTAL_NZ.\n
  132. * When type of a is float16 and type of c is float16, the format of a, b, c
  133. * should all be ND or FRACTAL_NZ.\n
  134. * When type of a is float16 and type of c is float32, the format of a, b, c
  135. * should all be ND or FRACTAL_NZ . \n
  136. *@par Attributes:
  137. *Two attributes, including:
  138. *@li transpose_a: Optional. A bool. If True, changes the shape of "a" from
  139. * [M, K] to [K, M].
  140. *@li transpose_b: Optional. A bool. If True, changes the shape of "b" from
  141. * [K, N] to [N, K] . \n
  142. *@par Outputs:
  143. *y: The result matrix Tensor. Must be one of the following types: float16,
  144. * float32, int32. Has format [ND, FRACTAL_NZ], the format should be equal to a.
  145. * 2D(ND) or 4D(FRACTAL_NZ).
  146. */
  147. REG_OP(GEMM)
  148. .INPUT(a, TensorType({DT_FLOAT16, DT_INT8}))
  149. .INPUT(b, TensorType({DT_FLOAT16, DT_INT8}))
  150. .INPUT(c, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  151. .INPUT(alpha, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  152. .INPUT(beta, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  153. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  154. .ATTR(transpose_a, Bool, false)
  155. .ATTR(transpose_b, Bool, false)
  156. .OP_END_FACTORY_REG(GEMM)
  157. /**
  158. *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  159. *@par Inputs:
  160. *Two inputs, including:
  161. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  162. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ].
  163. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  164. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ] . \n
  165. *@par Attributes:
  166. *@li adj_x1: A bool. If True, changes the shape of "x1" from [B, M, K] to [B, K, M].
  167. *@li adj_x2: A bool. If True, changes the shape of "x2" from [B, M, K] to [B, K, M] . \n
  168. *@par Outputs:
  169. *y: The result matrix Tensor. 2D or higher. Must be one of the following types: float16,
  170. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ]. Has the same shape length as "x1" and "x2" . \n
  171. *@par Third-party framework compatibility
  172. * Compatible with the TensorFlow operator BatchMatmul.
  173. */
  174. REG_OP(BatchMatMul)
  175. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  176. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  177. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  178. .ATTR(adj_x1, Bool, false)
  179. .ATTR(adj_x2, Bool, false)
  180. .OP_END_FACTORY_REG(BatchMatMul)
  181. /**
  182. * @brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  183. * @par Inputs:
  184. * Three inputs, including:
  185. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  186. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ].
  187. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  188. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ] . \n
  189. * @li bias: A matrix Tensor. Must be one of the following types: float16,
  190. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ] . \n
  191. * @par Attributes:
  192. * @li adj_x1: A bool. If True, changes the shape of "x1" from [B, M, K] to [B, K, M].
  193. * @li adj_x2: A bool. If True, changes the shape of "x2" from [B, M, K] to [B, K, M] . \n
  194. * @par Outputs:
  195. * y: The result matrix Tensor. 2D or higher. Must be one of the following types: float16,
  196. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ]. Has the same shape length as "x1" and "x2" . \n
  197. * @par Third-party framework compatibility
  198. * Compatible with the TensorFlow operator BatchMatmul.
  199. */
  200. REG_OP(BatchMatMulV2)
  201. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8}))
  202. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8}))
  203. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  204. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  205. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  206. .ATTR(adj_x1, Bool, false)
  207. .ATTR(adj_x2, Bool, false)
  208. .ATTR(offset_x, Int, 0)
  209. .OP_END_FACTORY_REG(BatchMatMulV2)
  210. /**
  211. *@brief Computes half the L2 norm of a tensor without the sqrt . \n
  212. *@par Inputs:
  213. * x: A Tensor.
  214. * TensorType::FloatingDataType() . \n
  215. *@par Outputs:
  216. *y: A Tensor. Has the same type as "x".
  217. *@par Third-party framework compatibility
  218. *Compatible with the TensorFlow operator L2Loss.
  219. */
  220. REG_OP(L2Loss)
  221. .INPUT(x, TensorType::FloatingDataType())
  222. .OUTPUT(y, TensorType::FloatingDataType())
  223. .OP_END_FACTORY_REG(L2Loss)
  224. /**
  225. *@brief: Returns a batched diagonal tensor with a given batched diagonal values . \n
  226. *@par Inputs:
  227. *x: A Tensor. Must be one of the following types:
  228. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  229. * qint8, quint8, qint32, uint16, complex128, uint32, uint64 . \n
  230. *@par Outputs:
  231. *y: A Tensor. Has the same type as "x" . \n
  232. *@par Third-party framework compatibility
  233. * Compatible with the TensorFlow operator MatrixDiag.
  234. */
  235. REG_OP(MatrixDiag)
  236. .INPUT(x, TensorType::BasicType())
  237. .OUTPUT(y, TensorType::BasicType())
  238. .OP_END_FACTORY_REG(MatrixDiag)
  239. /**
  240. *@brief: Returns a batched diagonal tensor with a given batched diagonal values . \n
  241. *@par Inputs:
  242. * Two inputs, including:
  243. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  244. *@li assist: A Tensor of the same type as "x" . \n
  245. *@par Outputs:
  246. *y: A Tensor. Has the same type as "x" . \n
  247. *@par Third-party framework compatibility
  248. * Compatible with the TensorFlow operator MatrixDiag.
  249. *
  250. * @par Restrictions:
  251. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixDiag instead.
  252. */
  253. REG_OP(MatrixDiagD)
  254. .INPUT(x, TensorType::BasicType())
  255. .INPUT(assist, TensorType::BasicType())
  256. .OUTPUT(y, TensorType::BasicType())
  257. .OP_END_FACTORY_REG(MatrixDiagD)
  258. /**
  259. *@brief: Returns the batched diagonal part of a batched tensor . \n
  260. *@par Inputs:
  261. *x: A Tensor. Must be one of the following types:
  262. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  263. * qint8, quint8, qint32, uint16, complex128, uint32, uint64 . \n
  264. *@par Outputs:
  265. *y: A Tensor. Has the same type as "x" . \n
  266. *@par Third-party framework compatibility
  267. * Compatible with the TensorFlow operator MatrixDiagPart.
  268. */
  269. REG_OP(MatrixDiagPart)
  270. .INPUT(x, TensorType::BasicType())
  271. .OUTPUT(y, TensorType::BasicType())
  272. .OP_END_FACTORY_REG(MatrixDiagPart)
  273. /**
  274. *@brief: Returns the batched diagonal part of a batched tensor . \n
  275. *@par Inputs:
  276. * Two inputs, including:
  277. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  278. *@li assist: A Tensor of the same type as "x" . \n
  279. *@par Outputs:
  280. *y: A Tensor. Has the same type as "x" . \n
  281. *@par Third-party framework compatibility
  282. * Compatible with the TensorFlow operator MatrixDiagPart.
  283. *
  284. * @par Restrictions:
  285. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixDiagPart instead.
  286. */
  287. REG_OP(MatrixDiagPartD)
  288. .INPUT(x, TensorType::BasicType())
  289. .INPUT(assist, TensorType::BasicType())
  290. .OUTPUT(y, TensorType::BasicType())
  291. .OP_END_FACTORY_REG(MatrixDiagPartD)
  292. /**
  293. *@brief: Returns a batched matrix tensor with new batched diagonal values . \n
  294. *@par Inputs:
  295. * Two inputs, including:
  296. *@li x: A Tensor. Must be one of the following types:
  297. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  298. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  299. *@li diagonal: A Tensor of the same type as "x" . \n
  300. *@par Outputs:
  301. *y: A Tensor. Has the same type as "x" . \n
  302. *@par Third-party framework compatibility
  303. * Compatible with the TensorFlow operator MatrixSetDiag.
  304. */
  305. REG_OP(MatrixSetDiag)
  306. .INPUT(x, TensorType::BasicType())
  307. .INPUT(diagonal, TensorType::BasicType())
  308. .OUTPUT(y, TensorType::BasicType())
  309. .OP_END_FACTORY_REG(MatrixSetDiag)
  310. /**
  311. *@brief: Returns a batched matrix tensor with new batched diagonal values . \n
  312. *@par Inputs:
  313. * Three inputs, including:
  314. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  315. *@li diagonal: A Tensor of the same type as "x".
  316. *@li assist: A Tensor of the same type as "x" . \n
  317. *@par Outputs:
  318. *y: A Tensor. Has the same type as "x" . \n
  319. *@par Third-party framework compatibility
  320. * Compatible with the TensorFlow operator MatrixSetDiag.
  321. *
  322. * @par Restrictions:
  323. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixSetDiag instead.
  324. */
  325. REG_OP(MatrixSetDiagD)
  326. .INPUT(x, TensorType::BasicType())
  327. .INPUT(diagonal, TensorType::BasicType())
  328. .INPUT(assist, TensorType::BasicType())
  329. .OUTPUT(y, TensorType::BasicType())
  330. .OP_END_FACTORY_REG(MatrixSetDiagD)
  331. /**
  332. *@brief Applies sparse "updates" to individual values or slices in a Variable . \n
  333. *@par Inputs:
  334. * Three inputs, including:
  335. *@li var: An ND Tensor.
  336. *Must be one of the following types: float16, float32, int8, uint8, double,
  337. * int64, complex64, qint8, quint8, qint32, uint16, complex128, half, uint32,
  338. * uint64
  339. *@li indices: An ND Tensor.
  340. *Must be one of the following types: int32 or int64
  341. *@li updates: An ND Tensor.
  342. *Must be one of the following types: float16, float32, int8, uint8, double,
  343. * int64, complex64, qint8, quint8, qint32, uint16, complex128, half, uint32,
  344. * uint64
  345. *@par Attributes:
  346. *use_locking: An optional bool. Defaults to "False". If "True",
  347. * the operation will be protected by a lock . \n
  348. *@par Outputs:
  349. *var: A Tensor. Has the same type and format as input "var" . \n
  350. *@par Third-party framework compatibility
  351. * Compatible with the TensorFlow operator ScatterNdUpdate.
  352. */
  353. REG_OP(ScatterNdUpdate)
  354. .INPUT(var, TensorType::BasicType())
  355. .INPUT(indices, TensorType::IndexNumberType())
  356. .INPUT(updates, TensorType::BasicType())
  357. .OUTPUT(var, TensorType::BasicType())
  358. .ATTR(use_locking, Bool, false)
  359. .OP_END_FACTORY_REG(ScatterNdUpdate)
  360. /**
  361. *@brief Applies sparse addition to individual values or slices in a Variable . \n
  362. *@par Inputs:
  363. * Three inputs, including:
  364. *@li x: An ND Tensor. \n
  365. *Must be one of the following types: float16, float32, bool, int8, uint8
  366. *@li indices: An ND Tensor. \n
  367. *Must be one of the following types: int32
  368. *@li updates: An ND Tensor. \n
  369. *Must be one of the following types: float16, float32, bool, int8, uint8
  370. *@par Outputs:
  371. *y: A Tensor. Has the same type and format as input "x" . \n
  372. *@par Third-party framework compatibility
  373. * Compatible with the TensorFlow operator TensorScatterUpdate.
  374. *@par Restrictions:
  375. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  376. */
  377. REG_OP(TensorScatterUpdate)
  378. .INPUT(x, TensorType::BasicType())
  379. .INPUT(indices, TensorType::IndexNumberType())
  380. .INPUT(updates, TensorType::BasicType())
  381. .OUTPUT(y, TensorType::BasicType())
  382. .OP_END_FACTORY_REG(TensorScatterUpdate)
  383. /**
  384. *@brief Uses "updates" to update tensor "data" by "indices". \n
  385. *@par Inputs:
  386. * Three inputs, including:
  387. *@li data: An ND Tensor . \n
  388. *Must be one of the following types: float16, float32, int32, int8, uint8
  389. *@li indices: An ND Tensor of type int32 or int64
  390. *@li updates: An Tensor. Same shape as indices. format:NCHW, NHWC . \n
  391. *Must be one of the following types: float16, float32, int32, int8, uint8
  392. *@par Attributes:
  393. *@li axis: An optional attribute. Defaults to 0.
  394. *@par Outputs:
  395. *y: A Tensor. Has the same type and format as input "data" . \n
  396. *@par Third-party framework compatibility
  397. * Compatible with the ONNX operator ScatterElements.
  398. */
  399. REG_OP(ScatterElements)
  400. .INPUT(data, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  401. .INPUT(indices, TensorType::IndexNumberType())
  402. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  403. .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  404. .ATTR(axis, Int, 0)
  405. .OP_END_FACTORY_REG(ScatterElements)
  406. /**
  407. *@brief Adds sparse "updates" to a variable reference . \n
  408. *@par Inputs:
  409. * Three inputs, including:
  410. *@li var: An ND Tensor .
  411. *Must be one of the following types: float16, float32, int32, int8, uint8
  412. *@li indices: An ND Tensor of type int32 or int64
  413. *@li updates: An Tensor. format:NCHW, NHWC .
  414. *Must be one of the following types: float16, float32, int32, int8, uint8
  415. *@par Attributes:
  416. * use_locking: An optional bool. Defaults to "False". If "True", the operation
  417. * will be protected by a lock . \n
  418. *@par Outputs:
  419. *var: A Tensor. Has the same type and format as input "var" . \n
  420. *@par Third-party framework compatibility
  421. * Compatible with the TensorFlow operator ScatterAdd.
  422. */
  423. REG_OP(ScatterAdd)
  424. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  425. .INPUT(indices, TensorType::IndexNumberType())
  426. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  427. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  428. .ATTR(use_locking, Bool, false)
  429. .OP_END_FACTORY_REG(ScatterAdd)
  430. /**
  431. *@brief Divides a variable reference by sparse updates . \n
  432. *@par Inputs:
  433. * Three inputs, including:
  434. *@li var: An ND Tensor.
  435. *Must be one of the following types: float16, float, int32, int8, uint8
  436. *@li indices: An ND Tensor.
  437. *Must be one of the following types: int32 or int64
  438. *@li updates: An ND Tensor.
  439. *Must be one of the following types: float16, float, int32, int8, uint8
  440. *@par Attributes:
  441. *@li use_locking: An optional bool. Defaults to "False". If "True",
  442. * the operation will be protected by a lock . \n
  443. *@par Outputs:
  444. *var: A Tensor. Has the same type and format as input "var" . \n
  445. *@par Third-party framework compatibility
  446. * Compatible with the TensorFlow operator ScatterDiv.
  447. */
  448. REG_OP(ScatterDiv)
  449. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  450. .INPUT(indices, TensorType::IndexNumberType())
  451. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  452. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  453. .ATTR(use_locking, Bool, false)
  454. .OP_END_FACTORY_REG(ScatterDiv)
  455. /**
  456. *@brief Applies sparse addition to individual values or slices in a Variable . \n
  457. *@par Inputs:
  458. * Three inputs, including:
  459. *@li var: An ND Tensor.
  460. *Must be one of the following types: float16, float, int32, int8, uint8
  461. *@li indices: An ND Tensor.
  462. *Must be one of the following types: int32 or int64
  463. *@li updates: An ND Tensor.
  464. *Must be one of the following types: float16, float, int32, int8, uint8
  465. *@par Attributes:
  466. *use_locking: An optional bool. Defaults to "False". If "True",
  467. * the operation will be protected by a lock . \n
  468. *@par Outputs:
  469. *var: A Tensor. Has the same type and format as input "var" . \n
  470. *@par Third-party framework compatibility
  471. * Compatible with the TensorFlow operator ScatterNdAdd.
  472. */
  473. REG_OP(ScatterNdAdd)
  474. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  475. .INPUT(indices, TensorType::IndexNumberType())
  476. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  477. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  478. .ATTR(use_locking, Bool, false)
  479. .OP_END_FACTORY_REG(ScatterNdAdd)
  480. /**
  481. *@brief Applies sparse addition to individual values or slices in a Variable . \n
  482. *@par Inputs:
  483. * Three inputs, including:
  484. *@li x: An ND Tensor. \n
  485. *Must be one of the following types: float16, float32, int32, int8, uint8
  486. *@li indices: An ND Tensor. \n
  487. *Must be one of the following types: int32
  488. *@li updates: An ND Tensor. \n
  489. * Must be one of the following types: float16, float32, int32, int8, uint8
  490. *@par Outputs:
  491. *y: A Tensor. Has the same type and format as input "x" . \n
  492. *@par Third-party framework compatibility
  493. * Compatible with the TensorFlow operator TensorScatterAdd.
  494. *@par Restrictions:
  495. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  496. */
  497. REG_OP(TensorScatterAdd)
  498. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  499. .INPUT(indices, TensorType::IndexNumberType())
  500. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  501. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  502. .OP_END_FACTORY_REG(TensorScatterAdd)
  503. /**
  504. *@brief Applies sparse subtraction to individual values or slices in a Variable . \n
  505. *@par Inputs:
  506. * Three inputs, including:
  507. *@li var: An ND Tensor.
  508. *Must be one of the following types: float16, float, int32, int8, uint8
  509. *@li indices: An ND Tensor.
  510. *Must be one of the following types: int32 or int64
  511. *@li updates: An ND Tensor.
  512. *Must be one of the following types: float16, float, int32, int8, uint8
  513. *@par Attributes:
  514. *use_locking: An optional bool. Defaults to "False". If "True",
  515. * the operation will be protected by a lock . \n
  516. *@par Outputs:
  517. * var: A Tensor. Has the same type and format as input "var" . \n
  518. *@par Third-party framework compatibility
  519. * Compatible with the TensorFlow operator ScatterNdSub.
  520. */
  521. REG_OP(ScatterNdSub)
  522. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  523. .INPUT(indices, TensorType::IndexNumberType())
  524. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  525. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  526. .ATTR(use_locking, Bool, false)
  527. .OP_END_FACTORY_REG(ScatterNdSub)
  528. /**
  529. *@brief Applies sparse addition to individual values or slices in a Variable . \n
  530. *@par Inputs:
  531. * Three inputs, including:
  532. *@li x: An ND Tensor. \n
  533. *Must be one of the following types: float16, float32, int32, int8, uint8
  534. *@li indices: An ND Tensor. \n
  535. *Must be one of the following types: int32
  536. *@li updates: An ND Tensor. \n
  537. *Must be one of the following types: float16, float32, int32, int8, uint8
  538. *@par Outputs:
  539. * y: A Tensor. Has the same type and format as input "x" . \n
  540. *@par Third-party framework compatibility
  541. * Compatible with the TensorFlow operator TensorScatterSub.
  542. *@par Restrictions:
  543. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  544. */
  545. REG_OP(TensorScatterSub)
  546. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  547. .INPUT(indices, TensorType::IndexNumberType())
  548. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  549. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  550. .OP_END_FACTORY_REG(TensorScatterSub)
  551. /**
  552. *@brief Subtracts sparse updates to a variable reference . \n
  553. *@par Inputs:
  554. * Three inputs, including:
  555. *@li var: An ND Tensor.
  556. *Must be one of the following types: float16, float, int32, int8, uint8
  557. *@li indices: An ND Tensor.
  558. *Must be one of the following types: int32 or int64
  559. *@li updates: An ND Tensor.
  560. *Must be one of the following types: float16, float, int32, int8, uint8
  561. *@par Attributes:
  562. *use_locking: An optional bool. Defaults to "False". If "True",
  563. * the operation will be protected by a lock . \n
  564. *@par Outputs:
  565. * var: A Tensor. Has the same type and format as input "var" . \n
  566. *@par Third-party framework compatibility
  567. * Compatible with the TensorFlow operator ScatterSub.
  568. */
  569. REG_OP(ScatterSub)
  570. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  571. .INPUT(indices, TensorType::IndexNumberType())
  572. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  573. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  574. .ATTR(use_locking, Bool, false)
  575. .OP_END_FACTORY_REG(ScatterSub)
  576. /**
  577. *@brief: Returns the batched diagonal part of a batched tensor with "assist" . \n
  578. *@par Inputs:
  579. * Two inputs, including:
  580. * @li x: A Tensor of type float16, float32, or int32.
  581. * @li assist: A Tensor of the same type as "x" . \n
  582. *@par Outputs:
  583. *y: A Tensor. Has the same type as "x" . \n
  584. *@par Third-party framework compatibility
  585. * Compatible with the TensorFlow operator DiagPart.
  586. *
  587. * @par Restrictions:
  588. * Warning: THIS FUNCTION IS DEPRECATED. Please use DiagPart instead.
  589. */
  590. REG_OP(DiagPartD)
  591. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  592. .INPUT(assist, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  593. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  594. .OP_END_FACTORY_REG(DiagPartD)
  595. /**
  596. *@brief: Returns the batched diagonal part of a batched tensor . \n
  597. *@par Inputs:
  598. *x: A Tensor. Must be one of the following types:
  599. * float16, float32, int32, int64, double, complex64, complex128 . \n
  600. *@par Outputs:
  601. *y: A Tensor. Has the same type as "x" . \n
  602. *@par Third-party framework compatibility
  603. * Compatible with the TensorFlow operator DiagPart.
  604. */
  605. REG_OP(DiagPart)
  606. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  607. DT_COMPLEX64, DT_COMPLEX128}))
  608. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  609. DT_COMPLEX64, DT_COMPLEX128}))
  610. .OP_END_FACTORY_REG(DiagPart)
  611. /**
  612. *@brief Also known as a "fully-connected" layer, computes an inner product with a set of learned weights, and (optionally) adds biases . \n
  613. *@par Inputs:
  614. * Four inputs, including:
  615. *@li x: A Tensor of type float16, int8.
  616. *@li w: A weight matrix of type float16, int8.
  617. *@li b: A Tensor of type float16, int32, float32.
  618. *@li offset_w: A Tensor of type int8 . \n
  619. *@par Attributes:
  620. *@li num_output: Reserved.
  621. *@li transpose: A bool, specifying weight whether to transpose, either "true" or "false". Defaults to "false".
  622. *@li axis: Optional. A int, 1 or 2, specifying which dimension the input "K" starts from. Defaults to 1.
  623. * The product of the subsequent dimensions starting form first dimension or the second dimension is "K".
  624. *@li offset_x: Reserved . \n
  625. *@par Outputs:
  626. *y: The result tensor of type float16, int32, float32 . \n
  627. *@par Third-party framework compatibility
  628. * Compatible with the Caffe operator InnerProduct . \n
  629. *@par Quantization supported or not
  630. * Yes
  631. */
  632. REG_OP(FullyConnection)
  633. .INPUT(x, TensorType({DT_FLOAT16, DT_INT8}))
  634. .INPUT(w, TensorType({DT_FLOAT16, DT_INT8}))
  635. .OPTIONAL_INPUT(b, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32}))
  636. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  637. .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32}))
  638. .REQUIRED_ATTR(num_output, Int)
  639. .ATTR(transpose, Bool, false)
  640. .ATTR(axis, Int, 1)
  641. .ATTR(offset_x, Int, 0)
  642. .OP_END_FACTORY_REG(FullyConnection)
  643. /**
  644. *@brief Also known as a "fully-connected-compress" layer, computes an inner product with a set of learned weights, and (optionally) adds biases . \n
  645. *@par Inputs:
  646. * Four inputs, including:
  647. *@li x: A Tensor of type uint8, int8.
  648. *@li w: A weight matrix of type int8, int8.
  649. *@li w: A compress index matrix of type int8, int8.
  650. *@li b: A Tensor of type float16, int32, int32.
  651. *@li offset_w: A Tensor of type int8.i
  652. *@par Attributes:
  653. *@li num_output: Reserved.
  654. *@li transpose: A bool, specifying whether to transpose, either "true" or "false". Defaults to "false".
  655. *@li axis: Reserved.
  656. *@li offset_x: Reserved . \n
  657. *@par Outputs:
  658. *y: The result tensor of type int32 . \n
  659. *@par Third-party framework compatibility
  660. * Compatible with the Caffe operator InnerProduct . \n
  661. *@par Quantization supported or not
  662. * Yes
  663. */
  664. REG_OP(FullyConnectionCompress)
  665. .INPUT(x, TensorType({DT_UINT8, DT_INT8}))
  666. .INPUT(w, TensorType({DT_INT8}))
  667. .INPUT(comress_index, TensorType({DT_INT8}))
  668. .OPTIONAL_INPUT(b, TensorType({DT_INT32}))
  669. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  670. .OUTPUT(y, TensorType({DT_INT32}))
  671. .REQUIRED_ATTR(num_output, Int)
  672. .ATTR(transpose, Bool, false)
  673. .ATTR(axis, Int, 1)
  674. .ATTR(offset_x, Int, 0)
  675. .OP_END_FACTORY_REG(FullyConnectionCompress)
  676. /**
  677. *@brief Computes the confusion matrix from predictions and labels . \n
  678. *@par Inputs:
  679. * Three inputs, including:
  680. *@li labels: A Tensor. Must be one of the following types: float16, float32,
  681. * int32, int8, uint8.
  682. *@li predictions: A Tensor. Must be one of the following types: float16,
  683. * float32, int32, int8, uint8.
  684. *@li weights: A Tensor. Must be one of the following types: float16, float32,
  685. * int32, int8, uint8 . \n
  686. *@par Attributes:
  687. *@li num_classes: An integer for the shape of the output matrix.
  688. * No default value.
  689. *@li dtype: Data type of the confusion matrix. No default value . \n
  690. *@par Outputs:
  691. *y: A Tensor. Has the same type and format as input "labels"
  692. *@attention Constraints:
  693. *@li "weights", "labels", and "predictions" are 1D tensors.
  694. *@li The output is with shape (num_classes, num_classes),
  695. * where, 1 <= num_classes <= 4096 . \n
  696. *@see Region()
  697. *@par Third-party framework compatibility
  698. * Compatible with the TensorFlow operator ConfusionMatrix.
  699. */
  700. REG_OP(ConfusionMatrix)
  701. .INPUT(labels, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  702. .INPUT(predictions, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  703. .OPTIONAL_INPUT(weights, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  704. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  705. .REQUIRED_ATTR(num_classes, Int)
  706. .REQUIRED_ATTR(dtype, String)
  707. .OP_END_FACTORY_REG(ConfusionMatrix)
  708. /**
  709. *@brief Multiplies sparse updates into a variable reference . \n
  710. *@par Inputs:
  711. * Three inputs, including:
  712. *@li var: An ND Tensor.
  713. *Must be one of the following types: float16, float, int32, int8, uint8
  714. *@li indices: An ND Tensor.
  715. *Must be one of the following types: int32 or int64
  716. *@li updates: An ND Tensor . \n
  717. *Must be one of the following types: float16, float, int32, int8, uint8
  718. *@par Attributes:
  719. *use_locking: An optional bool. Defaults to "False". If "True", the operation
  720. * will be protected by a lock . \n
  721. *@par Outputs:
  722. *var: A Tensor. Has the same type and format as input "var" . \n
  723. *@par Third-party framework compatibility
  724. * Compatible with the TensorFlow operator ScatterMul.
  725. */
  726. REG_OP(ScatterMul)
  727. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  728. .INPUT(indices, TensorType::IndexNumberType())
  729. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  730. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  731. .ATTR(use_locking, Bool, false)
  732. .OP_END_FACTORY_REG(ScatterMul)
  733. /**
  734. *@brief Reduces sparse updates into a variable reference using
  735. * the "min" operation . \n
  736. *@par Inputs:
  737. * Three inputs, including:
  738. *@li var: An ND Tensor.
  739. *Must be one of the following types: float16, float, int32, int8, uint8
  740. *@li indices: An ND Tensor.
  741. *Must be one of the following types: int32 or int64
  742. *@li updates: An ND Tensor.
  743. *Must be one of the following types: float16, float, int32, int8, uint8
  744. *@par Attributes:
  745. *use_locking: An optional bool. Defaults to "False". If "True", the operation
  746. * will be protected by a lock . \n
  747. *@par Outputs:
  748. *var: A Tensor. Has the same type and format as input "var" . \n
  749. *@par Third-party framework compatibility
  750. * Compatible with the TensorFlow operator ScatterMin.
  751. */
  752. REG_OP(ScatterMin)
  753. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  754. .INPUT(indices, TensorType::IndexNumberType())
  755. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  756. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  757. .ATTR(use_locking, Bool, false)
  758. .OP_END_FACTORY_REG(ScatterMin)
  759. /**
  760. *@brief Reduces sparse updates into a variable reference using the "max" operation . \n
  761. *@par Inputs:
  762. * Three inputs, including:
  763. *@li var: An ND Tensor .
  764. *Must be one of the following types: float16, float, int32, int8, uint8
  765. *@li indices: An NCHW, NHWC, or ND Tensor . \n
  766. *Must be one of the following types: int32 or int64
  767. *@li updates: An NCHW, NHWC, or ND Tensor .
  768. *Must be one of the following types: float16, float, int32, int8, uint8
  769. *@par Attributes:
  770. *use_locking: An optional bool. Defaults to "False".
  771. * If "True", the operation will be protected by a lock . \n
  772. *@par Outputs:
  773. *var: A Tensor. Has the same type and format as input "var" . \n
  774. *@par Third-party framework compatibility
  775. * Compatible with the TensorFlow operator ScatterMax.
  776. */
  777. REG_OP(ScatterMax)
  778. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  779. .INPUT(indices, TensorType::IndexNumberType())
  780. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  781. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  782. .ATTR(use_locking, Bool, false)
  783. .OP_END_FACTORY_REG(ScatterMax)
  784. /**
  785. *@brief Applies sparse updates to a variable reference . \n
  786. *@par Inputs:
  787. * Three inputs, including:
  788. *@li var: An ND Tensor .
  789. *Must be one of the following types: float16, float, int32, int8, uint8
  790. *@li indices: An ND Tensor . \n
  791. *Must be one of the following types: int32 or int64
  792. *@li updates: An ND Tensor .
  793. *Must be one of the following types: float16, float, int32, int8, uint8
  794. *@par Attributes:
  795. *use_locking: An optional bool. Defaults to "False". If "True",
  796. * the operation will be protected by a lock . \n
  797. *@par Outputs:
  798. *var: A Tensor. Has the same type and format as input "var" . \n
  799. *@par Third-party framework compatibility
  800. * Compatible with the TensorFlow operator ScatterUpdate.
  801. */
  802. REG_OP(ScatterUpdate)
  803. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  804. .INPUT(indices, TensorType::IndexNumberType())
  805. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  806. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  807. .ATTR(use_locking, Bool, false)
  808. .OP_END_FACTORY_REG(ScatterUpdate)
  809. /**
  810. *@brief Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched `input` . \n
  811. *@par Inputs:
  812. * Three inputs, including:
  813. *@li input: Rank `r` tensor where `r >= 2`. \n
  814. *@li k: \n
  815. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  816. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  817. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  818. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  819. *@li padding_value: The value to fill the area outside the specified diagonal band with. \n
  820. *@par Outputs:
  821. *diagonal: The extracted diagonal(s) . \n
  822. *@par Third-party framework compatibility
  823. * Compatible with the TensorFlow operator ScatterUpdate.
  824. */
  825. REG_OP(MatrixDiagPartV2)
  826. .INPUT(input, TensorType::BasicType())
  827. .INPUT(k, TensorType({DT_INT32}))
  828. .INPUT(padding_value, TensorType::BasicType())
  829. .OUTPUT(diagonal, TensorType::BasicType())
  830. .OP_END_FACTORY_REG(MatrixDiagPartV2)
  831. /**
  832. *@brief Returns a batched matrix tensor with new batched diagonal values . \n
  833. *@par Inputs:
  834. * Three inputs, including:
  835. *@li input: "Rank `r+1`, where `r >= 1`. \n
  836. *@li diagonal: Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`. \n
  837. *@li k:
  838. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  839. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  840. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  841. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  842. *@par Outputs:
  843. *output: Rank `r+1`, with `output.shape = input.shape` . \n
  844. *@par Third-party framework compatibility
  845. * Compatible with the TensorFlow operator ScatterUpdate.
  846. */
  847. REG_OP(MatrixSetDiagV2)
  848. .INPUT(input, TensorType::BasicType())
  849. .INPUT(diagonal, TensorType::BasicType())
  850. .INPUT(k, TensorType({DT_INT32}))
  851. .OUTPUT(output, TensorType::BasicType())
  852. .OP_END_FACTORY_REG(MatrixSetDiagV2)
  853. /**
  854. *@brief Returns a batched diagonal tensor with given batched diagonal values . \n
  855. *@par Inputs:
  856. * Five inputs, including:
  857. *@li diagonal: Rank `r`, where `r >= 1` \n
  858. *@li k:
  859. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  860. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  861. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  862. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  863. *@li num_rows:
  864. *The number of rows of the output matrix. If it is not provided, the op assumes \n
  865. *the output matrix is a square matrix and infers the matrix size from k and the \n
  866. *innermost dimension of `diagonal`. \n
  867. *@li num_cols: An NCHW, NHWC, or ND Tensor.
  868. *The number of columns of the output matrix. If it is not provided, the op \n
  869. *assumes the output matrix is a square matrix and infers the matrix size from \n
  870. *k and the innermost dimension of `diagonal`. \n
  871. *@li padding_value: The number to fill the area outside the specified diagonal band with. \n
  872. *@par Outputs:
  873. *output: Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise . \n
  874. *@par Third-party framework compatibility
  875. * Compatible with the TensorFlow operator ScatterUpdate.
  876. */
  877. REG_OP(MatrixDiagV2)
  878. .INPUT(diagonal, TensorType::BasicType())
  879. .INPUT(k, TensorType({DT_INT32}))
  880. .INPUT(num_rows, TensorType({DT_INT32}))
  881. .INPUT(num_cols, TensorType({DT_INT32}))
  882. .INPUT(padding_value, TensorType::BasicType())
  883. .OUTPUT(output, TensorType::BasicType())
  884. .OP_END_FACTORY_REG(MatrixDiagV2)
  885. /**
  886. * @brief Add updates to var_out according to axis and indices.
  887. * @par Inputs:
  888. * Three inputs, including:
  889. * @li var: A Tensor. Must be one of the following types:
  890. * float16, float32, int32, int8, uint8.
  891. * @li indices: A Tensor of the indices, type should be int32.
  892. * @li updates: A Tensor of the same type as "var".
  893. * @par Attributes:
  894. * @li axis: An required int to specify the axis to perform indices add.
  895. * @par Outputs:
  896. * @li var_out: A Tensor. Same as input "var".
  897. * @par Third-party framework compatibility
  898. * Compatible with the Pytorch operator index_add.
  899. * @par Restrictions:
  900. * Warning:THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  901. */
  902. REG_OP(IndexAdd)
  903. .INPUT(var, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  904. .INPUT(indices, TensorType({DT_INT32}))
  905. .INPUT(updates, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  906. .OUTPUT(var_out, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  907. .ATTR(axis, Int, 0)
  908. .OP_END_FACTORY_REG(IndexAdd)
  909. /**
  910. *@brief: Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input \n
  911. *@par Inputs:
  912. * Two inputs, including:
  913. *@li x: A Tensor. Must be one of the following types:
  914. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  915. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  916. *@li diagonal:(int, optional) – the diagonal to consider。\n
  917. *@par Outputs:
  918. *y: A Tensor. Has the same type as "x" . \n
  919. *@par Third-party framework compatibility
  920. * Compatible with the Pytorch operator Triu.
  921. */
  922. REG_OP(Triu)
  923. .INPUT(x, TensorType::BasicType())
  924. .ATTR(diagonal, Int, 0)
  925. .OUTPUT(y, TensorType::BasicType())
  926. .OP_END_FACTORY_REG(Triu)
  927. /**
  928. *@brief: Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input \n
  929. *@par Inputs:
  930. * Two inputs, including:
  931. *@li x: A Tensor. Must be one of the following types:
  932. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  933. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  934. *@li diagonal:(int, optional) – the diagonal to consider。\n
  935. *@par Outputs:
  936. *y: A Tensor. Has the same type as "x" . \n
  937. *@par Third-party framework compatibility
  938. * Compatible with the Pytorch operator Tril.
  939. */
  940. REG_OP(Tril)
  941. .INPUT(x, TensorType::BasicType())
  942. .ATTR(diagonal, Int, 0)
  943. .OUTPUT(y, TensorType::BasicType())
  944. .OP_END_FACTORY_REG(Tril)
  945. /**
  946. *@brief Concatenates a list of N tensors along the first dimension.
  947. *@par Inputs:
  948. * Two inputs, including:
  949. * @li values: A list of Tensors. Must be one of the following types: int32, float16, float32.
  950. * Tensors to be concatenated. All must have size 1 in the first dimension and same shape.
  951. * It's a dynamic input.
  952. * @li shape: A Tensor of the same type as "x".
  953. * The final shape of the result. Should be equal to the shapes of any input
  954. * but with the number of input values in the first dimension . \n
  955. *@par Attributes:
  956. *equation: The subscripts for the Einstein summation. \n
  957. *N: tensor size of input \n
  958. *@par Outputs:
  959. *@li y: Sums the product of the elements of the input operands along dimensions specified
  960. using a notation based on the Einstein summation convention. \n
  961. *@attention Constraints:
  962. *Input N must be Int. \n
  963. *@par Third-party framework compatibility
  964. *Compatible with Pytorch einsum operator.
  965. */
  966. REG_OP(Einsum)
  967. .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
  968. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
  969. .REQUIRED_ATTR(equation, String)
  970. .REQUIRED_ATTR(N, Int)
  971. .OP_END_FACTORY_REG(Einsum)
  972. /**
  973. *@brief Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. \n
  974. *@par Inputs:
  975. *No inputs
  976. *@par Attributes:
  977. *@li num_rows: An required int. \n
  978. *@li num_columns: An optional int.Defaults to 0. \n
  979. *@li batch_shape: An optional ListInt.Defaults to []. \n
  980. *@li dtype: An optional int.Defaults to 0. \n
  981. *@par Outputs:
  982. *y: A Tensor with targeted type and shape. \n
  983. *@par Third-party framework compatibility
  984. *Compatible with the Pytorch operator Eye. \n
  985. */
  986. REG_OP(Eye)
  987. .OUTPUT(y, TensorType::BasicType()) /* "Result, has targeted element type" */
  988. .REQUIRED_ATTR(num_rows, Int)
  989. .ATTR(num_columns, Int, 0)
  990. .ATTR(batch_shape, ListInt, {})
  991. .ATTR(dtype, Int, 0)
  992. .OP_END_FACTORY_REG(Eye)
  993. } // namespace ge
  994. #endif // OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示