You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_calculation_ops.h 42 kB

5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file matrix_calculation_ops.h
  18. * \brief
  19. */
  20. #ifndef OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_
  21. #define OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_
  22. #include "graph/operator_reg.h"
  23. namespace ge {
  24. /**
  25. *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  26. *@par Inputs:
  27. *Three inputs, including:
  28. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float16,
  29. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  30. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float16,
  31. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  32. * @li bias: A optional 1D Tensor. Must be one of the following types: float16,
  33. * float32, int32. Has format [ND, NHWC] . \n
  34. *@par Attributes:
  35. *@li transpose_x1: A bool. If True, changes the shape of "x1" from [M, K] to [K, M].
  36. *@li transpose_x2: A bool. If True, changes the shape of "x2" from [M, K] to [K, M] . \n
  37. *@par Outputs:
  38. *y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  39. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ] . \n
  40. *@par Third-party framework compatibility
  41. * Compatible with the TensorFlow operator BatchMatmul.
  42. */
  43. REG_OP(MatMul)
  44. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  45. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  46. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  47. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  48. .ATTR(transpose_x1, Bool, false)
  49. .ATTR(transpose_x2, Bool, false)
  50. .OP_END_FACTORY_REG(MatMul)
  51. /**
  52. *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  53. *@par Inputs:
  54. *Two inputs, including:
  55. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float16,
  56. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  57. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float16,
  58. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ].
  59. * @li bias: A 1D Tensor. Must be one of the following types: float16,
  60. * float32, int32. Has format [ND, NHWC] . \n
  61. *@par Attributes:
  62. *@li transpose_x1: A bool. If True, changes the shape of "x1" from [M, K] to [K, M].
  63. *@li transpose_x2: A bool. If True, changes the shape of "x2" from [M, K] to [K, M] . \n
  64. *@par Outputs:
  65. *y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  66. * float32, int32. Has format [ND, NHWC, FRACTAL_NZ] . \n
  67. *@par Third-party framework compatibility
  68. * Compatible with the TensorFlow operator BatchMatmul.
  69. */
  70. REG_OP(MatMulV2)
  71. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8}))
  72. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8}))
  73. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  74. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  75. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  76. .ATTR(transpose_x1, Bool, false)
  77. .ATTR(transpose_x2, Bool, false)
  78. .ATTR(offset_x, Int, 0)
  79. .OP_END_FACTORY_REG(MatMulV2)
  80. /**
  81. *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  82. *@par Inputs:
  83. *Two inputs, including:
  84. * @li x1: A matrix Tensor. 2D. Must be one of the following types: int8.
  85. * @li x2: A matrix Tensor. 2D. Must be one of the following types: int8.
  86. * @li compress_index: A compress index matrix of type int8.
  87. * @li bias: A 1D Tensor. Must be one of the following types: int32, float16.
  88. *@par Attributes:
  89. *@li transpose_x1: A bool. If True, changes the shape of "x1" from [M, K] to [K, M].
  90. *@li transpose_x2: A bool. If True, changes the shape of "x2" from [M, K] to [K, M] . \n
  91. *@par Outputs:
  92. *y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  93. * int32. \n
  94. */
  95. REG_OP(MatMulV2Compress)
  96. .INPUT(x1, TensorType({DT_INT8}))
  97. .INPUT(x2, TensorType({DT_INT8}))
  98. .INPUT(compress_index, TensorType({DT_INT8}))
  99. .OPTIONAL_INPUT(bias, TensorType({DT_INT32, DT_FLOAT16}))
  100. .OUTPUT(y, TensorType({DT_INT32, DT_FLOAT16}))
  101. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  102. .ATTR(transpose_x1, Bool, false)
  103. .ATTR(transpose_x2, Bool, false)
  104. .ATTR(offset_x, Int, 0)
  105. .OP_END_FACTORY_REG(MatMulV2Compress)
  106. /**
  107. *@brief Performs Matrix-to-matrix Multiply, producing c=alpha[0]*a*b+beta[0]*c . \n
  108. *@attention Constraints:
  109. * For better performance, The k-axis must be aligned to 16 (input type
  110. * is float16) or 32 (input type is int8). \n
  111. *@par Inputs:
  112. *Five inputs, including:
  113. *@li a: A matrix Tensor. Must be one of the following types: float16, int8.
  114. * Has format [ND, FRACTAL_NZ]. 2D(ND) or 4D(FRACTAL_NZ).
  115. *@li b: A matrix Tensor. Must be one of the following types: float16, int8.
  116. * Has format [ND, FRACTAL_NZ, FRACTAL_Z]. 2D(ND) or 4D(FRACTAL_NZ, FRACTAL_Z).
  117. *@li c: A matrix Tensor. Must be one of the following types: float16, int32,
  118. * float32. has format [ND, FRACTAL_NZ]. 2D(ND) or 4D(FRACTAL_NZ).
  119. *@li alpha: A 1D Tensor. The shape of alpha is [1].Must be one of the following
  120. * types: float16, int32, float32. Has format [ND].
  121. *@li beta: A 1D Tensor. The shape of beta is [1]. Must be one of the following
  122. * types: float16, int32, float32. Has format [ND].
  123. * The format of a, b, c has restriction:\n
  124. * When type of a is int8 and type of c is int32, the format of a, b, c should
  125. * all be ND, or a is FRACTAL_NZ and b is FRACTAL_Z and c is ND.\n
  126. * When type of a is int8 and type of c is float32, the format of a, b, c should
  127. * all be ND or a is FRACTAL_NZ and b is FRACTAL_Z and c is FRACTAL_NZ.\n
  128. * When type of a is float16 and type of c is float16, the format of a, b, c
  129. * should all be ND or FRACTAL_NZ.\n
  130. * When type of a is float16 and type of c is float32, the format of a, b, c
  131. * should all be ND or FRACTAL_NZ . \n
  132. *@par Attributes:
  133. *Two attributes, including:
  134. *@li transpose_a: Optional. A bool. If True, changes the shape of "a" from
  135. * [M, K] to [K, M].
  136. *@li transpose_b: Optional. A bool. If True, changes the shape of "b" from
  137. * [K, N] to [N, K] . \n
  138. *@par Outputs:
  139. *y: The result matrix Tensor. Must be one of the following types: float16,
  140. * float32, int32. Has format [ND, FRACTAL_NZ], the format should be equal to a.
  141. * 2D(ND) or 4D(FRACTAL_NZ).
  142. */
  143. REG_OP(GEMM)
  144. .INPUT(a, TensorType({DT_FLOAT16, DT_INT8}))
  145. .INPUT(b, TensorType({DT_FLOAT16, DT_INT8}))
  146. .INPUT(c, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  147. .INPUT(alpha, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  148. .INPUT(beta, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  149. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  150. .ATTR(transpose_a, Bool, false)
  151. .ATTR(transpose_b, Bool, false)
  152. .OP_END_FACTORY_REG(GEMM)
  153. /**
  154. *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  155. *@par Inputs:
  156. *Three inputs, including:
  157. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  158. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ].
  159. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  160. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ] . \n
  161. *@par Attributes:
  162. *@li adj_x1: A bool. If True, changes the shape of "x1" from [B, M, K] to [B, K, M].
  163. *@li adj_x2: A bool. If True, changes the shape of "x2" from [B, M, K] to [B, K, M] . \n
  164. *@par Outputs:
  165. *y: The result matrix Tensor. 2D or higher. Must be one of the following types: float16,
  166. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ]. Has the same shape length as "x1" and "x2" . \n
  167. *@par Third-party framework compatibility
  168. * Compatible with the TensorFlow operator BatchMatmul.
  169. */
  170. REG_OP(BatchMatMul)
  171. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  172. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  173. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  174. .ATTR(adj_x1, Bool, false)
  175. .ATTR(adj_x2, Bool, false)
  176. .OP_END_FACTORY_REG(BatchMatMul)
  177. /**
  178. * @brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  179. * @par Inputs:
  180. * Three inputs, including:
  181. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  182. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ].
  183. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  184. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ] . \n
  185. * @li bias: A matrix Tensor. Must be one of the following types: float16,
  186. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ] . \n
  187. * @par Attributes:
  188. * @li adj_x1: A bool. If True, changes the shape of "x1" from [B, M, K] to [B, K, M].
  189. * @li adj_x2: A bool. If True, changes the shape of "x2" from [B, M, K] to [B, K, M] . \n
  190. * @par Outputs:
  191. * y: The result matrix Tensor. 2D or higher. Must be one of the following types: float16,
  192. * float32, int32. 2D or higher. Has format [ND, NHWC, FRACTAL_NZ]. Has the same shape length as "x1" and "x2" . \n
  193. * @par Third-party framework compatibility
  194. * Compatible with the TensorFlow operator BatchMatmul.
  195. */
  196. REG_OP(BatchMatMulV2)
  197. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8}))
  198. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8}))
  199. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  200. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  201. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  202. .ATTR(adj_x1, Bool, false)
  203. .ATTR(adj_x2, Bool, false)
  204. .ATTR(offset_x, Int, 0)
  205. .OP_END_FACTORY_REG(BatchMatMulV2)
  206. /**
  207. *@brief Computes half the L2 norm of a tensor without the sqrt . \n
  208. *@par Inputs:
  209. * x: A Tensor.
  210. * TensorType::FloatingDataType() . \n
  211. *@par Outputs:
  212. *y: A Tensor. Has the same type as "x".
  213. *@par Third-party framework compatibility
  214. *Compatible with the TensorFlow operator L2Loss.
  215. */
  216. REG_OP(L2Loss)
  217. .INPUT(x, TensorType::FloatingDataType())
  218. .OUTPUT(y, TensorType::FloatingDataType())
  219. .OP_END_FACTORY_REG(L2Loss)
  220. /**
  221. *@brief: Returns a batched diagonal tensor with a given batched diagonal values . \n
  222. *@par Inputs:
  223. *x: A Tensor. Must be one of the following types:
  224. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  225. * qint8, quint8, qint32, uint16, complex128, uint32, uint64 . \n
  226. *@par Outputs:
  227. *y: A Tensor. Has the same type as "x" . \n
  228. *@par Third-party framework compatibility
  229. * Compatible with the TensorFlow operator MatrixDiag.
  230. */
  231. REG_OP(MatrixDiag)
  232. .INPUT(x, TensorType::BasicType())
  233. .OUTPUT(y, TensorType::BasicType())
  234. .OP_END_FACTORY_REG(MatrixDiag)
  235. /**
  236. *@brief: Returns a batched diagonal tensor with a given batched diagonal values . \n
  237. *@par Inputs:
  238. * Two inputs, including:
  239. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  240. *@li assist: A Tensor of the same type as "x" . \n
  241. *@par Outputs:
  242. *y: A Tensor. Has the same type as "x" . \n
  243. *@par Third-party framework compatibility
  244. * Compatible with the TensorFlow operator MatrixDiag.
  245. *
  246. * @par Restrictions:
  247. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixDiag instead.
  248. */
  249. REG_OP(MatrixDiagD)
  250. .INPUT(x, TensorType::BasicType())
  251. .INPUT(assist, TensorType::BasicType())
  252. .OUTPUT(y, TensorType::BasicType())
  253. .OP_END_FACTORY_REG(MatrixDiagD)
  254. /**
  255. *@brief: Returns the batched diagonal part of a batched tensor . \n
  256. *@par Inputs:
  257. *x: A Tensor. Must be one of the following types:
  258. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  259. * qint8, quint8, qint32, uint16, complex128, uint32, uint64 . \n
  260. *@par Outputs:
  261. *y: A Tensor. Has the same type as "x" . \n
  262. *@par Third-party framework compatibility
  263. * Compatible with the TensorFlow operator MatrixDiagPart.
  264. */
  265. REG_OP(MatrixDiagPart)
  266. .INPUT(x, TensorType::BasicType())
  267. .OUTPUT(y, TensorType::BasicType())
  268. .OP_END_FACTORY_REG(MatrixDiagPart)
  269. /**
  270. *@brief: Returns the batched diagonal part of a batched tensor . \n
  271. *@par Inputs:
  272. * Two inputs, including:
  273. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  274. *@li assist: A Tensor of the same type as "x" . \n
  275. *@par Outputs:
  276. *y: A Tensor. Has the same type as "x" . \n
  277. *@par Third-party framework compatibility
  278. * Compatible with the TensorFlow operator MatrixDiagPart.
  279. *
  280. * @par Restrictions:
  281. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixDiagPart instead.
  282. */
  283. REG_OP(MatrixDiagPartD)
  284. .INPUT(x, TensorType::BasicType())
  285. .INPUT(assist, TensorType::BasicType())
  286. .OUTPUT(y, TensorType::BasicType())
  287. .OP_END_FACTORY_REG(MatrixDiagPartD)
  288. /**
  289. *@brief: Returns a batched matrix tensor with new batched diagonal values . \n
  290. *@par Inputs:
  291. * Two inputs, including:
  292. *@li x: A Tensor. Must be one of the following types:
  293. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  294. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  295. *@li diagonal: A Tensor of the same type as "x" . \n
  296. *@par Outputs:
  297. *y: A Tensor. Has the same type as "x" . \n
  298. *@par Third-party framework compatibility
  299. * Compatible with the TensorFlow operator MatrixSetDiag.
  300. */
  301. REG_OP(MatrixSetDiag)
  302. .INPUT(x, TensorType::BasicType())
  303. .INPUT(diagonal, TensorType::BasicType())
  304. .OUTPUT(y, TensorType::BasicType())
  305. .OP_END_FACTORY_REG(MatrixSetDiag)
  306. /**
  307. *@brief: Returns a batched matrix tensor with new batched diagonal values . \n
  308. *@par Inputs:
  309. * Three inputs, including:
  310. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  311. *@li diagonal: A Tensor of the same type as "x".
  312. *@li assist: A Tensor of the same type as "x" . \n
  313. *@par Outputs:
  314. *y: A Tensor. Has the same type as "x" . \n
  315. *@par Third-party framework compatibility
  316. * Compatible with the TensorFlow operator MatrixSetDiag.
  317. *
  318. * @par Restrictions:
  319. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixSetDiag instead.
  320. */
  321. REG_OP(MatrixSetDiagD)
  322. .INPUT(x, TensorType::BasicType())
  323. .INPUT(diagonal, TensorType::BasicType())
  324. .INPUT(assist, TensorType::BasicType())
  325. .OUTPUT(y, TensorType::BasicType())
  326. .OP_END_FACTORY_REG(MatrixSetDiagD)
  327. /**
  328. *@brief Applies sparse "updates" to individual values or slices in a Variable . \n
  329. *@par Inputs:
  330. * Three inputs, including:
  331. *@li var: An ND Tensor.
  332. *Must be one of the following types: float16, float32, int8, uint8, double,
  333. * int64, complex64, qint8, quint8, qint32, uint16, complex128, half, uint32,
  334. * uint64
  335. *@li indices: An ND Tensor.
  336. *Must be one of the following types: int32 or int64
  337. *@li updates: An ND Tensor.
  338. *Must be one of the following types: float16, float32, int8, uint8, double,
  339. * int64, complex64, qint8, quint8, qint32, uint16, complex128, half, uint32,
  340. * uint64
  341. *@par Attributes:
  342. *use_locking: An optional bool. Defaults to "False". If "True",
  343. * the operation will be protected by a lock . \n
  344. *@par Outputs:
  345. *var: A Tensor. Has the same type and format as input "var" . \n
  346. *@par Third-party framework compatibility
  347. * Compatible with the TensorFlow operator ScatterNdUpdate.
  348. */
  349. REG_OP(ScatterNdUpdate)
  350. .INPUT(var, TensorType::BasicType())
  351. .INPUT(indices, TensorType::IndexNumberType())
  352. .INPUT(updates, TensorType::BasicType())
  353. .OUTPUT(var, TensorType::BasicType())
  354. .ATTR(use_locking, Bool, false)
  355. .OP_END_FACTORY_REG(ScatterNdUpdate)
  356. /**
  357. *@brief Applies sparse addition to individual values or slices in a Variable . \n
  358. *@par Inputs:
  359. * Three inputs, including:
  360. *@li x: An ND Tensor. \n
  361. *Must be one of the following types: float16, float32, bool, int8, uint8
  362. *@li indices: An ND Tensor. \n
  363. *Must be one of the following types: int32
  364. *@li updates: An ND Tensor. \n
  365. *Must be one of the following types: float16, float32, bool, int8, uint8
  366. *@par Outputs:
  367. *y: A Tensor. Has the same type and format as input "x" . \n
  368. *@par Third-party framework compatibility
  369. * Compatible with the TensorFlow operator TensorScatterUpdate.
  370. *@par Restrictions:
  371. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  372. */
  373. REG_OP(TensorScatterUpdate)
  374. .INPUT(x, TensorType::BasicType())
  375. .INPUT(indices, TensorType::IndexNumberType())
  376. .INPUT(updates, TensorType::BasicType())
  377. .OUTPUT(y, TensorType::BasicType())
  378. .OP_END_FACTORY_REG(TensorScatterUpdate)
  379. /**
  380. *@brief Adds sparse "updates" to a variable reference . \n
  381. *@par Inputs:
  382. * Three inputs, including:
  383. *@li var: An ND Tensor . \n
  384. *Must be one of the following types: float16, float32, int32, int8, uint8
  385. *@li indices: An ND Tensor of type int32 or int64
  386. *@li updates: An Tensor. format:NCHW, NHWC . \n
  387. *Must be one of the following types: float16, float32, int32, int8, uint8
  388. *@par Attributes:
  389. * use_locking: An optional bool. Defaults to "False". If "True", the operation
  390. * will be protected by a lock . \n
  391. *@par Outputs:
  392. *var: A Tensor. Has the same type and format as input "var" . \n
  393. *@par Third-party framework compatibility
  394. * Compatible with the TensorFlow operator ScatterAdd.
  395. */
  396. REG_OP(ScatterAdd)
  397. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  398. .INPUT(indices, TensorType::IndexNumberType())
  399. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  400. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  401. .ATTR(use_locking, Bool, false)
  402. .OP_END_FACTORY_REG(ScatterAdd)
  403. /**
  404. *@brief Divides a variable reference by sparse updates . \n
  405. *@par Inputs:
  406. * Three inputs, including:
  407. *@li var: An ND Tensor.
  408. *Must be one of the following types: float16, float, int32, int8, uint8
  409. *@li indices: An ND Tensor.
  410. *Must be one of the following types: int32 or int64
  411. *@li updates: An ND Tensor.
  412. *Must be one of the following types: float16, float, int32, int8, uint8
  413. *@par Attributes:
  414. *@li use_locking: An optional bool. Defaults to "False". If "True",
  415. * the operation will be protected by a lock . \n
  416. *@par Outputs:
  417. *var: A Tensor. Has the same type and format as input "var" . \n
  418. *@par Third-party framework compatibility
  419. * Compatible with the TensorFlow operator ScatterDiv.
  420. */
  421. REG_OP(ScatterDiv)
  422. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  423. .INPUT(indices, TensorType::IndexNumberType())
  424. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  425. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  426. .ATTR(use_locking, Bool, false)
  427. .OP_END_FACTORY_REG(ScatterDiv)
  428. /**
  429. *@brief Applies sparse addition to individual values or slices in a Variable . \n
  430. *@par Inputs:
  431. * Three inputs, including:
  432. *@li var: An ND Tensor.
  433. *Must be one of the following types: float16, float, int32, int8, uint8
  434. *@li indices: An ND Tensor.
  435. *Must be one of the following types: int32 or int64
  436. *@li updates: An ND Tensor.
  437. *Must be one of the following types: float16, float, int32, int8, uint8
  438. *@par Attributes:
  439. *use_locking: An optional bool. Defaults to "False". If "True",
  440. * the operation will be protected by a lock . \n
  441. *@par Outputs:
  442. *var: A Tensor. Has the same type and format as input "var" . \n
  443. *@par Third-party framework compatibility
  444. * Compatible with the TensorFlow operator ScatterNdAdd.
  445. */
  446. REG_OP(ScatterNdAdd)
  447. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  448. .INPUT(indices, TensorType::IndexNumberType())
  449. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  450. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  451. .ATTR(use_locking, Bool, false)
  452. .OP_END_FACTORY_REG(ScatterNdAdd)
  453. /**
  454. *@brief Applies sparse addition to individual values or slices in a Variable . \n
  455. *@par Inputs:
  456. * Three inputs, including:
  457. *@li x: An ND Tensor. \n
  458. *Must be one of the following types: float16, float32, int32, int8, uint8
  459. *@li indices: An ND Tensor. \n
  460. *Must be one of the following types: int32
  461. *@li updates: An ND Tensor. \n
  462. * Must be one of the following types: float16, float32, int32, int8, uint8
  463. *@par Outputs:
  464. *y: A Tensor. Has the same type and format as input "x" . \n
  465. *@par Third-party framework compatibility
  466. * Compatible with the TensorFlow operator TensorScatterAdd.
  467. *@par Restrictions:
  468. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  469. */
  470. REG_OP(TensorScatterAdd)
  471. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  472. .INPUT(indices, TensorType::IndexNumberType())
  473. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  474. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  475. .OP_END_FACTORY_REG(TensorScatterAdd)
  476. /**
  477. *@brief Applies sparse subtraction to individual values or slices in a Variable . \n
  478. *@par Inputs:
  479. * Three inputs, including:
  480. *@li var: An ND Tensor.
  481. *Must be one of the following types: float16, float, int32, int8, uint8
  482. *@li indices: An ND Tensor.
  483. *Must be one of the following types: int32 or int64
  484. *@li updates: An ND Tensor.
  485. *Must be one of the following types: float16, float, int32, int8, uint8
  486. *@par Attributes:
  487. *use_locking: An optional bool. Defaults to "False". If "True",
  488. * the operation will be protected by a lock . \n
  489. *@par Outputs:
  490. * var: A Tensor. Has the same type and format as input "var" . \n
  491. *@par Third-party framework compatibility
  492. * Compatible with the TensorFlow operator ScatterNdSub.
  493. */
  494. REG_OP(ScatterNdSub)
  495. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  496. .INPUT(indices, TensorType::IndexNumberType())
  497. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  498. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  499. .ATTR(use_locking, Bool, false)
  500. .OP_END_FACTORY_REG(ScatterNdSub)
  501. /**
  502. *@brief Applies sparse addition to individual values or slices in a Variable . \n
  503. *@par Inputs:
  504. * Three inputs, including:
  505. *@li x: An ND Tensor. \n
  506. *Must be one of the following types: float16, float32, int32, int8, uint8
  507. *@li indices: An ND Tensor. \n
  508. *Must be one of the following types: int32
  509. *@li updates: An ND Tensor. \n
  510. *Must be one of the following types: float16, float32, int32, int8, uint8
  511. *@par Outputs:
  512. * y: A Tensor. Has the same type and format as input "x" . \n
  513. *@par Third-party framework compatibility
  514. * Compatible with the TensorFlow operator TensorScatterSub.
  515. *@par Restrictions:
  516. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  517. */
  518. REG_OP(TensorScatterSub)
  519. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  520. .INPUT(indices, TensorType::IndexNumberType())
  521. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  522. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  523. .OP_END_FACTORY_REG(TensorScatterSub)
  524. /**
  525. *@brief Subtracts sparse updates to a variable reference . \n
  526. *@par Inputs:
  527. * Three inputs, including:
  528. *@li var: An ND Tensor.
  529. *Must be one of the following types: float16, float, int32, int8, uint8
  530. *@li indices: An ND Tensor.
  531. *Must be one of the following types: int32 or int64
  532. *@li updates: An ND Tensor.
  533. *Must be one of the following types: float16, float, int32, int8, uint8
  534. *@par Attributes:
  535. *use_locking: An optional bool. Defaults to "False". If "True",
  536. * the operation will be protected by a lock . \n
  537. *@par Outputs:
  538. * var: A Tensor. Has the same type and format as input "var" . \n
  539. *@par Third-party framework compatibility
  540. * Compatible with the TensorFlow operator ScatterSub.
  541. */
  542. REG_OP(ScatterSub)
  543. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  544. .INPUT(indices, TensorType::IndexNumberType())
  545. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  546. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  547. .ATTR(use_locking, Bool, false)
  548. .OP_END_FACTORY_REG(ScatterSub)
  549. /**
  550. *@brief: Returns the batched diagonal part of a batched tensor with "assist" . \n
  551. *@par Inputs:
  552. * Two inputs, including:
  553. * @li x: A Tensor of type float16, float32, or int32.
  554. * @li assist: A Tensor of the same type as "x" . \n
  555. *@par Outputs:
  556. *y: A Tensor. Has the same type as "x" . \n
  557. *@par Third-party framework compatibility
  558. * Compatible with the TensorFlow operator DiagPart.
  559. *
  560. * @par Restrictions:
  561. * Warning: THIS FUNCTION IS DEPRECATED. Please use DiagPart instead.
  562. */
  563. REG_OP(DiagPartD)
  564. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  565. .INPUT(assist, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  566. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  567. .OP_END_FACTORY_REG(DiagPartD)
  568. /**
  569. *@brief: Returns the batched diagonal part of a batched tensor . \n
  570. *@par Inputs:
  571. *x: A Tensor. Must be one of the following types:
  572. * float16, float32, int32, int64, double, complex64, complex128 . \n
  573. *@par Outputs:
  574. *y: A Tensor. Has the same type as "x" . \n
  575. *@par Third-party framework compatibility
  576. * Compatible with the TensorFlow operator DiagPart.
  577. */
  578. REG_OP(DiagPart)
  579. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  580. DT_COMPLEX64, DT_COMPLEX128}))
  581. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  582. DT_COMPLEX64, DT_COMPLEX128}))
  583. .OP_END_FACTORY_REG(DiagPart)
  584. /**
  585. *@brief Also known as a "fully-connected" layer, computes an inner product with a set of learned weights, and (optionally) adds biases . \n
  586. *@par Inputs:
  587. * Four inputs, including:
  588. *@li x: A Tensor of type float16, int8.
  589. *@li w: A weight matrix of type float16, int8.
  590. *@li b: A Tensor of type float16, int32, float32.
  591. *@li offset_w: A Tensor of type int8 . \n
  592. *@par Attributes:
  593. *@li num_output: Reserved.
  594. *@li transpose: A bool, specifying weight whether to transpose, either "true" or "false". Defaults to "false".
  595. *@li axis: Optional. A int, 1 or 2, specifying which dimension the input "K" starts from. Defaults to 1.
  596. * The product of the subsequent dimensions starting form first dimension or the second dimension is "K".
  597. *@li offset_x: Reserved . \n
  598. *@par Outputs:
  599. *y: The result tensor of type float16, int32, float32 . \n
  600. *@par Third-party framework compatibility
  601. * Compatible with the Caffe operator InnerProduct . \n
  602. *@par Quantization supported or not
  603. * Yes
  604. */
  605. REG_OP(FullyConnection)
  606. .INPUT(x, TensorType({DT_FLOAT16, DT_INT8}))
  607. .INPUT(w, TensorType({DT_FLOAT16, DT_INT8}))
  608. .OPTIONAL_INPUT(b, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32}))
  609. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  610. .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32}))
  611. .REQUIRED_ATTR(num_output, Int)
  612. .ATTR(transpose, Bool, false)
  613. .ATTR(axis, Int, 1)
  614. .ATTR(offset_x, Int, 0)
  615. .OP_END_FACTORY_REG(FullyConnection)
  616. /**
  617. *@brief Also known as a "fully-connected-compress" layer, computes an inner product with a set of learned weights, and (optionally) adds biases . \n
  618. *@par Inputs:
  619. * Four inputs, including:
  620. *@li x: A Tensor of type uint8, int8.
  621. *@li w: A weight matrix of type int8, int8.
  622. *@li w: A compress index matrix of type int8, int8.
  623. *@li b: A Tensor of type float16, int32, int32.
  624. *@li offset_w: A Tensor of type int8.i
  625. *@par Attributes:
  626. *@li num_output: Reserved.
  627. *@li transpose: A bool, specifying whether to transpose, either "true" or "false". Defaults to "false".
  628. *@li axis: Reserved.
  629. *@li offset_x: Reserved . \n
  630. *@par Outputs:
  631. *y: The result tensor of type int32 . \n
  632. *@par Third-party framework compatibility
  633. * Compatible with the Caffe operator InnerProduct . \n
  634. *@par Quantization supported or not
  635. * Yes
  636. */
  637. REG_OP(FullyConnectionCompress)
  638. .INPUT(x, TensorType({DT_UINT8, DT_INT8}))
  639. .INPUT(w, TensorType({DT_INT8}))
  640. .INPUT(comress_index, TensorType({DT_INT8}))
  641. .OPTIONAL_INPUT(b, TensorType({DT_INT32}))
  642. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  643. .OUTPUT(y, TensorType({DT_INT32}))
  644. .REQUIRED_ATTR(num_output, Int)
  645. .ATTR(transpose, Bool, false)
  646. .ATTR(axis, Int, 1)
  647. .ATTR(offset_x, Int, 0)
  648. .OP_END_FACTORY_REG(FullyConnectionCompress)
  649. /**
  650. *@brief Computes the confusion matrix from predictions and labels . \n
  651. *@par Inputs:
  652. * Three inputs, including:
  653. *@li labels: A Tensor. Must be one of the following types: float16, float32,
  654. * int32, int8, uint8.
  655. *@li predictions: A Tensor. Must be one of the following types: float16,
  656. * float32, int32, int8, uint8.
  657. *@li weights: A Tensor. Must be one of the following types: float16, float32,
  658. * int32, int8, uint8 . \n
  659. *@par Attributes:
  660. *@li num_classes: An integer for the shape of the output matrix.
  661. * No default value.
  662. *@li dtype: Data type of the confusion matrix. No default value . \n
  663. *@par Outputs:
  664. *y: A Tensor. Has the same type and format as input "labels"
  665. *@attention Constraints:
  666. *@li "weights", "labels", and "predictions" are 1D tensors.
  667. *@li The output is with shape (num_classes, num_classes),
  668. * where, 1 <= num_classes <= 4096 . \n
  669. *@see Region()
  670. *@par Third-party framework compatibility
  671. * Compatible with the TensorFlow operator ConfusionMatrix.
  672. */
  673. REG_OP(ConfusionMatrix)
  674. .INPUT(labels, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  675. .INPUT(predictions, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  676. .OPTIONAL_INPUT(weights, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  677. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  678. .REQUIRED_ATTR(num_classes, Int)
  679. .REQUIRED_ATTR(dtype, String)
  680. .OP_END_FACTORY_REG(ConfusionMatrix)
  681. /**
  682. *@brief Multiplies sparse updates into a variable reference . \n
  683. *@par Inputs:
  684. * Three inputs, including:
  685. *@li var: An ND Tensor.
  686. *Must be one of the following types: float16, float, int32, int8, uint8
  687. *@li indices: An ND Tensor.
  688. *Must be one of the following types: int32 or int64
  689. *@li updates: An ND Tensor . \n
  690. *Must be one of the following types: float16, float, int32, int8, uint8
  691. *@par Attributes:
  692. *use_locking: An optional bool. Defaults to "False". If "True", the operation
  693. * will be protected by a lock . \n
  694. *@par Outputs:
  695. *var: A Tensor. Has the same type and format as input "var" . \n
  696. *@par Third-party framework compatibility
  697. * Compatible with the TensorFlow operator ScatterMul.
  698. */
  699. REG_OP(ScatterMul)
  700. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  701. .INPUT(indices, TensorType::IndexNumberType())
  702. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  703. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  704. .ATTR(use_locking, Bool, false)
  705. .OP_END_FACTORY_REG(ScatterMul)
  706. /**
  707. *@brief Reduces sparse updates into a variable reference using
  708. * the "min" operation . \n
  709. *@par Inputs:
  710. * Three inputs, including:
  711. *@li var: An ND Tensor.
  712. *Must be one of the following types: float16, float, int32, int8, uint8
  713. *@li indices: An ND Tensor.
  714. *Must be one of the following types: int32 or int64
  715. *@li updates: An ND Tensor.
  716. *Must be one of the following types: float16, float, int32, int8, uint8
  717. *@par Attributes:
  718. *use_locking: An optional bool. Defaults to "False". If "True", the operation
  719. * will be protected by a lock . \n
  720. *@par Outputs:
  721. *var: A Tensor. Has the same type and format as input "var" . \n
  722. *@par Third-party framework compatibility
  723. * Compatible with the TensorFlow operator ScatterMin.
  724. */
  725. REG_OP(ScatterMin)
  726. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  727. .INPUT(indices, TensorType::IndexNumberType())
  728. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  729. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  730. .ATTR(use_locking, Bool, false)
  731. .OP_END_FACTORY_REG(ScatterMin)
  732. /**
  733. *@brief Reduces sparse updates into a variable reference using the "max" operation . \n
  734. *@par Inputs:
  735. * Three inputs, including:
  736. *@li var: An ND Tensor . \n
  737. *Must be one of the following types: float16, float, int32, int8, uint8
  738. *@li indices: An NCHW, NHWC, or ND Tensor . \n
  739. *Must be one of the following types: int32 or int64
  740. *@li updates: An NCHW, NHWC, or ND Tensor . \n
  741. *Must be one of the following types: float16, float, int32, int8, uint8
  742. *@par Attributes:
  743. *use_locking: An optional bool. Defaults to "False".
  744. * If "True", the operation will be protected by a lock . \n
  745. *@par Outputs:
  746. *var: A Tensor. Has the same type and format as input "var" . \n
  747. *@par Third-party framework compatibility
  748. * Compatible with the TensorFlow operator ScatterMax.
  749. */
  750. REG_OP(ScatterMax)
  751. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  752. .INPUT(indices, TensorType::IndexNumberType())
  753. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  754. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  755. .ATTR(use_locking, Bool, false)
  756. .OP_END_FACTORY_REG(ScatterMax)
  757. /**
  758. *@brief Applies sparse updates to a variable reference . \n
  759. *@par Inputs:
  760. * Three inputs, including:
  761. *@li var: An ND Tensor . \n
  762. *Must be one of the following types: float16, float, int32, int8, uint8
  763. *@li indices: An ND Tensor . \n
  764. *Must be one of the following types: int32 or int64
  765. *@li updates: An ND Tensor . \n
  766. *Must be one of the following types: float16, float, int32, int8, uint8
  767. *@par Attributes:
  768. *use_locking: An optional bool. Defaults to "False". If "True",
  769. * the operation will be protected by a lock . \n
  770. *@par Outputs:
  771. *var: A Tensor. Has the same type and format as input "var" . \n
  772. *@par Third-party framework compatibility
  773. * Compatible with the TensorFlow operator ScatterUpdate.
  774. */
  775. REG_OP(ScatterUpdate)
  776. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  777. .INPUT(indices, TensorType::IndexNumberType())
  778. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  779. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  780. .ATTR(use_locking, Bool, false)
  781. .OP_END_FACTORY_REG(ScatterUpdate)
  782. /**
  783. *@brief Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched `input` . \n
  784. *@par Inputs:
  785. * Three inputs, including:
  786. *@li input: Rank `r` tensor where `r >= 2`. \n
  787. *@li k: \n
  788. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  789. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  790. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  791. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  792. *@li padding_value: The value to fill the area outside the specified diagonal band with. \n
  793. *@par Outputs:
  794. *diagonal: The extracted diagonal(s) . \n
  795. *@par Third-party framework compatibility
  796. * Compatible with the TensorFlow operator ScatterUpdate.
  797. */
  798. REG_OP(MatrixDiagPartV2)
  799. .INPUT(input, TensorType::BasicType())
  800. .INPUT(k, TensorType({DT_INT32}))
  801. .INPUT(padding_value, TensorType::BasicType())
  802. .OUTPUT(diagonal, TensorType::BasicType())
  803. .OP_END_FACTORY_REG(MatrixDiagPartV2)
  804. /**
  805. *@brief Returns a batched matrix tensor with new batched diagonal values . \n
  806. *@par Inputs:
  807. * Three inputs, including:
  808. *@li input: "Rank `r+1`, where `r >= 1`. \n
  809. *@li diagonal: Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`. \n
  810. *@li k:
  811. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  812. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  813. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  814. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  815. *@par Outputs:
  816. *output: Rank `r+1`, with `output.shape = input.shape` . \n
  817. *@par Third-party framework compatibility
  818. * Compatible with the TensorFlow operator ScatterUpdate.
  819. */
  820. REG_OP(MatrixSetDiagV2)
  821. .INPUT(input, TensorType::BasicType())
  822. .INPUT(diagonal, TensorType::BasicType())
  823. .INPUT(k, TensorType({DT_INT32}))
  824. .OUTPUT(output, TensorType::BasicType())
  825. .OP_END_FACTORY_REG(MatrixSetDiagV2)
  826. /**
  827. *@brief Returns a batched diagonal tensor with given batched diagonal values . \n
  828. *@par Inputs:
  829. * Five inputs, including:
  830. *@li diagonal: Rank `r`, where `r >= 1` \n
  831. *@li k:
  832. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  833. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  834. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  835. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  836. *@li num_rows:
  837. *The number of rows of the output matrix. If it is not provided, the op assumes \n
  838. *the output matrix is a square matrix and infers the matrix size from k and the \n
  839. *innermost dimension of `diagonal`. \n
  840. *@li num_cols: An NCHW, NHWC, or ND Tensor.
  841. *The number of columns of the output matrix. If it is not provided, the op \n
  842. *assumes the output matrix is a square matrix and infers the matrix size from \n
  843. *k and the innermost dimension of `diagonal`. \n
  844. *@li padding_value: The number to fill the area outside the specified diagonal band with. \n
  845. *@par Outputs:
  846. *output: Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise . \n
  847. *@par Third-party framework compatibility
  848. * Compatible with the TensorFlow operator ScatterUpdate.
  849. */
  850. REG_OP(MatrixDiagV2)
  851. .INPUT(diagonal, TensorType::BasicType())
  852. .INPUT(k, TensorType({DT_INT32}))
  853. .INPUT(num_rows, TensorType({DT_INT32}))
  854. .INPUT(num_cols, TensorType({DT_INT32}))
  855. .INPUT(padding_value, TensorType::BasicType())
  856. .OUTPUT(output, TensorType::BasicType())
  857. .OP_END_FACTORY_REG(MatrixDiagV2)
  858. /**
  859. * @brief Add updates to var_out according to axis and indices.
  860. * @par Inputs:
  861. * Three inputs, including:
  862. * @li var: A Tensor. Must be one of the following types:
  863. * float16, float32, int32, int8, uint8.
  864. * @li indices: A Tensor of the indices, type should be int32.
  865. * @li updates: A Tensor of the same type as "var".
  866. * @par Attributes:
  867. * @li axis: An required int to specify the axis to perform indices add.
  868. * @par Outputs:
  869. * @li var_out: A Tensor. Same as input "var".
  870. * @par Third-party framework compatibility
  871. * Compatible with the Pytorch operator index_add.
  872. * @par Restrictions:
  873. * Warning:THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  874. */
  875. REG_OP(IndexAdd)
  876. .INPUT(var, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  877. .INPUT(indices, TensorType({DT_INT32}))
  878. .INPUT(updates, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  879. .OUTPUT(var_out, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  880. .ATTR(axis, Int, 0)
  881. .OP_END_FACTORY_REG(IndexAdd)
  882. /**
  883. *@brief: Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input \n
  884. *@par Inputs:
  885. * Two inputs, including:
  886. *@li x: A Tensor. Must be one of the following types:
  887. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  888. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  889. *@li diagonal:(int, optional) – the diagonal to consider。\n
  890. *@par Outputs:
  891. *y: A Tensor. Has the same type as "x" . \n
  892. *@par Third-party framework compatibility
  893. * Compatible with the Pytorch operator Triu.
  894. */
  895. REG_OP(Triu)
  896. .INPUT(x, TensorType::BasicType())
  897. .ATTR(diagonal, Int, 0)
  898. .OUTPUT(y, TensorType::BasicType())
  899. .OP_END_FACTORY_REG(Triu)
  900. /**
  901. *@brief: Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input \n
  902. *@par Inputs:
  903. * Two inputs, including:
  904. *@li x: A Tensor. Must be one of the following types:
  905. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  906. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  907. *@li diagonal:(int, optional) – the diagonal to consider。\n
  908. *@par Outputs:
  909. *y: A Tensor. Has the same type as "x" . \n
  910. *@par Third-party framework compatibility
  911. * Compatible with the Pytorch operator Tril.
  912. */
  913. REG_OP(Tril)
  914. .INPUT(x, TensorType::BasicType())
  915. .ATTR(diagonal, Int, 0)
  916. .OUTPUT(y, TensorType::BasicType())
  917. .OP_END_FACTORY_REG(Tril)
  918. /**
  919. *@brief Concatenates a list of N tensors along the first dimension.
  920. *@par Inputs:
  921. * Two inputs, including:
  922. * @li values: A list of Tensors. Must be one of the following types: int32, float16, float32.
  923. * Tensors to be concatenated. All must have size 1 in the first dimension and same shape.
  924. * It's a dynamic input.
  925. * @li shape: A Tensor of the same type as "x".
  926. * The final shape of the result. Should be equal to the shapes of any input
  927. * but with the number of input values in the first dimension . \n
  928. *@par Attributes:
  929. *equation: The subscripts for the Einstein summation. \n
  930. *tensor_size: tensor size of input \n
  931. *@par Outputs:
  932. *@li y: Sums the product of the elements of the input operands along dimensions specified
  933. using a notation based on the Einstein summation convention. \n
  934. *@attention Constraints:
  935. *Input tensor_size must be Int. \n
  936. *@par Third-party framework compatibility
  937. *Compatible with Pytorch einsum operator.
  938. */
  939. REG_OP(EinSum)
  940. .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
  941. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
  942. .REQUIRED_ATTR(equation, String)
  943. .REQUIRED_ATTR(tensor_size, Int)
  944. .OP_END_FACTORY_REG(EinSum)
  945. /**
  946. *@brief Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. \n
  947. *@par Inputs:
  948. *No inputs
  949. *@par Attributes:
  950. *@li num_rows: An required int. \n
  951. *@li num_columns: An optional int.Defaults to 0. \n
  952. *@li batch_shape: An optional ListInt.Defaults to []. \n
  953. *@li dtype: An optional int.Defaults to 0. \n
  954. *@par Outputs:
  955. *y: A Tensor with targeted type and shape. \n
  956. *@par Third-party framework compatibility
  957. *Compatible with the Pytorch operator Eye. \n
  958. */
  959. REG_OP(Eye)
  960. .OUTPUT(y, TensorType::BasicType()) /* "Result, has targeted element type" */
  961. .REQUIRED_ATTR(num_rows, Int)
  962. .ATTR(num_columns, Int, 0)
  963. .ATTR(batch_shape, ListInt, {})
  964. .ATTR(dtype, Int, 0)
  965. .OP_END_FACTORY_REG(Eye)
  966. } // namespace ge
  967. #endif // OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示