You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nn_training_ops.h 48 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_OP_TRAINING_OPS_H
  17. #define GE_OP_TRAINING_OPS_H
  18. #include "../../../inc/external/graph/operator_reg.h"
  19. #include "../graph/operator_reg.h"
  20. namespace ge {
  21. /**
  22. *@brief Updates "var" according to the AdaMax algorithm.\n
  23. * t-1 mean previous period.
  24. * m_t <- beta1 * m{t-1} + (1 - beta1) * grad\n
  25. * v_t <- max(beta2 * v{t-1}, abs(grad))\n
  26. * var <- var - lr / (1 - beta1^t) * m_t / (v_t + epsilon)
  27. *
  28. *@attention Constraints:\n
  29. * the input tensors must have the same shape.
  30. *
  31. *@par Inputs:
  32. *@li var: A mutable tensor. Must be one of the following types: TensorType::NumberType().
  33. * Should be from a Variable().
  34. *@li m: A mutable tensor. Has the same type as "var".
  35. * Should be from a Variable().
  36. *@li v: A mutable tensor. Has the same type as "var".
  37. * Should be from a Variable().
  38. *@li beta1_power: A scalar. Has the same type as "var".
  39. *@li lr: learning_rate. A scalar. Has the same type as "var".
  40. *@li beta1: A scalar. Has the same type as "var".
  41. *@li beta2: A scalar. Has the same type as "var".
  42. *@li epsilon: A scalar. Has the same type as "var".
  43. *@li grad: A tensor for the gradient. Has the same type as "var".
  44. *
  45. *@par Attributes:\n
  46. * use_locking: An optional bool. Defaults to "False".
  47. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  48. * by a lock; otherwise the behavior is undefined, but may exhibit less
  49. * contention.
  50. *
  51. *@par Outputs:
  52. * var: A mutable tensor. Has the same type as input "var".
  53. *
  54. */
  55. REG_OP(ApplyAdaMax)
  56. .INPUT(var, TensorType::NumberType())
  57. .INPUT(m, TensorType::NumberType())
  58. .INPUT(v, TensorType::NumberType())
  59. .INPUT(beta1_power, TensorType::NumberType())
  60. .INPUT(lr, TensorType::NumberType())
  61. .INPUT(beta1, TensorType::NumberType())
  62. .INPUT(beta2, TensorType::NumberType())
  63. .INPUT(epsilon, TensorType::NumberType())
  64. .INPUT(grad, TensorType::NumberType())
  65. .OUTPUT(var, TensorType::NumberType())
  66. .ATTR(use_locking, Bool, false)
  67. .OP_END_FACTORY_REG(ApplyAdaMax)
  68. /**
  69. *@brief Updates "var" according to the momentum scheme. Set use_nesterov = True if you
  70. * want to use Nesterov momentum.\n
  71. * computing process: \n
  72. * accum = accum * momentum + grad\n
  73. * var -= lr * accum
  74. *
  75. *@attention Constraints:\n
  76. * the input tensors must have the same shape.
  77. *
  78. *@par Inputs:
  79. *@li var: A mutable tensor. Should be from a Variable().
  80. *@li accum: A mutable tensor. Has the same type as "var".
  81. * Should be from a Variable().
  82. *@li lr: A scalar. Has the same type as "var".
  83. *@li grad: A tensor for the gradient. Has the same type as "var".
  84. *
  85. *@par Attributes:
  86. *@li use_nesterov: An optional bool. Defaults to "False".
  87. * If "True", the tensor passed to compute grad will be
  88. * var - lr * momentum * accum, so in the end, the var you get is actually
  89. * var - lr * momentum * accum.
  90. *
  91. *@li use_locking: An optional bool. Defaults to "False".\n
  92. * If "True", updating of the "var", "ms", and "mom" tensors is protected by a lock;
  93. * otherwise the behavior is undefined, but may exhibit less contention.
  94. *
  95. *@par Outputs:
  96. * var: A mutable tensor. Has the same type as input "var".
  97. *
  98. */
  99. REG_OP(ApplyMomentum)
  100. .INPUT(var, TensorType::NumberType())
  101. .INPUT(accum, TensorType::NumberType())
  102. .INPUT(lr, TensorType::NumberType())
  103. .INPUT(grad, TensorType::NumberType())
  104. .INPUT(momentum, TensorType::NumberType())
  105. .OUTPUT(var, TensorType::NumberType())
  106. .ATTR(use_nesterov, Bool, false)
  107. .ATTR(use_locking, Bool, false)
  108. .OP_END_FACTORY_REG(ApplyMomentum)
  109. /**
  110. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  111. *@par Inputs:
  112. * Five inputs, including:
  113. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  114. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  115. *@li lr: An NCHW, NHWC, or ND Tensor of type float32.
  116. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  117. *@li indices: An NCHW, NHWC, or ND Tensor of type float32.
  118. *@par Attributes:
  119. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  120. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  121. *@par Outputs:
  122. *var: A Tensor. Has the same type and format as input "var".
  123. */
  124. REG_OP(SparseApplyAdagrad)
  125. .INPUT(var, TensorType({DT_FLOAT}))
  126. .INPUT(accum, TensorType({DT_FLOAT}))
  127. .INPUT(lr, TensorType({DT_FLOAT}))
  128. .INPUT(grad, TensorType({DT_FLOAT}))
  129. .INPUT(indices, TensorType({DT_INT32}))
  130. .OUTPUT(var, TensorType({DT_FLOAT}))
  131. .ATTR(use_locking, Bool, false)
  132. .OP_END_FACTORY_REG(SparseApplyAdagrad)
  133. /**
  134. *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme.
  135. *@par Inputs:
  136. * Four inputs, including:
  137. *@li var: An NCHW, NHWC, or ND Tensor of type float32.
  138. *@li accum: An NCHW, NHWC, or ND Tensor of type float32.
  139. *@li grad: An NCHW, NHWC, or ND Tensor of type float32.
  140. *@li indices: An NCHW, NHWC, or ND Tensor of type int32.
  141. *@par Attributes:
  142. *@li lr: Required, used for computation.
  143. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock.
  144. *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False".
  145. *@par Outputs:
  146. *var: A Tensor. Has the same type and format as input "var".
  147. */
  148. REG_OP(SparseApplyAdagradD)
  149. .INPUT(var, TensorType({DT_FLOAT}))
  150. .INPUT(accum, TensorType({DT_FLOAT}))
  151. .INPUT(grad, TensorType({DT_FLOAT}))
  152. .INPUT(indices, TensorType({DT_INT32}))
  153. .OUTPUT(var, TensorType({DT_FLOAT}))
  154. .REQUIRED_ATTR(lr, Float)
  155. .ATTR(use_locking, Bool, false)
  156. .OP_END_FACTORY_REG(SparseApplyAdagradD)
  157. REG_OP(ApplyMomentumCCE)
  158. .INPUT(var, TensorType::NumberType())
  159. .INPUT(accum, TensorType::NumberType())
  160. .INPUT(lr, TensorType::NumberType())
  161. .INPUT(grad, TensorType::NumberType())
  162. .INPUT(momentum, TensorType::NumberType())
  163. .OUTPUT(var, TensorType::NumberType())
  164. .ATTR(use_nesterov, Bool, false)
  165. .ATTR(use_locking, Bool, false)
  166. .OP_END_FACTORY_REG(ApplyMomentumCCE)
  167. /**
  168. *@brief Updates "var" according to the AddSign update.\n
  169. * t-1 mean previous period.
  170. * m_t <- beta1 * m_{t-1} + (1 - beta1) * grad\n
  171. * update <- exp(logbase * sign_decay * sign(grad) * sign(m_t)) * grad\n
  172. * var <- var - lr * update
  173. *
  174. *@attention Constraints:\n
  175. * the input tensors must have the same shape.
  176. *
  177. *@par Inputs:
  178. *@li var: A mutable tensor. Should be from a Variable().
  179. *@li m: A mutable tensor. Has the same type as "var".
  180. * Should be from a Variable().
  181. *@li lr: A scalar. Has the same type as "var".
  182. *@li logbase: A scalar. Has the same type as "var".
  183. *@li sign_decay: A scalar. Has the same type as "var".
  184. *@li beta: A scalar. Has the same type as "var".
  185. *@li grad: A tensor for the gradient. Has the same type as "var".
  186. *
  187. *@par Attributes:
  188. * use_locking: An optional bool. Defaults to "False".
  189. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  190. * by a lock; otherwise the behavior is undefined, but may exhibit less
  191. * contention.
  192. *
  193. *@par Outputs:
  194. * var: A mutable tensor. Has the same type as input "var".
  195. *
  196. */
  197. REG_OP(ApplyPowerSign)
  198. .INPUT(var, TensorType::NumberType())
  199. .INPUT(m, TensorType::NumberType())
  200. .INPUT(lr, TensorType::NumberType())
  201. .INPUT(logbase, TensorType::NumberType())
  202. .INPUT(sign_decay, TensorType::NumberType())
  203. .INPUT(beta, TensorType::NumberType())
  204. .INPUT(grad, TensorType::NumberType())
  205. .OUTPUT(var, TensorType::NumberType())
  206. .ATTR(use_locking, Bool, false)
  207. .OP_END_FACTORY_REG(ApplyPowerSign)
  208. /**
  209. *@brief Updates "var" as FOBOS algorithm with fixed learning rate.\n
  210. * prox_v = var - alpha * delta\n
  211. * var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0}
  212. *
  213. *@attention Constraints:\n
  214. * the input tensors must have the same shape.
  215. *
  216. *@par Inputs:
  217. *@li var: A mutable tensor. Should be from a Variable().
  218. *@li alpha: A scalar. Has the same type as "var".
  219. *@li l1: A scalar. Has the same type as "var".
  220. *@li l2: A scalar. Has the same type as "var".
  221. *@li delta: A tensor. Has the same type as "var". The change.
  222. *
  223. *@par Attributes:
  224. * use_locking: An optional bool. Defaults to "False".
  225. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  226. * by a lock; otherwise the behavior is undefined, but may exhibit less
  227. * contention.
  228. *
  229. *@par Outputs:
  230. * var: A mutable tensor. Has the same type as input "var".
  231. *
  232. */
  233. REG_OP(ApplyProximalGradientDescent)
  234. .INPUT(var, TensorType::NumberType())
  235. .INPUT(alpha, TensorType::NumberType())
  236. .INPUT(l1, TensorType::NumberType())
  237. .INPUT(l2, TensorType::NumberType())
  238. .INPUT(delta, TensorType::NumberType())
  239. .OUTPUT(var, TensorType::NumberType())
  240. .ATTR(use_locking, Bool, false)
  241. .OP_END_FACTORY_REG(ApplyProximalGradientDescent)
  242. /**
  243. *@brief Updates "var" according to the AddSign update.
  244. *@par Inputs:
  245. *Seven inputs, including:
  246. * @li var: A mutable Tensor of type TensorType::NumberType().
  247. * Should be a Variable Tensor.
  248. * @li m: A mutable Tensor of the same type as "var".
  249. * Should be a Variable Tensor.
  250. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  251. * @li alpha: A Tensor of the same type as "var". Must be a scalar.
  252. * @li sign_decay: A Tensor of the same type as "var". Must be a scalar.
  253. * @li beta: A Tensor of the same type as "var". Must be a scalar.
  254. * @li grad: A Tensor of the same type as "var", for the gradient.
  255. *@par Attributes:
  256. *use_locking: An optional bool. Defaults to "False".
  257. * If "True", updating of the "var" and "m" tensors will be
  258. * protected by a lock; otherwise the behavior is undefined,
  259. * but may exhibit less contention.
  260. *@par Outputs:
  261. *var: A mutable Tensor. Has the same type as "var".
  262. */
  263. REG_OP(ApplyAddSign)
  264. .INPUT(var, TensorType::NumberType())
  265. .INPUT(m, TensorType::NumberType())
  266. .INPUT(lr, TensorType::NumberType())
  267. .INPUT(alpha, TensorType::NumberType())
  268. .INPUT(sign_decay, TensorType::NumberType())
  269. .INPUT(beta, TensorType::NumberType())
  270. .INPUT(grad, TensorType::NumberType())
  271. .OUTPUT(var, TensorType::NumberType())
  272. .ATTR(use_locking, Bool, false)
  273. .OP_END_FACTORY_REG(ApplyAddSign)
  274. /**
  275. *@brief Updates "var" according to the centered RMSProp algorithm.\n
  276. * The centered RMSProp algorithm uses an estimate of the centered second moment
  277. * (i.e., the variance) for normalization, as opposed to regular RMSProp, which
  278. * uses the (uncentered) second moment. This often helps with training, but is
  279. * slightly more expensive in terms of computation and memory.
  280. *
  281. * t-1 mean previous period.
  282. * mg <- rho * mg{t-1} + (1-rho) * grad\n
  283. * ms <- rho * ms{t-1} + (1-rho) * grad * grad\n
  284. * mom <- momentum * mom{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)\n
  285. * var <- var - mom\n
  286. *
  287. *@attention Constraints:\n
  288. *@li in dense implementation of this algorithm, mg, ms, and mom will
  289. * update even if the grad is zero, but in this sparse implementation, mg, ms,
  290. * and mom will not update in iterations during which the grad is zero.
  291. *@li the input tensors must have the same shape.
  292. *
  293. *@par Inputs:
  294. *@li var: A mutable tensor. Should be from a Variable().
  295. *@li mg: A mutable tensor. Has the same type as "var".
  296. * Should be from a Variable().
  297. *@li ms: A mutable tensor. Has the same type as "var".
  298. * Should be from a Variable().
  299. *@li mom: A mutable tensor. Has the same type as "var".
  300. * Should be from a Variable().
  301. *@li lr: A scalar. Has the same type as "var".
  302. *@li rho: A scalar. Has the same type as "var".
  303. *@li momentum: A tensor. Has the same type as "var".
  304. *@li epsilon: A scalar. Has the same type as "var".
  305. *@li grad: A tensor for the gradient. Has the same type as "var".
  306. *
  307. *@par Attributes:
  308. * use_locking: An optional bool. Defaults to "False".
  309. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  310. * by a lock; otherwise the behavior is undefined, but may exhibit less
  311. * contention.
  312. *
  313. *@par Outputs:
  314. * var: A mutable tensor. Has the same type as input "var".
  315. *
  316. */
  317. REG_OP(ApplyCenteredRMSProp)
  318. .INPUT(var, TensorType::NumberType())
  319. .INPUT(mg, TensorType::NumberType())
  320. .INPUT(ms, TensorType::NumberType())
  321. .INPUT(mom, TensorType::NumberType())
  322. .INPUT(lr, TensorType::NumberType())
  323. .INPUT(rho, TensorType::NumberType())
  324. .INPUT(momentum, TensorType::NumberType())
  325. .INPUT(epsilon, TensorType::NumberType())
  326. .INPUT(grad, TensorType::NumberType())
  327. .OUTPUT(var, TensorType::NumberType())
  328. .ATTR(use_locking, Bool, false)
  329. .OP_END_FACTORY_REG(ApplyCenteredRMSProp)
  330. /**
  331. *@brief Updates "var" by subtracting 'alpha' * 'delta' from it.\n
  332. * var -= delta * alpha
  333. *
  334. *@attention Constraints:\n
  335. * the input tensors must have the same shape.
  336. *
  337. *@par Inputs:
  338. *@li var: A mutable tensor. Should be from a Variable().
  339. *@li alpha: A scalar. Has the same type as "var".
  340. *@li delta: A tensor for the change. Has the same type as "var".
  341. *
  342. *@par Attributes:
  343. * use_locking: An optional bool. Defaults to "False".
  344. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  345. * by a lock; otherwise the behavior is undefined, but may exhibit less
  346. * contention.
  347. *
  348. *@par Outputs:
  349. * var: A mutable tensor. Has the same type as input "var".
  350. *
  351. */
  352. REG_OP(ApplyGradientDescent)
  353. .INPUT(var, TensorType::NumberType())
  354. .INPUT(alpha, TensorType::NumberType())
  355. .INPUT(delta, TensorType::NumberType())
  356. .OUTPUT(var, TensorType::NumberType())
  357. .ATTR(use_locking, Bool, false)
  358. .OP_END_FACTORY_REG(ApplyGradientDescent)
  359. /**
  360. *@brief Updates "var" according to the adagrad scheme.\n
  361. * accum += grad * grad\n
  362. * var -= lr * grad * (1 / sqrt(accum))
  363. *
  364. *@attention Constraints:\n
  365. * the input tensors must have the same shape.
  366. *
  367. *@par Inputs:
  368. *@li var: A mutable tensor. Should be from a Variable().
  369. *@li accum: A mutable tensor. Has the same type as "var".
  370. * Should be from a Variable().
  371. *@li lr: A scalar. Has the same type as "var".
  372. *@li grad: A tensor for the gradient. Has the same type as "var".
  373. *
  374. *@par Attributes:
  375. * use_locking: An optional bool. Defaults to "False".
  376. * If "True", updating of the "var", "ms", and "mom" tensors is protected
  377. * by a lock; otherwise the behavior is undefined, but may exhibit less
  378. * contention.
  379. *
  380. *@par Outputs:
  381. * var: A mutable tensor. Has the same type as input "var".
  382. *
  383. */
  384. REG_OP(ApplyAdagrad)
  385. .INPUT(var, TensorType::NumberType())
  386. .INPUT(accum, TensorType::NumberType())
  387. .INPUT(lr, TensorType::NumberType())
  388. .INPUT(grad, TensorType::NumberType())
  389. .OUTPUT(var, TensorType::NumberType())
  390. .ATTR(update_slots, Bool, true)
  391. .ATTR(use_locking, Bool, false)
  392. .OP_END_FACTORY_REG(ApplyAdagrad)
  393. /**
  394. *@brief Updates "var" according to the proximal adagrad scheme.
  395. *@par Inputs:
  396. *Eight inputs, including:
  397. * @li var: A mutable Tensor. Must be one of the following types:
  398. * TensorType::NumberType(). Should be a Variable Tensor.
  399. * @li gradient_accumulator: A mutable Tensor. Must have the same
  400. * type as "var". Should be a Variable Tensor.
  401. * @li gradient_squared_accumulator: A mutable Tensor of the same type as "var".
  402. * Should be a Variable Tensor.
  403. * @li grad: A Tensor of the same type as "var", for the gradient.
  404. * @li lr: A Tensor of the same type as "var".
  405. * Scaling factor. Must be a scalar.
  406. * @li l1: A Tensor of the same type as "var".
  407. * L1 regulariation. Must be a scalar.
  408. * @li l2: A Tensor of the same type as "var".
  409. * L2 regulariation. Must be a scalar.
  410. * @li global_step: A Tensor of type int32 or int64.
  411. * Training step number. Must be a scalar.
  412. *@par Attributes:
  413. *use_locking: An optional bool. Defaults to "False".
  414. * If "True", updating of the var and accum tensors will be
  415. * protected by a lock; otherwise the behavior is undefined,
  416. * but may exhibit less contention.
  417. *@par Outputs:
  418. *var: A mutable Tensor. Has the same type as "var".
  419. */
  420. REG_OP(ApplyAdagradDA)
  421. .INPUT(var, TensorType::NumberType())
  422. .INPUT(gradient_accumulator, TensorType::NumberType())
  423. .INPUT(gradient_squared_accumulator, TensorType::NumberType())
  424. .INPUT(grad, TensorType::NumberType())
  425. .INPUT(lr, TensorType::NumberType())
  426. .INPUT(l1, TensorType::NumberType())
  427. .INPUT(l2, TensorType::NumberType())
  428. .INPUT(global_step, TensorType({DT_INT32, DT_INT64}))
  429. .OUTPUT(var, TensorType::NumberType())
  430. .ATTR(use_locking, Bool, false)
  431. .OP_END_FACTORY_REG(ApplyAdagradDA)
  432. /**
  433. *@brief Returns the dimension index in the destination data format given the one in
  434. * the source data format.
  435. *
  436. *@par Inputs:
  437. * x: A tensor of type int32 or int64.
  438. * A Tensor with each element as a dimension index in source data format.
  439. * Must be in the range [-4, 4).
  440. *
  441. *@par Attributes:
  442. *@li src_format: An optional string. Defaults to NHWC.
  443. * source data format.
  444. *@li dst_format: An optional string. Defaults to NCHW.
  445. * destination data format.
  446. *
  447. *@par Outputs:
  448. * y: A tensor. Has the same type as "x".
  449. *
  450. */
  451. REG_OP(DataFormatDimMap)
  452. .INPUT(x, TensorType::IndexNumberType())
  453. .ATTR(src_format, String, "NHWC")
  454. .ATTR(dst_format, String, "NCHW")
  455. .OUTPUT(y, TensorType::IndexNumberType())
  456. .OP_END_FACTORY_REG(DataFormatDimMap)
  457. /**
  458. * @brief Implements stochastic gradient descent (optionally with momentum).\n
  459. * Nesterov momentum is based on the formula from
  460. * On the importance of initialization and momentum in deep learning.\n
  461. * @par Inputs:
  462. * @li parameters: A mutable tensor of type float16 or float32.\n
  463. * Specifies the iterable of parameters to optimize or dicts defining parameter
  464. * groups.
  465. * @li gradient: A tensor of type float16 or float32.\n
  466. * Specifies the gradient of training step.
  467. * @li learning_rate: A tensor of type float16 or float32.\n
  468. * Specifies the learing_rate of training step.
  469. * @li accum: A tensor of type float16 or float32.
  470. * Specifies the velocity of training step.
  471. * @li momentum: A tensor of type float16 or float32.
  472. * Specifies the momentum factor.
  473. * @li stat: A tensor of type float16 or float32.
  474. * Specifies the status representing the first step or not.
  475. * @par Attributes:
  476. * @li dampening: An optional float, specifying the dampening for momentum.
  477. * Defaults to "0.0".
  478. * @li weight_decay: An optional float, specifying the L2 penalty. Defaults to
  479. * "0.0".
  480. * @li nesterov: An optional bool, specifying whether to enable Nesterov
  481. * momentum. Defaults to "False".
  482. * @par Outputs:
  483. * parameters: A mutable tensor same as input "parameters".
  484. * @see ApplyMomentum()
  485. */
  486. REG_OP(SGD)
  487. .INPUT(parameters, TensorType(DT_FLOAT, DT_FLOAT16))
  488. .INPUT(gradient, TensorType(DT_FLOAT, DT_FLOAT16))
  489. .INPUT(learning_rate, TensorType(DT_FLOAT, DT_FLOAT16))
  490. .INPUT(accum, TensorType(DT_FLOAT, DT_FLOAT16))
  491. .INPUT(momentum, TensorType(DT_FLOAT, DT_FLOAT16))
  492. .INPUT(stat, TensorType(DT_FLOAT, DT_FLOAT16))
  493. .OUTPUT(parameters, TensorType(DT_FLOAT, DT_FLOAT16))
  494. .ATTR(dampening, Float, 0.0)
  495. .ATTR(weight_decay, Float, 0.0)
  496. .ATTR(nesterov, Bool, false)
  497. .OP_END_FACTORY_REG(SGD)
  498. /**
  499. * @brief Updates "var" according to the RMSProp algorithm.\n
  500. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  501. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  502. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  503. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  504. * var <- var - mom\n
  505. *
  506. * @par Inputs:
  507. * @li var: A mutable tensor. Must be one of the data types defined in\n
  508. * TensorType::NumberType(). Should be from a Variable().
  509. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  510. * Variable().
  511. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  512. * Variable().
  513. * @li lr: A scalar. Must have the same type as "var".
  514. * @li rho: A scalar. Must have the same type as "var".
  515. * @li momentum: A scalar. Must have the same type as "var".
  516. * @li epsilon: A scalar. Must have the same type as "var".
  517. * @li grad: A tensor, specifying the gradient. Must have the same type as "var".
  518. *
  519. * @par Attributes:
  520. * use_locking: An optional "bool". Defaults to "False". If "True", updating of\n
  521. * the "var", "ms", and "mom" tensors will be protected by a lock; otherwise the\n
  522. * behavior is undefined, but may exhibit less contention.
  523. *
  524. * @par Outputs:
  525. * var: A mutable tensor. Has the same type as input "var".
  526. *
  527. * @attention Constraints:
  528. * @li Note that in dense implementation of this algorithm, "ms" and "mom" will\n
  529. * update even if "grad" is 0, but in this sparse implementation, "ms" and "mom"\n
  530. * will not update in iterations during which "grad" is 0.\n
  531. * @li The input tensors "var", "ms", "mom" and "grad" must have the same shape.
  532. */
  533. REG_OP(ApplyRMSProp)
  534. .INPUT(var, TensorType::NumberType())
  535. .INPUT(ms, TensorType::NumberType())
  536. .INPUT(mom, TensorType::NumberType())
  537. .INPUT(lr, TensorType::NumberType())
  538. .INPUT(rho, TensorType::NumberType())
  539. .INPUT(momentum, TensorType::NumberType())
  540. .INPUT(epsilon, TensorType::NumberType())
  541. .INPUT(grad, TensorType::NumberType())
  542. .OUTPUT(var, TensorType::NumberType())
  543. .ATTR(use_locking, Bool, false)
  544. .OP_END_FACTORY_REG(ApplyRMSProp)
  545. /**
  546. * @brief Updates "var" according to the RMSProp algorithm, a const input will be
  547. * considered as an attribute.\n
  548. * mean_square = decay * mean_square + (1-decay) * gradient ** 2\n
  549. * Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n
  550. * ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n
  551. * mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n
  552. * var <- var - mom
  553. *
  554. * @par Inputs:
  555. * @li var: A mutable tensor. Must be one of the data types defined in\n
  556. * TensorType::NumberType(). Should be from a Variable().
  557. * @li ms: A mutable tensor. Must have the same type as "var". Should be from a
  558. * Variable().
  559. * @li mom: A mutable tensor. Must have the same type as "var". Should be from a
  560. * Variable().
  561. * @li lr: A scalar. Must have the same type as "var".
  562. * @li grad: A tensor, specifying the gradient. Must have the same type as "var".
  563. *
  564. * @par Attributes:
  565. * @li use_locking: An optional "bool". Defaults to "False". If "True", updating\n
  566. * of the "var", "ms", and "mom" tensors will be protected by a lock; otherwise
  567. * the behavior is undefined, but may exhibit less contention.
  568. * @li rho: A required scalar. Must have the same type as "var".
  569. * @li momentum: A required scalar. Must have the same type as "var".
  570. * @li epsilon: A required scalar. Must have the same type as "var".
  571. *
  572. * @par Outputs:
  573. * var: A mutable tensor. Must have the same type as input "var".
  574. * @attention Constraints:
  575. * @li Note that in dense implementation of this algorithm, "ms" and "mom" will\n
  576. * update even if "grad" is 0, but in this sparse implementation, "ms" and "mom"\n
  577. * will not update in iterations during which "grad" is 0.
  578. * @li The input tensors "var", "ms", "mom" and "grad" must have the same shape.
  579. */
  580. REG_OP(ApplyRMSPropD)
  581. .INPUT(var, TensorType::NumberType())
  582. .INPUT(ms, TensorType::NumberType())
  583. .INPUT(mom, TensorType::NumberType())
  584. .INPUT(lr, TensorType::NumberType())
  585. .INPUT(grad, TensorType::NumberType())
  586. .OUTPUT(var, TensorType::NumberType())
  587. .REQUIRED_ATTR(rho, Float)
  588. .REQUIRED_ATTR(momentum, Float)
  589. .REQUIRED_ATTR(epsilon, Float)
  590. .ATTR(use_locking, Bool, false)
  591. .OP_END_FACTORY_REG(ApplyRMSPropD)
  592. /**
  593. *@brief Update "var" and "accum" according to FOBOS with Adagrad learning rate.
  594. *@par Inputs:
  595. *Six inputs, including:
  596. * @li var: A mutable Tensor of type TensorType::NumberType().
  597. * Should be from a Variable().
  598. * @li accum: A mutable Tensor of the same type as "var". Should be from a Variable().
  599. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  600. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  601. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  602. * @li grad: A Tensor of the same type as "var", for the gradient.
  603. *@par Attributes:
  604. *use_locking: An optional bool. Defaults to "False". If "True", updating of the "var" and "accum" *tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less *contention.
  605. *@par Outputs:
  606. *var: A mutable Tensor. Has the same type as "var".
  607. */
  608. REG_OP(ApplyProximalAdagrad)
  609. .INPUT(var, TensorType::NumberType())
  610. .INPUT(accum, TensorType::NumberType())
  611. .INPUT(lr, TensorType::NumberType())
  612. .INPUT(l1, TensorType::NumberType())
  613. .INPUT(l2, TensorType::NumberType())
  614. .INPUT(grad, TensorType::NumberType())
  615. .OUTPUT(var, TensorType::NumberType())
  616. .ATTR(use_locking, Bool, false)
  617. .OP_END_FACTORY_REG(ApplyProximalAdagrad)
  618. /**
  619. *@brief Updates entries in 'var' and 'accum' according to the Proximal Adagrad algorithm.\ n
  620. * Compared with op ApplyProximalAdagrad, an additional index tensor is input,
  621. * Only the indices into the first dimensions of "var" and "accum" are updated.
  622. *@par Inputs:
  623. * Seven inputs, including:\n
  624. * @li var: A mutable Tensor.\n
  625. * TensorType::NumberType(). Should be a Variable Tensor.
  626. * @li accum: A mutable Tensor of the same type as "var".\n
  627. * Should be a Variable Tensor.
  628. * @li lr: A Tensor of the same type as "var".\n
  629. * Scaling factor. Must be a scalar.
  630. * @li l1: A Tensor of the same type as "var".\n
  631. * L1 regulariation. Must be a scalar.
  632. * @li l2: A Tensor of the same type as "var".\n
  633. * L2 regulariation. Must be a scalar.
  634. * @li grad: A Tensor. Has the same type as "var". \n
  635. * The gradient.
  636. * @li indices: A vector of indices into the first dimension of "var" and "accum".\n
  637. * TensorType::IndexNumberType().
  638. *@par Attributes:
  639. *use_locking: An optional bool. Defaults to "False".\n
  640. * If "True", updating of the var and accum tensors will be protected by a lock; \n
  641. * If "False", the behavior is undefined, but may exhibit less contention.
  642. *@par Outputs:
  643. *var: A mutable Tensor. Has the same type as "var".
  644. */
  645. REG_OP(SparseApplyProximalAdagrad)
  646. .INPUT(var, TensorType::NumberType())
  647. .INPUT(accum, TensorType::NumberType())
  648. .INPUT(lr, TensorType::NumberType())
  649. .INPUT(l1, TensorType::NumberType())
  650. .INPUT(l2, TensorType::NumberType())
  651. .INPUT(grad, TensorType::NumberType())
  652. .INPUT(indices, TensorType::IndexNumberType())
  653. .OUTPUT(var, TensorType::NumberType())
  654. .ATTR(use_locking, Bool, false)
  655. .OP_END_FACTORY_REG(SparseApplyProximalAdagrad)
  656. /**
  657. *@brief Updates "var" according to the Ftrl-proximal scheme.
  658. *@par Inputs:
  659. *Eight inputs, including:
  660. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  661. * Should be a Variable Tensor.
  662. * @li accum: A mutable Tensor of the same type as "var".
  663. * Should be a Variable Tensor.
  664. * @li linear: A mutable Tensor of the same type as "var".
  665. * Should be a Variable Tensor.
  666. * @li grad: A Tensor of the same type as "var", for the gradient.
  667. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  668. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  669. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  670. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  671. *@par Attributes:
  672. *use_locking: An optional bool. Defaults to "False".
  673. * If "True", updating of the "var" and "accum" tensors will be
  674. * protected by a lock; otherwise the behavior is undefined,
  675. * but may exhibit less contention.
  676. *@par Outputs:
  677. *var: A mutable Tensor. Has the same type as "var".
  678. */
  679. REG_OP(ApplyFtrl)
  680. .INPUT(var, TensorType::NumberType())
  681. .INPUT(accum, TensorType::NumberType())
  682. .INPUT(linear, TensorType::NumberType())
  683. .INPUT(grad, TensorType::NumberType())
  684. .INPUT(lr, TensorType::NumberType())
  685. .INPUT(l1, TensorType::NumberType())
  686. .INPUT(l2, TensorType::NumberType())
  687. .INPUT(lr_power, TensorType::NumberType())
  688. .OUTPUT(var, TensorType::NumberType())
  689. .ATTR(use_locking, Bool, false)
  690. .OP_END_FACTORY_REG(ApplyFtrl)
  691. /**
  692. *@brief Update "var" according to the Ftrl-proximal scheme.
  693. *@par Inputs:
  694. *Nine inputs, including:
  695. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  696. * Should be a Variable Tensor.
  697. * @li accum: A mutable Tensor of the same type as "var".
  698. * Should be a Variable Tensor.
  699. * @li linear: A mutable Tensor of the same type as "var".
  700. * Should be a Variable Tensor.
  701. * @li grad: A Tensor of the same type as "var", for the gradient.
  702. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  703. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  704. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  705. * @li l2_shrinkage: A Tensor of the same type as "var".
  706. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  707. *@par Attributes:
  708. *use_locking: An optional bool. Defaults to "False".
  709. * If "True", updating of the "var" and "accum" tensors will be
  710. * protected by a lock; otherwise the behavior is undefined,
  711. * but may exhibit less contention.
  712. *@par Outputs:
  713. *var: A mutable Tensor. Has the same type as "var".
  714. */
  715. REG_OP(ApplyFtrlV2)
  716. .INPUT(var, TensorType::NumberType())
  717. .INPUT(accum, TensorType::NumberType())
  718. .INPUT(linear, TensorType::NumberType())
  719. .INPUT(grad, TensorType::NumberType())
  720. .INPUT(lr, TensorType::NumberType())
  721. .INPUT(l1, TensorType::NumberType())
  722. .INPUT(l2, TensorType::NumberType())
  723. .INPUT(l2_shrinkage, TensorType::NumberType())
  724. .INPUT(lr_power, TensorType::NumberType())
  725. .OUTPUT(var, TensorType::NumberType())
  726. .ATTR(use_locking, Bool, false)
  727. .OP_END_FACTORY_REG(ApplyFtrlV2)
  728. /**
  729. *@brief Updates "var" according to the Adam algorithm.\n
  730. * lr_t <- text{learning\_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)\n
  731. * m_t <- beta_1 * m_{t-1} + (1 - beta_1) * g\n
  732. * v_t <- max(beta2 * v{t-1}, abs(g))\n
  733. * variable <- variable - lr_t * m_t / (sqrt{v_t} + epsilon)
  734. *
  735. *@attention Constraints:\n
  736. * *The input tensors must have the same shape.*
  737. *
  738. *@par Inputs:
  739. *@li var: A mutable Tensor of the type TensorType::NumberType().
  740. * Should be from a Variable().
  741. *@li m: A mutable Tensor of the same type as "var".
  742. * Should be from a Variable().
  743. *@li v: A mutable Tensor of the same type as "var".
  744. * Should be from a Variable().
  745. *@li beta1_power: A scalar of the same type as "var".
  746. *@li beta2_power: A scalar of the same type as "var".
  747. *@li lr: learning_rate. A scalar of the same type as "var".
  748. *@li beta1: A scalar of the same type as "var".
  749. *@li beta2: A scalar of the same type as "var".
  750. *@li epsilon: A scalar of the same type as "var".
  751. *@li grad: A Tensor of the same type as "var", for the gradient.
  752. *
  753. *@par Attributes:\n
  754. *@li use_locking: An optional bool. Defaults to "False".
  755. * If "True", updating of the "var", m", and "v" tensors will be protected
  756. * by a lock; otherwise the behavior is undefined, but may exhibit less
  757. * contention.
  758. *@li use_nesterov: An optional bool. Defaults to "False".
  759. If "True", uses the nesterov update.
  760. *
  761. *@par Outputs:
  762. * var: A mutable Tensor. Has the same type as intput "var".
  763. */
  764. REG_OP(ApplyAdam)
  765. .INPUT(var, TensorType::NumberType())
  766. .INPUT(m, TensorType::NumberType())
  767. .INPUT(v, TensorType::NumberType())
  768. .INPUT(beta1_power, TensorType::NumberType())
  769. .INPUT(beta2_power, TensorType::NumberType())
  770. .INPUT(lr, TensorType::NumberType())
  771. .INPUT(beta1, TensorType::NumberType())
  772. .INPUT(beta2, TensorType::NumberType())
  773. .INPUT(epsilon, TensorType::NumberType())
  774. .INPUT(grad, TensorType::NumberType())
  775. .OUTPUT(var, TensorType::NumberType())
  776. .OUTPUT(m, TensorType::NumberType())
  777. .OUTPUT(v, TensorType::NumberType())
  778. .ATTR(use_locking, Bool, false)
  779. .ATTR(use_nesterov, Bool, false)
  780. .OP_END_FACTORY_REG(ApplyAdam)
  781. /**
  782. *@brief Updates "var" according to the proximal adadelta scheme.
  783. *@par Inputs:
  784. *Seven inputs, including:
  785. * @li var: A mutable Tensor of type TensorType::NumberType().
  786. * Should be a Variable Tensor.
  787. * @li accum: A mutable Tensor of the same type as "var".
  788. * Should be a Variable Tensor.
  789. * @li accum_update: A mutable Tensor of the same type as "var".
  790. * Should be a Variable Tensor.
  791. * @li lr: A scalar of the same type as "var", for the scaling factor.
  792. * @li rho: A scalar of the same type as "var", for the decay factor.
  793. * @li epsilon: A scalar of the same type as "var", for the constant factor.
  794. * @li grad: A Tensor of the same type as "var", for the gradient.
  795. *@par Attributes:
  796. *use_locking: An optional bool. Defaults to "False".
  797. * If "True", updating of the "var", "accum" and "accum_update" tensors will be
  798. * protected by a lock; otherwise the behavior is undefined,
  799. * but may exhibit less contention.
  800. *@par Outputs:
  801. *var: A mutable Tensor. Has the same type as "var".
  802. */
  803. REG_OP(ApplyAdadelta)
  804. .INPUT(var, TensorType::NumberType())
  805. .INPUT(accum, TensorType::NumberType())
  806. .INPUT(accum_update, TensorType::NumberType())
  807. .INPUT(lr, TensorType::NumberType())
  808. .INPUT(rho, TensorType::NumberType())
  809. .INPUT(epsilon, TensorType::NumberType())
  810. .INPUT(grad, TensorType::NumberType())
  811. .OUTPUT(var, TensorType::NumberType())
  812. .ATTR(use_locking, Bool, false)
  813. .OP_END_FACTORY_REG(ApplyAdadelta)
  814. /**
  815. *@brief Updates "var" according to the ApplyMomentum algorithm. \n
  816. * accum = accum * momentum + x1 * x2
  817. * if use_nesterov is True:
  818. * var -= x1 * x2 * lr + accum * momentum * lr
  819. * else:
  820. * var -= accum * lr
  821. *@par Inputs:
  822. *Six inputs, including:
  823. * @li var: A mutable Tensor of type TensorType::NumberType().
  824. * Should be a Variable Tensor.
  825. * @li accum: A mutable Tensor of the same type as "var".
  826. * Should be a Variable Tensor.
  827. * @li lr: A scalar of the same type as "var", for the scaling factor.
  828. * @li x1: A Tensor of type TensorType::NumberType().
  829. * @li momentum: A scalar of the same type as "var".
  830. * @li x2: A Tensor of the same type as "var".
  831. *@par Attributes:
  832. *Two Attributes, including:
  833. *@li use_nesterov: An optional bool. Defaults to "False". \n
  834. * If True, the tensor passed to compute grad will be var - lr * momentum * accum, \n
  835. * so in the end, the var you get is actually var - lr * momentum * accum.
  836. *@li use_locking: An optional bool. Defaults to "False". \n
  837. * If "True", updating of the "var", m", and "v" tensors will be protected \n
  838. * by a lock; otherwise the behavior is undefined, but may exhibit less contention.
  839. *@par Outputs:
  840. *var: A mutable Tensor. Has the same type as "var".
  841. */
  842. REG_OP(FusedMulApplyMomentum)
  843. .INPUT(var, TensorType::NumberType())
  844. .INPUT(accum, TensorType::NumberType())
  845. .INPUT(lr, TensorType::NumberType())
  846. .INPUT(x1, TensorType::NumberType())
  847. .INPUT(momentum, TensorType::NumberType())
  848. .INPUT(x2, TensorType::NumberType())
  849. .OUTPUT(var, TensorType::NumberType())
  850. .ATTR(use_nesterov, Bool, false)
  851. .ATTR(use_locking, Bool, false)
  852. .OP_END_FACTORY_REG(FusedMulApplyMomentum)
  853. /**
  854. *@brief Updates "var" according to the ApplyMomentum algorithm. \n
  855. * accum = accum * momentum + x1 * x2
  856. * if use_nesterov is True:
  857. * var -= x1 * x2 * lr + accum * momentum * lr
  858. * else:
  859. * var -= accum * lr
  860. *@par Inputs:
  861. *Six inputs, including:
  862. * @li var: A mutable Tensor of type TensorType::NumberType().
  863. * Should be a Variable Tensor.
  864. * @li accum: A mutable Tensor of the same type as "var".
  865. * Should be a Variable Tensor.
  866. * @li lr: A scalar of the same type as "var", for the scaling factor.
  867. * @li x1: A Tensor of type TensorType::NumberType().
  868. * @li momentum: A scalar of the same type as "var".
  869. * @li x2: A Tensor of the same type as "var".
  870. *@par Attributes:
  871. *Two Attributes, including:
  872. *@li use_nesterov: An optional bool. Defaults to "False". \n
  873. * If True, the tensor passed to compute grad will be var - lr * momentum * accum, \n
  874. * so in the end, the var you get is actually var - lr * momentum * accum.
  875. *@li use_locking: An optional bool. Defaults to "False". \n
  876. * If "True", updating of the "var", m", and "v" tensors will be protected \n
  877. * by a lock; otherwise the behavior is undefined, but may exhibit less contention.
  878. *@par Outputs:
  879. *Two outputs, including:
  880. *@li var: A Tensor. Has the same type as "var".
  881. *@li var_copy: A Tensor. Has the same type as "var".
  882. */
  883. REG_OP(FusedMulApplyMomentumExtern)
  884. .INPUT(var, TensorType::NumberType())
  885. .INPUT(accum, TensorType::NumberType())
  886. .INPUT(lr, TensorType::NumberType())
  887. .INPUT(x1, TensorType::NumberType())
  888. .INPUT(momentum, TensorType::NumberType())
  889. .INPUT(x2, TensorType::NumberType())
  890. .INPUT(var_copy, TensorType::NumberType())
  891. .OUTPUT(var, TensorType::NumberType())
  892. .OUTPUT(var_copy, TensorType::NumberType())
  893. .ATTR(use_nesterov, Bool, false)
  894. .ATTR(use_locking, Bool, false)
  895. .OP_END_FACTORY_REG(FusedMulApplyMomentumExtern)
  896. /**
  897. *@brief Update "g" according to the LARS algorithm.
  898. *@par Inputs:
  899. *Four inputs, including:
  900. * @li w: A Tensor. Must be of type TensorType::DT_FLOAT.
  901. * @li g: A Tensor of the same type and shape as "w".
  902. * @li weight_decay: A Tensor of the same type as "w", Must be a scalar.
  903. * @li learning_rate: A Tensor of the same type as "w", Must be a scalar.
  904. *@par Attributes:
  905. *Three Attributes, including:
  906. * @li hyperpara: An optional float. Default value is 0.001.
  907. * @li epsilon: An optional float. Default value is 1e-5.Avoid denominator is 0.
  908. * @li use_clip: An optional bool. Defaults to "False".\n
  909. * If "True", updating learning rate.
  910. *@par Outputs:
  911. *g_new: Tensor of the same type as "w".
  912. */
  913. REG_OP(LarsV2)
  914. .INPUT(w, TensorType(DT_FLOAT))
  915. .INPUT(g, TensorType(DT_FLOAT))
  916. .INPUT(weight_decay, TensorType(DT_FLOAT))
  917. .INPUT(learning_rate, TensorType(DT_FLOAT))
  918. .OUTPUT(g_new, TensorType(DT_FLOAT))
  919. .ATTR(hyperpara, Float, 0.001)
  920. .ATTR(epsilon, Float, 0.00001)
  921. .ATTR(use_clip, Bool, false)
  922. .OP_END_FACTORY_REG(LarsV2)
  923. /**
  924. *@brief Update "g" according to the LARS algorithm.
  925. *@par Inputs:
  926. *Six inputs, including:
  927. * @li w: A Tensor. Must be of type TensorType::DT_FLOAT.
  928. * @li g: A Tensor of the same type and shape as "w".
  929. * @li w_square_sum: A Tensor of square_sum(w), has the same type as "w", Must be a scalar.
  930. * @li g_square_sum: A Tensor of square(g), has the same type as "w", Must be a scalar.
  931. * @li weight_decay: A Tensor of the same type as "w", Must be a scalar.
  932. * @li learning_rate: A Tensor of the same type as "w", Must be a scalar.
  933. *@par Attributes:
  934. *Three Attributes, including:
  935. * @li hyperpara: An optional float. Default value is 0.001.
  936. * @li epsilon: An optional float. Default value is 1e-5.Avoid denominator is 0.
  937. * @li use_clip: An optional bool. Defaults to "False".\n
  938. * If "True", updating learning rate.
  939. *@par Outputs:
  940. *g_new: Tensor of the same type as "w".
  941. */
  942. REG_OP(LarsV2Update)
  943. .INPUT(w, TensorType(DT_FLOAT))
  944. .INPUT(g, TensorType(DT_FLOAT))
  945. .INPUT(w_square_sum, TensorType(DT_FLOAT))
  946. .INPUT(g_square_sum, TensorType(DT_FLOAT))
  947. .INPUT(weight_decay, TensorType(DT_FLOAT))
  948. .INPUT(learning_rate, TensorType(DT_FLOAT))
  949. .OUTPUT(g_new, TensorType(DT_FLOAT))
  950. .ATTR(hyperpara, Float, 0.001)
  951. .ATTR(epsilon, Float, 0.00001)
  952. .ATTR(use_clip, Bool, false)
  953. .OP_END_FACTORY_REG(LarsV2Update)
  954. /**
  955. * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme.
  956. * @par Inputs:
  957. * Nine inputs, including:
  958. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  959. * Should be a Variable Tensor.
  960. * @li accum: A mutable Tensor of the same type as "var".
  961. * Should be a Variable Tensor.
  962. * @li linear: A mutable Tensor of the same type as "var".
  963. * Should be a Variable Tensor.
  964. * @li grad: A Tensor of the same type as "var", for the gradient.
  965. * @li indices: A vector of indices into the first dimension of var and accum.
  966. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  967. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  968. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  969. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  970. * @par Attributes:
  971. * use_locking: An optional bool. Defaults to "False".
  972. * If "True", updating of the "var" and "accum" tensors will be
  973. * protected by a lock; otherwise the behavior is undefined,
  974. * but may exhibit less contention.
  975. * @par Outputs:
  976. * var: A Tensor. Has the same type and format as input "var".
  977. */
  978. REG_OP(SparseApplyFtrl)
  979. .INPUT(var, TensorType({DT_FLOAT}))
  980. .INPUT(accum, TensorType({DT_FLOAT}))
  981. .INPUT(linear, TensorType({DT_FLOAT}))
  982. .INPUT(grad, TensorType({DT_FLOAT}))
  983. .INPUT(indices, TensorType({DT_INT32}))
  984. .INPUT(lr, TensorType({DT_FLOAT}))
  985. .INPUT(l1, TensorType({DT_FLOAT}))
  986. .INPUT(l2, TensorType({DT_FLOAT}))
  987. .INPUT(lr_power, TensorType({DT_FLOAT}))
  988. .OUTPUT(var, TensorType({DT_FLOAT}))
  989. .ATTR(use_locking, Bool, false)
  990. .OP_END_FACTORY_REG(SparseApplyFtrl)
  991. /**
  992. * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme.
  993. * @par Inputs:
  994. * Nine inputs, including:
  995. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  996. * Should be a Variable Tensor.
  997. * @li accum: A mutable Tensor of the same type as "var".
  998. * Should be a Variable Tensor.
  999. * @li linear: A mutable Tensor of the same type as "var".
  1000. * Should be a Variable Tensor.
  1001. * @li grad: A Tensor of the same type as "var", for the gradient.
  1002. * @li indices: A vector of indices into the first dimension of var and accum.
  1003. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1004. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1005. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1006. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1007. * @par Attributes:
  1008. * use_locking: An optional bool. Defaults to "False".
  1009. * If "True", updating of the "var" and "accum" tensors will be
  1010. * protected by a lock; otherwise the behavior is undefined,
  1011. * but may exhibit less contention.
  1012. * @par Outputs:
  1013. * var: A Tensor. Has the same type and format as input "var".
  1014. */
  1015. REG_OP(SparseApplyFtrlD)
  1016. .INPUT(var, TensorType({DT_FLOAT}))
  1017. .INPUT(accum, TensorType({DT_FLOAT}))
  1018. .INPUT(linear, TensorType({DT_FLOAT}))
  1019. .INPUT(grad, TensorType({DT_FLOAT}))
  1020. .INPUT(indices, TensorType({DT_INT32}))
  1021. .OUTPUT(var, TensorType({DT_FLOAT}))
  1022. .REQUIRED_ATTR(lr, Float)
  1023. .REQUIRED_ATTR(l1, Float)
  1024. .REQUIRED_ATTR(l2, Float)
  1025. .REQUIRED_ATTR(lr_power, Float)
  1026. .ATTR(use_locking, Bool, false)
  1027. .OP_END_FACTORY_REG(SparseApplyFtrlD)
  1028. /**
  1029. * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme.
  1030. * That is for rows we have grad for, we update var, accum and linear
  1031. * @par Inputs:
  1032. * Ten inputs, including:
  1033. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1034. * Should be a Variable Tensor.
  1035. * @li accum: A mutable Tensor of the same type as "var".
  1036. * Should be a Variable Tensor.
  1037. * @li linear: A mutable Tensor of the same type as "var".
  1038. * Should be a Variable Tensor.
  1039. * @li grad: A Tensor of the same type as "var", for the gradient.
  1040. * @li indices: A vector of indices into the first dimension of var and accum.
  1041. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1042. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1043. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1044. * @li l2_shrinkage: A Tensor of the same type as "var", L2 shrinkage regulariation. Must be a scalar.
  1045. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1046. * @par Attributes:
  1047. * use_locking: An optional bool. Defaults to "False".
  1048. * If "True", updating of the "var" and "accum" tensors will be
  1049. * rotected by a lock; otherwise the behavior is undefined,
  1050. * but may exhibit less contention.
  1051. * @par Outputs:
  1052. * var: A Tensor. Has the same type and format as input "var".
  1053. */
  1054. REG_OP(SparseApplyFtrlV2)
  1055. .INPUT(var, TensorType({DT_FLOAT}))
  1056. .INPUT(accum, TensorType({DT_FLOAT}))
  1057. .INPUT(linear, TensorType({DT_FLOAT}))
  1058. .INPUT(grad, TensorType({DT_FLOAT}))
  1059. .INPUT(indices, TensorType({DT_INT32}))
  1060. .INPUT(lr, TensorType({DT_FLOAT}))
  1061. .INPUT(l1, TensorType({DT_FLOAT}))
  1062. .INPUT(l2, TensorType({DT_FLOAT}))
  1063. .INPUT(l2_shrinkage, TensorType({DT_FLOAT}))
  1064. .INPUT(lr_power, TensorType({DT_FLOAT}))
  1065. .OUTPUT(var, TensorType({DT_FLOAT}))
  1066. .ATTR(use_locking, Bool, false)
  1067. .OP_END_FACTORY_REG(SparseApplyFtrlV2)
  1068. /**
  1069. * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme.
  1070. * That is for rows we have grad for, we update var, accum and linear
  1071. * @par Inputs:
  1072. * Ten inputs, including:
  1073. * @li var: A mutable Tensor. Must be of type TensorType::NumberType().
  1074. * Should be a Variable Tensor.
  1075. * @li accum: A mutable Tensor of the same type as "var".
  1076. * Should be a Variable Tensor.
  1077. * @li linear: A mutable Tensor of the same type as "var".
  1078. * Should be a Variable Tensor.
  1079. * @li grad: A Tensor of the same type as "var", for the gradient.
  1080. * @li indices: A vector of indices into the first dimension of var and accum.
  1081. * @par Attributes:
  1082. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1083. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar.
  1084. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar.
  1085. * @li l2_shrinkage: A Tensor of the same type as "var", L2 shrinkage regulariation. Must be a scalar.
  1086. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar.
  1087. * @li use_locking: An optional bool. Defaults to "False".
  1088. * If "True", updating of the "var" and "accum" tensors will be
  1089. * rotected by a lock; otherwise the behavior is undefined,
  1090. * but may exhibit less contention.
  1091. * @par Outputs:
  1092. * var: A Tensor. Has the same type and format as input "var".
  1093. */
  1094. REG_OP(SparseApplyFtrlV2D)
  1095. .INPUT(var, TensorType({DT_FLOAT}))
  1096. .INPUT(accum, TensorType({DT_FLOAT}))
  1097. .INPUT(linear, TensorType({DT_FLOAT}))
  1098. .INPUT(grad, TensorType({DT_FLOAT}))
  1099. .INPUT(indices, TensorType({DT_INT32}))
  1100. .OUTPUT(var, TensorType({DT_FLOAT}))
  1101. .REQUIRED_ATTR(lr, Float)
  1102. .REQUIRED_ATTR(l1, Float)
  1103. .REQUIRED_ATTR(l2, Float)
  1104. .REQUIRED_ATTR(l2_shrinkage, Float)
  1105. .REQUIRED_ATTR(lr_power, Float)
  1106. .ATTR(use_locking, Bool, false)
  1107. .OP_END_FACTORY_REG(SparseApplyFtrlV2D)
  1108. /**
  1109. *@brief Clean memory of workspace list.
  1110. *@par Attributes:
  1111. * @li automic_add_mem_size: sizes of workspaces.
  1112. */
  1113. REG_OP(AtomicAddrClean)
  1114. .ATTR(automic_add_mem_size, ListInt, {})
  1115. .OP_END_FACTORY_REG(AtomicAddrClean)
  1116. } // namespace ge
  1117. #endif // GE_OP_TRAINING_OPS_H

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示