You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_calculation_ops.h 72 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file matrix_calculation_ops.h
  18. * \brief
  19. */
  20. #ifndef OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_
  21. #define OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_
  22. #include "graph/operator_reg.h"
  23. namespace ge {
  24. /**
  25. * @brief Backprop W of AttentionLnQKV + ReduceSumD \n
  26. * @par Inputs:
  27. * Four inputs, including:
  28. * @li x: A Tensor. Must be one of the following types: float16.
  29. * @li query_dx: A Tensor. Must be one of the following types: float16.
  30. * @li key_dw: A Tensor. Must be one of the following types: float16.
  31. * @li value_dw: A Tensor. Must be one of the following types: float16.
  32. * @par Attributes:
  33. * @li trans_a: A optional attribute, the type is bool. Defaults to True.
  34. * @li trans_b: A optional attribute, the type is bool. Defaults to False. \n
  35. * @par Outputs:
  36. * Six outputs, including:
  37. * @li dw_query: A Tensor. Must be one of the following types: float16.
  38. * @li dw_key: A Tensor. Must be one of the following types: float16.
  39. * @li dw_value: A Tensor. Must be one of the following types: float16.
  40. * @li dbias_query: A Tensor. Must be one of the following types: float16.
  41. * @li dbias_key: A Tensor. Must be one of the following types: float16.
  42. * @li dbias_value: A Tensor. Must be one of the following types: float16. \n
  43. * @par Restrictions:
  44. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n
  45. */
  46. REG_OP(AttentionQKVGradW)
  47. .INPUT(x, TensorType({DT_FLOAT16}))
  48. .INPUT(query_dx, TensorType({DT_FLOAT16}))
  49. .OPTIONAL_INPUT(key_dw, TensorType({DT_FLOAT16}))
  50. .OPTIONAL_INPUT(value_dw, TensorType({DT_FLOAT16}))
  51. .OUTPUT(dw_query, TensorType({DT_FLOAT16}))
  52. .OUTPUT(dw_key, TensorType({DT_FLOAT16}))
  53. .OUTPUT(dw_value, TensorType({DT_FLOAT16}))
  54. .OUTPUT(dbias_query, TensorType({DT_FLOAT16}))
  55. .OUTPUT(dbias_key, TensorType({DT_FLOAT16}))
  56. .OUTPUT(dbias_value, TensorType({DT_FLOAT16}))
  57. .ATTR(trans_a, Bool, true)
  58. .ATTR(trans_b, Bool, false)
  59. .OP_END_FACTORY_REG(AttentionQKVGradW)
  60. /**
  61. * @brief Backprop X of AttentionLnQKV + AddN \n
  62. * @par Inputs:
  63. * Seven inputs, including:
  64. * @li ln_dx: A Tensor. Must be one of the following types: float16.
  65. * @li query_dx: A Tensor. Must be one of the following types: float16.
  66. * @li key_dw: A Tensor. Must be one of the following types: float16.
  67. * @li value_dw: A Tensor. Must be one of the following types: float16.
  68. * @li kernel_query: A Tensor. Must be one of the following types: float16.
  69. * @li kernel_key: A Tensor. Must be one of the following types: float16.
  70. * @li kernel_value: A Tensor. Must be one of the following types: float16. \n
  71. * @par Attributes:
  72. * @li trans_a: A optional attribute, the type is bool. Defaults to False.
  73. * @li trans_b: A optional attribute, the type is bool. Defaults to True. \n
  74. * @par Outputs:
  75. * One outputs, including:
  76. * @li dx: A Tensor. Must be one of the following types: float16. \n
  77. * @par Restrictions:
  78. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n
  79. */
  80. REG_OP(AttentionQKVGradX)
  81. .INPUT(ln_dx, TensorType({DT_FLOAT16}))
  82. .INPUT(query_dx, TensorType({DT_FLOAT16}))
  83. .INPUT(key_dw, TensorType({DT_FLOAT16}))
  84. .INPUT(value_dw, TensorType({DT_FLOAT16}))
  85. .INPUT(kernel_query, TensorType({DT_FLOAT16}))
  86. .INPUT(kernel_key, TensorType({DT_FLOAT16}))
  87. .INPUT(kernel_value, TensorType({DT_FLOAT16}))
  88. .OUTPUT(dx, TensorType({DT_FLOAT16}))
  89. .ATTR(trans_a, Bool, false)
  90. .ATTR(trans_b, Bool, true)
  91. .OP_END_FACTORY_REG(AttentionQKVGradX)
  92. /**
  93. * @brief
  94. / (MatMul -> ConfusionTransposeD).
  95. LayerNorm - (MatMul -> ConfusionTransposeD).
  96. \ (MatMul -> ConfusionTransposeD). \n
  97. * @par Inputs:
  98. * Nine inputs, including:
  99. * @li x: A Tensor. Must be one of the following types: float16.
  100. * @li kernel_query: A Tensor. Must be one of the following types: float16.
  101. * @li kernel_key: A Tensor. Must be one of the following types: float16.
  102. * @li kernel_value: A Tensor. Must be one of the following types: float16.
  103. * @li gamma: A Tensor. Must be one of the following types: float16.
  104. * @li beta: A Tensor. Must be one of the following types: float16.
  105. * @li bias_query: A Tensor. Must be one of the following types: float16.
  106. * @li bias_key: A Tensor. Must be one of the following types: float16.
  107. * @li bias_value: A Tensor. Must be one of the following types: float16. \n
  108. * @par Attributes:
  109. * @li epsilon: A optional attribute, the type is float32. Defaults to 1e-7.
  110. * @li trans_a: A optional attribute, the type is bool. Defaults to False.
  111. * @li trans_b: A optional attribute, the type is bool. Defaults to False. \n
  112. * @par Outputs:
  113. * Six outputs, including:
  114. * @li norm: A Tensor. Must be one of the following types: float16.
  115. * @li query_output: A Tensor. Must be one of the following types: float16.
  116. * @li key_output: A Tensor. Must be one of the following types: float16.
  117. * @li value_output: A Tensor. Must be one of the following types: float16.
  118. * @li mean: A Tensor. Must be one of the following types: float16.
  119. * @li variance: A Tensor. Must be one of the following types: float16. \n
  120. * @par Restrictions:
  121. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n
  122. */
  123. REG_OP(AttentionLnQKV)
  124. .INPUT(x, TensorType({DT_FLOAT16}))
  125. .INPUT(kernel_query, TensorType({DT_FLOAT16}))
  126. .INPUT(kernel_key, TensorType({DT_FLOAT16}))
  127. .INPUT(kernel_value, TensorType({DT_FLOAT16}))
  128. .INPUT(gamma, TensorType({DT_FLOAT16}))
  129. .INPUT(beta, TensorType({DT_FLOAT16}))
  130. .OPTIONAL_INPUT(bias_query, TensorType({DT_FLOAT16}))
  131. .OPTIONAL_INPUT(bias_key, TensorType({DT_FLOAT16}))
  132. .OPTIONAL_INPUT(bias_value, TensorType({DT_FLOAT16}))
  133. .OUTPUT(norm, TensorType({DT_FLOAT16}))
  134. .OUTPUT(query_output, TensorType({DT_FLOAT16}))
  135. .OUTPUT(key_output, TensorType({DT_FLOAT16}))
  136. .OUTPUT(value_output, TensorType({DT_FLOAT16}))
  137. .OUTPUT(mean, TensorType({DT_FLOAT16}))
  138. .OUTPUT(variance, TensorType({DT_FLOAT16}))
  139. .ATTR(epsilon, Float, 0.0000001)
  140. .ATTR(trans_a, Bool, false)
  141. .ATTR(trans_b, Bool, false)
  142. .OP_END_FACTORY_REG(AttentionLnQKV)
  143. /**
  144. * @brief
  145. swin_transformer model specific structure.Operator only supports swin_transformer. \n
  146. * @par Inputs:
  147. * Five inputs, including:
  148. * @li x: A Tensor. Must be one of the following types: float16.
  149. * @li gamma: A Tensor. Must be one of the following types: float16.
  150. * @li beta: A Tensor. Must be one of the following types: float16.
  151. * @li weight: A Tensor. Must be one of the following types: float16.
  152. * @li bias: A Tensor. Must be one of the following types: float16. \n
  153. * @par Attributes:
  154. * @li head_num: A optional attribute, the type is int.
  155. * @li head_dim: A optional attribute, the type is int.
  156. * @li seq_length: A optional attribute, the type is int.
  157. * @li shifts: A optional attribute, the type is list int. Defaults to ().
  158. * @li epsilon: A optional attribute, the type is float. Defaults to 1e-7. \n
  159. * @par Outputs:
  160. * Three outputs, including:
  161. * @li query_output: A Tensor. Must be one of the following types: float16.
  162. * @li key_output: A Tensor. Must be one of the following types: float16.
  163. * @li value_output: A Tensor. Must be one of the following types: float16. \n
  164. * @par Restrictions:
  165. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n
  166. */
  167. REG_OP(SwinTransformerLnQKV)
  168. .INPUT(x, TensorType({DT_FLOAT16}))
  169. .INPUT(gamma, TensorType({DT_FLOAT16}))
  170. .INPUT(beta, TensorType({DT_FLOAT16}))
  171. .INPUT(weight, TensorType({DT_FLOAT16}))
  172. .INPUT(bias, TensorType({DT_FLOAT16}))
  173. .OUTPUT(query_output, TensorType({DT_FLOAT16}))
  174. .OUTPUT(key_output, TensorType({DT_FLOAT16}))
  175. .OUTPUT(value_output, TensorType({DT_FLOAT16}))
  176. .REQUIRED_ATTR(head_num, Int)
  177. .REQUIRED_ATTR(head_dim, Int)
  178. .REQUIRED_ATTR(seq_length, Int)
  179. .ATTR(shifts, ListInt, {})
  180. .ATTR(epsilon, Float, 0.0000001)
  181. .OP_END_FACTORY_REG(SwinTransformerLnQKV)
  182. /**
  183. * @brief Multiplies matrix "a" by matrix "b", producing "a * b". \n
  184. * @par Inputs:
  185. * Three inputs, including:
  186. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float16,
  187. * float32, int32. Has format [ND, NHWC].
  188. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float16,
  189. * float32, int32. Has format [ND, NHWC].
  190. * @li bias: A optional 1D Tensor. Must be one of the following types: float16,
  191. * float32, int32. Has format [ND, NHWC]. \n
  192. * @par Attributes:
  193. * @li transpose_x1: A bool. If True, changes the shape of "x1" from [M, K] to
  194. * [K, M].
  195. * @li transpose_x2: A bool. If True, changes the shape of "x2" from [M, K] to
  196. * [K, M]. \n
  197. * @par Outputs:
  198. * y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  199. * float32, int32. Has format [ND, NHWC]. \n
  200. * @par Third-party framework compatibility
  201. * Compatible with the TensorFlow operator BatchMatmul.
  202. */
  203. REG_OP(MatMul)
  204. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  205. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  206. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  207. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  208. .ATTR(transpose_x1, Bool, false)
  209. .ATTR(transpose_x2, Bool, false)
  210. .OP_END_FACTORY_REG(MatMul)
  211. /**
  212. * @brief Multiplies matrix "a" by matrix "b", producing "a * b". \n
  213. * @par Inputs:
  214. * Four inputs, including:
  215. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float32,
  216. * float16, int32, int8, int4. Has format [ND, NHWC].
  217. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float32,
  218. * float16, int32, int8, int4. Has format [ND, NHWC].
  219. * @li bias: A 1D Tensor. Must be one of the following types: float32,
  220. * float16, int32. Has format [ND, NHWC].
  221. * @li offset_w: A Optional 1D Tensor for quantized inference. Type is int8, int4.
  222. * Reserved. \n
  223. * @par Attributes:
  224. * @li transpose_x1: A bool. If True, changes the shape of "x1" from [K, M] to
  225. * [M, K].
  226. * @li transpose_x2: A bool. If True, changes the shape of "x2" from [N, K] to
  227. * [K, N].
  228. * @li offset_x: An optional integer for quantized MatMulV2.
  229. * The negative offset added to the input x1 for int8 type. Ensure offset_x
  230. * within the effective range of int8 [-128, 127]. Defaults to "0". \n
  231. * @par Outputs:
  232. * y: The result matrix Tensor. 2D. Must be one of the following types: float32,
  233. * float16, int32. Has format [ND, NHWC]. \n
  234. * @attention Constraints:
  235. * if performances better in format NZ, please close
  236. * "MatmulTransdataFusionPass" in fusion configuration. \n
  237. * @par Third-party framework compatibility
  238. * Compatible with the TensorFlow operator BatchMatmul.
  239. */
  240. REG_OP(MatMulV2)
  241. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_INT4, DT_BF16}))
  242. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_INT4, DT_BF16}))
  243. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  244. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  245. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8, DT_INT4}))
  246. .ATTR(transpose_x1, Bool, false)
  247. .ATTR(transpose_x2, Bool, false)
  248. .ATTR(offset_x, Int, 0)
  249. .OP_END_FACTORY_REG(MatMulV2)
  250. /**
  251. * @brief Multiplies matrix "a" by matrix "b", producing "a * b". \n
  252. * @par Inputs:
  253. * Five inputs, including:
  254. * @li x1: A matrix Tensor. 2D. Must be one of the following types: int8.
  255. * @li x2: A matrix Tensor. 2D. Must be one of the following types: int8.
  256. * @li compress_index: A compress index matrix of type int8.
  257. * @li bias: An optional Tensor. 1D. Must be one of the following types: int32,
  258. * float16.
  259. * @li offset_w: An optional matrix Tensor. 2D. Must be one of the following
  260. * types: int8. \n
  261. * @par Attributes:
  262. * @li transpose_x1: A bool. If True, changes the shape of "x1" from [K, M] to
  263. * [M, K].
  264. * @li transpose_x2: A bool. If True, changes the shape of "x2" from [N, K] to
  265. * [K, N].
  266. * @li offset_x: An optional integer for quantized MatMulV2Compress.
  267. * The negative offset added to the input x1 for int8 type. Ensure offset_x
  268. * within the effective range of int8 [-128, 127]. Defaults to "0". \n
  269. * @par Outputs:
  270. * y: The result matrix Tensor. 2D. Must be one of the following types: int32,
  271. * float16. \n
  272. * @attention Constraints:
  273. * if performances better in format NZ, please close
  274. * "MatmulTransdataFusionPass" in fusion configuration.
  275. */
  276. REG_OP(MatMulV2Compress)
  277. .INPUT(x1, TensorType({DT_INT8}))
  278. .INPUT(x2, TensorType({DT_INT8}))
  279. .INPUT(compress_index, TensorType({DT_INT8}))
  280. .OPTIONAL_INPUT(bias, TensorType({DT_INT32, DT_FLOAT16}))
  281. .OUTPUT(y, TensorType({DT_INT32, DT_FLOAT16}))
  282. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  283. .ATTR(transpose_x1, Bool, false)
  284. .ATTR(transpose_x2, Bool, false)
  285. .ATTR(offset_x, Int, 0)
  286. .OP_END_FACTORY_REG(MatMulV2Compress)
  287. /**
  288. * @brief Performs Matrix-to-matrix Multiply,
  289. * producing y=alpha[0]*a*b+beta[0]*c. \n
  290. * @attention Constraints:
  291. * For better performance, The k-axis must be aligned to 16 (input type
  292. * is float16) or 32 (input type is int8). \n
  293. * @par Inputs:
  294. * Five inputs, including:
  295. * @li a: A matrix Tensor. Must be one of the following types:float32, float16,
  296. * int8, int32. Has format ND.
  297. * @li b: A matrix Tensor. Must be one of the following types:float32, float16,
  298. * int8, int32. Has format ND.
  299. * @li c: A matrix Tensor. Must be one of the following types:float32, float16,
  300. * int8, int32. Has format ND.
  301. * @li alpha: A 1D Tensor. The shape of alpha is [1].Must be one of the
  302. * following types: float16, int32, float32, int8. Has format ND.
  303. * @li beta: A 1D Tensor. The shape of beta is [1]. Must be one of the following
  304. * types: float16, int32, float32, int8. Has format ND.\n
  305. * The format of a, b, c has restriction:\n
  306. * When type of a is int8 and type of c is int32, the format of a, b, c should
  307. * all be ND.\n
  308. * When type of a is int8 and type of c is float32, the format of a, b, c
  309. * should all be ND.\n
  310. * When type of a is float16 and type of c is float16, the format of a, b, c
  311. * should all be ND.\n
  312. * When type of a is float16 and type of c is float32, the format of a, b, c
  313. * should all be ND. \n
  314. * @par Attributes:
  315. * Two attributes, including:
  316. * @li transpose_a: Optional. A bool. If True, changes the shape of "a" from
  317. * [M, K] to [K, M].
  318. * @li transpose_b: Optional. A bool. If True, changes the shape of "b" from
  319. * [K, N] to [N, K]. \n
  320. * @par Outputs:
  321. * y: The result matrix Tensor. Must be one of the following types: float16,
  322. * float32, int32, int8. Has format [ND], the format should be equal to a.
  323. */
  324. REG_OP(GEMM)
  325. .INPUT(a, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  326. .INPUT(b, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  327. .INPUT(c, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  328. .INPUT(alpha, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  329. .INPUT(beta, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  330. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  331. .ATTR(transpose_a, Bool, false)
  332. .ATTR(transpose_b, Bool, false)
  333. .OP_END_FACTORY_REG(GEMM)
  334. /**
  335. * @brief Multiplies matrix "a" by matrix "b", producing "a * b". \n
  336. * @par Inputs:
  337. * Two inputs, including:
  338. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  339. * float32, int32. 2D or higher. Has format [ND, NHWC].
  340. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  341. * float32, int32. 2D or higher. Has format [ND, NHWC]. \n
  342. * @par Attributes:
  343. * @li adj_x1: A bool. If True, changes the shape of "x1" from [B, M, K]
  344. * to [B, K, M].
  345. * @li adj_x2: A bool. If True, changes the shape of "x2" from [B, M, K]
  346. * to [B, K, M]. \n
  347. * @par Outputs:
  348. * y: The result matrix Tensor. 2D or higher. Must be one of the following
  349. * types: float16,
  350. * float32, int32. 2D or higher. Has format [ND, NHWC]. Has the same shape
  351. * length as "x1" and "x2". \n
  352. * @par Third-party framework compatibility
  353. * Compatible with the TensorFlow operator BatchMatmul.
  354. */
  355. REG_OP(BatchMatMul)
  356. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  357. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  358. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  359. .ATTR(adj_x1, Bool, false)
  360. .ATTR(adj_x2, Bool, false)
  361. .OP_END_FACTORY_REG(BatchMatMul)
  362. /**
  363. * @brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  364. * @par Inputs:
  365. * Three inputs, including:
  366. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  367. * float32, int32, int8, int4. 2D or higher. Has format [ND, NHWC].
  368. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  369. * float32, int32, int8, int4. 2D or higher. Has format [ND, NHWC].
  370. * @li bias: A optional Tensor. Must be one of the following types:
  371. * float16,
  372. * float32, int32. Has format [ND, NHWC].
  373. * @li offset_w: A optional Tensor. Must be one of the following types:
  374. * int8, int4. Has format [ND, NHWC]. \n
  375. * @par Attributes:
  376. * @li adj_x1: A bool. If True, changes the shape of "x1" from [B, M, K] to
  377. * [B, K, M].
  378. * @li adj_x2: A bool. If True, changes the shape of "x2" from [B, M, K] to
  379. * [B, K, M]. \n
  380. * @par Outputs:
  381. * y: The result matrix Tensor. 2D or higher. Must be one of the following
  382. * types: float16,
  383. * float32, int32. 2D or higher. Has format [ND, NHWC]. Has the same shape
  384. * length as "x1" and "x2". \n
  385. * @attention Constraints:
  386. * if performances better in format NZ, please close
  387. * "MatmulTransdataFusionPass" in fusion configuration. \n
  388. * @par Third-party framework compatibility
  389. * Compatible with the TensorFlow operator BatchMatmul.
  390. */
  391. REG_OP(BatchMatMulV2)
  392. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_INT4, DT_BF16}))
  393. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_INT4, DT_BF16}))
  394. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  395. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8, DT_INT4}))
  396. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  397. .ATTR(adj_x1, Bool, false)
  398. .ATTR(adj_x2, Bool, false)
  399. .ATTR(offset_x, Int, 0)
  400. .OP_END_FACTORY_REG(BatchMatMulV2)
  401. /**
  402. * @brief Computes half the L2 norm of a tensor without the sqrt . \n
  403. * @par Inputs:
  404. * x: A Tensor.
  405. * TensorType::FloatingDataType() . \n
  406. * @par Outputs:
  407. * y: A Tensor. Has the same type as "x". \n
  408. * @attention Constraints:
  409. * if performances better in format NZ, please close
  410. * "MatmulTransdataFusionPass" in fusion configuration. \n
  411. * @par Third-party framework compatibility
  412. * Compatible with the TensorFlow operator L2Loss.
  413. */
  414. REG_OP(L2Loss)
  415. .INPUT(x, TensorType::FloatingDataType())
  416. .OUTPUT(y, TensorType::FloatingDataType())
  417. .OP_END_FACTORY_REG(L2Loss)
  418. /**
  419. * @brief: Returns a batched diagonal tensor with a given batched diagonal values . \n
  420. * @par Inputs:
  421. * x: A Tensor. Must be one of the following types:
  422. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  423. * qint8, quint8, qint32, uint16, complex128, uint32, uint64 . \n
  424. * @par Outputs:
  425. * y: A Tensor. Has the same type as "x" . \n
  426. * @par Third-party framework compatibility
  427. * Compatible with the TensorFlow operator MatrixDiag.
  428. */
  429. REG_OP(MatrixDiag)
  430. .INPUT(x, TensorType::BasicType())
  431. .OUTPUT(y, TensorType::BasicType())
  432. .OP_END_FACTORY_REG(MatrixDiag)
  433. /**
  434. * @brief: Returns a batched diagonal tensor with a given batched diagonal values . \n
  435. * @par Inputs:
  436. * Two inputs, including:
  437. * @li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  438. * @li assist: A Tensor of the same type as "x" . \n
  439. * @par Outputs:
  440. *y: A Tensor. Has the same type as "x" . \n
  441. * @par Third-party framework compatibility
  442. * Compatible with the TensorFlow operator MatrixDiag.
  443. *
  444. * @par Restrictions:
  445. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixDiag instead.
  446. */
  447. REG_OP(MatrixDiagD)
  448. .INPUT(x, TensorType::BasicType())
  449. .INPUT(assist, TensorType::BasicType())
  450. .OUTPUT(y, TensorType::BasicType())
  451. .OP_END_FACTORY_REG(MatrixDiagD)
  452. /**
  453. * @brief: Returns the batched diagonal part of a batched tensor . \n
  454. * @par Inputs:
  455. * x: A Tensor. Must be one of the following types:
  456. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  457. * qint8, quint8, qint32, uint16, complex128, uint32, uint64 . \n
  458. * @par Outputs:
  459. * y: A Tensor. Has the same type as "x" . \n
  460. * @par Third-party framework compatibility
  461. * Compatible with the TensorFlow operator MatrixDiagPart.
  462. */
  463. REG_OP(MatrixDiagPart)
  464. .INPUT(x, TensorType::BasicType())
  465. .OUTPUT(y, TensorType::BasicType())
  466. .OP_END_FACTORY_REG(MatrixDiagPart)
  467. /**
  468. * @brief: Returns the batched diagonal part of a batched tensor . \n
  469. * @par Inputs:
  470. * Two inputs, including:
  471. * @li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  472. * @li assist: A Tensor of the same type as "x" . \n
  473. * @par Outputs:
  474. * y: A Tensor. Has the same type as "x" . \n
  475. * @par Third-party framework compatibility
  476. * Compatible with the TensorFlow operator MatrixDiagPart.
  477. *
  478. * @par Restrictions:
  479. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixDiagPart instead.
  480. */
  481. REG_OP(MatrixDiagPartD)
  482. .INPUT(x, TensorType::BasicType())
  483. .INPUT(assist, TensorType::BasicType())
  484. .OUTPUT(y, TensorType::BasicType())
  485. .OP_END_FACTORY_REG(MatrixDiagPartD)
  486. /**
  487. * @brief: Returns a batched matrix tensor with new batched diagonal values . \n
  488. * @par Inputs:
  489. * Two inputs, including:
  490. * @li x: A Tensor. Must be one of the following types:
  491. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  492. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  493. * @li diagonal: A Tensor of the same type as "x" . \n
  494. * @par Outputs:
  495. * y: A Tensor. Has the same type as "x" . \n
  496. * @par Third-party framework compatibility
  497. * Compatible with the TensorFlow operator MatrixSetDiag.
  498. */
  499. REG_OP(MatrixSetDiag)
  500. .INPUT(x, TensorType::BasicType())
  501. .INPUT(diagonal, TensorType::BasicType())
  502. .OUTPUT(y, TensorType::BasicType())
  503. .OP_END_FACTORY_REG(MatrixSetDiag)
  504. /**
  505. * @brief: Returns a batched matrix tensor with new batched diagonal values . \n
  506. * @par Inputs:
  507. * Three inputs, including:
  508. * @li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  509. * @li diagonal: A Tensor of the same type as "x".
  510. * @li assist: A Tensor of the same type as "x" . \n
  511. * @par Outputs:
  512. * y: A Tensor. Has the same type as "x" . \n
  513. * @par Third-party framework compatibility
  514. * Compatible with the TensorFlow operator MatrixSetDiag.
  515. *
  516. * @par Restrictions:
  517. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixSetDiag instead.
  518. */
  519. REG_OP(MatrixSetDiagD)
  520. .INPUT(x, TensorType::BasicType())
  521. .INPUT(diagonal, TensorType::BasicType())
  522. .INPUT(assist, TensorType::BasicType())
  523. .OUTPUT(y, TensorType::BasicType())
  524. .OP_END_FACTORY_REG(MatrixSetDiagD)
  525. /**
  526. * @brief Function AttentionScore. \n
  527. * @par Inputs:
  528. * six inputs, including:
  529. * @li query: A matrix Tensor. The type only support float16.
  530. * @li key: A matrix Tensor. The type only support float16.
  531. * @li value: A matrix Tensor. The type only support float16.
  532. * @li padding_mask: A matrix Tensor. The type only support float16.
  533. * @li scale: A scalar. The type only support float16.
  534. * @li drop_mask: A matrix Tensor. The type only support uint8. \n
  535. * @par Attributes:
  536. * @li keep_prob: A mutable Tensor. Must met all of the following rules:
  537. shape of "keep_prob" should be (1,) or [1,].
  538. * @li query_transpose: A bool. If True, changes the shape of "query" from [K, M] to
  539. [M, K].
  540. * @li key_transpose: A bool. If True, changes the shape of "key" from [N, K] to
  541. [K, N].
  542. * @li bmm_score_transpose_a: A bool. If True, changes the shape of "mid_data" from [K, M] to
  543. [M, K].
  544. * @li bmm_score_transpose_b: A bool. If True, changes the shape of "value" from [N, K] to
  545. [K, N].
  546. * @li axes: A list of int. The dimension softmax would be performed on. Defaults
  547. to "[-1]" . \n
  548. * @par Outputs:
  549. * attention_score: The result matrix Tensor. The type only support float16.
  550. * softmax_output: The result matrix Tensor. The type only support float16.
  551. * @par Restrictions:
  552. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  553. */
  554. REG_OP(AttentionScore)
  555. .INPUT(query, TensorType({DT_FLOAT16}))
  556. .INPUT(key, TensorType({DT_FLOAT16}))
  557. .INPUT(value, TensorType({DT_FLOAT16}))
  558. .INPUT(padding_mask, TensorType({DT_FLOAT16}))
  559. .INPUT(scale, TensorType({DT_FLOAT16}))
  560. .OPTIONAL_INPUT(drop_mask, TensorType({DT_INT8}))
  561. .OUTPUT(attention_score, TensorType({DT_FLOAT16}))
  562. .OUTPUT(softmax_output, TensorType({DT_FLOAT16}))
  563. .ATTR(keep_prob, Float, 1.0)
  564. .ATTR(query_transpose, Bool, false)
  565. .ATTR(key_transpose, Bool, false)
  566. .ATTR(bmm_score_transpose_a, Bool, false)
  567. .ATTR(bmm_score_transpose_b, Bool, false)
  568. .ATTR(softmax_axes, ListInt, {-1})
  569. .OP_END_FACTORY_REG(AttentionScore)
  570. /**
  571. * @brief Function AttentionScoreGrad. \n
  572. * @par Inputs:
  573. * seven inputs, including:
  574. * @li attention_score: A matrix Tensor. The type only support float16.
  575. * @li dx: A matrix Tensor. The type only support float16.
  576. * @li query: A matrix Tensor. The type only support float16.
  577. * @li key: A matrix Tensor. The type only support float16.
  578. * @li value: A matrix Tensor. The type only support float16.
  579. * @li scale: A scalar. The type only support float16.
  580. * @li drop_mask: A matrix Tensor. The type only support uint8. \n
  581. * @par Attributes:
  582. * @li keep_prob: A mutable Tensor. Must met all of the following rules:
  583. shape of "keep_prob" should be (1,) or [1,].
  584. * @li query_transpose: A bool. If True, changes the shape of "query" from [K, M] to
  585. [M, K].
  586. * @li key_transpose: A bool. If True, changes the shape of "key" from [N, K] to
  587. [K, N].
  588. * @li value_transpose: A bool. If True, changes the shape of "mid_data" from [K, M] to
  589. [M, K].
  590. * @li dx_transpose: A bool. If True, changes the shape of "value" from [N, K] to
  591. [K, N].
  592. * @li softmax_axes: A int. The dimension softmax would be performed on. Defaults
  593. to "-1" . \n
  594. * @par Outputs:
  595. * value_dw: The result matrix Tensor. The type only support float16.
  596. * query_dx: The result matrix Tensor. The type only support float16.
  597. * key_dw: The result matrix Tensor. The type only support float16.
  598. * @par Restrictions:
  599. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  600. */
  601. REG_OP(AttentionScoreGrad)
  602. .INPUT(attention_score, TensorType({DT_FLOAT16}))
  603. .INPUT(dx, TensorType({DT_FLOAT16}))
  604. .INPUT(query, TensorType({DT_FLOAT16}))
  605. .INPUT(key, TensorType({DT_FLOAT16}))
  606. .INPUT(value, TensorType({DT_FLOAT16}))
  607. .INPUT(scale, TensorType({DT_FLOAT16}))
  608. .OPTIONAL_INPUT(drop_mask, TensorType({DT_INT8}))
  609. .OUTPUT(value_dw, TensorType({DT_FLOAT16}))
  610. .OUTPUT(query_dx, TensorType({DT_FLOAT16}))
  611. .OUTPUT(key_dw, TensorType({DT_FLOAT16}))
  612. .ATTR(keep_prob, Float, 1.0)
  613. .ATTR(query_transpose, Bool, false)
  614. .ATTR(key_transpose, Bool, false)
  615. .ATTR(value_transpose, Bool, false)
  616. .ATTR(dx_transpose, Bool, false)
  617. .ATTR(softmax_axes, Int, -1)
  618. .OP_END_FACTORY_REG(AttentionScoreGrad)
  619. /**
  620. * @brief Applies sparse "updates" to individual values or slices in a Variable . \n
  621. * @par Inputs:
  622. * Three inputs, including:
  623. * @li var: An ND Tensor.
  624. * Must be one of the following types: float16, float32, int8, uint8, double,
  625. * int64, complex64, qint8, quint8, qint32, uint16, complex128, half, uint32,
  626. * uint64
  627. * @li indices: An ND Tensor.
  628. * Must be one of the following types: int32 or int64
  629. * @li updates: An ND Tensor.
  630. * Must be one of the following types: float16, float32, int8, uint8, double,
  631. * int64, complex64, qint8, quint8, qint32, uint16, complex128, half, uint32,
  632. * uint64
  633. * @par Attributes:
  634. * use_locking: An optional bool. Defaults to "False". If "True",
  635. * the operation will be protected by a lock . \n
  636. * @par Outputs:
  637. * var: A Tensor. Has the same type and format as input "var" . \n
  638. * @par Third-party framework compatibility
  639. * Compatible with the TensorFlow operator ScatterNdUpdate.
  640. */
  641. REG_OP(ScatterNdUpdate)
  642. .INPUT(var, TensorType::BasicType())
  643. .INPUT(indices, TensorType::IndexNumberType())
  644. .INPUT(updates, TensorType::BasicType())
  645. .OUTPUT(var, TensorType::BasicType())
  646. .ATTR(use_locking, Bool, false)
  647. .OP_END_FACTORY_REG(ScatterNdUpdate)
  648. /**
  649. * @brief Applies sparse addition to individual values or slices in a Variable . \n
  650. * @par Inputs:
  651. * Three inputs, including:
  652. * @li x: An ND Tensor. \n
  653. * Must be one of the following types: float16, float32, bool, int8, uint8
  654. * @li indices: An ND Tensor. \n
  655. * Must be one of the following types: int32
  656. * @li updates: An ND Tensor. \n
  657. * Must be one of the following types: float16, float32, bool, int8, uint8
  658. * @par Outputs:
  659. * y: A Tensor. Has the same type and format as input "x" . \n
  660. * @par Third-party framework compatibility
  661. * Compatible with the TensorFlow operator TensorScatterUpdate.
  662. * @par Restrictions:
  663. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  664. */
  665. REG_OP(TensorScatterUpdate)
  666. .INPUT(x, TensorType::BasicType())
  667. .INPUT(indices, TensorType::IndexNumberType())
  668. .INPUT(updates, TensorType::BasicType())
  669. .OUTPUT(y, TensorType::BasicType())
  670. .OP_END_FACTORY_REG(TensorScatterUpdate)
  671. /**
  672. * @brief Uses "updates" to update tensor "data" by "indices". \n
  673. * @par Inputs:
  674. * Three inputs, including:
  675. * @li data: An ND Tensor . \n
  676. * Must be one of the following types: float16, float32, int32, int8, uint8
  677. * @li indices: An ND Tensor of type int32 or int64
  678. * @li updates: An Tensor. Same shape as indices. format:NCHW, NHWC . \n
  679. * Must be one of the following types: float16, float32, int32, int8, uint8
  680. * @par Attributes:
  681. * @li axis: An optional attribute. Defaults to 0.
  682. * @li reduction: An optional attribute. Defaults to string "none" and can be "add" or "mul".
  683. * @par Outputs:
  684. * y: A Tensor. Has the same type and format as input "data" . \n
  685. * @par Third-party framework compatibility
  686. * Compatible with the ONNX operator ScatterElements.
  687. */
  688. REG_OP(ScatterElements)
  689. .INPUT(data, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  690. .INPUT(indices, TensorType::IndexNumberType())
  691. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  692. .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  693. .ATTR(axis, Int, 0)
  694. .ATTR(reduction, String, "none")
  695. .OP_END_FACTORY_REG(ScatterElements)
  696. /**
  697. * @brief Uses "updates" to update tensor "data" by "indices". \n
  698. * @par Inputs:
  699. * Three inputs, including:
  700. * @li var: A Tensor of type BasicType.
  701. * @li indices: An ND Tensor of type int32 or int64.
  702. * @li updates: An Tensor with the same dtype as 'var'. Same shape as indices. \n
  703. * @par Attributes:
  704. * @li use_locking: An optional bool. Defaults to "False". If "True",
  705. * the operation will be protected by a lock . \n
  706. * @par Outputs:
  707. * var: A Tensor. Has the same type and format as input "var" . \n
  708. * @par Third-party framework compatibility
  709. * Compatible with the TensorFlow operator ScatterNdMax.
  710. */
  711. REG_OP(ScatterNdMax)
  712. .INPUT(var, TensorType::BasicType())
  713. .INPUT(indices, TensorType::IndexNumberType())
  714. .INPUT(updates, TensorType::BasicType())
  715. .OUTPUT(var, TensorType::BasicType())
  716. .ATTR(use_locking, Bool, false)
  717. .OP_END_FACTORY_REG(ScatterNdMax)
  718. /**
  719. * @brief Adds sparse "updates" to a variable reference . \n
  720. * @par Inputs:
  721. * Three inputs, including:
  722. * @li var: An ND Tensor .
  723. * Must be one of the following types: float16, float, int32, int8, uint8
  724. * @li indices: An ND Tensor . \n
  725. * Must be one of the following types: int32 or int64
  726. * @li updates: An ND Tensor .
  727. * Must be one of the following types: float16, float, int32, int8, uint8
  728. * @par Attributes:
  729. * use_locking: An optional bool. Defaults to "False". If "True",
  730. * the operation will be protected by a lock . \n
  731. * @par Outputs:
  732. * var: A Tensor. Has the same type and format as input "var" . \n
  733. * @par Third-party framework compatibility
  734. * Compatible with the TensorFlow operator ScatterAdd.
  735. */
  736. REG_OP(ScatterAdd)
  737. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  738. .INPUT(indices, TensorType::IndexNumberType())
  739. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  740. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  741. .ATTR(use_locking, Bool, false)
  742. .OP_END_FACTORY_REG(ScatterAdd)
  743. /**
  744. * @brief Adds sparse "updates" to a variable reference . \n
  745. * @par Inputs:
  746. * Three inputs, including:
  747. * @li var: An ND Tensor .
  748. * Must be one of the following types: float16, float32, int32, int8, uint8
  749. * @li indices: An ND Tensor of type int32 or int64
  750. * @li updates: An ND Tensor .
  751. * Must be one of the following types: float16, float32, int32, int8, uint8
  752. * @par Attributes:
  753. * axis: An required int. The axis along which to index. \n
  754. * @par Outputs:
  755. * var: A Tensor. Has the same type and format as input "var" . \n
  756. * @par Third-party framework compatibility
  757. * Compatible with the pytorch operator ScatterAdd.
  758. */
  759. REG_OP(ScatterAddWithAxis)
  760. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  761. .INPUT(indices, TensorType::IndexNumberType())
  762. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  763. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  764. .REQUIRED_ATTR(axis, Int)
  765. .OP_END_FACTORY_REG(ScatterAddWithAxis)
  766. /**
  767. * @brief Divides a variable reference by sparse updates . \n
  768. * @par Inputs:
  769. * Three inputs, including:
  770. * @li var: An ND Tensor.
  771. * Must be one of the following types: float16, float, int32, int8, uint8
  772. * @li indices: An ND Tensor.
  773. * Must be one of the following types: int32 or int64
  774. * @li updates: An ND Tensor.
  775. * Must be one of the following types: float16, float, int32, int8, uint8
  776. * @par Attributes:
  777. * use_locking: An optional bool. Defaults to "False". If "True",
  778. * the operation will be protected by a lock . \n
  779. * @par Outputs:
  780. * var: A Tensor. Has the same type and format as input "var" . \n
  781. * @par Third-party framework compatibility
  782. * Compatible with the TensorFlow operator ScatterDiv.
  783. */
  784. REG_OP(ScatterDiv)
  785. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  786. .INPUT(indices, TensorType::IndexNumberType())
  787. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  788. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  789. .ATTR(use_locking, Bool, false)
  790. .OP_END_FACTORY_REG(ScatterDiv)
  791. /**
  792. * @brief Applies sparse addition to individual values or slices in a Variable . \n
  793. * @par Inputs:
  794. * Three inputs, including:
  795. * @li var: An ND Tensor.
  796. * Must be one of the following types: float16, float, int32, int8, uint8
  797. * @li indices: An ND Tensor.
  798. * Must be one of the following types: int32 or int64
  799. * @li updates: An ND Tensor.
  800. * Must be one of the following types: float16, float, int32, int8, uint8
  801. * @par Attributes:
  802. * use_locking: An optional bool. Defaults to "False". If "True",
  803. * the operation will be protected by a lock . \n
  804. * @par Outputs:
  805. * var: A Tensor. Has the same type and format as input "var" . \n
  806. * @par Third-party framework compatibility
  807. * Compatible with the TensorFlow operator ScatterNdAdd.
  808. */
  809. REG_OP(ScatterNdAdd)
  810. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  811. .INPUT(indices, TensorType::IndexNumberType())
  812. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  813. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  814. .ATTR(use_locking, Bool, false)
  815. .OP_END_FACTORY_REG(ScatterNdAdd)
  816. /**
  817. * @brief Applies sparse addition to individual values or slices in a Variable . \n
  818. * @par Inputs:
  819. * Three inputs, including:
  820. * @li x: An ND Tensor. \n
  821. * Must be one of the following types: float16, float32, int32, int8, uint8
  822. * @li indices: An ND Tensor. \n
  823. * Must be one of the following types: int32
  824. * @li updates: An ND Tensor. \n
  825. * Must be one of the following types: float16, float32, int32, int8, uint8
  826. * @par Outputs:
  827. * y: A Tensor. Has the same type and format as input "x" . \n
  828. * @par Third-party framework compatibility
  829. * Compatible with the TensorFlow operator TensorScatterAdd.
  830. * @par Restrictions:
  831. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  832. */
  833. REG_OP(TensorScatterAdd)
  834. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  835. .INPUT(indices, TensorType::IndexNumberType())
  836. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  837. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  838. .OP_END_FACTORY_REG(TensorScatterAdd)
  839. /**
  840. * @brief Applies sparse subtraction to individual values or slices in a Variable . \n
  841. * @par Inputs:
  842. * Three inputs, including:
  843. * @li var: An ND Tensor.
  844. * Must be one of the following types: float16, float, int32, int8, uint8
  845. * @li indices: An ND Tensor.
  846. * Must be one of the following types: int32 or int64
  847. * @li updates: An ND Tensor.
  848. * Must be one of the following types: float16, float, int32, int8, uint8
  849. * @par Attributes:
  850. *use_locking: An optional bool. Defaults to "False". If "True",
  851. * the operation will be protected by a lock . \n
  852. * @par Outputs:
  853. * var: A Tensor. Has the same type and format as input "var" . \n
  854. * @par Third-party framework compatibility
  855. * Compatible with the TensorFlow operator ScatterNdSub.
  856. */
  857. REG_OP(ScatterNdSub)
  858. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  859. .INPUT(indices, TensorType::IndexNumberType())
  860. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  861. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  862. .ATTR(use_locking, Bool, false)
  863. .OP_END_FACTORY_REG(ScatterNdSub)
  864. /**
  865. * @brief Uses "updates" to update tensor "data" by "indices". \n
  866. * @par Inputs:
  867. * Three inputs, including:
  868. * @li var: A Tensor of type BasicType.
  869. * @li indices: A ND Tensor of type int32 or int64.
  870. * @li updates: A Tensor with the same dtype as 'var'. Same shape as indices. \n
  871. * @par Attributes:
  872. * use_locking: An optional bool. Defaults to "False". If "True",
  873. * the operation will be protected by a lock . \n
  874. * @par Outputs:
  875. * var: A Tensor. Has the same type and format as input "var" . \n
  876. * @par Third-party framework compatibility
  877. * Compatible with the TensorFlow operator ScatterNdMin.
  878. */
  879. REG_OP(ScatterNdMin)
  880. .INPUT(var, TensorType::BasicType())
  881. .INPUT(indices, TensorType::IndexNumberType())
  882. .INPUT(updates, TensorType::BasicType())
  883. .OUTPUT(var, TensorType::BasicType())
  884. .ATTR(use_locking, Bool, false)
  885. .OP_END_FACTORY_REG(ScatterNdMin)
  886. /**
  887. * @brief Applies sparse addition to individual values or slices in a Variable . \n
  888. * @par Inputs:
  889. * Three inputs, including:
  890. * @li x: An ND Tensor. \n
  891. * Must be one of the following types: float16, float32, int32, int8, uint8
  892. * @li indices: An ND Tensor. \n
  893. * Must be one of the following types: int32
  894. * @li updates: An ND Tensor. \n
  895. * Must be one of the following types: float16, float32, int32, int8, uint8
  896. * @par Outputs:
  897. * y: A Tensor. Has the same type and format as input "x" . \n
  898. * @par Third-party framework compatibility
  899. * Compatible with the TensorFlow operator TensorScatterSub.
  900. * @par Restrictions:
  901. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  902. */
  903. REG_OP(TensorScatterSub)
  904. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  905. .INPUT(indices, TensorType::IndexNumberType())
  906. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  907. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  908. .OP_END_FACTORY_REG(TensorScatterSub)
  909. /**
  910. * @brief Subtracts sparse updates to a variable reference . \n
  911. * @par Inputs:
  912. * Three inputs, including:
  913. * @li var: An ND Tensor.
  914. * Must be one of the following types: float16, float, int32, int8, uint8
  915. * @li indices: An ND Tensor.
  916. * Must be one of the following types: int32 or int64
  917. * @li updates: An ND Tensor.
  918. * Must be one of the following types: float16, float, int32, int8, uint8
  919. * @par Attributes:
  920. * use_locking: An optional bool. Defaults to "False". If "True",
  921. * the operation will be protected by a lock . \n
  922. * @par Outputs:
  923. * var: A Tensor. Has the same type and format as input "var" . \n
  924. * @par Third-party framework compatibility
  925. * Compatible with the TensorFlow operator ScatterSub.
  926. */
  927. REG_OP(ScatterSub)
  928. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  929. .INPUT(indices, TensorType::IndexNumberType())
  930. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  931. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  932. .ATTR(use_locking, Bool, false)
  933. .OP_END_FACTORY_REG(ScatterSub)
  934. /**
  935. * @brief: Returns the batched diagonal part of a batched tensor with "assist" . \n
  936. * @par Inputs:
  937. * Two inputs, including:
  938. * @li x: A Tensor of type float16, float32, or int32.
  939. * @li assist: A Tensor of the same type as "x" . \n
  940. * @par Outputs:
  941. * y: A Tensor. Has the same type as "x" . \n
  942. * @par Third-party framework compatibility
  943. * Compatible with the TensorFlow operator DiagPart.
  944. *
  945. * @par Restrictions:
  946. * Warning: THIS FUNCTION IS DEPRECATED. Please use DiagPart instead.
  947. */
  948. REG_OP(DiagPartD)
  949. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  950. .INPUT(assist, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  951. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  952. .OP_END_FACTORY_REG(DiagPartD)
  953. /**
  954. * @brief: Returns the batched diagonal part of a batched tensor . \n
  955. * @par Inputs:
  956. * x: A Tensor. Must be one of the following types:
  957. * float16, float32, int32, int64, double, complex64, complex128 . \n
  958. * @par Outputs:
  959. * y: A Tensor. Has the same type as "x" . \n
  960. * @par Third-party framework compatibility
  961. * Compatible with the TensorFlow operator DiagPart.
  962. */
  963. REG_OP(DiagPart)
  964. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  965. DT_COMPLEX64, DT_COMPLEX128}))
  966. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  967. DT_COMPLEX64, DT_COMPLEX128}))
  968. .OP_END_FACTORY_REG(DiagPart)
  969. /**
  970. * @brief Also known as a "fully-connected" layer, computes an inner product
  971. * with a set of learned weights, and (optionally) adds biases. \n
  972. * @par Inputs:
  973. * Four inputs, including:
  974. * @li x: A Tensor of type float16, int8, int4, float32.
  975. * @li w: A weight matrix of type float16, int8, int4, float32.
  976. * @li b: An optional Tensor of type float16, int32, float32.
  977. * @li offset_w: An optional Tensor of type int8, int4.
  978. * Reserved. Only None Supported. \n
  979. * @par Attributes:
  980. * @li num_output: Required. An int, output neuron number. Reserved.
  981. * @li transpose: A bool, specifying weight whether to transpose input w,
  982. * either "true" or "false". Defaults to "false".
  983. * @li axis: Optional. An int, 1 or 2, specifying which dimension the input
  984. * "K" starts from. Defaults to 1.
  985. * The product of the subsequent dimensions starting form first dimension
  986. * or the second dimension is "K".
  987. * @li offset_x: An optional integer for quantized FullyConnection.
  988. * The negative offset added to the input image for int8 type. Ensure offset_x
  989. * within the effective range of int8 [-128, 127]. Defaults to "0". \n
  990. * @par Outputs:
  991. * y: The result tensor of type float16, int32, float32. \n
  992. * @par Third-party framework compatibility
  993. * Compatible with the Caffe operator InnerProduct. \n
  994. * @par Quantization supported or not
  995. * Yes
  996. */
  997. REG_OP(FullyConnection)
  998. .INPUT(x, TensorType({DT_FLOAT16, DT_INT8, DT_INT4, DT_FLOAT32, DT_BF16}))
  999. .INPUT(w, TensorType({DT_FLOAT16, DT_INT8, DT_INT4, DT_FLOAT32, DT_BF16}))
  1000. .OPTIONAL_INPUT(b, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32, DT_BF16}))
  1001. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8, DT_INT4}))
  1002. .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32, DT_BF16}))
  1003. .REQUIRED_ATTR(num_output, Int)
  1004. .ATTR(transpose, Bool, false)
  1005. .ATTR(axis, Int, 1)
  1006. .ATTR(offset_x, Int, 0)
  1007. .OP_END_FACTORY_REG(FullyConnection)
  1008. /**
  1009. * @brief Also known as a "fully-connected-compress" layer, computes an inner
  1010. * product with a set of learned weights, and (optionally) adds biases. \n
  1011. * @par Inputs:
  1012. * Five inputs, including:
  1013. * @li x: A Tensor of type uint8, int8.
  1014. * @li w: A weight matrix of type int8.
  1015. * @li compress_index: A compress index matrix of type int8.
  1016. * @li b: A optional Tensor of type int32.
  1017. * @li offset_w: A optional Tensor of type int8.
  1018. * @par Attributes:
  1019. * @li num_output: A int, specifying the number of outputs.
  1020. * @li transpose: A bool, specifying whether to transpose input w, either "true"
  1021. * or "false". Defaults to "false".
  1022. * @li axis: Optional. A int, 1 or 2, specifying which dimension the input "K"
  1023. * starts from. Defaults to "1".
  1024. * The product of the subsequent dimensions starting form first dimension or the
  1025. * second dimension is "K".
  1026. * @li offset_x: An optional integer for quantized FullyConnectionCompress.
  1027. * The negative offset added to the input image for int8 type. Ensure offset_x
  1028. * within the effective range of int8 [-128, 127]. Defaults to "0". \n
  1029. * @par Outputs:
  1030. * y: The result tensor of type int32. \n
  1031. * @par Third-party framework compatibility
  1032. * Compatible with the Caffe operator InnerProduct. \n
  1033. * @par Quantization supported or not
  1034. * Yes
  1035. */
  1036. REG_OP(FullyConnectionCompress)
  1037. .INPUT(x, TensorType({DT_UINT8, DT_INT8}))
  1038. .INPUT(w, TensorType({DT_INT8}))
  1039. .INPUT(comress_index, TensorType({DT_INT8}))
  1040. .OPTIONAL_INPUT(b, TensorType({DT_INT32}))
  1041. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  1042. .OUTPUT(y, TensorType({DT_INT32}))
  1043. .REQUIRED_ATTR(num_output, Int)
  1044. .ATTR(transpose, Bool, false)
  1045. .ATTR(axis, Int, 1)
  1046. .ATTR(offset_x, Int, 0)
  1047. .OP_END_FACTORY_REG(FullyConnectionCompress)
  1048. /**
  1049. * @brief Computes the confusion matrix from predictions and labels . \n
  1050. * @par Inputs:
  1051. * Three inputs, including:
  1052. * @li labels: A Tensor. Must be one of the following types: float16, float32,
  1053. * int32, int8, uint8.
  1054. * @li predictions: A Tensor. Must be one of the following types: float16,
  1055. * float32, int32, int8, uint8.
  1056. * @li weights: A Tensor. Must be one of the following types: float16, float32,
  1057. * int32, int8, uint8 . \n
  1058. * @par Attributes:
  1059. * @li num_classes: An integer for the shape of the output matrix.
  1060. * No default value.
  1061. * @li dtype: Data type of the confusion matrix. No default value . \n
  1062. * @par Outputs:
  1063. * y: A Tensor. Has the same type and format as input "labels"
  1064. * @attention Constraints:
  1065. * @li "weights", "labels", and "predictions" are 1D tensors.
  1066. * @li The output is with shape (num_classes, num_classes),
  1067. * where, 1 <= num_classes <= 4096 . \n
  1068. * @see Region()
  1069. * @par Third-party framework compatibility
  1070. * Compatible with the TensorFlow operator ConfusionMatrix.
  1071. */
  1072. REG_OP(ConfusionMatrix)
  1073. .INPUT(labels, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  1074. .INPUT(predictions, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  1075. .OPTIONAL_INPUT(weights, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  1076. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  1077. .REQUIRED_ATTR(num_classes, Int)
  1078. .REQUIRED_ATTR(dtype, String)
  1079. .OP_END_FACTORY_REG(ConfusionMatrix)
  1080. /**
  1081. * @brief Multiplies sparse updates into a variable reference . \n
  1082. * @par Inputs:
  1083. * Three inputs, including:
  1084. * @li var: An ND Tensor.
  1085. * Must be one of the following types: float16, float, int32, int8, uint8
  1086. * @li indices: An ND Tensor.
  1087. * Must be one of the following types: int32 or int64
  1088. * @li updates: An ND Tensor . \n
  1089. * Must be one of the following types: float16, float, int32, int8, uint8
  1090. * @par Attributes:
  1091. * use_locking: An optional bool. Defaults to "False". If "True", the operation
  1092. * will be protected by a lock . \n
  1093. * @par Outputs:
  1094. * var: A Tensor. Has the same type and format as input "var" . \n
  1095. * @par Third-party framework compatibility
  1096. * Compatible with the TensorFlow operator ScatterMul.
  1097. */
  1098. REG_OP(ScatterMul)
  1099. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1100. .INPUT(indices, TensorType::IndexNumberType())
  1101. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1102. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1103. .ATTR(use_locking, Bool, false)
  1104. .OP_END_FACTORY_REG(ScatterMul)
  1105. /**
  1106. * @brief Reduces sparse updates into a variable reference using
  1107. * the "min" operation . \n
  1108. * @par Inputs:
  1109. * Three inputs, including:
  1110. * @li var: An ND Tensor.
  1111. * Must be one of the following types: float16, float, int32, int8, uint8
  1112. * @li indices: An ND Tensor.
  1113. * Must be one of the following types: int32 or int64
  1114. * @li updates: An ND Tensor.
  1115. * Must be one of the following types: float16, float, int32, int8, uint8
  1116. * @par Attributes:
  1117. * use_locking: An optional bool. Defaults to "False". If "True", the operation
  1118. * will be protected by a lock . \n
  1119. * @par Outputs:
  1120. * var: A Tensor. Has the same type and format as input "var" . \n
  1121. * @par Third-party framework compatibility
  1122. * Compatible with the TensorFlow operator ScatterMin.
  1123. */
  1124. REG_OP(ScatterMin)
  1125. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1126. .INPUT(indices, TensorType::IndexNumberType())
  1127. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1128. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1129. .ATTR(use_locking, Bool, false)
  1130. .OP_END_FACTORY_REG(ScatterMin)
  1131. /**
  1132. * @brief Reduces sparse updates into a variable reference using the "max" operation . \n
  1133. * @par Inputs:
  1134. * Three inputs, including:
  1135. * @li var: An ND Tensor .
  1136. * Must be one of the following types: float16, float, int32, int8, uint8
  1137. * @li indices: An NCHW, NHWC, or ND Tensor . \n
  1138. * Must be one of the following types: int32 or int64
  1139. * @li updates: An NCHW, NHWC, or ND Tensor .
  1140. * Must be one of the following types: float16, float, int32, int8, uint8
  1141. * @par Attributes:
  1142. * use_locking: An optional bool. Defaults to "False".
  1143. * If "True", the operation will be protected by a lock . \n
  1144. * @par Outputs:
  1145. * var: A Tensor. Has the same type and format as input "var" . \n
  1146. * @par Third-party framework compatibility
  1147. * Compatible with the TensorFlow operator ScatterMax.
  1148. */
  1149. REG_OP(ScatterMax)
  1150. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1151. .INPUT(indices, TensorType::IndexNumberType())
  1152. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1153. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1154. .ATTR(use_locking, Bool, false)
  1155. .OP_END_FACTORY_REG(ScatterMax)
  1156. /**
  1157. * @brief Applies sparse updates to a variable reference . \n
  1158. * @par Inputs:
  1159. * Three inputs, including:
  1160. * @li var: An ND Tensor .
  1161. * Must be one of the following types: float16, float, int32, int8, uint8
  1162. * @li indices: An ND Tensor . \n
  1163. * Must be one of the following types: int32 or int64
  1164. * @li updates: An ND Tensor .
  1165. * Must be one of the following types: float16, float, int32, int8, uint8
  1166. * @par Attributes:
  1167. * use_locking: An optional bool. Defaults to "False". If "True",
  1168. * the operation will be protected by a lock . \n
  1169. * @par Outputs:
  1170. * var: A Tensor. Has the same type and format as input "var" . \n
  1171. * @par Third-party framework compatibility
  1172. * Compatible with the TensorFlow operator ScatterUpdate.
  1173. */
  1174. REG_OP(ScatterUpdate)
  1175. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1176. .INPUT(indices, TensorType::IndexNumberType())
  1177. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1178. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1179. .ATTR(use_locking, Bool, false)
  1180. .OP_END_FACTORY_REG(ScatterUpdate)
  1181. /**
  1182. * @brief Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched `input` . \n
  1183. * @par Inputs:
  1184. * Three inputs, including:
  1185. * @li input: Rank `r` tensor where `r >= 2`. \n
  1186. * @li k: \n
  1187. * Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  1188. * diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  1189. * (for a single diagonal) or a pair of integers specifying the low and high ends \n
  1190. * of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1191. * @li padding_value: The value to fill the area outside the specified diagonal band with. \n
  1192. * @par Outputs:
  1193. * diagonal: The extracted diagonal(s) . \n
  1194. * @par Third-party framework compatibility
  1195. * Compatible with the TensorFlow operator ScatterUpdate.
  1196. */
  1197. REG_OP(MatrixDiagPartV2)
  1198. .INPUT(input, TensorType::BasicType())
  1199. .INPUT(k, TensorType({DT_INT32}))
  1200. .INPUT(padding_value, TensorType::BasicType())
  1201. .OUTPUT(diagonal, TensorType::BasicType())
  1202. .OP_END_FACTORY_REG(MatrixDiagPartV2)
  1203. /**
  1204. * @brief Returns a batched matrix tensor with new batched diagonal values . \n
  1205. * @par Inputs:
  1206. * Three inputs, including:
  1207. * @li input: "Rank `r+1`, where `r >= 1`. \n
  1208. * @li diagonal: Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`. \n
  1209. * @li k:
  1210. * Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  1211. * diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  1212. * (for a single diagonal) or a pair of integers specifying the low and high ends \n
  1213. * of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1214. * @par Outputs:
  1215. * output: Rank `r+1`, with `output.shape = input.shape` . \n
  1216. * @par Third-party framework compatibility
  1217. * Compatible with the TensorFlow operator ScatterUpdate.
  1218. */
  1219. REG_OP(MatrixSetDiagV2)
  1220. .INPUT(input, TensorType::BasicType())
  1221. .INPUT(diagonal, TensorType::BasicType())
  1222. .INPUT(k, TensorType({DT_INT32}))
  1223. .OUTPUT(output, TensorType::BasicType())
  1224. .OP_END_FACTORY_REG(MatrixSetDiagV2)
  1225. /**
  1226. * @brief Returns a batched matrix tensor with new batched diagonal values . \n
  1227. * @par Inputs:
  1228. * Three inputs, including:
  1229. * @li input: "Rank `r+1`, where `r >= 1`. \n
  1230. * @li diagonal: Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`. \n
  1231. * @li k:
  1232. * Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  1233. * diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  1234. * (for a single diagonal) or a pair of integers specifying the low and high ends \n
  1235. * of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1236. * @par Attributes:
  1237. * @li align: An optional string. Defaults to RIGHT_LEFT. It is a string specifying \n
  1238. * how superdiagonals and subdiagonals should be aligned, respectively. \n
  1239. * other optional: LEFT_RIGHT, LEFT_LEFT, and RIGHT_RIGHT.\n
  1240. * @par Outputs:
  1241. * output: Rank `r+1`, with `output.shape = input.shape` . \n
  1242. * @par Third-party framework compatibility
  1243. * Compatible with the TensorFlow operator ScatterUpdate.
  1244. */
  1245. REG_OP(MatrixSetDiagV3)
  1246. .INPUT(input, TensorType::BasicType())
  1247. .INPUT(diagonal, TensorType::BasicType())
  1248. .INPUT(k, TensorType({DT_INT32}))
  1249. .OUTPUT(output, TensorType::BasicType())
  1250. .ATTR(align, String, "RIGHT_LEFT")
  1251. .OP_END_FACTORY_REG(MatrixSetDiagV3)
  1252. /**
  1253. * @brief Returns a batched diagonal tensor with given batched diagonal values . \n
  1254. * @par Inputs:
  1255. * Five inputs, including:
  1256. * @li diagonal: Rank `r`, where `r >= 1` \n
  1257. * @li k:
  1258. * Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  1259. * diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  1260. * (for a single diagonal) or a pair of integers specifying the low and high ends \n
  1261. * of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1262. * @li num_rows:
  1263. * The number of rows of the output matrix. If it is not provided, the op assumes \n
  1264. * the output matrix is a square matrix and infers the matrix size from k and the \n
  1265. * innermost dimension of `diagonal`. \n
  1266. * @li num_cols: An NCHW, NHWC, or ND Tensor.
  1267. * The number of columns of the output matrix. If it is not provided, the op \n
  1268. * assumes the output matrix is a square matrix and infers the matrix size from \n
  1269. * k and the innermost dimension of `diagonal`. \n
  1270. * @li padding_value: The number to fill the area outside the specified diagonal band with. \n
  1271. * @par Outputs:
  1272. * output: Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise . \n
  1273. * @par Third-party framework compatibility
  1274. * Compatible with the TensorFlow operator ScatterUpdate.
  1275. */
  1276. REG_OP(MatrixDiagV2)
  1277. .INPUT(diagonal, TensorType::BasicType())
  1278. .INPUT(k, TensorType({DT_INT32}))
  1279. .INPUT(num_rows, TensorType({DT_INT32}))
  1280. .INPUT(num_cols, TensorType({DT_INT32}))
  1281. .INPUT(padding_value, TensorType::BasicType())
  1282. .OUTPUT(output, TensorType::BasicType())
  1283. .OP_END_FACTORY_REG(MatrixDiagV2)
  1284. /**
  1285. * @brief Add updates to var_out according to axis and indices.
  1286. * @par Inputs:
  1287. * Three inputs, including:
  1288. * @li var: A Tensor. Must be one of the following types:
  1289. * float16, float32, int32, int8, uint8.
  1290. * @li indices: A Tensor of the indices, type should be int32.
  1291. * @li updates: A Tensor of the same type as "var".
  1292. * @par Attributes:
  1293. * @li axis: An required int to specify the axis to perform indices add.
  1294. * @par Outputs:
  1295. * @li var_out: A Tensor. Same as input "var".
  1296. * @par Third-party framework compatibility
  1297. * Compatible with the Pytorch operator index_add.
  1298. * @par Restrictions:
  1299. * Warning:THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  1300. */
  1301. REG_OP(IndexAdd)
  1302. .INPUT(var, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  1303. .INPUT(indices, TensorType({DT_INT32}))
  1304. .INPUT(updates, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  1305. .OUTPUT(var_out, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  1306. .ATTR(axis, Int, 0)
  1307. .OP_END_FACTORY_REG(IndexAdd)
  1308. /**
  1309. * @brief According to the index number of indexes, replace the value
  1310. * corresponding to X1 with the value in x2.
  1311. * @par Inputs:
  1312. * Three inputs, including:
  1313. * @li x1: A Tensor. Must be one of the following types:
  1314. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1315. * qint8, quint8, qint32, uint16, complex128, uint32, uint64. \n
  1316. * @li x2: A Tensor of the same type as "x1".
  1317. * @li indices: A Tensor of the indices,
  1318. * @par Attributes:
  1319. * @li accumulate: Does it support self accumulation.Defaults to 0.
  1320. * @par Outputs:
  1321. * @li y: A Tensor. Same as input "x1".
  1322. * @par Third-party framework compatibility
  1323. * Compatible with the Pytorch operator index_put.
  1324. * @par Restrictions:
  1325. * Warning:THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  1326. */
  1327. REG_OP(IndexPut)
  1328. .INPUT(x1, TensorType::BasicType())
  1329. .INPUT(x2, TensorType::BasicType())
  1330. .OUTPUT(y, TensorType::BasicType())
  1331. .REQUIRED_ATTR(indices, ListInt)
  1332. .ATTR(accumulate, Int, 0)
  1333. .OP_END_FACTORY_REG(IndexPut)
  1334. /**
  1335. * @brief: Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input \n
  1336. * @par Inputs:
  1337. * x: A Tensor. Must be one of the following types:
  1338. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1339. * qint8, quint8, qint32, uint16, complex128, uint32, uint64. \n
  1340. * @par Attributes:
  1341. * diagonal: An optional attribute indicates the diagonal to consider. \n
  1342. * @par Outputs:
  1343. * y: A Tensor. Has the same type as "x" . \n
  1344. * @par Third-party framework compatibility
  1345. * Compatible with the Pytorch operator Triu.
  1346. */
  1347. REG_OP(Triu)
  1348. .INPUT(x, TensorType::BasicType())
  1349. .ATTR(diagonal, Int, 0)
  1350. .OUTPUT(y, TensorType::BasicType())
  1351. .OP_END_FACTORY_REG(Triu)
  1352. /**
  1353. * @brief: Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input \n
  1354. *@par Inputs:
  1355. * x: A Tensor. Must be one of the following types:
  1356. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1357. * qint8, quint8, qint32, uint16, complex128, uint32, uint64. \n
  1358. * @par Attributes:
  1359. * diagonal: An optional attribute indicates the diagonal to consider. \n
  1360. * @par Outputs:
  1361. * y: A Tensor. Has the same type as "x" . \n
  1362. * @par Third-party framework compatibility
  1363. * Compatible with the Pytorch operator Tril.
  1364. */
  1365. REG_OP(Tril)
  1366. .INPUT(x, TensorType::BasicType())
  1367. .ATTR(diagonal, Int, 0)
  1368. .OUTPUT(y, TensorType::BasicType())
  1369. .OP_END_FACTORY_REG(Tril)
  1370. /**
  1371. * @brief Concatenates a list of N tensors along the first dimension.
  1372. * @par Inputs:
  1373. * @li x: A list of Tensors. Must be one of the following types: int32,
  1374. * float16, float32. Tensors to be concatenated. All must have size 1 in
  1375. * the first dimension and same shape.It's a dynamic input. \n
  1376. * @par Attributes:
  1377. * @li equation: The subscripts for the Einstein summation. \n
  1378. * @li N: tensor size of input. \n
  1379. * @par Outputs:
  1380. * @li y: Sums the product of the elements of the input operands along
  1381. * dimensions specified
  1382. * using a notation based on the Einstein summation convention. \n
  1383. * @attention Constraints:
  1384. * Input N must be Int. \n
  1385. * @par Third-party framework compatibility
  1386. * Compatible with Pytorch einsum operator.
  1387. */
  1388. REG_OP(Einsum)
  1389. .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
  1390. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
  1391. .REQUIRED_ATTR(equation, String)
  1392. .REQUIRED_ATTR(N, Int)
  1393. .OP_END_FACTORY_REG(Einsum)
  1394. /**
  1395. * @brief Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. \n
  1396. * @par Inputs:
  1397. * No inputs
  1398. * @par Attributes:
  1399. * @li num_rows: An required int. \n
  1400. * @li num_columns: An optional int.Defaults to 0. \n
  1401. * @li batch_shape: An optional ListInt.Defaults to []. \n
  1402. * @li dtype: An optional int.Defaults to 0. \n
  1403. * @par Outputs:
  1404. * y: A Tensor with targeted type and shape. \n
  1405. * @par Third-party framework compatibility
  1406. * Compatible with the Pytorch operator Eye. \n
  1407. */
  1408. REG_OP(Eye)
  1409. .OUTPUT(y, TensorType::BasicType()) /* "Result, has targeted element type" */
  1410. .REQUIRED_ATTR(num_rows, Int)
  1411. .ATTR(num_columns, Int, 0)
  1412. .ATTR(batch_shape, ListInt, {})
  1413. .ATTR(dtype, Int, 0)
  1414. .OP_END_FACTORY_REG(Eye)
  1415. /**
  1416. * @brief: Fill diagonal of at least 2 dimension tensors with value . \n
  1417. * @par Inputs:
  1418. * x: A Tensor. Must be one of the following types:
  1419. * float32, int32, int64 . \n
  1420. * @par Outputs:
  1421. *y: A Tensor. Has the same type as "x" . \n
  1422. * @par Attributes:
  1423. * fill_value:The value to fill in
  1424. * wrap: An optional bool. Defaults to "False". If "True", Use recursive fill. \n
  1425. * @par Third-party framework compatibility
  1426. * Compatible with the Pytorch operator FillDiagonal.
  1427. */
  1428. REG_OP(FillDiagonal)
  1429. .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT64}))
  1430. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64}))
  1431. .REQUIRED_ATTR(fill_value, Float)
  1432. .ATTR(wrap, Bool, false)
  1433. .OP_END_FACTORY_REG(FillDiagonal)
  1434. /**
  1435. * @brief: Returns the sum of the elements of the diagonal of the input 2-D matrix. \n
  1436. * @par Inputs:
  1437. * x: A Tensor. Must be one of the following types:
  1438. * float16, float. \n
  1439. * @par Outputs:
  1440. * y: A Tensor. Has the same type as "x" . \n
  1441. * @par Third-party framework compatibility
  1442. * Compatible with the Pytorch operator Trace.
  1443. */
  1444. REG_OP(Trace)
  1445. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT}))
  1446. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT}))
  1447. .OP_END_FACTORY_REG(Trace)
  1448. /**
  1449. * @brief Computes the generalized inverse of any matrix. \n
  1450. * @par Inputs:
  1451. * @li x: input matrix. Must be one of the following types:
  1452. * double, float. \n
  1453. * @par Attributes:
  1454. * @li rcond: An optional float >= 0 or inf. Defaults to 1e-15. \n
  1455. * @par Outputs:
  1456. * y: A Tensor with the same type and shape of x's transpose. \n
  1457. */
  1458. REG_OP(Pinverse)
  1459. .INPUT(x, TensorType({ DT_FLOAT, DT_DOUBLE }))
  1460. .OUTPUT(y, TensorType({ DT_FLOAT, DT_DOUBLE }))
  1461. .ATTR(rcond, Float, 1e-15)
  1462. .OP_END_FACTORY_REG(Pinverse)
  1463. /**
  1464. * @brief From the input tensor and updates tensor, select the maximum value according to indices to output. \n
  1465. * @par Inputs:
  1466. * Three inputs, including:
  1467. * @li input: Must be one of the following types:
  1468. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1469. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  1470. * @li indices: Must be one of the following types:
  1471. * int32, int64.
  1472. * @li updates: Must have the same type as input. \n
  1473. * @par Outputs:
  1474. * output: A Tensor with the same type as input. \n
  1475. */
  1476. REG_OP(TensorScatterMax)
  1477. .INPUT(input, TensorType::BasicType())
  1478. .INPUT(indices, TensorType::IndexNumberType())
  1479. .INPUT(updates, TensorType::BasicType())
  1480. .OUTPUT(output, TensorType::BasicType())
  1481. .OP_END_FACTORY_REG(TensorScatterMax)
  1482. /**
  1483. * @brief From the input tensor and updates tensor, select the minimum value according to indices to output. \n
  1484. * @par Inputs:
  1485. * Three inputs, including:
  1486. * @li input: Must be one of the following types:
  1487. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1488. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  1489. * @li indices: Must be one of the following types:
  1490. * int32, int64.
  1491. * @li updates: Must have the same type as input. \n
  1492. * @par Outputs:
  1493. * output: A Tensor with the same type as input. \n
  1494. */
  1495. REG_OP(TensorScatterMin)
  1496. .INPUT(input, TensorType::BasicType())
  1497. .INPUT(indices, TensorType::IndexNumberType())
  1498. .INPUT(updates, TensorType::BasicType())
  1499. .OUTPUT(output, TensorType::BasicType())
  1500. .OP_END_FACTORY_REG(TensorScatterMin)
  1501. /**
  1502. * @brief: Returns the batched diagonal part of a batched tensor. \n
  1503. * @par Inputs:
  1504. * @li x: A Tensor. Rank r tensor where r >= 2.
  1505. * @li k: A Tensor of type int32. Diagonal offset(s). Positive value means superdiagonal,
  1506. 0 refers to the main diagonal, and negative value means subdiagonals. k can be a
  1507. single integer (for a single diagonal) or a pair of integers specifying the low and
  1508. high ends of a matrix band. k[0] must not be larger than k[1].
  1509. * @li padding_value:A Tensor. Must have the same type as input. The value to fill the area
  1510. outside the specified diagonal band with. Default is 0. \n
  1511. * @par Outputs:
  1512. * @li y: A Tensor. Has the same type as "input". \n
  1513. * @par Attributes:
  1514. * @li align:An optional string from: "LEFT_RIGHT", "RIGHT_LEFT", "LEFT_LEFT", "RIGHT_RIGHT". Defaults to "RIGHT_LEFT".
  1515. * @par Third-party framework compatibility
  1516. * Compatible with the Tensorflow operator FillDiagonal.
  1517. */
  1518. REG_OP(MatrixDiagPartV3)
  1519. .INPUT(x, TensorType::BasicType())
  1520. .INPUT(k, TensorType({DT_INT32}))
  1521. .INPUT(padding_value, TensorType::BasicType())
  1522. .OUTPUT(y, TensorType::BasicType())
  1523. .ATTR(align,String ,"RIGHT_LEFT")
  1524. .OP_END_FACTORY_REG(MatrixDiagPartV3)
  1525. /**
  1526. * @brief Returns a batched diagonal tensor with given batched diagonal values . \n
  1527. * @par Inputs:
  1528. * Five inputs, including:
  1529. * @li x: Rank `r`, where `r >= 1` \n
  1530. * @li k:
  1531. * Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main
  1532. * diagonal, and negative value means subdiagonals. `k` can be a single integer
  1533. * (for a single diagonal) or a pair of integers specifying the low and high ends
  1534. * of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1535. * @li num_rows:
  1536. * The number of rows of the output matrix. If it is not provided, the op assumes
  1537. * the output matrix is a square matrix and infers the matrix size from k and the
  1538. * innermost dimension of `diagonal`. \n
  1539. * @li num_cols: An NCHW, NHWC, or ND Tensor.
  1540. * The number of columns of the output matrix. If it is not provided, the op
  1541. * assumes the output matrix is a square matrix and infers the matrix size from
  1542. * k and the innermost dimension of `diagonal`. \n
  1543. * @li padding_value: The number to fill the area outside the specified diagonal band with. \n
  1544. * @par Attributes:
  1545. * @li align: An optional string from: "LEFT_RIGHT", "RIGHT_LEFT", "LEFT_LEFT", "RIGHT_RIGHT".
  1546. * Defaults to "RIGHT_LEFT" \n
  1547. * @par Outputs:
  1548. * @li y: Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise . \n
  1549. * @par Third-party framework compatibility
  1550. * Compatible with the TensorFlow operator ScatterUpdate.
  1551. */
  1552. REG_OP(MatrixDiagV3)
  1553. .INPUT(x, TensorType({BasicType(), DT_BOOL}))
  1554. .INPUT(k, TensorType({DT_INT32}))
  1555. .INPUT(num_rows, TensorType({DT_INT32}))
  1556. .INPUT(num_cols, TensorType({DT_INT32}))
  1557. .INPUT(padding_value, TensorType({BasicType(), DT_BOOL}))
  1558. .OUTPUT(y, TensorType({BasicType(), DT_BOOL}))
  1559. .ATTR(align, String, "RIGHT_LEFT")
  1560. .OP_END_FACTORY_REG(MatrixDiagV3)
  1561. /**
  1562. * @brief Function SwinAttentionScore. \n
  1563. * @par Inputs:
  1564. * six inputs, including:
  1565. * @li query: A matrix Tensor. The type only support float16.
  1566. * @li key: A matrix Tensor. The type only support float16.
  1567. * @li value: A matrix Tensor. The type only support float16.
  1568. * @li padding_mask1: A matrix Tensor. The type only support float16.
  1569. * @li padding_mask2: A matrix Tensor. The type only support float16.
  1570. * @li scale: A scalar. The type only support float16.
  1571. * @li drop_mask: A matrix Tensor. The type only support uint8. \n
  1572. * @par Attributes:
  1573. * @li keep_prob: A mutable Tensor. Must met all of the following rules:
  1574. shape of "keep_prob" should be (1,) or [1,].
  1575. * @li query_transpose: A bool. If True, changes the shape of "query" from [K, M] to
  1576. [M, K].
  1577. * @li key_transpose: A bool. If True, changes the shape of "key" from [N, K] to
  1578. [K, N].
  1579. * @li bmm_score_transpose_a: A bool. If True, changes the shape of "mid_data" from [K, M] to
  1580. [M, K].
  1581. * @li bmm_score_transpose_b: A bool. If True, changes the shape of "value" from [N, K] to
  1582. [K, N].
  1583. * @li axes: A list of int. The dimension softmax would be performed on. Defaults
  1584. to "[]" . \n
  1585. * @par Outputs:
  1586. * attention_score: The result matrix Tensor. The type only support float16.
  1587. * softmax: The result matrix Tensor. The type only support float16.
  1588. * @par Restrictions:
  1589. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  1590. */
  1591. REG_OP(SwinAttentionScore)
  1592. .INPUT(query, TensorType({DT_FLOAT16}))
  1593. .INPUT(key, TensorType({DT_FLOAT16}))
  1594. .INPUT(value, TensorType({DT_FLOAT16}))
  1595. .INPUT(padding_mask1, TensorType({DT_FLOAT16}))
  1596. .OPTIONAL_INPUT(padding_mask2, TensorType({DT_FLOAT16}))
  1597. .INPUT(scale, TensorType({DT_FLOAT16}))
  1598. .OPTIONAL_INPUT(drop_mask, TensorType({DT_INT8}))
  1599. .OUTPUT(attention_score, TensorType({DT_FLOAT16}))
  1600. .OUTPUT(softmax, TensorType({DT_FLOAT16}))
  1601. .ATTR(keep_prob, Float, 1.0)
  1602. .ATTR(query_transpose, Bool, false)
  1603. .ATTR(key_transpose, Bool, false)
  1604. .ATTR(bmm_score_transpose_a, Bool, false)
  1605. .ATTR(bmm_score_transpose_b, Bool, false)
  1606. .ATTR(softmax_axes, ListInt, {})
  1607. .OP_END_FACTORY_REG(SwinAttentionScore)
  1608. /**
  1609. * @brief
  1610. swin_transformer model specific structure.Operator only supports swin_transformer. \n
  1611. * @par Inputs:
  1612. * Three inputs, including:
  1613. * @li x: A Tensor. Must be one of the following types: float16.
  1614. * @li weight: A Tensor. Must be one of the following types: float16.
  1615. * @li bias: A Tensor. Must be one of the following types: float16. \n
  1616. * @par Attributes:
  1617. * @li shifts: A optional attribute, the type is list int. Defaults to (). \n
  1618. * @par Outputs:
  1619. * One output, including:
  1620. * @li y: A Tensor. Must be one of the following types: float16. \n
  1621. * @par Restrictions:
  1622. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n
  1623. */
  1624. REG_OP(SwinAttentionFFN)
  1625. .INPUT(x1, TensorType({DT_FLOAT16}))
  1626. .INPUT(x2, TensorType({DT_FLOAT16}))
  1627. .INPUT(bias, TensorType({DT_FLOAT16}))
  1628. .OUTPUT(y, TensorType({DT_FLOAT16}))
  1629. .ATTR(shifts, ListInt, {})
  1630. .OP_END_FACTORY_REG(SwinAttentionFFN)
  1631. } // namespace ge
  1632. #endif // OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示