You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn_training_ops.h 109 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737
  1. /**
  2. * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*!
  17. * \file nn_training_ops.h
  18. * \brief
  19. */
  20. #ifndef OPS_BUILT_IN_OP_PROTO_INC_NN_TRAINING_OPS_H_
  21. #define OPS_BUILT_IN_OP_PROTO_INC_NN_TRAINING_OPS_H_
  22. #include "graph/operator_reg.h"
  23. namespace ge {
  24. /**
  25. *@brief Updates "var" according to the AdaMax algorithm.
  26. * t-1 mean previous period.
  27. * m_t <- beta1 * m{t-1} + (1 - beta1) * grad
  28. * v_t <- max(beta2 * v{t-1}, abs(grad))
  29. * var <- var - lr / (1 - beta1^t) * m_t / (v_t + epsilon)
  30. *
  31. *@attention Constraints:
  32. * the input tensors must have the same shape.
  33. *
  34. *@par Inputs:
  35. *@li var: A mutable tensor. Must be one of the following types: TensorType::NumberType().
  36. * Should be from a Variable().
  37. *@li m: A mutable tensor. Has the same type as "var".
  38. * Should be from a Variable().
  39. *@li v: A mutable tensor. Has the same type as "var".
  40. * Should be from a Variable().
  41. *@li beta1_power: A scalar. Has the same type as "var".
  42. *@li lr: learning_rate. A scalar. Has the same type as "var".
  43. *@li beta1: A scalar. Has the same type as "var".
  44. *@li beta2: A scalar. Has the same type as "var".
  45. *@li epsilon: A scalar. Has the same type as "var".
  46. *@li grad: A tensor for the gradient. Has the same type as "var".
  47. *
  48. *@par Attributes:
  49. * use_locking: An optional bool. Defaults to "False".
  50. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  51. * by a lock; otherwise the behavior is undefined, but may exhibit less
  52. * contention.
  53. *
  54. *@par Outputs:
  55. * var: A mutable tensor. Has the same type as input "var".
  56. *
  57. *@par Third-party framework compatibility
  58. *Compatible with the TensorFlow operator ApplyAdaMax.
  59. *
  60. */
  61. REG_OP(ApplyAdaMax)
  62. .INPUT(var, TensorType::NumberType())
  63. .INPUT(m, TensorType::NumberType())
  64. .INPUT(v, TensorType::NumberType())
  65. .INPUT(beta1_power, TensorType::NumberType())
  66. .INPUT(lr, TensorType::NumberType())
  67. .INPUT(beta1, TensorType::NumberType())
  68. .INPUT(beta2, TensorType::NumberType())
  69. .INPUT(epsilon, TensorType::NumberType())
  70. .INPUT(grad, TensorType::NumberType())
  71. .OUTPUT(var, TensorType::NumberType())
  72. .ATTR(use_locking, Bool, false)
  73. .OP_END_FACTORY_REG(ApplyAdaMax)
  74. /**
  75. *@brief Updates "var" according to the AdaMax algorithm.
  76. * t-1 mean previous period.
  77. * m_t <- beta1 * m{t-1} + (1 - beta1) * grad
  78. * v_t <- max(beta2 * v{t-1}, abs(grad))
  79. * var <- var - lr / (1 - beta1^t) * m_t / (v_t + epsilon)
  80. *
  81. *@attention Constraints:
  82. * the input tensors must have the same shape.
  83. *
  84. *@par Inputs:
  85. *@li var: A mutable tensor. Must be one of the following types: TensorType::NumberType().
  86. * Should be from a Variable().
  87. *@li m: A mutable tensor. Has the same type as "var".
  88. * Should be from a Variable().
  89. *@li v: A mutable tensor. Has the same type as "var".
  90. * Should be from a Variable().
  91. *@li beta1_power: A scalar. Has the same type as "var".
  92. *@li lr: learning_rate. A scalar. Has the same type as "var".
  93. *@li beta1: A scalar. Has the same type as "var".
  94. *@li beta2: A scalar. Has the same type as "var".
  95. *@li epsilon: A scalar. Has the same type as "var".
  96. *@li grad: A tensor for the gradient. Has the same type as "var".
  97. *
  98. *@par Attributes:
  99. * use_locking: An optional bool. Defaults to "False".
  100. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  101. * by a lock; otherwise the behavior is undefined, but may exhibit less
  102. * contention.
  103. *
  104. *@par Outputs:
  105. *@li var: A mutable tensor. Has the same type as input "var".
  106. *@li m: A mutable tensor. Has the same type as input "m".
  107. *@li v: A mutable tensor. Has the same type as input "v".
  108. *
  109. *@par Third-party framework compatibility
  110. *Compatible with the TensorFlow operator ApplyAdaMax.
  111. *
  112. * @par Restrictions:
  113. * Warning: THIS FUNCTION IS DEPRECATED. Please use ApplyAdaMax instead.
  114. */
  115. REG_OP(ApplyAdaMaxD)
  116. .INPUT(var, TensorType::NumberType())
  117. .INPUT(m, TensorType::NumberType())
  118. .INPUT(v, TensorType::NumberType())
  119. .INPUT(beta1_power, TensorType::NumberType())
  120. .INPUT(lr, TensorType::NumberType())
  121. .INPUT(beta1, TensorType::NumberType())
  122. .INPUT(beta2, TensorType::NumberType())
  123. .INPUT(epsilon, TensorType::NumberType())
  124. .INPUT(grad, TensorType::NumberType())
  125. .OUTPUT(var, TensorType::NumberType())
  126. .OUTPUT(m, TensorType::NumberType())
  127. .OUTPUT(v, TensorType::NumberType())
  128. .ATTR(use_locking, Bool, false)
  129. .OP_END_FACTORY_REG(ApplyAdaMaxD)
  130. /**
  131. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme . \n
  132. *@par Inputs:
  133. * Five inputs, including:
  134. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  135. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  136. *@li lr: An NCHW, NHWC, or ND Tensor of type float32.
  137. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  138. *@li indices: An NCHW, NHWC, or ND Tensor of type float32 . \n
  139. *@par Attributes:
  140. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  141. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False" . \n
  142. *@par Outputs:
  143. *var: A Tensor. Has the same type and format as input "var" . \n
  144. *@par Third-party framework compatibility
  145. * Compatible with the TensorFlow operator SparseApplyAdagrad.
  146. */
  147. REG_OP(SparseApplyAdagrad)
  148. .INPUT(var, TensorType({DT_FLOAT}))
  149. .INPUT(accum, TensorType({DT_FLOAT}))
  150. .INPUT(lr, TensorType({DT_FLOAT}))
  151. .INPUT(grad, TensorType({DT_FLOAT}))
  152. .INPUT(indices, TensorType({DT_INT32}))
  153. .OUTPUT(var, TensorType({DT_FLOAT}))
  154. .OUTPUT(accum, TensorType({DT_FLOAT}))
  155. .ATTR(use_locking, Bool, false)
  156. .ATTR(update_slots, Bool, true)
  157. .OP_END_FACTORY_REG(SparseApplyAdagrad)
  158. /**
  159. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme . \n
  160. *@par Inputs:
  161. * Four inputs, including:
  162. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  163. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  164. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  165. *@li indices: An NCHW, NHWC, or ND Tensor of type int32 . \n
  166. *@par Attributes:
  167. *@li lr: Required, used for computation.
  168. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  169. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False" . \n
  170. *@par Outputs:
  171. *@li var: A Tensor. Has the same type and format as input "var".
  172. *@li accum: A Tensor. Has the same type and format as input "var" . \n
  173. *@par Third-party framework compatibility
  174. * Compatible with the TensorFlow operator SparseApplyAdagrad. \n
  175. *
  176. *@par Restrictions:
  177. *Warning: THIS FUNCTION IS DEPRECATED. Please use SparseApplyAdagrad instead.
  178. */
  179. REG_OP(SparseApplyAdagradD)
  180. .INPUT(var, TensorType({DT_FLOAT}))
  181. .INPUT(accum, TensorType({DT_FLOAT}))
  182. .INPUT(grad, TensorType({DT_FLOAT}))
  183. .INPUT(indices, TensorType({DT_INT32}))
  184. .OUTPUT(var, TensorType({DT_FLOAT}))
  185. .OUTPUT(accum, TensorType({DT_FLOAT}))
  186. .REQUIRED_ATTR(lr, Float)
  187. .ATTR(use_locking, Bool, false)
  188. .ATTR(update_slots, Bool, true)
  189. .OP_END_FACTORY_REG(SparseApplyAdagradD)
  190. /**
  191. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme . \n
  192. *@par Inputs:
  193. *Six inputs, including:
  194. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  195. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  196. *@li lr: An NCHW, NHWC, or ND Tensor of type float32.
  197. *@li epsilon: An NCHW, NHWC, or ND Tensor of type float32.
  198. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  199. *@li indices: An NCHW, NHWC, or ND Tensor of type float32 . \n
  200. *@par Attributes:
  201. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  202. *@li update_slots: An optional bool. Defaults to "True". If "False", the computation logic will be different . \n
  203. *@par Outputs:
  204. *var: A Tensor. Has the same type and format as input "var" . \n
  205. *@par Third-party framework compatibility
  206. *Compatible with the TensorFlow operator SparseApplyAdagradV2.
  207. */
  208. REG_OP(SparseApplyAdagradV2)
  209. .INPUT(var, TensorType({DT_FLOAT}))
  210. .INPUT(accum, TensorType({DT_FLOAT}))
  211. .INPUT(lr, TensorType({DT_FLOAT}))
  212. .INPUT(epsilon, TensorType({DT_FLOAT}))
  213. .INPUT(grad, TensorType({DT_FLOAT}))
  214. .INPUT(indices, TensorType({DT_INT32}))
  215. .OUTPUT(var, TensorType({DT_FLOAT}))
  216. .ATTR(use_locking, Bool, false)
  217. .ATTR(update_slots, Bool, true)
  218. .OP_END_FACTORY_REG(SparseApplyAdagradV2)
  219. /**
  220. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme . \n
  221. *@par Inputs:
  222. *Four inputs, including:
  223. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  224. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  225. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  226. *@li indices: An NCHW, NHWC, or ND Tensor of type int32 . \n
  227. *@par Attributes:
  228. *@li lr: Required, used for computation.
  229. *@li epsilon: Required, used for computation.
  230. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  231. *@li update_slots: An optional bool. Defaults to "True". If "False", the computation logic will be different . \n
  232. *@par Outputs:
  233. *@li var: A Tensor. Has the same type and format as input "var".
  234. *@li accum: A Tensor. Has the same type and format as input "accum" . \n
  235. *@par Third-party framework compatibility
  236. *Compatible with the TensorFlow operator SparseApplyAdagradV2. \n
  237. *
  238. *@par Restrictions:
  239. *Warning: THIS FUNCTION IS DEPRECATED. Please use SparseApplyAdagradV2 instead.
  240. */
  241. REG_OP(SparseApplyAdagradV2D)
  242. .INPUT(var, TensorType({DT_FLOAT}))
  243. .INPUT(accum, TensorType({DT_FLOAT}))
  244. .INPUT(grad, TensorType({DT_FLOAT}))
  245. .INPUT(indices, TensorType({DT_INT32}))
  246. .OUTPUT(var, TensorType({DT_FLOAT}))
  247. .OUTPUT(accum, TensorType({DT_FLOAT}))
  248. .REQUIRED_ATTR(lr, Float)
  249. .REQUIRED_ATTR(epsilon, Float)
  250. .ATTR(use_locking, Bool, false)
  251. .ATTR(update_slots, Bool, true)
  252. .OP_END_FACTORY_REG(SparseApplyAdagradV2D)
  253. /**
  254. *@brief Updates "var" according to the momentum scheme. Set use_nesterov = True if you
  255. * want to use Nesterov momentum.
  256. * computing process:
  257. * accum = accum * momentum + grad
  258. * var -= lr * accum
  259. *
  260. *@attention Constraints:
  261. * the input tensors must have the same shape.
  262. *
  263. *@par Inputs:
  264. *@li var: A mutable tensor. Should be from a Variable().
  265. *@li accum: A mutable tensor. Has the same type as "var".
  266. * Should be from a Variable().
  267. *@li lr: A scalar. Has the same type as "var".
  268. *@li grad: A tensor for the gradient. Has the same type as "var".
  269. *@li momentum: Momentum. Must be a scalar.
  270. *@par Attributes:
  271. *@li use_nesterov: An optional bool. Defaults to "False".
  272. * If "True", the tensor passed to compute grad will be
  273. * var - lr * momentum * accum, so in the end, the var you get is actually
  274. * var - lr * momentum * accum.
  275. *
  276. *@li use_locking: An optional bool. Defaults to "False".
  277. * If "True", updating of the "var", "ms", and "mom" tensors is protected by a lock;
  278. * otherwise the behavior is undefined, but may exhibit less contention.
  279. *
  280. *@par Outputs:
  281. * var: A mutable tensor. Has the same type as input "var".
  282. *
  283. *@par Third-party framework compatibility
  284. *Compatible with the TensorFlow operator ApplyMomentum.
  285. *
  286. */
  287. REG_OP(ApplyMomentum)
  288. .INPUT(var, TensorType::NumberType())
  289. .INPUT(accum, TensorType::NumberType())
  290. .INPUT(lr, TensorType::NumberType())
  291. .INPUT(grad, TensorType::NumberType())
  292. .INPUT(momentum, TensorType::NumberType())
  293. .OUTPUT(var, TensorType::NumberType())
  294. .ATTR(use_nesterov, Bool, false)
  295. .ATTR(use_locking, Bool, false)
  296. .OP_END_FACTORY_REG(ApplyMomentum)
  297. /**
  298. *@brief Updates "var" according to the momentum scheme. Set use_nesterov = True if you
  299. * want to use Nesterov momentum.
  300. * computing process:
  301. * accum = accum * momentum + grad
  302. * var -= lr * accum
  303. *
  304. *@attention Constraints:
  305. * the input tensors must have the same shape.
  306. *
  307. *@par Inputs:
  308. *@li var: A mutable tensor. Should be from a Variable().
  309. *@li accum: A mutable tensor. Has the same type as "var".
  310. * Should be from a Variable().
  311. *@li lr: A scalar. Has the same type as "var".
  312. *@li grad: A tensor for the gradient. Has the same type as "var".
  313. *
  314. *@par Attributes:
  315. *@li use_nesterov: An optional bool. Defaults to "False".
  316. * If "True", the tensor passed to compute grad will be
  317. * var - lr * momentum * accum, so in the end, the var you get is actually
  318. * var - lr * momentum * accum.
  319. *
  320. *@li use_locking: An optional bool. Defaults to "False".
  321. * If "True", updating of the "var", "ms", and "mom" tensors is protected by a lock;
  322. * otherwise the behavior is undefined, but may exhibit less contention.
  323. *
  324. *@par Outputs:
  325. * var: A mutable tensor. Has the same type as input "var".
  326. * accum: A mutable tensor. Has the same type as input "accum".
  327. *@par Third-party framework compatibility
  328. *Compatible with the TensorFlow operator ApplyMomentum.
  329. *
  330. * @par Restrictions:
  331. * Warning: THIS FUNCTION IS DEPRECATED. Please use ApplyMomentum instead.
  332. */
  333. REG_OP(ApplyMomentumD)
  334. .INPUT(var, TensorType::NumberType())
  335. .INPUT(accum, TensorType::NumberType())
  336. .INPUT(lr, TensorType::NumberType())
  337. .INPUT(grad, TensorType::NumberType())
  338. .INPUT(momentum, TensorType::NumberType())
  339. .OUTPUT(var, TensorType::NumberType())
  340. .OUTPUT(accum, TensorType::NumberType())
  341. .ATTR(use_nesterov, Bool, false)
  342. .ATTR(use_locking, Bool, false)
  343. .OP_END_FACTORY_REG(ApplyMomentumD)
  344. /**
  345. *@brief Updates '*var' according to the momentum scheme.
  346. * accum = accum * momentum - grad * lr
  347. * if use_nesterov is True:
  348. * var += accum * momentum - grad * lr
  349. * else:
  350. * var += accum
  351. *
  352. *@par Inputs:
  353. *@li var: A mutable tensor. Must be one of the data types defined in
  354. * TensorType::NumberType(). Should be from a Variable().
  355. *@li accum: A mutable tensor. Has the same type as "var". Should be from a
  356. * Variable().
  357. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  358. * from a Variable().
  359. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  360. * from a Variable().
  361. *@li momentum: A scalar. Has the same type as "var".
  362. *
  363. *@par Attributes:
  364. *@li use_nesterov: An optional bool. Defaults to "False".
  365. * If "True", var will be updated by using Nesterov momentum.
  366. *@li use_locking: An optional bool. Defaults to "False".
  367. * If "True", updating of the "var" tensor is protected by a lock;
  368. * otherwise the behavior is undefined, but may exhibit less contention.
  369. *
  370. *@par Outputs:
  371. * var: A mutable tensor. Has the same type as input "var".
  372. *
  373. *@attention Constraints:
  374. * The input tensors must have the same shape.
  375. *
  376. *@par Third-party framework compatibility
  377. * Compatible with the TensorFlow operator ResourceApplyKerasMomentum.
  378. *
  379. */
  380. REG_OP(ApplyKerasMomentum)
  381. .INPUT(var, TensorType::NumberType())
  382. .INPUT(accum, TensorType::NumberType())
  383. .INPUT(lr, TensorType::NumberType())
  384. .INPUT(grad, TensorType::NumberType())
  385. .INPUT(momentum, TensorType::NumberType())
  386. .OUTPUT(var, TensorType::NumberType())
  387. .ATTR(use_locking, Bool, false)
  388. .ATTR(use_nesterov, Bool, false)
  389. .OP_END_FACTORY_REG(ApplyKerasMomentum)
  390. /**
  391. *@brief Updates '*var' according to the momentum scheme.
  392. * accum = accum * momentum - grad * lr
  393. * if use_nesterov is True:
  394. * var += accum * momentum - grad * lr
  395. * else:
  396. * var += accum
  397. *
  398. *@par Inputs:
  399. *@li var: A mutable tensor. Must be one of the data types defined in
  400. * TensorType::NumberType(). Should be from a Variable().
  401. *@li accum: A mutable tensor. Has the same type as "var". Should be from a
  402. * Variable().
  403. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  404. * from a Variable().
  405. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  406. * from a Variable().
  407. *@li momentum: A scalar. Has the same type as "var". Should be from a
  408. * Variable().
  409. *
  410. *@par Attributes:
  411. *@li use_nesterov: An optional bool. Defaults to "False".
  412. * If "True", var will be updated by using nesterov momentum
  413. *@li use_locking: An optional bool. Defaults to "False".
  414. * If "True", updating of the "var" tensor is protected by a lock;
  415. * otherwise the behavior is undefined, but may exhibit less contention.
  416. *
  417. *@par Outputs:
  418. *@li var: A mutable tensor. Has the same type as input "var".
  419. *@li accum: A mutable tensor. Has the same type as input "var"
  420. *
  421. *@attention Constraints:
  422. * The input tensors must have the same shape.
  423. *
  424. *@par Third-party framework compatibility
  425. * Compatible with the TensorFlow operator ResourceApplyKerasMomentum.
  426. *
  427. *@par Restrictions:
  428. *Warning: THIS FUNCTION IS DEPRECATED. Please use ApplyKerasMomentum instead.
  429. */
  430. REG_OP(ApplyKerasMomentumD)
  431. .INPUT(var, TensorType::NumberType())
  432. .INPUT(accum, TensorType::NumberType())
  433. .INPUT(lr, TensorType::NumberType())
  434. .INPUT(grad, TensorType::NumberType())
  435. .INPUT(momentum, TensorType::NumberType())
  436. .OUTPUT(var, TensorType::NumberType())
  437. .OUTPUT(accum, TensorType::NumberType())
  438. .ATTR(use_locking, Bool, false)
  439. .ATTR(use_nesterov, Bool, false)
  440. .OP_END_FACTORY_REG(ApplyKerasMomentumD)
  441. /**
  442. *@brief Updates '*var' according to the Adam algorithm.
  443. * lr_t := {learning_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)
  444. * m_t := beta_1 * m_{t-1} + (1 - beta_1) * g
  445. * v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g
  446. * vhat_t := max{vhat_{t-1}, v_t}
  447. * variable := variable - lr_t * m_t / (sqrt{vhat_t} + epsilon)
  448. *
  449. *@par Inputs:
  450. *@li var: A mutable tensor. Must be one of the data types defined in
  451. * TensorType::NumberType(). Should be from a Variable().
  452. *@li m: A mutable tensor. Has the same type as "var". Should be from a
  453. * Variable().
  454. *@li v: A mutable tensor. Has the same type as "var". Should be from a
  455. * Variable().
  456. *@li vhat: A mutable tensor. Has the same type as "var". Should be from a
  457. * Variable().
  458. *@li beta1_power: A mutable tensor. Has the same type as "var". Should be from a
  459. * Variable().
  460. *@li beta2_power: A mutable tensor. Has the same type as "var". Should be from a
  461. * Variable().
  462. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  463. * from a Variable().
  464. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  465. * from a Variable().
  466. *
  467. *@par Attributes:
  468. *@li beta1: A scalar. Has the same type as "var".
  469. *@li beta2: A scalar. Has the same type as "var".
  470. *@li epsilon: A scalar. Has the same type as "var".
  471. *@li use_locking: An optional bool. Defaults to "False".
  472. * If "True", updating of the "var" tensor is protected by a lock;
  473. * otherwise the behavior is undefined, but may exhibit less contention.
  474. *
  475. *@par Outputs:
  476. *@li var: A mutable tensor. Has the same type as input "var".
  477. *@li m: A mutable tensor. Has the same type as input "var"
  478. *@li v: A mutable tensor. Has the same type as input "var"
  479. *@li vhat: A mutable tensor. Has the same type as input "var"
  480. *
  481. *@attention Constraints:
  482. * The input tensors must have the same shape.
  483. *
  484. *@par Third-party framework compatibility
  485. * Compatible with the TensorFlow operator ResourceApplyKerasMomentum.
  486. *
  487. *@par Restrictions:
  488. *Warning: THIS FUNCTION IS DEPRECATED. Please use ApplyAdamWithAmsgrad instead.
  489. *
  490. */
  491. REG_OP(ApplyAdamWithAmsgradD)
  492. .INPUT(var, TensorType::NumberType())
  493. .INPUT(m, TensorType::NumberType())
  494. .INPUT(v, TensorType::NumberType())
  495. .INPUT(vhat, TensorType::NumberType())
  496. .INPUT(beta1_power, TensorType::NumberType())
  497. .INPUT(beta2_power, TensorType::NumberType())
  498. .INPUT(lr, TensorType::NumberType())
  499. .INPUT(grad, TensorType::NumberType())
  500. .OUTPUT(var, TensorType::NumberType())
  501. .OUTPUT(m, TensorType::NumberType())
  502. .OUTPUT(v, TensorType::NumberType())
  503. .OUTPUT(vhat, TensorType::NumberType())
  504. .REQUIRED_ATTR(beta1, Float)
  505. .REQUIRED_ATTR(beta2, Float)
  506. .REQUIRED_ATTR(epsilon, Float)
  507. .ATTR(use_locking, Bool, false)
  508. .OP_END_FACTORY_REG(ApplyAdamWithAmsgradD)
  509. /**
  510. *@brief Updates '*var' according to the Adam algorithm..
  511. * lr_t := {learning_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)
  512. * m_t := beta_1 * m_{t-1} + (1 - beta_1) * g
  513. * v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g
  514. * vhat_t := max{vhat_{t-1}, v_t}
  515. * variable := variable - lr_t * m_t / (sqrt{vhat_t} + epsilon)
  516. *
  517. *@par Inputs:
  518. *@li var: A mutable tensor. Must be one of the data types defined in
  519. * TensorType::NumberType(). Should be from a Variable().
  520. *@li m: A mutable tensor. Has the same type as "var". Should be from a
  521. * Variable().
  522. *@li v: A mutable tensor. Has the same type as "var". Should be from a
  523. * Variable().
  524. *@li vhat: A mutable tensor. Has the same type as "var". Should be from a
  525. * Variable().
  526. *@li beta1_power: A mutable tensor. Has the same type as "var". Should be from a
  527. * Variable().
  528. *@li beta2_power: A mutable tensor. Has the same type as "var". Should be from a
  529. * Variable().
  530. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  531. * from a Variable().
  532. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  533. * from a Variable().
  534. *
  535. *@par Attributes:
  536. *@li beta1: A scalar. Has the same type as "var".
  537. *@li beta2: A scalar. Has the same type as "var".
  538. *@li epsilon: A scalar. Has the same type as "var".
  539. *@li use_locking: An optional bool. Defaults to "False".
  540. * If "True", updating of the "var" tensor is protected by a lock;
  541. * otherwise the behavior is undefined, but may exhibit less contention.
  542. *
  543. *@par Outputs:
  544. *@li var: A mutable tensor. Has the same type as input "var".
  545. *@li m: A mutable tensor. Has the same type as input "var"
  546. *@li v: A mutable tensor. Has the same type as input "var"
  547. *@li vhat: A mutable tensor. Has the same type as input "var"
  548. *
  549. *@attention Constraints:
  550. * The input tensors must have the same shape.
  551. *
  552. *@par Third-party framework compatibility
  553. * Compatible with the TensorFlow operator ResourceApplyKerasMomentum.
  554. *
  555. */
  556. REG_OP(ApplyAdamWithAmsgrad)
  557. .INPUT(var, TensorType::NumberType())
  558. .INPUT(m, TensorType::NumberType())
  559. .INPUT(v, TensorType::NumberType())
  560. .INPUT(vhat, TensorType::NumberType())
  561. .INPUT(beta1_power, TensorType::NumberType())
  562. .INPUT(beta2_power, TensorType::NumberType())
  563. .INPUT(lr, TensorType::NumberType())
  564. .INPUT(beta1, TensorType::NumberType())
  565. .INPUT(beta2, TensorType::NumberType())
  566. .INPUT(epsilon, TensorType::NumberType())
  567. .INPUT(grad, TensorType::NumberType())
  568. .OUTPUT(var, TensorType::NumberType())
  569. .ATTR(use_locking, Bool, false)
  570. .OP_END_FACTORY_REG(ApplyAdamWithAmsgrad)
  571. /**
  572. *@brief Updates '*var' according to the Adam algorithm..
  573. * lr_t := {learning_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)
  574. * m_t := beta_1 * m_{t-1} + (1 - beta_1) * g
  575. * v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g
  576. * vhat_t := max{vhat_{t-1}, v_t}
  577. * variable := variable - lr_t * m_t / (sqrt{vhat_t} + epsilon)
  578. *
  579. *@par Inputs:
  580. *Eleven inputs, including:
  581. *@li var: A mutable tensor. Must be one of the data types defined in
  582. * TensorType::NumberType(). Should be from a Variable().
  583. *@li m: A mutable tensor. Has the same type as "var". Should be from a
  584. * Variable().
  585. *@li v: A mutable tensor. Has the same type as "var". Should be from a
  586. * Variable().
  587. *@li vhat: A mutable tensor. Has the same type as "var". Should be from a
  588. * Variable().
  589. *@li beta1_power: A mutable tensor. Has the same type as "var". Should be from a
  590. * Variable().
  591. *@li beta2_power: A mutable tensor. Has the same type as "var". Should be from a
  592. * Variable().
  593. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  594. * from a Variable().
  595. *@li beta1: A mutable tensor. Has the same type as "var". Should be
  596. * from a Variable().
  597. *@li beta2: A mutable tensor. Has the same type as "var". Should be
  598. * from a Variable().
  599. *@li epsilon: A mutable tensor. Has the same type as "var". Should be
  600. * from a Variable().
  601. *@li grad: A tensor for the gradient. Has the same type as "var". Should be
  602. * from a Variable().
  603. *
  604. *@par Attribute:
  605. *one attribute, including:
  606. *@li use_locking: An optional bool. Defaults to "False".
  607. * If "True", updating of the "var" tensor is protected by a lock;
  608. * otherwise the behavior is undefined, but may exhibit less contention.
  609. *
  610. *@par Outputs:
  611. *four outputs, including:
  612. *@li var: A mutable tensor. Has the same type as input "var".
  613. *@li m: A mutable tensor. Has the same type as input "var"
  614. *@li v: A mutable tensor. Has the same type as input "var"
  615. *@li vhat: A mutable tensor. Has the same type as input "var"
  616. *
  617. *@attention Constraints:
  618. * The input tensors must have the same shape.
  619. *
  620. *@par Third-party framework compatibility
  621. * Compatible with the TensorFlow operator ResourceApplyKerasMomentum.
  622. *
  623. */
  624. REG_OP(ApplyAdamWithAmsgradV2)
  625. .INPUT(var, TensorType({DT_FLOAT}))
  626. .INPUT(m, TensorType({DT_FLOAT}))
  627. .INPUT(v, TensorType({DT_FLOAT}))
  628. .INPUT(vhat, TensorType({DT_FLOAT}))
  629. .INPUT(beta1_power, TensorType({DT_FLOAT}))
  630. .INPUT(beta2_power, TensorType({DT_FLOAT}))
  631. .INPUT(lr, TensorType({DT_FLOAT}))
  632. .INPUT(beta1, TensorType({DT_FLOAT}))
  633. .INPUT(beta2, TensorType({DT_FLOAT}))
  634. .INPUT(epsilon, TensorType({DT_FLOAT}))
  635. .INPUT(grad, TensorType({DT_FLOAT}))
  636. .OUTPUT(var, TensorType({DT_FLOAT}))
  637. .OUTPUT(m, TensorType({DT_FLOAT}))
  638. .OUTPUT(v, TensorType({DT_FLOAT}))
  639. .OUTPUT(vhat, TensorType({DT_FLOAT}))
  640. .ATTR(use_locking, Bool, false)
  641. .OP_END_FACTORY_REG(ApplyAdamWithAmsgradV2)
  642. /**
  643. *@brief Updates "var" according to the AddSign update.
  644. * t-1 mean previous period.
  645. * m_t <- beta1 * m_{t-1} + (1 - beta1) * grad
  646. * update <- exp(logbase * sign_decay * sign(grad) * sign(m_t)) * grad
  647. * var <- var - lr * update
  648. *
  649. *@attention Constraints:
  650. * the input tensors must have the same shape.
  651. *
  652. *@par Inputs:
  653. *@li var: A mutable tensor. Should be from a Variable().
  654. *@li m: A mutable tensor. Has the same type as "var".
  655. * Should be from a Variable().
  656. *@li lr: A scalar. Has the same type as "var".
  657. *@li logbase: A scalar. Has the same type as "var".
  658. *@li sign_decay: A scalar. Has the same type as "var".
  659. *@li beta: A scalar. Has the same type as "var".
  660. *@li grad: A tensor for the gradient. Has the same type as "var".
  661. *
  662. *@par Attributes:
  663. * use_locking: An optional bool. Defaults to "False".
  664. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  665. * by a lock; otherwise the behavior is undefined, but may exhibit less
  666. * contention.
  667. *
  668. *@par Outputs:
  669. * var: A mutable tensor. Has the same type as input "var".
  670. *
  671. *@par Third-party framework compatibility
  672. *Compatible with the TensorFlow operator ApplyPowerSign.
  673. *
  674. */
  675. REG_OP(ApplyPowerSign)
  676. .INPUT(var, TensorType::NumberType())
  677. .INPUT(m, TensorType::NumberType())
  678. .INPUT(lr, TensorType::NumberType())
  679. .INPUT(logbase, TensorType::NumberType())
  680. .INPUT(sign_decay, TensorType::NumberType())
  681. .INPUT(beta, TensorType::NumberType())
  682. .INPUT(grad, TensorType::NumberType())
  683. .OUTPUT(var, TensorType::NumberType())
  684. .ATTR(use_locking, Bool, false)
  685. .OP_END_FACTORY_REG(ApplyPowerSign)
  686. /**
  687. *@brief Updates "var" according to the AddSign update.
  688. * t-1 mean previous period.
  689. * m_t <- beta1 * m_{t-1} + (1 - beta1) * grad
  690. * update <- exp(logbase * sign_decay * sign(grad) * sign(m_t)) * grad
  691. * var <- var - lr * update
  692. *
  693. *@attention Constraints:
  694. * the input tensors must have the same shape.
  695. *
  696. *@par Inputs:
  697. *@li var: A mutable tensor. Should be from a Variable().
  698. *@li m: A mutable tensor. Has the same type as "var".
  699. * Should be from a Variable().
  700. *@li lr: A scalar. Has the same type as "var".
  701. *@li logbase: A scalar. Has the same type as "var".
  702. *@li sign_decay: A scalar. Has the same type as "var".
  703. *@li beta: A scalar. Has the same type as "var".
  704. *@li grad: A tensor for the gradient. Has the same type as "var".
  705. *
  706. *@par Attributes:
  707. * use_locking: An optional bool. Defaults to "False".
  708. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  709. * by a lock; otherwise the behavior is undefined, but may exhibit less
  710. * contention.
  711. *
  712. *@par Outputs:
  713. *@li var: A mutable tensor. Has the same type as input "var".
  714. *@li m: A mutable tensor. Has the same type as input "var".
  715. *
  716. *@par Third-party framework compatibility
  717. *Compatible with the TensorFlow operator ApplyPowerSign.
  718. *
  719. * @par Restrictions:
  720. * Warning: THIS FUNCTION IS DEPRECATED. Please use ApplyPowerSign instead.
  721. */
  722. REG_OP(ApplyPowerSignD)
  723. .INPUT(var, TensorType::NumberType())
  724. .INPUT(m, TensorType::NumberType())
  725. .INPUT(lr, TensorType::NumberType())
  726. .INPUT(logbase, TensorType::NumberType())
  727. .INPUT(sign_decay, TensorType::NumberType())
  728. .INPUT(beta, TensorType::NumberType())
  729. .INPUT(grad, TensorType::NumberType())
  730. .OUTPUT(var, TensorType::NumberType())
  731. .OUTPUT(m, TensorType::NumberType())
  732. .ATTR(use_locking, Bool, false)
  733. .OP_END_FACTORY_REG(ApplyPowerSignD)
  734. /**
  735. *@brief Updates "var" as FOBOS algorithm with fixed learning rate.
  736. * prox_v = var - alpha * delta
  737. * var = sign(prox_v)/(1+alpha * l2) * max{|prox_v|-alpha * l1,0}
  738. *
  739. *@attention Constraints:
  740. * the input tensors must have the same shape.
  741. *
  742. *@par Inputs:
  743. *@li var: A mutable tensor. Should be from a Variable().
  744. *@li alpha: A scalar. Has the same type as "var".
  745. *@li l1: A scalar. Has the same type as "var".
  746. *@li l2: A scalar. Has the same type as "var".
  747. *@li delta: A tensor. Has the same type as "var". The change.
  748. *
  749. *@par Attributes:
  750. * use_locking: An optional bool. Defaults to "False".
  751. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  752. * by a lock; otherwise the behavior is undefined, but may exhibit less
  753. * contention.
  754. *
  755. *@par Outputs:
  756. * var: A mutable tensor. Has the same type as input "var".
  757. *
  758. *@par Third-party framework compatibility
  759. *Compatible with the TensorFlow operator ApplyProximalGradientDescent.
  760. *
  761. */
  762. REG_OP(ApplyProximalGradientDescent)
  763. .INPUT(var, TensorType::NumberType())
  764. .INPUT(alpha, TensorType::NumberType())
  765. .INPUT(l1, TensorType::NumberType())
  766. .INPUT(l2, TensorType::NumberType())
  767. .INPUT(delta, TensorType::NumberType())
  768. .OUTPUT(var, TensorType::NumberType())
  769. .ATTR(use_locking, Bool, false)
  770. .OP_END_FACTORY_REG(ApplyProximalGradientDescent)
  771. /**
  772. *@brief Updates "var" according to the AddSign update . \n
  773. *@par Inputs:
  774. *Seven inputs, including:
  775. * @li var: A mutable Tensor of type TensorType::NumberType().
  776. * Should be a Variable Tensor.
  777. * @li m: A mutable Tensor of the same type as "var".
  778. * Should be a Variable Tensor.
  779. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  780. * @li alpha: A Tensor of the same type as "var". Must be a scalar.
  781. * @li sign_decay: A Tensor of the same type as "var". Must be a scalar.
  782. * @li beta: A Tensor of the same type as "var". Must be a scalar.
  783. * @li grad: A Tensor of the same type as "var", for the gradient.
  784. *@par Attributes:
  785. *use_locking: An optional bool. Defaults to "False".
  786. * If "True", updating of the "var" and "m" tensors will be
  787. * protected by a lock; otherwise the behavior is undefined,
  788. * but may exhibit less contention . \n
  789. *@par Outputs:
  790. *var: A mutable Tensor. Has the same type as "var" . \n
  791. *@par Third-party framework compatibility
  792. * Compatible with the TensorFlow operator ApplyAddSign.
  793. */
  794. REG_OP(ApplyAddSign)
  795. .INPUT(var, TensorType::NumberType())
  796. .INPUT(m, TensorType::NumberType())
  797. .INPUT(lr, TensorType::NumberType())
  798. .INPUT(alpha, TensorType::NumberType())
  799. .INPUT(sign_decay, TensorType::NumberType())
  800. .INPUT(beta, TensorType::NumberType())
  801. .INPUT(grad, TensorType::NumberType())
  802. .OUTPUT(var, TensorType::NumberType())
  803. .ATTR(use_locking, Bool, false)
  804. .OP_END_FACTORY_REG(ApplyAddSign)
  805. /**
  806. *@brief Updates "var" according to the AddSign update . \n
  807. *@par Inputs:
  808. *Seven inputs, including:
  809. * @li var: A mutable Tensor of type TensorType::NumberType().
  810. * Should be a Variable Tensor.
  811. * @li m: A mutable Tensor of the same type as "var".
  812. * Should be a Variable Tensor.
  813. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  814. * @li alpha: A Tensor of the same type as "var". Must be a scalar.
  815. * @li sign_decay: A Tensor of the same type as "var". Must be a scalar.
  816. * @li beta: A Tensor of the same type as "var". Must be a scalar.
  817. * @li grad: A Tensor of the same type as "var", for the gradient.
  818. *@par Attributes:
  819. *use_locking: An optional bool. Defaults to "False".
  820. * If "True", updating of the "var" and "m" tensors will be
  821. * protected by a lock; otherwise the behavior is undefined,
  822. * but may exhibit less contention . \n
  823. *@par Outputs:
  824. *@li var: A mutable Tensor. Has the same type as "var".
  825. *@li m: A mutable Tensor. Has the same type as "m" . \n
  826. *@par Third-party framework compatibility
  827. * Compatible with the TensorFlow operator ApplyAddSign.
  828. *
  829. * @par Restrictions:
  830. * Warning: THIS FUNCTION IS DEPRECATED. Please use ApplyAddSign instead.
  831. */
  832. REG_OP(ApplyAddSignD)
  833. .INPUT(var, TensorType::NumberType())
  834. .INPUT(m, TensorType::NumberType())
  835. .INPUT(lr, TensorType::NumberType())
  836. .INPUT(alpha, TensorType::NumberType())
  837. .INPUT(sign_decay, TensorType::NumberType())
  838. .INPUT(beta, TensorType::NumberType())
  839. .INPUT(grad, TensorType::NumberType())
  840. .OUTPUT(var, TensorType::NumberType())
  841. .OUTPUT(m, TensorType::NumberType())
  842. .ATTR(use_locking, Bool, false)
  843. .OP_END_FACTORY_REG(ApplyAddSignD)
  844. /**
  845. *@brief Updates "var" according to the centered RMSProp algorithm.
  846. * The centered RMSProp algorithm uses an estimate of the centered second moment
  847. * (i.e., the variance) for normalization, as opposed to regular RMSProp, which
  848. * uses the (uncentered) second moment. This often helps with training, but is
  849. * slightly more expensive in terms of computation and memory.
  850. *
  851. * t-1 mean previous period.
  852. * mg <- rho * mg{t-1} + (1-rho) * grad
  853. * ms <- rho * ms{t-1} + (1-rho) * grad * grad
  854. * mom <- momentum * mom{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
  855. * var <- var - mom
  856. *
  857. *@attention Constraints:
  858. *@li in dense implementation of this algorithm, mg, ms, and mom will
  859. * update even if the grad is zero, but in this sparse implementation, mg, ms,
  860. * and mom will not update in iterations during which the grad is zero.
  861. *@li the input tensors must have the same shape.
  862. *
  863. *@par Inputs:
  864. *@li var: A mutable tensor. Should be from a Variable().
  865. *@li mg: A mutable tensor. Has the same type as "var".
  866. * Should be from a Variable().
  867. *@li ms: A mutable tensor. Has the same type as "var".
  868. * Should be from a Variable().
  869. *@li mom: A mutable tensor. Has the same type as "var".
  870. * Should be from a Variable().
  871. *@li lr: A scalar. Has the same type as "var".
  872. *@li rho: A scalar. Has the same type as "var".
  873. *@li momentum: A tensor. Has the same type as "var".
  874. *@li epsilon: A scalar. Has the same type as "var".
  875. *@li grad: A tensor for the gradient. Has the same type as "var".
  876. *
  877. *@par Attributes:
  878. * use_locking: An optional bool. Defaults to "False".
  879. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  880. * by a lock; otherwise the behavior is undefined, but may exhibit less
  881. * contention.
  882. *
  883. *@par Outputs:
  884. * var: A mutable tensor. Has the same type as input "var".
  885. *
  886. *@par Third-party framework compatibility
  887. *Compatible with the TensorFlow operator ApplyCenteredRMSProp.
  888. *
  889. */
  890. REG_OP(ApplyCenteredRMSProp)
  891. .INPUT(var, TensorType::NumberType())
  892. .INPUT(mg, TensorType::NumberType())
  893. .INPUT(ms, TensorType::NumberType())
  894. .INPUT(mom, TensorType::NumberType())
  895. .INPUT(lr, TensorType::NumberType())
  896. .INPUT(rho, TensorType::NumberType())
  897. .INPUT(momentum, TensorType::NumberType())
  898. .INPUT(epsilon, TensorType::NumberType())
  899. .INPUT(grad, TensorType::NumberType())
  900. .OUTPUT(var, TensorType::NumberType())
  901. .ATTR(use_locking, Bool, false)
  902. .OP_END_FACTORY_REG(ApplyCenteredRMSProp)
  903. /**
  904. *@brief Updates "var" according to the centered RMSProp algorithm.
  905. * The centered RMSProp algorithm uses an estimate of the centered second moment
  906. * (i.e., the variance) for normalization, as opposed to regular RMSProp, which
  907. * uses the (uncentered) second moment. This often helps with training, but is
  908. * slightly more expensive in terms of computation and memory.
  909. *
  910. * t-1 mean previous period.
  911. * mg <- rho * mg{t-1} + (1-rho) * grad
  912. * ms <- rho * ms{t-1} + (1-rho) * grad * grad
  913. * mom <- momentum * mom{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
  914. * var <- var - mom
  915. *
  916. *@attention Constraints:
  917. *@li in dense implementation of this algorithm, mg, ms, and mom will
  918. * update even if the grad is zero, but in this sparse implementation, mg, ms,
  919. * and mom will not update in iterations during which the grad is zero.
  920. *@li the input tensors must have the same shape.
  921. *
  922. *@par Inputs:
  923. *@li var: A mutable tensor. Should be from a Variable().
  924. *@li mg: A mutable tensor. Has the same type as "var".
  925. * Should be from a Variable().
  926. *@li ms: A mutable tensor. Has the same type as "var".
  927. * Should be from a Variable().
  928. *@li mom: A mutable tensor. Has the same type as "var".
  929. * Should be from a Variable().
  930. *@li lr: A scalar. Has the same type as "var".
  931. *@li rho: A scalar. Has the same type as "var".
  932. *@li momentum: A tensor. Has the same type as "var".
  933. *@li epsilon: A scalar. Has the same type as "var".
  934. *@li grad: A tensor for the gradient. Has the same type as "var".
  935. *
  936. *@par Attributes:
  937. * use_locking: An optional bool. Defaults to "False".
  938. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  939. * by a lock; otherwise the behavior is undefined, but may exhibit less
  940. * contention.
  941. *
  942. *@par Outputs:
  943. *@li var: A mutable Tensor. Has the same type as "var".
  944. *@li mg: A mutable Tensor. Has the same type as "mg".
  945. *@li ms: A mutable Tensor. Has the same type as "ms".
  946. *@li mom: A mutable Tensor. Has the same type as "mom" . \n
  947. *@par Third-party framework compatibility
  948. *Compatible with the TensorFlow operator ApplyCenteredRMSPropD.
  949. *
  950. * @par Restrictions:
  951. * Warning: THIS FUNCTION IS DEPRECATED. Please use ApplyCenteredRMSProp instead.
  952. */
  953. REG_OP(ApplyCenteredRMSPropD)
  954. .INPUT(var, TensorType::NumberType())
  955. .INPUT(mg, TensorType::NumberType())
  956. .INPUT(ms, TensorType::NumberType())
  957. .INPUT(mom, TensorType::NumberType())
  958. .INPUT(lr, TensorType::NumberType())
  959. .INPUT(rho, TensorType::NumberType())
  960. .INPUT(momentum, TensorType::NumberType())
  961. .INPUT(epsilon, TensorType::NumberType())
  962. .INPUT(grad, TensorType::NumberType())
  963. .OUTPUT(var, TensorType::NumberType())
  964. .OUTPUT(mg, TensorType::NumberType())
  965. .OUTPUT(ms, TensorType::NumberType())
  966. .OUTPUT(mom, TensorType::NumberType())
  967. .ATTR(use_locking, Bool, false)
  968. .OP_END_FACTORY_REG(ApplyCenteredRMSPropD)
  969. /**
  970. *@brief Updates "var" by subtracting 'alpha' * 'delta' from it.
  971. * var -= delta * alpha
  972. *
  973. *@attention Constraints:
  974. * the input tensors must have the same shape.
  975. *
  976. *@par Inputs:
  977. *@li var: A mutable tensor. Should be from a Variable().
  978. *@li alpha: A scalar. Has the same type as "var".
  979. *@li delta: A tensor for the change. Has the same type as "var".
  980. *
  981. *@par Attributes:
  982. * use_locking: An optional bool. Defaults to "False".
  983. * If "True", updating of the "var" tensors is protected
  984. * by a lock; otherwise the behavior is undefined, but may exhibit less
  985. * contention.
  986. *
  987. *@par Outputs:
  988. * var: A mutable tensor. Has the same type as input "var".
  989. *
  990. *@par Third-party framework compatibility
  991. *Compatible with the TensorFlow operator ApplyGradientDescent.
  992. *
  993. */
  994. REG_OP(ApplyGradientDescent)
  995. .INPUT(var, TensorType::NumberType())
  996. .INPUT(alpha, TensorType::NumberType())
  997. .INPUT(delta, TensorType::NumberType())
  998. .OUTPUT(var, TensorType::NumberType())
  999. .ATTR(use_locking, Bool, false)
  1000. .OP_END_FACTORY_REG(ApplyGradientDescent)
  1001. /**
  1002. *@brief Updates "var" according to the adagrad scheme.
  1003. * accum += grad * grad
  1004. * var -= lr * grad * (1 / sqrt(accum))
  1005. *
  1006. *@attention Constraints:
  1007. * the input tensors must have the same shape.
  1008. *
  1009. *@par Inputs:
  1010. *@li var: A mutable tensor. Should be from a Variable().
  1011. *@li accum: A mutable tensor. Has the same type as "var".
  1012. * Should be from a Variable().
  1013. *@li lr: A scalar. Has the same type as "var".
  1014. *@li grad: A tensor for the gradient. Has the same type as "var".
  1015. *
  1016. *@par Attributes:
  1017. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  1018. *@li use_locking: An optional bool. Defaults to "False".
  1019. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  1020. * by a lock; otherwise the behavior is undefined, but may exhibit less
  1021. * contention.
  1022. *
  1023. *@par Outputs:
  1024. * var: A mutable tensor. Has the same type as input "var".
  1025. *
  1026. *@par Third-party framework compatibility
  1027. *Compatible with the TensorFlow operator ApplyAdagrad.
  1028. *
  1029. */
  1030. REG_OP(ApplyAdagrad)
  1031. .INPUT(var, TensorType::NumberType())
  1032. .INPUT(accum, TensorType::NumberType())
  1033. .INPUT(lr, TensorType::NumberType())
  1034. .INPUT(grad, TensorType::NumberType())
  1035. .OUTPUT(var, TensorType::NumberType())
  1036. .ATTR(update_slots, Bool, true)
  1037. .ATTR(use_locking, Bool, false)
  1038. .OP_END_FACTORY_REG(ApplyAdagrad)
  1039. /**
  1040. *@brief Updates "var" according to the adagrad scheme.
  1041. * accum += grad * grad
  1042. * var -= lr * grad * (1 / sqrt(accum))
  1043. *
  1044. *@attention Constraints:
  1045. * the input tensors must have the same shape.
  1046. *
  1047. *@par Inputs:
  1048. *@li var: A mutable tensor. Should be from a Variable().
  1049. *@li accum: A mutable tensor. Has the same type as "var".
  1050. * Should be from a Variable().
  1051. *@li lr: A scalar. Has the same type as "var".
  1052. *@li grad: A tensor for the gradient. Has the same type as "var".
  1053. *
  1054. *@par Attributes:
  1055. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  1056. *@li use_locking: An optional bool. Defaults to "False".
  1057. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  1058. * by a lock; otherwise the behavior is undefined, but may exhibit less
  1059. * contention.
  1060. *
  1061. *@par Outputs:
  1062. *@li var: A mutable tensor. Has the same type as input "var".
  1063. *@li accum: A mutable tensor. Has the same type as input "var".
  1064. *
  1065. *@par Third-party framework compatibility
  1066. *Compatible with the TensorFlow operator ApplyAdagrad.
  1067. *
  1068. * @par Restrictions:
  1069. * Warning: THIS FUNCTION IS DEPRECATED. Please use ApplyAdagrad instead.
  1070. */
  1071. REG_OP(ApplyAdagradD)
  1072. .INPUT(var, TensorType::NumberType())
  1073. .INPUT(accum, TensorType::NumberType())
  1074. .INPUT(lr, TensorType::NumberType())
  1075. .INPUT(grad, TensorType::NumberType())
  1076. .OUTPUT(var, TensorType::NumberType())
  1077. .OUTPUT(accum, TensorType::NumberType())
  1078. .ATTR(update_slots, Bool, true)
  1079. .ATTR(use_locking, Bool, false)
  1080. .OP_END_FACTORY_REG(ApplyAdagradD)
  1081. /**
  1082. * @brief Updates "var" according to the adagradv2 scheme.
  1083. * accum += grad * grad
  1084. * var -= lr * grad * (1 / sqrt(accum) + epsilon)
  1085. *
  1086. * @par Inputs:
  1087. * @li var: A mutable tensor. Must be one of the data types defined in
  1088. * TensorType::NumberType(). Should be from a Variable().
  1089. * @li accum: A mutable tensor. Has the same type as "var". Should be from a
  1090. * Variable().
  1091. * @li lr: A tensor for the learning rate. Has the same type as "var". Should be
  1092. * from a Variable().
  1093. * @li grad: A tensor for the gradient. Has the same type as "var". Should be
  1094. * from a Variable().
  1095. * @li epsilon: A scalar. Has the same type as "var".
  1096. *
  1097. * @par Attributes:
  1098. * @li update_slots: An optional bool. Defaults to "True".
  1099. * If "True", "accum" will be updated
  1100. * @li use_locking: An optional bool. Defaults to "False".
  1101. * If "True", updating of the "var" tensor is protected by a lock;
  1102. * otherwise the behavior is undefined, but may exhibit less contention.
  1103. *
  1104. * @par Outputs:
  1105. * var: A mutable tensor. Has the same type as input "var".
  1106. *
  1107. * @attention Constraints:
  1108. * The input tensors must have the same shape.
  1109. *
  1110. * @par Third-party framework compatibility
  1111. * Compatible with the TensorFlow operator ApplyAdagrad.
  1112. *
  1113. */
  1114. REG_OP(ApplyAdagradV2)
  1115. .INPUT(var, TensorType::NumberType())
  1116. .INPUT(accum, TensorType::NumberType())
  1117. .INPUT(lr, TensorType::NumberType())
  1118. .INPUT(epsilon, TensorType::NumberType())
  1119. .INPUT(grad, TensorType::NumberType())
  1120. .OUTPUT(var, TensorType::NumberType())
  1121. .ATTR(update_slots, Bool, true)
  1122. .ATTR(use_locking, Bool, false)
  1123. .OP_END_FACTORY_REG(ApplyAdagradV2)
  1124. /**
  1125. * @brief Updates "var" according to the adagradv2 scheme.
  1126. * accum += grad * grad
  1127. * var -= lr * grad * (1 / sqrt(accum) + epsilon)
  1128. *
  1129. * @par Inputs:
  1130. * @li var: A mutable tensor. Must be one of the data types defined in
  1131. * TensorType::NumberType(). Should be from a Variable().
  1132. * @li accum: A mutable tensor. Has the same type as "var". Should be from a
  1133. * Variable().
  1134. * @li lr: A tensor for the learning rate. Has the same type as "var". Should be
  1135. * from a Variable().
  1136. * @li grad: A tensor for the gradient. Has the same type as "var". Should be
  1137. * from a Variable().
  1138. *
  1139. * @par Attributes:
  1140. * @li epsilon: A scalar. Has the same type as "var".
  1141. * @li update_slots: An optional bool. Defaults to "True".
  1142. * If "True", "accum" will be updated
  1143. * @li use_locking: An optional bool. Defaults to "False".
  1144. * If "True", updating of the "var" tensor is protected by a lock;
  1145. * otherwise the behavior is undefined, but may exhibit less contention.
  1146. *
  1147. * @par Outputs:
  1148. * var: A mutable tensor. Has the same type as input "var".
  1149. *
  1150. * @attention Constraints:
  1151. * The input tensors must have the same shape.
  1152. *
  1153. * @par Third-party framework compatibility
  1154. * Compatible with the TensorFlow operator ApplyAdagrad.
  1155. *
  1156. *@par Restrictions:
  1157. *Warning: THIS FUNCTION IS DEPRECATED. Please use ApplyAdagradV2 instead.
  1158. */
  1159. REG_OP(ApplyAdagradV2D)
  1160. .INPUT(var, TensorType::NumberType())
  1161. .INPUT(accum, TensorType::NumberType())
  1162. .INPUT(lr, TensorType::NumberType())
  1163. .INPUT(grad, TensorType::NumberType())
  1164. .OUTPUT(var, TensorType::NumberType())
  1165. .OUTPUT(accum, TensorType::NumberType())
  1166. .REQUIRED_ATTR(epsilon, Float)
  1167. .ATTR(update_slots, Bool, true)
  1168. .ATTR(use_locking, Bool, false)
  1169. .OP_END_FACTORY_REG(ApplyAdagradV2D)
  1170. /**
  1171. *@brief Updates "var" according to the proximal adagrad scheme . \n
  1172. *@par Inputs:
  1173. *Eight inputs, including:
  1174. * @li var: A mutable Tensor. Must be one of the following types:
  1175. * TensorType::NumberType(). Should be a Variable Tensor.
  1176. * @li gradient_accumulator: A mutable Tensor. Must have the same
  1177. * type as "var". Should be a Variable Tensor.
  1178. * @li gradient_squared_accumulator: A mutable Tensor of the same type as "var".
  1179. * Should be a Variable Tensor.
  1180. * @li grad: A Tensor of the same type as "var", for the gradient.
  1181. * @li lr: A Tensor of the same type as "var".
  1182. * Scaling factor. Must be a scalar.
  1183. * @li l1: A Tensor of the same type as "var".
  1184. * L1 regulariation. Must be a scalar.
  1185. * @li l2: A Tensor of the same type as "var".
  1186. * L2 regulariation. Must be a scalar.
  1187. * @li global_step: A Tensor of type int32 or int64.
  1188. * Training step number. Must be a scalar . \n
  1189. *@par Attributes:
  1190. *use_locking: An optional bool. Defaults to "False".
  1191. * If "True", updating of the var and accum tensors will be
  1192. * protected by a lock; otherwise the behavior is undefined,
  1193. * but may exhibit less contention . \n
  1194. *@par Outputs:
  1195. *var: A mutable Tensor. Has the same type as "var" . \n
  1196. *@par Third-party framework compatibility
  1197. *Compatible with the TensorFlow operator ApplyAdagradDA.
  1198. */
  1199. REG_OP(ApplyAdagradDA)
  1200. .INPUT(var, TensorType::NumberType())
  1201. .INPUT(gradient_accumulator, TensorType::NumberType())
  1202. .INPUT(gradient_squared_accumulator, TensorType::NumberType())
  1203. .INPUT(grad, TensorType::NumberType())
  1204. .INPUT(lr, TensorType::NumberType())
  1205. .INPUT(l1, TensorType::NumberType())
  1206. .INPUT(l2, TensorType::NumberType())
  1207. .INPUT(global_step, TensorType({DT_INT32, DT_INT64}))
  1208. .OUTPUT(var, TensorType::NumberType())
  1209. .ATTR(use_locking, Bool, false)
  1210. .OP_END_FACTORY_REG(ApplyAdagradDA)
  1211. /**
  1212. *@brief Updates "var" according to the proximal adagrad scheme . \n
  1213. *@par Inputs:
  1214. *Eight inputs, including:
  1215. * @li var: A mutable Tensor. Must be one of the following types:
  1216. * TensorType::NumberType(). Should be a Variable Tensor.
  1217. * @li gradient_accumulator: A mutable Tensor. Must have the same
  1218. * type as "var". Should be a Variable Tensor.
  1219. * @li gradient_squared_accumulator: A mutable Tensor of the same type as "var".
  1220. * Should be a Variable Tensor.
  1221. * @li grad: A Tensor of the same type as "var", for the gradient.
  1222. * @li lr: A Tensor of the same type as "var".
  1223. * Scaling factor. Must be a scalar.
  1224. * @li l1: A Tensor of the same type as "var".
  1225. * L1 regulariation. Must be a scalar.
  1226. * @li l2: A Tensor of the same type as "var".
  1227. * L2 regulariation. Must be a scalar.
  1228. * @li global_step: A Tensor of type int32 or int64.
  1229. * Training step number. Must be a scalar . \n
  1230. *@par Attributes:
  1231. *use_locking: An optional bool. Defaults to "False".
  1232. * If "True", updating of the var and accum tensors will be
  1233. * protected by a lock; otherwise the behavior is undefined,
  1234. * but may exhibit less contention . \n
  1235. *@par Outputs:
  1236. *var: A mutable Tensor. Has the same type as "var".
  1237. *gradient_accumulator: A mutable Tensor. Has the same type as "var".
  1238. *gradient_squared_accumulator: A mutable Tensor. Has the same type as "var" . \n
  1239. *@par Third-party framework compatibility
  1240. *Compatible with the TensorFlow operator ApplyAdagradDA.
  1241. *
  1242. * @par Restrictions:
  1243. * Warning: THIS FUNCTION IS DEPRECATED. Please use ApplyAdagradDA instead.
  1244. */
  1245. REG_OP(ApplyAdagradDAD)
  1246. .INPUT(var, TensorType::NumberType())
  1247. .INPUT(gradient_accumulator, TensorType::NumberType())
  1248. .INPUT(gradient_squared_accumulator, TensorType::NumberType())
  1249. .INPUT(grad, TensorType::NumberType())
  1250. .INPUT(lr, TensorType::NumberType())
  1251. .INPUT(l1, TensorType::NumberType())
  1252. .INPUT(l2, TensorType::NumberType())
  1253. .INPUT(global_step, TensorType({DT_INT32, DT_INT64}))
  1254. .OUTPUT(var, TensorType::NumberType())
  1255. .OUTPUT(gradient_accumulator, TensorType::NumberType())
  1256. .OUTPUT(gradient_squared_accumulator, TensorType::NumberType())
  1257. .ATTR(use_locking, Bool, false)
  1258. .OP_END_FACTORY_REG(ApplyAdagradDAD)
  1259. /**
  1260. *@brief Returns the dimension index in the destination data format given the one in
  1261. * the source data format.
  1262. *
  1263. *@par Inputs:
  1264. * x: A tensor of type int32 or int64.
  1265. * A Tensor with each element as a dimension index in source data format.
  1266. * Must be in the range [-4, 4).
  1267. *
  1268. *@par Attributes:
  1269. *@li src_format: An optional string. Defaults to NHWC.
  1270. * source data format. Must of length 4.
  1271. *@li dst_format: An optional string. Defaults to NCHW.
  1272. * destination data format. Must of length 4.
  1273. *
  1274. *@par Outputs:
  1275. * y: A tensor. Has the same type as "x". Must be in the range [0, 4).
  1276. *
  1277. *@par Third-party framework compatibility
  1278. *Compatible with the TensorFlow operator DataFormatDimMap.
  1279. *
  1280. */
  1281. REG_OP(DataFormatDimMap)
  1282. .INPUT(x, TensorType::IndexNumberType())
  1283. .ATTR(src_format, String, "NHWC")
  1284. .ATTR(dst_format, String, "NCHW")
  1285. .OUTPUT(y, TensorType::IndexNumberType())
  1286. .OP_END_FACTORY_REG(DataFormatDimMap)
  1287. /**
  1288. * @brief Implements stochastic gradient descent (optionally with momentum).
  1289. * Nesterov momentum is based on the formula from
  1290. * On the importance of initialization and momentum in deep learning.
  1291. * @par Inputs:
  1292. * @li parameters: A mutable tensor of type float16 or float32.
  1293. * Specifies the iterable of parameters to optimize or dicts defining parameter
  1294. * groups.
  1295. * @li gradient: A tensor of type float16 or float32.
  1296. * Specifies the gradient of training step.
  1297. * @li learning_rate: A tensor of type float16 or float32.
  1298. * Specifies the learing_rate of training step.
  1299. * @li accum: A tensor of type float16 or float32.
  1300. * Specifies the velocity of training step.
  1301. * @li momentum: A tensor of type float16 or float32.
  1302. * Specifies the momentum factor.
  1303. * @li stat: A tensor of type float16 or float32.
  1304. * Specifies the status representing the first step or not . \n
  1305. * @par Attributes:
  1306. * @li dampening: An optional float, specifying the dampening for momentum.
  1307. * Defaults to "0.0".
  1308. * @li weight_decay: An optional float, specifying the L2 penalty. Defaults to
  1309. * "0.0".
  1310. * @li nesterov: An optional bool, specifying whether to enable Nesterov
  1311. * momentum. Defaults to "False" . \n
  1312. * @par Outputs:
  1313. * parameters: A mutable tensor same as input "parameters" . \n
  1314. * @see ApplyMomentum()
  1315. * @par Third-party framework compatibility
  1316. * @li Compatible with the PyTorch operator SGD.
  1317. */
  1318. REG_OP(SGD)
  1319. .INPUT(parameters, TensorType(DT_FLOAT, DT_FLOAT16))
  1320. .INPUT(gradient, TensorType(DT_FLOAT, DT_FLOAT16))
  1321. .INPUT(learning_rate, TensorType(DT_FLOAT, DT_FLOAT16))
  1322. .INPUT(accum, TensorType(DT_FLOAT, DT_FLOAT16))
  1323. .INPUT(momentum, TensorType(DT_FLOAT, DT_FLOAT16))
  1324. .INPUT(stat, TensorType(DT_FLOAT, DT_FLOAT16))
  1325. .OUTPUT(parameters, TensorType(DT_FLOAT, DT_FLOAT16))
  1326. .ATTR(dampening, Float, 0.0)
  1327. .ATTR(weight_decay, Float, 0.0)
  1328. .ATTR(nesterov, Bool, false)
  1329. .OP_END_FACTORY_REG(SGD)
  1330. /**
  1331. * @brief Updates "var" according to the RMSProp algorithm.
  1332. * mean_square = decay * mean_square + (1-decay) * gradient ** 2
  1333. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
  1334. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad
  1335. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
  1336. * var <- var - mom
  1337. *
  1338. * @par Inputs:
  1339. * @li var: A mutable tensor. Must be one of the data types defined in
  1340. * TensorType::NumberType(). Should be from a Variable().
  1341. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  1342. * Variable().
  1343. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  1344. * Variable().
  1345. * @li lr: A scalar. Must have the same type as "var".
  1346. * @li rho: A scalar. Must have the same type as "var".
  1347. * @li momentum: A scalar. Must have the same type as "var".
  1348. * @li epsilon: A scalar. Must have the same type as "var".
  1349. * @li grad: A tensor, specifying the gradient. Must have the same type as "var".
  1350. *
  1351. * @par Attributes:
  1352. * use_locking: An optional "bool". Defaults to "False". If "True", updating of
  1353. * the "var", "ms", and "mom" tensors will be protected by a lock; otherwise the
  1354. * behavior is undefined, but may exhibit less contention.
  1355. *
  1356. * @par Outputs:
  1357. * var: A mutable tensor. Has the same type as input "var".
  1358. *
  1359. * @attention Constraints:
  1360. * @li Note that in dense implementation of this algorithm, "ms" and "mom" will
  1361. * update even if "grad" is 0, but in this sparse implementation, "ms" and "mom"
  1362. * will not update in iterations during which "grad" is 0.
  1363. * @li The input tensors "var", "ms", "mom" and "grad" must have the same shape.
  1364. *
  1365. * @par Third-party framework compatibility
  1366. * @li Compatible with the TensorFlow operator ApplyRMSProp.
  1367. */
  1368. REG_OP(ApplyRMSProp)
  1369. .INPUT(var, TensorType::NumberType())
  1370. .INPUT(ms, TensorType::NumberType())
  1371. .INPUT(mom, TensorType::NumberType())
  1372. .INPUT(lr, TensorType::NumberType())
  1373. .INPUT(rho, TensorType::NumberType())
  1374. .INPUT(momentum, TensorType::NumberType())
  1375. .INPUT(epsilon, TensorType::NumberType())
  1376. .INPUT(grad, TensorType::NumberType())
  1377. .OUTPUT(var, TensorType::NumberType())
  1378. .ATTR(use_locking, Bool, false)
  1379. .OP_END_FACTORY_REG(ApplyRMSProp)
  1380. /**
  1381. * @brief Updates "var" according to the RMSProp algorithm, a const input will be
  1382. * considered as an attribute.
  1383. * mean_square = decay * mean_square + (1-decay) * gradient ** 2
  1384. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
  1385. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad
  1386. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
  1387. * var <- var - mom
  1388. *
  1389. * @par Inputs:
  1390. * @li var: A mutable tensor. Must be one of the data types defined in
  1391. * TensorType::NumberType(). Should be from a Variable().
  1392. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  1393. * Variable().
  1394. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  1395. * Variable().
  1396. * @li lr: A scalar. Must have the same type as "var".
  1397. * @li grad: A tensor, specifying the gradient. Must have the same type as "var".
  1398. *
  1399. * @par Attributes:
  1400. * @li use_locking: An optional "bool". Defaults to "False". If "True", updating
  1401. * of the "var", "ms", and "mom" tensors will be protected by a lock;
  1402. * otherwise the behavior is undefined, but may exhibit less contention.
  1403. * @li rho: A required scalar. Must have the same type as "var".
  1404. * @li momentum: A required scalar. Must have the same type as "var".
  1405. * @li epsilon: A required scalar. Must have the same type as "var".
  1406. *
  1407. * @par Outputs:
  1408. * var: A mutable tensor. Must have the same type as input "var".
  1409. *
  1410. * @attention Constraints:
  1411. * @li Note that in dense implementation of this algorithm, "ms" and "mom" will
  1412. * update even if "grad" is 0, but in this sparse implementation, "ms" and "mom"
  1413. * will not update in iterations during which "grad" is 0.
  1414. * @li The input tensors "var", "ms", "mom" and "grad" must have the same shape.
  1415. *
  1416. * @par Third-party framework compatibility
  1417. * @li Compatible with the TensorFlow operator ApplyRMSProp.
  1418. *
  1419. *@par Restrictions:
  1420. *Warning: THIS FUNCTION IS DEPRECATED. Please use ApplyRMSProp instead.
  1421. */
  1422. REG_OP(ApplyRMSPropD)
  1423. .INPUT(var, TensorType::NumberType())
  1424. .INPUT(ms, TensorType::NumberType())
  1425. .INPUT(mom, TensorType::NumberType())
  1426. .INPUT(lr, TensorType::NumberType())
  1427. .INPUT(grad, TensorType::NumberType())
  1428. .OUTPUT(var, TensorType::NumberType())
  1429. .OUTPUT(ms, TensorType::NumberType())
  1430. .OUTPUT(mom, TensorType::NumberType())
  1431. .REQUIRED_ATTR(rho, Float)
  1432. .REQUIRED_ATTR(momentum, Float)
  1433. .REQUIRED_ATTR(epsilon, Float)
  1434. .ATTR(use_locking, Bool, false)
  1435. .OP_END_FACTORY_REG(ApplyRMSPropD)
  1436. /**
  1437. *@brief Update "var" and "accum" according to FOBOS with Adagrad learning rate . \n
  1438. *@par Inputs:
  1439. *Six inputs, including:
  1440. * @li var: A mutable Tensor of type TensorType::NumberType().
  1441. * Should be from a Variable().
  1442. * @li accum: A mutable Tensor of the same type as "var". Should be from a Variable().
  1443. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1444. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1445. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1446. * @li grad: A Tensor of the same type as "var", for the gradient . \n
  1447. *@par Attributes:
  1448. *use_locking: An optional bool. Defaults to "False". If "True", updating of the "var" and "accum" *tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less *contention . \n
  1449. *@par Outputs:
  1450. *var: A mutable tensor. Must have the same type as input "var" . \n
  1451. *@par Third-party framework compatibility
  1452. *Compatible with the TensorFlow operator ApplyProximalAdagrad.
  1453. */
  1454. REG_OP(ApplyProximalAdagrad)
  1455. .INPUT(var, TensorType::NumberType())
  1456. .INPUT(accum, TensorType::NumberType())
  1457. .INPUT(lr, TensorType::NumberType())
  1458. .INPUT(l1, TensorType::NumberType())
  1459. .INPUT(l2, TensorType::NumberType())
  1460. .INPUT(grad, TensorType::NumberType())
  1461. .OUTPUT(var, TensorType::NumberType())
  1462. .ATTR(use_locking, Bool, false)
  1463. .OP_END_FACTORY_REG(ApplyProximalAdagrad)
  1464. /**
  1465. *@brief Update "var" and "accum" according to FOBOS with Adagrad learning rate . \n
  1466. *@par Inputs:
  1467. *Six inputs, including:
  1468. * @li var: A mutable Tensor of type TensorType::NumberType().
  1469. * Should be from a Variable().
  1470. * @li accum: A mutable Tensor of the same type as "var". Should be from a Variable().
  1471. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1472. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1473. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1474. * @li grad: A Tensor of the same type as "var", for the gradient . \n
  1475. *@par Attributes:
  1476. *use_locking: An optional bool. Defaults to "False". If "True", updating of the "var" and "accum" *tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less *contention . \n
  1477. *@par Outputs:
  1478. * @li var: A mutable Tensor. Has the same type as "var".
  1479. * @li accum: A mutable Tensor. Has the same type as "var" . \n
  1480. *@par Third-party framework compatibility
  1481. *Compatible with the TensorFlow operator ApplyProximalAdagradD.
  1482. *
  1483. * @par Restrictions:
  1484. * Warning: THIS FUNCTION IS DEPRECATED. Please use ApplyProximalAdagrad instead.
  1485. */
  1486. REG_OP(ApplyProximalAdagradD)
  1487. .INPUT(var, TensorType::NumberType())
  1488. .INPUT(accum, TensorType::NumberType())
  1489. .INPUT(lr, TensorType::NumberType())
  1490. .INPUT(l1, TensorType::NumberType())
  1491. .INPUT(l2, TensorType::NumberType())
  1492. .INPUT(grad, TensorType::NumberType())
  1493. .OUTPUT(var, TensorType::NumberType())
  1494. .OUTPUT(accum, TensorType::NumberType())
  1495. .ATTR(use_locking, Bool, false)
  1496. .OP_END_FACTORY_REG(ApplyProximalAdagradD)
  1497. /**
  1498. *@brief Updates entries in 'var' and 'accum' according to the Proximal Adagrad algorithm.
  1499. * Compared with op ApplyProximalAdagrad, an additional index tensor is input,
  1500. * Only the indices into the first dimensions of "var" and "accum" are updated . \n
  1501. *@par Inputs:
  1502. * Seven inputs, including:
  1503. * @li var: A mutable Tensor.
  1504. * TensorType::NumberType(). Should be a Variable Tensor.
  1505. * @li accum: A mutable Tensor of the same type as "var".
  1506. * Should be a Variable Tensor. Should be greater than or equal to zero.
  1507. * Accum and grad cannot be equal to zero at the same time.
  1508. * @li lr: A Tensor of the same type as "var".
  1509. * Scaling factor. Must be a scalar. Should be greater than zero.
  1510. * @li l1: A Tensor of the same type as "var".
  1511. * L1 regulariation. Must be a scalar. Should be greater than or equal to zero.
  1512. * @li l2: A Tensor of the same type as "var".
  1513. * L2 regulariation. Must be a scalar. Should be greater than or equal to zero.
  1514. * @li grad: A Tensor. Has the same type as "var".
  1515. * The gradient.
  1516. * @li indices: A vector of indices into the first dimension of "var" and "accum".
  1517. * TensorType::IndexNumberType(). Can contain duplicate values . \n
  1518. *@par Attributes:
  1519. *use_locking: An optional bool. Defaults to "False".
  1520. * If "True", updating of the var and accum tensors will be protected by a lock;
  1521. * If "False", the behavior is undefined, but may exhibit less contention.
  1522. *@par Outputs:
  1523. *var: A mutable Tensor. Has the same type as "var" . \n
  1524. *@par Third-party framework compatibility
  1525. *Compatible with the TensorFlow operator SparseApplyProximalAdagrad.
  1526. */
  1527. REG_OP(SparseApplyProximalAdagrad)
  1528. .INPUT(var, TensorType::NumberType())
  1529. .INPUT(accum, TensorType::NumberType())
  1530. .INPUT(lr, TensorType::NumberType())
  1531. .INPUT(l1, TensorType::NumberType())
  1532. .INPUT(l2, TensorType::NumberType())
  1533. .INPUT(grad, TensorType::NumberType())
  1534. .INPUT(indices, TensorType::IndexNumberType())
  1535. .OUTPUT(var, TensorType::NumberType())
  1536. .ATTR(use_locking, Bool, false)
  1537. .OP_END_FACTORY_REG(SparseApplyProximalAdagrad)
  1538. /**
  1539. *@brief Updates entries in 'var' and 'accum' according to the Proximal Adagrad algorithm.\ n
  1540. * Compared with op ApplyProximalAdagrad, an additional index tensor is input,
  1541. * Only the indices into the first dimensions of "var" and "accum" are updated . \n
  1542. *@par Inputs:
  1543. * Seven inputs, including:
  1544. * @li var: A mutable Tensor.
  1545. * TensorType::NumberType(). Should be a Variable Tensor.
  1546. * @li accum: A mutable Tensor of the same type as "var".
  1547. * Should be a Variable Tensor. Should be greater than or equal to zero.
  1548. * Accum and grad cannot be equal to zero at the same time.
  1549. * @li lr: A Tensor of the same type as "var".
  1550. * Scaling factor. Must be a scalar. Should be greater than zero.
  1551. * @li l1: A Tensor of the same type as "var".
  1552. * L1 regulariation. Must be a scalar. Should be greater than or equal to zero.
  1553. * @li l2: A Tensor of the same type as "var".
  1554. * L2 regulariation. Must be a scalar. Should be greater than or equal to zero.
  1555. * @li grad: A Tensor. Has the same type as "var".
  1556. * The gradient.
  1557. * @li indices: A vector of indices into the first dimension of "var" and "accum".
  1558. * TensorType::IndexNumberType(). Can contain duplicate values . \n
  1559. *@par Attributes:
  1560. *use_locking: An optional bool. Defaults to "False".
  1561. * If "True", updating of the var and accum tensors will be protected by a lock;
  1562. * If "False", the behavior is undefined, but may exhibit less contention . \n
  1563. *@par Outputs:
  1564. *@li var: A mutable Tensor. Has the same type as "var".
  1565. *@li accum: A mutable Tensor. Has the same type as "var" . \n
  1566. *@par Third-party framework compatibility
  1567. *Compatible with the TensorFlow operator SparseApplyProximalAdagrad.
  1568. * @par Restrictions:
  1569. * Warning: THIS FUNCTION IS DEPRECATED. Please use SparseApplyProximalAdagrad instead.
  1570. */
  1571. REG_OP(SparseApplyProximalAdagradD)
  1572. .INPUT(var, TensorType::NumberType())
  1573. .INPUT(accum, TensorType::NumberType())
  1574. .INPUT(lr, TensorType::NumberType())
  1575. .INPUT(l1, TensorType::NumberType())
  1576. .INPUT(l2, TensorType::NumberType())
  1577. .INPUT(grad, TensorType::NumberType())
  1578. .INPUT(indices, TensorType::IndexNumberType())
  1579. .OUTPUT(var, TensorType::NumberType())
  1580. .OUTPUT(accum, TensorType::NumberType())
  1581. .ATTR(use_locking, Bool, false)
  1582. .OP_END_FACTORY_REG(SparseApplyProximalAdagradD)
  1583. /**
  1584. *@brief Updates "var" according to the Ftrl-proximal scheme . \n
  1585. *@par Inputs:
  1586. *Eight inputs, including:
  1587. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1588. * Should be a Variable Tensor.
  1589. * @li accum: A mutable Tensor of the same type as "var".
  1590. * Should be a Variable Tensor.
  1591. * @li linear: A mutable Tensor of the same type as "var".
  1592. * Should be a Variable Tensor.
  1593. * @li grad: A Tensor of the same type as "var", for the gradient.
  1594. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1595. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1596. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1597. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar . \n
  1598. *@par Attributes:
  1599. *use_locking: An optional bool. Defaults to "False".
  1600. * If "True", updating of the "var" and "accum" tensors will be
  1601. * protected by a lock; otherwise the behavior is undefined,
  1602. * but may exhibit less contention . \n
  1603. *@par Outputs:
  1604. *var: A mutable Tensor. Has the same type as "var" . \n
  1605. *@par Third-party framework compatibility
  1606. *Compatible with the TensorFlow operator ApplyFtrl.
  1607. */
  1608. REG_OP(ApplyFtrl)
  1609. .INPUT(var, TensorType::NumberType())
  1610. .INPUT(accum, TensorType::NumberType())
  1611. .INPUT(linear, TensorType::NumberType())
  1612. .INPUT(grad, TensorType::NumberType())
  1613. .INPUT(lr, TensorType::NumberType())
  1614. .INPUT(l1, TensorType::NumberType())
  1615. .INPUT(l2, TensorType::NumberType())
  1616. .INPUT(lr_power, TensorType::NumberType())
  1617. .OUTPUT(var, TensorType::NumberType())
  1618. .ATTR(use_locking, Bool, false)
  1619. .OP_END_FACTORY_REG(ApplyFtrl)
  1620. /**
  1621. *@brief Updates "var" according to the Ftrl-proximal scheme . \n
  1622. *@par Inputs:
  1623. *Eight inputs, including:
  1624. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1625. * Should be a Variable Tensor.
  1626. * @li accum: A mutable Tensor of the same type as "var".
  1627. * Should be a Variable Tensor.
  1628. * @li linear: A mutable Tensor of the same type as "var".
  1629. * Should be a Variable Tensor.
  1630. * @li grad: A Tensor of the same type as "var", for the gradient.
  1631. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1632. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1633. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1634. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar . \n
  1635. *@par Attributes:
  1636. *use_locking: An optional bool. Defaults to "False".
  1637. * If "True", updating of the "var" and "accum" tensors will be
  1638. * protected by a lock; otherwise the behavior is undefined,
  1639. * but may exhibit less contention . \n
  1640. *@par Outputs:
  1641. *@li var: A mutable Tensor. Has the same type as "var".
  1642. *@li accum: A mutable Tensor. Has the same type as "accum".
  1643. *@li linear: A mutable Tensor. Has the same type as "linear" . \n
  1644. *@par Third-party framework compatibility
  1645. *Compatible with the TensorFlow operator ApplyFtrl.
  1646. *
  1647. * @par Restrictions:
  1648. * Warning: THIS FUNCTION IS DEPRECATED. Please use ApplyFtrl instead.
  1649. */
  1650. REG_OP(ApplyFtrlD)
  1651. .INPUT(var, TensorType::NumberType())
  1652. .INPUT(accum, TensorType::NumberType())
  1653. .INPUT(linear, TensorType::NumberType())
  1654. .INPUT(grad, TensorType::NumberType())
  1655. .INPUT(lr, TensorType::NumberType())
  1656. .INPUT(l1, TensorType::NumberType())
  1657. .INPUT(l2, TensorType::NumberType())
  1658. .INPUT(lr_power, TensorType::NumberType())
  1659. .OUTPUT(var, TensorType::NumberType())
  1660. .OUTPUT(accum, TensorType::NumberType())
  1661. .OUTPUT(linear, TensorType::NumberType())
  1662. .ATTR(use_locking, Bool, false)
  1663. .OP_END_FACTORY_REG(ApplyFtrlD)
  1664. /**
  1665. *@brief Update "var" according to the Ftrl-proximal scheme . \n
  1666. *@par Inputs:
  1667. *Nine inputs, including:
  1668. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1669. * Should be a Variable Tensor.
  1670. * @li accum: A mutable Tensor of the same type as "var".
  1671. * Should be a Variable Tensor.
  1672. * @li linear: A mutable Tensor of the same type as "var".
  1673. * Should be a Variable Tensor.
  1674. * @li grad: A Tensor of the same type as "var", for the gradient.
  1675. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1676. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1677. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1678. * @li l2_shrinkage: A Tensor of the same type as "var".
  1679. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar . \n
  1680. *@par Attributes:
  1681. *use_locking: An optional bool. Defaults to "False".
  1682. * If "True", updating of the "var" and "accum" tensors will be
  1683. * protected by a lock; otherwise the behavior is undefined,
  1684. * but may exhibit less contention . \n
  1685. *@par Outputs:
  1686. *var: A mutable Tensor. Has the same type as "var" . \n
  1687. *@par Third-party framework compatibility
  1688. *Compatible with the TensorFlow operator ApplyFtrlV2.
  1689. */
  1690. REG_OP(ApplyFtrlV2)
  1691. .INPUT(var, TensorType::NumberType())
  1692. .INPUT(accum, TensorType::NumberType())
  1693. .INPUT(linear, TensorType::NumberType())
  1694. .INPUT(grad, TensorType::NumberType())
  1695. .INPUT(lr, TensorType::NumberType())
  1696. .INPUT(l1, TensorType::NumberType())
  1697. .INPUT(l2, TensorType::NumberType())
  1698. .INPUT(l2_shrinkage, TensorType::NumberType())
  1699. .INPUT(lr_power, TensorType::NumberType())
  1700. .OUTPUT(var, TensorType::NumberType())
  1701. .ATTR(use_locking, Bool, false)
  1702. .OP_END_FACTORY_REG(ApplyFtrlV2)
  1703. /**
  1704. *@brief Update "var" according to the Ftrl-proximal scheme . \n
  1705. *@par Inputs:
  1706. *Nine inputs, including:
  1707. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1708. * Should be a Variable Tensor.
  1709. * @li accum: A mutable Tensor of the same type as "var".
  1710. * Should be a Variable Tensor.
  1711. * @li linear: A mutable Tensor of the same type as "var".
  1712. * Should be a Variable Tensor.
  1713. * @li grad: A Tensor of the same type as "var", for the gradient.
  1714. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1715. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1716. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1717. * @li l2_shrinkage: A Tensor of the same type as "var".
  1718. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar . \n
  1719. *@par Attributes:
  1720. *use_locking: An optional bool. Defaults to "False".
  1721. * If "True", updating of the "var" and "accum" tensors will be
  1722. * protected by a lock; otherwise the behavior is undefined,
  1723. * but may exhibit less contention . \n
  1724. *@par Outputs:
  1725. *var: A mutable Tensor. Has the same type as "var".
  1726. *accum: A mutable Tensor. Has the same type as "accum".
  1727. *linear: A mutable Tensor. Has the same type as "linear" . \n
  1728. *@par Third-party framework compatibility
  1729. *Compatible with the TensorFlow operator ApplyFtrlV2.
  1730. *
  1731. * @par Restrictions:
  1732. * Warning: THIS FUNCTION IS DEPRECATED. Please use ApplyFtrlV2 instead.
  1733. */
  1734. REG_OP(ApplyFtrlV2D)
  1735. .INPUT(var, TensorType::NumberType())
  1736. .INPUT(accum, TensorType::NumberType())
  1737. .INPUT(linear, TensorType::NumberType())
  1738. .INPUT(grad, TensorType::NumberType())
  1739. .INPUT(lr, TensorType::NumberType())
  1740. .INPUT(l1, TensorType::NumberType())
  1741. .INPUT(l2, TensorType::NumberType())
  1742. .INPUT(l2_shrinkage, TensorType::NumberType())
  1743. .INPUT(lr_power, TensorType::NumberType())
  1744. .OUTPUT(var, TensorType::NumberType())
  1745. .OUTPUT(accum, TensorType::NumberType())
  1746. .OUTPUT(linear, TensorType::NumberType())
  1747. .ATTR(use_locking, Bool, false)
  1748. .OP_END_FACTORY_REG(ApplyFtrlV2D)
  1749. /**
  1750. *@brief Updates "var" according to the Adam algorithm.
  1751. * lr_t <- text{learning\_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)
  1752. * m_t <- beta_1 * m_{t-1} + (1 - beta_1) * g
  1753. * v_t <- max(beta2 * v{t-1}, abs(g))
  1754. * variable <- variable - lr_t * m_t / (sqrt{v_t} + epsilon)
  1755. *
  1756. *@attention Constraints:
  1757. * *The input tensors must have the same shape.*
  1758. *
  1759. *@par Inputs:
  1760. *@li var: A mutable Tensor of the type TensorType::NumberType().
  1761. * Should be from a Variable().
  1762. *@li m: A mutable Tensor of the same type as "var".
  1763. * Should be from a Variable().
  1764. *@li v: A mutable Tensor of the same type as "var".
  1765. * Should be from a Variable().
  1766. *@li beta1_power: A scalar of the same type as "var".
  1767. *@li beta2_power: A scalar of the same type as "var".
  1768. *@li lr: learning_rate. A scalar of the same type as "var".
  1769. *@li beta1: A scalar of the same type as "var".
  1770. *@li beta2: A scalar of the same type as "var".
  1771. *@li epsilon: A scalar of the same type as "var".
  1772. *@li grad: A Tensor of the same type as "var", for the gradient.
  1773. *
  1774. *@par Attributes:
  1775. *@li use_locking: An optional bool. Defaults to "False".
  1776. * If "True", updating of the "var", m", and "v" tensors will be protected
  1777. * by a lock; otherwise the behavior is undefined, but may exhibit less
  1778. * contention.
  1779. *@li use_nesterov: An optional bool. Defaults to "False".
  1780. If "True", uses the nesterov update.
  1781. *
  1782. *@par Outputs:
  1783. * var: A mutable Tensor. Has the same type as intput "var" . \n
  1784. *@par Third-party framework compatibility
  1785. *Compatible with the TensorFlow operator ApplyAdam.
  1786. */
  1787. REG_OP(ApplyAdam)
  1788. .INPUT(var, TensorType::NumberType())
  1789. .INPUT(m, TensorType::NumberType())
  1790. .INPUT(v, TensorType::NumberType())
  1791. .INPUT(beta1_power, TensorType::NumberType())
  1792. .INPUT(beta2_power, TensorType::NumberType())
  1793. .INPUT(lr, TensorType::NumberType())
  1794. .INPUT(beta1, TensorType::NumberType())
  1795. .INPUT(beta2, TensorType::NumberType())
  1796. .INPUT(epsilon, TensorType::NumberType())
  1797. .INPUT(grad, TensorType::NumberType())
  1798. .OUTPUT(var, TensorType::NumberType())
  1799. .ATTR(use_locking, Bool, false)
  1800. .ATTR(use_nesterov, Bool, false)
  1801. .OP_END_FACTORY_REG(ApplyAdam)
  1802. /**
  1803. *@brief Updates "var" according to the Adam algorithm.
  1804. * lr_t <- text{learning\_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)
  1805. * m_t <- beta_1 * m_{t-1} + (1 - beta_1) * g
  1806. * v_t <- max(beta2 * v{t-1}, abs(g))
  1807. * variable <- variable - lr_t * m_t / (sqrt{v_t} + epsilon)
  1808. *
  1809. *@attention Constraints:
  1810. * *The input tensors must have the same shape.*
  1811. *
  1812. *@par Inputs:
  1813. *@li var: A mutable Tensor of the type TensorType::NumberType().
  1814. * Should be from a Variable().
  1815. *@li m: A mutable Tensor of the same type as "var".
  1816. * Should be from a Variable().
  1817. *@li v: A mutable Tensor of the same type as "var".
  1818. * Should be from a Variable().
  1819. *@li beta1_power: A scalar of the same type as "var".
  1820. *@li beta2_power: A scalar of the same type as "var".
  1821. *@li lr: learning_rate. A scalar of the same type as "var".
  1822. *@li beta1: A scalar of the same type as "var".
  1823. *@li beta2: A scalar of the same type as "var".
  1824. *@li epsilon: A scalar of the same type as "var".
  1825. *@li grad: A Tensor of the same type as "var", for the gradient.
  1826. *
  1827. *@par Attributes:
  1828. *@li use_locking: An optional bool. Defaults to "False".
  1829. * If "True", updating of the "var", m", and "v" tensors will be protected
  1830. * by a lock; otherwise the behavior is undefined, but may exhibit less
  1831. * contention.
  1832. *@li use_nesterov: An optional bool. Defaults to "False".
  1833. If "True", uses the nesterov update.
  1834. *
  1835. *@par Outputs:
  1836. *@li var: A mutable tensor. Has the same type as input "var".
  1837. *@li m: A mutable tensor. Has the same type as input "m".
  1838. *@li v: A mutable tensor. Has the same type as input "v" . \n
  1839. *@par Third-party framework compatibility
  1840. *Compatible with the TensorFlow operator ApplyAdam.
  1841. *
  1842. * @par Restrictions:
  1843. * Warning: THIS FUNCTION IS DEPRECATED. Please use ApplyAdam instead.
  1844. */
  1845. REG_OP(ApplyAdamD)
  1846. .INPUT(var, TensorType::NumberType())
  1847. .INPUT(m, TensorType::NumberType())
  1848. .INPUT(v, TensorType::NumberType())
  1849. .INPUT(beta1_power, TensorType::NumberType())
  1850. .INPUT(beta2_power, TensorType::NumberType())
  1851. .INPUT(lr, TensorType::NumberType())
  1852. .INPUT(beta1, TensorType::NumberType())
  1853. .INPUT(beta2, TensorType::NumberType())
  1854. .INPUT(epsilon, TensorType::NumberType())
  1855. .INPUT(grad, TensorType::NumberType())
  1856. .OUTPUT(var, TensorType::NumberType())
  1857. .OUTPUT(m, TensorType::NumberType())
  1858. .OUTPUT(v, TensorType::NumberType())
  1859. .ATTR(use_locking, Bool, false)
  1860. .ATTR(use_nesterov, Bool, false)
  1861. .OP_END_FACTORY_REG(ApplyAdamD)
  1862. /**
  1863. *@brief Updates "var" according to the proximal adadelta scheme . \n
  1864. *@par Inputs:
  1865. *Seven inputs, including:
  1866. * @li var: A mutable Tensor of type TensorType::NumberType().
  1867. * Should be a Variable Tensor.
  1868. * @li accum: A mutable Tensor of the same type as "var".
  1869. * Should be a Variable Tensor.
  1870. * @li accum_update: A mutable Tensor of the same type as "var".
  1871. * Should be a Variable Tensor.
  1872. * @li lr: A scalar of the same type as "var", for the scaling factor.
  1873. * @li rho: A scalar of the same type as "var", for the decay factor.
  1874. * @li epsilon: A scalar of the same type as "var", for the constant factor.
  1875. * @li grad: A Tensor of the same type as "var", for the gradient . \n
  1876. *@par Attributes:
  1877. *use_locking: An optional bool. Defaults to "False".
  1878. * If "True", updating of the "var", "accum" and "accum_update" tensors will be
  1879. * protected by a lock; otherwise the behavior is undefined,
  1880. * but may exhibit less contention . \n
  1881. *@par Outputs:
  1882. *var: A mutable Tensor. Has the same type as "var" . \n
  1883. *@par Third-party framework compatibility
  1884. * Compatible with the TensorFlow operator ApplyAdadelta.
  1885. */
  1886. REG_OP(ApplyAdadelta)
  1887. .INPUT(var, TensorType::NumberType())
  1888. .INPUT(accum, TensorType::NumberType())
  1889. .INPUT(accum_update, TensorType::NumberType())
  1890. .INPUT(lr, TensorType::NumberType())
  1891. .INPUT(rho, TensorType::NumberType())
  1892. .INPUT(epsilon, TensorType::NumberType())
  1893. .INPUT(grad, TensorType::NumberType())
  1894. .OUTPUT(var, TensorType::NumberType())
  1895. .ATTR(use_locking, Bool, false)
  1896. .OP_END_FACTORY_REG(ApplyAdadelta)
  1897. /**
  1898. *@brief Updates "var" according to the proximal adadelta scheme . \n
  1899. *@par Inputs:
  1900. *Seven inputs, including:
  1901. * @li var: A mutable Tensor of type TensorType::NumberType().
  1902. * Should be a Variable Tensor.
  1903. * @li accum: A mutable Tensor of the same type as "var".
  1904. * Should be a Variable Tensor.
  1905. * @li accum_update: A mutable Tensor of the same type as "var".
  1906. * Should be a Variable Tensor.
  1907. * @li lr: A scalar of the same type as "var", for the scaling factor.
  1908. * @li rho: A scalar of the same type as "var", for the decay factor.
  1909. * @li epsilon: A scalar of the same type as "var", for the constant factor.
  1910. * @li grad: A Tensor of the same type as "var", for the gradient . \n
  1911. *@par Attributes:
  1912. *use_locking: An optional bool. Defaults to "False".
  1913. * If "True", updating of the "var", "accum" and "accum_update" tensors will be
  1914. * protected by a lock; otherwise the behavior is undefined,
  1915. * but may exhibit less contention . \n
  1916. *@par Outputs:
  1917. *@li var: A mutable Tensor. Has the same type as "var".
  1918. *@li accum: A mutable Tensor. Has the same type as "var".
  1919. *@li accum_update: A mutable Tensor. Has the same type as "var" . \n
  1920. *@par Third-party framework compatibility
  1921. * Compatible with the TensorFlow operator ApplyAdadelta.
  1922. * @par Restrictions:
  1923. * Warning: THIS FUNCTION IS DEPRECATED. Please use ApplyAdadelta instead.
  1924. */
  1925. REG_OP(ApplyAdadeltaD)
  1926. .INPUT(var, TensorType::NumberType())
  1927. .INPUT(accum, TensorType::NumberType())
  1928. .INPUT(accum_update, TensorType::NumberType())
  1929. .INPUT(lr, TensorType::NumberType())
  1930. .INPUT(rho, TensorType::NumberType())
  1931. .INPUT(epsilon, TensorType::NumberType())
  1932. .INPUT(grad, TensorType::NumberType())
  1933. .OUTPUT(var, TensorType::NumberType())
  1934. .OUTPUT(accum, TensorType::NumberType())
  1935. .OUTPUT(accum_update, TensorType::NumberType())
  1936. .ATTR(use_locking, Bool, false)
  1937. .OP_END_FACTORY_REG(ApplyAdadeltaD)
  1938. /**
  1939. *@brief Updates "var" according to the ApplyMomentum algorithm.
  1940. * accum = accum * momentum + x1 * x2
  1941. * if use_nesterov is True:
  1942. * var -= x1 * x2 * lr + accum * momentum * lr
  1943. * else: var -= accum * lr
  1944. *
  1945. *@par Inputs:
  1946. * Six inputs, including:
  1947. *@li var: A mutable Tensor has type TensorType::NumberType().
  1948. * Should be a Variable Tensor.
  1949. *@li accum: A mutable Tensor has the same type as "var".
  1950. * Should be a Variable Tensor.
  1951. *@li lr: A scalar has the same type as "var", for the scaling factor.
  1952. *@li x1: A Tensor has type TensorType::NumberType().
  1953. *@li momentum: A scalar has the same type as "var".
  1954. *@li x2: A scalar has the same type as "var". \n
  1955. *
  1956. *@par Attributes:
  1957. * Two attributes, including:
  1958. *@li use_nesterov: An optional bool. Defaults to "False".
  1959. * If True, the tensor passed to compute grad will be
  1960. * var - lr * momentum * accum, so in the end,
  1961. * the var you get is actually var - lr * momentum * accum.
  1962. *@li use_locking: An optional bool. Defaults to "False".
  1963. * If "True", updating of the "var", m", and "v" tensors will be protected
  1964. * by a lock; otherwise the behavior is undefined, but may exhibit
  1965. * less contention. \n
  1966. *
  1967. *@par Outputs:
  1968. * Two outputs, including:
  1969. *@li var: A mutable Tensor has the same type as "var".
  1970. *@li accum: A mutable Tensor has the same type as "var". \n
  1971. *@par Restrictions:
  1972. * Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  1973. */
  1974. REG_OP(FusedMulApplyMomentum)
  1975. .INPUT(var, TensorType::NumberType())
  1976. .INPUT(accum, TensorType::NumberType())
  1977. .INPUT(lr, TensorType::NumberType())
  1978. .INPUT(x1, TensorType::NumberType())
  1979. .INPUT(momentum, TensorType::NumberType())
  1980. .INPUT(x2, TensorType::NumberType())
  1981. .OUTPUT(var, TensorType::NumberType())
  1982. .OUTPUT(accum, TensorType::NumberType())
  1983. .ATTR(use_nesterov, Bool, false)
  1984. .ATTR(use_locking, Bool, false)
  1985. .OP_END_FACTORY_REG(FusedMulApplyMomentum)
  1986. /**
  1987. * @brief Updates "var" according to the ApplyMomentum algorithm.
  1988. * accum = accum * momentum + x1 * x2
  1989. * if use_nesterov is True:
  1990. * var -= x1 * x2 * lr + accum * momentum * lr
  1991. * else:
  1992. * var -= accum * lr
  1993. *
  1994. * @par Inputs:
  1995. * Seven inputs, including:
  1996. * @li var: A mutable Tensor of type float32.
  1997. * Should be a Variable Tensor.
  1998. * @li accum: A mutable Tensor has type TensorType::NumberType().
  1999. * Should be a Variable Tensor.
  2000. * @li lr: A scalar has the same type as "accum", for the scaling factor.
  2001. * @li x1: A Tensor has the same type as "accum".
  2002. * @li momentum: A scalar has the same type as "accum".
  2003. * @li x2: A scalar has the same type as "accum".
  2004. * @li var_copy: A Tensor has type float16.
  2005. *
  2006. * @par Attributes:
  2007. * Two Attributes, including:
  2008. * @li use_nesterov: An optional bool. Defaults to "False".
  2009. * If True, the tensor passed to compute grad will be var - lr * momentum * accum,
  2010. * so in the end, the var you get is actually var - lr * momentum * accum.
  2011. * @li use_locking: An optional bool. Defaults to "False".
  2012. * If "True", updating of the "var", m", and "v" tensors will be protected
  2013. * by a lock; otherwise the behavior is undefined, but may exhibit less contention.
  2014. *
  2015. * @par Outputs:
  2016. * Three outputs, including:
  2017. * @li var: A Tensor has the type float32.
  2018. * @li var_copy: A Tensor has the type float16.
  2019. * @li accum: A Tensor has the same type as input "accum".
  2020. *@par Restrictions:
  2021. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  2022. */
  2023. REG_OP(FusedMulApplyMomentumExtern)
  2024. .INPUT(var, TensorType(DT_FLOAT))
  2025. .INPUT(accum, TensorType::NumberType())
  2026. .INPUT(lr, TensorType::NumberType())
  2027. .INPUT(x1, TensorType::NumberType())
  2028. .INPUT(momentum, TensorType::NumberType())
  2029. .INPUT(x2, TensorType::NumberType())
  2030. .INPUT(var_copy, TensorType(DT_FLOAT16))
  2031. .OUTPUT(var, TensorType(DT_FLOAT))
  2032. .OUTPUT(var_copy, TensorType(DT_FLOAT16))
  2033. .OUTPUT(accum, TensorType::NumberType())
  2034. .ATTR(use_nesterov, Bool, false)
  2035. .ATTR(use_locking, Bool, false)
  2036. .OP_END_FACTORY_REG(FusedMulApplyMomentumExtern)
  2037. /**
  2038. *@brief Updates '*var' according to the momentum scheme.
  2039. * accum = accum * momentum - x1 * x2 * lr
  2040. * if use_nesterov is True:
  2041. * var += accum * momentum - x1 * x2 * lr
  2042. * else:
  2043. * var += accum
  2044. *
  2045. *@par Inputs:
  2046. *@li var: A mutable tensor. Must be one of the data types defined in
  2047. * TensorType::NumberType(). Should be from a Variable().
  2048. *@li accum: A mutable tensor. Has the same type as "var". Should be from a
  2049. * Variable().
  2050. *@li lr: A tensor for the learning rate. Has the same type as "var". Should be
  2051. * from a Variable().
  2052. *@li x1: A Tensor has type TensorType::NumberType().
  2053. *@li momentum: A scalar. Has the same type as "var".
  2054. *@li x2: A scalar has the same type as "var".
  2055. *
  2056. *@par Attributes:
  2057. *@li use_nesterov: An optional bool. Defaults to "False".
  2058. * If "True", var will be updated by using Nesterov momentum.
  2059. *@li use_locking: An optional bool. Defaults to "False".
  2060. * If "True", updating of the "var" tensor is protected by a lock;
  2061. * otherwise the behavior is undefined, but may exhibit less contention.
  2062. *
  2063. *@par Outputs:
  2064. * @li var: A mutable tensor. Has the same type as input "var".
  2065. * @li accum: A mutable tensor. Has the same type as input "accum".
  2066. *
  2067. *@attention Constraints:
  2068. * @li var: A mutable tensor. Has the same type as input "var".
  2069. * @li accum: A mutable tensor. Has the same type as input "accum".
  2070. *
  2071. *@par Third-party framework compatibility
  2072. * Compatible with the TensorFlow operator ResourceApplyKerasMomentum.
  2073. *
  2074. */
  2075. REG_OP(FusedMulApplyKerasMomentum)
  2076. .INPUT(var, TensorType::NumberType())
  2077. .INPUT(accum, TensorType::NumberType())
  2078. .INPUT(lr, TensorType::NumberType())
  2079. .INPUT(x1, TensorType::NumberType())
  2080. .INPUT(momentum, TensorType::NumberType())
  2081. .INPUT(x2, TensorType::NumberType())
  2082. .OUTPUT(var, TensorType::NumberType())
  2083. .OUTPUT(accum, TensorType::NumberType())
  2084. .ATTR(use_locking, Bool, false)
  2085. .ATTR(use_nesterov, Bool, false)
  2086. .OP_END_FACTORY_REG(FusedMulApplyKerasMomentum)
  2087. /**
  2088. *@brief Update "g" according to the LARS algorithm . \n
  2089. *@par Inputs:
  2090. *Four inputs, including:
  2091. * @li w: A Tensor. Must be of type TensorType::DT_FLOAT.
  2092. * @li g: A Tensor of the same type and shape as "w".
  2093. * @li weight_decay: A Tensor of the same type as "w", Must be a scalar.
  2094. * @li learning_rate: A Tensor of the same type as "w", Must be a scalar . \n
  2095. *@par Attributes:
  2096. *Three Attributes, including:
  2097. * @li hyperpara: An optional float. Default value is 0.001.
  2098. * @li epsilon: An optional float. Default value is 1e-5.Avoid denominator is 0.
  2099. * @li use_clip: An optional bool. Defaults to "False".
  2100. * If "True", updating learning rate . \n
  2101. *@par Outputs:
  2102. *g_new: Tensor of the same type as "w".
  2103. */
  2104. REG_OP(LarsV2)
  2105. .INPUT(w, TensorType(DT_FLOAT))
  2106. .INPUT(g, TensorType(DT_FLOAT))
  2107. .INPUT(weight_decay, TensorType(DT_FLOAT))
  2108. .INPUT(learning_rate, TensorType(DT_FLOAT))
  2109. .OUTPUT(g_new, TensorType(DT_FLOAT))
  2110. .ATTR(hyperpara, Float, 0.001)
  2111. .ATTR(epsilon, Float, 0.00001)
  2112. .ATTR(use_clip, Bool, false)
  2113. .OP_END_FACTORY_REG(LarsV2)
  2114. /**
  2115. *@brief Update "g" according to the LARS algorithm . \n
  2116. *@par Inputs:
  2117. *Six inputs, including:
  2118. * @li w: A Tensor. Must be of type TensorType::DT_FLOAT.
  2119. * @li g: A Tensor of the same type and shape as "w".
  2120. * @li w_square_sum: A Tensor of square_sum(w), has the same type as "w", Must be a scalar.
  2121. * @li g_square_sum: A Tensor of square(g), has the same type as "w", Must be a scalar.
  2122. * @li weight_decay: A Tensor of the same type as "w", Must be a scalar.
  2123. * @li learning_rate: A Tensor of the same type as "w", Must be a scalar . \n
  2124. *@par Attributes:
  2125. *Three Attributes, including:
  2126. * @li hyperpara: An optional float. Default value is 0.001.
  2127. * @li epsilon: An optional float. Default value is 1e-5.Avoid denominator is 0.
  2128. * @li use_clip: An optional bool. Defaults to "False".
  2129. * If "True", updating learning rate . \n
  2130. *@par Outputs:
  2131. *g_new: Tensor of the same type as "w".
  2132. */
  2133. REG_OP(LarsV2Update)
  2134. .INPUT(w, TensorType(DT_FLOAT))
  2135. .INPUT(g, TensorType(DT_FLOAT))
  2136. .INPUT(w_square_sum, TensorType(DT_FLOAT))
  2137. .INPUT(g_square_sum, TensorType(DT_FLOAT))
  2138. .INPUT(weight_decay, TensorType(DT_FLOAT))
  2139. .INPUT(learning_rate, TensorType(DT_FLOAT))
  2140. .OUTPUT(g_new, TensorType(DT_FLOAT))
  2141. .ATTR(hyperpara, Float, 0.001)
  2142. .ATTR(epsilon, Float, 0.00001)
  2143. .ATTR(use_clip, Bool, false)
  2144. .OP_END_FACTORY_REG(LarsV2Update)
  2145. /**
  2146. * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme . \n
  2147. * @par Inputs:
  2148. * Nine inputs, including:
  2149. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  2150. * Should be a Variable Tensor.
  2151. * @li accum: A mutable Tensor of the same type as "var".
  2152. * Should be a Variable Tensor. The value of accum must be greater than 0.
  2153. * @li linear: A mutable Tensor of the same type as "var".
  2154. * Should be a Variable Tensor.
  2155. * @li grad: A Tensor of the same type as "var", for the gradient.
  2156. * @li indices: A vector of indices into the first dimension of var and accum.
  2157. * The value of indices must be unique. Otherwise, the result is unpredictable.
  2158. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2159. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  2160. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  2161. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar . \n
  2162. * @par Attributes:
  2163. * use_locking: An optional bool. Defaults to "False".
  2164. * If "True", updating of the "var" and "accum" tensors will be
  2165. * protected by a lock; otherwise the behavior is undefined,
  2166. * but may exhibit less contention . \n
  2167. * @par Outputs:
  2168. * var: A Tensor. Has the same type and format as input "var" . \n
  2169. * @par Third-party framework compatibility
  2170. * Compatible with the TensorFlow operator SparseApplyFtrl.
  2171. */
  2172. REG_OP(SparseApplyFtrl)
  2173. .INPUT(var, TensorType({DT_FLOAT}))
  2174. .INPUT(accum, TensorType({DT_FLOAT}))
  2175. .INPUT(linear, TensorType({DT_FLOAT}))
  2176. .INPUT(grad, TensorType({DT_FLOAT}))
  2177. .INPUT(indices, TensorType({DT_INT32}))
  2178. .INPUT(lr, TensorType({DT_FLOAT}))
  2179. .INPUT(l1, TensorType({DT_FLOAT}))
  2180. .INPUT(l2, TensorType({DT_FLOAT}))
  2181. .INPUT(lr_power, TensorType({DT_FLOAT}))
  2182. .OUTPUT(var, TensorType({DT_FLOAT}))
  2183. .ATTR(use_locking, Bool, false)
  2184. .OP_END_FACTORY_REG(SparseApplyFtrl)
  2185. /**
  2186. * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme . \n
  2187. * @par Inputs:
  2188. * Five inputs, including:
  2189. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  2190. * Should be a Variable Tensor.
  2191. * @li accum: A mutable Tensor of the same type as "var".
  2192. * Should be a Variable Tensor. The value of accum must be greater than 0.
  2193. * @li linear: A mutable Tensor of the same type as "var".
  2194. * Should be a Variable Tensor.
  2195. * @li grad: A Tensor of the same type as "var", for the gradient.
  2196. * @li indices: A vector of indices into the first dimension of var and accum.
  2197. * The value of indices must be unique. Otherwise, the result is unpredictable . \n
  2198. * @par Attributes:
  2199. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2200. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  2201. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  2202. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2203. * @li use_locking: An optional bool. Defaults to "False".
  2204. * If "True", updating of the "var" and "accum" tensors will be
  2205. * protected by a lock; otherwise the behavior is undefined,
  2206. * but may exhibit less contention . \n
  2207. * @par Outputs:
  2208. * @li var: A Tensor. Has the same type and format as input "var".
  2209. * @li accum: A Tensor. Has the same type and format as input "accum".
  2210. * @li linear: A Tensor. Has the same type and format as input "linear" . \n
  2211. * @par Third-party framework compatibility
  2212. * Compatible with the TensorFlow operator SparseApplyFtrl.
  2213. *
  2214. *@par Restrictions:
  2215. *Warning: THIS FUNCTION IS DEPRECATED. Please use SparseApplyFtrl instead.
  2216. */
  2217. REG_OP(SparseApplyFtrlD)
  2218. .INPUT(var, TensorType({DT_FLOAT}))
  2219. .INPUT(accum, TensorType({DT_FLOAT}))
  2220. .INPUT(linear, TensorType({DT_FLOAT}))
  2221. .INPUT(grad, TensorType({DT_FLOAT}))
  2222. .INPUT(indices, TensorType({DT_INT32}))
  2223. .OUTPUT(var, TensorType({DT_FLOAT}))
  2224. .OUTPUT(accum, TensorType({DT_FLOAT}))
  2225. .OUTPUT(linear, TensorType({DT_FLOAT}))
  2226. .REQUIRED_ATTR(lr, Float)
  2227. .REQUIRED_ATTR(l1, Float)
  2228. .REQUIRED_ATTR(l2, Float)
  2229. .REQUIRED_ATTR(lr_power, Float)
  2230. .ATTR(use_locking, Bool, false)
  2231. .OP_END_FACTORY_REG(SparseApplyFtrlD)
  2232. /**
  2233. * @brief Updates relevant entries in '*var' according to the Ftrl-proximal scheme.
  2234. * That is for rows we have grad for, "var", "accum" and "linear" are updated . \n
  2235. * @par Inputs:
  2236. * Ten inputs, including:
  2237. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  2238. * Should be a Variable Tensor.
  2239. * @li accum: A mutable Tensor of the same type as "var".
  2240. * Should be a Variable Tensor.
  2241. * @li linear: A mutable Tensor of the same type as "var".
  2242. * Should be a Variable Tensor.
  2243. * @li grad: A Tensor of the same type as "var", for the gradient.
  2244. * @li indices: A vector of indices into the first dimension of "var" and "accum".
  2245. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2246. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  2247. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  2248. * @li l2_shrinkage: A Tensor of the same type as "var", L2 shrinkage regulariation. Must be a scalar.
  2249. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar . \n
  2250. * @par Attributes:
  2251. * use_locking: An optional bool. Defaults to "False".
  2252. * If "True", updating of the "var" and "accum" tensors will be
  2253. * protected by a lock; otherwise the behavior is undefined,
  2254. * but may exhibit less contention . \n
  2255. * @par Outputs:
  2256. * var: A Tensor. Has the same type and format as input "var" . \n
  2257. * @par Third-party framework compatibility
  2258. * Compatible with the TensorFlow operator SparseApplyFtrlV2.
  2259. */
  2260. REG_OP(SparseApplyFtrlV2)
  2261. .INPUT(var, TensorType({DT_FLOAT}))
  2262. .INPUT(accum, TensorType({DT_FLOAT}))
  2263. .INPUT(linear, TensorType({DT_FLOAT}))
  2264. .INPUT(grad, TensorType({DT_FLOAT}))
  2265. .INPUT(indices, TensorType({DT_INT32}))
  2266. .INPUT(lr, TensorType({DT_FLOAT}))
  2267. .INPUT(l1, TensorType({DT_FLOAT}))
  2268. .INPUT(l2, TensorType({DT_FLOAT}))
  2269. .INPUT(l2_shrinkage, TensorType({DT_FLOAT}))
  2270. .INPUT(lr_power, TensorType({DT_FLOAT}))
  2271. .OUTPUT(var, TensorType({DT_FLOAT}))
  2272. .ATTR(use_locking, Bool, false)
  2273. .OP_END_FACTORY_REG(SparseApplyFtrlV2)
  2274. /**
  2275. * @brief Updates relevant entries in '*var' according to the Ftrl-proximal scheme.
  2276. * That is for rows we have grad for, "var", "accum" and "linear" are updated . \n
  2277. * @par Inputs:
  2278. * Five inputs, including:
  2279. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  2280. * Should be a Variable Tensor.
  2281. * @li accum: A mutable Tensor of the same type as "var".
  2282. * Should be a Variable Tensor.
  2283. * @li linear: A mutable Tensor of the same type as "var".
  2284. * Should be a Variable Tensor.
  2285. * @li grad: A Tensor of the same type as "var", for the gradient.
  2286. * @li indices: A vector of indices into the first dimension of "var" and "accum" . \n
  2287. * @par Attributes:
  2288. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2289. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  2290. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  2291. * @li l2_shrinkage: A Tensor of the same type as "var", L2 shrinkage regulariation. Must be a scalar.
  2292. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  2293. * @li use_locking: An optional bool. Defaults to "False".
  2294. * If "True", updating of the "var" and "accum" tensors will be
  2295. * protected by a lock; otherwise the behavior is undefined,
  2296. * but may exhibit less contention . \n
  2297. * @par Outputs:
  2298. * @li var: A Tensor. Has the same type and format as input "var".
  2299. * @li accum: A Tensor. Has the same type and format as input "accum".
  2300. * @li linear: A Tensor. Has the same type and format as input "linear" . \n
  2301. * @par Third-party framework compatibility
  2302. * Compatible with the TensorFlow operator SparseApplyFtrlV2D.
  2303. *
  2304. * @par Restrictions:
  2305. * Warning: THIS FUNCTION IS DEPRECATED. Please use SparseApplyFtrlV2 instead.
  2306. */
  2307. REG_OP(SparseApplyFtrlV2D)
  2308. .INPUT(var, TensorType({DT_FLOAT}))
  2309. .INPUT(accum, TensorType({DT_FLOAT}))
  2310. .INPUT(linear, TensorType({DT_FLOAT}))
  2311. .INPUT(grad, TensorType({DT_FLOAT}))
  2312. .INPUT(indices, TensorType({DT_INT32}))
  2313. .OUTPUT(var, TensorType({DT_FLOAT}))
  2314. .OUTPUT(accum, TensorType({DT_FLOAT}))
  2315. .OUTPUT(linear, TensorType({DT_FLOAT}))
  2316. .REQUIRED_ATTR(lr, Float)
  2317. .REQUIRED_ATTR(l1, Float)
  2318. .REQUIRED_ATTR(l2, Float)
  2319. .REQUIRED_ATTR(l2_shrinkage, Float)
  2320. .REQUIRED_ATTR(lr_power, Float)
  2321. .ATTR(use_locking, Bool, false)
  2322. .OP_END_FACTORY_REG(SparseApplyFtrlV2D)
  2323. /**
  2324. * @brief Updates "var" in specified index according to the RMSProp algorithm.
  2325. * mean_square = decay * mean_square + (1-decay) * gradient ** 2
  2326. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
  2327. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad
  2328. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
  2329. * var <- var - mom
  2330. *
  2331. * @par Inputs:
  2332. * Nine inputs, including:
  2333. * @li var: A mutable tensor. Must be one of the data types defined in
  2334. * TensorType::NumberType(). Should be from a Variable().
  2335. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  2336. * Variable().
  2337. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  2338. * Variable().
  2339. * @li lr: A scalar. Must have the same type as "var".
  2340. * @li rho: A scalar. Must have the same type as "var".
  2341. * @li momentum: A scalar. Must have the same type as "var".
  2342. * @li epsilon: A scalar. Must have the same type as "var".
  2343. * @li grad: A tensor, specifying the gradient.
  2344. * @li indices: A vector of indices into the first dimension of "var", "mom" and "ms".
  2345. *
  2346. * @par Attributes:
  2347. * use_locking: An optional "bool". Defaults to "False". If "True", updating of
  2348. * the "var", "ms", and "mom" tensors will be protected by a lock; otherwise the
  2349. * behavior is undefined, but may exhibit less contention.
  2350. *
  2351. * @par Outputs:
  2352. * var: A mutable tensor. Has the same type as input "var".
  2353. *
  2354. * @attention Constraints:
  2355. * @li Note that in this sparse implementation, "ms" and "mom" will not update
  2356. * in iterations during which "grad" is 0.
  2357. * @li The input tensors "var", "ms", and "mom" must have the same shape.
  2358. *
  2359. * @par Third-party framework compatibility
  2360. * Compatible with the TensorFlow operator SparseApplyRMSProp.
  2361. */
  2362. REG_OP(SparseApplyRMSProp)
  2363. .INPUT(var, TensorType::NumberType())
  2364. .INPUT(ms, TensorType::NumberType())
  2365. .INPUT(mom, TensorType::NumberType())
  2366. .INPUT(lr, TensorType::NumberType())
  2367. .INPUT(rho, TensorType::NumberType())
  2368. .INPUT(momentum, TensorType::NumberType())
  2369. .INPUT(epsilon, TensorType::NumberType())
  2370. .INPUT(grad, TensorType::NumberType())
  2371. .INPUT(indices, TensorType::IndexNumberType())
  2372. .OUTPUT(var, TensorType::NumberType())
  2373. .ATTR(use_locking, Bool, false)
  2374. .OP_END_FACTORY_REG(SparseApplyRMSProp)
  2375. /**
  2376. * @brief Updates "var" in specified index according to the RMSProp algorithm.
  2377. * a const input will be considered as an attribute.
  2378. * mean_square = decay * mean_square + (1-decay) * gradient ** 2
  2379. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
  2380. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad
  2381. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
  2382. * var <- var - mom
  2383. *
  2384. * @par Inputs:
  2385. * Six inputs, including:
  2386. * @li var: A mutable tensor. Must be one of the data types defined in
  2387. * TensorType::NumberType(). Should be from a Variable().
  2388. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  2389. * Variable().
  2390. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  2391. * Variable().
  2392. * @li lr: A scalar. Must have the same type as "var".
  2393. * @li grad: A tensor, specifying the gradient.
  2394. *
  2395. * @par Attributes:
  2396. * @li use_locking: An optional "bool". Defaults to "False". If "True",
  2397. * updating of the "var", "ms", and "mom" tensors will be protected by a lock;
  2398. * otherwise the behavior is undefined, but may exhibit less contention.
  2399. * @li rho: A required scalar. Must have the same type as "var".
  2400. * @li momentum: A required scalar. Must have the same type as "var".
  2401. * @li epsilon: A required scalar. Must have the same type as "var".
  2402. *
  2403. * @par Outputs:
  2404. * @li var: A mutable tensor. Must have the same type as input "var".
  2405. * @li ms: A mutable tensor. Must have the same type as input "ms".
  2406. * @li mom: A mutable tensor. Must have the same type as input "mom".
  2407. *
  2408. * @attention Constraints:
  2409. * @li Note that in this sparse implementation, "ms" and "mom" will not update
  2410. * in iterations during which "grad" is 0.
  2411. * @li The input tensors "var", "ms" and "mom" must have the same shape.
  2412. *
  2413. * @par Restrictions:
  2414. * Warning: THIS FUNCTION IS DEPRECATED. Please use SparseApplyRMSProp instead.
  2415. */
  2416. REG_OP(SparseApplyRMSPropD)
  2417. .INPUT(var, TensorType::NumberType())
  2418. .INPUT(ms, TensorType::NumberType())
  2419. .INPUT(mom, TensorType::NumberType())
  2420. .INPUT(lr, TensorType::NumberType())
  2421. .INPUT(grad, TensorType::NumberType())
  2422. .INPUT(indices, TensorType::IndexNumberType())
  2423. .OUTPUT(var, TensorType::NumberType())
  2424. .OUTPUT(ms, TensorType::NumberType())
  2425. .OUTPUT(mom, TensorType::NumberType())
  2426. .REQUIRED_ATTR(rho, Float)
  2427. .REQUIRED_ATTR(momentum, Float)
  2428. .REQUIRED_ATTR(epsilon, Float)
  2429. .ATTR(use_locking, Bool, false)
  2430. .OP_END_FACTORY_REG(SparseApplyRMSPropD)
  2431. /**
  2432. * @brief Updates "var" in specified index according to the Adadelta algorithm.
  2433. * accum <- rho * accum + (1 - rho) * grad.square()
  2434. * update <- (accum_update + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad
  2435. * var <- var - update * lr
  2436. * accum_update <- rho() * accum_update + (1 - rho()) * update.square()
  2437. *
  2438. * @par Inputs:
  2439. * Eight inputs, including:
  2440. * @li var: A mutable tensor. Must be one of the data types defined in
  2441. * TensorType::NumberType(). Should be from a Variable().
  2442. * @li accum: A mutable tensor. Must have the same type as "var". Should be from a
  2443. * Variable().
  2444. * @li accum_update: A mutable tensor. Must have the same type as "var". Should be from a
  2445. * Variable().
  2446. * @li lr: A scalar. Must have the same type as "var".
  2447. * @li rho: A scalar. Must have the same type as "var".
  2448. * @li epsilon: A scalar. Must have the same type as "var".
  2449. * @li grad: A tensor, specifying the gradient.
  2450. * @li indices: A vector of indices into the first dimension of "var", "accum" and "accum_update".
  2451. *
  2452. * @par Attributes:
  2453. * use_locking: An optional "bool". Defaults to "False". If "True", updating of
  2454. * the "var", "accum", and "accum_update" tensors will be protected by a lock; otherwise the
  2455. * behavior is undefined, but may exhibit less contention.
  2456. *
  2457. * @par Outputs:
  2458. * var: A mutable tensor. Has the same type as input "var".
  2459. *
  2460. * @attention Constraints:
  2461. * @li Note that in this sparse implementation, "accum" and "accum_update" will not update
  2462. * in iterations during which "grad" is 0.
  2463. * @li The input tensors "var", "accum", and "accum_update" must have the same shape.
  2464. *
  2465. * @par Third-party framework compatibility
  2466. * Compatible with the TensorFlow operator SparseApplyAdadelta.
  2467. */
  2468. REG_OP(SparseApplyAdadelta)
  2469. .INPUT(var, TensorType::NumberType())
  2470. .INPUT(accum, TensorType::NumberType())
  2471. .INPUT(accum_update, TensorType::NumberType())
  2472. .INPUT(lr, TensorType::NumberType())
  2473. .INPUT(rho, TensorType::NumberType())
  2474. .INPUT(epsilon, TensorType::NumberType())
  2475. .INPUT(grad, TensorType::NumberType())
  2476. .INPUT(indices, TensorType::IndexNumberType())
  2477. .OUTPUT(var, TensorType::NumberType())
  2478. .ATTR(use_locking, Bool, false)
  2479. .OP_END_FACTORY_REG(SparseApplyAdadelta)
  2480. /**
  2481. * @brief Updates "var" in specified index according to the Adadelta algorithm.
  2482. * a const input will be considered as an attribute.
  2483. * accum <- rho * accum + (1 - rho) * grad.square()
  2484. * update <- (accum_update + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad
  2485. * var <- var - update * lr
  2486. * accum_update <- rho() * accum_update + (1 - rho()) * update.square()
  2487. *
  2488. * @par Inputs:
  2489. * Seven inputs, including:
  2490. * @li var: A mutable tensor. Must be one of the data types defined in
  2491. * TensorType::NumberType(). Should be from a Variable().
  2492. * @li accum: A mutable tensor. Must have the same type as "var". Should be from a
  2493. * Variable().
  2494. * @li accum_update: A mutable tensor. Must have the same type as "var". Should be from a
  2495. * Variable().
  2496. * @li lr: A scalar. Must have the same type as "var".
  2497. * @li rho: A scalar. Must have the same type as "var".
  2498. * @li grad: A tensor, specifying the gradient.
  2499. * @li indices: A vector of indices into the first dimension of "var", "accum" and "accum_update".
  2500. *
  2501. * @par Attributes:
  2502. * @li use_locking: An optional "bool". Defaults to "False". If "True",
  2503. * updating of the "var", "accum", and "accum_update" tensors will be protected by a lock;
  2504. * otherwise the behavior is undefined, but may exhibit less contention.
  2505. * @li epsilon: A required scalar. Must have the same type as "var".
  2506. *
  2507. * @par Outputs:
  2508. * @li var: A mutable tensor. Must have the same type as input "var".
  2509. * @li accum: A mutable tensor. Must have the same type as input "accum".
  2510. * @li accum_update: A mutable tensor. Must have the same type as input "accum_update".
  2511. *
  2512. * @attention Constraints:
  2513. * @li Note that in this sparse implementation, "accum" and "accum_update" will not update
  2514. * in iterations during which "grad" is 0.
  2515. * @li The input tensors "var", "accum" and "accum_update" must have the same shape.
  2516. *
  2517. * @par Restrictions:
  2518. * Warning: THIS FUNCTION IS DEPRECATED. Please use SparseApplyAdadelta instead.
  2519. */
  2520. REG_OP(SparseApplyAdadeltaD)
  2521. .INPUT(var, TensorType::NumberType())
  2522. .INPUT(accum, TensorType::NumberType())
  2523. .INPUT(accum_update, TensorType::NumberType())
  2524. .INPUT(lr, TensorType::NumberType())
  2525. .INPUT(rho, TensorType::NumberType())
  2526. .INPUT(grad, TensorType::NumberType())
  2527. .INPUT(indices, TensorType::IndexNumberType())
  2528. .OUTPUT(var, TensorType::NumberType())
  2529. .OUTPUT(accum, TensorType::NumberType())
  2530. .OUTPUT(accum_update, TensorType::NumberType())
  2531. .REQUIRED_ATTR(epsilon, Float)
  2532. .ATTR(use_locking, Bool, false)
  2533. .OP_END_FACTORY_REG(SparseApplyAdadeltaD)
  2534. /**
  2535. *@brief Clean memory of workspace list . \n
  2536. *@par Attributes:
  2537. * @li automic_add_mem_size: sizes of workspaces . \n
  2538. *@par Restrictions:
  2539. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  2540. */
  2541. REG_OP(AtomicAddrClean)
  2542. .ATTR(automic_add_mem_size, ListInt, {})
  2543. .OP_END_FACTORY_REG(AtomicAddrClean)
  2544. /**
  2545. *@brief Clean memory of workspace list . \n
  2546. *@par Attributes:
  2547. * @li workspace_size: sizes of workspaces . \n
  2548. *@par Restrictions:
  2549. *Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use.
  2550. */
  2551. REG_OP(DynamicAtomicAddrClean)
  2552. .ATTR(automic_add_mem_size, ListInt, {})
  2553. .OP_END_FACTORY_REG(DynamicAtomicAddrClean)
  2554. } // namespace ge
  2555. #endif // OPS_BUILT_IN_OP_PROTO_INC_NN_TRAINING_OPS_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示