You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matrix_calculation_ops.h 72 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
3 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file matrix_calculation_ops.h
  18. * \brief
  19. */
  20. #ifndef OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_
  21. #define OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_
  22. #include "graph/operator_reg.h"
  23. namespace ge {
  24. /**
  25. * @brief Backprop W of AttentionLnQKV + ReduceSumD \n
  26. * @par Inputs:
  27. * Four inputs, including:
  28. * @li x: A Tensor. Must be one of the following types: float16.
  29. * @li query_dx: A Tensor. Must be one of the following types: float16.
  30. * @li key_dw: A Tensor. Must be one of the following types: float16.
  31. * @li value_dw: A Tensor. Must be one of the following types: float16.
  32. * @par Attributes:
  33. * @li trans_a: A optional attribute, the type is bool. Defaults to True.
  34. * @li trans_b: A optional attribute, the type is bool. Defaults to False. \n
  35. * @par Outputs:
  36. * Six outputs, including:
  37. * @li dw_query: A Tensor. Must be one of the following types: float16.
  38. * @li dw_key: A Tensor. Must be one of the following types: float16.
  39. * @li dw_value: A Tensor. Must be one of the following types: float16.
  40. * @li dbias_query: A Tensor. Must be one of the following types: float16.
  41. * @li dbias_key: A Tensor. Must be one of the following types: float16.
  42. * @li dbias_value: A Tensor. Must be one of the following types: float16. \n
  43. * @par Restrictions:
  44. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n
  45. */
  46. REG_OP(AttentionQKVGradW)
  47. .INPUT(x, TensorType({DT_FLOAT16}))
  48. .INPUT(query_dx, TensorType({DT_FLOAT16}))
  49. .OPTIONAL_INPUT(key_dw, TensorType({DT_FLOAT16}))
  50. .OPTIONAL_INPUT(value_dw, TensorType({DT_FLOAT16}))
  51. .OUTPUT(dw_query, TensorType({DT_FLOAT16}))
  52. .OUTPUT(dw_key, TensorType({DT_FLOAT16}))
  53. .OUTPUT(dw_value, TensorType({DT_FLOAT16}))
  54. .OUTPUT(dbias_query, TensorType({DT_FLOAT16}))
  55. .OUTPUT(dbias_key, TensorType({DT_FLOAT16}))
  56. .OUTPUT(dbias_value, TensorType({DT_FLOAT16}))
  57. .ATTR(trans_a, Bool, true)
  58. .ATTR(trans_b, Bool, false)
  59. .OP_END_FACTORY_REG(AttentionQKVGradW)
  60. /**
  61. * @brief Backprop X of AttentionLnQKV + AddN \n
  62. * @par Inputs:
  63. * Seven inputs, including:
  64. * @li ln_dx: A Tensor. Must be one of the following types: float16.
  65. * @li query_dx: A Tensor. Must be one of the following types: float16.
  66. * @li key_dw: A Tensor. Must be one of the following types: float16.
  67. * @li value_dw: A Tensor. Must be one of the following types: float16.
  68. * @li kernel_query: A Tensor. Must be one of the following types: float16.
  69. * @li kernel_key: A Tensor. Must be one of the following types: float16.
  70. * @li kernel_value: A Tensor. Must be one of the following types: float16. \n
  71. * @par Attributes:
  72. * @li trans_a: A optional attribute, the type is bool. Defaults to False.
  73. * @li trans_b: A optional attribute, the type is bool. Defaults to True. \n
  74. * @par Outputs:
  75. * One outputs, including:
  76. * @li dx: A Tensor. Must be one of the following types: float16. \n
  77. * @par Restrictions:
  78. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n
  79. */
  80. REG_OP(AttentionQKVGradX)
  81. .INPUT(ln_dx, TensorType({DT_FLOAT16}))
  82. .INPUT(query_dx, TensorType({DT_FLOAT16}))
  83. .INPUT(key_dw, TensorType({DT_FLOAT16}))
  84. .INPUT(value_dw, TensorType({DT_FLOAT16}))
  85. .INPUT(kernel_query, TensorType({DT_FLOAT16}))
  86. .INPUT(kernel_key, TensorType({DT_FLOAT16}))
  87. .INPUT(kernel_value, TensorType({DT_FLOAT16}))
  88. .OUTPUT(dx, TensorType({DT_FLOAT16}))
  89. .ATTR(trans_a, Bool, false)
  90. .ATTR(trans_b, Bool, true)
  91. .OP_END_FACTORY_REG(AttentionQKVGradX)
  92. /**
  93. * @brief
  94. / (MatMul -> ConfusionTransposeD).
  95. LayerNorm - (MatMul -> ConfusionTransposeD).
  96. \ (MatMul -> ConfusionTransposeD). \n
  97. * @par Inputs:
  98. * Nine inputs, including:
  99. * @li x: A Tensor. Must be one of the following types: float16.
  100. * @li kernel_query: A Tensor. Must be one of the following types: float16.
  101. * @li kernel_key: A Tensor. Must be one of the following types: float16.
  102. * @li kernel_value: A Tensor. Must be one of the following types: float16.
  103. * @li gamma: A Tensor. Must be one of the following types: float16.
  104. * @li beta: A Tensor. Must be one of the following types: float16.
  105. * @li bias_query: A Tensor. Must be one of the following types: float16.
  106. * @li bias_key: A Tensor. Must be one of the following types: float16.
  107. * @li bias_value: A Tensor. Must be one of the following types: float16. \n
  108. * @par Attributes:
  109. * @li epsilon: A optional attribute, the type is float32. Defaults to 1e-7.
  110. * @li trans_a: A optional attribute, the type is bool. Defaults to False.
  111. * @li trans_b: A optional attribute, the type is bool. Defaults to False. \n
  112. * @par Outputs:
  113. * Six outputs, including:
  114. * @li norm: A Tensor. Must be one of the following types: float16.
  115. * @li query_output: A Tensor. Must be one of the following types: float16.
  116. * @li key_output: A Tensor. Must be one of the following types: float16.
  117. * @li value_output: A Tensor. Must be one of the following types: float16.
  118. * @li mean: A Tensor. Must be one of the following types: float16.
  119. * @li variance: A Tensor. Must be one of the following types: float16. \n
  120. * @par Restrictions:
  121. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n
  122. */
  123. REG_OP(AttentionLnQKV)
  124. .INPUT(x, TensorType({DT_FLOAT16}))
  125. .INPUT(kernel_query, TensorType({DT_FLOAT16}))
  126. .INPUT(kernel_key, TensorType({DT_FLOAT16}))
  127. .INPUT(kernel_value, TensorType({DT_FLOAT16}))
  128. .INPUT(gamma, TensorType({DT_FLOAT16}))
  129. .INPUT(beta, TensorType({DT_FLOAT16}))
  130. .OPTIONAL_INPUT(bias_query, TensorType({DT_FLOAT16}))
  131. .OPTIONAL_INPUT(bias_key, TensorType({DT_FLOAT16}))
  132. .OPTIONAL_INPUT(bias_value, TensorType({DT_FLOAT16}))
  133. .OUTPUT(norm, TensorType({DT_FLOAT16}))
  134. .OUTPUT(query_output, TensorType({DT_FLOAT16}))
  135. .OUTPUT(key_output, TensorType({DT_FLOAT16}))
  136. .OUTPUT(value_output, TensorType({DT_FLOAT16}))
  137. .OUTPUT(mean, TensorType({DT_FLOAT16}))
  138. .OUTPUT(variance, TensorType({DT_FLOAT16}))
  139. .ATTR(epsilon, Float, 0.0000001)
  140. .ATTR(trans_a, Bool, false)
  141. .ATTR(trans_b, Bool, false)
  142. .OP_END_FACTORY_REG(AttentionLnQKV)
  143. /**
  144. * @brief
  145. swin_transformer model specific structure.Operator only supports swin_transformer. \n
  146. * @par Inputs:
  147. * Five inputs, including:
  148. * @li x: A Tensor. Must be one of the following types: float16.
  149. * @li gamma: A Tensor. Must be one of the following types: float16.
  150. * @li beta: A Tensor. Must be one of the following types: float16.
  151. * @li weight: A Tensor. Must be one of the following types: float16.
  152. * @li bias: A Tensor. Must be one of the following types: float16. \n
  153. * @par Attributes:
  154. * @li head_num: A optional attribute, the type is int.
  155. * @li head_dim: A optional attribute, the type is int.
  156. * @li seq_length: A optional attribute, the type is int.
  157. * @li shifts: A optional attribute, the type is list int. Defaults to ().
  158. * @li epsilon: A optional attribute, the type is float. Defaults to 1e-7. \n
  159. * @par Outputs:
  160. * Three outputs, including:
  161. * @li query_output: A Tensor. Must be one of the following types: float16.
  162. * @li key_output: A Tensor. Must be one of the following types: float16.
  163. * @li value_output: A Tensor. Must be one of the following types: float16. \n
  164. * @par Restrictions:
  165. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n
  166. */
  167. REG_OP(SwinTransformerLnQKV)
  168. .INPUT(x, TensorType({DT_FLOAT16}))
  169. .INPUT(gamma, TensorType({DT_FLOAT16}))
  170. .INPUT(beta, TensorType({DT_FLOAT16}))
  171. .INPUT(weight, TensorType({DT_FLOAT16}))
  172. .INPUT(bias, TensorType({DT_FLOAT16}))
  173. .OUTPUT(query_output, TensorType({DT_FLOAT16}))
  174. .OUTPUT(key_output, TensorType({DT_FLOAT16}))
  175. .OUTPUT(value_output, TensorType({DT_FLOAT16}))
  176. .REQUIRED_ATTR(head_num, Int)
  177. .REQUIRED_ATTR(head_dim, Int)
  178. .REQUIRED_ATTR(seq_length, Int)
  179. .ATTR(shifts, ListInt, {})
  180. .ATTR(epsilon, Float, 0.0000001)
  181. .OP_END_FACTORY_REG(SwinTransformerLnQKV)
  182. /**
  183. * @brief Multiplies matrix "a" by matrix "b", producing "a * b". \n
  184. * @par Inputs:
  185. * Three inputs, including:
  186. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float16,
  187. * float32, int32. Has format [ND, NHWC].
  188. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float16,
  189. * float32, int32. Has format [ND, NHWC].
  190. * @li bias: A optional 1D Tensor. Must be one of the following types: float16,
  191. * float32, int32. Has format [ND, NHWC]. \n
  192. * @par Attributes:
  193. * @li transpose_x1: A bool. If True, changes the shape of "x1" from [K, M] to
  194. * [M, K] before multiplication.
  195. * @li transpose_x2: A bool. If True, changes the shape of "x2" from [N, K] to
  196. * [K, N] before multiplication. \n
  197. * @par Outputs:
  198. * y: The result matrix Tensor. 2D. Must be one of the following types: float16,
  199. * float32, int32. Has format [ND, NHWC]. \n
  200. * @par Third-party framework compatibility
  201. * Compatible with the TensorFlow operator BatchMatmul.
  202. */
  203. REG_OP(MatMul)
  204. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  205. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  206. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  207. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  208. .ATTR(transpose_x1, Bool, false)
  209. .ATTR(transpose_x2, Bool, false)
  210. .OP_END_FACTORY_REG(MatMul)
  211. /**
  212. * @brief Multiplies matrix "a" by matrix "b", producing "a * b". \n
  213. * @par Inputs:
  214. * Four inputs, including:
  215. * @li x1: A matrix Tensor. 2D. Must be one of the following types: float32,
  216. * float16, int32, int8, int4. Has format [ND, NHWC].
  217. * @li x2: A matrix Tensor. 2D. Must be one of the following types: float32,
  218. * float16, int32, int8, int4. Has format [ND, NHWC].
  219. * @li bias: A 1D Tensor. Must be one of the following types: float32,
  220. * float16, int32. Has format [ND, NHWC].
  221. * @li offset_w: A Optional 1D Tensor for quantized inference. Type is int8, int4.
  222. * Reserved. \n
  223. * @par Attributes:
  224. * @li transpose_x1: A bool. If True, changes the shape of "x1" from [K, M] to
  225. * [M, K] before multiplication.
  226. * @li transpose_x2: A bool. If True, changes the shape of "x2" from [N, K] to
  227. * [K, N] before multiplication.
  228. * @li offset_x: An optional integer for quantized MatMulV2.
  229. * The negative offset added to the input x1 for int8 type. Ensure offset_x
  230. * within the effective range of int8 [-128, 127]. Defaults to "0". \n
  231. * @par Outputs:
  232. * y: The result matrix Tensor. 2D. Must be one of the following types: float32,
  233. * float16, int32. Has format [ND, NHWC]. \n
  234. * @attention Constraints:
  235. * if performances better in format NZ, please close
  236. * "MatmulTransdataFusionPass" in fusion configuration. \n
  237. * @par Third-party framework compatibility
  238. * Compatible with the TensorFlow operator BatchMatmul.
  239. */
  240. REG_OP(MatMulV2)
  241. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_INT4, DT_BF16}))
  242. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_INT4, DT_BF16}))
  243. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  244. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  245. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8, DT_INT4}))
  246. .ATTR(transpose_x1, Bool, false)
  247. .ATTR(transpose_x2, Bool, false)
  248. .ATTR(offset_x, Int, 0)
  249. .OP_END_FACTORY_REG(MatMulV2)
  250. /**
  251. * @brief Multiplies matrix "a" by matrix "b", producing "a * b". \n
  252. * @par Inputs:
  253. * Five inputs, including:
  254. * @li x1: A matrix Tensor. 2D. Must be one of the following types: int8.
  255. * @li x2: A matrix Tensor. 2D. Must be one of the following types: int8.
  256. * @li compress_index: A compress index matrix of type int8.
  257. * @li bias: An optional Tensor. 1D. Must be one of the following types: int32,
  258. * float16.
  259. * @li offset_w: An optional matrix Tensor. 2D. Must be one of the following
  260. * types: int8. \n
  261. * @par Attributes:
  262. * @li transpose_x1: A bool. If True, changes the shape of "x1" from [K, M] to
  263. * [M, K] before multiplication.
  264. * @li transpose_x2: A bool. If True, changes the shape of "x2" from [N, K] to
  265. * [K, N] before multiplication.
  266. * @li offset_x: An optional integer for quantized MatMulV2Compress.
  267. * The negative offset added to the input x1 for int8 type. Ensure offset_x
  268. * within the effective range of int8 [-128, 127]. Defaults to "0". \n
  269. * @par Outputs:
  270. * y: The result matrix Tensor. 2D. Must be one of the following types: int32,
  271. * float16. \n
  272. * @attention Constraints:
  273. * if performances better in format NZ, please close
  274. * "MatmulTransdataFusionPass" in fusion configuration.
  275. */
  276. REG_OP(MatMulV2Compress)
  277. .INPUT(x1, TensorType({DT_INT8}))
  278. .INPUT(x2, TensorType({DT_INT8}))
  279. .INPUT(compress_index, TensorType({DT_INT8}))
  280. .OPTIONAL_INPUT(bias, TensorType({DT_INT32, DT_FLOAT16}))
  281. .OUTPUT(y, TensorType({DT_INT32, DT_FLOAT16}))
  282. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  283. .ATTR(transpose_x1, Bool, false)
  284. .ATTR(transpose_x2, Bool, false)
  285. .ATTR(offset_x, Int, 0)
  286. .OP_END_FACTORY_REG(MatMulV2Compress)
  287. /**
  288. * @brief Performs Matrix-to-matrix Multiply,
  289. * producing y=alpha[0]*a*b+beta[0]*c. \n
  290. * @attention Constraints:
  291. * For better performance, The k-axis must be aligned to 16 (input type
  292. * is float16) or 32 (input type is int8). \n
  293. * @par Inputs:
  294. * Five inputs, including:
  295. * @li a: A matrix Tensor. Must be one of the following types:float32, float16,
  296. * int8, int32. Has format ND.
  297. * @li b: A matrix Tensor. Must be one of the following types:float32, float16,
  298. * int8, int32. Has format ND.
  299. * @li c: A matrix Tensor. Must be one of the following types:float32, float16,
  300. * int8, int32. Has format ND.
  301. * @li alpha: A 1D Tensor. The shape of alpha is [1].Must be one of the
  302. * following types: float32, float16, int8, int32. Has format ND.
  303. * @li beta: A 1D Tensor. The shape of beta is [1]. Must be one of the following
  304. * types: float32, float16, int8, int32. Has format ND. \n
  305. * @par Attributes:
  306. * Two attributes, including:
  307. * @li transpose_a: Optional. A bool. If True, changes the shape of "a" from
  308. * [K, M] to [M, K] before multiplication.
  309. * @li transpose_b: Optional. A bool. If True, changes the shape of "b" from
  310. * [N, K] to [K, N] before multiplication. \n
  311. * @par Outputs:
  312. * y: The result matrix Tensor. Must be one of the following types: float32,
  313. * float16, int8, int32. Has format [ND], the format should be equal to a.
  314. */
  315. REG_OP(GEMM)
  316. .INPUT(a, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  317. .INPUT(b, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  318. .INPUT(c, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  319. .INPUT(alpha, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  320. .INPUT(beta, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  321. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32}))
  322. .ATTR(transpose_a, Bool, false)
  323. .ATTR(transpose_b, Bool, false)
  324. .OP_END_FACTORY_REG(GEMM)
  325. /**
  326. * @brief Multiplies matrix "a" by matrix "b", producing "a * b". \n
  327. * @par Inputs:
  328. * Two inputs, including:
  329. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  330. * float32, int32. 2D or higher. Has format [ND, NHWC].
  331. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  332. * float32, int32. 2D or higher. Has format [ND, NHWC]. \n
  333. * @par Attributes:
  334. * @li adj_x1: A bool. If True, changes the shape of "x1" from [B, K, M]
  335. * to [B, M, K] before multiplication.
  336. * @li adj_x2: A bool. If True, changes the shape of "x2" from [B, N, K]
  337. * to [B, K, N] before multiplication. \n
  338. * @par Outputs:
  339. * y: The result matrix Tensor. 2D or higher. Must be one of the following
  340. * types: float16,
  341. * float32, int32. 2D or higher. Has format [ND, NHWC]. Has the same shape
  342. * length as "x1" and "x2". \n
  343. * @par Third-party framework compatibility
  344. * Compatible with the TensorFlow operator BatchMatmul.
  345. */
  346. REG_OP(BatchMatMul)
  347. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  348. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  349. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  350. .ATTR(adj_x1, Bool, false)
  351. .ATTR(adj_x2, Bool, false)
  352. .OP_END_FACTORY_REG(BatchMatMul)
  353. /**
  354. * @brief Multiplies matrix "a" by matrix "b", producing "a * b" . \n
  355. * @par Inputs:
  356. * Three inputs, including:
  357. * @li x1: A matrix Tensor. Must be one of the following types: float16,
  358. * float32, int32, int8, int4. 2D or higher. Has format [ND, NHWC].
  359. * @li x2: A matrix Tensor. Must be one of the following types: float16,
  360. * float32, int32, int8, int4. 2D or higher. Has format [ND, NHWC].
  361. * @li bias: A optional Tensor. Must be one of the following types:
  362. * float16,
  363. * float32, int32. Has format [ND, NHWC].
  364. * @li offset_w: A optional Tensor. Must be one of the following types:
  365. * int8, int4. Has format [ND, NHWC]. \n
  366. * @par Attributes:
  367. * @li adj_x1: A bool. If True, changes the shape of "x1" from [B, K, M] to
  368. * [B, M, K] before multiplication.
  369. * @li adj_x2: A bool. If True, changes the shape of "x2" from [B, N, K] to
  370. * [B, K, N] before multiplication. \n
  371. * @par Outputs:
  372. * y: The result matrix Tensor. 2D or higher. Must be one of the following
  373. * types: float16,
  374. * float32, int32. 2D or higher. Has format [ND, NHWC]. Has the same shape
  375. * length as "x1" and "x2". \n
  376. * @attention Constraints:
  377. * if performances better in format NZ, please close
  378. * "MatmulTransdataFusionPass" in fusion configuration. \n
  379. * @par Third-party framework compatibility
  380. * Compatible with the TensorFlow operator BatchMatmul.
  381. */
  382. REG_OP(BatchMatMulV2)
  383. .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_INT4, DT_BF16}))
  384. .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_INT4, DT_BF16}))
  385. .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  386. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8, DT_INT4}))
  387. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_BF16}))
  388. .ATTR(adj_x1, Bool, false)
  389. .ATTR(adj_x2, Bool, false)
  390. .ATTR(offset_x, Int, 0)
  391. .OP_END_FACTORY_REG(BatchMatMulV2)
  392. /**
  393. * @brief Computes half the L2 norm of a tensor without the sqrt . \n
  394. * @par Inputs:
  395. * x: A Tensor.
  396. * TensorType::FloatingDataType() . \n
  397. * @par Outputs:
  398. * y: A Tensor. Has the same type as "x". \n
  399. * @attention Constraints:
  400. * if performances better in format NZ, please close
  401. * "MatmulTransdataFusionPass" in fusion configuration. \n
  402. * @par Third-party framework compatibility
  403. * Compatible with the TensorFlow operator L2Loss.
  404. */
  405. REG_OP(L2Loss)
  406. .INPUT(x, TensorType::FloatingDataType())
  407. .OUTPUT(y, TensorType::FloatingDataType())
  408. .OP_END_FACTORY_REG(L2Loss)
  409. /**
  410. * @brief: Returns a batched diagonal tensor with a given batched diagonal values . \n
  411. * @par Inputs:
  412. * x: A Tensor. Must be one of the following types:
  413. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  414. * qint8, quint8, qint32, uint16, complex128, uint32, uint64 . \n
  415. * @par Outputs:
  416. * y: A Tensor. Has the same type as "x" . \n
  417. * @par Third-party framework compatibility
  418. * Compatible with the TensorFlow operator MatrixDiag.
  419. */
  420. REG_OP(MatrixDiag)
  421. .INPUT(x, TensorType::BasicType())
  422. .OUTPUT(y, TensorType::BasicType())
  423. .OP_END_FACTORY_REG(MatrixDiag)
  424. /**
  425. * @brief: Returns a batched diagonal tensor with a given batched diagonal values . \n
  426. * @par Inputs:
  427. * Two inputs, including:
  428. * @li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  429. * @li assist: A Tensor of the same type as "x" . \n
  430. * @par Outputs:
  431. *y: A Tensor. Has the same type as "x" . \n
  432. * @par Third-party framework compatibility
  433. * Compatible with the TensorFlow operator MatrixDiag.
  434. *
  435. * @par Restrictions:
  436. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixDiag instead.
  437. */
  438. REG_OP(MatrixDiagD)
  439. .INPUT(x, TensorType::BasicType())
  440. .INPUT(assist, TensorType::BasicType())
  441. .OUTPUT(y, TensorType::BasicType())
  442. .OP_END_FACTORY_REG(MatrixDiagD)
  443. /**
  444. * @brief: Returns the batched diagonal part of a batched tensor . \n
  445. * @par Inputs:
  446. * x: A Tensor. Must be one of the following types:
  447. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  448. * qint8, quint8, qint32, uint16, complex128, uint32, uint64 . \n
  449. * @par Outputs:
  450. * y: A Tensor. Has the same type as "x" . \n
  451. * @par Third-party framework compatibility
  452. * Compatible with the TensorFlow operator MatrixDiagPart.
  453. */
  454. REG_OP(MatrixDiagPart)
  455. .INPUT(x, TensorType::BasicType())
  456. .OUTPUT(y, TensorType::BasicType())
  457. .OP_END_FACTORY_REG(MatrixDiagPart)
  458. /**
  459. * @brief: Returns the batched diagonal part of a batched tensor . \n
  460. * @par Inputs:
  461. * Two inputs, including:
  462. * @li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  463. * @li assist: A Tensor of the same type as "x" . \n
  464. * @par Outputs:
  465. * y: A Tensor. Has the same type as "x" . \n
  466. * @par Third-party framework compatibility
  467. * Compatible with the TensorFlow operator MatrixDiagPart.
  468. *
  469. * @par Restrictions:
  470. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixDiagPart instead.
  471. */
  472. REG_OP(MatrixDiagPartD)
  473. .INPUT(x, TensorType::BasicType())
  474. .INPUT(assist, TensorType::BasicType())
  475. .OUTPUT(y, TensorType::BasicType())
  476. .OP_END_FACTORY_REG(MatrixDiagPartD)
  477. /**
  478. * @brief: Returns a batched matrix tensor with new batched diagonal values . \n
  479. * @par Inputs:
  480. * Two inputs, including:
  481. * @li x: A Tensor. Must be one of the following types:
  482. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  483. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  484. * @li diagonal: A Tensor of the same type as "x" . \n
  485. * @par Outputs:
  486. * y: A Tensor. Has the same type as "x" . \n
  487. * @par Third-party framework compatibility
  488. * Compatible with the TensorFlow operator MatrixSetDiag.
  489. */
  490. REG_OP(MatrixSetDiag)
  491. .INPUT(x, TensorType::BasicType())
  492. .INPUT(diagonal, TensorType::BasicType())
  493. .OUTPUT(y, TensorType::BasicType())
  494. .OP_END_FACTORY_REG(MatrixSetDiag)
  495. /**
  496. * @brief: Returns a batched matrix tensor with new batched diagonal values . \n
  497. * @par Inputs:
  498. * Three inputs, including:
  499. * @li x: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8.
  500. * @li diagonal: A Tensor of the same type as "x".
  501. * @li assist: A Tensor of the same type as "x" . \n
  502. * @par Outputs:
  503. * y: A Tensor. Has the same type as "x" . \n
  504. * @par Third-party framework compatibility
  505. * Compatible with the TensorFlow operator MatrixSetDiag.
  506. *
  507. * @par Restrictions:
  508. * Warning: THIS FUNCTION IS DEPRECATED. Please use MatrixSetDiag instead.
  509. */
  510. REG_OP(MatrixSetDiagD)
  511. .INPUT(x, TensorType::BasicType())
  512. .INPUT(diagonal, TensorType::BasicType())
  513. .INPUT(assist, TensorType::BasicType())
  514. .OUTPUT(y, TensorType::BasicType())
  515. .OP_END_FACTORY_REG(MatrixSetDiagD)
  516. /**
  517. * @brief Function AttentionScore. \n
  518. * @par Inputs:
  519. * six inputs, including:
  520. * @li query: A matrix Tensor. The type only support float16.
  521. * @li key: A matrix Tensor. The type only support float16.
  522. * @li value: A matrix Tensor. The type only support float16.
  523. * @li padding_mask: A matrix Tensor. The type only support float16.
  524. * @li scale: A scalar. The type only support float16.
  525. * @li drop_mask: A matrix Tensor. The type only support uint8. \n
  526. * @par Attributes:
  527. * @li keep_prob: A mutable Tensor. Must met all of the following rules:
  528. shape of "keep_prob" should be (1,) or [1,].
  529. * @li query_transpose: A bool. If True, changes the shape of "query" from [K, M] to
  530. [M, K].
  531. * @li key_transpose: A bool. If True, changes the shape of "key" from [N, K] to
  532. [K, N].
  533. * @li bmm_score_transpose_a: A bool. If True, changes the shape of "mid_data" from [K, M] to
  534. [M, K].
  535. * @li bmm_score_transpose_b: A bool. If True, changes the shape of "value" from [N, K] to
  536. [K, N].
  537. * @li axes: A list of int. The dimension softmax would be performed on. Defaults
  538. to "[-1]" . \n
  539. * @par Outputs:
  540. * attention_score: The result matrix Tensor. The type only support float16.
  541. * softmax_output: The result matrix Tensor. The type only support float16.
  542. * @par Restrictions:
  543. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  544. */
  545. REG_OP(AttentionScore)
  546. .INPUT(query, TensorType({DT_FLOAT16}))
  547. .INPUT(key, TensorType({DT_FLOAT16}))
  548. .INPUT(value, TensorType({DT_FLOAT16}))
  549. .INPUT(padding_mask, TensorType({DT_FLOAT16}))
  550. .INPUT(scale, TensorType({DT_FLOAT16}))
  551. .OPTIONAL_INPUT(drop_mask, TensorType({DT_INT8}))
  552. .OUTPUT(attention_score, TensorType({DT_FLOAT16}))
  553. .OUTPUT(softmax_output, TensorType({DT_FLOAT16}))
  554. .ATTR(keep_prob, Float, 1.0)
  555. .ATTR(query_transpose, Bool, false)
  556. .ATTR(key_transpose, Bool, false)
  557. .ATTR(bmm_score_transpose_a, Bool, false)
  558. .ATTR(bmm_score_transpose_b, Bool, false)
  559. .ATTR(softmax_axes, ListInt, {-1})
  560. .OP_END_FACTORY_REG(AttentionScore)
  561. /**
  562. * @brief Function AttentionScoreGrad. \n
  563. * @par Inputs:
  564. * seven inputs, including:
  565. * @li attention_score: A matrix Tensor. The type only support float16.
  566. * @li dx: A matrix Tensor. The type only support float16.
  567. * @li query: A matrix Tensor. The type only support float16.
  568. * @li key: A matrix Tensor. The type only support float16.
  569. * @li value: A matrix Tensor. The type only support float16.
  570. * @li scale: A scalar. The type only support float16.
  571. * @li drop_mask: A matrix Tensor. The type only support uint8. \n
  572. * @par Attributes:
  573. * @li keep_prob: A mutable Tensor. Must met all of the following rules:
  574. shape of "keep_prob" should be (1,) or [1,].
  575. * @li query_transpose: A bool. If True, changes the shape of "query" from [K, M] to
  576. [M, K].
  577. * @li key_transpose: A bool. If True, changes the shape of "key" from [N, K] to
  578. [K, N].
  579. * @li value_transpose: A bool. If True, changes the shape of "mid_data" from [K, M] to
  580. [M, K].
  581. * @li dx_transpose: A bool. If True, changes the shape of "value" from [N, K] to
  582. [K, N].
  583. * @li softmax_axes: A int. The dimension softmax would be performed on. Defaults
  584. to "-1" . \n
  585. * @par Outputs:
  586. * value_dw: The result matrix Tensor. The type only support float16.
  587. * query_dx: The result matrix Tensor. The type only support float16.
  588. * key_dw: The result matrix Tensor. The type only support float16.
  589. * @par Restrictions:
  590. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  591. */
  592. REG_OP(AttentionScoreGrad)
  593. .INPUT(attention_score, TensorType({DT_FLOAT16}))
  594. .INPUT(dx, TensorType({DT_FLOAT16}))
  595. .INPUT(query, TensorType({DT_FLOAT16}))
  596. .INPUT(key, TensorType({DT_FLOAT16}))
  597. .INPUT(value, TensorType({DT_FLOAT16}))
  598. .INPUT(scale, TensorType({DT_FLOAT16}))
  599. .OPTIONAL_INPUT(drop_mask, TensorType({DT_INT8}))
  600. .OUTPUT(value_dw, TensorType({DT_FLOAT16}))
  601. .OUTPUT(query_dx, TensorType({DT_FLOAT16}))
  602. .OUTPUT(key_dw, TensorType({DT_FLOAT16}))
  603. .ATTR(keep_prob, Float, 1.0)
  604. .ATTR(query_transpose, Bool, false)
  605. .ATTR(key_transpose, Bool, false)
  606. .ATTR(value_transpose, Bool, false)
  607. .ATTR(dx_transpose, Bool, false)
  608. .ATTR(softmax_axes, Int, -1)
  609. .OP_END_FACTORY_REG(AttentionScoreGrad)
  610. /**
  611. * @brief Applies sparse "updates" to individual values or slices in a Variable . \n
  612. * @par Inputs:
  613. * Three inputs, including:
  614. * @li var: An ND Tensor.
  615. * Must be one of the following types: float16, float32, int8, uint8, double,
  616. * int64, complex64, qint8, quint8, qint32, uint16, complex128, half, uint32,
  617. * uint64
  618. * @li indices: An ND Tensor.
  619. * Must be one of the following types: int32 or int64
  620. * @li updates: An ND Tensor.
  621. * Must be one of the following types: float16, float32, int8, uint8, double,
  622. * int64, complex64, qint8, quint8, qint32, uint16, complex128, half, uint32,
  623. * uint64
  624. * @par Attributes:
  625. * use_locking: An optional bool. Defaults to "False". If "True",
  626. * the operation will be protected by a lock . \n
  627. * @par Outputs:
  628. * var: A Tensor. Has the same type and format as input "var" . \n
  629. * @par Third-party framework compatibility
  630. * Compatible with the TensorFlow operator ScatterNdUpdate.
  631. */
  632. REG_OP(ScatterNdUpdate)
  633. .INPUT(var, TensorType::BasicType())
  634. .INPUT(indices, TensorType::IndexNumberType())
  635. .INPUT(updates, TensorType::BasicType())
  636. .OUTPUT(var, TensorType::BasicType())
  637. .ATTR(use_locking, Bool, false)
  638. .OP_END_FACTORY_REG(ScatterNdUpdate)
  639. /**
  640. * @brief Applies sparse addition to individual values or slices in a Variable . \n
  641. * @par Inputs:
  642. * Three inputs, including:
  643. * @li x: An ND Tensor. \n
  644. * Must be one of the following types: float16, float32, bool, int8, uint8
  645. * @li indices: An ND Tensor. \n
  646. * Must be one of the following types: int32
  647. * @li updates: An ND Tensor. \n
  648. * Must be one of the following types: float16, float32, bool, int8, uint8
  649. * @par Outputs:
  650. * y: A Tensor. Has the same type and format as input "x" . \n
  651. * @par Third-party framework compatibility
  652. * Compatible with the TensorFlow operator TensorScatterUpdate.
  653. * @par Restrictions:
  654. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  655. */
  656. REG_OP(TensorScatterUpdate)
  657. .INPUT(x, TensorType::BasicType())
  658. .INPUT(indices, TensorType::IndexNumberType())
  659. .INPUT(updates, TensorType::BasicType())
  660. .OUTPUT(y, TensorType::BasicType())
  661. .OP_END_FACTORY_REG(TensorScatterUpdate)
  662. /**
  663. * @brief Uses "updates" to update tensor "data" by "indices". \n
  664. * @par Inputs:
  665. * Three inputs, including:
  666. * @li data: An ND Tensor . \n
  667. * Must be one of the following types: float16, float32, int32, int8, uint8
  668. * @li indices: An ND Tensor of type int32 or int64
  669. * @li updates: An Tensor. Same shape as indices. format:NCHW, NHWC . \n
  670. * Must be one of the following types: float16, float32, int32, int8, uint8
  671. * @par Attributes:
  672. * @li axis: An optional attribute. Defaults to 0.
  673. * @li reduction: An optional attribute. Defaults to string "none" and can be
  674. * "add" or "mul". \n
  675. * @par Outputs:
  676. * y: A Tensor. Has the same type and format as input "data" . \n
  677. * @par Third-party framework compatibility
  678. * Compatible with the ONNX operator ScatterElements.
  679. */
  680. REG_OP(ScatterElements)
  681. .INPUT(data, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  682. .INPUT(indices, TensorType::IndexNumberType())
  683. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  684. .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  685. .ATTR(axis, Int, 0)
  686. .ATTR(reduction, String, "none")
  687. .OP_END_FACTORY_REG(ScatterElements)
  688. /**
  689. * @brief Uses "updates" to update tensor "data" by "indices". \n
  690. * @par Inputs:
  691. * Three inputs, including:
  692. * @li var: A Tensor of type BasicType.
  693. * @li indices: An ND Tensor of type int32 or int64.
  694. * @li updates: An Tensor with the same dtype as 'var'. Same shape as indices. \n
  695. * @par Attributes:
  696. * @li use_locking: An optional bool. Defaults to "False". If "True",
  697. * the operation will be protected by a lock . \n
  698. * @par Outputs:
  699. * var: A Tensor. Has the same type and format as input "var" . \n
  700. * @par Third-party framework compatibility
  701. * Compatible with the TensorFlow operator ScatterNdMax.
  702. */
  703. REG_OP(ScatterNdMax)
  704. .INPUT(var, TensorType::BasicType())
  705. .INPUT(indices, TensorType::IndexNumberType())
  706. .INPUT(updates, TensorType::BasicType())
  707. .OUTPUT(var, TensorType::BasicType())
  708. .ATTR(use_locking, Bool, false)
  709. .OP_END_FACTORY_REG(ScatterNdMax)
  710. /**
  711. * @brief Adds sparse "updates" to a variable reference . \n
  712. * @par Inputs:
  713. * Three inputs, including:
  714. * @li var: An ND Tensor .
  715. * Must be one of the following types: float16, float, int32, int8, uint8
  716. * @li indices: An ND Tensor . \n
  717. * Must be one of the following types: int32 or int64
  718. * @li updates: An ND Tensor .
  719. * Must be one of the following types: float16, float, int32, int8, uint8
  720. * @par Attributes:
  721. * use_locking: An optional bool. Defaults to "False". If "True",
  722. * the operation will be protected by a lock . \n
  723. * @par Outputs:
  724. * var: A Tensor. Has the same type and format as input "var" . \n
  725. * @par Third-party framework compatibility
  726. * Compatible with the TensorFlow operator ScatterAdd.
  727. */
  728. REG_OP(ScatterAdd)
  729. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  730. .INPUT(indices, TensorType::IndexNumberType())
  731. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  732. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  733. .ATTR(use_locking, Bool, false)
  734. .OP_END_FACTORY_REG(ScatterAdd)
  735. /**
  736. * @brief Adds sparse "updates" to a variable reference . \n
  737. * @par Inputs:
  738. * Three inputs, including:
  739. * @li var: An ND Tensor .
  740. * Must be one of the following types: float16, float32, int32, int8, uint8
  741. * @li indices: An ND Tensor of type int32 or int64
  742. * @li updates: An ND Tensor .
  743. * Must be one of the following types: float16, float32, int32, int8, uint8
  744. * @par Attributes:
  745. * axis: An required int. The axis along which to index. \n
  746. * @par Outputs:
  747. * var: A Tensor. Has the same type and format as input "var" . \n
  748. * @par Third-party framework compatibility
  749. * Compatible with the pytorch operator ScatterAdd.
  750. */
  751. REG_OP(ScatterAddWithAxis)
  752. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  753. .INPUT(indices, TensorType::IndexNumberType())
  754. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  755. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  756. .REQUIRED_ATTR(axis, Int)
  757. .OP_END_FACTORY_REG(ScatterAddWithAxis)
  758. /**
  759. * @brief Divides a variable reference by sparse updates . \n
  760. * @par Inputs:
  761. * Three inputs, including:
  762. * @li var: An ND Tensor.
  763. * Must be one of the following types: float16, float, int32, int8, uint8
  764. * @li indices: An ND Tensor.
  765. * Must be one of the following types: int32 or int64
  766. * @li updates: An ND Tensor.
  767. * Must be one of the following types: float16, float, int32, int8, uint8
  768. * @par Attributes:
  769. * use_locking: An optional bool. Defaults to "False". If "True",
  770. * the operation will be protected by a lock . \n
  771. * @par Outputs:
  772. * var: A Tensor. Has the same type and format as input "var" . \n
  773. * @par Third-party framework compatibility
  774. * Compatible with the TensorFlow operator ScatterDiv.
  775. */
  776. REG_OP(ScatterDiv)
  777. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  778. .INPUT(indices, TensorType::IndexNumberType())
  779. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  780. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  781. .ATTR(use_locking, Bool, false)
  782. .OP_END_FACTORY_REG(ScatterDiv)
  783. /**
  784. * @brief Applies sparse addition to individual values or slices in a Variable . \n
  785. * @par Inputs:
  786. * Three inputs, including:
  787. * @li var: An ND Tensor.
  788. * Must be one of the following types: float16, float, int32, int8, uint8
  789. * @li indices: An ND Tensor.
  790. * Must be one of the following types: int32 or int64
  791. * @li updates: An ND Tensor.
  792. * Must be one of the following types: float16, float, int32, int8, uint8
  793. * @par Attributes:
  794. * use_locking: An optional bool. Defaults to "False". If "True",
  795. * the operation will be protected by a lock . \n
  796. * @par Outputs:
  797. * var: A Tensor. Has the same type and format as input "var" . \n
  798. * @par Third-party framework compatibility
  799. * Compatible with the TensorFlow operator ScatterNdAdd.
  800. */
  801. REG_OP(ScatterNdAdd)
  802. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  803. .INPUT(indices, TensorType::IndexNumberType())
  804. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  805. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  806. .ATTR(use_locking, Bool, false)
  807. .OP_END_FACTORY_REG(ScatterNdAdd)
  808. /**
  809. * @brief Applies sparse addition to individual values or slices in a Variable . \n
  810. * @par Inputs:
  811. * Three inputs, including:
  812. * @li x: An ND Tensor. \n
  813. * Must be one of the following types: float16, float32, int32, int8, uint8
  814. * @li indices: An ND Tensor. \n
  815. * Must be one of the following types: int32
  816. * @li updates: An ND Tensor. \n
  817. * Must be one of the following types: float16, float32, int32, int8, uint8
  818. * @par Outputs:
  819. * y: A Tensor. Has the same type and format as input "x" . \n
  820. * @par Third-party framework compatibility
  821. * Compatible with the TensorFlow operator TensorScatterAdd.
  822. * @par Restrictions:
  823. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  824. */
  825. REG_OP(TensorScatterAdd)
  826. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  827. .INPUT(indices, TensorType::IndexNumberType())
  828. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  829. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  830. .OP_END_FACTORY_REG(TensorScatterAdd)
  831. /**
  832. * @brief Applies sparse subtraction to individual values or slices in a Variable . \n
  833. * @par Inputs:
  834. * Three inputs, including:
  835. * @li var: An ND Tensor.
  836. * Must be one of the following types: float16, float, int32, int8, uint8
  837. * @li indices: An ND Tensor.
  838. * Must be one of the following types: int32 or int64
  839. * @li updates: An ND Tensor.
  840. * Must be one of the following types: float16, float, int32, int8, uint8
  841. * @par Attributes:
  842. *use_locking: An optional bool. Defaults to "False". If "True",
  843. * the operation will be protected by a lock . \n
  844. * @par Outputs:
  845. * var: A Tensor. Has the same type and format as input "var" . \n
  846. * @par Third-party framework compatibility
  847. * Compatible with the TensorFlow operator ScatterNdSub.
  848. */
  849. REG_OP(ScatterNdSub)
  850. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  851. .INPUT(indices, TensorType::IndexNumberType())
  852. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  853. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  854. .ATTR(use_locking, Bool, false)
  855. .OP_END_FACTORY_REG(ScatterNdSub)
  856. /**
  857. * @brief Uses "updates" to update tensor "data" by "indices". \n
  858. * @par Inputs:
  859. * Three inputs, including:
  860. * @li var: A Tensor of type BasicType.
  861. * @li indices: A ND Tensor of type int32 or int64.
  862. * @li updates: A Tensor with the same dtype as 'var'. Same shape as indices. \n
  863. * @par Attributes:
  864. * use_locking: An optional bool. Defaults to "False". If "True",
  865. * the operation will be protected by a lock . \n
  866. * @par Outputs:
  867. * var: A Tensor. Has the same type and format as input "var" . \n
  868. * @par Third-party framework compatibility
  869. * Compatible with the TensorFlow operator ScatterNdMin.
  870. */
  871. REG_OP(ScatterNdMin)
  872. .INPUT(var, TensorType::BasicType())
  873. .INPUT(indices, TensorType::IndexNumberType())
  874. .INPUT(updates, TensorType::BasicType())
  875. .OUTPUT(var, TensorType::BasicType())
  876. .ATTR(use_locking, Bool, false)
  877. .OP_END_FACTORY_REG(ScatterNdMin)
  878. /**
  879. * @brief Applies sparse addition to individual values or slices in a Variable . \n
  880. * @par Inputs:
  881. * Three inputs, including:
  882. * @li x: An ND Tensor. \n
  883. * Must be one of the following types: float16, float32, int32, int8, uint8
  884. * @li indices: An ND Tensor. \n
  885. * Must be one of the following types: int32
  886. * @li updates: An ND Tensor. \n
  887. * Must be one of the following types: float16, float32, int32, int8, uint8
  888. * @par Outputs:
  889. * y: A Tensor. Has the same type and format as input "x" . \n
  890. * @par Third-party framework compatibility
  891. * Compatible with the TensorFlow operator TensorScatterSub.
  892. * @par Restrictions:
  893. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  894. */
  895. REG_OP(TensorScatterSub)
  896. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  897. .INPUT(indices, TensorType::IndexNumberType())
  898. .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  899. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  900. .OP_END_FACTORY_REG(TensorScatterSub)
  901. /**
  902. * @brief Subtracts sparse updates to a variable reference . \n
  903. * @par Inputs:
  904. * Three inputs, including:
  905. * @li var: An ND Tensor.
  906. * Must be one of the following types: float16, float, int32, int8, uint8
  907. * @li indices: An ND Tensor.
  908. * Must be one of the following types: int32 or int64
  909. * @li updates: An ND Tensor.
  910. * Must be one of the following types: float16, float, int32, int8, uint8
  911. * @par Attributes:
  912. * use_locking: An optional bool. Defaults to "False". If "True",
  913. * the operation will be protected by a lock . \n
  914. * @par Outputs:
  915. * var: A Tensor. Has the same type and format as input "var" . \n
  916. * @par Third-party framework compatibility
  917. * Compatible with the TensorFlow operator ScatterSub.
  918. */
  919. REG_OP(ScatterSub)
  920. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  921. .INPUT(indices, TensorType::IndexNumberType())
  922. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  923. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  924. .ATTR(use_locking, Bool, false)
  925. .OP_END_FACTORY_REG(ScatterSub)
  926. /**
  927. * @brief: Returns the batched diagonal part of a batched tensor with "assist" . \n
  928. * @par Inputs:
  929. * Two inputs, including:
  930. * @li x: A Tensor of type float16, float32, or int32.
  931. * @li assist: A Tensor of the same type as "x" . \n
  932. * @par Outputs:
  933. * y: A Tensor. Has the same type as "x" . \n
  934. * @par Third-party framework compatibility
  935. * Compatible with the TensorFlow operator DiagPart.
  936. *
  937. * @par Restrictions:
  938. * Warning: THIS FUNCTION IS DEPRECATED. Please use DiagPart instead.
  939. */
  940. REG_OP(DiagPartD)
  941. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  942. .INPUT(assist, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  943. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32}))
  944. .OP_END_FACTORY_REG(DiagPartD)
  945. /**
  946. * @brief: Returns the batched diagonal part of a batched tensor . \n
  947. * @par Inputs:
  948. * x: A Tensor. Must be one of the following types:
  949. * float16, float32, int32, int64, double, complex64, complex128 . \n
  950. * @par Outputs:
  951. * y: A Tensor. Has the same type as "x" . \n
  952. * @par Third-party framework compatibility
  953. * Compatible with the TensorFlow operator DiagPart.
  954. */
  955. REG_OP(DiagPart)
  956. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  957. DT_COMPLEX64, DT_COMPLEX128}))
  958. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT64, DT_DOUBLE,
  959. DT_COMPLEX64, DT_COMPLEX128}))
  960. .OP_END_FACTORY_REG(DiagPart)
  961. /**
  962. * @brief Also known as a "fully-connected" layer, computes an inner product
  963. * with a set of learned weights, and (optionally) adds biases. \n
  964. * @par Inputs:
  965. * Four inputs, including:
  966. * @li x: A Tensor of type float16, int8, int4.
  967. * @li w: A weight matrix of type float16, int8, int4, float32.
  968. * @li b: An optional Tensor of type float16, int32, float32.
  969. * @li offset_w: An optional Tensor of type int8, int4.
  970. * Reserved. Only None Supported. \n
  971. * @par Attributes:
  972. * @li num_output: Required. An int, output neuron number. Reserved.
  973. * @li transpose: A bool, specifying weight whether to transpose input w,
  974. * either "true" or "false". Defaults to "false".
  975. * @li axis: Optional. An int, 1 or 2, specifying which dimension the input
  976. * "K" starts from. Defaults to 1.
  977. * The product of the subsequent dimensions starting form first dimension
  978. * or the second dimension is "K".
  979. * @li offset_x: An optional integer for quantized FullyConnection.
  980. * The negative offset added to the input image for int8 type. Ensure offset_x
  981. * within the effective range of int8 [-128, 127]. Defaults to "0". \n
  982. * @par Outputs:
  983. * y: The result tensor of type float16, int32, float32. \n
  984. * @par Third-party framework compatibility
  985. * Compatible with the Caffe operator InnerProduct. \n
  986. * @par Quantization supported or not
  987. * Yes
  988. */
  989. REG_OP(FullyConnection)
  990. .INPUT(x, TensorType({DT_FLOAT16, DT_INT8, DT_INT4, DT_FLOAT, DT_BF16}))
  991. .INPUT(w, TensorType({DT_FLOAT16, DT_INT8, DT_INT4, DT_FLOAT, DT_BF16}))
  992. .OPTIONAL_INPUT(b, TensorType({DT_FLOAT16, DT_INT32, DT_FLOAT, DT_BF16}))
  993. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8, DT_INT4}))
  994. .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32, DT_FLOAT, DT_BF16}))
  995. .REQUIRED_ATTR(num_output, Int)
  996. .ATTR(transpose, Bool, false)
  997. .ATTR(axis, Int, 1)
  998. .ATTR(offset_x, Int, 0)
  999. .OP_END_FACTORY_REG(FullyConnection)
  1000. /**
  1001. * @brief Also known as a "fully-connected-compress" layer, computes an inner
  1002. * product with a set of learned weights, and (optionally) adds biases. \n
  1003. * @par Inputs:
  1004. * Five inputs, including:
  1005. * @li x: A Tensor of type uint8, int8.
  1006. * @li w: A weight matrix of type int8.
  1007. * @li compress_index: A compress index matrix of type int8.
  1008. * @li b: A optional Tensor of type int32.
  1009. * @li offset_w: A optional Tensor of type int8.
  1010. * @par Attributes:
  1011. * @li num_output: A int, specifying the number of outputs.
  1012. * @li transpose: A bool, specifying whether to transpose input w, either "true"
  1013. * or "false". Defaults to "false".
  1014. * @li axis: Optional. A int, 1 or 2, specifying which dimension the input "K"
  1015. * starts from. Defaults to "1".
  1016. * The product of the subsequent dimensions starting form first dimension or the
  1017. * second dimension is "K".
  1018. * @li offset_x: An optional integer for quantized FullyConnectionCompress.
  1019. * The negative offset added to the input image for int8 type. Ensure offset_x
  1020. * within the effective range of int8 [-128, 127]. Defaults to "0". \n
  1021. * @par Outputs:
  1022. * y: The result tensor of type int32. \n
  1023. * @par Third-party framework compatibility
  1024. * Compatible with the Caffe operator InnerProduct. \n
  1025. * @par Quantization supported or not
  1026. * Yes
  1027. */
  1028. REG_OP(FullyConnectionCompress)
  1029. .INPUT(x, TensorType({DT_UINT8, DT_INT8}))
  1030. .INPUT(w, TensorType({DT_INT8}))
  1031. .INPUT(comress_index, TensorType({DT_INT8}))
  1032. .OPTIONAL_INPUT(b, TensorType({DT_INT32}))
  1033. .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
  1034. .OUTPUT(y, TensorType({DT_INT32}))
  1035. .REQUIRED_ATTR(num_output, Int)
  1036. .ATTR(transpose, Bool, false)
  1037. .ATTR(axis, Int, 1)
  1038. .ATTR(offset_x, Int, 0)
  1039. .OP_END_FACTORY_REG(FullyConnectionCompress)
  1040. /**
  1041. * @brief Computes the confusion matrix from predictions and labels . \n
  1042. * @par Inputs:
  1043. * Three inputs, including:
  1044. * @li labels: A Tensor. Must be one of the following types: float16, float32,
  1045. * int32, int8, uint8.
  1046. * @li predictions: A Tensor. Must be one of the following types: float16,
  1047. * float32, int32, int8, uint8.
  1048. * @li weights: A Tensor. Must be one of the following types: float16, float32,
  1049. * int32, int8, uint8 . \n
  1050. * @par Attributes:
  1051. * @li num_classes: An integer for the shape of the output matrix.
  1052. * No default value.
  1053. * @li dtype: Data type of the confusion matrix. No default value . \n
  1054. * @par Outputs:
  1055. * y: A Tensor. Has the same type and format as input "labels"
  1056. * @attention Constraints:
  1057. * @li "weights", "labels", and "predictions" are 1D tensors.
  1058. * @li The output is with shape (num_classes, num_classes),
  1059. * where, 1 <= num_classes <= 4096 . \n
  1060. * @see Region()
  1061. * @par Third-party framework compatibility
  1062. * Compatible with the TensorFlow operator ConfusionMatrix.
  1063. */
  1064. REG_OP(ConfusionMatrix)
  1065. .INPUT(labels, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  1066. .INPUT(predictions, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  1067. .OPTIONAL_INPUT(weights, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  1068. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16, DT_INT8, DT_UINT8}))
  1069. .REQUIRED_ATTR(num_classes, Int)
  1070. .REQUIRED_ATTR(dtype, String)
  1071. .OP_END_FACTORY_REG(ConfusionMatrix)
  1072. /**
  1073. * @brief Multiplies sparse updates into a variable reference . \n
  1074. * @par Inputs:
  1075. * Three inputs, including:
  1076. * @li var: An ND Tensor.
  1077. * Must be one of the following types: float16, float, int32, int8, uint8
  1078. * @li indices: An ND Tensor.
  1079. * Must be one of the following types: int32 or int64
  1080. * @li updates: An ND Tensor . \n
  1081. * Must be one of the following types: float16, float, int32, int8, uint8
  1082. * @par Attributes:
  1083. * use_locking: An optional bool. Defaults to "False". If "True", the operation
  1084. * will be protected by a lock . \n
  1085. * @par Outputs:
  1086. * var: A Tensor. Has the same type and format as input "var" . \n
  1087. * @par Third-party framework compatibility
  1088. * Compatible with the TensorFlow operator ScatterMul.
  1089. */
  1090. REG_OP(ScatterMul)
  1091. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1092. .INPUT(indices, TensorType::IndexNumberType())
  1093. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1094. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1095. .ATTR(use_locking, Bool, false)
  1096. .OP_END_FACTORY_REG(ScatterMul)
  1097. /**
  1098. * @brief Reduces sparse updates into a variable reference using
  1099. * the "min" operation . \n
  1100. * @par Inputs:
  1101. * Three inputs, including:
  1102. * @li var: An ND Tensor.
  1103. * Must be one of the following types: float16, float, int32, int8, uint8
  1104. * @li indices: An ND Tensor.
  1105. * Must be one of the following types: int32 or int64
  1106. * @li updates: An ND Tensor.
  1107. * Must be one of the following types: float16, float, int32, int8, uint8
  1108. * @par Attributes:
  1109. * use_locking: An optional bool. Defaults to "False". If "True", the operation
  1110. * will be protected by a lock . \n
  1111. * @par Outputs:
  1112. * var: A Tensor. Has the same type and format as input "var" . \n
  1113. * @par Third-party framework compatibility
  1114. * Compatible with the TensorFlow operator ScatterMin.
  1115. */
  1116. REG_OP(ScatterMin)
  1117. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1118. .INPUT(indices, TensorType::IndexNumberType())
  1119. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1120. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1121. .ATTR(use_locking, Bool, false)
  1122. .OP_END_FACTORY_REG(ScatterMin)
  1123. /**
  1124. * @brief Reduces sparse updates into a variable reference using the "max" operation . \n
  1125. * @par Inputs:
  1126. * Three inputs, including:
  1127. * @li var: An ND Tensor .
  1128. * Must be one of the following types: float16, float, int32, int8, uint8
  1129. * @li indices: An NCHW, NHWC, or ND Tensor . \n
  1130. * Must be one of the following types: int32 or int64
  1131. * @li updates: An NCHW, NHWC, or ND Tensor .
  1132. * Must be one of the following types: float16, float, int32, int8, uint8
  1133. * @par Attributes:
  1134. * use_locking: An optional bool. Defaults to "False".
  1135. * If "True", the operation will be protected by a lock . \n
  1136. * @par Outputs:
  1137. * var: A Tensor. Has the same type and format as input "var" . \n
  1138. * @par Third-party framework compatibility
  1139. * Compatible with the TensorFlow operator ScatterMax.
  1140. */
  1141. REG_OP(ScatterMax)
  1142. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1143. .INPUT(indices, TensorType::IndexNumberType())
  1144. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1145. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1146. .ATTR(use_locking, Bool, false)
  1147. .OP_END_FACTORY_REG(ScatterMax)
  1148. /**
  1149. * @brief Applies sparse updates to a variable reference . \n
  1150. * @par Inputs:
  1151. * Three inputs, including:
  1152. * @li var: An ND Tensor .
  1153. * Must be one of the following types: float16, float, int32, int8, uint8
  1154. * @li indices: An ND Tensor . \n
  1155. * Must be one of the following types: int32 or int64
  1156. * @li updates: An ND Tensor .
  1157. * Must be one of the following types: float16, float, int32, int8, uint8
  1158. * @par Attributes:
  1159. * use_locking: An optional bool. Defaults to "False". If "True",
  1160. * the operation will be protected by a lock . \n
  1161. * @par Outputs:
  1162. * var: A Tensor. Has the same type and format as input "var" . \n
  1163. * @par Third-party framework compatibility
  1164. * Compatible with the TensorFlow operator ScatterUpdate.
  1165. */
  1166. REG_OP(ScatterUpdate)
  1167. .INPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1168. .INPUT(indices, TensorType::IndexNumberType())
  1169. .INPUT(updates, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1170. .OUTPUT(var, TensorType({DT_FLOAT16,DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8}))
  1171. .ATTR(use_locking, Bool, false)
  1172. .OP_END_FACTORY_REG(ScatterUpdate)
  1173. /**
  1174. * @brief Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched `input` . \n
  1175. * @par Inputs:
  1176. * Three inputs, including:
  1177. * @li input: Rank `r` tensor where `r >= 2`. \n
  1178. * @li k: \n
  1179. * Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  1180. * diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  1181. * (for a single diagonal) or a pair of integers specifying the low and high ends \n
  1182. * of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1183. * @li padding_value: The value to fill the area outside the specified diagonal band with. \n
  1184. * @par Outputs:
  1185. * diagonal: The extracted diagonal(s) . \n
  1186. * @par Third-party framework compatibility
  1187. * Compatible with the TensorFlow operator ScatterUpdate.
  1188. */
  1189. REG_OP(MatrixDiagPartV2)
  1190. .INPUT(input, TensorType::BasicType())
  1191. .INPUT(k, TensorType({DT_INT32}))
  1192. .INPUT(padding_value, TensorType::BasicType())
  1193. .OUTPUT(diagonal, TensorType::BasicType())
  1194. .OP_END_FACTORY_REG(MatrixDiagPartV2)
  1195. /**
  1196. * @brief Returns a batched matrix tensor with new batched diagonal values . \n
  1197. * @par Inputs:
  1198. * Three inputs, including:
  1199. * @li input: "Rank `r+1`, where `r >= 1`. \n
  1200. * @li diagonal: Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`. \n
  1201. * @li k:
  1202. * Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  1203. * diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  1204. * (for a single diagonal) or a pair of integers specifying the low and high ends \n
  1205. * of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1206. * @par Outputs:
  1207. * output: Rank `r+1`, with `output.shape = input.shape` . \n
  1208. * @par Third-party framework compatibility
  1209. * Compatible with the TensorFlow operator ScatterUpdate.
  1210. */
  1211. REG_OP(MatrixSetDiagV2)
  1212. .INPUT(input, TensorType::BasicType())
  1213. .INPUT(diagonal, TensorType::BasicType())
  1214. .INPUT(k, TensorType({DT_INT32}))
  1215. .OUTPUT(output, TensorType::BasicType())
  1216. .OP_END_FACTORY_REG(MatrixSetDiagV2)
  1217. /**
  1218. * @brief Returns a batched matrix tensor with new batched diagonal values . \n
  1219. * @par Inputs:
  1220. * Three inputs, including:
  1221. * @li input: "Rank `r+1`, where `r >= 1`. \n
  1222. * @li diagonal: Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`. \n
  1223. * @li k:
  1224. * Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  1225. * diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  1226. * (for a single diagonal) or a pair of integers specifying the low and high ends \n
  1227. * of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1228. * @par Attributes:
  1229. * @li align: An optional string. Defaults to RIGHT_LEFT. It is a string specifying \n
  1230. * how superdiagonals and subdiagonals should be aligned, respectively. \n
  1231. * other optional: LEFT_RIGHT, LEFT_LEFT, and RIGHT_RIGHT.\n
  1232. * @par Outputs:
  1233. * output: Rank `r+1`, with `output.shape = input.shape` . \n
  1234. * @par Third-party framework compatibility
  1235. * Compatible with the TensorFlow operator ScatterUpdate.
  1236. */
  1237. REG_OP(MatrixSetDiagV3)
  1238. .INPUT(input, TensorType::BasicType())
  1239. .INPUT(diagonal, TensorType::BasicType())
  1240. .INPUT(k, TensorType({DT_INT32}))
  1241. .OUTPUT(output, TensorType::BasicType())
  1242. .ATTR(align, String, "RIGHT_LEFT")
  1243. .OP_END_FACTORY_REG(MatrixSetDiagV3)
  1244. /**
  1245. * @brief Returns a batched diagonal tensor with given batched diagonal values . \n
  1246. * @par Inputs:
  1247. * Five inputs, including:
  1248. * @li diagonal: Rank `r`, where `r >= 1` \n
  1249. * @li k:
  1250. * Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n
  1251. * diagonal, and negative value means subdiagonals. `k` can be a single integer \n
  1252. * (for a single diagonal) or a pair of integers specifying the low and high ends \n
  1253. * of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1254. * @li num_rows:
  1255. * The number of rows of the output matrix. If it is not provided, the op assumes \n
  1256. * the output matrix is a square matrix and infers the matrix size from k and the \n
  1257. * innermost dimension of `diagonal`. \n
  1258. * @li num_cols: An NCHW, NHWC, or ND Tensor.
  1259. * The number of columns of the output matrix. If it is not provided, the op \n
  1260. * assumes the output matrix is a square matrix and infers the matrix size from \n
  1261. * k and the innermost dimension of `diagonal`. \n
  1262. * @li padding_value: The number to fill the area outside the specified diagonal band with. \n
  1263. * @par Outputs:
  1264. * output: Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise . \n
  1265. * @par Third-party framework compatibility
  1266. * Compatible with the TensorFlow operator ScatterUpdate.
  1267. */
  1268. REG_OP(MatrixDiagV2)
  1269. .INPUT(diagonal, TensorType::BasicType())
  1270. .INPUT(k, TensorType({DT_INT32}))
  1271. .INPUT(num_rows, TensorType({DT_INT32}))
  1272. .INPUT(num_cols, TensorType({DT_INT32}))
  1273. .INPUT(padding_value, TensorType::BasicType())
  1274. .OUTPUT(output, TensorType::BasicType())
  1275. .OP_END_FACTORY_REG(MatrixDiagV2)
  1276. /**
  1277. * @brief Add updates to var_out according to axis and indices.
  1278. * @par Inputs:
  1279. * Three inputs, including:
  1280. * @li var: A Tensor. Must be one of the following types:
  1281. * float16, float32, int32, int8, uint8.
  1282. * @li indices: A Tensor of the indices, type should be int32.
  1283. * @li updates: A Tensor of the same type as "var".
  1284. * @par Attributes:
  1285. * @li axis: An required int to specify the axis to perform indices add.
  1286. * @par Outputs:
  1287. * @li var_out: A Tensor. Same as input "var".
  1288. * @par Third-party framework compatibility
  1289. * Compatible with the Pytorch operator index_add.
  1290. * @par Restrictions:
  1291. * Warning:THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  1292. */
  1293. REG_OP(IndexAdd)
  1294. .INPUT(var, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  1295. .INPUT(indices, TensorType({DT_INT32}))
  1296. .INPUT(updates, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  1297. .OUTPUT(var_out, TensorType({DT_INT32, DT_INT8, DT_UINT8, DT_FLOAT32, DT_FLOAT16}))
  1298. .ATTR(axis, Int, 0)
  1299. .OP_END_FACTORY_REG(IndexAdd)
  1300. /**
  1301. * @brief According to the index number of indexes, replace the value
  1302. * corresponding to X1 with the value in x2.
  1303. * @par Inputs:
  1304. * Three inputs, including:
  1305. * @li x1: A Tensor. Must be one of the following types:
  1306. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1307. * qint8, quint8, qint32, uint16, complex128, uint32, uint64. \n
  1308. * @li x2: A Tensor of the same type as "x1".
  1309. * @li indices: A Tensor of the indices,
  1310. * @par Attributes:
  1311. * @li accumulate: Does it support self accumulation.Defaults to 0.
  1312. * @par Outputs:
  1313. * @li y: A Tensor. Same as input "x1".
  1314. * @par Third-party framework compatibility
  1315. * Compatible with the Pytorch operator index_put.
  1316. * @par Restrictions:
  1317. * Warning:THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  1318. */
  1319. REG_OP(IndexPut)
  1320. .INPUT(x1, TensorType::BasicType())
  1321. .INPUT(x2, TensorType::BasicType())
  1322. .OUTPUT(y, TensorType::BasicType())
  1323. .REQUIRED_ATTR(indices, ListInt)
  1324. .ATTR(accumulate, Int, 0)
  1325. .OP_END_FACTORY_REG(IndexPut)
  1326. /**
  1327. * @brief: Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input \n
  1328. * @par Inputs:
  1329. * x: A Tensor. Must be one of the following types:
  1330. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1331. * qint8, quint8, qint32, uint16, complex128, uint32, uint64. \n
  1332. * @par Attributes:
  1333. * diagonal: An optional attribute indicates the diagonal to consider. \n
  1334. * @par Outputs:
  1335. * y: A Tensor. Has the same type as "x" . \n
  1336. * @par Third-party framework compatibility
  1337. * Compatible with the Pytorch operator Triu.
  1338. */
  1339. REG_OP(Triu)
  1340. .INPUT(x, TensorType::BasicType())
  1341. .ATTR(diagonal, Int, 0)
  1342. .OUTPUT(y, TensorType::BasicType())
  1343. .OP_END_FACTORY_REG(Triu)
  1344. /**
  1345. * @brief: Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input \n
  1346. *@par Inputs:
  1347. * x: A Tensor. Must be one of the following types:
  1348. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1349. * qint8, quint8, qint32, uint16, complex128, uint32, uint64. \n
  1350. * @par Attributes:
  1351. * diagonal: An optional attribute indicates the diagonal to consider. \n
  1352. * @par Outputs:
  1353. * y: A Tensor. Has the same type as "x" . \n
  1354. * @par Third-party framework compatibility
  1355. * Compatible with the Pytorch operator Tril.
  1356. */
  1357. REG_OP(Tril)
  1358. .INPUT(x, TensorType::BasicType())
  1359. .ATTR(diagonal, Int, 0)
  1360. .OUTPUT(y, TensorType::BasicType())
  1361. .OP_END_FACTORY_REG(Tril)
  1362. /**
  1363. * @brief Concatenates a list of N tensors along the first dimension.
  1364. * @par Inputs:
  1365. * @li x: A list of Tensors. Must be one of the following types: int32,
  1366. * float16, float32. Tensors to be concatenated. All must have size 1 in
  1367. * the first dimension and same shape. It's a dynamic input. \n
  1368. * @par Attributes:
  1369. * @li equation: The subscripts for the Einstein summation. \n
  1370. * @li N: tensor size of input. \n
  1371. * @par Outputs:
  1372. * @li y: Sums the product of the elements of the input operands along
  1373. * dimensions specified
  1374. * using a notation based on the Einstein summation convention. \n
  1375. * @attention Constraints:
  1376. * Input N must be Int. \n
  1377. * @par Third-party framework compatibility
  1378. * Compatible with Tensorflow 2.x einsum operator.
  1379. */
  1380. REG_OP(Einsum)
  1381. .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
  1382. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
  1383. .REQUIRED_ATTR(equation, String)
  1384. .REQUIRED_ATTR(N, Int)
  1385. .OP_END_FACTORY_REG(Einsum)
  1386. /**
  1387. * @brief Returns a 2-D tensor with ones on the diagonal and zeros elsewhere. \n
  1388. * @par Inputs:
  1389. * No inputs
  1390. * @par Attributes:
  1391. * @li num_rows: An required int. \n
  1392. * @li num_columns: An optional int.Defaults to 0. \n
  1393. * @li batch_shape: An optional ListInt.Defaults to []. \n
  1394. * @li dtype: An optional int.Defaults to 0. \n
  1395. * @par Outputs:
  1396. * y: A Tensor with targeted type and shape. \n
  1397. * @par Third-party framework compatibility
  1398. * Compatible with the Pytorch operator Eye. \n
  1399. */
  1400. REG_OP(Eye)
  1401. .OUTPUT(y, TensorType::BasicType()) /* "Result, has targeted element type" */
  1402. .REQUIRED_ATTR(num_rows, Int)
  1403. .ATTR(num_columns, Int, 0)
  1404. .ATTR(batch_shape, ListInt, {})
  1405. .ATTR(dtype, Int, 0)
  1406. .OP_END_FACTORY_REG(Eye)
  1407. /**
  1408. * @brief: Fill diagonal of at least 2 dimension tensors with value . \n
  1409. * @par Inputs:
  1410. * x: A Tensor. Must be one of the following types:
  1411. * float32, int32, int64 . \n
  1412. * @par Outputs:
  1413. *y: A Tensor. Has the same type as "x" . \n
  1414. * @par Attributes:
  1415. * fill_value:The value to fill in
  1416. * wrap: An optional bool. Defaults to "False". If "True", Use recursive fill. \n
  1417. * @par Third-party framework compatibility
  1418. * Compatible with the Pytorch operator FillDiagonal.
  1419. */
  1420. REG_OP(FillDiagonal)
  1421. .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_INT64}))
  1422. .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64}))
  1423. .REQUIRED_ATTR(fill_value, Float)
  1424. .ATTR(wrap, Bool, false)
  1425. .OP_END_FACTORY_REG(FillDiagonal)
  1426. /**
  1427. * @brief: Returns the sum of the elements of the diagonal of the input 2-D matrix. \n
  1428. * @par Inputs:
  1429. * x: A Tensor. Must be one of the following types:
  1430. * float16, float. \n
  1431. * @par Outputs:
  1432. * y: A Tensor. Has the same type as "x" . \n
  1433. * @par Third-party framework compatibility
  1434. * Compatible with the Pytorch operator Trace.
  1435. */
  1436. REG_OP(Trace)
  1437. .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT}))
  1438. .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT}))
  1439. .OP_END_FACTORY_REG(Trace)
  1440. /**
  1441. * @brief Computes the generalized inverse of any matrix. \n
  1442. * @par Inputs:
  1443. * @li x: input matrix. Must be one of the following types:
  1444. * double, float. \n
  1445. * @par Attributes:
  1446. * @li rcond: An optional float >= 0 or inf. Defaults to 1e-15. \n
  1447. * @par Outputs:
  1448. * y: A Tensor with the same type and shape of x's transpose. \n
  1449. */
  1450. REG_OP(Pinverse)
  1451. .INPUT(x, TensorType({ DT_FLOAT, DT_DOUBLE }))
  1452. .OUTPUT(y, TensorType({ DT_FLOAT, DT_DOUBLE }))
  1453. .ATTR(rcond, Float, 1e-15)
  1454. .OP_END_FACTORY_REG(Pinverse)
  1455. /**
  1456. * @brief From the input tensor and updates tensor, select the maximum value according to indices to output. \n
  1457. * @par Inputs:
  1458. * Three inputs, including:
  1459. * @li input: Must be one of the following types:
  1460. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1461. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  1462. * @li indices: Must be one of the following types:
  1463. * int32, int64.
  1464. * @li updates: Must have the same type as input. \n
  1465. * @par Outputs:
  1466. * output: A Tensor with the same type as input. \n
  1467. */
  1468. REG_OP(TensorScatterMax)
  1469. .INPUT(input, TensorType::BasicType())
  1470. .INPUT(indices, TensorType::IndexNumberType())
  1471. .INPUT(updates, TensorType::BasicType())
  1472. .OUTPUT(output, TensorType::BasicType())
  1473. .OP_END_FACTORY_REG(TensorScatterMax)
  1474. /**
  1475. * @brief From the input tensor and updates tensor, select the minimum value according to indices to output. \n
  1476. * @par Inputs:
  1477. * Three inputs, including:
  1478. * @li input: Must be one of the following types:
  1479. * float16, float32, double, int32, uint8, int16, int8, complex64, int64,
  1480. * qint8, quint8, qint32, uint16, complex128, uint32, uint64.
  1481. * @li indices: Must be one of the following types:
  1482. * int32, int64.
  1483. * @li updates: Must have the same type as input. \n
  1484. * @par Outputs:
  1485. * output: A Tensor with the same type as input. \n
  1486. */
  1487. REG_OP(TensorScatterMin)
  1488. .INPUT(input, TensorType::BasicType())
  1489. .INPUT(indices, TensorType::IndexNumberType())
  1490. .INPUT(updates, TensorType::BasicType())
  1491. .OUTPUT(output, TensorType::BasicType())
  1492. .OP_END_FACTORY_REG(TensorScatterMin)
  1493. /**
  1494. * @brief: Returns the batched diagonal part of a batched tensor. \n
  1495. * @par Inputs:
  1496. * @li x: A Tensor. Rank r tensor where r >= 2.
  1497. * @li k: A Tensor of type int32. Diagonal offset(s). Positive value means superdiagonal,
  1498. 0 refers to the main diagonal, and negative value means subdiagonals. k can be a
  1499. single integer (for a single diagonal) or a pair of integers specifying the low and
  1500. high ends of a matrix band. k[0] must not be larger than k[1].
  1501. * @li padding_value:A Tensor. Must have the same type as input. The value to fill the area
  1502. outside the specified diagonal band with. Default is 0. \n
  1503. * @par Outputs:
  1504. * @li y: A Tensor. Has the same type as "input". \n
  1505. * @par Attributes:
  1506. * @li align:An optional string from: "LEFT_RIGHT", "RIGHT_LEFT", "LEFT_LEFT", "RIGHT_RIGHT". Defaults to "RIGHT_LEFT".
  1507. * @par Third-party framework compatibility
  1508. * Compatible with the Tensorflow operator FillDiagonal.
  1509. */
  1510. REG_OP(MatrixDiagPartV3)
  1511. .INPUT(x, TensorType::BasicType())
  1512. .INPUT(k, TensorType({DT_INT32}))
  1513. .INPUT(padding_value, TensorType::BasicType())
  1514. .OUTPUT(y, TensorType::BasicType())
  1515. .ATTR(align,String ,"RIGHT_LEFT")
  1516. .OP_END_FACTORY_REG(MatrixDiagPartV3)
  1517. /**
  1518. * @brief Returns a batched diagonal tensor with given batched diagonal values . \n
  1519. * @par Inputs:
  1520. * Five inputs, including:
  1521. * @li x: Rank `r`, where `r >= 1` \n
  1522. * @li k:
  1523. * Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main
  1524. * diagonal, and negative value means subdiagonals. `k` can be a single integer
  1525. * (for a single diagonal) or a pair of integers specifying the low and high ends
  1526. * of a matrix band. `k[0]` must not be larger than `k[1]`. \n
  1527. * @li num_rows:
  1528. * The number of rows of the output matrix. If it is not provided, the op assumes
  1529. * the output matrix is a square matrix and infers the matrix size from k and the
  1530. * innermost dimension of `diagonal`. \n
  1531. * @li num_cols: An NCHW, NHWC, or ND Tensor.
  1532. * The number of columns of the output matrix. If it is not provided, the op
  1533. * assumes the output matrix is a square matrix and infers the matrix size from
  1534. * k and the innermost dimension of `diagonal`. \n
  1535. * @li padding_value: The number to fill the area outside the specified diagonal band with. \n
  1536. * @par Attributes:
  1537. * @li align: An optional string from: "LEFT_RIGHT", "RIGHT_LEFT", "LEFT_LEFT", "RIGHT_RIGHT".
  1538. * Defaults to "RIGHT_LEFT" \n
  1539. * @par Outputs:
  1540. * @li y: Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise . \n
  1541. * @par Third-party framework compatibility
  1542. * Compatible with the TensorFlow operator ScatterUpdate.
  1543. */
  1544. REG_OP(MatrixDiagV3)
  1545. .INPUT(x, TensorType({BasicType(), DT_BOOL}))
  1546. .INPUT(k, TensorType({DT_INT32}))
  1547. .INPUT(num_rows, TensorType({DT_INT32}))
  1548. .INPUT(num_cols, TensorType({DT_INT32}))
  1549. .INPUT(padding_value, TensorType({BasicType(), DT_BOOL}))
  1550. .OUTPUT(y, TensorType({BasicType(), DT_BOOL}))
  1551. .ATTR(align, String, "RIGHT_LEFT")
  1552. .OP_END_FACTORY_REG(MatrixDiagV3)
  1553. /**
  1554. * @brief Function SwinAttentionScore. \n
  1555. * @par Inputs:
  1556. * six inputs, including:
  1557. * @li query: A matrix Tensor. The type only support float16.
  1558. * @li key: A matrix Tensor. The type only support float16.
  1559. * @li value: A matrix Tensor. The type only support float16.
  1560. * @li padding_mask1: A matrix Tensor. The type only support float16.
  1561. * @li padding_mask2: A matrix Tensor. The type only support float16.
  1562. * @li scale: A scalar. The type only support float16.
  1563. * @li drop_mask: A matrix Tensor. The type only support uint8. \n
  1564. * @par Attributes:
  1565. * @li keep_prob: A mutable Tensor. Must met all of the following rules:
  1566. shape of "keep_prob" should be (1,) or [1,].
  1567. * @li query_transpose: A bool. If True, changes the shape of "query" from [K, M] to
  1568. [M, K].
  1569. * @li key_transpose: A bool. If True, changes the shape of "key" from [N, K] to
  1570. [K, N].
  1571. * @li bmm_score_transpose_a: A bool. If True, changes the shape of "mid_data" from [K, M] to
  1572. [M, K].
  1573. * @li bmm_score_transpose_b: A bool. If True, changes the shape of "value" from [N, K] to
  1574. [K, N].
  1575. * @li axes: A list of int. The dimension softmax would be performed on. Defaults
  1576. to "[]" . \n
  1577. * @par Outputs:
  1578. * attention_score: The result matrix Tensor. The type only support float16.
  1579. * softmax: The result matrix Tensor. The type only support float16.
  1580. * @par Restrictions:
  1581. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  1582. */
  1583. REG_OP(SwinAttentionScore)
  1584. .INPUT(query, TensorType({DT_FLOAT16}))
  1585. .INPUT(key, TensorType({DT_FLOAT16}))
  1586. .INPUT(value, TensorType({DT_FLOAT16}))
  1587. .INPUT(padding_mask1, TensorType({DT_FLOAT16}))
  1588. .OPTIONAL_INPUT(padding_mask2, TensorType({DT_FLOAT16}))
  1589. .INPUT(scale, TensorType({DT_FLOAT16}))
  1590. .OPTIONAL_INPUT(drop_mask, TensorType({DT_INT8}))
  1591. .OUTPUT(attention_score, TensorType({DT_FLOAT16}))
  1592. .OUTPUT(softmax, TensorType({DT_FLOAT16}))
  1593. .ATTR(keep_prob, Float, 1.0)
  1594. .ATTR(query_transpose, Bool, false)
  1595. .ATTR(key_transpose, Bool, false)
  1596. .ATTR(bmm_score_transpose_a, Bool, false)
  1597. .ATTR(bmm_score_transpose_b, Bool, false)
  1598. .ATTR(softmax_axes, ListInt, {})
  1599. .OP_END_FACTORY_REG(SwinAttentionScore)
  1600. /**
  1601. * @brief
  1602. swin_transformer model specific structure.Operator only supports swin_transformer. \n
  1603. * @par Inputs:
  1604. * Three inputs, including:
  1605. * @li x: A Tensor. Must be one of the following types: float16.
  1606. * @li weight: A Tensor. Must be one of the following types: float16.
  1607. * @li bias: A Tensor. Must be one of the following types: float16. \n
  1608. * @par Attributes:
  1609. * @li shifts: A optional attribute, the type is list int. Defaults to (). \n
  1610. * @par Outputs:
  1611. * One output, including:
  1612. * @li y: A Tensor. Must be one of the following types: float16. \n
  1613. * @par Restrictions:
  1614. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. \n
  1615. */
  1616. REG_OP(SwinAttentionFFN)
  1617. .INPUT(x1, TensorType({DT_FLOAT16}))
  1618. .INPUT(x2, TensorType({DT_FLOAT16}))
  1619. .INPUT(bias, TensorType({DT_FLOAT16}))
  1620. .OUTPUT(y, TensorType({DT_FLOAT16}))
  1621. .ATTR(shifts, ListInt, {})
  1622. .OP_END_FACTORY_REG(SwinAttentionFFN)
  1623. } // namespace ge
  1624. #endif // OPS_BUILT_IN_OP_PROTO_INC_MATRIX_CALCULATION_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示