You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_calculation_ops.h 54 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
3 years ago
3 years ago
5 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file matrix_calculation_ops.h
  18. * \brief
  19. */
  20. #ifndef OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_
  21. #define OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_
  22. #include "graph/operator_reg.h"
  23. namespace ge {
  24. /**
  25. * @brief
  26. / (MatMul -> ConfusionTransposeD).
  27. LayerNorm - (MatMul -> ConfusionTransposeD).
  28. \ (MatMul -> ConfusionTransposeD). \n
  29. * @par Inputs:
  30. * Nine inputs, including:
  31. * @li x: A Tensor. Must be one of the following types: float16.
  32. * @li kernel_query: A Tensor. Must be one of the following types: float16.
  33. * @li kernel_key: A Tensor. Must be one of the following types: float16.
  34. * @li kernel_value: A Tensor. Must be one of the following types: float16.
  35. * @li gamma: A Tensor. Must be one of the following types: float16.
  36. * @li beta: A Tensor. Must be one of the following types: float16.
  37. * @li bias_query: A Tensor. Must be one of the following types: float16.
  38. * @li bias_key: A Tensor. Must be one of the following types: float16.
  39. * @li bias_value: A Tensor. Must be one of the following types: float16. \n
  40. * @par Attributes:
  41. * @li epsilon: A optional attribute, the type is float32. Defaults to 1e-7.
  42. * @li trans_a: A optional attribute, the type is bool. Defaults to False.
  43. * @li trans_b: A optional attribute, the type is bool. Defaults to False. \n
  44. * @par Outputs:
  45. * Six outputs, including:
  46. * @li norm: A Tensor. Must be one of the following types: float16.
  47. * @li query_output: A Tensor. Must be one of the following types: float16.
  48. * @li key_output: A Tensor. Must be one of the following types: float16.
  49. * @li value_output: A Tensor. Must be one of the following types: float16.
  50. * @li mean: A Tensor. Must be one of the following types: float16.
  51. * @li variance: A Tensor. Must be one of the following types: float16. \n
  52. */
  53. REG_OP(AttentionLnQKV)
  54. .INPUT(x, TensorType({DT_FLOAT16}))
  55. .INPUT(kernel_query, TensorType({DT_FLOAT16}))
  56. .INPUT(kernel_key, TensorType({DT_FLOAT16}))
  57. .INPUT(kernel_value, TensorType({DT_FLOAT16}))
  58. .INPUT(gamma, TensorType({DT_FLOAT16}))
  59. .INPUT(beta, TensorType({DT_FLOAT16}))
  60. .OPTIONAL_INPUT(bias_query, TensorType({DT_FLOAT16}))
  61. .OPTIONAL_INPUT(bias_key, TensorType({DT_FLOAT16}))
  62. .OPTIONAL_INPUT(bias_value, TensorType({DT_FLOAT16}))
  63. .OUTPUT(norm, TensorType({DT_FLOAT16}))
  64. .OUTPUT(query_output, TensorType({DT_FLOAT16}))
  65. .OUTPUT(key_output, TensorType({DT_FLOAT16}))
  66. .OUTPUT(value_output, TensorType({DT_FLOAT16}))
  67. .OUTPUT(mean, TensorType({DT_FLOAT16}))
  68. .OUTPUT(variance, TensorType({DT_FLOAT16}))
  69. .ATTR(epsilon, Float, 0.0000001)
  70. .ATTR(trans_a, Bool, false)
  71. .ATTR(trans_b, Bool, false)
  72. .OP_END_FACTORY_REG(AttentionLnQKV)
  73. /**
  74. *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  75. *@par Inputs:
  76. *Three inputs, including:
  77. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float16,
  78. * float32, int32. Has format [ND, NHWC].
  79. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float16,
  80. * float32, int32. Has format [ND, NHWC].
  81. * @li bias: A optional 1D Tensor. Must be one of the following types: float16,
  82. * float32, int32. Has format [ND, NHWC] . \n
  83. *@par Attributes:
  84. *@li transpose_x1: A bool. If True, changes the shape of "x1" from [M, K] to [K, M].
  85. *@li transpose_x2: A bool. If True, changes the shape of "x2" from [M, K] to [K, M] . \n
  86. *@par Outputs:
  87. *y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  88. * float32, int32. Has format [ND, NHWC] . \n
  89. *@par Third-party framework compatibility
  90. * Compatible with the TensorFlow operator BatchMatmul.
  91. */
  92. REG_OP(MatMul)
  93. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  94. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  95. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  96. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  97. .ATTR(transpose_x1, Bool, false)
  98. .ATTR(transpose_x2, Bool, false)
  99. .OP_END_FACTORY_REG(MatMul)
  100. /**
  101. *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  102. *@par Inputs:
  103. *Four inputs, including:
  104. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float32,
  105. float16, int32, int8. Has format [ND, NHWC].
  106. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float32,
  107. float16, int32, int8. Has format [ND, NHWC].
  108. * @li bias: A 1D Tensor. Must be one of the following types: float32,
  109. float16, int32. Has format [ND, NHWC].
  110. * @li offset_w: A Optional 1D Tensor for quantized inference. Type is int8.
  111. Reserved. \n
  112. *@par Attributes:
  113. * @li transpose_x1: A bool. If True, changes the shape of "x1" from [K, M] to
  114. [M, K].
  115. * @li transpose_x2: A bool. If True, changes the shape of "x2" from [N, K] to
  116. [K, N].
  117. * @li offset_x: An optional integer for quantized MatMulV2.
  118. * The negative offset added to the input x1 for int8 type. Ensure offset_x
  119. within the effective range of int8 [-128, 127]. Defaults to "0". \n
  120. *@par Outputs:
  121. *y: The result matrix Tensor. 2D. Must be one of the following types: float32,
  122. float16, int32. Has format [ND, NHWC]. \n
  123. *@attention Constraints:
  124. * if performances better in format NZ, please close
  125. "MatmulTransdataFusionPass" in fusion configuration. \n
  126. *@par Third-party framework compatibility
  127. * Compatible with the TensorFlow operator BatchMatmul.
  128. */
  129. REG_OP(MatMulV2)
  130. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_INT4, DT_BF16}))
  131. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_INT4, DT_BF16}))
  132. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  133. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  134. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8, DT_INT4}))
  135. .ATTR(transpose_x1, Bool, false)
  136. .ATTR(transpose_x2, Bool, false)
  137. .ATTR(offset_x, Int, 0)
  138. .OP_END_FACTORY_REG(MatMulV2)
  139. /**
  140. *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  141. *@par Inputs:
  142. *Five inputs, including:
  143. * @li x1: A matrix Tensor. 2D. Must be one of the following types: int8.
  144. * @li x2: A matrix Tensor. 2D. Must be one of the following types: int8.
  145. * @li compress_index: A compress index matrix of type int8.
  146. * @li bias: An optional Tensor. 1D. Must be one of the following types: int32,
  147. float16.
  148. * @li offset_w: An optional matrix Tensor. 2D. Must be one of the following
  149. types: int8. \n
  150. *@par Attributes:
  151. *@li transpose_x1: A bool. If True, changes the shape of "x1" from [K, M] to
  152. [M, K].
  153. *@li transpose_x2: A bool. If True, changes the shape of "x2" from [N, K] to
  154. [K, N].
  155. *@li offset_x: An optional integer for quantized MatMulV2Compress.
  156. *The negative offset added to the input x1 for int8 type. Ensure offset_x
  157. within the effective range of int8 [-128, 127]. Defaults to "0". \n
  158. *@par Outputs:
  159. *y: The result matrix Tensor. 2D. Must be one of the following types: int32,
  160. * float16. \n
  161. *@attention Constraints:
  162. * if performances better in format NZ, please close
  163. "MatmulTransdataFusionPass" in fusion configuration.
  164. */
  165. REG_OP(MatMulV2Compress)
  166. .INPUT(x1, TensorType({DT_INT8}))
  167. .INPUT(x2, TensorType({DT_INT8}))
  168. .INPUT(compress_index, TensorType({DT_INT8}))
  169. .OPTIONAL_INPUT(bias, TensorType({DT_INT32, DT_FLOAT16}))
  170. .OUTPUT(y, TensorType({DT_INT32, DT_FLOAT16}))
  171. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  172. .ATTR(transpose_x1, Bool, false)
  173. .ATTR(transpose_x2, Bool, false)
  174. .ATTR(offset_x, Int, 0)
  175. .OP_END_FACTORY_REG(MatMulV2Compress)
  176. /**
  177. *@brief Performs Matrix-to-matrix Multiply, producing y=alpha[0]*a*b+beta[0]*c . \n
  178. *@attention Constraints:
  179. * For better performance, The k-axis must be aligned to 16 (input type
  180. * is float16) or 32 (input type is int8). \n
  181. *@par Inputs:
  182. *Five inputs, including:
  183. *@li a: A matrix Tensor. Must be one of the following types: float16, int8.
  184. * Has format [ND].
  185. *@li b: A matrix Tensor. Must be one of the following types: float16, int8.
  186. * Has format ND.
  187. *@li c: A matrix Tensor. Must be one of the following types: float16, int32,
  188. * float32. has format ND.
  189. *@li alpha: A 1D Tensor. The shape of alpha is [1].Must be one of the following
  190. * types: float16, int32, float32. Has format [ND].
  191. *@li beta: A 1D Tensor. The shape of beta is [1]. Must be one of the following
  192. * types: float16, int32, float32. Has format [ND].
  193. * The format of a, b, c has restriction:\n
  194. * When type of a is int8 and type of c is int32, the format of a, b, c should
  195. * all be ND.\n
  196. * When type of a is int8 and type of c is float32, the format of a, b, c should
  197. * all be ND.\n
  198. * When type of a is float16 and type of c is float16, the format of a, b, c
  199. * should all be ND.\n
  200. * When type of a is float16 and type of c is float32, the format of a, b, c
  201. * should all be ND. \n
  202. *@par Attributes:
  203. *Two attributes, including:
  204. *@li transpose_a: Optional. A bool. If True, changes the shape of "a" from
  205. * [M, K] to [K, M].
  206. *@li transpose_b: Optional. A bool. If True, changes the shape of "b" from
  207. * [K, N] to [N, K] . \n
  208. *@par Outputs:
  209. *y: The result matrix Tensor. Must be one of the following types: float16,
  210. * float32, int32. Has format [ND], the format should be equal to a.
  211. */
  212. REG_OP(GEMM)
  213. .INPUT(a, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  214. .INPUT(b, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  215. .INPUT(c, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  216. .INPUT(alpha, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  217. .INPUT(beta, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  218. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  219. .ATTR(transpose_a, Bool, false)
  220. .ATTR(transpose_b, Bool, false)
  221. .OP_END_FACTORY_REG(GEMM)
  222. /**
  223. *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  224. *@par Inputs:
  225. *Two inputs, including:
  226. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  227. * float32, int32. 2D or higher. Has format [ND, NHWC].
  228. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  229. * float32, int32. 2D or higher. Has format [ND, NHWC] . \n
  230. *@par Attributes:
  231. *@li adj_x1: A bool. If True, changes the shape of "x1" from [B, M, K] to [B, K, M].
  232. *@li adj_x2: A bool. If True, changes the shape of "x2" from [B, M, K] to [B, K, M] . \n
  233. *@par Outputs:
  234. *y: The result matrix Tensor. 2D or higher. Must be one of the following types: float16,
  235. * float32, int32. 2D or higher. Has format [ND, NHWC]. Has the same shape length as "x1" and "x2" . \n
  236. *@par Third-party framework compatibility
  237. * Compatible with the TensorFlow operator BatchMatmul.
  238. */
  239. REG_OP(BatchMatMul)
  240. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  241. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  242. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  243. .ATTR(adj_x1, Bool, false)
  244. .ATTR(adj_x2, Bool, false)
  245. .OP_END_FACTORY_REG(BatchMatMul)
  246. /**
  247. * @brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  248. * @par Inputs:
  249. * Three inputs, including:
  250. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  251. * float32, int32. 2D or higher. Has format [ND, NHWC].
  252. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  253. * float32, int32. 2D or higher. Has format [ND, NHWC] . \n
  254. * @li bias: A matrix Tensor. Must be one of the following types: float16,
  255. * float32, int32. 2D or higher. Has format [ND, NHWC] . \n
  256. * @par Attributes:
  257. * @li adj_x1: A bool. If True, changes the shape of "x1" from [B, M, K] to [B, K, M].
  258. * @li adj_x2: A bool. If True, changes the shape of "x2" from [B, M, K] to [B, K, M] . \n
  259. * @par Outputs:
  260. * y: The result matrix Tensor. 2D or higher. Must be one of the following types: float16,
  261. * float32, int32. 2D or higher. Has format [ND, NHWC]. Has the same shape length as "x1" and "x2" . \n
  262. *@attention Constraints:
  263. * if performances better in format NZ, please close
  264. "MatmulTransdataFusionPass" in fusion configuration. \n
  265. * @par Third-party framework compatibility
  266. * Compatible with the TensorFlow operator BatchMatmul.
  267. */
  268. REG_OP(BatchMatMulV2)
  269. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_INT4, DT_BF16}))
  270. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_INT4, DT_BF16}))
  271. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  272. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8, DT_INT4}))
  273. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  274. .ATTR(adj_x1, Bool, false)
  275. .ATTR(adj_x2, Bool, false)
  276. .ATTR(offset_x, Int, 0)
  277. .OP_END_FACTORY_REG(BatchMatMulV2)
  278. /**
  279. *@brief Computes half the L2 norm of a tensor without the sqrt . \n
  280. *@par Inputs:
  281. * x: A Tensor.
  282. * TensorType::FloatingDataType() . \n
  283. *@par Outputs:
  284. *y: A Tensor. Has the same type as "x". \n
  285. *@attention Constraints:
  286. * if performances better in format NZ, please close
  287. "MatmulTransdataFusionPass" in fusion configuration. \n
  288. *@par Third-party framework compatibility
  289. *Compatible with the TensorFlow operator L2Loss.
  290. */
  291. REG_OP(L2Loss)
  292. .INPUT(x, TensorType::FloatingDataType())
  293. .OUTPUT(y, TensorType::FloatingDataType())
  294. .OP_END_FACTORY_REG(L2Loss)
  295. /**
  296. *@brief: Returns a batched diagonal tensor with a given batched diagonal values . \n
  297. *@par Inputs:
  298. *x: A Tensor. Must be one of the following types:
  299. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  300. * qint8, quint8, qint32, uint16, complex128, uint32, uint64 . \n
  301. *@par Outputs:
  302. *y: A Tensor. Has the same type as "x" . \n
  303. *@par Third-party framework compatibility
  304. * Compatible with the TensorFlow operator MatrixDiag.
  305. */
  306. REG_OP(MatrixDiag)
  307. .INPUT(x, TensorType::BasicType())
  308. .OUTPUT(y, TensorType::BasicType())
  309. .OP_END_FACTORY_REG(MatrixDiag)
  310. /**
  311. *@brief: Returns a batched diagonal tensor with a given batched diagonal values . \n
  312. *@par Inputs:
  313. * Two inputs, including:
  314. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  315. *@li assist: A Tensor of the same type as "x" . \n
  316. *@par Outputs:
  317. *y: A Tensor. Has the same type as "x" . \n
  318. *@par Third-party framework compatibility
  319. * Compatible with the TensorFlow operator MatrixDiag.
  320. *
  321. * @par Restrictions:
  322. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixDiag instead.
  323. */
  324. REG_OP(MatrixDiagD)
  325. .INPUT(x, TensorType::BasicType())
  326. .INPUT(assist, TensorType::BasicType())
  327. .OUTPUT(y, TensorType::BasicType())
  328. .OP_END_FACTORY_REG(MatrixDiagD)
  329. /**
  330. *@brief: Returns the batched diagonal part of a batched tensor . \n
  331. *@par Inputs:
  332. *x: A Tensor. Must be one of the following types:
  333. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  334. * qint8, quint8, qint32, uint16, complex128, uint32, uint64 . \n
  335. *@par Outputs:
  336. *y: A Tensor. Has the same type as "x" . \n
  337. *@par Third-party framework compatibility
  338. * Compatible with the TensorFlow operator MatrixDiagPart.
  339. */
  340. REG_OP(MatrixDiagPart)
  341. .INPUT(x, TensorType::BasicType())
  342. .OUTPUT(y, TensorType::BasicType())
  343. .OP_END_FACTORY_REG(MatrixDiagPart)
  344. /**
  345. *@brief: Returns the batched diagonal part of a batched tensor . \n
  346. *@par Inputs:
  347. * Two inputs, including:
  348. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  349. *@li assist: A Tensor of the same type as "x" . \n
  350. *@par Outputs:
  351. *y: A Tensor. Has the same type as "x" . \n
  352. *@par Third-party framework compatibility
  353. * Compatible with the TensorFlow operator MatrixDiagPart.
  354. *
  355. * @par Restrictions:
  356. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixDiagPart instead.
  357. */
  358. REG_OP(MatrixDiagPartD)
  359. .INPUT(x, TensorType::BasicType())
  360. .INPUT(assist, TensorType::BasicType())
  361. .OUTPUT(y, TensorType::BasicType())
  362. .OP_END_FACTORY_REG(MatrixDiagPartD)
  363. /**
  364. *@brief: Returns a batched matrix tensor with new batched diagonal values . \n
  365. *@par Inputs:
  366. * Two inputs, including:
  367. *@li x: A Tensor. Must be one of the following types:
  368. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  369. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  370. *@li diagonal: A Tensor of the same type as "x" . \n
  371. *@par Outputs:
  372. *y: A Tensor. Has the same type as "x" . \n
  373. *@par Third-party framework compatibility
  374. * Compatible with the TensorFlow operator MatrixSetDiag.
  375. */
  376. REG_OP(MatrixSetDiag)
  377. .INPUT(x, TensorType::BasicType())
  378. .INPUT(diagonal, TensorType::BasicType())
  379. .OUTPUT(y, TensorType::BasicType())
  380. .OP_END_FACTORY_REG(MatrixSetDiag)
  381. /**
  382. *@brief: Returns a batched matrix tensor with new batched diagonal values . \n
  383. *@par Inputs:
  384. * Three inputs, including:
  385. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  386. *@li diagonal: A Tensor of the same type as "x".
  387. *@li assist: A Tensor of the same type as "x" . \n
  388. *@par Outputs:
  389. *y: A Tensor. Has the same type as "x" . \n
  390. *@par Third-party framework compatibility
  391. * Compatible with the TensorFlow operator MatrixSetDiag.
  392. *
  393. * @par Restrictions:
  394. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixSetDiag instead.
  395. */
  396. REG_OP(MatrixSetDiagD)
  397. .INPUT(x, TensorType::BasicType())
  398. .INPUT(diagonal, TensorType::BasicType())
  399. .INPUT(assist, TensorType::BasicType())
  400. .OUTPUT(y, TensorType::BasicType())
  401. .OP_END_FACTORY_REG(MatrixSetDiagD)
  402. /**
  403. * @brief Function AttentionScore. \n
  404. * @par Inputs:
  405. * six inputs, including:
  406. * @li query: A matrix Tensor. The type only support float16.
  407. * @li key: A matrix Tensor. The type only support float16.
  408. * @li value: A matrix Tensor. The type only support float16.
  409. * @li padding_mask: A matrix Tensor. The type only support float16.
  410. * @li scale: A scalar. The type only support float16.
  411. * @li drop_mask: A matrix Tensor. The type only support uint8. \n
  412. * @par Attributes:
  413. * @li keep_prob: A mutable Tensor. Must met all of the following rules:
  414. shape of "keep_prob" should be (1,) or [1,].
  415. * @li query_transpose: A bool. If True, changes the shape of "query" from [K, M] to
  416. [M, K].
  417. * @li key_transpose: A bool. If True, changes the shape of "key" from [N, K] to
  418. [K, N].
  419. * @li bmm_score_transpose_a: A bool. If True, changes the shape of "mid_data" from [K, M] to
  420. [M, K].
  421. * @li bmm_score_transpose_b: A bool. If True, changes the shape of "value" from [N, K] to
  422. [K, N].
  423. * @li axes: A list of int. The dimension softmax would be performed on. Defaults
  424. to "[-1]" . \n
  425. * @par Outputs:
  426. * attention_score: The result matrix Tensor. The type only support float16.
  427. * softmax_output: The result matrix Tensor. The type only support float16.
  428. * @par Restrictions:
  429. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  430. */
  431. REG_OP(AttentionScore)
  432. .INPUT(query, TensorType({DT_FLOAT16}))
  433. .INPUT(key, TensorType({DT_FLOAT16}))
  434. .INPUT(value, TensorType({DT_FLOAT16}))
  435. .INPUT(padding_mask, TensorType({DT_FLOAT16}))
  436. .INPUT(scale, TensorType({DT_FLOAT16}))
  437. .OPTIONAL_INPUT(drop_mask, TensorType({DT_INT8}))
  438. .OUTPUT(attention_score, TensorType({DT_FLOAT16}))
  439. .OUTPUT(softmax_output, TensorType({DT_FLOAT16}))
  440. .ATTR(keep_prob, Float, 1.0)
  441. .ATTR(query_transpose, Bool, false)
  442. .ATTR(key_transpose, Bool, false)
  443. .ATTR(bmm_score_transpose_a, Bool, false)
  444. .ATTR(bmm_score_transpose_b, Bool, false)
  445. .ATTR(softmax_axes, ListInt, {-1})
  446. .OP_END_FACTORY_REG(AttentionScore)
  447. /**
  448. *@brief Applies sparse "updates" to individual values or slices in a Variable . \n
  449. *@par Inputs:
  450. * Three inputs, including:
  451. *@li var: An ND Tensor.
  452. *Must be one of the following types: float16, float32, int8, uint8, double,
  453. * int64, complex64, qint8, quint8, qint32, uint16, complex128, half, uint32,
  454. * uint64
  455. *@li indices: An ND Tensor.
  456. *Must be one of the following types: int32 or int64
  457. *@li updates: An ND Tensor.
  458. *Must be one of the following types: float16, float32, int8, uint8, double,
  459. * int64, complex64, qint8, quint8, qint32, uint16, complex128, half, uint32,
  460. * uint64
  461. *@par Attributes:
  462. *use_locking: An optional bool. Defaults to "False". If "True",
  463. * the operation will be protected by a lock . \n
  464. *@par Outputs:
  465. *var: A Tensor. Has the same type and format as input "var" . \n
  466. *@par Third-party framework compatibility
  467. * Compatible with the TensorFlow operator ScatterNdUpdate.
  468. */
  469. REG_OP(ScatterNdUpdate)
  470. .INPUT(var, TensorType::BasicType())
  471. .INPUT(indices, TensorType::IndexNumberType())
  472. .INPUT(updates, TensorType::BasicType())
  473. .OUTPUT(var, TensorType::BasicType())
  474. .ATTR(use_locking, Bool, false)
  475. .OP_END_FACTORY_REG(ScatterNdUpdate)
  476. /**
  477. *@brief Applies sparse addition to individual values or slices in a Variable . \n
  478. *@par Inputs:
  479. * Three inputs, including:
  480. *@li x: An ND Tensor. \n
  481. *Must be one of the following types: float16, float32, bool, int8, uint8
  482. *@li indices: An ND Tensor. \n
  483. *Must be one of the following types: int32
  484. *@li updates: An ND Tensor. \n
  485. *Must be one of the following types: float16, float32, bool, int8, uint8
  486. *@par Outputs:
  487. *y: A Tensor. Has the same type and format as input "x" . \n
  488. *@par Third-party framework compatibility
  489. * Compatible with the TensorFlow operator TensorScatterUpdate.
  490. *@par Restrictions:
  491. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  492. */
  493. REG_OP(TensorScatterUpdate)
  494. .INPUT(x, TensorType::BasicType())
  495. .INPUT(indices, TensorType::IndexNumberType())
  496. .INPUT(updates, TensorType::BasicType())
  497. .OUTPUT(y, TensorType::BasicType())
  498. .OP_END_FACTORY_REG(TensorScatterUpdate)
  499. /**
  500. *@brief Uses "updates" to update tensor "data" by "indices". \n
  501. *@par Inputs:
  502. * Three inputs, including:
  503. *@li data: An ND Tensor . \n
  504. *Must be one of the following types: float16, float32, int32, int8, uint8
  505. *@li indices: An ND Tensor of type int32 or int64
  506. *@li updates: An Tensor. Same shape as indices. format:NCHW, NHWC . \n
  507. *Must be one of the following types: float16, float32, int32, int8, uint8
  508. *@par Attributes:
  509. *@li axis: An optional attribute. Defaults to 0.
  510. *@par Outputs:
  511. *y: A Tensor. Has the same type and format as input "data" . \n
  512. *@par Third-party framework compatibility
  513. * Compatible with the ONNX operator ScatterElements.
  514. */
  515. REG_OP(ScatterElements)
  516. .INPUT(data, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  517. .INPUT(indices, TensorType::IndexNumberType())
  518. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  519. .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  520. .ATTR(axis, Int, 0)
  521. .OP_END_FACTORY_REG(ScatterElements)
  522. /**
  523. *@brief Adds sparse "updates" to a variable reference . \n
  524. *@par Inputs:
  525. * Three inputs, including:
  526. *@li var: An ND Tensor .
  527. *Must be one of the following types: float16, float, int32, int8, uint8
  528. *@li indices: An ND Tensor . \n
  529. *Must be one of the following types: int32 or int64
  530. *@li updates: An ND Tensor .
  531. *Must be one of the following types: float16, float, int32, int8, uint8
  532. *@par Attributes:
  533. *use_locking: An optional bool. Defaults to "False". If "True",
  534. * the operation will be protected by a lock . \n
  535. *@par Outputs:
  536. *var: A Tensor. Has the same type and format as input "var" . \n
  537. *@par Third-party framework compatibility
  538. * Compatible with the TensorFlow operator ScatterAdd.
  539. */
  540. REG_OP(ScatterAdd)
  541. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  542. .INPUT(indices, TensorType::IndexNumberType())
  543. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  544. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  545. .ATTR(use_locking, Bool, false)
  546. .OP_END_FACTORY_REG(ScatterAdd)
  547. /**
  548. *@brief Adds sparse "updates" to a variable reference . \n
  549. *@par Inputs:
  550. * Three inputs, including:
  551. *@li var: An ND Tensor .
  552. *Must be one of the following types: float16, float32, int32, int8, uint8
  553. *@li indices: An ND Tensor of type int32 or int64
  554. *@li updates: An ND Tensor .
  555. *Must be one of the following types: float16, float32, int32, int8, uint8
  556. *@par Attributes:
  557. * axis: An required int. The axis along which to index. \n
  558. *@par Outputs:
  559. *var: A Tensor. Has the same type and format as input "var" . \n
  560. *@par Third-party framework compatibility
  561. * Compatible with the pytorch operator ScatterAdd.
  562. */
  563. REG_OP(ScatterAddWithAxis)
  564. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  565. .INPUT(indices, TensorType::IndexNumberType())
  566. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  567. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  568. .REQUIRED_ATTR(axis, Int)
  569. .OP_END_FACTORY_REG(ScatterAddWithAxis)
  570. /**
  571. *@brief Divides a variable reference by sparse updates . \n
  572. *@par Inputs:
  573. * Three inputs, including:
  574. *@li var: An ND Tensor.
  575. *Must be one of the following types: float16, float, int32, int8, uint8
  576. *@li indices: An ND Tensor.
  577. *Must be one of the following types: int32 or int64
  578. *@li updates: An ND Tensor.
  579. *Must be one of the following types: float16, float, int32, int8, uint8
  580. *@par Attributes:
  581. *use_locking: An optional bool. Defaults to "False". If "True",
  582. * the operation will be protected by a lock . \n
  583. *@par Outputs:
  584. *var: A Tensor. Has the same type and format as input "var" . \n
  585. *@par Third-party framework compatibility
  586. * Compatible with the TensorFlow operator ScatterDiv.
  587. */
  588. REG_OP(ScatterDiv)
  589. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  590. .INPUT(indices, TensorType::IndexNumberType())
  591. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  592. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  593. .ATTR(use_locking, Bool, false)
  594. .OP_END_FACTORY_REG(ScatterDiv)
  595. /**
  596. *@brief Applies sparse addition to individual values or slices in a Variable . \n
  597. *@par Inputs:
  598. * Three inputs, including:
  599. *@li var: An ND Tensor.
  600. *Must be one of the following types: float16, float, int32, int8, uint8
  601. *@li indices: An ND Tensor.
  602. *Must be one of the following types: int32 or int64
  603. *@li updates: An ND Tensor.
  604. *Must be one of the following types: float16, float, int32, int8, uint8
  605. *@par Attributes:
  606. *use_locking: An optional bool. Defaults to "False". If "True",
  607. * the operation will be protected by a lock . \n
  608. *@par Outputs:
  609. *var: A Tensor. Has the same type and format as input "var" . \n
  610. *@par Third-party framework compatibility
  611. * Compatible with the TensorFlow operator ScatterNdAdd.
  612. */
  613. REG_OP(ScatterNdAdd)
  614. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  615. .INPUT(indices, TensorType::IndexNumberType())
  616. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  617. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  618. .ATTR(use_locking, Bool, false)
  619. .OP_END_FACTORY_REG(ScatterNdAdd)
  620. /**
  621. *@brief Applies sparse addition to individual values or slices in a Variable . \n
  622. *@par Inputs:
  623. * Three inputs, including:
  624. *@li x: An ND Tensor. \n
  625. *Must be one of the following types: float16, float32, int32, int8, uint8
  626. *@li indices: An ND Tensor. \n
  627. *Must be one of the following types: int32
  628. *@li updates: An ND Tensor. \n
  629. * Must be one of the following types: float16, float32, int32, int8, uint8
  630. *@par Outputs:
  631. *y: A Tensor. Has the same type and format as input "x" . \n
  632. *@par Third-party framework compatibility
  633. * Compatible with the TensorFlow operator TensorScatterAdd.
  634. *@par Restrictions:
  635. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  636. */
  637. REG_OP(TensorScatterAdd)
  638. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  639. .INPUT(indices, TensorType::IndexNumberType())
  640. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  641. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  642. .OP_END_FACTORY_REG(TensorScatterAdd)
  643. /**
  644. *@brief Applies sparse subtraction to individual values or slices in a Variable . \n
  645. *@par Inputs:
  646. * Three inputs, including:
  647. *@li var: An ND Tensor.
  648. *Must be one of the following types: float16, float, int32, int8, uint8
  649. *@li indices: An ND Tensor.
  650. *Must be one of the following types: int32 or int64
  651. *@li updates: An ND Tensor.
  652. *Must be one of the following types: float16, float, int32, int8, uint8
  653. *@par Attributes:
  654. *use_locking: An optional bool. Defaults to "False". If "True",
  655. * the operation will be protected by a lock . \n
  656. *@par Outputs:
  657. * var: A Tensor. Has the same type and format as input "var" . \n
  658. *@par Third-party framework compatibility
  659. * Compatible with the TensorFlow operator ScatterNdSub.
  660. */
  661. REG_OP(ScatterNdSub)
  662. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  663. .INPUT(indices, TensorType::IndexNumberType())
  664. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  665. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  666. .ATTR(use_locking, Bool, false)
  667. .OP_END_FACTORY_REG(ScatterNdSub)
  668. /**
  669. *@brief Applies sparse addition to individual values or slices in a Variable . \n
  670. *@par Inputs:
  671. * Three inputs, including:
  672. *@li x: An ND Tensor. \n
  673. *Must be one of the following types: float16, float32, int32, int8, uint8
  674. *@li indices: An ND Tensor. \n
  675. *Must be one of the following types: int32
  676. *@li updates: An ND Tensor. \n
  677. *Must be one of the following types: float16, float32, int32, int8, uint8
  678. *@par Outputs:
  679. * y: A Tensor. Has the same type and format as input "x" . \n
  680. *@par Third-party framework compatibility
  681. * Compatible with the TensorFlow operator TensorScatterSub.
  682. *@par Restrictions:
  683. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  684. */
  685. REG_OP(TensorScatterSub)
  686. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  687. .INPUT(indices, TensorType::IndexNumberType())
  688. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  689. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  690. .OP_END_FACTORY_REG(TensorScatterSub)
  691. /**
  692. *@brief Subtracts sparse updates to a variable reference . \n
  693. *@par Inputs:
  694. * Three inputs, including:
  695. *@li var: An ND Tensor.
  696. *Must be one of the following types: float16, float, int32, int8, uint8
  697. *@li indices: An ND Tensor.
  698. *Must be one of the following types: int32 or int64
  699. *@li updates: An ND Tensor.
  700. *Must be one of the following types: float16, float, int32, int8, uint8
  701. *@par Attributes:
  702. *use_locking: An optional bool. Defaults to "False". If "True",
  703. * the operation will be protected by a lock . \n
  704. *@par Outputs:
  705. * var: A Tensor. Has the same type and format as input "var" . \n
  706. *@par Third-party framework compatibility
  707. * Compatible with the TensorFlow operator ScatterSub.
  708. */
  709. REG_OP(ScatterSub)
  710. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  711. .INPUT(indices, TensorType::IndexNumberType())
  712. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  713. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  714. .ATTR(use_locking, Bool, false)
  715. .OP_END_FACTORY_REG(ScatterSub)
  716. /**
  717. *@brief: Returns the batched diagonal part of a batched tensor with "assist" . \n
  718. *@par Inputs:
  719. * Two inputs, including:
  720. * @li x: A Tensor of type float16, float32, or int32.
  721. * @li assist: A Tensor of the same type as "x" . \n
  722. *@par Outputs:
  723. *y: A Tensor. Has the same type as "x" . \n
  724. *@par Third-party framework compatibility
  725. * Compatible with the TensorFlow operator DiagPart.
  726. *
  727. * @par Restrictions:
  728. * Warning: THIS FUNCTION IS DEPRECATED. Please use DiagPart instead.
  729. */
  730. REG_OP(DiagPartD)
  731. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  732. .INPUT(assist, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  733. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  734. .OP_END_FACTORY_REG(DiagPartD)
  735. /**
  736. *@brief: Returns the batched diagonal part of a batched tensor . \n
  737. *@par Inputs:
  738. *x: A Tensor. Must be one of the following types:
  739. * float16, float32, int32, int64, double, complex64, complex128 . \n
  740. *@par Outputs:
  741. *y: A Tensor. Has the same type as "x" . \n
  742. *@par Third-party framework compatibility
  743. * Compatible with the TensorFlow operator DiagPart.
  744. */
  745. REG_OP(DiagPart)
  746. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  747. DT_COMPLEX64, DT_COMPLEX128}))
  748. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  749. DT_COMPLEX64, DT_COMPLEX128}))
  750. .OP_END_FACTORY_REG(DiagPart)
  751. /**
  752. *@brief Also known as a "fully-connected" layer, computes an inner product with a set of learned weights, and (optionally) adds biases . \n
  753. *@par Inputs:
  754. * Four inputs, including:
  755. *@li x: A Tensor of type float16, int8.
  756. *@li w: A weight matrix of type float16, int8.
  757. *@li b: An optional Tensor of type float16, int32, float32.
  758. *@li offset_w: An optional Tensor of type int8. Reserved. Only None Supported. \n
  759. *@par Attributes:
  760. *@li num_output: Required. An int, output neuron number. Reserved.
  761. *@li transpose: A bool, specifying weight whether to transpose input w, either "true" or "false". Defaults to "false".
  762. *@li axis: Optional. An int, 1 or 2, specifying which dimension the input "K" starts from. Defaults to 1.
  763. * The product of the subsequent dimensions starting form first dimension or the second dimension is "K".
  764. *@li offset_x: An optional integer for quantized FullyConnection.
  765. *The negative offset added to the input image for int8 type. Ensure offset_x within the
  766. *effective range of int8 [-128, 127]. Defaults to "0". \n
  767. *@par Outputs:
  768. *y: The result tensor of type float16, int32, float32 . \n
  769. *@par Third-party framework compatibility
  770. * Compatible with the Caffe operator InnerProduct . \n
  771. *@par Quantization supported or not
  772. * Yes
  773. */
  774. REG_OP(FullyConnection)
  775. .INPUT(x, TensorType({DT_FLOAT16, DT_INT8, DT_INT4, DT_FLOAT32, DT_BF16}))
  776. .INPUT(w, TensorType({DT_FLOAT16, DT_INT8, DT_INT4, DT_FLOAT32, DT_BF16}))
  777. .OPTIONAL_INPUT(b, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32, DT_BF16}))
  778. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8, DT_INT4}))
  779. .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32, DT_BF16}))
  780. .REQUIRED_ATTR(num_output, Int)
  781. .ATTR(transpose, Bool, false)
  782. .ATTR(axis, Int, 1)
  783. .ATTR(offset_x, Int, 0)
  784. .OP_END_FACTORY_REG(FullyConnection)
  785. /**
  786. *@brief Also known as a "fully-connected-compress" layer, computes an inner
  787. product with a set of learned weights, and (optionally) adds biases . \n
  788. *@par Inputs:
  789. * Five inputs, including:
  790. *@li x: A Tensor of type uint8, int8.
  791. *@li w: A weight matrix of type int8.
  792. *@li compress_index: A compress index matrix of type int8.
  793. *@li b: A Tensor of type int32.
  794. *@li offset_w: A Tensor of type int8.
  795. *@par Attributes:
  796. *@li num_output: A int, specifying the number of outputs.
  797. *@li transpose: A bool, specifying whether to transpose input w, either "true"
  798. or "false". Defaults to "false".
  799. *@li axis: Optional. A int, 1 or 2, specifying which dimension the input "K"
  800. starts from. Defaults to "1".
  801. * The product of the subsequent dimensions starting form first dimension or the
  802. second dimension is "K".
  803. *@li offset_x: An optional integer for quantized FullyConnectionCompress.
  804. *The negative offset added to the input image for int8 type. Ensure offset_x
  805. within the effective range of int8 [-128, 127]. Defaults to "0". \n
  806. *@par Outputs:
  807. *y: The result tensor of type int32. \n
  808. *@par Third-party framework compatibility
  809. * Compatible with the Caffe operator InnerProduct. \n
  810. *@par Quantization supported or not
  811. * Yes
  812. */
  813. REG_OP(FullyConnectionCompress)
  814. .INPUT(x, TensorType({DT_UINT8, DT_INT8}))
  815. .INPUT(w, TensorType({DT_INT8}))
  816. .INPUT(comress_index, TensorType({DT_INT8}))
  817. .OPTIONAL_INPUT(b, TensorType({DT_INT32}))
  818. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  819. .OUTPUT(y, TensorType({DT_INT32}))
  820. .REQUIRED_ATTR(num_output, Int)
  821. .ATTR(transpose, Bool, false)
  822. .ATTR(axis, Int, 1)
  823. .ATTR(offset_x, Int, 0)
  824. .OP_END_FACTORY_REG(FullyConnectionCompress)
  825. /**
  826. *@brief Computes the confusion matrix from predictions and labels . \n
  827. *@par Inputs:
  828. * Three inputs, including:
  829. *@li labels: A Tensor. Must be one of the following types: float16, float32,
  830. * int32, int8, uint8.
  831. *@li predictions: A Tensor. Must be one of the following types: float16,
  832. * float32, int32, int8, uint8.
  833. *@li weights: A Tensor. Must be one of the following types: float16, float32,
  834. * int32, int8, uint8 . \n
  835. *@par Attributes:
  836. *@li num_classes: An integer for the shape of the output matrix.
  837. * No default value.
  838. *@li dtype: Data type of the confusion matrix. No default value . \n
  839. *@par Outputs:
  840. *y: A Tensor. Has the same type and format as input "labels"
  841. *@attention Constraints:
  842. *@li "weights", "labels", and "predictions" are 1D tensors.
  843. *@li The output is with shape (num_classes, num_classes),
  844. * where, 1 <= num_classes <= 4096 . \n
  845. *@see Region()
  846. *@par Third-party framework compatibility
  847. * Compatible with the TensorFlow operator ConfusionMatrix.
  848. */
  849. REG_OP(ConfusionMatrix)
  850. .INPUT(labels, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  851. .INPUT(predictions, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  852. .OPTIONAL_INPUT(weights, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  853. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  854. .REQUIRED_ATTR(num_classes, Int)
  855. .REQUIRED_ATTR(dtype, String)
  856. .OP_END_FACTORY_REG(ConfusionMatrix)
  857. /**
  858. *@brief Multiplies sparse updates into a variable reference . \n
  859. *@par Inputs:
  860. * Three inputs, including:
  861. *@li var: An ND Tensor.
  862. *Must be one of the following types: float16, float, int32, int8, uint8
  863. *@li indices: An ND Tensor.
  864. *Must be one of the following types: int32 or int64
  865. *@li updates: An ND Tensor . \n
  866. *Must be one of the following types: float16, float, int32, int8, uint8
  867. *@par Attributes:
  868. *use_locking: An optional bool. Defaults to "False". If "True", the operation
  869. * will be protected by a lock . \n
  870. *@par Outputs:
  871. *var: A Tensor. Has the same type and format as input "var" . \n
  872. *@par Third-party framework compatibility
  873. * Compatible with the TensorFlow operator ScatterMul.
  874. */
  875. REG_OP(ScatterMul)
  876. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  877. .INPUT(indices, TensorType::IndexNumberType())
  878. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  879. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  880. .ATTR(use_locking, Bool, false)
  881. .OP_END_FACTORY_REG(ScatterMul)
  882. /**
  883. *@brief Reduces sparse updates into a variable reference using
  884. * the "min" operation . \n
  885. *@par Inputs:
  886. * Three inputs, including:
  887. *@li var: An ND Tensor.
  888. *Must be one of the following types: float16, float, int32, int8, uint8
  889. *@li indices: An ND Tensor.
  890. *Must be one of the following types: int32 or int64
  891. *@li updates: An ND Tensor.
  892. *Must be one of the following types: float16, float, int32, int8, uint8
  893. *@par Attributes:
  894. *use_locking: An optional bool. Defaults to "False". If "True", the operation
  895. * will be protected by a lock . \n
  896. *@par Outputs:
  897. *var: A Tensor. Has the same type and format as input "var" . \n
  898. *@par Third-party framework compatibility
  899. * Compatible with the TensorFlow operator ScatterMin.
  900. */
  901. REG_OP(ScatterMin)
  902. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  903. .INPUT(indices, TensorType::IndexNumberType())
  904. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  905. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  906. .ATTR(use_locking, Bool, false)
  907. .OP_END_FACTORY_REG(ScatterMin)
  908. /**
  909. *@brief Reduces sparse updates into a variable reference using the "max" operation . \n
  910. *@par Inputs:
  911. * Three inputs, including:
  912. *@li var: An ND Tensor .
  913. *Must be one of the following types: float16, float, int32, int8, uint8
  914. *@li indices: An NCHW, NHWC, or ND Tensor . \n
  915. *Must be one of the following types: int32 or int64
  916. *@li updates: An NCHW, NHWC, or ND Tensor .
  917. *Must be one of the following types: float16, float, int32, int8, uint8
  918. *@par Attributes:
  919. *use_locking: An optional bool. Defaults to "False".
  920. * If "True", the operation will be protected by a lock . \n
  921. *@par Outputs:
  922. *var: A Tensor. Has the same type and format as input "var" . \n
  923. *@par Third-party framework compatibility
  924. * Compatible with the TensorFlow operator ScatterMax.
  925. */
  926. REG_OP(ScatterMax)
  927. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  928. .INPUT(indices, TensorType::IndexNumberType())
  929. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  930. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  931. .ATTR(use_locking, Bool, false)
  932. .OP_END_FACTORY_REG(ScatterMax)
  933. /**
  934. *@brief Applies sparse updates to a variable reference . \n
  935. *@par Inputs:
  936. * Three inputs, including:
  937. *@li var: An ND Tensor .
  938. *Must be one of the following types: float16, float, int32, int8, uint8
  939. *@li indices: An ND Tensor . \n
  940. *Must be one of the following types: int32 or int64
  941. *@li updates: An ND Tensor .
  942. *Must be one of the following types: float16, float, int32, int8, uint8
  943. *@par Attributes:
  944. *use_locking: An optional bool. Defaults to "False". If "True",
  945. * the operation will be protected by a lock . \n
  946. *@par Outputs:
  947. *var: A Tensor. Has the same type and format as input "var" . \n
  948. *@par Third-party framework compatibility
  949. * Compatible with the TensorFlow operator ScatterUpdate.
  950. */
  951. REG_OP(ScatterUpdate)
  952. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  953. .INPUT(indices, TensorType::IndexNumberType())
  954. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  955. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  956. .ATTR(use_locking, Bool, false)
  957. .OP_END_FACTORY_REG(ScatterUpdate)
  958. /**
  959. *@brief Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched `input` . \n
  960. *@par Inputs:
  961. * Three inputs, including:
  962. *@li input: Rank `r` tensor where `r >= 2`. \n
  963. *@li k: \n
  964. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  965. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  966. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  967. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  968. *@li padding_value: The value to fill the area outside the specified diagonal band with. \n
  969. *@par Outputs:
  970. *diagonal: The extracted diagonal(s) . \n
  971. *@par Third-party framework compatibility
  972. * Compatible with the TensorFlow operator ScatterUpdate.
  973. */
  974. REG_OP(MatrixDiagPartV2)
  975. .INPUT(input, TensorType::BasicType())
  976. .INPUT(k, TensorType({DT_INT32}))
  977. .INPUT(padding_value, TensorType::BasicType())
  978. .OUTPUT(diagonal, TensorType::BasicType())
  979. .OP_END_FACTORY_REG(MatrixDiagPartV2)
  980. /**
  981. *@brief Returns a batched matrix tensor with new batched diagonal values . \n
  982. *@par Inputs:
  983. * Three inputs, including:
  984. *@li input: "Rank `r+1`, where `r >= 1`. \n
  985. *@li diagonal: Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`. \n
  986. *@li k:
  987. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  988. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  989. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  990. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  991. *@par Outputs:
  992. *output: Rank `r+1`, with `output.shape = input.shape` . \n
  993. *@par Third-party framework compatibility
  994. * Compatible with the TensorFlow operator ScatterUpdate.
  995. */
  996. REG_OP(MatrixSetDiagV2)
  997. .INPUT(input, TensorType::BasicType())
  998. .INPUT(diagonal, TensorType::BasicType())
  999. .INPUT(k, TensorType({DT_INT32}))
  1000. .OUTPUT(output, TensorType::BasicType())
  1001. .OP_END_FACTORY_REG(MatrixSetDiagV2)
  1002. /**
  1003. *@brief Returns a batched matrix tensor with new batched diagonal values . \n
  1004. *@par Inputs:
  1005. * Three inputs, including:
  1006. *@li input: "Rank `r+1`, where `r >= 1`. \n
  1007. *@li diagonal: Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`. \n
  1008. *@li k:
  1009. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  1010. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  1011. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  1012. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1013. *@par Attributes:
  1014. *@li align: An optional string. Defaults to RIGHT_LEFT. It is a string specifying \n
  1015. *how superdiagonals and subdiagonals should be aligned, respectively. \n
  1016. *other optional: LEFT_RIGHT, LEFT_LEFT, and RIGHT_RIGHT.\n
  1017. *@par Outputs:
  1018. *output: Rank `r+1`, with `output.shape = input.shape` . \n
  1019. *@par Third-party framework compatibility
  1020. * Compatible with the TensorFlow operator ScatterUpdate.
  1021. */
  1022. REG_OP(MatrixSetDiagV3)
  1023. .INPUT(input, TensorType::BasicType())
  1024. .INPUT(diagonal, TensorType::BasicType())
  1025. .INPUT(k, TensorType({DT_INT32}))
  1026. .OUTPUT(output, TensorType::BasicType())
  1027. .ATTR(align, String, "RIGHT_LEFT")
  1028. .OP_END_FACTORY_REG(MatrixSetDiagV3)
  1029. /**
  1030. *@brief Returns a batched diagonal tensor with given batched diagonal values . \n
  1031. *@par Inputs:
  1032. * Five inputs, including:
  1033. *@li diagonal: Rank `r`, where `r >= 1` \n
  1034. *@li k:
  1035. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  1036. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  1037. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  1038. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1039. *@li num_rows:
  1040. *The number of rows of the output matrix. If it is not provided, the op assumes \n
  1041. *the output matrix is a square matrix and infers the matrix size from k and the \n
  1042. *innermost dimension of `diagonal`. \n
  1043. *@li num_cols: An NCHW, NHWC, or ND Tensor.
  1044. *The number of columns of the output matrix. If it is not provided, the op \n
  1045. *assumes the output matrix is a square matrix and infers the matrix size from \n
  1046. *k and the innermost dimension of `diagonal`. \n
  1047. *@li padding_value: The number to fill the area outside the specified diagonal band with. \n
  1048. *@par Outputs:
  1049. *output: Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise . \n
  1050. *@par Third-party framework compatibility
  1051. * Compatible with the TensorFlow operator ScatterUpdate.
  1052. */
  1053. REG_OP(MatrixDiagV2)
  1054. .INPUT(diagonal, TensorType::BasicType())
  1055. .INPUT(k, TensorType({DT_INT32}))
  1056. .INPUT(num_rows, TensorType({DT_INT32}))
  1057. .INPUT(num_cols, TensorType({DT_INT32}))
  1058. .INPUT(padding_value, TensorType::BasicType())
  1059. .OUTPUT(output, TensorType::BasicType())
  1060. .OP_END_FACTORY_REG(MatrixDiagV2)
  1061. /**
  1062. * @brief Add updates to var_out according to axis and indices.
  1063. * @par Inputs:
  1064. * Three inputs, including:
  1065. * @li var: A Tensor. Must be one of the following types:
  1066. * float16, float32, int32, int8, uint8.
  1067. * @li indices: A Tensor of the indices, type should be int32.
  1068. * @li updates: A Tensor of the same type as "var".
  1069. * @par Attributes:
  1070. * @li axis: An required int to specify the axis to perform indices add.
  1071. * @par Outputs:
  1072. * @li var_out: A Tensor. Same as input "var".
  1073. * @par Third-party framework compatibility
  1074. * Compatible with the Pytorch operator index_add.
  1075. * @par Restrictions:
  1076. * Warning:THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  1077. */
  1078. REG_OP(IndexAdd)
  1079. .INPUT(var, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  1080. .INPUT(indices, TensorType({DT_INT32}))
  1081. .INPUT(updates, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  1082. .OUTPUT(var_out, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  1083. .ATTR(axis, Int, 0)
  1084. .OP_END_FACTORY_REG(IndexAdd)
  1085. /**
  1086. * @brief According to the index number of indexes, replace the value
  1087. *corresponding to X1 with the value in x2.
  1088. * @par Inputs:
  1089. * Three inputs, including:
  1090. * @li x1: A Tensor. Must be one of the following types:
  1091. *float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1092. *qint8, quint8, qint32, uint16, complex128, uint32, uint64. \n
  1093. * @li x2: A Tensor of the same type as "x1".
  1094. * @li indices: A Tensor of the indices,
  1095. * @par Attributes:
  1096. * @li accumulate: Does it support self accumulation.Defaults to 0.
  1097. * @par Outputs:
  1098. * @li y: A Tensor. Same as input "x1".
  1099. * @par Third-party framework compatibility
  1100. * Compatible with the Pytorch operator index_put.
  1101. * @par Restrictions:
  1102. * Warning:THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  1103. */
  1104. REG_OP(IndexPut)
  1105. .INPUT(x1, TensorType::BasicType())
  1106. .INPUT(x2, TensorType::BasicType())
  1107. .OUTPUT(y, TensorType::BasicType())
  1108. .REQUIRED_ATTR(indices, ListInt)
  1109. .ATTR(accumulate, Int, 0)
  1110. .OP_END_FACTORY_REG(IndexPut)
  1111. /**
  1112. *@brief: Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input \n
  1113. *@par Inputs:
  1114. *x: A Tensor. Must be one of the following types:
  1115. *float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1116. *qint8, quint8, qint32, uint16, complex128, uint32, uint64. \n
  1117. *@par Attributes:
  1118. *diagonal: An optional attribute indicates the diagonal to consider. \n
  1119. *@par Outputs:
  1120. *y: A Tensor. Has the same type as "x" . \n
  1121. *@par Third-party framework compatibility
  1122. * Compatible with the Pytorch operator Triu.
  1123. */
  1124. REG_OP(Triu)
  1125. .INPUT(x, TensorType::BasicType())
  1126. .ATTR(diagonal, Int, 0)
  1127. .OUTPUT(y, TensorType::BasicType())
  1128. .OP_END_FACTORY_REG(Triu)
  1129. /**
  1130. *@brief: Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input \n
  1131. *@par Inputs:
  1132. *x: A Tensor. Must be one of the following types:
  1133. *float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1134. *qint8, quint8, qint32, uint16, complex128, uint32, uint64. \n
  1135. *@par Attributes:
  1136. *diagonal: An optional attribute indicates the diagonal to consider. \n
  1137. *@par Outputs:
  1138. *y: A Tensor. Has the same type as "x" . \n
  1139. *@par Third-party framework compatibility
  1140. * Compatible with the Pytorch operator Tril.
  1141. */
  1142. REG_OP(Tril)
  1143. .INPUT(x, TensorType::BasicType())
  1144. .ATTR(diagonal, Int, 0)
  1145. .OUTPUT(y, TensorType::BasicType())
  1146. .OP_END_FACTORY_REG(Tril)
  1147. /**
  1148. *@brief Concatenates a list of N tensors along the first dimension.
  1149. *@par Inputs:
  1150. * Two inputs, including:
  1151. * @li values: A list of Tensors. Must be one of the following types: int32, float16, float32.
  1152. * Tensors to be concatenated. All must have size 1 in the first dimension and same shape.
  1153. * It's a dynamic input.
  1154. * @li shape: A Tensor of the same type as "x".
  1155. * The final shape of the result. Should be equal to the shapes of any input
  1156. * but with the number of input values in the first dimension . \n
  1157. *@par Attributes:
  1158. *equation: The subscripts for the Einstein summation. \n
  1159. *N: tensor size of input \n
  1160. *@par Outputs:
  1161. *@li y: Sums the product of the elements of the input operands along dimensions specified
  1162. using a notation based on the Einstein summation convention. \n
  1163. *@attention Constraints:
  1164. *Input N must be Int. \n
  1165. *@par Third-party framework compatibility
  1166. *Compatible with Pytorch einsum operator.
  1167. */
  1168. REG_OP(Einsum)
  1169. .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
  1170. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
  1171. .REQUIRED_ATTR(equation, String)
  1172. .REQUIRED_ATTR(N, Int)
  1173. .OP_END_FACTORY_REG(Einsum)
  1174. /**
  1175. *@brief Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. \n
  1176. *@par Inputs:
  1177. *No inputs
  1178. *@par Attributes:
  1179. *@li num_rows: An required int. \n
  1180. *@li num_columns: An optional int.Defaults to 0. \n
  1181. *@li batch_shape: An optional ListInt.Defaults to []. \n
  1182. *@li dtype: An optional int.Defaults to 0. \n
  1183. *@par Outputs:
  1184. *y: A Tensor with targeted type and shape. \n
  1185. *@par Third-party framework compatibility
  1186. *Compatible with the Pytorch operator Eye. \n
  1187. */
  1188. REG_OP(Eye)
  1189. .OUTPUT(y, TensorType::BasicType()) /* "Result, has targeted element type" */
  1190. .REQUIRED_ATTR(num_rows, Int)
  1191. .ATTR(num_columns, Int, 0)
  1192. .ATTR(batch_shape, ListInt, {})
  1193. .ATTR(dtype, Int, 0)
  1194. .OP_END_FACTORY_REG(Eye)
  1195. /**
  1196. *@brief: Fill diagonal of at least 2 dimension tensors with value . \n
  1197. *@par Inputs:
  1198. *x: A Tensor. Must be one of the following types:
  1199. * float32, int32, int64 . \n
  1200. *@par Outputs:
  1201. *y: A Tensor. Has the same type as "x" . \n
  1202. *@par Attributes:
  1203. *fill_value:The value to fill in
  1204. *wrap: An optional bool. Defaults to "False". If "True", Use recursive fill. \n
  1205. *@par Third-party framework compatibility
  1206. * Compatible with the Pytorch operator FillDiagonal.
  1207. */
  1208. REG_OP(FillDiagonal)
  1209. .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT64}))
  1210. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64}))
  1211. .REQUIRED_ATTR(fill_value, Float)
  1212. .ATTR(wrap, Bool, false)
  1213. .OP_END_FACTORY_REG(FillDiagonal)
  1214. /**
  1215. *@brief: Returns the sum of the elements of the diagonal of the input 2-D matrix. \n
  1216. *@par Inputs:
  1217. *x: A Tensor. Must be one of the following types:
  1218. * float16, float. \n
  1219. *@par Outputs:
  1220. *y: A Tensor. Has the same type as "x" . \n
  1221. *@par Third-party framework compatibility
  1222. * Compatible with the Pytorch operator Trace.
  1223. */
  1224. REG_OP(Trace)
  1225. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT}))
  1226. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT}))
  1227. .OP_END_FACTORY_REG(Trace)
  1228. /**
  1229. *@brief Computes the generalized inverse of any matrix. \n
  1230. *@par Inputs:
  1231. * @li x: input matrix. Must be one of the following types:
  1232. * double, float. \n
  1233. *@par Attributes:
  1234. * @li rcond: An optional float >= 0 or inf. Defaults to 1e-15. \n
  1235. *@par Outputs:
  1236. * y: A Tensor with the same type and shape of x's transpose. \n
  1237. */
  1238. REG_OP(Pinverse)
  1239. .INPUT(x, TensorType({ DT_FLOAT, DT_DOUBLE }))
  1240. .OUTPUT(y, TensorType({ DT_FLOAT, DT_DOUBLE }))
  1241. .ATTR(rcond, Float, 1e-15)
  1242. .OP_END_FACTORY_REG(Pinverse)
  1243. } // namespace ge
  1244. #endif // OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示