You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_calculation_ops.h 64 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
3 years ago
3 years ago
5 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file matrix_calculation_ops.h
  18. * \brief
  19. */
  20. #ifndef OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_
  21. #define OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_
  22. #include "graph/operator_reg.h"
  23. namespace ge {
  24. /**
  25. * @brief Backprop W of AttentionLnQKV + ReduceSumD \n
  26. * @par Inputs:
  27. * Four inputs, including:
  28. * @li x: A Tensor. Must be one of the following types: float16.
  29. * @li query_dx: A Tensor. Must be one of the following types: float16.
  30. * @li key_dw: A Tensor. Must be one of the following types: float16.
  31. * @li value_dw: A Tensor. Must be one of the following types: float16.
  32. * @par Attributes:
  33. * @li trans_a: A optional attribute, the type is bool. Defaults to True.
  34. * @li trans_b: A optional attribute, the type is bool. Defaults to False. \n
  35. * @par Outputs:
  36. * Six outputs, including:
  37. * @li dw_query: A Tensor. Must be one of the following types: float16.
  38. * @li dw_key: A Tensor. Must be one of the following types: float16.
  39. * @li dw_value: A Tensor. Must be one of the following types: float16.
  40. * @li dbias_query: A Tensor. Must be one of the following types: float16.
  41. * @li dbias_key: A Tensor. Must be one of the following types: float16.
  42. * @li dbias_value: A Tensor. Must be one of the following types: float16. \n
  43. * @par Restrictions:
  44. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n
  45. */
  46. REG_OP(AttentionQKVGradW)
  47. .INPUT(x, TensorType({DT_FLOAT16}))
  48. .INPUT(query_dx, TensorType({DT_FLOAT16}))
  49. .INPUT(key_dw, TensorType({DT_FLOAT16}))
  50. .INPUT(value_dw, TensorType({DT_FLOAT16}))
  51. .OUTPUT(dw_query, TensorType({DT_FLOAT16}))
  52. .OUTPUT(dw_key, TensorType({DT_FLOAT16}))
  53. .OUTPUT(dw_value, TensorType({DT_FLOAT16}))
  54. .OUTPUT(dbias_query, TensorType({DT_FLOAT16}))
  55. .OUTPUT(dbias_key, TensorType({DT_FLOAT16}))
  56. .OUTPUT(dbias_value, TensorType({DT_FLOAT16}))
  57. .ATTR(trans_a, Bool, true)
  58. .ATTR(trans_b, Bool, false)
  59. .OP_END_FACTORY_REG(AttentionQKVGradW)
  60. /**
  61. * @brief Backprop X of AttentionLnQKV + AddN \n
  62. * @par Inputs:
  63. * Seven inputs, including:
  64. * @li ln_dx: A Tensor. Must be one of the following types: float16.
  65. * @li query_dx: A Tensor. Must be one of the following types: float16.
  66. * @li key_dw: A Tensor. Must be one of the following types: float16.
  67. * @li value_dw: A Tensor. Must be one of the following types: float16.
  68. * @li kernel_query: A Tensor. Must be one of the following types: float16.
  69. * @li kernel_key: A Tensor. Must be one of the following types: float16.
  70. * @li kernel_value: A Tensor. Must be one of the following types: float16. \n
  71. * @par Attributes:
  72. * @li trans_a: A optional attribute, the type is bool. Defaults to False.
  73. * @li trans_b: A optional attribute, the type is bool. Defaults to True. \n
  74. * @par Outputs:
  75. * One outputs, including:
  76. * @li dx: A Tensor. Must be one of the following types: float16. \n
  77. * @par Restrictions:
  78. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n
  79. */
  80. REG_OP(AttentionQKVGradX)
  81. .INPUT(ln_dx, TensorType({DT_FLOAT16}))
  82. .INPUT(query_dx, TensorType({DT_FLOAT16}))
  83. .INPUT(key_dw, TensorType({DT_FLOAT16}))
  84. .INPUT(value_dw, TensorType({DT_FLOAT16}))
  85. .INPUT(kernel_query, TensorType({DT_FLOAT16}))
  86. .INPUT(kernel_key, TensorType({DT_FLOAT16}))
  87. .INPUT(kernel_value, TensorType({DT_FLOAT16}))
  88. .OUTPUT(dx, TensorType({DT_FLOAT16}))
  89. .ATTR(trans_a, Bool, false)
  90. .ATTR(trans_b, Bool, true)
  91. .OP_END_FACTORY_REG(AttentionQKVGradX)
  92. /**
  93. * @brief
  94. / (MatMul -> ConfusionTransposeD).
  95. LayerNorm - (MatMul -> ConfusionTransposeD).
  96. \ (MatMul -> ConfusionTransposeD). \n
  97. * @par Inputs:
  98. * Nine inputs, including:
  99. * @li x: A Tensor. Must be one of the following types: float16.
  100. * @li kernel_query: A Tensor. Must be one of the following types: float16.
  101. * @li kernel_key: A Tensor. Must be one of the following types: float16.
  102. * @li kernel_value: A Tensor. Must be one of the following types: float16.
  103. * @li gamma: A Tensor. Must be one of the following types: float16.
  104. * @li beta: A Tensor. Must be one of the following types: float16.
  105. * @li bias_query: A Tensor. Must be one of the following types: float16.
  106. * @li bias_key: A Tensor. Must be one of the following types: float16.
  107. * @li bias_value: A Tensor. Must be one of the following types: float16. \n
  108. * @par Attributes:
  109. * @li epsilon: A optional attribute, the type is float32. Defaults to 1e-7.
  110. * @li trans_a: A optional attribute, the type is bool. Defaults to False.
  111. * @li trans_b: A optional attribute, the type is bool. Defaults to False. \n
  112. * @par Outputs:
  113. * Six outputs, including:
  114. * @li norm: A Tensor. Must be one of the following types: float16.
  115. * @li query_output: A Tensor. Must be one of the following types: float16.
  116. * @li key_output: A Tensor. Must be one of the following types: float16.
  117. * @li value_output: A Tensor. Must be one of the following types: float16.
  118. * @li mean: A Tensor. Must be one of the following types: float16.
  119. * @li variance: A Tensor. Must be one of the following types: float16. \n
  120. * @par Restrictions:
  121. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n
  122. */
  123. REG_OP(AttentionLnQKV)
  124. .INPUT(x, TensorType({DT_FLOAT16}))
  125. .INPUT(kernel_query, TensorType({DT_FLOAT16}))
  126. .INPUT(kernel_key, TensorType({DT_FLOAT16}))
  127. .INPUT(kernel_value, TensorType({DT_FLOAT16}))
  128. .INPUT(gamma, TensorType({DT_FLOAT16}))
  129. .INPUT(beta, TensorType({DT_FLOAT16}))
  130. .OPTIONAL_INPUT(bias_query, TensorType({DT_FLOAT16}))
  131. .OPTIONAL_INPUT(bias_key, TensorType({DT_FLOAT16}))
  132. .OPTIONAL_INPUT(bias_value, TensorType({DT_FLOAT16}))
  133. .OUTPUT(norm, TensorType({DT_FLOAT16}))
  134. .OUTPUT(query_output, TensorType({DT_FLOAT16}))
  135. .OUTPUT(key_output, TensorType({DT_FLOAT16}))
  136. .OUTPUT(value_output, TensorType({DT_FLOAT16}))
  137. .OUTPUT(mean, TensorType({DT_FLOAT16}))
  138. .OUTPUT(variance, TensorType({DT_FLOAT16}))
  139. .ATTR(epsilon, Float, 0.0000001)
  140. .ATTR(trans_a, Bool, false)
  141. .ATTR(trans_b, Bool, false)
  142. .OP_END_FACTORY_REG(AttentionLnQKV)
  143. /**
  144. * @brief
  145. swin_transformer model specific structure.Operator only supports swin_transformer. \n
  146. * @par Inputs:
  147. * Five inputs, including:
  148. * @li x: A Tensor. Must be one of the following types: float16.
  149. * @li gamma: A Tensor. Must be one of the following types: float16.
  150. * @li beta: A Tensor. Must be one of the following types: float16.
  151. * @li weight: A Tensor. Must be one of the following types: float16.
  152. * @li bias: A Tensor. Must be one of the following types: float16. \n
  153. * @par Attributes:
  154. * @li head_num: A optional attribute, the type is int.
  155. * @li head_dim: A optional attribute, the type is int.
  156. * @li seq_length: A optional attribute, the type is int.
  157. * @li shifts: A optional attribute, the type is list int. Defaults to ().
  158. * @li epsilon: A optional attribute, the type is float. Defaults to 1e-7. \n
  159. * @par Outputs:
  160. * Three outputs, including:
  161. * @li query_output: A Tensor. Must be one of the following types: float16.
  162. * @li key_output: A Tensor. Must be one of the following types: float16.
  163. * @li value_output: A Tensor. Must be one of the following types: float16. \n
  164. * @par Restrictions:
  165. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n
  166. */
  167. REG_OP(SwinTransformerLnQKV)
  168. .INPUT(x, TensorType({DT_FLOAT16}))
  169. .INPUT(gamma, TensorType({DT_FLOAT16}))
  170. .INPUT(beta, TensorType({DT_FLOAT16}))
  171. .INPUT(weight, TensorType({DT_FLOAT16}))
  172. .INPUT(bias, TensorType({DT_FLOAT16}))
  173. .OUTPUT(query_output, TensorType({DT_FLOAT16}))
  174. .OUTPUT(key_output, TensorType({DT_FLOAT16}))
  175. .OUTPUT(value_output, TensorType({DT_FLOAT16}))
  176. .REQUIRED_ATTR(head_num, Int)
  177. .REQUIRED_ATTR(head_dim, Int)
  178. .REQUIRED_ATTR(seq_length, Int)
  179. .ATTR(shifts, ListInt, {})
  180. .ATTR(epsilon, Float, 0.0000001)
  181. .OP_END_FACTORY_REG(SwinTransformerLnQKV)
  182. /**
  183. *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  184. *@par Inputs:
  185. *Three inputs, including:
  186. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float16,
  187. * float32, int32. Has format [ND, NHWC].
  188. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float16,
  189. * float32, int32. Has format [ND, NHWC].
  190. * @li bias: A optional 1D Tensor. Must be one of the following types: float16,
  191. * float32, int32. Has format [ND, NHWC] . \n
  192. *@par Attributes:
  193. *@li transpose_x1: A bool. If True, changes the shape of "x1" from [M, K] to [K, M].
  194. *@li transpose_x2: A bool. If True, changes the shape of "x2" from [M, K] to [K, M] . \n
  195. *@par Outputs:
  196. *y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  197. * float32, int32. Has format [ND, NHWC] . \n
  198. *@par Third-party framework compatibility
  199. * Compatible with the TensorFlow operator BatchMatmul.
  200. */
  201. REG_OP(MatMul)
  202. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  203. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  204. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  205. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  206. .ATTR(transpose_x1, Bool, false)
  207. .ATTR(transpose_x2, Bool, false)
  208. .OP_END_FACTORY_REG(MatMul)
  209. /**
  210. *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  211. *@par Inputs:
  212. *Four inputs, including:
  213. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float32,
  214. float16, int32, int8. Has format [ND, NHWC].
  215. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float32,
  216. float16, int32, int8. Has format [ND, NHWC].
  217. * @li bias: A 1D Tensor. Must be one of the following types: float32,
  218. float16, int32. Has format [ND, NHWC].
  219. * @li offset_w: A Optional 1D Tensor for quantized inference. Type is int8.
  220. Reserved. \n
  221. *@par Attributes:
  222. * @li transpose_x1: A bool. If True, changes the shape of "x1" from [K, M] to
  223. [M, K].
  224. * @li transpose_x2: A bool. If True, changes the shape of "x2" from [N, K] to
  225. [K, N].
  226. * @li offset_x: An optional integer for quantized MatMulV2.
  227. * The negative offset added to the input x1 for int8 type. Ensure offset_x
  228. within the effective range of int8 [-128, 127]. Defaults to "0". \n
  229. *@par Outputs:
  230. *y: The result matrix Tensor. 2D. Must be one of the following types: float32,
  231. float16, int32. Has format [ND, NHWC]. \n
  232. *@attention Constraints:
  233. * if performances better in format NZ, please close
  234. "MatmulTransdataFusionPass" in fusion configuration. \n
  235. *@par Third-party framework compatibility
  236. * Compatible with the TensorFlow operator BatchMatmul.
  237. */
  238. REG_OP(MatMulV2)
  239. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_INT4, DT_BF16}))
  240. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_INT4, DT_BF16}))
  241. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  242. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  243. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8, DT_INT4}))
  244. .ATTR(transpose_x1, Bool, false)
  245. .ATTR(transpose_x2, Bool, false)
  246. .ATTR(offset_x, Int, 0)
  247. .OP_END_FACTORY_REG(MatMulV2)
  248. /**
  249. *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  250. *@par Inputs:
  251. *Five inputs, including:
  252. * @li x1: A matrix Tensor. 2D. Must be one of the following types: int8.
  253. * @li x2: A matrix Tensor. 2D. Must be one of the following types: int8.
  254. * @li compress_index: A compress index matrix of type int8.
  255. * @li bias: An optional Tensor. 1D. Must be one of the following types: int32,
  256. float16.
  257. * @li offset_w: An optional matrix Tensor. 2D. Must be one of the following
  258. types: int8. \n
  259. *@par Attributes:
  260. *@li transpose_x1: A bool. If True, changes the shape of "x1" from [K, M] to
  261. [M, K].
  262. *@li transpose_x2: A bool. If True, changes the shape of "x2" from [N, K] to
  263. [K, N].
  264. *@li offset_x: An optional integer for quantized MatMulV2Compress.
  265. *The negative offset added to the input x1 for int8 type. Ensure offset_x
  266. within the effective range of int8 [-128, 127]. Defaults to "0". \n
  267. *@par Outputs:
  268. *y: The result matrix Tensor. 2D. Must be one of the following types: int32,
  269. * float16. \n
  270. *@attention Constraints:
  271. * if performances better in format NZ, please close
  272. "MatmulTransdataFusionPass" in fusion configuration.
  273. */
  274. REG_OP(MatMulV2Compress)
  275. .INPUT(x1, TensorType({DT_INT8}))
  276. .INPUT(x2, TensorType({DT_INT8}))
  277. .INPUT(compress_index, TensorType({DT_INT8}))
  278. .OPTIONAL_INPUT(bias, TensorType({DT_INT32, DT_FLOAT16}))
  279. .OUTPUT(y, TensorType({DT_INT32, DT_FLOAT16}))
  280. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  281. .ATTR(transpose_x1, Bool, false)
  282. .ATTR(transpose_x2, Bool, false)
  283. .ATTR(offset_x, Int, 0)
  284. .OP_END_FACTORY_REG(MatMulV2Compress)
  285. /**
  286. *@brief Performs Matrix-to-matrix Multiply, producing y=alpha[0]*a*b+beta[0]*c . \n
  287. *@attention Constraints:
  288. * For better performance, The k-axis must be aligned to 16 (input type
  289. * is float16) or 32 (input type is int8). \n
  290. *@par Inputs:
  291. *Five inputs, including:
  292. *@li a: A matrix Tensor. Must be one of the following types: float16, int8.
  293. * Has format [ND].
  294. *@li b: A matrix Tensor. Must be one of the following types: float16, int8.
  295. * Has format ND.
  296. *@li c: A matrix Tensor. Must be one of the following types: float16, int32,
  297. * float32. has format ND.
  298. *@li alpha: A 1D Tensor. The shape of alpha is [1].Must be one of the following
  299. * types: float16, int32, float32. Has format [ND].
  300. *@li beta: A 1D Tensor. The shape of beta is [1]. Must be one of the following
  301. * types: float16, int32, float32. Has format [ND].
  302. * The format of a, b, c has restriction:\n
  303. * When type of a is int8 and type of c is int32, the format of a, b, c should
  304. * all be ND.\n
  305. * When type of a is int8 and type of c is float32, the format of a, b, c should
  306. * all be ND.\n
  307. * When type of a is float16 and type of c is float16, the format of a, b, c
  308. * should all be ND.\n
  309. * When type of a is float16 and type of c is float32, the format of a, b, c
  310. * should all be ND. \n
  311. *@par Attributes:
  312. *Two attributes, including:
  313. *@li transpose_a: Optional. A bool. If True, changes the shape of "a" from
  314. * [M, K] to [K, M].
  315. *@li transpose_b: Optional. A bool. If True, changes the shape of "b" from
  316. * [K, N] to [N, K] . \n
  317. *@par Outputs:
  318. *y: The result matrix Tensor. Must be one of the following types: float16,
  319. * float32, int32. Has format [ND], the format should be equal to a.
  320. */
  321. REG_OP(GEMM)
  322. .INPUT(a, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  323. .INPUT(b, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  324. .INPUT(c, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  325. .INPUT(alpha, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  326. .INPUT(beta, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  327. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  328. .ATTR(transpose_a, Bool, false)
  329. .ATTR(transpose_b, Bool, false)
  330. .OP_END_FACTORY_REG(GEMM)
  331. /**
  332. *@brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  333. *@par Inputs:
  334. *Two inputs, including:
  335. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  336. * float32, int32. 2D or higher. Has format [ND, NHWC].
  337. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  338. * float32, int32. 2D or higher. Has format [ND, NHWC] . \n
  339. *@par Attributes:
  340. *@li adj_x1: A bool. If True, changes the shape of "x1" from [B, M, K] to [B, K, M].
  341. *@li adj_x2: A bool. If True, changes the shape of "x2" from [B, M, K] to [B, K, M] . \n
  342. *@par Outputs:
  343. *y: The result matrix Tensor. 2D or higher. Must be one of the following types: float16,
  344. * float32, int32. 2D or higher. Has format [ND, NHWC]. Has the same shape length as "x1" and "x2" . \n
  345. *@par Third-party framework compatibility
  346. * Compatible with the TensorFlow operator BatchMatmul.
  347. */
  348. REG_OP(BatchMatMul)
  349. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  350. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  351. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  352. .ATTR(adj_x1, Bool, false)
  353. .ATTR(adj_x2, Bool, false)
  354. .OP_END_FACTORY_REG(BatchMatMul)
  355. /**
  356. * @brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  357. * @par Inputs:
  358. * Three inputs, including:
  359. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  360. * float32, int32. 2D or higher. Has format [ND, NHWC].
  361. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  362. * float32, int32. 2D or higher. Has format [ND, NHWC] . \n
  363. * @li bias: A matrix Tensor. Must be one of the following types: float16,
  364. * float32, int32. 2D or higher. Has format [ND, NHWC] . \n
  365. * @par Attributes:
  366. * @li adj_x1: A bool. If True, changes the shape of "x1" from [B, M, K] to [B, K, M].
  367. * @li adj_x2: A bool. If True, changes the shape of "x2" from [B, M, K] to [B, K, M] . \n
  368. * @par Outputs:
  369. * y: The result matrix Tensor. 2D or higher. Must be one of the following types: float16,
  370. * float32, int32. 2D or higher. Has format [ND, NHWC]. Has the same shape length as "x1" and "x2" . \n
  371. *@attention Constraints:
  372. * if performances better in format NZ, please close
  373. "MatmulTransdataFusionPass" in fusion configuration. \n
  374. * @par Third-party framework compatibility
  375. * Compatible with the TensorFlow operator BatchMatmul.
  376. */
  377. REG_OP(BatchMatMulV2)
  378. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_INT4, DT_BF16}))
  379. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_INT4, DT_BF16}))
  380. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  381. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8, DT_INT4}))
  382. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  383. .ATTR(adj_x1, Bool, false)
  384. .ATTR(adj_x2, Bool, false)
  385. .ATTR(offset_x, Int, 0)
  386. .OP_END_FACTORY_REG(BatchMatMulV2)
  387. /**
  388. *@brief Computes half the L2 norm of a tensor without the sqrt . \n
  389. *@par Inputs:
  390. * x: A Tensor.
  391. * TensorType::FloatingDataType() . \n
  392. *@par Outputs:
  393. *y: A Tensor. Has the same type as "x". \n
  394. *@attention Constraints:
  395. * if performances better in format NZ, please close
  396. "MatmulTransdataFusionPass" in fusion configuration. \n
  397. *@par Third-party framework compatibility
  398. *Compatible with the TensorFlow operator L2Loss.
  399. */
  400. REG_OP(L2Loss)
  401. .INPUT(x, TensorType::FloatingDataType())
  402. .OUTPUT(y, TensorType::FloatingDataType())
  403. .OP_END_FACTORY_REG(L2Loss)
  404. /**
  405. *@brief: Returns a batched diagonal tensor with a given batched diagonal values . \n
  406. *@par Inputs:
  407. *x: A Tensor. Must be one of the following types:
  408. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  409. * qint8, quint8, qint32, uint16, complex128, uint32, uint64 . \n
  410. *@par Outputs:
  411. *y: A Tensor. Has the same type as "x" . \n
  412. *@par Third-party framework compatibility
  413. * Compatible with the TensorFlow operator MatrixDiag.
  414. */
  415. REG_OP(MatrixDiag)
  416. .INPUT(x, TensorType::BasicType())
  417. .OUTPUT(y, TensorType::BasicType())
  418. .OP_END_FACTORY_REG(MatrixDiag)
  419. /**
  420. *@brief: Returns a batched diagonal tensor with a given batched diagonal values . \n
  421. *@par Inputs:
  422. * Two inputs, including:
  423. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  424. *@li assist: A Tensor of the same type as "x" . \n
  425. *@par Outputs:
  426. *y: A Tensor. Has the same type as "x" . \n
  427. *@par Third-party framework compatibility
  428. * Compatible with the TensorFlow operator MatrixDiag.
  429. *
  430. * @par Restrictions:
  431. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixDiag instead.
  432. */
  433. REG_OP(MatrixDiagD)
  434. .INPUT(x, TensorType::BasicType())
  435. .INPUT(assist, TensorType::BasicType())
  436. .OUTPUT(y, TensorType::BasicType())
  437. .OP_END_FACTORY_REG(MatrixDiagD)
  438. /**
  439. *@brief: Returns the batched diagonal part of a batched tensor . \n
  440. *@par Inputs:
  441. *x: A Tensor. Must be one of the following types:
  442. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  443. * qint8, quint8, qint32, uint16, complex128, uint32, uint64 . \n
  444. *@par Outputs:
  445. *y: A Tensor. Has the same type as "x" . \n
  446. *@par Third-party framework compatibility
  447. * Compatible with the TensorFlow operator MatrixDiagPart.
  448. */
  449. REG_OP(MatrixDiagPart)
  450. .INPUT(x, TensorType::BasicType())
  451. .OUTPUT(y, TensorType::BasicType())
  452. .OP_END_FACTORY_REG(MatrixDiagPart)
  453. /**
  454. *@brief: Returns the batched diagonal part of a batched tensor . \n
  455. *@par Inputs:
  456. * Two inputs, including:
  457. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  458. *@li assist: A Tensor of the same type as "x" . \n
  459. *@par Outputs:
  460. *y: A Tensor. Has the same type as "x" . \n
  461. *@par Third-party framework compatibility
  462. * Compatible with the TensorFlow operator MatrixDiagPart.
  463. *
  464. * @par Restrictions:
  465. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixDiagPart instead.
  466. */
  467. REG_OP(MatrixDiagPartD)
  468. .INPUT(x, TensorType::BasicType())
  469. .INPUT(assist, TensorType::BasicType())
  470. .OUTPUT(y, TensorType::BasicType())
  471. .OP_END_FACTORY_REG(MatrixDiagPartD)
  472. /**
  473. *@brief: Returns a batched matrix tensor with new batched diagonal values . \n
  474. *@par Inputs:
  475. * Two inputs, including:
  476. *@li x: A Tensor. Must be one of the following types:
  477. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  478. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  479. *@li diagonal: A Tensor of the same type as "x" . \n
  480. *@par Outputs:
  481. *y: A Tensor. Has the same type as "x" . \n
  482. *@par Third-party framework compatibility
  483. * Compatible with the TensorFlow operator MatrixSetDiag.
  484. */
  485. REG_OP(MatrixSetDiag)
  486. .INPUT(x, TensorType::BasicType())
  487. .INPUT(diagonal, TensorType::BasicType())
  488. .OUTPUT(y, TensorType::BasicType())
  489. .OP_END_FACTORY_REG(MatrixSetDiag)
  490. /**
  491. *@brief: Returns a batched matrix tensor with new batched diagonal values . \n
  492. *@par Inputs:
  493. * Three inputs, including:
  494. *@li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  495. *@li diagonal: A Tensor of the same type as "x".
  496. *@li assist: A Tensor of the same type as "x" . \n
  497. *@par Outputs:
  498. *y: A Tensor. Has the same type as "x" . \n
  499. *@par Third-party framework compatibility
  500. * Compatible with the TensorFlow operator MatrixSetDiag.
  501. *
  502. * @par Restrictions:
  503. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixSetDiag instead.
  504. */
  505. REG_OP(MatrixSetDiagD)
  506. .INPUT(x, TensorType::BasicType())
  507. .INPUT(diagonal, TensorType::BasicType())
  508. .INPUT(assist, TensorType::BasicType())
  509. .OUTPUT(y, TensorType::BasicType())
  510. .OP_END_FACTORY_REG(MatrixSetDiagD)
  511. /**
  512. * @brief Function AttentionScore. \n
  513. * @par Inputs:
  514. * six inputs, including:
  515. * @li query: A matrix Tensor. The type only support float16.
  516. * @li key: A matrix Tensor. The type only support float16.
  517. * @li value: A matrix Tensor. The type only support float16.
  518. * @li padding_mask: A matrix Tensor. The type only support float16.
  519. * @li scale: A scalar. The type only support float16.
  520. * @li drop_mask: A matrix Tensor. The type only support uint8. \n
  521. * @par Attributes:
  522. * @li keep_prob: A mutable Tensor. Must met all of the following rules:
  523. shape of "keep_prob" should be (1,) or [1,].
  524. * @li query_transpose: A bool. If True, changes the shape of "query" from [K, M] to
  525. [M, K].
  526. * @li key_transpose: A bool. If True, changes the shape of "key" from [N, K] to
  527. [K, N].
  528. * @li bmm_score_transpose_a: A bool. If True, changes the shape of "mid_data" from [K, M] to
  529. [M, K].
  530. * @li bmm_score_transpose_b: A bool. If True, changes the shape of "value" from [N, K] to
  531. [K, N].
  532. * @li axes: A list of int. The dimension softmax would be performed on. Defaults
  533. to "[-1]" . \n
  534. * @par Outputs:
  535. * attention_score: The result matrix Tensor. The type only support float16.
  536. * softmax_output: The result matrix Tensor. The type only support float16.
  537. * @par Restrictions:
  538. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  539. */
  540. REG_OP(AttentionScore)
  541. .INPUT(query, TensorType({DT_FLOAT16}))
  542. .INPUT(key, TensorType({DT_FLOAT16}))
  543. .INPUT(value, TensorType({DT_FLOAT16}))
  544. .INPUT(padding_mask, TensorType({DT_FLOAT16}))
  545. .INPUT(scale, TensorType({DT_FLOAT16}))
  546. .OPTIONAL_INPUT(drop_mask, TensorType({DT_INT8}))
  547. .OUTPUT(attention_score, TensorType({DT_FLOAT16}))
  548. .OUTPUT(softmax_output, TensorType({DT_FLOAT16}))
  549. .ATTR(keep_prob, Float, 1.0)
  550. .ATTR(query_transpose, Bool, false)
  551. .ATTR(key_transpose, Bool, false)
  552. .ATTR(bmm_score_transpose_a, Bool, false)
  553. .ATTR(bmm_score_transpose_b, Bool, false)
  554. .ATTR(softmax_axes, ListInt, {-1})
  555. .OP_END_FACTORY_REG(AttentionScore)
  556. /**
  557. *@brief Applies sparse "updates" to individual values or slices in a Variable . \n
  558. *@par Inputs:
  559. * Three inputs, including:
  560. *@li var: An ND Tensor.
  561. *Must be one of the following types: float16, float32, int8, uint8, double,
  562. * int64, complex64, qint8, quint8, qint32, uint16, complex128, half, uint32,
  563. * uint64
  564. *@li indices: An ND Tensor.
  565. *Must be one of the following types: int32 or int64
  566. *@li updates: An ND Tensor.
  567. *Must be one of the following types: float16, float32, int8, uint8, double,
  568. * int64, complex64, qint8, quint8, qint32, uint16, complex128, half, uint32,
  569. * uint64
  570. *@par Attributes:
  571. *use_locking: An optional bool. Defaults to "False". If "True",
  572. * the operation will be protected by a lock . \n
  573. *@par Outputs:
  574. *var: A Tensor. Has the same type and format as input "var" . \n
  575. *@par Third-party framework compatibility
  576. * Compatible with the TensorFlow operator ScatterNdUpdate.
  577. */
  578. REG_OP(ScatterNdUpdate)
  579. .INPUT(var, TensorType::BasicType())
  580. .INPUT(indices, TensorType::IndexNumberType())
  581. .INPUT(updates, TensorType::BasicType())
  582. .OUTPUT(var, TensorType::BasicType())
  583. .ATTR(use_locking, Bool, false)
  584. .OP_END_FACTORY_REG(ScatterNdUpdate)
  585. /**
  586. *@brief Applies sparse addition to individual values or slices in a Variable . \n
  587. *@par Inputs:
  588. * Three inputs, including:
  589. *@li x: An ND Tensor. \n
  590. *Must be one of the following types: float16, float32, bool, int8, uint8
  591. *@li indices: An ND Tensor. \n
  592. *Must be one of the following types: int32
  593. *@li updates: An ND Tensor. \n
  594. *Must be one of the following types: float16, float32, bool, int8, uint8
  595. *@par Outputs:
  596. *y: A Tensor. Has the same type and format as input "x" . \n
  597. *@par Third-party framework compatibility
  598. * Compatible with the TensorFlow operator TensorScatterUpdate.
  599. *@par Restrictions:
  600. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  601. */
  602. REG_OP(TensorScatterUpdate)
  603. .INPUT(x, TensorType::BasicType())
  604. .INPUT(indices, TensorType::IndexNumberType())
  605. .INPUT(updates, TensorType::BasicType())
  606. .OUTPUT(y, TensorType::BasicType())
  607. .OP_END_FACTORY_REG(TensorScatterUpdate)
  608. /**
  609. *@brief Uses "updates" to update tensor "data" by "indices". \n
  610. *@par Inputs:
  611. * Three inputs, including:
  612. *@li data: An ND Tensor . \n
  613. *Must be one of the following types: float16, float32, int32, int8, uint8
  614. *@li indices: An ND Tensor of type int32 or int64
  615. *@li updates: An Tensor. Same shape as indices. format:NCHW, NHWC . \n
  616. *Must be one of the following types: float16, float32, int32, int8, uint8
  617. *@par Attributes:
  618. *@li axis: An optional attribute. Defaults to 0.
  619. *@par Outputs:
  620. *y: A Tensor. Has the same type and format as input "data" . \n
  621. *@par Third-party framework compatibility
  622. * Compatible with the ONNX operator ScatterElements.
  623. */
  624. REG_OP(ScatterElements)
  625. .INPUT(data, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  626. .INPUT(indices, TensorType::IndexNumberType())
  627. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  628. .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  629. .ATTR(axis, Int, 0)
  630. .OP_END_FACTORY_REG(ScatterElements)
  631. /**
  632. *@brief Adds sparse "updates" to a variable reference . \n
  633. *@par Inputs:
  634. * Three inputs, including:
  635. *@li var: An ND Tensor .
  636. *Must be one of the following types: float16, float, int32, int8, uint8
  637. *@li indices: An ND Tensor . \n
  638. *Must be one of the following types: int32 or int64
  639. *@li updates: An ND Tensor .
  640. *Must be one of the following types: float16, float, int32, int8, uint8
  641. *@par Attributes:
  642. *use_locking: An optional bool. Defaults to "False". If "True",
  643. * the operation will be protected by a lock . \n
  644. *@par Outputs:
  645. *var: A Tensor. Has the same type and format as input "var" . \n
  646. *@par Third-party framework compatibility
  647. * Compatible with the TensorFlow operator ScatterAdd.
  648. */
  649. REG_OP(ScatterAdd)
  650. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  651. .INPUT(indices, TensorType::IndexNumberType())
  652. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  653. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  654. .ATTR(use_locking, Bool, false)
  655. .OP_END_FACTORY_REG(ScatterAdd)
  656. /**
  657. *@brief Adds sparse "updates" to a variable reference . \n
  658. *@par Inputs:
  659. * Three inputs, including:
  660. *@li var: An ND Tensor .
  661. *Must be one of the following types: float16, float32, int32, int8, uint8
  662. *@li indices: An ND Tensor of type int32 or int64
  663. *@li updates: An ND Tensor .
  664. *Must be one of the following types: float16, float32, int32, int8, uint8
  665. *@par Attributes:
  666. * axis: An required int. The axis along which to index. \n
  667. *@par Outputs:
  668. *var: A Tensor. Has the same type and format as input "var" . \n
  669. *@par Third-party framework compatibility
  670. * Compatible with the pytorch operator ScatterAdd.
  671. */
  672. REG_OP(ScatterAddWithAxis)
  673. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  674. .INPUT(indices, TensorType::IndexNumberType())
  675. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  676. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  677. .REQUIRED_ATTR(axis, Int)
  678. .OP_END_FACTORY_REG(ScatterAddWithAxis)
  679. /**
  680. *@brief Divides a variable reference by sparse updates . \n
  681. *@par Inputs:
  682. * Three inputs, including:
  683. *@li var: An ND Tensor.
  684. *Must be one of the following types: float16, float, int32, int8, uint8
  685. *@li indices: An ND Tensor.
  686. *Must be one of the following types: int32 or int64
  687. *@li updates: An ND Tensor.
  688. *Must be one of the following types: float16, float, int32, int8, uint8
  689. *@par Attributes:
  690. *use_locking: An optional bool. Defaults to "False". If "True",
  691. * the operation will be protected by a lock . \n
  692. *@par Outputs:
  693. *var: A Tensor. Has the same type and format as input "var" . \n
  694. *@par Third-party framework compatibility
  695. * Compatible with the TensorFlow operator ScatterDiv.
  696. */
  697. REG_OP(ScatterDiv)
  698. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  699. .INPUT(indices, TensorType::IndexNumberType())
  700. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  701. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  702. .ATTR(use_locking, Bool, false)
  703. .OP_END_FACTORY_REG(ScatterDiv)
  704. /**
  705. *@brief Applies sparse addition to individual values or slices in a Variable . \n
  706. *@par Inputs:
  707. * Three inputs, including:
  708. *@li var: An ND Tensor.
  709. *Must be one of the following types: float16, float, int32, int8, uint8
  710. *@li indices: An ND Tensor.
  711. *Must be one of the following types: int32 or int64
  712. *@li updates: An ND Tensor.
  713. *Must be one of the following types: float16, float, int32, int8, uint8
  714. *@par Attributes:
  715. *use_locking: An optional bool. Defaults to "False". If "True",
  716. * the operation will be protected by a lock . \n
  717. *@par Outputs:
  718. *var: A Tensor. Has the same type and format as input "var" . \n
  719. *@par Third-party framework compatibility
  720. * Compatible with the TensorFlow operator ScatterNdAdd.
  721. */
  722. REG_OP(ScatterNdAdd)
  723. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  724. .INPUT(indices, TensorType::IndexNumberType())
  725. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  726. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  727. .ATTR(use_locking, Bool, false)
  728. .OP_END_FACTORY_REG(ScatterNdAdd)
  729. /**
  730. *@brief Applies sparse addition to individual values or slices in a Variable . \n
  731. *@par Inputs:
  732. * Three inputs, including:
  733. *@li x: An ND Tensor. \n
  734. *Must be one of the following types: float16, float32, int32, int8, uint8
  735. *@li indices: An ND Tensor. \n
  736. *Must be one of the following types: int32
  737. *@li updates: An ND Tensor. \n
  738. * Must be one of the following types: float16, float32, int32, int8, uint8
  739. *@par Outputs:
  740. *y: A Tensor. Has the same type and format as input "x" . \n
  741. *@par Third-party framework compatibility
  742. * Compatible with the TensorFlow operator TensorScatterAdd.
  743. *@par Restrictions:
  744. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  745. */
  746. REG_OP(TensorScatterAdd)
  747. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  748. .INPUT(indices, TensorType::IndexNumberType())
  749. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  750. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  751. .OP_END_FACTORY_REG(TensorScatterAdd)
  752. /**
  753. *@brief Applies sparse subtraction to individual values or slices in a Variable . \n
  754. *@par Inputs:
  755. * Three inputs, including:
  756. *@li var: An ND Tensor.
  757. *Must be one of the following types: float16, float, int32, int8, uint8
  758. *@li indices: An ND Tensor.
  759. *Must be one of the following types: int32 or int64
  760. *@li updates: An ND Tensor.
  761. *Must be one of the following types: float16, float, int32, int8, uint8
  762. *@par Attributes:
  763. *use_locking: An optional bool. Defaults to "False". If "True",
  764. * the operation will be protected by a lock . \n
  765. *@par Outputs:
  766. * var: A Tensor. Has the same type and format as input "var" . \n
  767. *@par Third-party framework compatibility
  768. * Compatible with the TensorFlow operator ScatterNdSub.
  769. */
  770. REG_OP(ScatterNdSub)
  771. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  772. .INPUT(indices, TensorType::IndexNumberType())
  773. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  774. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  775. .ATTR(use_locking, Bool, false)
  776. .OP_END_FACTORY_REG(ScatterNdSub)
  777. /**
  778. *@brief Applies sparse addition to individual values or slices in a Variable . \n
  779. *@par Inputs:
  780. * Three inputs, including:
  781. *@li x: An ND Tensor. \n
  782. *Must be one of the following types: float16, float32, int32, int8, uint8
  783. *@li indices: An ND Tensor. \n
  784. *Must be one of the following types: int32
  785. *@li updates: An ND Tensor. \n
  786. *Must be one of the following types: float16, float32, int32, int8, uint8
  787. *@par Outputs:
  788. * y: A Tensor. Has the same type and format as input "x" . \n
  789. *@par Third-party framework compatibility
  790. * Compatible with the TensorFlow operator TensorScatterSub.
  791. *@par Restrictions:
  792. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  793. */
  794. REG_OP(TensorScatterSub)
  795. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  796. .INPUT(indices, TensorType::IndexNumberType())
  797. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  798. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  799. .OP_END_FACTORY_REG(TensorScatterSub)
  800. /**
  801. *@brief Subtracts sparse updates to a variable reference . \n
  802. *@par Inputs:
  803. * Three inputs, including:
  804. *@li var: An ND Tensor.
  805. *Must be one of the following types: float16, float, int32, int8, uint8
  806. *@li indices: An ND Tensor.
  807. *Must be one of the following types: int32 or int64
  808. *@li updates: An ND Tensor.
  809. *Must be one of the following types: float16, float, int32, int8, uint8
  810. *@par Attributes:
  811. *use_locking: An optional bool. Defaults to "False". If "True",
  812. * the operation will be protected by a lock . \n
  813. *@par Outputs:
  814. * var: A Tensor. Has the same type and format as input "var" . \n
  815. *@par Third-party framework compatibility
  816. * Compatible with the TensorFlow operator ScatterSub.
  817. */
  818. REG_OP(ScatterSub)
  819. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  820. .INPUT(indices, TensorType::IndexNumberType())
  821. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  822. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  823. .ATTR(use_locking, Bool, false)
  824. .OP_END_FACTORY_REG(ScatterSub)
  825. /**
  826. *@brief: Returns the batched diagonal part of a batched tensor with "assist" . \n
  827. *@par Inputs:
  828. * Two inputs, including:
  829. * @li x: A Tensor of type float16, float32, or int32.
  830. * @li assist: A Tensor of the same type as "x" . \n
  831. *@par Outputs:
  832. *y: A Tensor. Has the same type as "x" . \n
  833. *@par Third-party framework compatibility
  834. * Compatible with the TensorFlow operator DiagPart.
  835. *
  836. * @par Restrictions:
  837. * Warning: THIS FUNCTION IS DEPRECATED. Please use DiagPart instead.
  838. */
  839. REG_OP(DiagPartD)
  840. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  841. .INPUT(assist, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  842. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  843. .OP_END_FACTORY_REG(DiagPartD)
  844. /**
  845. *@brief: Returns the batched diagonal part of a batched tensor . \n
  846. *@par Inputs:
  847. *x: A Tensor. Must be one of the following types:
  848. * float16, float32, int32, int64, double, complex64, complex128 . \n
  849. *@par Outputs:
  850. *y: A Tensor. Has the same type as "x" . \n
  851. *@par Third-party framework compatibility
  852. * Compatible with the TensorFlow operator DiagPart.
  853. */
  854. REG_OP(DiagPart)
  855. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  856. DT_COMPLEX64, DT_COMPLEX128}))
  857. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  858. DT_COMPLEX64, DT_COMPLEX128}))
  859. .OP_END_FACTORY_REG(DiagPart)
  860. /**
  861. *@brief Also known as a "fully-connected" layer, computes an inner product with a set of learned weights, and (optionally) adds biases . \n
  862. *@par Inputs:
  863. * Four inputs, including:
  864. *@li x: A Tensor of type float16, int8.
  865. *@li w: A weight matrix of type float16, int8.
  866. *@li b: An optional Tensor of type float16, int32, float32.
  867. *@li offset_w: An optional Tensor of type int8. Reserved. Only None Supported. \n
  868. *@par Attributes:
  869. *@li num_output: Required. An int, output neuron number. Reserved.
  870. *@li transpose: A bool, specifying weight whether to transpose input w, either "true" or "false". Defaults to "false".
  871. *@li axis: Optional. An int, 1 or 2, specifying which dimension the input "K" starts from. Defaults to 1.
  872. * The product of the subsequent dimensions starting form first dimension or the second dimension is "K".
  873. *@li offset_x: An optional integer for quantized FullyConnection.
  874. *The negative offset added to the input image for int8 type. Ensure offset_x within the
  875. *effective range of int8 [-128, 127]. Defaults to "0". \n
  876. *@par Outputs:
  877. *y: The result tensor of type float16, int32, float32 . \n
  878. *@par Third-party framework compatibility
  879. * Compatible with the Caffe operator InnerProduct . \n
  880. *@par Quantization supported or not
  881. * Yes
  882. */
  883. REG_OP(FullyConnection)
  884. .INPUT(x, TensorType({DT_FLOAT16, DT_INT8, DT_INT4, DT_FLOAT32, DT_BF16}))
  885. .INPUT(w, TensorType({DT_FLOAT16, DT_INT8, DT_INT4, DT_FLOAT32, DT_BF16}))
  886. .OPTIONAL_INPUT(b, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32, DT_BF16}))
  887. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8, DT_INT4}))
  888. .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32, DT_BF16}))
  889. .REQUIRED_ATTR(num_output, Int)
  890. .ATTR(transpose, Bool, false)
  891. .ATTR(axis, Int, 1)
  892. .ATTR(offset_x, Int, 0)
  893. .OP_END_FACTORY_REG(FullyConnection)
  894. /**
  895. *@brief Also known as a "fully-connected-compress" layer, computes an inner
  896. product with a set of learned weights, and (optionally) adds biases . \n
  897. *@par Inputs:
  898. * Five inputs, including:
  899. *@li x: A Tensor of type uint8, int8.
  900. *@li w: A weight matrix of type int8.
  901. *@li compress_index: A compress index matrix of type int8.
  902. *@li b: A Tensor of type int32.
  903. *@li offset_w: A Tensor of type int8.
  904. *@par Attributes:
  905. *@li num_output: A int, specifying the number of outputs.
  906. *@li transpose: A bool, specifying whether to transpose input w, either "true"
  907. or "false". Defaults to "false".
  908. *@li axis: Optional. A int, 1 or 2, specifying which dimension the input "K"
  909. starts from. Defaults to "1".
  910. * The product of the subsequent dimensions starting form first dimension or the
  911. second dimension is "K".
  912. *@li offset_x: An optional integer for quantized FullyConnectionCompress.
  913. *The negative offset added to the input image for int8 type. Ensure offset_x
  914. within the effective range of int8 [-128, 127]. Defaults to "0". \n
  915. *@par Outputs:
  916. *y: The result tensor of type int32. \n
  917. *@par Third-party framework compatibility
  918. * Compatible with the Caffe operator InnerProduct. \n
  919. *@par Quantization supported or not
  920. * Yes
  921. */
  922. REG_OP(FullyConnectionCompress)
  923. .INPUT(x, TensorType({DT_UINT8, DT_INT8}))
  924. .INPUT(w, TensorType({DT_INT8}))
  925. .INPUT(comress_index, TensorType({DT_INT8}))
  926. .OPTIONAL_INPUT(b, TensorType({DT_INT32}))
  927. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  928. .OUTPUT(y, TensorType({DT_INT32}))
  929. .REQUIRED_ATTR(num_output, Int)
  930. .ATTR(transpose, Bool, false)
  931. .ATTR(axis, Int, 1)
  932. .ATTR(offset_x, Int, 0)
  933. .OP_END_FACTORY_REG(FullyConnectionCompress)
  934. /**
  935. *@brief Computes the confusion matrix from predictions and labels . \n
  936. *@par Inputs:
  937. * Three inputs, including:
  938. *@li labels: A Tensor. Must be one of the following types: float16, float32,
  939. * int32, int8, uint8.
  940. *@li predictions: A Tensor. Must be one of the following types: float16,
  941. * float32, int32, int8, uint8.
  942. *@li weights: A Tensor. Must be one of the following types: float16, float32,
  943. * int32, int8, uint8 . \n
  944. *@par Attributes:
  945. *@li num_classes: An integer for the shape of the output matrix.
  946. * No default value.
  947. *@li dtype: Data type of the confusion matrix. No default value . \n
  948. *@par Outputs:
  949. *y: A Tensor. Has the same type and format as input "labels"
  950. *@attention Constraints:
  951. *@li "weights", "labels", and "predictions" are 1D tensors.
  952. *@li The output is with shape (num_classes, num_classes),
  953. * where, 1 <= num_classes <= 4096 . \n
  954. *@see Region()
  955. *@par Third-party framework compatibility
  956. * Compatible with the TensorFlow operator ConfusionMatrix.
  957. */
  958. REG_OP(ConfusionMatrix)
  959. .INPUT(labels, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  960. .INPUT(predictions, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  961. .OPTIONAL_INPUT(weights, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  962. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  963. .REQUIRED_ATTR(num_classes, Int)
  964. .REQUIRED_ATTR(dtype, String)
  965. .OP_END_FACTORY_REG(ConfusionMatrix)
  966. /**
  967. *@brief Multiplies sparse updates into a variable reference . \n
  968. *@par Inputs:
  969. * Three inputs, including:
  970. *@li var: An ND Tensor.
  971. *Must be one of the following types: float16, float, int32, int8, uint8
  972. *@li indices: An ND Tensor.
  973. *Must be one of the following types: int32 or int64
  974. *@li updates: An ND Tensor . \n
  975. *Must be one of the following types: float16, float, int32, int8, uint8
  976. *@par Attributes:
  977. *use_locking: An optional bool. Defaults to "False". If "True", the operation
  978. * will be protected by a lock . \n
  979. *@par Outputs:
  980. *var: A Tensor. Has the same type and format as input "var" . \n
  981. *@par Third-party framework compatibility
  982. * Compatible with the TensorFlow operator ScatterMul.
  983. */
  984. REG_OP(ScatterMul)
  985. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  986. .INPUT(indices, TensorType::IndexNumberType())
  987. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  988. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  989. .ATTR(use_locking, Bool, false)
  990. .OP_END_FACTORY_REG(ScatterMul)
  991. /**
  992. *@brief Reduces sparse updates into a variable reference using
  993. * the "min" operation . \n
  994. *@par Inputs:
  995. * Three inputs, including:
  996. *@li var: An ND Tensor.
  997. *Must be one of the following types: float16, float, int32, int8, uint8
  998. *@li indices: An ND Tensor.
  999. *Must be one of the following types: int32 or int64
  1000. *@li updates: An ND Tensor.
  1001. *Must be one of the following types: float16, float, int32, int8, uint8
  1002. *@par Attributes:
  1003. *use_locking: An optional bool. Defaults to "False". If "True", the operation
  1004. * will be protected by a lock . \n
  1005. *@par Outputs:
  1006. *var: A Tensor. Has the same type and format as input "var" . \n
  1007. *@par Third-party framework compatibility
  1008. * Compatible with the TensorFlow operator ScatterMin.
  1009. */
  1010. REG_OP(ScatterMin)
  1011. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1012. .INPUT(indices, TensorType::IndexNumberType())
  1013. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1014. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1015. .ATTR(use_locking, Bool, false)
  1016. .OP_END_FACTORY_REG(ScatterMin)
  1017. /**
  1018. *@brief Reduces sparse updates into a variable reference using the "max" operation . \n
  1019. *@par Inputs:
  1020. * Three inputs, including:
  1021. *@li var: An ND Tensor .
  1022. *Must be one of the following types: float16, float, int32, int8, uint8
  1023. *@li indices: An NCHW, NHWC, or ND Tensor . \n
  1024. *Must be one of the following types: int32 or int64
  1025. *@li updates: An NCHW, NHWC, or ND Tensor .
  1026. *Must be one of the following types: float16, float, int32, int8, uint8
  1027. *@par Attributes:
  1028. *use_locking: An optional bool. Defaults to "False".
  1029. * If "True", the operation will be protected by a lock . \n
  1030. *@par Outputs:
  1031. *var: A Tensor. Has the same type and format as input "var" . \n
  1032. *@par Third-party framework compatibility
  1033. * Compatible with the TensorFlow operator ScatterMax.
  1034. */
  1035. REG_OP(ScatterMax)
  1036. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1037. .INPUT(indices, TensorType::IndexNumberType())
  1038. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1039. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1040. .ATTR(use_locking, Bool, false)
  1041. .OP_END_FACTORY_REG(ScatterMax)
  1042. /**
  1043. *@brief Applies sparse updates to a variable reference . \n
  1044. *@par Inputs:
  1045. * Three inputs, including:
  1046. *@li var: An ND Tensor .
  1047. *Must be one of the following types: float16, float, int32, int8, uint8
  1048. *@li indices: An ND Tensor . \n
  1049. *Must be one of the following types: int32 or int64
  1050. *@li updates: An ND Tensor .
  1051. *Must be one of the following types: float16, float, int32, int8, uint8
  1052. *@par Attributes:
  1053. *use_locking: An optional bool. Defaults to "False". If "True",
  1054. * the operation will be protected by a lock . \n
  1055. *@par Outputs:
  1056. *var: A Tensor. Has the same type and format as input "var" . \n
  1057. *@par Third-party framework compatibility
  1058. * Compatible with the TensorFlow operator ScatterUpdate.
  1059. */
  1060. REG_OP(ScatterUpdate)
  1061. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1062. .INPUT(indices, TensorType::IndexNumberType())
  1063. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1064. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1065. .ATTR(use_locking, Bool, false)
  1066. .OP_END_FACTORY_REG(ScatterUpdate)
  1067. /**
  1068. *@brief Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched `input` . \n
  1069. *@par Inputs:
  1070. * Three inputs, including:
  1071. *@li input: Rank `r` tensor where `r >= 2`. \n
  1072. *@li k: \n
  1073. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  1074. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  1075. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  1076. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1077. *@li padding_value: The value to fill the area outside the specified diagonal band with. \n
  1078. *@par Outputs:
  1079. *diagonal: The extracted diagonal(s) . \n
  1080. *@par Third-party framework compatibility
  1081. * Compatible with the TensorFlow operator ScatterUpdate.
  1082. */
  1083. REG_OP(MatrixDiagPartV2)
  1084. .INPUT(input, TensorType::BasicType())
  1085. .INPUT(k, TensorType({DT_INT32}))
  1086. .INPUT(padding_value, TensorType::BasicType())
  1087. .OUTPUT(diagonal, TensorType::BasicType())
  1088. .OP_END_FACTORY_REG(MatrixDiagPartV2)
  1089. /**
  1090. *@brief Returns a batched matrix tensor with new batched diagonal values . \n
  1091. *@par Inputs:
  1092. * Three inputs, including:
  1093. *@li input: "Rank `r+1`, where `r >= 1`. \n
  1094. *@li diagonal: Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`. \n
  1095. *@li k:
  1096. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  1097. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  1098. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  1099. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1100. *@par Outputs:
  1101. *output: Rank `r+1`, with `output.shape = input.shape` . \n
  1102. *@par Third-party framework compatibility
  1103. * Compatible with the TensorFlow operator ScatterUpdate.
  1104. */
  1105. REG_OP(MatrixSetDiagV2)
  1106. .INPUT(input, TensorType::BasicType())
  1107. .INPUT(diagonal, TensorType::BasicType())
  1108. .INPUT(k, TensorType({DT_INT32}))
  1109. .OUTPUT(output, TensorType::BasicType())
  1110. .OP_END_FACTORY_REG(MatrixSetDiagV2)
  1111. /**
  1112. *@brief Returns a batched matrix tensor with new batched diagonal values . \n
  1113. *@par Inputs:
  1114. * Three inputs, including:
  1115. *@li input: "Rank `r+1`, where `r >= 1`. \n
  1116. *@li diagonal: Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`. \n
  1117. *@li k:
  1118. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  1119. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  1120. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  1121. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1122. *@par Attributes:
  1123. *@li align: An optional string. Defaults to RIGHT_LEFT. It is a string specifying \n
  1124. *how superdiagonals and subdiagonals should be aligned, respectively. \n
  1125. *other optional: LEFT_RIGHT, LEFT_LEFT, and RIGHT_RIGHT.\n
  1126. *@par Outputs:
  1127. *output: Rank `r+1`, with `output.shape = input.shape` . \n
  1128. *@par Third-party framework compatibility
  1129. * Compatible with the TensorFlow operator ScatterUpdate.
  1130. */
  1131. REG_OP(MatrixSetDiagV3)
  1132. .INPUT(input, TensorType::BasicType())
  1133. .INPUT(diagonal, TensorType::BasicType())
  1134. .INPUT(k, TensorType({DT_INT32}))
  1135. .OUTPUT(output, TensorType::BasicType())
  1136. .ATTR(align, String, "RIGHT_LEFT")
  1137. .OP_END_FACTORY_REG(MatrixSetDiagV3)
  1138. /**
  1139. *@brief Returns a batched diagonal tensor with given batched diagonal values . \n
  1140. *@par Inputs:
  1141. * Five inputs, including:
  1142. *@li diagonal: Rank `r`, where `r >= 1` \n
  1143. *@li k:
  1144. *Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  1145. *diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  1146. *(for a single diagonal) or a pair of integers specifying the low and high ends \n
  1147. *of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1148. *@li num_rows:
  1149. *The number of rows of the output matrix. If it is not provided, the op assumes \n
  1150. *the output matrix is a square matrix and infers the matrix size from k and the \n
  1151. *innermost dimension of `diagonal`. \n
  1152. *@li num_cols: An NCHW, NHWC, or ND Tensor.
  1153. *The number of columns of the output matrix. If it is not provided, the op \n
  1154. *assumes the output matrix is a square matrix and infers the matrix size from \n
  1155. *k and the innermost dimension of `diagonal`. \n
  1156. *@li padding_value: The number to fill the area outside the specified diagonal band with. \n
  1157. *@par Outputs:
  1158. *output: Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise . \n
  1159. *@par Third-party framework compatibility
  1160. * Compatible with the TensorFlow operator ScatterUpdate.
  1161. */
  1162. REG_OP(MatrixDiagV2)
  1163. .INPUT(diagonal, TensorType::BasicType())
  1164. .INPUT(k, TensorType({DT_INT32}))
  1165. .INPUT(num_rows, TensorType({DT_INT32}))
  1166. .INPUT(num_cols, TensorType({DT_INT32}))
  1167. .INPUT(padding_value, TensorType::BasicType())
  1168. .OUTPUT(output, TensorType::BasicType())
  1169. .OP_END_FACTORY_REG(MatrixDiagV2)
  1170. /**
  1171. * @brief Add updates to var_out according to axis and indices.
  1172. * @par Inputs:
  1173. * Three inputs, including:
  1174. * @li var: A Tensor. Must be one of the following types:
  1175. * float16, float32, int32, int8, uint8.
  1176. * @li indices: A Tensor of the indices, type should be int32.
  1177. * @li updates: A Tensor of the same type as "var".
  1178. * @par Attributes:
  1179. * @li axis: An required int to specify the axis to perform indices add.
  1180. * @par Outputs:
  1181. * @li var_out: A Tensor. Same as input "var".
  1182. * @par Third-party framework compatibility
  1183. * Compatible with the Pytorch operator index_add.
  1184. * @par Restrictions:
  1185. * Warning:THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  1186. */
  1187. REG_OP(IndexAdd)
  1188. .INPUT(var, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  1189. .INPUT(indices, TensorType({DT_INT32}))
  1190. .INPUT(updates, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  1191. .OUTPUT(var_out, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  1192. .ATTR(axis, Int, 0)
  1193. .OP_END_FACTORY_REG(IndexAdd)
  1194. /**
  1195. * @brief According to the index number of indexes, replace the value
  1196. *corresponding to X1 with the value in x2.
  1197. * @par Inputs:
  1198. * Three inputs, including:
  1199. * @li x1: A Tensor. Must be one of the following types:
  1200. *float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1201. *qint8, quint8, qint32, uint16, complex128, uint32, uint64. \n
  1202. * @li x2: A Tensor of the same type as "x1".
  1203. * @li indices: A Tensor of the indices,
  1204. * @par Attributes:
  1205. * @li accumulate: Does it support self accumulation.Defaults to 0.
  1206. * @par Outputs:
  1207. * @li y: A Tensor. Same as input "x1".
  1208. * @par Third-party framework compatibility
  1209. * Compatible with the Pytorch operator index_put.
  1210. * @par Restrictions:
  1211. * Warning:THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  1212. */
  1213. REG_OP(IndexPut)
  1214. .INPUT(x1, TensorType::BasicType())
  1215. .INPUT(x2, TensorType::BasicType())
  1216. .OUTPUT(y, TensorType::BasicType())
  1217. .REQUIRED_ATTR(indices, ListInt)
  1218. .ATTR(accumulate, Int, 0)
  1219. .OP_END_FACTORY_REG(IndexPut)
  1220. /**
  1221. *@brief: Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input \n
  1222. *@par Inputs:
  1223. *x: A Tensor. Must be one of the following types:
  1224. *float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1225. *qint8, quint8, qint32, uint16, complex128, uint32, uint64. \n
  1226. *@par Attributes:
  1227. *diagonal: An optional attribute indicates the diagonal to consider. \n
  1228. *@par Outputs:
  1229. *y: A Tensor. Has the same type as "x" . \n
  1230. *@par Third-party framework compatibility
  1231. * Compatible with the Pytorch operator Triu.
  1232. */
  1233. REG_OP(Triu)
  1234. .INPUT(x, TensorType::BasicType())
  1235. .ATTR(diagonal, Int, 0)
  1236. .OUTPUT(y, TensorType::BasicType())
  1237. .OP_END_FACTORY_REG(Triu)
  1238. /**
  1239. *@brief: Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input \n
  1240. *@par Inputs:
  1241. *x: A Tensor. Must be one of the following types:
  1242. *float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1243. *qint8, quint8, qint32, uint16, complex128, uint32, uint64. \n
  1244. *@par Attributes:
  1245. *diagonal: An optional attribute indicates the diagonal to consider. \n
  1246. *@par Outputs:
  1247. *y: A Tensor. Has the same type as "x" . \n
  1248. *@par Third-party framework compatibility
  1249. * Compatible with the Pytorch operator Tril.
  1250. */
  1251. REG_OP(Tril)
  1252. .INPUT(x, TensorType::BasicType())
  1253. .ATTR(diagonal, Int, 0)
  1254. .OUTPUT(y, TensorType::BasicType())
  1255. .OP_END_FACTORY_REG(Tril)
  1256. /**
  1257. *@brief Concatenates a list of N tensors along the first dimension.
  1258. *@par Inputs:
  1259. * Two inputs, including:
  1260. * @li values: A list of Tensors. Must be one of the following types: int32, float16, float32.
  1261. * Tensors to be concatenated. All must have size 1 in the first dimension and same shape.
  1262. * It's a dynamic input.
  1263. * @li shape: A Tensor of the same type as "x".
  1264. * The final shape of the result. Should be equal to the shapes of any input
  1265. * but with the number of input values in the first dimension . \n
  1266. *@par Attributes:
  1267. *equation: The subscripts for the Einstein summation. \n
  1268. *N: tensor size of input \n
  1269. *@par Outputs:
  1270. *@li y: Sums the product of the elements of the input operands along dimensions specified
  1271. using a notation based on the Einstein summation convention. \n
  1272. *@attention Constraints:
  1273. *Input N must be Int. \n
  1274. *@par Third-party framework compatibility
  1275. *Compatible with Pytorch einsum operator.
  1276. */
  1277. REG_OP(Einsum)
  1278. .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
  1279. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
  1280. .REQUIRED_ATTR(equation, String)
  1281. .REQUIRED_ATTR(N, Int)
  1282. .OP_END_FACTORY_REG(Einsum)
  1283. /**
  1284. *@brief Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. \n
  1285. *@par Inputs:
  1286. *No inputs
  1287. *@par Attributes:
  1288. *@li num_rows: An required int. \n
  1289. *@li num_columns: An optional int.Defaults to 0. \n
  1290. *@li batch_shape: An optional ListInt.Defaults to []. \n
  1291. *@li dtype: An optional int.Defaults to 0. \n
  1292. *@par Outputs:
  1293. *y: A Tensor with targeted type and shape. \n
  1294. *@par Third-party framework compatibility
  1295. *Compatible with the Pytorch operator Eye. \n
  1296. */
  1297. REG_OP(Eye)
  1298. .OUTPUT(y, TensorType::BasicType()) /* "Result, has targeted element type" */
  1299. .REQUIRED_ATTR(num_rows, Int)
  1300. .ATTR(num_columns, Int, 0)
  1301. .ATTR(batch_shape, ListInt, {})
  1302. .ATTR(dtype, Int, 0)
  1303. .OP_END_FACTORY_REG(Eye)
  1304. /**
  1305. *@brief: Fill diagonal of at least 2 dimension tensors with value . \n
  1306. *@par Inputs:
  1307. *x: A Tensor. Must be one of the following types:
  1308. * float32, int32, int64 . \n
  1309. *@par Outputs:
  1310. *y: A Tensor. Has the same type as "x" . \n
  1311. *@par Attributes:
  1312. *fill_value:The value to fill in
  1313. *wrap: An optional bool. Defaults to "False". If "True", Use recursive fill. \n
  1314. *@par Third-party framework compatibility
  1315. * Compatible with the Pytorch operator FillDiagonal.
  1316. */
  1317. REG_OP(FillDiagonal)
  1318. .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT64}))
  1319. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64}))
  1320. .REQUIRED_ATTR(fill_value, Float)
  1321. .ATTR(wrap, Bool, false)
  1322. .OP_END_FACTORY_REG(FillDiagonal)
  1323. /**
  1324. *@brief: Returns the sum of the elements of the diagonal of the input 2-D matrix. \n
  1325. *@par Inputs:
  1326. *x: A Tensor. Must be one of the following types:
  1327. * float16, float. \n
  1328. *@par Outputs:
  1329. *y: A Tensor. Has the same type as "x" . \n
  1330. *@par Third-party framework compatibility
  1331. * Compatible with the Pytorch operator Trace.
  1332. */
  1333. REG_OP(Trace)
  1334. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT}))
  1335. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT}))
  1336. .OP_END_FACTORY_REG(Trace)
  1337. /**
  1338. *@brief Computes the generalized inverse of any matrix. \n
  1339. *@par Inputs:
  1340. * @li x: input matrix. Must be one of the following types:
  1341. * double, float. \n
  1342. *@par Attributes:
  1343. * @li rcond: An optional float >= 0 or inf. Defaults to 1e-15. \n
  1344. *@par Outputs:
  1345. * y: A Tensor with the same type and shape of x's transpose. \n
  1346. */
  1347. REG_OP(Pinverse)
  1348. .INPUT(x, TensorType({ DT_FLOAT, DT_DOUBLE }))
  1349. .OUTPUT(y, TensorType({ DT_FLOAT, DT_DOUBLE }))
  1350. .ATTR(rcond, Float, 1e-15)
  1351. .OP_END_FACTORY_REG(Pinverse)
  1352. /**
  1353. * @brief From the input tensor and updates tensor, select the maximum value according to indices to output. \n
  1354. * @par Inputs:
  1355. * Three inputs, including:
  1356. * @li input: Must be one of the following types:
  1357. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1358. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  1359. * @li indices: Must be one of the following types:
  1360. * int32, int64.
  1361. * @li updates: Must have the same type as input. \n
  1362. * @par Outputs:
  1363. * output: A Tensor with the same type as input. \n
  1364. */
  1365. REG_OP(TensorScatterMax)
  1366. .INPUT(input, TensorType::BasicType())
  1367. .INPUT(indices, TensorType::IndexNumberType())
  1368. .INPUT(updates, TensorType::BasicType())
  1369. .OUTPUT(output, TensorType::BasicType())
  1370. .OP_END_FACTORY_REG(TensorScatterMax)
  1371. /**
  1372. * @brief From the input tensor and updates tensor, select the minimum value according to indices to output. \n
  1373. * @par Inputs:
  1374. * Three inputs, including:
  1375. * @li input: Must be one of the following types:
  1376. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1377. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  1378. * @li indices: Must be one of the following types:
  1379. * int32, int64.
  1380. * @li updates: Must have the same type as input. \n
  1381. * @par Outputs:
  1382. * output: A Tensor with the same type as input. \n
  1383. */
  1384. REG_OP(TensorScatterMin)
  1385. .INPUT(input, TensorType::BasicType())
  1386. .INPUT(indices, TensorType::IndexNumberType())
  1387. .INPUT(updates, TensorType::BasicType())
  1388. .OUTPUT(output, TensorType::BasicType())
  1389. .OP_END_FACTORY_REG(TensorScatterMin)
  1390. /**
  1391. * @brief: Returns the batched diagonal part of a batched tensor. \n
  1392. * @par Inputs:
  1393. * @li x: A Tensor. Rank r tensor where r >= 2.
  1394. * @li k: A Tensor of type int32. Diagonal offset(s). Positive value means superdiagonal,
  1395. 0 refers to the main diagonal, and negative value means subdiagonals. k can be a
  1396. single integer (for a single diagonal) or a pair of integers specifying the low and
  1397. high ends of a matrix band. k[0] must not be larger than k[1].
  1398. * @li padding_value:A Tensor. Must have the same type as input. The value to fill the area
  1399. outside the specified diagonal band with. Default is 0. \n
  1400. * @par Outputs:
  1401. * @li y: A Tensor. Has the same type as "input". \n
  1402. * @par Attributes:
  1403. * @li align:An optional string from: "LEFT_RIGHT", "RIGHT_LEFT", "LEFT_LEFT", "RIGHT_RIGHT". Defaults to "RIGHT_LEFT".
  1404. * @par Third-party framework compatibility
  1405. * Compatible with the Tensorflow operator FillDiagonal.
  1406. */
  1407. REG_OP(MatrixDiagPartV3)
  1408. .INPUT(x, TensorType::BasicType())
  1409. .INPUT(k, TensorType({DT_INT32}))
  1410. .INPUT(padding_value, TensorType::BasicType())
  1411. .OUTPUT(y, TensorType::BasicType())
  1412. .ATTR(align,String ,"RIGHT_LEFT")
  1413. .OP_END_FACTORY_REG(MatrixDiagPartV3)
  1414. /**
  1415. * @brief Returns a batched diagonal tensor with given batched diagonal values . \n
  1416. * @par Inputs:
  1417. * Five inputs, including:
  1418. * @li x: Rank `r`, where `r >= 1` \n
  1419. * @li k:
  1420. * Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main
  1421. * diagonal, and negative value means subdiagonals. `k` can be a single integer
  1422. * (for a single diagonal) or a pair of integers specifying the low and high ends
  1423. * of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1424. * @li num_rows:
  1425. * The number of rows of the output matrix. If it is not provided, the op assumes
  1426. * the output matrix is a square matrix and infers the matrix size from k and the
  1427. * innermost dimension of `diagonal`. \n
  1428. * @li num_cols: An NCHW, NHWC, or ND Tensor.
  1429. * The number of columns of the output matrix. If it is not provided, the op
  1430. * assumes the output matrix is a square matrix and infers the matrix size from
  1431. * k and the innermost dimension of `diagonal`. \n
  1432. * @li padding_value: The number to fill the area outside the specified diagonal band with. \n
  1433. * @par Attributes:
  1434. * @li align: An optional string from: "LEFT_RIGHT", "RIGHT_LEFT", "LEFT_LEFT", "RIGHT_RIGHT".
  1435. * Defaults to "RIGHT_LEFT" \n
  1436. * @par Outputs:
  1437. * @li y: Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise . \n
  1438. * @par Third-party framework compatibility
  1439. * Compatible with the TensorFlow operator ScatterUpdate.
  1440. */
  1441. REG_OP(MatrixDiagV3)
  1442. .INPUT(x, TensorType::BasicType())
  1443. .INPUT(k, TensorType({DT_INT32}))
  1444. .INPUT(num_rows, TensorType({DT_INT32}))
  1445. .INPUT(num_cols, TensorType({DT_INT32}))
  1446. .INPUT(padding_value, TensorType::BasicType())
  1447. .OUTPUT(y, TensorType::BasicType())
  1448. .ATTR(align, String, "RIGHT_LEFT")
  1449. .OP_END_FACTORY_REG(MatrixDiagV3)
  1450. } // namespace ge
  1451. #endif // OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示